From 8da15ecd83807e2526e112da732c823c701583f6 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 10 May 2024 16:50:48 +0800 Subject: [PATCH 001/706] add set_module for x86inductorquantizer's static quant Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 99 +++++++++ third_party/kineto | 2 +- torch/ao/quantization/quantizer/utils.py | 64 +++++- .../quantizer/x86_inductor_quantizer.py | 193 +++++++++++++++--- .../quantizer/xnnpack_quantizer.py | 35 +--- 5 files changed, 325 insertions(+), 68 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 218b30bd9e33..7be04dcbd7d9 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1865,6 +1865,105 @@ def test_qat_dynamic_quant_linear(self): is_qat=True, ) + @skipIfNoX86 + def test_set_module_name(self): + """Test that quantize the specific submodule.""" + + class Sub(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set global to no quantization and then default config for a specific submodule. + quantizer = X86InductorQuantizer() + quantizer.set_module_name( + "sub", xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + torch.ops.aten.linear.default: 2, + # input and output for the second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # second linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + @skipIfNoX86 + def test_set_module_name_with_underscores(self) -> None: + """Test that if a module name has an underscore, we can still quantize it.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + # This module name has underscores, which can be part of a mangled + # name. + self.foo_bar = torch.nn.Linear(2, 2) + self.baz = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.baz(self.foo_bar(x)) + + # Set global to no quantization and then default config for a specific submodule. + quantizer = X86InductorQuantizer() + quantizer.set_module_name( + "foo_bar", xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = capture_pre_autograd_graph(m, example_inputs) + m = prepare_pt2e(m, quantizer) + # Use a linear count instead of names because the names might change, but + # the order should be the same. + count = 0 + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear.default: + # Get the weight observer to see the per-channel vs per-tensor. + weight_observer_node = n.args[1] + if count == 0: + # for foo_bar. + self.assertEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + observer_instance = getattr(m, weight_observer_node.target) + self.assertEqual( + observer_instance.qscheme, torch.per_channel_symmetric + ) + else: + # For baz it should have no observer at all. + self.assertNotEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + count += 1 + @skipIfNoX86 def test_filter_conv2d_recipe(self): """ diff --git a/third_party/kineto b/third_party/kineto index 3a81076cc970..327ac5052cf2 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 3a81076cc97092666f319846f32f36b73ce2293e +Subproject commit 327ac5052cf25238fc769ab421c680d19b848eb3 diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index f25d0916018b..14a373274428 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Callable, List, Optional from torch.ao.quantization.pt2e.utils import _is_sym_size_node @@ -47,3 +47,65 @@ def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]): ((user not in partition_nodes) or _is_sym_size_node(user)) for user in node.users ) + + +def _get_module_name_filter(module_name: str): + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + + def _normalize_path(n): + prefix = 0 + # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. + if n.startswith("L['self']."): + prefix = len("L['self'].") + return n[prefix:] + + names = [_normalize_path(n) for n, _ in nn_module_stack.values()] + return module_name in names + + return module_name_filter + + +def _is_annotated(nodes: List[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def _skip_annotate( + nodes: List[Node], filter_fn: Optional[Callable[[Node], bool]] = None +): + skip_annotate = False + # 1) Skip annotate if any node is already annotated + if _is_annotated(nodes): + skip_annotate = True + # 2) TODO: Skip annotate if a) filter_fn is provided and b) any node does not pass the filter + if filter_fn and any(not filter_fn(node) for node in nodes): + skip_annotate = True + return skip_annotate diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 4cc05e46c6a7..e41e05e5ab01 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -14,6 +14,7 @@ Set, Tuple, TYPE_CHECKING, + Union, ) import torch @@ -36,6 +37,10 @@ Quantizer, SharedQuantizationSpec, ) +from torch.ao.quantization.quantizer.utils import ( + _get_module_name_filter, + _skip_annotate, +) from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( _is_annotated, get_bias_qspec, @@ -144,6 +149,46 @@ def _mark_nodes_as_annotated(nodes: List[Node]): node.meta[QUANT_ANNOTATION_KEY]._annotated = True +AnnotatorType = Callable[ + [ + "X86InductorQuantizer", + torch.fx.GraphModule, + Optional[QuantizationConfig], + Optional[Callable[[Node], bool]], + ], + Optional[List[List[Node]]], +] + +X86_ANNOTATOR_COLLECTIONS: Dict[str, Dict[str, AnnotatorType]] = { + "STATIC": {}, + "DYNAMIC": {}, + "QAT": {}, +} + +STATIC_ANNOTATORS = X86_ANNOTATOR_COLLECTIONS["STATIC"] +DYNAMIC_ANNOTATORS = X86_ANNOTATOR_COLLECTIONS["DYNAMIC"] +QAT_ANNOTATORS = X86_ANNOTATOR_COLLECTIONS["QAT"] + + +AnnotatorCollectionType = Dict[str, AnnotatorType] + + +def register_annotator( + annotators_list: Union[AnnotatorCollectionType, List[AnnotatorCollectionType]], + annotator_name: Optional[str] = None, +): + def decorator(annotator: AnnotatorType): + nonlocal annotators_list, annotator_name + if not isinstance(annotators_list, list): + annotators_list = [annotators_list] + annotator_name = annotator_name or annotator.__name__ + for annotators in annotators_list: + annotators[annotator_name] = annotator + return annotator + + return decorator + + def _is_node_annotated(_node): """ return True if the node is annotated, otherwise return False @@ -303,6 +348,7 @@ def __init__(self): self.operator_type_qconfig: Dict[ torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} + self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: @@ -372,6 +418,19 @@ def set_module_type_qconfig( ) return self + def set_module_name( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + """ + assert ( + quantization_config is not None + ), " quantization_config == None is not supported yet" + self.module_name_config[module_name] = quantization_config + return self + def _set_aten_operator_qconfig( self, operator_type: torch._ops.OpOverloadPacket, @@ -510,6 +569,37 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: model = self._annotate_for_static_quantization_config(model) return model + def _annotate_by_module_name( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + if quantization_config.is_qat: + for annotator_func in QAT_ANNOTATORS.values(): + annotator_func(self, model, quantization_config, filter_fn) + for annotator_func in STATIC_ANNOTATORS.values(): + annotator_func(self, model, quantization_config, filter_fn) + return model + + def _annotate_static_quantization_config_by_module_name( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + for module_name, config in self.module_name_config.items(): + self._annotate_by_module_name( + model, config, _get_module_name_filter(module_name) + ) + return model + + def _annotate_static_quantization_config_by_op_type_and_global_config(self, model): + self._annotate_conv2d_fusion_pattern(model) + self._annotate_linear_fusion_pattern(model) + self._annotate_matmul_pattern(model) + def _annotate_for_static_quantization_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: @@ -525,9 +615,12 @@ def _annotate_for_static_quantization_config( """ # Step1: Recipe of fusion patterns like conv/linear. - self._annotate_conv2d_fusion_pattern(model) - self._annotate_linear_fusion_pattern(model) - self._annotate_matmul(model) + if self.module_name_config: + self._annotate_static_quantization_config_by_module_name(model) + if self.operator_type_qconfig or self.global_config: + self._annotate_static_quantization_config_by_op_type_and_global_config( + model + ) # Step2: Recipe to propagate annotation for patterns beside conv/linear. # Go through all the nodes from start to end. @@ -561,7 +654,9 @@ def _annotate_qat_conv2d_fusion_pattern( self._annotate_qat_conv2d_bn(model, config) def _annotate_qat_conv2d_bn_binary_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] @@ -782,26 +877,38 @@ def _annotate_linear_fusion_pattern(self, model: torch.fx.GraphModule): self._annotate_linear_unary(model, config) self._annotate_linear(model, config) - def _annotate_matmul(self, model: torch.fx.GraphModule): + def _annotate_matmul_pattern(self, model: torch.fx.GraphModule): if config := self._get_aten_operator_qconfig(torch.ops.aten.matmul.default): - for node in model.graph.nodes: - if node.target == torch.ops.aten.matmul.default and not _is_annotated( - [node] - ): - input_qspec_map = {} - matmul_node = node - for input_node in matmul_node.args: - input_qspec_map[input_node] = get_input_act_qspec(config) - matmul_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - input_qspec_map=input_qspec_map, - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + self._annotate_matmul(model, config) + + @register_annotator(STATIC_ANNOTATORS) + def _annotate_matmul( + self, + model: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, + ): + for node in model.graph.nodes: + if node.target != torch.ops.aten.matmul.default: + continue + if _skip_annotate([node], filter_fn): + continue + input_qspec_map = {} + matmul_node = node + for input_node in matmul_node.args: + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + @register_annotator(STATIC_ANNOTATORS) def _annotate_conv2d_binary_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: # Conv2d + add + unary op fused_partitions = find_sequential_partitions( @@ -829,7 +936,7 @@ def _annotate_conv2d_binary_unary( ): # No conv node found to be fused with add continue - if _is_annotated([unary_node, binary_node, conv_node]): + if _skip_annotate([unary_node, binary_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) binary_node_input_qspec_map = {} @@ -845,8 +952,12 @@ def _annotate_conv2d_binary_unary( _is_output_of_quantized_pattern=True, ) + @register_annotator(STATIC_ANNOTATORS) def _annotate_conv2d_binary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: # Conv2d + add fused_partitions = find_sequential_partitions( @@ -875,7 +986,7 @@ def _annotate_conv2d_binary( ): # No conv node found to be fused with add continue - if _is_annotated([binary_node, conv_node]): + if _skip_annotate([binary_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) binary_node_input_qspec_map = {} @@ -888,8 +999,12 @@ def _annotate_conv2d_binary( _is_output_of_quantized_pattern=True, ) + @register_annotator(STATIC_ANNOTATORS) def _annotate_conv2d_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: fused_partitions = [] unary_patterns = [ @@ -915,7 +1030,7 @@ def _annotate_conv2d_unary( or conv_node.target != torch.ops.aten.conv2d.default ): continue - if _is_annotated([unary_node, conv_node]): + if _skip_annotate([unary_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( @@ -923,8 +1038,12 @@ def _annotate_conv2d_unary( _is_output_of_quantized_pattern=True, ) + @register_annotator(STATIC_ANNOTATORS) def _annotate_conv2d( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: conv_partitions = get_source_partitions( gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] @@ -940,7 +1059,7 @@ def _annotate_conv2d( ): raise ValueError(f"{conv_node} is not an aten conv2d operator") # skip annotation if it is already annotated - if _is_annotated([conv_node]): + if _skip_annotate([conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, True, quantization_config) @@ -1099,8 +1218,12 @@ def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: self._annotate_output_share_observer_as_input(input_node, node) return + @register_annotator(STATIC_ANNOTATORS) def _annotate_linear( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: linear_partitions = get_source_partitions( gm.graph, [torch.nn.Linear, torch.nn.functional.linear] @@ -1119,12 +1242,16 @@ def _annotate_linear( ): raise ValueError(f"{linear_node} is not an aten linear operator") # skip annotation if it is already annotated - if _is_annotated([linear_node]): + if _skip_annotate([linear_node], filter_fn): continue self._annotate_linear_node_helper(linear_node, True, quantization_config) + @register_annotator(STATIC_ANNOTATORS) def _annotate_linear_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: postop_list = [ torch.nn.ReLU, @@ -1146,7 +1273,7 @@ def _annotate_linear_unary( torch.ops.aten.linear.default, ): continue - if _is_annotated([unary_node, linear_node]): + if _skip_annotate([unary_node, linear_node], filter_fn): continue self._annotate_linear_node_helper(linear_node, False, quantization_config) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( @@ -1154,10 +1281,12 @@ def _annotate_linear_unary( _is_output_of_quantized_pattern=True, ) + @register_annotator(STATIC_ANNOTATORS) def _annotate_linear_binary_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: # linear + binary_op + (optional) unary op binary_op_list = [operator.add] diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index f3d1b6ca8b39..e13a79f39267 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -22,6 +22,7 @@ ) from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( _convert_scalars_to_attrs, @@ -192,40 +193,6 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_symmetric_config_and_operators() -def _get_module_name_filter(module_name: str): - """Get the module_name_filter function for a given module name, the filter accepts - a node and checks if the node comes from a module that has certain module name - - For example: - node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 - - - >> module_name_filter = _get_module_name_filter("blocks.sub") - >> print(module_name_filter(node)) - True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" - """ - - def module_name_filter(n: Node) -> bool: - # example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - # get_attr nodes doesn't have nn_module_stack? - nn_module_stack = n.meta.get("nn_module_stack", {}) - - def _normalize_path(n): - prefix = 0 - # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. - if n.startswith("L['self']."): - prefix = len("L['self'].") - return n[prefix:] - - names = [_normalize_path(n) for n, _ in nn_module_stack.values()] - return module_name in names - - return module_name_filter - - def _get_module_type_filter(tp: Callable): """Get the module_type_filter function for a given module type, the filter accepts a node and checks if the node comes from a module that has certain module type From 4729d7b66e700f268988088285d860562a69fd46 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sat, 11 May 2024 11:27:44 +0800 Subject: [PATCH 002/706] support for dynamic Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 41 +++++++++ .../quantizer/x86_inductor_quantizer.py | 85 +++++++++++++------ 2 files changed, 102 insertions(+), 24 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 7be04dcbd7d9..1975e384b3ca 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1964,6 +1964,47 @@ def forward(self, x): ) count += 1 + @skipIfNoX86 + def test_set_module_for_dynamic_quant(self): + """Test that quantize the specific submodule for dynamic quantization.""" + + with override_quantized_engine("x86"), torch.no_grad(): + for is_qat in [False, True]: + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + # only quantize `self.q_proj` `self.v_proj` + dynamic_config = xiq.get_default_x86_inductor_quantization_config( + is_dynamic=True, is_qat=is_qat + ) + quantizer = ( + X86InductorQuantizer() + .set_module_name("q_proj", dynamic_config) + .set_module_name("v_proj", dynamic_config) + ) + node_occurrence = { + # for quantize input + torch.ops.quantized_decomposed.choose_qparams.tensor: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # each for q_proj and v_proj + # torch.ops.quantized_decomposed.quantize_per_channel.default: 2, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + @skipIfNoX86 def test_filter_conv2d_recipe(self): """ diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index e41e05e5ab01..f6eaed879eeb 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -3,6 +3,7 @@ import itertools import operator import warnings +from collections import OrderedDict from dataclasses import dataclass from typing import ( Any, @@ -160,14 +161,15 @@ def _mark_nodes_as_annotated(nodes: List[Node]): ] X86_ANNOTATOR_COLLECTIONS: Dict[str, Dict[str, AnnotatorType]] = { - "STATIC": {}, - "DYNAMIC": {}, - "QAT": {}, + "STATIC": OrderedDict(), + "DYNAMIC": OrderedDict(), + "STATIC_QAT_ONLY": OrderedDict(), } STATIC_ANNOTATORS = X86_ANNOTATOR_COLLECTIONS["STATIC"] DYNAMIC_ANNOTATORS = X86_ANNOTATOR_COLLECTIONS["DYNAMIC"] -QAT_ANNOTATORS = X86_ANNOTATOR_COLLECTIONS["QAT"] +STATIC_QAT_ONLY_ANNOTATORS = X86_ANNOTATOR_COLLECTIONS["STATIC_QAT_ONLY"] +# For static QAT, apply the `STATIC_QAT_ONLY_ANNOTATORS` and `STATIC_ANNOTATORS` in order. AnnotatorCollectionType = Dict[str, AnnotatorType] @@ -569,7 +571,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: model = self._annotate_for_static_quantization_config(model) return model - def _annotate_by_module_name( + def _annotate_by_single_config( self, model: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], @@ -579,23 +581,35 @@ def _annotate_by_module_name( if quantization_config is None: return model - if quantization_config.is_qat: - for annotator_func in QAT_ANNOTATORS.values(): - annotator_func(self, model, quantization_config, filter_fn) - for annotator_func in STATIC_ANNOTATORS.values(): + if ( + quantization_config.input_activation + and quantization_config.input_activation.is_dynamic + ): + annotators = DYNAMIC_ANNOTATORS + else: + annotators = STATIC_ANNOTATORS.copy() + if quantization_config.is_qat: + # Apply QAT-specific annotators first + qat_annotators = STATIC_QAT_ONLY_ANNOTATORS.copy() + qat_annotators.update(STATIC_ANNOTATORS) + annotators = qat_annotators + + for annotator_func in annotators.values(): annotator_func(self, model, quantization_config, filter_fn) return model - def _annotate_static_quantization_config_by_module_name( + def _annotate_quantization_by_module_name_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: for module_name, config in self.module_name_config.items(): - self._annotate_by_module_name( + self._annotate_by_single_config( model, config, _get_module_name_filter(module_name) ) return model - def _annotate_static_quantization_config_by_op_type_and_global_config(self, model): + def _annotate_static_quantization_by_op_type_and_global_config( + self, model: torch.fx.GraphModule + ): self._annotate_conv2d_fusion_pattern(model) self._annotate_linear_fusion_pattern(model) self._annotate_matmul_pattern(model) @@ -616,11 +630,9 @@ def _annotate_for_static_quantization_config( # Step1: Recipe of fusion patterns like conv/linear. if self.module_name_config: - self._annotate_static_quantization_config_by_module_name(model) + self._annotate_quantization_by_module_name_config(model) if self.operator_type_qconfig or self.global_config: - self._annotate_static_quantization_config_by_op_type_and_global_config( - model - ) + self._annotate_static_quantization_by_op_type_and_global_config(model) # Step2: Recipe to propagate annotation for patterns beside conv/linear. # Go through all the nodes from start to end. @@ -640,6 +652,15 @@ def _annotate_for_static_quantization_config( def _annotate_for_dynamic_quantization_config( self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + if self.module_name_config: + self._annotate_quantization_by_module_name_config(model) + if self.operator_type_qconfig or self.global_config: + self._annotate_dynamic_quantization_by_op_type_and_global_config(model) + return model + + def _annotate_dynamic_quantization_by_op_type_and_global_config( + self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: self._annotate_linear_fusion_pattern(model) return model @@ -653,10 +674,12 @@ def _annotate_qat_conv2d_fusion_pattern( self._annotate_qat_conv2d_bn_unary(model, config) self._annotate_qat_conv2d_bn(model, config) + @register_annotator(STATIC_QAT_ONLY_ANNOTATORS) def _annotate_qat_conv2d_bn_binary_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] @@ -696,7 +719,9 @@ def _annotate_qat_conv2d_bn_binary_unary( ): continue - if _is_annotated([unary_node, binary_node, bn_output_node, conv_node]): + if _skip_annotate( + [unary_node, binary_node, bn_output_node, conv_node], filter_fn + ): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) @@ -721,8 +746,12 @@ def _annotate_qat_conv2d_bn_binary_unary( nodes_to_mark_annotated.extend(list(unary_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) + @register_annotator(STATIC_QAT_ONLY_ANNOTATORS) def _annotate_qat_conv2d_bn_binary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add] @@ -756,7 +785,7 @@ def _annotate_qat_conv2d_bn_binary( ): continue - if _is_annotated([binary_node, bn_output_node, conv_node]): + if _skip_annotate([binary_node, bn_output_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) @@ -777,8 +806,12 @@ def _annotate_qat_conv2d_bn_binary( nodes_to_mark_annotated.extend(list(binary_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) + @register_annotator(STATIC_QAT_ONLY_ANNOTATORS) def _annotate_qat_conv2d_bn_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: fused_partitions = [] unary_patterns = [ @@ -810,7 +843,7 @@ def _annotate_qat_conv2d_bn_unary( ): continue - if _is_annotated([unary_node, bn_output_node, conv_node]): + if _skip_annotate([unary_node, bn_output_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) @@ -825,8 +858,12 @@ def _annotate_qat_conv2d_bn_unary( nodes_to_mark_annotated.extend(list(unary_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) + @register_annotator(STATIC_QAT_ONLY_ANNOTATORS) def _annotate_qat_conv2d_bn( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d] @@ -843,7 +880,7 @@ def _annotate_qat_conv2d_bn( ): continue - if _is_annotated([bn_output_node, conv_node]): + if _skip_annotate([bn_output_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) @@ -1218,7 +1255,7 @@ def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: self._annotate_output_share_observer_as_input(input_node, node) return - @register_annotator(STATIC_ANNOTATORS) + @register_annotator([STATIC_ANNOTATORS, DYNAMIC_ANNOTATORS]) def _annotate_linear( self, gm: torch.fx.GraphModule, From 737a9065d878de439fe3824290fc82e20c6e8a61 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 13 May 2024 08:48:13 +0800 Subject: [PATCH 003/706] enhance UT Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 15 ++++++++++++--- .../quantizer/x86_inductor_quantizer.py | 11 ++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 1975e384b3ca..2cc0e0172c81 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1965,7 +1965,7 @@ def forward(self, x): count += 1 @skipIfNoX86 - def test_set_module_for_dynamic_quant(self): + def test_set_module_name_for_dynamic_quant(self): """Test that quantize the specific submodule for dynamic quantization.""" with override_quantized_engine("x86"), torch.no_grad(): @@ -1982,18 +1982,27 @@ def test_set_module_for_dynamic_quant(self): .set_module_name("v_proj", dynamic_config) ) node_occurrence = { - # for quantize input + # ops for quantizing/de-quantizing input torch.ops.quantized_decomposed.choose_qparams.tensor: 1, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, # each for q_proj and v_proj - # torch.ops.quantized_decomposed.quantize_per_channel.default: 2, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ + # ops for quantizing/de-quantizing input torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + # op for de-quantizing `q_proj`'s weight + torch.ops.quantized_decomposed.dequantize_per_channel.default, + # q_proj + torch.ops.aten.linear.default, + # k_proj + torch.ops.aten.linear.default, + # op for de-quantizing `v_proj`'s weight + torch.ops.quantized_decomposed.dequantize_per_channel.default, + # v_proj torch.ops.aten.linear.default, ] self._test_quantizer( diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index f6eaed879eeb..346b5a70cadc 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -179,8 +179,13 @@ def register_annotator( annotators_list: Union[AnnotatorCollectionType, List[AnnotatorCollectionType]], annotator_name: Optional[str] = None, ): - def decorator(annotator: AnnotatorType): - nonlocal annotators_list, annotator_name + def decorator( + annotator: AnnotatorType, + annotators_list: Union[ + AnnotatorCollectionType, List[AnnotatorCollectionType] + ] = annotators_list, + annotator_name: Optional[str] = annotator_name, + ) -> AnnotatorType: if not isinstance(annotators_list, list): annotators_list = [annotators_list] annotator_name = annotator_name or annotator.__name__ @@ -577,7 +582,7 @@ def _annotate_by_single_config( quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> torch.fx.GraphModule: - # TODO: implement the support for None to be canceling out previous annotations + # implement the support for None to be canceling out previous annotations if quantization_config is None: return model From af337843a38e66265f7606f26e2adebd066ab8ee Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 13 May 2024 10:06:40 +0800 Subject: [PATCH 004/706] clean code Signed-off-by: yiliu30 --- torch/ao/quantization/quantizer/utils.py | 2 +- .../quantizer/x86_inductor_quantizer.py | 43 ++++++++++--------- .../quantizer/xnnpack_quantizer_utils.py | 16 +------ 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index 14a373274428..d6dd373a7ba7 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -105,7 +105,7 @@ def _skip_annotate( # 1) Skip annotate if any node is already annotated if _is_annotated(nodes): skip_annotate = True - # 2) TODO: Skip annotate if a) filter_fn is provided and b) any node does not pass the filter + # 2) Skip annotate if a) filter_fn is provided and b) any node fails the filter if filter_fn and any(not filter_fn(node) for node in nodes): skip_annotate = True return skip_annotate diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 346b5a70cadc..82e8014cd9ed 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -3,7 +3,6 @@ import itertools import operator import warnings -from collections import OrderedDict from dataclasses import dataclass from typing import ( Any, @@ -18,6 +17,8 @@ Union, ) +from typing_extensions import TypeAlias + import torch import torch.nn.functional as F from torch.ao.quantization.fake_quantize import ( @@ -40,10 +41,10 @@ ) from torch.ao.quantization.quantizer.utils import ( _get_module_name_filter, + _is_annotated, _skip_annotate, ) from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( - _is_annotated, get_bias_qspec, get_input_act_qspec, get_output_act_qspec, @@ -150,7 +151,7 @@ def _mark_nodes_as_annotated(nodes: List[Node]): node.meta[QUANT_ANNOTATION_KEY]._annotated = True -AnnotatorType = Callable[ +AnnotatorType: TypeAlias = Callable[ [ "X86InductorQuantizer", torch.fx.GraphModule, @@ -160,30 +161,29 @@ def _mark_nodes_as_annotated(nodes: List[Node]): Optional[List[List[Node]]], ] -X86_ANNOTATOR_COLLECTIONS: Dict[str, Dict[str, AnnotatorType]] = { - "STATIC": OrderedDict(), - "DYNAMIC": OrderedDict(), - "STATIC_QAT_ONLY": OrderedDict(), +X86_ANNOTATORS_REGISTRY: Dict[str, Dict[str, AnnotatorType]] = { + "STATIC": {}, + "DYNAMIC": {}, + "STATIC_QAT_ONLY": {}, } -STATIC_ANNOTATORS = X86_ANNOTATOR_COLLECTIONS["STATIC"] -DYNAMIC_ANNOTATORS = X86_ANNOTATOR_COLLECTIONS["DYNAMIC"] -STATIC_QAT_ONLY_ANNOTATORS = X86_ANNOTATOR_COLLECTIONS["STATIC_QAT_ONLY"] -# For static QAT, apply the `STATIC_QAT_ONLY_ANNOTATORS` and `STATIC_ANNOTATORS` in order. +AnnotatorsType: TypeAlias = Dict[str, AnnotatorType] - -AnnotatorCollectionType = Dict[str, AnnotatorType] +# Annotators collection +STATIC_ANNOTATORS: AnnotatorsType = X86_ANNOTATORS_REGISTRY["STATIC"] +DYNAMIC_ANNOTATORS: AnnotatorsType = X86_ANNOTATORS_REGISTRY["DYNAMIC"] +STATIC_QAT_ONLY_ANNOTATORS: AnnotatorsType = X86_ANNOTATORS_REGISTRY["STATIC_QAT_ONLY"] +# For static QAT, apply the `STATIC_QAT_ONLY_ANNOTATORS` and `STATIC_ANNOTATORS` in order. def register_annotator( - annotators_list: Union[AnnotatorCollectionType, List[AnnotatorCollectionType]], + annotators_list: Union[AnnotatorsType, List[AnnotatorsType]], annotator_name: Optional[str] = None, ): + # register annotator functions into one or more annotator collections. def decorator( annotator: AnnotatorType, - annotators_list: Union[ - AnnotatorCollectionType, List[AnnotatorCollectionType] - ] = annotators_list, + annotators_list: Union[AnnotatorsType, List[AnnotatorsType]] = annotators_list, annotator_name: Optional[str] = annotator_name, ) -> AnnotatorType: if not isinstance(annotators_list, list): @@ -582,6 +582,8 @@ def _annotate_by_single_config( quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> torch.fx.GraphModule: + """Select the annotator functions according to the `quantization_config` and apply.""" + # implement the support for None to be canceling out previous annotations if quantization_config is None: return model @@ -592,12 +594,11 @@ def _annotate_by_single_config( ): annotators = DYNAMIC_ANNOTATORS else: - annotators = STATIC_ANNOTATORS.copy() + annotators = STATIC_ANNOTATORS if quantization_config.is_qat: # Apply QAT-specific annotators first - qat_annotators = STATIC_QAT_ONLY_ANNOTATORS.copy() - qat_annotators.update(STATIC_ANNOTATORS) - annotators = qat_annotators + for annotator_func in STATIC_QAT_ONLY_ANNOTATORS.values(): + annotator_func(self, model, quantization_config, filter_fn) for annotator_func in annotators.values(): annotator_func(self, model, quantization_config, filter_fn) diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index 9f1732e57370..d5595a136990 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -26,6 +26,7 @@ from torch.ao.quantization.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, + _is_annotated, ) from torch.fx import Node from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( @@ -94,21 +95,6 @@ class OperatorConfig(NamedTuple): operators: List[OperatorPatternType] -def _is_annotated(nodes: List[Node]): - """ - Given a list of nodes (that represents an operator pattern), - check if any of the node is annotated, return True if any of the node - is annotated, otherwise return False - """ - annotated = False - for node in nodes: - annotated = annotated or ( - "quantization_annotation" in node.meta - and node.meta["quantization_annotation"]._annotated - ) - return annotated - - def _mark_nodes_as_annotated(nodes: List[Node]): for node in nodes: if node is not None: From 669e0e53ad6efe535c896c5f898092733cdacdc0 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 13 May 2024 10:14:32 +0800 Subject: [PATCH 005/706] rename set_module_name to set_module_name_qconfig Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 14 +++++++------- .../quantizer/x86_inductor_quantizer.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 2cc0e0172c81..4b9c6dfb0d18 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1866,7 +1866,7 @@ def test_qat_dynamic_quant_linear(self): ) @skipIfNoX86 - def test_set_module_name(self): + def test_set_module_name_qconfig(self): """Test that quantize the specific submodule.""" class Sub(torch.nn.Module): @@ -1892,7 +1892,7 @@ def forward(self, x): example_inputs = (torch.randn(3, 5),) # Set global to no quantization and then default config for a specific submodule. quantizer = X86InductorQuantizer() - quantizer.set_module_name( + quantizer.set_module_name_qconfig( "sub", xiq.get_default_x86_inductor_quantization_config() ) node_occurrence = { @@ -1914,7 +1914,7 @@ def forward(self, x): self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) @skipIfNoX86 - def test_set_module_name_with_underscores(self) -> None: + def test_set_module_name_qconfig_with_underscores(self) -> None: """Test that if a module name has an underscore, we can still quantize it.""" class M(torch.nn.Module): @@ -1930,7 +1930,7 @@ def forward(self, x): # Set global to no quantization and then default config for a specific submodule. quantizer = X86InductorQuantizer() - quantizer.set_module_name( + quantizer.set_module_name_qconfig( "foo_bar", xiq.get_default_x86_inductor_quantization_config() ) example_inputs = (torch.randn(2, 2),) @@ -1965,7 +1965,7 @@ def forward(self, x): count += 1 @skipIfNoX86 - def test_set_module_name_for_dynamic_quant(self): + def test_set_module_name_qconfig_for_dynamic_quant(self): """Test that quantize the specific submodule for dynamic quantization.""" with override_quantized_engine("x86"), torch.no_grad(): @@ -1978,8 +1978,8 @@ def test_set_module_name_for_dynamic_quant(self): ) quantizer = ( X86InductorQuantizer() - .set_module_name("q_proj", dynamic_config) - .set_module_name("v_proj", dynamic_config) + .set_module_name_qconfig("q_proj", dynamic_config) + .set_module_name_qconfig("v_proj", dynamic_config) ) node_occurrence = { # ops for quantizing/de-quantizing input diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 82e8014cd9ed..ceeeab021ee4 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -425,11 +425,11 @@ def set_module_type_qconfig( ) return self - def set_module_name( + def set_module_name_qconfig( self, module_name: str, quantization_config: Optional[QuantizationConfig] ): """Set quantization_config for a submodule with name: `module_name`, for example: - quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator + quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator patterns in the submodule with this module name with the given `quantization_config` """ assert ( From 642fb63545e35b8352b1d723fab6d52f750f69db Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 13 May 2024 10:46:09 +0800 Subject: [PATCH 006/706] update the submodule Signed-off-by: yiliu30 --- third_party/kineto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/kineto b/third_party/kineto index 327ac5052cf2..3a81076cc970 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 327ac5052cf25238fc769ab421c680d19b848eb3 +Subproject commit 3a81076cc97092666f319846f32f36b73ce2293e From b8308f9dd3c7fbebc27b9847dfdf774a60760ac6 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 13 May 2024 16:30:17 +0800 Subject: [PATCH 007/706] rename config to qconfig Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index ceeeab021ee4..87d4a4c8e817 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -355,7 +355,7 @@ def __init__(self): self.operator_type_qconfig: Dict[ torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} - self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} + self.module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: @@ -435,7 +435,7 @@ def set_module_name_qconfig( assert ( quantization_config is not None ), " quantization_config == None is not supported yet" - self.module_name_config[module_name] = quantization_config + self.module_name_qconfig[module_name] = quantization_config return self def _set_aten_operator_qconfig( @@ -604,12 +604,12 @@ def _annotate_by_single_config( annotator_func(self, model, quantization_config, filter_fn) return model - def _annotate_quantization_by_module_name_config( + def _annotate_quantization_by_module_name_qconfig( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: - for module_name, config in self.module_name_config.items(): + for module_name, qconfig in self.module_name_qconfig.items(): self._annotate_by_single_config( - model, config, _get_module_name_filter(module_name) + model, qconfig, _get_module_name_filter(module_name) ) return model @@ -635,8 +635,8 @@ def _annotate_for_static_quantization_config( """ # Step1: Recipe of fusion patterns like conv/linear. - if self.module_name_config: - self._annotate_quantization_by_module_name_config(model) + if self.module_name_qconfig: + self._annotate_quantization_by_module_name_qconfig(model) if self.operator_type_qconfig or self.global_config: self._annotate_static_quantization_by_op_type_and_global_config(model) @@ -659,8 +659,8 @@ def _annotate_for_static_quantization_config( def _annotate_for_dynamic_quantization_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: - if self.module_name_config: - self._annotate_quantization_by_module_name_config(model) + if self.module_name_qconfig: + self._annotate_quantization_by_module_name_qconfig(model) if self.operator_type_qconfig or self.global_config: self._annotate_dynamic_quantization_by_op_type_and_global_config(model) return model From aa185f5e7322e201c0b3122f0fd5a61d15f2d790 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 15 May 2024 21:33:58 +0800 Subject: [PATCH 008/706] unified apply qconfig Signed-off-by: yiliu30 --- torch/ao/quantization/quantizer/utils.py | 102 ++++++ .../quantizer/x86_inductor_quantizer.py | 291 +++++++++++++----- .../quantizer/xnnpack_quantizer.py | 51 +-- 3 files changed, 325 insertions(+), 119 deletions(-) diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index d6dd373a7ba7..832aad36598d 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -1,3 +1,5 @@ +import logging +import os from typing import Callable, List, Optional from torch.ao.quantization.pt2e.utils import _is_sym_size_node @@ -5,6 +7,9 @@ from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation from torch.fx import Node +log = logging.getLogger(__name__) +log.setLevel(os.environ.get("LOGLEVEL", "ERROR")) + def _annotate_input_qspec_map(node: Node, input_node: Node, qspec): quantization_annotation = node.meta.get( @@ -83,6 +88,53 @@ def _normalize_path(n): return module_name_filter +def _get_module_type_filter(tp: Callable): + """Get the module_type_filter function for a given module type, the filter accepts + a node and checks if the node comes from a module that has certain module type + + For example: + node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear + + + >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule + >> print(module_type_filter(node)) + True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) + """ + + tp_str = tp.__module__ + "." + tp.__qualname__ + # import pdb; pdb.set_trace() + + def module_type_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [] + for _, t in nn_module_stack.values(): + # export() returns str, but older APIs (e.g. capture_pre_autograd_graph) + # return type. Handle both cases. + if isinstance(t, type): + t = t.__module__ + "." + t.__qualname__ + types.append(t) + + return tp_str in types + + return module_type_filter + + +def _get_not_module_type_or_name_filter( + tp_list: List[Callable], module_name_list: List[str] +) -> Callable[[Node], bool]: + module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_module_type_or_name_filter(n: Node) -> bool: + return not any(f(n) for f in module_type_filters + module_name_list_filters) + + return not_module_type_or_name_filter + + def _is_annotated(nodes: List[Node]): """ Given a list of nodes (that represents an operator pattern), @@ -98,6 +150,13 @@ def _is_annotated(nodes: List[Node]): return annotated +class CurrentStage: + is_global = False + + +current_stage = CurrentStage() + + def _skip_annotate( nodes: List[Node], filter_fn: Optional[Callable[[Node], bool]] = None ): @@ -105,7 +164,50 @@ def _skip_annotate( # 1) Skip annotate if any node is already annotated if _is_annotated(nodes): skip_annotate = True + return skip_annotate # 2) Skip annotate if a) filter_fn is provided and b) any node fails the filter + # filter_fn result + # case 1, + # filter_fn False, False, False + # not filter_fn True, True, True + # any True + # -> skip + + # no node named as user specific + + # case 2, + # filter_fn True, False, False + # not filter_fn False, True, True + # any True + # -> skip + + # some node are not user specific + + # case 3, + # filter_fn True + # not filter_fn False + # any False + # -> not skip + # all node are user specific + if current_stage.is_global and filter_fn is not None: + for node in nodes: + if filter_fn(node): + log.warning("not skip nodes %s", nodes) + return False if filter_fn and any(not filter_fn(node) for node in nodes): + log.warning("skip nodes %s", nodes) skip_annotate = True + return skip_annotate + # import pdb; pdb.set_trace() + log.warning("not skip nodes %s", nodes) return skip_annotate + + +# def dump_function_name_decorator(func): +# def wrapper(*args, **kwargs): +# log.warning(f"entering {func.__name__}") +# result = func(*args, **kwargs) +# log.warning(f"exiting {func.__name__}") +# return result + +# return wrapper diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 87d4a4c8e817..c5ec7e444b9d 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -43,6 +43,8 @@ _get_module_name_filter, _is_annotated, _skip_annotate, + current_stage, + log, ) from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( get_bias_qspec, @@ -108,6 +110,125 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): QUANT_ANNOTATION_KEY = "quantization_annotation" +def _get_operator_type_filter(operator_type: Callable): + """Get the operator_type_filter function for a given operator type, the filter accepts + a node and checks if the node has certain module type + + For example: + node: linear_op = call_function[...](...) # linear_op.target if torch.ops.aten.linear.default + + + >> operator_type_filter = _get_operator_type_filter(torch.ops.aten.linear.default) + >> print(operator_type_filter(node)) + True # the node's target is `torch.ops.aten.linear.default` + """ + + def operator_type_filter(n: Node) -> bool: + result = n.target == operator_type + # log.warning( + # f"!!for node: {n}, (name: {n.name}, target:{n.target}) n.target equal operator_type ({operator_type})? {result}" + # ) + log.warning( + "!!for node: %s, (name: %s, target:%s) n.target equal operator_type (%s))? %s", + n, + n.name, + n.target, + operator_type, + result, + ) + return result + + return operator_type_filter + + +def _x86_get_not_module_type_or_name_filter( + tp_list: List[torch._ops.OpOverloadPacket], module_name_list: List[str] +) -> Callable[[Node], bool]: + # Check if the node is 1) belong to the `default_quantizable_ops` and 2) not be marked + # by `set_module_name_qconfig`, or `set_module_type_qconfig` `set_function_type_qconfig`. + + operator_type_filters = [_get_operator_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_module_type_or_name_filter(n: Node) -> bool: + # For global_config, only quantize the `default_quantizable_ops` + belong_to_default_quantizable_ops = n.target in default_quantizable_ops + # if n.target not in default_quantizable_ops: + # result1 = False + # for f in module_name_list_filters: + # log.warning(f"not module_name for node: {n}, (name: {n.name}, target:{n.target}), f:{f} f(n) is {f(n)}") + + # for f in operator_type_filters: + # log.warning(f"not module type: {n}, (name: {n.name}, target:{n.target}), f:{f} f(n) is {f(n)}") + + not_module_type_or_module_name_node = not any( + f(n) for f in operator_type_filters + module_name_list_filters + ) + final_result = ( + belong_to_default_quantizable_ops and not_module_type_or_module_name_node + ) + + # rewrite it not use the f-string + log.warning( + ( + "for node: %s, (name: %s, target:%s), node in default_quantizable_ops? %s;" + "node is not used by operator_type_filters or module_name_list_filters? %s" + ), + n, + n.name, + n.target, + belong_to_default_quantizable_ops, + not_module_type_or_module_name_node, + ) + return final_result + + return not_module_type_or_name_filter + + +def _node_checker_for_module_name_qconfig(nodes, filter_fn): + skip_annotate = False + # 1) Skip annotate if any node is already annotated + if _is_annotated(nodes): + skip_annotate = True + return skip_annotate + # 2) Skip annotate if a) filter_fn is provided and b) any node fails the filter + # filter_fn result + # case 1, + # filter_fn False, False, False + # not filter_fn True, True, True + # any True + # -> skip + + # no node named as user specific + + # case 2, + # filter_fn True, False, False + # not filter_fn False, True, True + # any True + # -> skip + + # some node are not user specific + + # case 3, + # filter_fn True + # not filter_fn False + # any False + # -> not skip + # all node are user specific + if current_stage.is_global and filter_fn is not None: + for node in nodes: + if filter_fn(node): + log.warning("not skip nodes %s", nodes) + return False + if filter_fn and any(not filter_fn(node) for node in nodes): + log.warning("skip nodes %s", nodes) + skip_annotate = True + return skip_annotate + # import pdb; pdb.set_trace() + log.warning("not skip nodes %s", nodes) + return skip_annotate + + def _map_module_function_to_aten_operator_type(): module_function_to_aten_operator: Dict[Callable, torch._ops.OpOverloadPacket] = {} map_list = ( @@ -161,10 +282,12 @@ def _mark_nodes_as_annotated(nodes: List[Node]): Optional[List[List[Node]]], ] +from collections import OrderedDict + X86_ANNOTATORS_REGISTRY: Dict[str, Dict[str, AnnotatorType]] = { - "STATIC": {}, - "DYNAMIC": {}, - "STATIC_QAT_ONLY": {}, + "STATIC": OrderedDict(), + "DYNAMIC": OrderedDict(), + "STATIC_QAT_ONLY": OrderedDict(), } AnnotatorsType: TypeAlias = Dict[str, AnnotatorType] @@ -356,6 +479,7 @@ def __init__(self): torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} self.module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} + self._module_type_qconfig = {} @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: @@ -396,6 +520,7 @@ def set_function_type_qconfig( function_type: Callable, quantization_config: Optional[QuantizationConfig], ) -> "X86InductorQuantizer": + self._module_type_qconfig[function_type] = quantization_config if function_type in X86InductorQuantizer.module_function_to_aten_operator_type: self._set_aten_operator_qconfig( X86InductorQuantizer.module_function_to_aten_operator_type[ @@ -414,6 +539,7 @@ def set_module_type_qconfig( module_type: torch.nn.Module, quantization_config: Optional[QuantizationConfig], ) -> "X86InductorQuantizer": + self._module_type_qconfig[module_type] = quantization_config if module_type in X86InductorQuantizer.module_function_to_aten_operator_type: self._set_aten_operator_qconfig( X86InductorQuantizer.module_function_to_aten_operator_type[module_type], @@ -571,7 +697,7 @@ def _get_input_idx_for_binary_node( def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """just handling global spec for now""" if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] - model = self._annotate_for_dynamic_quantization_config(model) + model = self._annotate_for_static_quantization_config(model) else: model = self._annotate_for_static_quantization_config(model) return model @@ -607,10 +733,25 @@ def _annotate_by_single_config( def _annotate_quantization_by_module_name_qconfig( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_qconfig.keys()) + for module_name, qconfig in self.module_name_qconfig.items(): self._annotate_by_single_config( model, qconfig, _get_module_name_filter(module_name) ) + tp_list = list(self.operator_type_qconfig.keys()) + for operator_type, qconfig in self.operator_type_qconfig.items(): + self._annotate_by_single_config( + model, qconfig, _get_operator_type_filter(operator_type) + ) + log.warning("start to handle global config") + if self.global_config: + current_stage.is_global = True + self._annotate_by_single_config( + model, + self.global_config, + _x86_get_not_module_type_or_name_filter(tp_list, module_name_list), + ) return model def _annotate_static_quantization_by_op_type_and_global_config( @@ -635,10 +776,13 @@ def _annotate_for_static_quantization_config( """ # Step1: Recipe of fusion patterns like conv/linear. - if self.module_name_qconfig: - self._annotate_quantization_by_module_name_qconfig(model) - if self.operator_type_qconfig or self.global_config: - self._annotate_static_quantization_by_op_type_and_global_config(model) + + self._annotate_quantization_by_module_name_qconfig(model) + + # if self.module_name_qconfig: + # self._annotate_quantization_by_module_name_qconfig(model) + # if self.operator_type_qconfig or self.global_config: + # self._annotate_static_quantization_by_op_type_and_global_config(model) # Step2: Recipe to propagate annotation for patterns beside conv/linear. # Go through all the nodes from start to end. @@ -661,8 +805,8 @@ def _annotate_for_dynamic_quantization_config( ) -> torch.fx.GraphModule: if self.module_name_qconfig: self._annotate_quantization_by_module_name_qconfig(model) - if self.operator_type_qconfig or self.global_config: - self._annotate_dynamic_quantization_by_op_type_and_global_config(model) + # if self.operator_type_qconfig or self.global_config: + # self._annotate_dynamic_quantization_by_op_type_and_global_config(model) return model def _annotate_dynamic_quantization_by_op_type_and_global_config( @@ -1261,69 +1405,6 @@ def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: self._annotate_output_share_observer_as_input(input_node, node) return - @register_annotator([STATIC_ANNOTATORS, DYNAMIC_ANNOTATORS]) - def _annotate_linear( - self, - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, - ) -> None: - linear_partitions = get_source_partitions( - gm.graph, [torch.nn.Linear, torch.nn.functional.linear] - ) - linear_partitions = list( - itertools.chain.from_iterable(linear_partitions.values()) - ) - for partition in linear_partitions: - if len(partition.output_nodes) > 1: - raise ValueError( - "Linear partition cannot have more than one output node" - ) - linear_node = partition.output_nodes[0] - if linear_node.op != "call_function" or linear_node.target not in ( - torch.ops.aten.linear.default, - ): - raise ValueError(f"{linear_node} is not an aten linear operator") - # skip annotation if it is already annotated - if _skip_annotate([linear_node], filter_fn): - continue - self._annotate_linear_node_helper(linear_node, True, quantization_config) - - @register_annotator(STATIC_ANNOTATORS) - def _annotate_linear_unary( - self, - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, - ) -> None: - postop_list = [ - torch.nn.ReLU, - torch.nn.LeakyReLU, - torch.nn.Tanh, - torch.nn.GELU, - ] - fused_partitions: List[tuple] = [] - for postop in postop_list: - fused_partitions = fused_partitions + find_sequential_partitions( - gm, [torch.nn.Linear, postop] - ) - for fused_partition in fused_partitions: - linear_partition, unary_partition = fused_partition - linear_node, unary_node = self._get_output_nodes_of_partitions( - [linear_partition, unary_partition] - ) - if linear_node.op != "call_function" or linear_node.target not in ( - torch.ops.aten.linear.default, - ): - continue - if _skip_annotate([unary_node, linear_node], filter_fn): - continue - self._annotate_linear_node_helper(linear_node, False, quantization_config) - unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - _annotated=True, - _is_output_of_quantized_pattern=True, - ) - @register_annotator(STATIC_ANNOTATORS) def _annotate_linear_binary_unary( self, @@ -1407,6 +1488,70 @@ def _annotate_linear_binary_unary( _is_output_of_quantized_pattern=True, ) + @register_annotator(STATIC_ANNOTATORS) + def _annotate_linear_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> None: + postop_list = [ + torch.nn.ReLU, + torch.nn.LeakyReLU, + torch.nn.Tanh, + torch.nn.GELU, + ] + fused_partitions: List[tuple] = [] + for postop in postop_list: + fused_partitions = fused_partitions + find_sequential_partitions( + gm, [torch.nn.Linear, postop] + ) + for fused_partition in fused_partitions: + linear_partition, unary_partition = fused_partition + linear_node, unary_node = self._get_output_nodes_of_partitions( + [linear_partition, unary_partition] + ) + if linear_node.op != "call_function" or linear_node.target not in ( + torch.ops.aten.linear.default, + ): + continue + if _skip_annotate([unary_node, linear_node], filter_fn): + continue + self._annotate_linear_node_helper(linear_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + @register_annotator([STATIC_ANNOTATORS, DYNAMIC_ANNOTATORS]) + def _annotate_linear( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> None: + linear_partitions = get_source_partitions( + gm.graph, [torch.nn.Linear, torch.nn.functional.linear] + ) + linear_partitions = list( + itertools.chain.from_iterable(linear_partitions.values()) + ) + for partition in linear_partitions: + log.warning("for partition: %s", partition) + if len(partition.output_nodes) > 1: + raise ValueError( + "Linear partition cannot have more than one output node" + ) + linear_node = partition.output_nodes[0] + if linear_node.op != "call_function" or linear_node.target not in ( + torch.ops.aten.linear.default, + ): + raise ValueError(f"{linear_node} is not an aten linear operator") + # skip annotation if it is already annotated + if _skip_annotate([linear_node], filter_fn): + continue + self._annotate_linear_node_helper(linear_node, True, quantization_config) + def validate(self, model: torch.fx.GraphModule) -> None: pass diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index e13a79f39267..4649c3546f05 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -22,7 +22,11 @@ ) from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.utils import _get_module_name_filter +from torch.ao.quantization.quantizer.utils import ( + _get_module_name_filter, + _get_module_type_filter, + _get_not_module_type_or_name_filter, +) from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( _convert_scalars_to_attrs, @@ -193,51 +197,6 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_symmetric_config_and_operators() -def _get_module_type_filter(tp: Callable): - """Get the module_type_filter function for a given module type, the filter accepts - a node and checks if the node comes from a module that has certain module type - - For example: - node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear - - - >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule - >> print(module_type_filter(node)) - True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) - """ - - tp_str = tp.__module__ + "." + tp.__qualname__ - - def module_type_filter(n: Node) -> bool: - # example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - nn_module_stack = n.meta.get("nn_module_stack", {}) - types = [] - for _, t in nn_module_stack.values(): - # export() returns str, but older APIs (e.g. capture_pre_autograd_graph) - # return type. Handle both cases. - if isinstance(t, type): - t = t.__module__ + "." + t.__qualname__ - types.append(t) - return tp_str in types - - return module_type_filter - - -def _get_not_module_type_or_name_filter( - tp_list: List[Callable], module_name_list: List[str] -) -> Callable[[Node], bool]: - module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] - module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] - - def not_module_type_or_name_filter(n: Node) -> bool: - return not any(f(n) for f in module_type_filters + module_name_list_filters) - - return not_module_type_or_name_filter - - class XNNPACKQuantizer(Quantizer): supported_config_and_operators = _get_supported_config_and_operators() STATIC_QAT_ONLY_OPS = [ From 9cfaddb8222d695ebf9290e1d23c39cead102987 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 16 May 2024 12:51:53 +0800 Subject: [PATCH 009/706] fixed propogation annotate Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 57 +++++++++++++------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index c5ec7e444b9d..898d17480d45 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -224,7 +224,6 @@ def _node_checker_for_module_name_qconfig(nodes, filter_fn): log.warning("skip nodes %s", nodes) skip_annotate = True return skip_annotate - # import pdb; pdb.set_trace() log.warning("not skip nodes %s", nodes) return skip_annotate @@ -728,6 +727,10 @@ def _annotate_by_single_config( for annotator_func in annotators.values(): annotator_func(self, model, quantization_config, filter_fn) + + self._annotate_propagation_quantizable_pattern_entry(model, quantization_config, filter_fn) + self._annotate_output_for_int8_in_int8_out_pattern_entry(model, quantization_config, filter_fn) + return model def _annotate_quantization_by_module_name_qconfig( @@ -788,15 +791,15 @@ def _annotate_for_static_quantization_config( # Go through all the nodes from start to end. # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/ # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 - for node in model.graph.nodes: - self._annotate_propagation_quantizable_pattern(node) + # for node in model.graph.nodes: + # self._annotate_propagation_quantizable_pattern(node) - # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized - # in inputs. So, we can fuse dq-operator-q into a quantized op. - # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ - # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 - for node in model.graph.nodes: - self._annotate_output_for_int8_in_int8_out_pattern(node) + # # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized + # # in inputs. So, we can fuse dq-operator-q into a quantized op. + # # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 + # for node in model.graph.nodes: + # self._annotate_output_for_int8_in_int8_out_pattern(node) return model @@ -1251,7 +1254,8 @@ def _annotate_conv2d( self._annotate_conv_node_helper(conv_node, True, quantization_config) def _annotate_maxpool2d( - self, node: Node, quantization_config: QuantizationConfig + self, node: Node, quantization_config: QuantizationConfig, + filter_fn ) -> None: if node.target is not torch.ops.aten.max_pool2d.default: return @@ -1262,6 +1266,8 @@ def _annotate_maxpool2d( ] ): return + if _skip_annotate([maxpool_node], filter_fn): + return input_node = maxpool_node.args[0] assert isinstance(input_node, Node) input_qspec_map = {} @@ -1298,14 +1304,20 @@ def _annotate_cat( _annotated=True, _is_output_of_quantized_pattern=True, ) - - def _annotate_propagation_quantizable_pattern(self, node: Node) -> None: + + def _annotate_propagation_quantizable_pattern_entry(self, model, quantization_config, filter_fn): + for node in model.graph.nodes: + # if _skip_annotate([node], filter_fn): + # continue + self._annotate_propagation_quantizable_pattern(node, quantization_config, filter_fn) + + def _annotate_propagation_quantizable_pattern(self, node: Node, quantization_config, filter_fn) -> None: # Propagate annotation to quantizable patterns. if ( (node.target in propagation_quantizable_ops) and (not _is_any_annotated([node])) and (node.op == "call_function") - and (quantization_config := self._get_aten_operator_qconfig(node.target)) # type: ignore[arg-type] + # and (quantization_config := self._get_aten_operator_qconfig(node.target)) # type: ignore[arg-type] ): def is_all_inputs_connected_to_quantized_op(input_nodes): @@ -1320,7 +1332,7 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): input_nodes_to_check = [node.all_input_nodes[0]] if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): return - self._annotate_maxpool2d(node, quantization_config) + self._annotate_maxpool2d(node, quantization_config, filter_fn) return elif node.target is torch.ops.aten.cat.default: input_nodes_to_check = node.all_input_nodes @@ -1361,8 +1373,18 @@ def _annotate_output_share_observer_as_input( edge_or_node ) return - - def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: + + + def _annotate_output_for_int8_in_int8_out_pattern_entry(self, model, quantization_config, filter_fn): + for node in model.graph.nodes: + # if _skip_annotate([node], filter_fn): + # continue + # if quantization_config is None: + # return + + self._annotate_output_for_int8_in_int8_out_pattern(node, quantization_config, filter_fn) + + def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node, quantization_config, filter_fn) -> None: r""" Check and insert observer at output of node in int8_in_int8_out_ops if needed. Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/ @@ -1382,6 +1404,9 @@ def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: ] ): return + # !!! Do not skip this, this annotator annotate a annotated node???? + # if _skip_annotate([maxpool_node], filter_fn): + # return # Get the quantization_annotation from getitem_node maxpool_node_quantization_annotation = ( maxpool_node.meta[QUANT_ANNOTATION_KEY] From 886c0763c147c5960b020fb6e3dd214079aaa52f Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 17 May 2024 13:09:05 +0800 Subject: [PATCH 010/706] add more UTs Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 110 ++++++++++++++++++ .../quantizer/x86_inductor_quantizer.py | 76 +++++++----- 2 files changed, 160 insertions(+), 26 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 4b9c6dfb0d18..5ff7a261bd8e 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -538,6 +538,7 @@ def _test_quantizer( expected_node_occurrence, expected_node_list=None, is_qat=False, + debug=False, ): m_eager = model.train() if is_qat else model.eval() @@ -556,6 +557,8 @@ def _test_quantizer( prepare_model = copy.deepcopy(m) m = convert_pt2e(m) convert_model = copy.deepcopy(m) + if debug: + convert_model.print_readable(True) pt2_quant_output = m(*example_inputs) node_occurrence = { ns.call_function(k): v for k, v in expected_node_occurrence.items() @@ -1913,6 +1916,113 @@ def forward(self, x): ] self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + @skipIfNoX86 + def test_set_module_name_and_set_module_type_case2(self): + """ + + All linear are quantized except the second one. + """ + + class Sub(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set global to no quantization and then default config for a specific submodule. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( + torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config() + ) + + node_occurrence = { + torch.ops.aten.linear.default: 2, + # input and output for the first linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # first linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.aten.linear.default, + # second linear is not quantized + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, example_inputs, quantizer, node_occurrence, node_list, debug=True + ) + + @skipIfNoX86 + def test_set_module_name_and_set_module_type(self): + """ + All linear are not quantized except the second one. + """ + + class Sub(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set global to no quantization and then default config for a specific submodule. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config() + ).set_module_type_qconfig(torch.nn.Linear, None) + + node_occurrence = { + torch.ops.aten.linear.default: 2, + # input and output for the second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # second linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, example_inputs, quantizer, node_occurrence, node_list, debug=True + ) + @skipIfNoX86 def test_set_module_name_qconfig_with_underscores(self) -> None: """Test that if a module name has an underscore, we can still quantize it.""" diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 898d17480d45..73f357d1c5f0 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -110,7 +110,7 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): QUANT_ANNOTATION_KEY = "quantization_annotation" -def _get_operator_type_filter(operator_type: Callable): +def _get_operator_type_filter(operator_type: Callable, module_name_list): """Get the operator_type_filter function for a given operator type, the filter accepts a node and checks if the node has certain module type @@ -122,8 +122,12 @@ def _get_operator_type_filter(operator_type: Callable): >> print(operator_type_filter(node)) True # the node's target is `torch.ops.aten.linear.default` """ + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] def operator_type_filter(n: Node) -> bool: + not_module_name_node = not any(f(n) for f in module_name_list_filters) + # if not not_module_name_node: + # return False result = n.target == operator_type # log.warning( # f"!!for node: {n}, (name: {n.name}, target:{n.target}) n.target equal operator_type ({operator_type})? {result}" @@ -136,7 +140,7 @@ def operator_type_filter(n: Node) -> bool: operator_type, result, ) - return result + return not_module_name_node and result return operator_type_filter @@ -147,8 +151,12 @@ def _x86_get_not_module_type_or_name_filter( # Check if the node is 1) belong to the `default_quantizable_ops` and 2) not be marked # by `set_module_name_qconfig`, or `set_module_type_qconfig` `set_function_type_qconfig`. - operator_type_filters = [_get_operator_type_filter(tp) for tp in tp_list] - module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + operator_type_filters = [ + _get_operator_type_filter(tp, module_name_list) for tp in tp_list + ] + module_name_list_filters: List[ + Callable + ] = [] # [_get_module_name_filter(m) for m in module_name_list] def not_module_type_or_name_filter(n: Node) -> bool: # For global_config, only quantize the `default_quantizable_ops` @@ -557,9 +565,9 @@ def set_module_name_qconfig( quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator patterns in the submodule with this module name with the given `quantization_config` """ - assert ( - quantization_config is not None - ), " quantization_config == None is not supported yet" + # assert ( + # quantization_config is not None + # ), " quantization_config == None is not supported yet" self.module_name_qconfig[module_name] = quantization_config return self @@ -709,7 +717,7 @@ def _annotate_by_single_config( ) -> torch.fx.GraphModule: """Select the annotator functions according to the `quantization_config` and apply.""" - # implement the support for None to be canceling out previous annotations + # For `quantization_config`, skip the annotation, and it won't be annotated by the next stage either. if quantization_config is None: return model @@ -727,9 +735,13 @@ def _annotate_by_single_config( for annotator_func in annotators.values(): annotator_func(self, model, quantization_config, filter_fn) - - self._annotate_propagation_quantizable_pattern_entry(model, quantization_config, filter_fn) - self._annotate_output_for_int8_in_int8_out_pattern_entry(model, quantization_config, filter_fn) + + self._annotate_propagation_quantizable_pattern_entry( + model, quantization_config, filter_fn + ) + self._annotate_output_for_int8_in_int8_out_pattern_entry( + model, quantization_config, filter_fn + ) return model @@ -745,7 +757,9 @@ def _annotate_quantization_by_module_name_qconfig( tp_list = list(self.operator_type_qconfig.keys()) for operator_type, qconfig in self.operator_type_qconfig.items(): self._annotate_by_single_config( - model, qconfig, _get_operator_type_filter(operator_type) + model, + qconfig, + _get_operator_type_filter(operator_type, module_name_list), ) log.warning("start to handle global config") if self.global_config: @@ -1254,8 +1268,7 @@ def _annotate_conv2d( self._annotate_conv_node_helper(conv_node, True, quantization_config) def _annotate_maxpool2d( - self, node: Node, quantization_config: QuantizationConfig, - filter_fn + self, node: Node, quantization_config: QuantizationConfig, filter_fn ) -> None: if node.target is not torch.ops.aten.max_pool2d.default: return @@ -1304,14 +1317,20 @@ def _annotate_cat( _annotated=True, _is_output_of_quantized_pattern=True, ) - - def _annotate_propagation_quantizable_pattern_entry(self, model, quantization_config, filter_fn): + + def _annotate_propagation_quantizable_pattern_entry( + self, model, quantization_config, filter_fn + ): for node in model.graph.nodes: # if _skip_annotate([node], filter_fn): # continue - self._annotate_propagation_quantizable_pattern(node, quantization_config, filter_fn) - - def _annotate_propagation_quantizable_pattern(self, node: Node, quantization_config, filter_fn) -> None: + self._annotate_propagation_quantizable_pattern( + node, quantization_config, filter_fn + ) + + def _annotate_propagation_quantizable_pattern( + self, node: Node, quantization_config, filter_fn + ) -> None: # Propagate annotation to quantizable patterns. if ( (node.target in propagation_quantizable_ops) @@ -1373,18 +1392,23 @@ def _annotate_output_share_observer_as_input( edge_or_node ) return - - - def _annotate_output_for_int8_in_int8_out_pattern_entry(self, model, quantization_config, filter_fn): + + def _annotate_output_for_int8_in_int8_out_pattern_entry( + self, model, quantization_config, filter_fn + ): for node in model.graph.nodes: # if _skip_annotate([node], filter_fn): # continue # if quantization_config is None: # return - - self._annotate_output_for_int8_in_int8_out_pattern(node, quantization_config, filter_fn) - - def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node, quantization_config, filter_fn) -> None: + + self._annotate_output_for_int8_in_int8_out_pattern( + node, quantization_config, filter_fn + ) + + def _annotate_output_for_int8_in_int8_out_pattern( + self, node: Node, quantization_config, filter_fn + ) -> None: r""" Check and insert observer at output of node in int8_in_int8_out_ops if needed. Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/ From 87019b47ef236feb7ff7dace9b1de4ccc822d659 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 17 May 2024 16:47:00 +0800 Subject: [PATCH 011/706] unifed the set_module_name, set_module_type Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 381 +++++------------- 1 file changed, 108 insertions(+), 273 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 73f357d1c5f0..714cc3a221b9 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -14,11 +14,8 @@ Set, Tuple, TYPE_CHECKING, - Union, ) -from typing_extensions import TypeAlias - import torch import torch.nn.functional as F from torch.ao.quantization.fake_quantize import ( @@ -41,7 +38,6 @@ ) from torch.ao.quantization.quantizer.utils import ( _get_module_name_filter, - _is_annotated, _skip_annotate, current_stage, log, @@ -126,21 +122,16 @@ def _get_operator_type_filter(operator_type: Callable, module_name_list): def operator_type_filter(n: Node) -> bool: not_module_name_node = not any(f(n) for f in module_name_list_filters) - # if not not_module_name_node: - # return False - result = n.target == operator_type - # log.warning( - # f"!!for node: {n}, (name: {n.name}, target:{n.target}) n.target equal operator_type ({operator_type})? {result}" - # ) + has_certain_operator_type = n.target == operator_type log.warning( "!!for node: %s, (name: %s, target:%s) n.target equal operator_type (%s))? %s", n, n.name, n.target, operator_type, - result, + has_certain_operator_type, ) - return not_module_name_node and result + return not_module_name_node and has_certain_operator_type return operator_type_filter @@ -193,49 +184,6 @@ def not_module_type_or_name_filter(n: Node) -> bool: return not_module_type_or_name_filter -def _node_checker_for_module_name_qconfig(nodes, filter_fn): - skip_annotate = False - # 1) Skip annotate if any node is already annotated - if _is_annotated(nodes): - skip_annotate = True - return skip_annotate - # 2) Skip annotate if a) filter_fn is provided and b) any node fails the filter - # filter_fn result - # case 1, - # filter_fn False, False, False - # not filter_fn True, True, True - # any True - # -> skip - - # no node named as user specific - - # case 2, - # filter_fn True, False, False - # not filter_fn False, True, True - # any True - # -> skip - - # some node are not user specific - - # case 3, - # filter_fn True - # not filter_fn False - # any False - # -> not skip - # all node are user specific - if current_stage.is_global and filter_fn is not None: - for node in nodes: - if filter_fn(node): - log.warning("not skip nodes %s", nodes) - return False - if filter_fn and any(not filter_fn(node) for node in nodes): - log.warning("skip nodes %s", nodes) - skip_annotate = True - return skip_annotate - log.warning("not skip nodes %s", nodes) - return skip_annotate - - def _map_module_function_to_aten_operator_type(): module_function_to_aten_operator: Dict[Callable, torch._ops.OpOverloadPacket] = {} map_list = ( @@ -279,53 +227,6 @@ def _mark_nodes_as_annotated(nodes: List[Node]): node.meta[QUANT_ANNOTATION_KEY]._annotated = True -AnnotatorType: TypeAlias = Callable[ - [ - "X86InductorQuantizer", - torch.fx.GraphModule, - Optional[QuantizationConfig], - Optional[Callable[[Node], bool]], - ], - Optional[List[List[Node]]], -] - -from collections import OrderedDict - -X86_ANNOTATORS_REGISTRY: Dict[str, Dict[str, AnnotatorType]] = { - "STATIC": OrderedDict(), - "DYNAMIC": OrderedDict(), - "STATIC_QAT_ONLY": OrderedDict(), -} - -AnnotatorsType: TypeAlias = Dict[str, AnnotatorType] - -# Annotators collection -STATIC_ANNOTATORS: AnnotatorsType = X86_ANNOTATORS_REGISTRY["STATIC"] -DYNAMIC_ANNOTATORS: AnnotatorsType = X86_ANNOTATORS_REGISTRY["DYNAMIC"] -STATIC_QAT_ONLY_ANNOTATORS: AnnotatorsType = X86_ANNOTATORS_REGISTRY["STATIC_QAT_ONLY"] -# For static QAT, apply the `STATIC_QAT_ONLY_ANNOTATORS` and `STATIC_ANNOTATORS` in order. - - -def register_annotator( - annotators_list: Union[AnnotatorsType, List[AnnotatorsType]], - annotator_name: Optional[str] = None, -): - # register annotator functions into one or more annotator collections. - def decorator( - annotator: AnnotatorType, - annotators_list: Union[AnnotatorsType, List[AnnotatorsType]] = annotators_list, - annotator_name: Optional[str] = annotator_name, - ) -> AnnotatorType: - if not isinstance(annotators_list, list): - annotators_list = [annotators_list] - annotator_name = annotator_name or annotator.__name__ - for annotators in annotators_list: - annotators[annotator_name] = annotator - return annotator - - return decorator - - def _is_node_annotated(_node): """ return True if the node is annotated, otherwise return False @@ -704,48 +605,24 @@ def _get_input_idx_for_binary_node( def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """just handling global spec for now""" if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] - model = self._annotate_for_static_quantization_config(model) + model = self._annotate_entry(model) else: - model = self._annotate_for_static_quantization_config(model) + model = self._annotate_entry(model) return model - def _annotate_by_single_config( - self, - model: torch.fx.GraphModule, - quantization_config: Optional[QuantizationConfig], - filter_fn: Optional[Callable[[Node], bool]] = None, - ) -> torch.fx.GraphModule: - """Select the annotator functions according to the `quantization_config` and apply.""" - - # For `quantization_config`, skip the annotation, and it won't be annotated by the next stage either. - if quantization_config is None: - return model - - if ( - quantization_config.input_activation - and quantization_config.input_activation.is_dynamic - ): - annotators = DYNAMIC_ANNOTATORS - else: - annotators = STATIC_ANNOTATORS - if quantization_config.is_qat: - # Apply QAT-specific annotators first - for annotator_func in STATIC_QAT_ONLY_ANNOTATORS.values(): - annotator_func(self, model, quantization_config, filter_fn) - - for annotator_func in annotators.values(): - annotator_func(self, model, quantization_config, filter_fn) + def _annotate_by_single_config(self, model, config, filter_fn): + if config is None: + return + self._annotate_all_conv2d_fusion_pattern(model, config, filter_fn) + self._annotate_all_linear_fusion_pattern(model, config, filter_fn) + self._annotate_matmul(model, config, filter_fn) - self._annotate_propagation_quantizable_pattern_entry( - model, quantization_config, filter_fn - ) + self._annotate_propagation_quantizable_pattern_entry(model, config, filter_fn) self._annotate_output_for_int8_in_int8_out_pattern_entry( - model, quantization_config, filter_fn + model, config, filter_fn ) - return model - - def _annotate_quantization_by_module_name_qconfig( + def _annotate_quantization_with_all_qconfig( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: module_name_list = list(self.module_name_qconfig.keys()) @@ -761,7 +638,6 @@ def _annotate_quantization_by_module_name_qconfig( qconfig, _get_operator_type_filter(operator_type, module_name_list), ) - log.warning("start to handle global config") if self.global_config: current_stage.is_global = True self._annotate_by_single_config( @@ -771,16 +647,7 @@ def _annotate_quantization_by_module_name_qconfig( ) return model - def _annotate_static_quantization_by_op_type_and_global_config( - self, model: torch.fx.GraphModule - ): - self._annotate_conv2d_fusion_pattern(model) - self._annotate_linear_fusion_pattern(model) - self._annotate_matmul_pattern(model) - - def _annotate_for_static_quantization_config( - self, model: torch.fx.GraphModule - ) -> torch.fx.GraphModule: + def _annotate_entry(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: r""" High-level description of quantization recipe for X86 Inductor Backend: Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. @@ -794,10 +661,10 @@ def _annotate_for_static_quantization_config( # Step1: Recipe of fusion patterns like conv/linear. - self._annotate_quantization_by_module_name_qconfig(model) + self._annotate_quantization_with_all_qconfig(model) # if self.module_name_qconfig: - # self._annotate_quantization_by_module_name_qconfig(model) + # self._annotate_quantization_with_all_qconfig(model) # if self.operator_type_qconfig or self.global_config: # self._annotate_static_quantization_by_op_type_and_global_config(model) @@ -817,31 +684,15 @@ def _annotate_for_static_quantization_config( return model - def _annotate_for_dynamic_quantization_config( - self, model: torch.fx.GraphModule - ) -> torch.fx.GraphModule: - if self.module_name_qconfig: - self._annotate_quantization_by_module_name_qconfig(model) - # if self.operator_type_qconfig or self.global_config: - # self._annotate_dynamic_quantization_by_op_type_and_global_config(model) - return model - - def _annotate_dynamic_quantization_by_op_type_and_global_config( - self, model: torch.fx.GraphModule - ) -> torch.fx.GraphModule: - self._annotate_linear_fusion_pattern(model) - return model - - def _annotate_qat_conv2d_fusion_pattern( - self, model: torch.fx.GraphModule, config: QuantizationConfig + def _annotate_all_qat_conv2d_fusion_pattern( + self, model: torch.fx.GraphModule, config: QuantizationConfig, filter_fn ): # Annotate QAT Specific patterns - self._annotate_qat_conv2d_bn_binary_unary(model, config) - self._annotate_qat_conv2d_bn_binary(model, config) - self._annotate_qat_conv2d_bn_unary(model, config) - self._annotate_qat_conv2d_bn(model, config) + self._annotate_qat_conv2d_bn_binary_unary(model, config, filter_fn) + self._annotate_qat_conv2d_bn_binary(model, config, filter_fn) + self._annotate_qat_conv2d_bn_unary(model, config, filter_fn) + self._annotate_qat_conv2d_bn(model, config, filter_fn) - @register_annotator(STATIC_QAT_ONLY_ANNOTATORS) def _annotate_qat_conv2d_bn_binary_unary( self, gm: torch.fx.GraphModule, @@ -913,7 +764,6 @@ def _annotate_qat_conv2d_bn_binary_unary( nodes_to_mark_annotated.extend(list(unary_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) - @register_annotator(STATIC_QAT_ONLY_ANNOTATORS) def _annotate_qat_conv2d_bn_binary( self, gm: torch.fx.GraphModule, @@ -973,7 +823,6 @@ def _annotate_qat_conv2d_bn_binary( nodes_to_mark_annotated.extend(list(binary_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) - @register_annotator(STATIC_QAT_ONLY_ANNOTATORS) def _annotate_qat_conv2d_bn_unary( self, gm: torch.fx.GraphModule, @@ -1025,7 +874,6 @@ def _annotate_qat_conv2d_bn_unary( nodes_to_mark_annotated.extend(list(unary_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) - @register_annotator(STATIC_QAT_ONLY_ANNOTATORS) def _annotate_qat_conv2d_bn( self, gm: torch.fx.GraphModule, @@ -1063,29 +911,27 @@ def _annotate_qat_conv2d_bn( nodes_to_mark_annotated.extend(list(bn_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) - def _annotate_conv2d_fusion_pattern(self, model: torch.fx.GraphModule): - if config := self._get_aten_operator_qconfig(torch.ops.aten.conv2d.default): - if config.is_qat: - # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat - self._annotate_qat_conv2d_fusion_pattern(model, config) - self._annotate_conv2d_binary_unary(model, config) - self._annotate_conv2d_binary(model, config) - self._annotate_conv2d_unary(model, config) - self._annotate_conv2d(model, config) - - def _annotate_linear_fusion_pattern(self, model: torch.fx.GraphModule): - if config := self._get_aten_operator_qconfig(torch.ops.aten.linear.default): - if config.input_activation and not config.input_activation.is_dynamic: - # Weiwen: Dynamic Quant of linear unary will be supported in next step - self._annotate_linear_binary_unary(model, config) - self._annotate_linear_unary(model, config) - self._annotate_linear(model, config) - - def _annotate_matmul_pattern(self, model: torch.fx.GraphModule): - if config := self._get_aten_operator_qconfig(torch.ops.aten.matmul.default): - self._annotate_matmul(model, config) - - @register_annotator(STATIC_ANNOTATORS) + def _annotate_all_conv2d_fusion_pattern( + self, model: torch.fx.GraphModule, config, filter_fn + ): + # if config := self._get_aten_operator_qconfig(torch.ops.aten.conv2d.default): + if config.is_qat: + # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat + self._annotate_all_qat_conv2d_fusion_pattern(model, config, filter_fn) + self._annotate_conv2d_binary_unary(model, config, filter_fn) + self._annotate_conv2d_binary(model, config, filter_fn) + self._annotate_conv2d_unary(model, config, filter_fn) + self._annotate_conv2d(model, config, filter_fn) + + def _annotate_all_linear_fusion_pattern( + self, model: torch.fx.GraphModule, config, filter_fn + ): + if config.input_activation and not config.input_activation.is_dynamic: + # Weiwen: Dynamic Quant of linear unary will be supported in next step + self._annotate_linear_binary_unary(model, config, filter_fn) + self._annotate_linear_unary(model, config, filter_fn) + self._annotate_linear(model, config, filter_fn) + def _annotate_matmul( self, model: torch.fx.GraphModule, @@ -1107,7 +953,6 @@ def _annotate_matmul( _is_output_of_quantized_pattern=True, ) - @register_annotator(STATIC_ANNOTATORS) def _annotate_conv2d_binary_unary( self, gm: torch.fx.GraphModule, @@ -1156,7 +1001,6 @@ def _annotate_conv2d_binary_unary( _is_output_of_quantized_pattern=True, ) - @register_annotator(STATIC_ANNOTATORS) def _annotate_conv2d_binary( self, gm: torch.fx.GraphModule, @@ -1203,7 +1047,6 @@ def _annotate_conv2d_binary( _is_output_of_quantized_pattern=True, ) - @register_annotator(STATIC_ANNOTATORS) def _annotate_conv2d_unary( self, gm: torch.fx.GraphModule, @@ -1242,7 +1085,6 @@ def _annotate_conv2d_unary( _is_output_of_quantized_pattern=True, ) - @register_annotator(STATIC_ANNOTATORS) def _annotate_conv2d( self, gm: torch.fx.GraphModule, @@ -1397,11 +1239,6 @@ def _annotate_output_for_int8_in_int8_out_pattern_entry( self, model, quantization_config, filter_fn ): for node in model.graph.nodes: - # if _skip_annotate([node], filter_fn): - # continue - # if quantization_config is None: - # return - self._annotate_output_for_int8_in_int8_out_pattern( node, quantization_config, filter_fn ) @@ -1454,7 +1291,68 @@ def _annotate_output_for_int8_in_int8_out_pattern( self._annotate_output_share_observer_as_input(input_node, node) return - @register_annotator(STATIC_ANNOTATORS) + def _annotate_linear( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> None: + linear_partitions = get_source_partitions( + gm.graph, [torch.nn.Linear, torch.nn.functional.linear] + ) + linear_partitions = list( + itertools.chain.from_iterable(linear_partitions.values()) + ) + for partition in linear_partitions: + log.warning("for partition: %s", partition) + if len(partition.output_nodes) > 1: + raise ValueError( + "Linear partition cannot have more than one output node" + ) + linear_node = partition.output_nodes[0] + if linear_node.op != "call_function" or linear_node.target not in ( + torch.ops.aten.linear.default, + ): + raise ValueError(f"{linear_node} is not an aten linear operator") + # skip annotation if it is already annotated + if _skip_annotate([linear_node], filter_fn): + continue + self._annotate_linear_node_helper(linear_node, True, quantization_config) + + def _annotate_linear_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> None: + postop_list = [ + torch.nn.ReLU, + torch.nn.LeakyReLU, + torch.nn.Tanh, + torch.nn.GELU, + ] + fused_partitions: List[tuple] = [] + for postop in postop_list: + fused_partitions = fused_partitions + find_sequential_partitions( + gm, [torch.nn.Linear, postop] + ) + for fused_partition in fused_partitions: + linear_partition, unary_partition = fused_partition + linear_node, unary_node = self._get_output_nodes_of_partitions( + [linear_partition, unary_partition] + ) + if linear_node.op != "call_function" or linear_node.target not in ( + torch.ops.aten.linear.default, + ): + continue + if _skip_annotate([unary_node, linear_node], filter_fn): + continue + self._annotate_linear_node_helper(linear_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + def _annotate_linear_binary_unary( self, gm: torch.fx.GraphModule, @@ -1516,8 +1414,9 @@ def _annotate_linear_binary_unary( if unary_node is None else [unary_node, binary_node, linear_node] ) - if _is_annotated(node_list): + if _skip_annotate(node_list, filter_fn): continue + self._annotate_linear_node_helper( linear_node, False, quantization_config ) @@ -1537,70 +1436,6 @@ def _annotate_linear_binary_unary( _is_output_of_quantized_pattern=True, ) - @register_annotator(STATIC_ANNOTATORS) - def _annotate_linear_unary( - self, - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, - ) -> None: - postop_list = [ - torch.nn.ReLU, - torch.nn.LeakyReLU, - torch.nn.Tanh, - torch.nn.GELU, - ] - fused_partitions: List[tuple] = [] - for postop in postop_list: - fused_partitions = fused_partitions + find_sequential_partitions( - gm, [torch.nn.Linear, postop] - ) - for fused_partition in fused_partitions: - linear_partition, unary_partition = fused_partition - linear_node, unary_node = self._get_output_nodes_of_partitions( - [linear_partition, unary_partition] - ) - if linear_node.op != "call_function" or linear_node.target not in ( - torch.ops.aten.linear.default, - ): - continue - if _skip_annotate([unary_node, linear_node], filter_fn): - continue - self._annotate_linear_node_helper(linear_node, False, quantization_config) - unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - _annotated=True, - _is_output_of_quantized_pattern=True, - ) - - @register_annotator([STATIC_ANNOTATORS, DYNAMIC_ANNOTATORS]) - def _annotate_linear( - self, - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, - ) -> None: - linear_partitions = get_source_partitions( - gm.graph, [torch.nn.Linear, torch.nn.functional.linear] - ) - linear_partitions = list( - itertools.chain.from_iterable(linear_partitions.values()) - ) - for partition in linear_partitions: - log.warning("for partition: %s", partition) - if len(partition.output_nodes) > 1: - raise ValueError( - "Linear partition cannot have more than one output node" - ) - linear_node = partition.output_nodes[0] - if linear_node.op != "call_function" or linear_node.target not in ( - torch.ops.aten.linear.default, - ): - raise ValueError(f"{linear_node} is not an aten linear operator") - # skip annotation if it is already annotated - if _skip_annotate([linear_node], filter_fn): - continue - self._annotate_linear_node_helper(linear_node, True, quantization_config) - def validate(self, model: torch.fx.GraphModule) -> None: pass From 06a5e0c2f044c5d4b6d427b7cc7524d804347226 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 17 May 2024 17:01:50 +0800 Subject: [PATCH 012/706] refine code Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 43 +++---------------- 1 file changed, 6 insertions(+), 37 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 714cc3a221b9..463a64435cf8 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -139,47 +139,22 @@ def operator_type_filter(n: Node) -> bool: def _x86_get_not_module_type_or_name_filter( tp_list: List[torch._ops.OpOverloadPacket], module_name_list: List[str] ) -> Callable[[Node], bool]: - # Check if the node is 1) belong to the `default_quantizable_ops` and 2) not be marked - # by `set_module_name_qconfig`, or `set_module_type_qconfig` `set_function_type_qconfig`. + # Check if the node is 1) belong to the `default_quantizable_ops` and 2) not be marked by `set_module_name_qconfig`, + # or `set_module_type_qconfig` `set_function_type_qconfig`. + # Only call the `operator_type_filters` is enough, since each filter of `operator_type_filters` will check + # the `module_name_list_filters`. operator_type_filters = [ _get_operator_type_filter(tp, module_name_list) for tp in tp_list ] - module_name_list_filters: List[ - Callable - ] = [] # [_get_module_name_filter(m) for m in module_name_list] def not_module_type_or_name_filter(n: Node) -> bool: # For global_config, only quantize the `default_quantizable_ops` belong_to_default_quantizable_ops = n.target in default_quantizable_ops - # if n.target not in default_quantizable_ops: - # result1 = False - # for f in module_name_list_filters: - # log.warning(f"not module_name for node: {n}, (name: {n.name}, target:{n.target}), f:{f} f(n) is {f(n)}") - - # for f in operator_type_filters: - # log.warning(f"not module type: {n}, (name: {n.name}, target:{n.target}), f:{f} f(n) is {f(n)}") - not_module_type_or_module_name_node = not any( - f(n) for f in operator_type_filters + module_name_list_filters - ) - final_result = ( - belong_to_default_quantizable_ops and not_module_type_or_module_name_node - ) - - # rewrite it not use the f-string - log.warning( - ( - "for node: %s, (name: %s, target:%s), node in default_quantizable_ops? %s;" - "node is not used by operator_type_filters or module_name_list_filters? %s" - ), - n, - n.name, - n.target, - belong_to_default_quantizable_ops, - not_module_type_or_module_name_node, + f(n) for f in operator_type_filters ) - return final_result + return belong_to_default_quantizable_ops and not_module_type_or_module_name_node return not_module_type_or_name_filter @@ -387,7 +362,6 @@ def __init__(self): torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} self.module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} - self._module_type_qconfig = {} @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: @@ -428,7 +402,6 @@ def set_function_type_qconfig( function_type: Callable, quantization_config: Optional[QuantizationConfig], ) -> "X86InductorQuantizer": - self._module_type_qconfig[function_type] = quantization_config if function_type in X86InductorQuantizer.module_function_to_aten_operator_type: self._set_aten_operator_qconfig( X86InductorQuantizer.module_function_to_aten_operator_type[ @@ -447,7 +420,6 @@ def set_module_type_qconfig( module_type: torch.nn.Module, quantization_config: Optional[QuantizationConfig], ) -> "X86InductorQuantizer": - self._module_type_qconfig[module_type] = quantization_config if module_type in X86InductorQuantizer.module_function_to_aten_operator_type: self._set_aten_operator_qconfig( X86InductorQuantizer.module_function_to_aten_operator_type[module_type], @@ -466,9 +438,6 @@ def set_module_name_qconfig( quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator patterns in the submodule with this module name with the given `quantization_config` """ - # assert ( - # quantization_config is not None - # ), " quantization_config == None is not supported yet" self.module_name_qconfig[module_name] = quantization_config return self From b26aedf5293d61a176e9cf9d0c7babaa92e005d7 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 20 May 2024 08:22:02 +0800 Subject: [PATCH 013/706] add docs Signed-off-by: yiliu30 --- torch/ao/quantization/quantizer/utils.py | 127 +--------------- .../quantizer/x86_inductor_quantizer.py | 137 +++++++++++------- .../quantizer/xnnpack_quantizer.py | 53 ++++++- .../quantizer/xnnpack_quantizer_utils.py | 16 +- 4 files changed, 149 insertions(+), 184 deletions(-) diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index 832aad36598d..366b6b06cff0 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -1,6 +1,6 @@ import logging import os -from typing import Callable, List, Optional +from typing import List from torch.ao.quantization.pt2e.utils import _is_sym_size_node @@ -86,128 +86,3 @@ def _normalize_path(n): return module_name in names return module_name_filter - - -def _get_module_type_filter(tp: Callable): - """Get the module_type_filter function for a given module type, the filter accepts - a node and checks if the node comes from a module that has certain module type - - For example: - node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear - - - >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule - >> print(module_type_filter(node)) - True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) - """ - - tp_str = tp.__module__ + "." + tp.__qualname__ - # import pdb; pdb.set_trace() - - def module_type_filter(n: Node) -> bool: - # example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - nn_module_stack = n.meta.get("nn_module_stack", {}) - types = [] - for _, t in nn_module_stack.values(): - # export() returns str, but older APIs (e.g. capture_pre_autograd_graph) - # return type. Handle both cases. - if isinstance(t, type): - t = t.__module__ + "." + t.__qualname__ - types.append(t) - - return tp_str in types - - return module_type_filter - - -def _get_not_module_type_or_name_filter( - tp_list: List[Callable], module_name_list: List[str] -) -> Callable[[Node], bool]: - module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] - module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] - - def not_module_type_or_name_filter(n: Node) -> bool: - return not any(f(n) for f in module_type_filters + module_name_list_filters) - - return not_module_type_or_name_filter - - -def _is_annotated(nodes: List[Node]): - """ - Given a list of nodes (that represents an operator pattern), - check if any of the node is annotated, return True if any of the node - is annotated, otherwise return False - """ - annotated = False - for node in nodes: - annotated = annotated or ( - "quantization_annotation" in node.meta - and node.meta["quantization_annotation"]._annotated - ) - return annotated - - -class CurrentStage: - is_global = False - - -current_stage = CurrentStage() - - -def _skip_annotate( - nodes: List[Node], filter_fn: Optional[Callable[[Node], bool]] = None -): - skip_annotate = False - # 1) Skip annotate if any node is already annotated - if _is_annotated(nodes): - skip_annotate = True - return skip_annotate - # 2) Skip annotate if a) filter_fn is provided and b) any node fails the filter - # filter_fn result - # case 1, - # filter_fn False, False, False - # not filter_fn True, True, True - # any True - # -> skip - - # no node named as user specific - - # case 2, - # filter_fn True, False, False - # not filter_fn False, True, True - # any True - # -> skip - - # some node are not user specific - - # case 3, - # filter_fn True - # not filter_fn False - # any False - # -> not skip - # all node are user specific - if current_stage.is_global and filter_fn is not None: - for node in nodes: - if filter_fn(node): - log.warning("not skip nodes %s", nodes) - return False - if filter_fn and any(not filter_fn(node) for node in nodes): - log.warning("skip nodes %s", nodes) - skip_annotate = True - return skip_annotate - # import pdb; pdb.set_trace() - log.warning("not skip nodes %s", nodes) - return skip_annotate - - -# def dump_function_name_decorator(func): -# def wrapper(*args, **kwargs): -# log.warning(f"entering {func.__name__}") -# result = func(*args, **kwargs) -# log.warning(f"exiting {func.__name__}") -# return result - -# return wrapper diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 463a64435cf8..a2f4445a21e6 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -36,12 +36,7 @@ Quantizer, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import ( - _get_module_name_filter, - _skip_annotate, - current_stage, - log, -) +from torch.ao.quantization.quantizer.utils import _get_module_name_filter, log from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( get_bias_qspec, get_input_act_qspec, @@ -106,6 +101,59 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): QUANT_ANNOTATION_KEY = "quantization_annotation" +class CurrentStage: + is_global = False + + +current_stage = CurrentStage() + + +def _skip_annotate( + nodes: List[Node], filter_fn: Optional[Callable[[Node], bool]] = None +): + skip_annotate = False + # 1) Skip annotate if any node is already annotated + if _is_any_annotated(nodes): + skip_annotate = True + return skip_annotate + # 2) Skip annotate if a) filter_fn is provided and b) any node fails the filter + # filter_fn result + # case 1, + # filter_fn False, False, False + # not filter_fn True, True, True + # any True + # -> skip + + # no node named as user specific + + # case 2, + # filter_fn True, False, False + # not filter_fn False, True, True + # any True + # -> skip + + # some node are not user specific + + # case 3, + # filter_fn True + # not filter_fn False + # any False + # -> not skip + # all node are user specific + if current_stage.is_global and filter_fn is not None: + for node in nodes: + if filter_fn(node): + log.warning("not skip nodes %s", nodes) + return False + if filter_fn and any(not filter_fn(node) for node in nodes): + log.warning("skip nodes %s", nodes) + skip_annotate = True + return skip_annotate + # import pdb; pdb.set_trace() + log.warning("not skip nodes %s", nodes) + return skip_annotate + + def _get_operator_type_filter(operator_type: Callable, module_name_list): """Get the operator_type_filter function for a given operator type, the filter accepts a node and checks if the node has certain module type @@ -136,7 +184,7 @@ def operator_type_filter(n: Node) -> bool: return operator_type_filter -def _x86_get_not_module_type_or_name_filter( +def _get_not_operator_type_or_name_filter( tp_list: List[torch._ops.OpOverloadPacket], module_name_list: List[str] ) -> Callable[[Node], bool]: # Check if the node is 1) belong to the `default_quantizable_ops` and 2) not be marked by `set_module_name_qconfig`, @@ -572,28 +620,13 @@ def _get_input_idx_for_binary_node( return conv_gemm_node_idx, extra_input_node_idx def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """just handling global spec for now""" - if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] - model = self._annotate_entry(model) - else: - model = self._annotate_entry(model) - return model - - def _annotate_by_single_config(self, model, config, filter_fn): - if config is None: - return - self._annotate_all_conv2d_fusion_pattern(model, config, filter_fn) - self._annotate_all_linear_fusion_pattern(model, config, filter_fn) - self._annotate_matmul(model, config, filter_fn) - - self._annotate_propagation_quantizable_pattern_entry(model, config, filter_fn) - self._annotate_output_for_int8_in_int8_out_pattern_entry( - model, config, filter_fn - ) - - def _annotate_quantization_with_all_qconfig( - self, model: torch.fx.GraphModule - ) -> torch.fx.GraphModule: + """ + 1) Annotate each node according the user's qconfig with following order: + `module_name_qconfig`, `module_type_qconfig`, and `global_config`. + 2) Skip nodes already annotated by an earlier stage. For example, + if `linear1` has been annotated in the `module_name_config` stage, + it will not be re-annotated in the `module_type_config` or `global_config` stages. + """ module_name_list = list(self.module_name_qconfig.keys()) for module_name, qconfig in self.module_name_qconfig.items(): @@ -612,12 +645,17 @@ def _annotate_quantization_with_all_qconfig( self._annotate_by_single_config( model, self.global_config, - _x86_get_not_module_type_or_name_filter(tp_list, module_name_list), + _get_not_operator_type_or_name_filter(tp_list, module_name_list), ) return model - def _annotate_entry(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - r""" + def _annotate_by_single_config( + self, + model: torch.fx.GraphModule, + config: Optional[QuantizationConfig], + filter_fn: Callable, + ) -> None: + """ High-level description of quantization recipe for X86 Inductor Backend: Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model @@ -627,31 +665,29 @@ def _annotate_entry(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type, we need to annotate the output of this pattern. """ + if config is None: + return # Step1: Recipe of fusion patterns like conv/linear. - - self._annotate_quantization_with_all_qconfig(model) - - # if self.module_name_qconfig: - # self._annotate_quantization_with_all_qconfig(model) - # if self.operator_type_qconfig or self.global_config: - # self._annotate_static_quantization_by_op_type_and_global_config(model) + self._annotate_all_conv2d_fusion_pattern(model, config, filter_fn) + self._annotate_all_linear_fusion_pattern(model, config, filter_fn) + self._annotate_matmul(model, config, filter_fn) # Step2: Recipe to propagate annotation for patterns beside conv/linear. # Go through all the nodes from start to end. # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/ # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 - # for node in model.graph.nodes: - # self._annotate_propagation_quantizable_pattern(node) - # # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized - # # in inputs. So, we can fuse dq-operator-q into a quantized op. - # # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ - # # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 - # for node in model.graph.nodes: - # self._annotate_output_for_int8_in_int8_out_pattern(node) + self._annotate_all_propagation_quantizable_pattern(model, config, filter_fn) - return model + # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized + # in inputs. So, we can fuse dq-operator-q into a quantized op. + # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 + + self._annotate_output_for_int8_in_int8_out_pattern_entry( + model, config, filter_fn + ) def _annotate_all_qat_conv2d_fusion_pattern( self, model: torch.fx.GraphModule, config: QuantizationConfig, filter_fn @@ -883,7 +919,6 @@ def _annotate_qat_conv2d_bn( def _annotate_all_conv2d_fusion_pattern( self, model: torch.fx.GraphModule, config, filter_fn ): - # if config := self._get_aten_operator_qconfig(torch.ops.aten.conv2d.default): if config.is_qat: # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat self._annotate_all_qat_conv2d_fusion_pattern(model, config, filter_fn) @@ -1129,12 +1164,10 @@ def _annotate_cat( _is_output_of_quantized_pattern=True, ) - def _annotate_propagation_quantizable_pattern_entry( + def _annotate_all_propagation_quantizable_pattern( self, model, quantization_config, filter_fn ): for node in model.graph.nodes: - # if _skip_annotate([node], filter_fn): - # continue self._annotate_propagation_quantizable_pattern( node, quantization_config, filter_fn ) diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 4649c3546f05..38232db54cbd 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -22,11 +22,7 @@ ) from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.utils import ( - _get_module_name_filter, - _get_module_type_filter, - _get_not_module_type_or_name_filter, -) +from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( _convert_scalars_to_attrs, @@ -197,6 +193,53 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_symmetric_config_and_operators() +def _get_module_type_filter(tp: Callable): + """Get the module_type_filter function for a given module type, the filter accepts + a node and checks if the node comes from a module that has certain module type + + For example: + node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear + + + >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule + >> print(module_type_filter(node)) + True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) + """ + + tp_str = tp.__module__ + "." + tp.__qualname__ + # import pdb; pdb.set_trace() + + def module_type_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [] + for _, t in nn_module_stack.values(): + # export() returns str, but older APIs (e.g. capture_pre_autograd_graph) + # return type. Handle both cases. + if isinstance(t, type): + t = t.__module__ + "." + t.__qualname__ + types.append(t) + + return tp_str in types + + return module_type_filter + + +def _get_not_module_type_or_name_filter( + tp_list: List[Callable], module_name_list: List[str] +) -> Callable[[Node], bool]: + module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_module_type_or_name_filter(n: Node) -> bool: + return not any(f(n) for f in module_type_filters + module_name_list_filters) + + return not_module_type_or_name_filter + + class XNNPACKQuantizer(Quantizer): supported_config_and_operators = _get_supported_config_and_operators() STATIC_QAT_ONLY_OPS = [ diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index d5595a136990..9f1732e57370 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -26,7 +26,6 @@ from torch.ao.quantization.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, - _is_annotated, ) from torch.fx import Node from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( @@ -95,6 +94,21 @@ class OperatorConfig(NamedTuple): operators: List[OperatorPatternType] +def _is_annotated(nodes: List[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + def _mark_nodes_as_annotated(nodes: List[Node]): for node in nodes: if node is not None: From a01340a39652b135bf3f9c70ac668d5f80e82e90 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 20 May 2024 08:26:19 +0800 Subject: [PATCH 014/706] clean code Signed-off-by: yiliu30 --- torch/ao/quantization/quantizer/utils.py | 5 ----- .../quantizer/x86_inductor_quantizer.py | 15 +-------------- .../quantization/quantizer/xnnpack_quantizer.py | 2 -- 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index 366b6b06cff0..77cfc22d73be 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -1,5 +1,3 @@ -import logging -import os from typing import List from torch.ao.quantization.pt2e.utils import _is_sym_size_node @@ -7,9 +5,6 @@ from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation from torch.fx import Node -log = logging.getLogger(__name__) -log.setLevel(os.environ.get("LOGLEVEL", "ERROR")) - def _annotate_input_qspec_map(node: Node, input_node: Node, qspec): quantization_annotation = node.meta.get( diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index a2f4445a21e6..f09f496dc5b8 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -36,7 +36,7 @@ Quantizer, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import _get_module_name_filter, log +from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( get_bias_qspec, get_input_act_qspec, @@ -143,14 +143,10 @@ def _skip_annotate( if current_stage.is_global and filter_fn is not None: for node in nodes: if filter_fn(node): - log.warning("not skip nodes %s", nodes) return False if filter_fn and any(not filter_fn(node) for node in nodes): - log.warning("skip nodes %s", nodes) skip_annotate = True return skip_annotate - # import pdb; pdb.set_trace() - log.warning("not skip nodes %s", nodes) return skip_annotate @@ -171,14 +167,6 @@ def _get_operator_type_filter(operator_type: Callable, module_name_list): def operator_type_filter(n: Node) -> bool: not_module_name_node = not any(f(n) for f in module_name_list_filters) has_certain_operator_type = n.target == operator_type - log.warning( - "!!for node: %s, (name: %s, target:%s) n.target equal operator_type (%s))? %s", - n, - n.name, - n.target, - operator_type, - has_certain_operator_type, - ) return not_module_name_node and has_certain_operator_type return operator_type_filter @@ -1306,7 +1294,6 @@ def _annotate_linear( itertools.chain.from_iterable(linear_partitions.values()) ) for partition in linear_partitions: - log.warning("for partition: %s", partition) if len(partition.output_nodes) > 1: raise ValueError( "Linear partition cannot have more than one output node" diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 38232db54cbd..e13a79f39267 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -207,7 +207,6 @@ def _get_module_type_filter(tp: Callable): """ tp_str = tp.__module__ + "." + tp.__qualname__ - # import pdb; pdb.set_trace() def module_type_filter(n: Node) -> bool: # example: { @@ -222,7 +221,6 @@ def module_type_filter(n: Node) -> bool: if isinstance(t, type): t = t.__module__ + "." + t.__qualname__ types.append(t) - return tp_str in types return module_type_filter From 912099adc83e71f98bd2edbc9adf48d095ba41b4 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 20 May 2024 08:39:03 +0800 Subject: [PATCH 015/706] remove useless code Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index f09f496dc5b8..d7f31dee2825 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -101,13 +101,6 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): QUANT_ANNOTATION_KEY = "quantization_annotation" -class CurrentStage: - is_global = False - - -current_stage = CurrentStage() - - def _skip_annotate( nodes: List[Node], filter_fn: Optional[Callable[[Node], bool]] = None ): @@ -140,14 +133,11 @@ def _skip_annotate( # any False # -> not skip # all node are user specific - if current_stage.is_global and filter_fn is not None: - for node in nodes: - if filter_fn(node): - return False - if filter_fn and any(not filter_fn(node) for node in nodes): - skip_annotate = True - return skip_annotate - return skip_annotate + + if filter_fn and any(filter_fn(node) for node in nodes): + return False + + return True def _get_operator_type_filter(operator_type: Callable, module_name_list): @@ -629,7 +619,6 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: _get_operator_type_filter(operator_type, module_name_list), ) if self.global_config: - current_stage.is_global = True self._annotate_by_single_config( model, self.global_config, From 4f4bab0d57dae31b612f45f10ca7801e5732dc84 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 20 May 2024 09:50:05 +0800 Subject: [PATCH 016/706] clean code Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 76 +++++++------------ 1 file changed, 27 insertions(+), 49 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index d7f31dee2825..2534a0b9d28d 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -107,33 +107,8 @@ def _skip_annotate( skip_annotate = False # 1) Skip annotate if any node is already annotated if _is_any_annotated(nodes): - skip_annotate = True - return skip_annotate - # 2) Skip annotate if a) filter_fn is provided and b) any node fails the filter - # filter_fn result - # case 1, - # filter_fn False, False, False - # not filter_fn True, True, True - # any True - # -> skip - - # no node named as user specific - - # case 2, - # filter_fn True, False, False - # not filter_fn False, True, True - # any True - # -> skip - - # some node are not user specific - - # case 3, - # filter_fn True - # not filter_fn False - # any False - # -> not skip - # all node are user specific - + return True + # 2) Not skip annotate if a) filter_fn is provided and b) any node passed the filter if filter_fn and any(filter_fn(node) for node in nodes): return False @@ -142,7 +117,7 @@ def _skip_annotate( def _get_operator_type_filter(operator_type: Callable, module_name_list): """Get the operator_type_filter function for a given operator type, the filter accepts - a node and checks if the node has certain module type + a node and checks if the node has certain operator type. For example: node: linear_op = call_function[...](...) # linear_op.target if torch.ops.aten.linear.default @@ -480,15 +455,6 @@ def _set_aten_operator_qconfig( ) return self - def _get_aten_operator_qconfig( - self, - operator_type: torch._ops.OpOverloadPacket, - ) -> Optional[QuantizationConfig]: - if operator_type in self.operator_type_qconfig: - assert operator_type in quantizable_ops - return self.operator_type_qconfig[operator_type] - return self.global_config if operator_type in default_quantizable_ops else None - def _annotate_conv_node_helper( self, conv_node: torch.fx.Node, @@ -597,6 +563,12 @@ def _get_input_idx_for_binary_node( assert isinstance(extra_input_node, Node) return conv_gemm_node_idx, extra_input_node_idx + def _check_qconfig(self) -> None: + """Check if the qconfig is valid. + Currently, not support mixed static and dynamic quantization config.""" + # TODO: + pass + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """ 1) Annotate each node according the user's qconfig with following order: @@ -605,6 +577,9 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: if `linear1` has been annotated in the `module_name_config` stage, it will not be re-annotated in the `module_type_config` or `global_config` stages. """ + + self._check_qconfig() + module_name_list = list(self.module_name_qconfig.keys()) for module_name, qconfig in self.module_name_qconfig.items(): @@ -1091,7 +1066,10 @@ def _annotate_conv2d( self._annotate_conv_node_helper(conv_node, True, quantization_config) def _annotate_maxpool2d( - self, node: Node, quantization_config: QuantizationConfig, filter_fn + self, + node: Node, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable] = None, ) -> None: if node.target is not torch.ops.aten.max_pool2d.default: return @@ -1157,7 +1135,6 @@ def _annotate_propagation_quantizable_pattern( (node.target in propagation_quantizable_ops) and (not _is_any_annotated([node])) and (node.op == "call_function") - # and (quantization_config := self._get_aten_operator_qconfig(node.target)) # type: ignore[arg-type] ): def is_all_inputs_connected_to_quantized_op(input_nodes): @@ -1215,7 +1192,10 @@ def _annotate_output_share_observer_as_input( return def _annotate_output_for_int8_in_int8_out_pattern_entry( - self, model, quantization_config, filter_fn + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig] = None, + filter_fn: Optional[Callable] = None, ): for node in model.graph.nodes: self._annotate_output_for_int8_in_int8_out_pattern( @@ -1223,7 +1203,10 @@ def _annotate_output_for_int8_in_int8_out_pattern_entry( ) def _annotate_output_for_int8_in_int8_out_pattern( - self, node: Node, quantization_config, filter_fn + self, + node: Node, + quantization_config: Optional[QuantizationConfig] = None, + filter_fn: Optional[Callable] = None, ) -> None: r""" Check and insert observer at output of node in int8_in_int8_out_ops if needed. @@ -1231,11 +1214,7 @@ def _annotate_output_for_int8_in_int8_out_pattern( 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495 """ edge_or_node: Tuple[Node, Node] - if ( - (node.target in int8_in_int8_out_ops) - and (_is_any_annotated([node])) - and (quantization_config := self._get_aten_operator_qconfig(node.target)) # type: ignore[arg-type] - ): + if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): if node.target == torch.ops.aten.max_pool2d.default: maxpool_node = node if not _is_all_annotated( @@ -1244,9 +1223,8 @@ def _annotate_output_for_int8_in_int8_out_pattern( ] ): return - # !!! Do not skip this, this annotator annotate a annotated node???? - # if _skip_annotate([maxpool_node], filter_fn): - # return + # Don't check the `filter_fn` here, as we want to annotate + # the output of the node that's being annotated. # Get the quantization_annotation from getitem_node maxpool_node_quantization_annotation = ( maxpool_node.meta[QUANT_ANNOTATION_KEY] From d8c38b3def341236423a02c0eee0c511a446675b Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 20 May 2024 10:39:31 +0800 Subject: [PATCH 017/706] add config checker Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 62 ++++++++++++++++--- 1 file changed, 54 insertions(+), 8 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 2534a0b9d28d..d7d6b3d73359 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -104,7 +104,6 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): def _skip_annotate( nodes: List[Node], filter_fn: Optional[Callable[[Node], bool]] = None ): - skip_annotate = False # 1) Skip annotate if any node is already annotated if _is_any_annotated(nodes): return True @@ -116,8 +115,10 @@ def _skip_annotate( def _get_operator_type_filter(operator_type: Callable, module_name_list): - """Get the operator_type_filter function for a given operator type, the filter accepts - a node and checks if the node has certain operator type. + """Get the operator_type_filter function for a given operator type and module name list. + + The filter accept a node and checks if 1) the node has certain operator type, + and 2) the node does not marked by `set_module_name_qconfig`. For example: node: linear_op = call_function[...](...) # linear_op.target if torch.ops.aten.linear.default @@ -358,7 +359,7 @@ class X86InductorQuantizer(Quantizer): def __init__(self): super().__init__() - self.global_config: QuantizationConfig = None # type: ignore[assignment] + self.global_config: Optional[QuantizationConfig] = None # type: ignore[assignment] self.operator_type_qconfig: Dict[ torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} @@ -565,9 +566,52 @@ def _get_input_idx_for_binary_node( def _check_qconfig(self) -> None: """Check if the qconfig is valid. - Currently, not support mixed static and dynamic quantization config.""" - # TODO: - pass + + Currently, not support mixed static and dynamic quantization config. + If the qconfig is mixed, the subsequent configuration will be skipped. + """ + + def _need_skip_cur_config( + qconfig: Optional[QuantizationConfig], _pre_mode: Optional[bool] + ): + input_act_config = getattr(qconfig, "input_activation", None) + if input_act_config: + qconfig_is_dynamic = input_act_config.is_dynamic + if _pre_mode is not None and _pre_mode != qconfig_is_dynamic: + warnings.warn( + "Mixed dynamic and static quantization config is not supported. \ + The subsequent configuration will be skipped." + ) + return _pre_mode, True + else: + if _pre_mode is None: + _pre_mode = qconfig_is_dynamic + return _pre_mode, False + + _pre_mode = None + + tmp_module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} + for module_name, qconfig in self.module_name_qconfig.items(): + _pre_mode, need_skip = _need_skip_cur_config(qconfig, _pre_mode) + if not need_skip: + tmp_module_name_qconfig[module_name] = qconfig + self.module_name_qconfig = tmp_module_name_qconfig + + tmp_operator_type_qconfig: Dict[ + torch._ops.OpOverloadPacket, Optional[QuantizationConfig] + ] = {} + for operator_type, qconfig in self.operator_type_qconfig.items(): + _pre_mode, need_skip = _need_skip_cur_config(qconfig, _pre_mode) + if not need_skip: + tmp_operator_type_qconfig[operator_type] = qconfig + self.operator_type_qconfig = tmp_operator_type_qconfig + + if self.global_config: + _pre_mode, need_skip = _need_skip_cur_config(self.global_config, _pre_mode) + if not need_skip: + self.global_config = self.global_config + else: + self.global_config = None def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """ @@ -581,11 +625,11 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: self._check_qconfig() module_name_list = list(self.module_name_qconfig.keys()) - for module_name, qconfig in self.module_name_qconfig.items(): self._annotate_by_single_config( model, qconfig, _get_module_name_filter(module_name) ) + tp_list = list(self.operator_type_qconfig.keys()) for operator_type, qconfig in self.operator_type_qconfig.items(): self._annotate_by_single_config( @@ -593,12 +637,14 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: qconfig, _get_operator_type_filter(operator_type, module_name_list), ) + if self.global_config: self._annotate_by_single_config( model, self.global_config, _get_not_operator_type_or_name_filter(tp_list, module_name_list), ) + return model def _annotate_by_single_config( From 49deed2e5a3a2a070bf4503a9cb62bcf39038c06 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 20 May 2024 11:35:02 +0800 Subject: [PATCH 018/706] add more UTs Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 53 +++++++++++++++++++ .../quantizer/x86_inductor_quantizer.py | 22 +++++--- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 5ff7a261bd8e..8fdf3683a049 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -2124,6 +2124,59 @@ def test_set_module_name_qconfig_for_dynamic_quant(self): is_qat=True, ) + @skipIfNoX86 + def test_set_mixed_static_and_dynamic(self): + """Test that mixed static and dynamic quantization for a module.""" + + with override_quantized_engine("x86"), torch.no_grad(): + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + # quantize `self.q_proj` + static_config = xiq.get_default_x86_inductor_quantization_config( + is_dynamic=False + ) + dynamic_config = xiq.get_default_x86_inductor_quantization_config( + is_dynamic=True + ) + # quantize `self.v_proj` with static config + # quantize `self.q_proj` with dynamic config (will be skipped) + quantizer = ( + X86InductorQuantizer() + .set_module_name_qconfig("q_proj", static_config) + .set_module_name_qconfig("v_proj", dynamic_config) + ) + node_occurrence = { + # ops for quantizing/de-quantizing input + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # only q_proj be quantized + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # ops for quantizing/de-quantizing input + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # op for de-quantizing `q_proj`'s weight + torch.ops.quantized_decomposed.dequantize_per_channel.default, + # q_proj + torch.ops.aten.linear.default, + # k_proj + torch.ops.aten.linear.default, + # `v_proj`'s weight will not be quantized + # torch.ops.quantized_decomposed.dequantize_per_channel.default, + # v_proj + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + debug=True, + ) + @skipIfNoX86 def test_filter_conv2d_recipe(self): """ diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index d7d6b3d73359..07059c701a11 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -150,7 +150,7 @@ def _get_not_operator_type_or_name_filter( _get_operator_type_filter(tp, module_name_list) for tp in tp_list ] - def not_module_type_or_name_filter(n: Node) -> bool: + def not_operator_type_or_name_filter(n: Node) -> bool: # For global_config, only quantize the `default_quantizable_ops` belong_to_default_quantizable_ops = n.target in default_quantizable_ops not_module_type_or_module_name_node = not any( @@ -158,7 +158,7 @@ def not_module_type_or_name_filter(n: Node) -> bool: ) return belong_to_default_quantizable_ops and not_module_type_or_module_name_node - return not_module_type_or_name_filter + return not_operator_type_or_name_filter def _map_module_function_to_aten_operator_type(): @@ -572,15 +572,15 @@ def _check_qconfig(self) -> None: """ def _need_skip_cur_config( - qconfig: Optional[QuantizationConfig], _pre_mode: Optional[bool] + qconfig: Optional[QuantizationConfig], _pre_mode: Optional[bool], msg: str ): input_act_config = getattr(qconfig, "input_activation", None) if input_act_config: qconfig_is_dynamic = input_act_config.is_dynamic if _pre_mode is not None and _pre_mode != qconfig_is_dynamic: warnings.warn( - "Mixed dynamic and static quantization config is not supported. \ - The subsequent configuration will be skipped." + "Mixed dynamic and static quantization config is not supported." + f"The configuration for {msg} will be skipped." ) return _pre_mode, True else: @@ -592,7 +592,9 @@ def _need_skip_cur_config( tmp_module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} for module_name, qconfig in self.module_name_qconfig.items(): - _pre_mode, need_skip = _need_skip_cur_config(qconfig, _pre_mode) + _pre_mode, need_skip = _need_skip_cur_config( + qconfig, _pre_mode, module_name + ) if not need_skip: tmp_module_name_qconfig[module_name] = qconfig self.module_name_qconfig = tmp_module_name_qconfig @@ -601,13 +603,17 @@ def _need_skip_cur_config( torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} for operator_type, qconfig in self.operator_type_qconfig.items(): - _pre_mode, need_skip = _need_skip_cur_config(qconfig, _pre_mode) + _pre_mode, need_skip = _need_skip_cur_config( + qconfig, _pre_mode, str(operator_type) + ) if not need_skip: tmp_operator_type_qconfig[operator_type] = qconfig self.operator_type_qconfig = tmp_operator_type_qconfig if self.global_config: - _pre_mode, need_skip = _need_skip_cur_config(self.global_config, _pre_mode) + _pre_mode, need_skip = _need_skip_cur_config( + self.global_config, _pre_mode, "global" + ) if not need_skip: self.global_config = self.global_config else: From 9d626909d984301ca4d9d2298ec70d417a98fe1f Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 20 May 2024 12:00:56 +0800 Subject: [PATCH 019/706] refine docs Signed-off-by: yiliu30 --- test/quantization/pt2e/test_x86inductor_quantizer.py | 9 +++++---- .../ao/quantization/quantizer/x86_inductor_quantizer.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 8fdf3683a049..de83d5f916a7 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1918,7 +1918,7 @@ def forward(self, x): @skipIfNoX86 def test_set_module_name_and_set_module_type_case2(self): - """ + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. All linear are quantized except the second one. """ @@ -1972,7 +1972,8 @@ def forward(self, x): @skipIfNoX86 def test_set_module_name_and_set_module_type(self): - """ + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. + All linear are not quantized except the second one. """ @@ -2138,8 +2139,8 @@ def test_set_mixed_static_and_dynamic(self): dynamic_config = xiq.get_default_x86_inductor_quantization_config( is_dynamic=True ) - # quantize `self.v_proj` with static config - # quantize `self.q_proj` with dynamic config (will be skipped) + # set `self.v_proj` with static config + # set `self.q_proj` with dynamic config (will be skipped) quantizer = ( X86InductorQuantizer() .set_module_name_qconfig("q_proj", static_config) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 07059c701a11..a9d9592f4e49 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -620,8 +620,9 @@ def _need_skip_cur_config( self.global_config = None def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """ - 1) Annotate each node according the user's qconfig with following order: + """Annotate the model with quantization configuration. + + 1) Annotate each node according users's qconfig with following order: `module_name_qconfig`, `module_type_qconfig`, and `global_config`. 2) Skip nodes already annotated by an earlier stage. For example, if `linear1` has been annotated in the `module_name_config` stage, From 9946b48b6b64f55cfafd5b4aa91ab22d166dc5fc Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 21 May 2024 10:13:31 +0800 Subject: [PATCH 020/706] refine UTs Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 62 ++++++++-------- .../quantizer/x86_inductor_quantizer.py | 70 ++++++++++--------- 2 files changed, 69 insertions(+), 63 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index de83d5f916a7..e618647555b4 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1911,7 +1911,6 @@ def forward(self, x): # second linear is quantized torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.aten.linear.default, ] self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) @@ -1920,7 +1919,7 @@ def forward(self, x): def test_set_module_name_and_set_module_type_case2(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. - All linear are quantized except the second one. + All linear are quantized except the last one. """ class Sub(torch.nn.Module): @@ -1934,11 +1933,13 @@ def forward(self, x): class M(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(5, 5) + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) self.sub = Sub() def forward(self, x): - x = self.linear(x) + x = self.linear1(x) + x = self.linear2(x) x = self.sub(x) return x @@ -1951,19 +1952,19 @@ def forward(self, x): ) node_occurrence = { - torch.ops.aten.linear.default: 2, - # input and output for the first linear - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + torch.ops.aten.linear.default: 3, + # quantize the input and output of the first linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - # first linear is quantized + # first and second linear are quantized torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.aten.linear.default, - # second linear is not quantized + torch.ops.aten.linear.default, + # last linear is not quantized torch.ops.aten.linear.default, ] self._test_quantizer( @@ -1971,10 +1972,10 @@ def forward(self, x): ) @skipIfNoX86 - def test_set_module_name_and_set_module_type(self): + def test_set_module_name_and_set_module_type_case1(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. - All linear are not quantized except the second one. + All linear are not quantized except the last one. """ class Sub(torch.nn.Module): @@ -1988,11 +1989,13 @@ def forward(self, x): class M(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(5, 5) + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) self.sub = Sub() def forward(self, x): - x = self.linear(x) + x = self.linear1(x) + x = self.linear2(x) x = self.sub(x) return x @@ -2005,19 +2008,19 @@ def forward(self, x): ).set_module_type_qconfig(torch.nn.Linear, None) node_occurrence = { - torch.ops.aten.linear.default: 2, - # input and output for the second linear + torch.ops.aten.linear.default: 3, + # quantize the input of the last linear torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ - # first linear is not quantized + # first and second linear is not quantized torch.ops.aten.linear.default, - # second linear is quantized + torch.ops.aten.linear.default, + # last linear is quantized torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.aten.linear.default, ] self._test_quantizer( @@ -2031,8 +2034,7 @@ def test_set_module_name_qconfig_with_underscores(self) -> None: class M(torch.nn.Module): def __init__(self): super().__init__() - # This module name has underscores, which can be part of a mangled - # name. + # This module name has underscores, which can be part of a mangled name. self.foo_bar = torch.nn.Linear(2, 2) self.baz = torch.nn.Linear(2, 2) @@ -2097,7 +2099,7 @@ def test_set_module_name_qconfig_for_dynamic_quant(self): torch.ops.quantized_decomposed.choose_qparams.tensor: 1, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, - # each for q_proj and v_proj + # ops for dequantize the weight of q_proj and v_proj torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ @@ -2105,14 +2107,14 @@ def test_set_module_name_qconfig_for_dynamic_quant(self): torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, - # op for de-quantizing `q_proj`'s weight - torch.ops.quantized_decomposed.dequantize_per_channel.default, + # # op for de-quantizing `q_proj`'s weight, disable this check + # torch.ops.quantized_decomposed.dequantize_per_channel.default, # q_proj torch.ops.aten.linear.default, # k_proj torch.ops.aten.linear.default, - # op for de-quantizing `v_proj`'s weight - torch.ops.quantized_decomposed.dequantize_per_channel.default, + # # op for de-quantizing `v_proj`'s weight, disable this check + # torch.ops.quantized_decomposed.dequantize_per_channel.default, # v_proj torch.ops.aten.linear.default, ] @@ -2157,8 +2159,8 @@ def test_set_mixed_static_and_dynamic(self): # ops for quantizing/de-quantizing input torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, - # op for de-quantizing `q_proj`'s weight - torch.ops.quantized_decomposed.dequantize_per_channel.default, + # op for de-quantizing `q_proj`'s weight, disable this check + # torch.ops.quantized_decomposed.dequantize_per_channel.default, # q_proj torch.ops.aten.linear.default, # k_proj diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index a9d9592f4e49..751db41df9c3 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -107,6 +107,7 @@ def _skip_annotate( # 1) Skip annotate if any node is already annotated if _is_any_annotated(nodes): return True + # 2) Not skip annotate if a) filter_fn is provided and b) any node passed the filter if filter_fn and any(filter_fn(node) for node in nodes): return False @@ -121,12 +122,14 @@ def _get_operator_type_filter(operator_type: Callable, module_name_list): and 2) the node does not marked by `set_module_name_qconfig`. For example: - node: linear_op = call_function[...](...) # linear_op.target if torch.ops.aten.linear.default + # linear_op.target if torch.ops.aten.linear.default + # linear_op is traced from `fc3` + node: linear_op = call_function[...](...) - >> operator_type_filter = _get_operator_type_filter(torch.ops.aten.linear.default) + >> operator_type_filter = _get_operator_type_filter(torch.ops.aten.linear.default, ["fc1", "fc2"]) >> print(operator_type_filter(node)) - True # the node's target is `torch.ops.aten.linear.default` + True # the node's target is `torch.ops.aten.linear.default` and not marked by `set_module_name_qconfig` """ module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] @@ -141,22 +144,26 @@ def operator_type_filter(n: Node) -> bool: def _get_not_operator_type_or_name_filter( tp_list: List[torch._ops.OpOverloadPacket], module_name_list: List[str] ) -> Callable[[Node], bool]: - # Check if the node is 1) belong to the `default_quantizable_ops` and 2) not be marked by `set_module_name_qconfig`, - # or `set_module_type_qconfig` `set_function_type_qconfig`. + """Get the not_operator_type_or_name_filter function for a given operator type list and module name list. + + The filter accept a node and checks if 1) the node does not marked by `set_module_name_qconfig`, + or `set_module_type_qconfig` `set_function_type_qconfig`, and 2) the node's type + is belong to `default_quantizable_ops`. + """ - # Only call the `operator_type_filters` is enough, since each filter of `operator_type_filters` will check - # the `module_name_list_filters`. + # Only call the `operator_type_filters` is enough, since each filter of `operator_type_filters` + # will check the `module_name_list_filters`. operator_type_filters = [ _get_operator_type_filter(tp, module_name_list) for tp in tp_list ] def not_operator_type_or_name_filter(n: Node) -> bool: # For global_config, only quantize the `default_quantizable_ops` - belong_to_default_quantizable_ops = n.target in default_quantizable_ops - not_module_type_or_module_name_node = not any( + is_default_quantizable_op = n.target in default_quantizable_ops + not_operator_type_or_module_name_node = not any( f(n) for f in operator_type_filters ) - return belong_to_default_quantizable_ops and not_module_type_or_module_name_node + return is_default_quantizable_op and not_operator_type_or_module_name_node return not_operator_type_or_name_filter @@ -439,6 +446,8 @@ def set_module_name_qconfig( """Set quantization_config for a submodule with name: `module_name`, for example: quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator patterns in the submodule with this module name with the given `quantization_config` + + The supported operators include `quantizable_ops` and `propagation_quantizable_ops`. """ self.module_name_qconfig[module_name] = quantization_config return self @@ -568,13 +577,14 @@ def _check_qconfig(self) -> None: """Check if the qconfig is valid. Currently, not support mixed static and dynamic quantization config. - If the qconfig is mixed, the subsequent configuration will be skipped. + If the mixture is detected, the subsequent configuration will be skipped. """ def _need_skip_cur_config( qconfig: Optional[QuantizationConfig], _pre_mode: Optional[bool], msg: str - ): + ) -> Tuple[Optional[bool], bool]: input_act_config = getattr(qconfig, "input_activation", None) + need_skip = False if input_act_config: qconfig_is_dynamic = input_act_config.is_dynamic if _pre_mode is not None and _pre_mode != qconfig_is_dynamic: @@ -582,11 +592,11 @@ def _need_skip_cur_config( "Mixed dynamic and static quantization config is not supported." f"The configuration for {msg} will be skipped." ) - return _pre_mode, True + need_skip = True else: if _pre_mode is None: _pre_mode = qconfig_is_dynamic - return _pre_mode, False + return _pre_mode, need_skip _pre_mode = None @@ -623,10 +633,10 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """Annotate the model with quantization configuration. 1) Annotate each node according users's qconfig with following order: - `module_name_qconfig`, `module_type_qconfig`, and `global_config`. - 2) Skip nodes already annotated by an earlier stage. For example, - if `linear1` has been annotated in the `module_name_config` stage, - it will not be re-annotated in the `module_type_config` or `global_config` stages. + `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. + 2) Skip nodes already annotated by earlier stage. For example, + if `linear1` has been annotated in the `module_name_config` stage, it will + not be re-annotated in the `operator_type_qconfig` or `global_config` stages. """ self._check_qconfig() @@ -660,7 +670,8 @@ def _annotate_by_single_config( config: Optional[QuantizationConfig], filter_fn: Callable, ) -> None: - """ + """Annotate the model with a quantization configuration. + High-level description of quantization recipe for X86 Inductor Backend: Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model @@ -690,9 +701,7 @@ def _annotate_by_single_config( # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 - self._annotate_output_for_int8_in_int8_out_pattern_entry( - model, config, filter_fn - ) + self._annotate_output_for_int8_in_int8_out_pattern_entry(model) def _annotate_all_qat_conv2d_fusion_pattern( self, model: torch.fx.GraphModule, config: QuantizationConfig, filter_fn @@ -1122,7 +1131,6 @@ def _annotate_maxpool2d( self, node: Node, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable] = None, ) -> None: if node.target is not torch.ops.aten.max_pool2d.default: return @@ -1133,8 +1141,7 @@ def _annotate_maxpool2d( ] ): return - if _skip_annotate([maxpool_node], filter_fn): - return + input_node = maxpool_node.args[0] assert isinstance(input_node, Node) input_qspec_map = {} @@ -1197,12 +1204,15 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): return False return True + if _skip_annotate([node], filter_fn): + return + if node.target is torch.ops.aten.max_pool2d.default: # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not input_nodes_to_check = [node.all_input_nodes[0]] if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): return - self._annotate_maxpool2d(node, quantization_config, filter_fn) + self._annotate_maxpool2d(node, quantization_config) return elif node.target is torch.ops.aten.cat.default: input_nodes_to_check = node.all_input_nodes @@ -1247,19 +1257,13 @@ def _annotate_output_share_observer_as_input( def _annotate_output_for_int8_in_int8_out_pattern_entry( self, model: torch.fx.GraphModule, - quantization_config: Optional[QuantizationConfig] = None, - filter_fn: Optional[Callable] = None, ): for node in model.graph.nodes: - self._annotate_output_for_int8_in_int8_out_pattern( - node, quantization_config, filter_fn - ) + self._annotate_output_for_int8_in_int8_out_pattern(node) def _annotate_output_for_int8_in_int8_out_pattern( self, node: Node, - quantization_config: Optional[QuantizationConfig] = None, - filter_fn: Optional[Callable] = None, ) -> None: r""" Check and insert observer at output of node in int8_in_int8_out_ops if needed. From d765d003d9b7ca1be68f78b51233cac61eb0901e Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 21 May 2024 10:35:52 +0800 Subject: [PATCH 021/706] refine the UTs Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index e618647555b4..626543b2b2de 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1900,9 +1900,10 @@ def forward(self, x): ) node_occurrence = { torch.ops.aten.linear.default: 2, - # input and output for the second linear + # quantize and dequantize the input for input of the second linear torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # dequantize the weight of the second linear torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -1945,7 +1946,7 @@ def forward(self, x): m = M().eval() example_inputs = (torch.randn(3, 5),) - # Set global to no quantization and then default config for a specific submodule. + # Set `sub` to no quantization and then default config for a all `Linear`. quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config() @@ -1953,9 +1954,10 @@ def forward(self, x): node_occurrence = { torch.ops.aten.linear.default: 3, - # quantize the input and output of the first linear + # quantize and dequantize the input and output of the first and second linear torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of the first and second linear torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ @@ -2009,13 +2011,14 @@ def forward(self, x): node_occurrence = { torch.ops.aten.linear.default: 3, - # quantize the input of the last linear + # quantize and dequantize the input of the last linear torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # dequantize the weight of the last linear torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ - # first and second linear is not quantized + # first and second linear are not quantized torch.ops.aten.linear.default, torch.ops.aten.linear.default, # last linear is quantized @@ -2095,25 +2098,25 @@ def test_set_module_name_qconfig_for_dynamic_quant(self): .set_module_name_qconfig("v_proj", dynamic_config) ) node_occurrence = { - # ops for quantizing/de-quantizing input + # quantize and dequantize the input torch.ops.quantized_decomposed.choose_qparams.tensor: 1, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, - # ops for dequantize the weight of q_proj and v_proj + # dequantize the weight of q_proj and v_proj torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - # ops for quantizing/de-quantizing input + # quantize and dequantize the input torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, - # # op for de-quantizing `q_proj`'s weight, disable this check + # # op for de-quantizing `q_proj`'s weight, disable this check to avoid random error. # torch.ops.quantized_decomposed.dequantize_per_channel.default, # q_proj torch.ops.aten.linear.default, # k_proj torch.ops.aten.linear.default, - # # op for de-quantizing `v_proj`'s weight, disable this check + # # op for de-quantizing `v_proj`'s weight, disable this check to avoid random error. # torch.ops.quantized_decomposed.dequantize_per_channel.default, # v_proj torch.ops.aten.linear.default, @@ -2149,17 +2152,17 @@ def test_set_mixed_static_and_dynamic(self): .set_module_name_qconfig("v_proj", dynamic_config) ) node_occurrence = { - # ops for quantizing/de-quantizing input + # quantize and dequantize the input torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, - # only q_proj be quantized + # only q_proj be quantized, dequantize its weight torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ # ops for quantizing/de-quantizing input torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, - # op for de-quantizing `q_proj`'s weight, disable this check + # op for de-quantizing `q_proj`'s weight, disable this check to avoid random error. # torch.ops.quantized_decomposed.dequantize_per_channel.default, # q_proj torch.ops.aten.linear.default, From b9f34e03e70dd6c1324d678892640af5b5444c98 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 22 May 2024 08:41:19 +0800 Subject: [PATCH 022/706] refine the docstring Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 179 +++++++++--------- .../quantizer/x86_inductor_quantizer.py | 11 +- 2 files changed, 99 insertions(+), 91 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 626543b2b2de..8a433e13bcf1 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1900,16 +1900,16 @@ def forward(self, x): ) node_occurrence = { torch.ops.aten.linear.default: 2, - # quantize and dequantize the input for input of the second linear + # quantize and dequantize the input of the second linear (`sub`) torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, - # dequantize the weight of the second linear + # dequantize the weight of the second linear (`sub`) torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ # first linear is not quantized torch.ops.aten.linear.default, - # second linear is quantized + # second linear (`sub`) is quantized torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, @@ -1917,10 +1917,60 @@ def forward(self, x): self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) @skipIfNoX86 - def test_set_module_name_and_set_module_type_case2(self): + def test_set_module_name_qconfig_with_underscores(self) -> None: + """Test that if a module name has an underscore, we can still quantize it.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + # This module name has underscores, which can be part of a mangled name. + self.foo_bar = torch.nn.Linear(2, 2) + self.baz = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.baz(self.foo_bar(x)) + + # Set global to no quantization and then default config for a specific submodule. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "foo_bar", xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = capture_pre_autograd_graph(m, example_inputs) + m = prepare_pt2e(m, quantizer) + # Use a linear count instead of names because the names might change, but + # the order should be the same. + count = 0 + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear.default: + # Get the weight observer to see the per-channel vs per-tensor. + weight_observer_node = n.args[1] + if count == 0: + # for foo_bar. + self.assertEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + observer_instance = getattr(m, weight_observer_node.target) + self.assertEqual( + observer_instance.qscheme, torch.per_channel_symmetric + ) + else: + # For baz it should have no observer at all. + self.assertNotEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + count += 1 + + @skipIfNoX86 + def test_set_module_name_and_set_module_type_case1(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. - All linear are quantized except the last one. + All linear are not quantized except the last one. """ class Sub(torch.nn.Module): @@ -1946,27 +1996,28 @@ def forward(self, x): m = M().eval() example_inputs = (torch.randn(3, 5),) - # Set `sub` to no quantization and then default config for a all `Linear`. + # Set `sub` with default config and then no quantization for all `Linear`. + # The config set by `set_module_name_qconfig` has higher priority than `set_module_type_qconfig`. quantizer = X86InductorQuantizer() - quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( - torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config() - ) + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config() + ).set_module_type_qconfig(torch.nn.Linear, None) node_occurrence = { torch.ops.aten.linear.default: 3, - # quantize and dequantize the input and output of the first and second linear - torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - # dequantize the weight of the first and second linear - torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + # quantize and dequantize the input of the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # dequantize the weight of the last linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ - # first and second linear are quantized - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # first and second linear are not quantized torch.ops.aten.linear.default, torch.ops.aten.linear.default, - # last linear is not quantized + # last linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, ] self._test_quantizer( @@ -1974,10 +2025,10 @@ def forward(self, x): ) @skipIfNoX86 - def test_set_module_name_and_set_module_type_case1(self): + def test_set_module_name_and_set_module_type_case2(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. - All linear are not quantized except the last one. + All linear are quantized except the last one. """ class Sub(torch.nn.Module): @@ -2003,83 +2054,33 @@ def forward(self, x): m = M().eval() example_inputs = (torch.randn(3, 5),) - # Set global to no quantization and then default config for a specific submodule. + # Set `sub` to None and then default config for a all `Linear`. quantizer = X86InductorQuantizer() - quantizer.set_module_name_qconfig( - "sub", xiq.get_default_x86_inductor_quantization_config() - ).set_module_type_qconfig(torch.nn.Linear, None) + quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( + torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config() + ) node_occurrence = { torch.ops.aten.linear.default: 3, - # quantize and dequantize the input of the last linear - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, - # dequantize the weight of the last linear - torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + # quantize and dequantize the input and output of the first and second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of the first and second linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - # first and second linear are not quantized - torch.ops.aten.linear.default, - torch.ops.aten.linear.default, - # last linear is quantized + # first and second linear are quantized torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + # last linear is not quantized + torch.ops.aten.linear.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, debug=True ) - @skipIfNoX86 - def test_set_module_name_qconfig_with_underscores(self) -> None: - """Test that if a module name has an underscore, we can still quantize it.""" - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - # This module name has underscores, which can be part of a mangled name. - self.foo_bar = torch.nn.Linear(2, 2) - self.baz = torch.nn.Linear(2, 2) - - def forward(self, x): - return self.baz(self.foo_bar(x)) - - # Set global to no quantization and then default config for a specific submodule. - quantizer = X86InductorQuantizer() - quantizer.set_module_name_qconfig( - "foo_bar", xiq.get_default_x86_inductor_quantization_config() - ) - example_inputs = (torch.randn(2, 2),) - m = M().eval() - m = capture_pre_autograd_graph(m, example_inputs) - m = prepare_pt2e(m, quantizer) - # Use a linear count instead of names because the names might change, but - # the order should be the same. - count = 0 - for n in m.graph.nodes: - if n.op == "call_function" and n.target == torch.ops.aten.linear.default: - # Get the weight observer to see the per-channel vs per-tensor. - weight_observer_node = n.args[1] - if count == 0: - # for foo_bar. - self.assertEqual( - weight_observer_node.op, - "call_module", - f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", - ) - observer_instance = getattr(m, weight_observer_node.target) - self.assertEqual( - observer_instance.qscheme, torch.per_channel_symmetric - ) - else: - # For baz it should have no observer at all. - self.assertNotEqual( - weight_observer_node.op, - "call_module", - f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", - ) - count += 1 - @skipIfNoX86 def test_set_module_name_qconfig_for_dynamic_quant(self): """Test that quantize the specific submodule for dynamic quantization.""" @@ -2131,13 +2132,15 @@ def test_set_module_name_qconfig_for_dynamic_quant(self): ) @skipIfNoX86 - def test_set_mixed_static_and_dynamic(self): - """Test that mixed static and dynamic quantization for a module.""" + def test_set_module_name_with_mixed_static_and_dynamic(self): + """Test that mixed static and dynamic quantization for a module. + + Currently, mixed static and dynamic quantization is not supported. The subsequent config will be ignored. + """ with override_quantized_engine("x86"), torch.no_grad(): m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() example_inputs = (torch.randn(1, 4, 64),) - # quantize `self.q_proj` static_config = xiq.get_default_x86_inductor_quantization_config( is_dynamic=False ) @@ -2159,10 +2162,10 @@ def test_set_mixed_static_and_dynamic(self): torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ - # ops for quantizing/de-quantizing input + # quantize and dequantize the input torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, - # op for de-quantizing `q_proj`'s weight, disable this check to avoid random error. + # dequantize `q_proj`'s weight, disable this check to avoid random error. # torch.ops.quantized_decomposed.dequantize_per_channel.default, # q_proj torch.ops.aten.linear.default, diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 751db41df9c3..b5ef864f9846 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -632,11 +632,16 @@ def _need_skip_cur_config( def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """Annotate the model with quantization configuration. - 1) Annotate each node according users's qconfig with following order: + Note: + 1. Annotate each node according to the users's qconfig in the following order: `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. - 2) Skip nodes already annotated by earlier stage. For example, - if `linear1` has been annotated in the `module_name_config` stage, it will + 2. Skip nodes that have already been annotated by an earlier stage. For example, + if `linear1` has been annotated during in the `module_name_config` stage, it will not be re-annotated in the `operator_type_qconfig` or `global_config` stages. + 3. For the config is `None`, the annotation will be skipped. + + For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. + This filter function checks if the node is marked by current stage and not marked by previous stage. """ self._check_qconfig() From 5a9a6017542cb91def9c19a7101d4938f78c793e Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 23 May 2024 08:36:23 +0800 Subject: [PATCH 023/706] rename some funcs and enhance config checker Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 145 +++++++++--------- 1 file changed, 72 insertions(+), 73 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index b5ef864f9846..cbcb155e1b22 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -360,6 +360,26 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_x86_inductor_config_and_operators() +from functools import wraps + + +def config_checker(method: Callable) -> Callable: + @wraps(method) + def wrapper( + self: "X86InductorQuantizer", + name: Any, + quantization_config: Optional["QuantizationConfig"], + ) -> "X86InductorQuantizer": + if self._need_skip_config(quantization_config): + warnings.warn( + f"Skip the quantization config for {name} by X86InductorQuantizer.", + ) + return self + return method(self, name, quantization_config) + + return wrapper + + class X86InductorQuantizer(Quantizer): supported_config_and_operators = _get_supported_config_and_operators() module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type() @@ -371,6 +391,8 @@ def __init__(self): torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} self.module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} + self._is_dynamic = None + self._is_qat = None @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: @@ -394,7 +416,41 @@ def get_supported_operator_for_quantization_config( return ops return [] + def _need_skip_config( + self, quantization_config: Optional[QuantizationConfig] + ) -> bool: + """Check if the given quantization config is valid for this quantizer. + + Note: Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. + If such a mix is detected, the configuration will be skipped. + """ + if quantization_config is None: + return False + + if self._is_qat is None: + self._is_qat = quantization_config.is_qat + else: + if self._is_qat != quantization_config.is_qat: + warnings.warn( + "Mixed QAT and Non-QAT quantization config is not supported." + ) + return True + input_activation_spec = quantization_config.input_activation + if input_activation_spec is not None: + if self._is_dynamic is None: + self._is_dynamic = input_activation_spec.is_dynamic + else: + if self._is_dynamic != input_activation_spec.is_dynamic: + warnings.warn( + "Mixed dynamic and static quantization config is not supported." + ) + return True + return False + def set_global(self, quantization_config: QuantizationConfig): + if self._need_skip_config(quantization_config): + warnings.warn("Skip the global quantization config.") + return self self.global_config = quantization_config return self @@ -406,6 +462,7 @@ def get_global_quantization_config(self): ) return self.global_config + @config_checker def set_function_type_qconfig( self, function_type: Callable, @@ -424,6 +481,7 @@ def set_function_type_qconfig( ) return self + @config_checker def set_module_type_qconfig( self, module_type: torch.nn.Module, @@ -440,6 +498,7 @@ def set_module_type_qconfig( ) return self + @config_checker def set_module_name_qconfig( self, module_name: str, quantization_config: Optional[QuantizationConfig] ): @@ -570,67 +629,9 @@ def _get_input_idx_for_binary_node( conv_gemm_node_idx = 1 extra_input_node_idx = 0 extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index] - assert isinstance(extra_input_node, Node) - return conv_gemm_node_idx, extra_input_node_idx - - def _check_qconfig(self) -> None: - """Check if the qconfig is valid. - - Currently, not support mixed static and dynamic quantization config. - If the mixture is detected, the subsequent configuration will be skipped. - """ - - def _need_skip_cur_config( - qconfig: Optional[QuantizationConfig], _pre_mode: Optional[bool], msg: str - ) -> Tuple[Optional[bool], bool]: - input_act_config = getattr(qconfig, "input_activation", None) - need_skip = False - if input_act_config: - qconfig_is_dynamic = input_act_config.is_dynamic - if _pre_mode is not None and _pre_mode != qconfig_is_dynamic: - warnings.warn( - "Mixed dynamic and static quantization config is not supported." - f"The configuration for {msg} will be skipped." - ) - need_skip = True - else: - if _pre_mode is None: - _pre_mode = qconfig_is_dynamic - return _pre_mode, need_skip - - _pre_mode = None - - tmp_module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} - for module_name, qconfig in self.module_name_qconfig.items(): - _pre_mode, need_skip = _need_skip_cur_config( - qconfig, _pre_mode, module_name - ) - if not need_skip: - tmp_module_name_qconfig[module_name] = qconfig - self.module_name_qconfig = tmp_module_name_qconfig - - tmp_operator_type_qconfig: Dict[ - torch._ops.OpOverloadPacket, Optional[QuantizationConfig] - ] = {} - for operator_type, qconfig in self.operator_type_qconfig.items(): - _pre_mode, need_skip = _need_skip_cur_config( - qconfig, _pre_mode, str(operator_type) - ) - if not need_skip: - tmp_operator_type_qconfig[operator_type] = qconfig - self.operator_type_qconfig = tmp_operator_type_qconfig - - if self.global_config: - _pre_mode, need_skip = _need_skip_cur_config( - self.global_config, _pre_mode, "global" - ) - if not need_skip: - self.global_config = self.global_config - else: - self.global_config = None def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """Annotate the model with quantization configuration. + """Annotate the model with quantization configurations. Note: 1. Annotate each node according to the users's qconfig in the following order: @@ -644,24 +645,22 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: This filter function checks if the node is marked by current stage and not marked by previous stage. """ - self._check_qconfig() - module_name_list = list(self.module_name_qconfig.keys()) for module_name, qconfig in self.module_name_qconfig.items(): - self._annotate_by_single_config( + self._annotate_by_config( model, qconfig, _get_module_name_filter(module_name) ) tp_list = list(self.operator_type_qconfig.keys()) for operator_type, qconfig in self.operator_type_qconfig.items(): - self._annotate_by_single_config( + self._annotate_by_config( model, qconfig, _get_operator_type_filter(operator_type, module_name_list), ) if self.global_config: - self._annotate_by_single_config( + self._annotate_by_config( model, self.global_config, _get_not_operator_type_or_name_filter(tp_list, module_name_list), @@ -669,7 +668,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: return model - def _annotate_by_single_config( + def _annotate_by_config( self, model: torch.fx.GraphModule, config: Optional[QuantizationConfig], @@ -690,8 +689,8 @@ def _annotate_by_single_config( return # Step1: Recipe of fusion patterns like conv/linear. - self._annotate_all_conv2d_fusion_pattern(model, config, filter_fn) - self._annotate_all_linear_fusion_pattern(model, config, filter_fn) + self._annotate_conv2d_fusion_pattern(model, config, filter_fn) + self._annotate_linear_fusion_pattern(model, config, filter_fn) self._annotate_matmul(model, config, filter_fn) # Step2: Recipe to propagate annotation for patterns beside conv/linear. @@ -699,7 +698,7 @@ def _annotate_by_single_config( # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/ # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 - self._annotate_all_propagation_quantizable_pattern(model, config, filter_fn) + self._annotate_propagation_quantizable_pattern_entry(model, config, filter_fn) # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized # in inputs. So, we can fuse dq-operator-q into a quantized op. @@ -708,7 +707,7 @@ def _annotate_by_single_config( self._annotate_output_for_int8_in_int8_out_pattern_entry(model) - def _annotate_all_qat_conv2d_fusion_pattern( + def _annotate_qat_conv2d_fusion_pattern( self, model: torch.fx.GraphModule, config: QuantizationConfig, filter_fn ): # Annotate QAT Specific patterns @@ -935,18 +934,18 @@ def _annotate_qat_conv2d_bn( nodes_to_mark_annotated.extend(list(bn_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) - def _annotate_all_conv2d_fusion_pattern( + def _annotate_conv2d_fusion_pattern( self, model: torch.fx.GraphModule, config, filter_fn ): if config.is_qat: # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat - self._annotate_all_qat_conv2d_fusion_pattern(model, config, filter_fn) + self._annotate_qat_conv2d_fusion_pattern(model, config, filter_fn) self._annotate_conv2d_binary_unary(model, config, filter_fn) self._annotate_conv2d_binary(model, config, filter_fn) self._annotate_conv2d_unary(model, config, filter_fn) self._annotate_conv2d(model, config, filter_fn) - def _annotate_all_linear_fusion_pattern( + def _annotate_linear_fusion_pattern( self, model: torch.fx.GraphModule, config, filter_fn ): if config.input_activation and not config.input_activation.is_dynamic: @@ -1184,7 +1183,7 @@ def _annotate_cat( _is_output_of_quantized_pattern=True, ) - def _annotate_all_propagation_quantizable_pattern( + def _annotate_propagation_quantizable_pattern_entry( self, model, quantization_config, filter_fn ): for node in model.graph.nodes: From a3b0129422552205bf5b5cfcb8e0fe20d4e96c68 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 23 May 2024 08:38:29 +0800 Subject: [PATCH 024/706] rename annotate func name Signed-off-by: yiliu30 --- torch/ao/quantization/quantizer/x86_inductor_quantizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index cbcb155e1b22..7f2cad0dbded 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -647,20 +647,20 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: module_name_list = list(self.module_name_qconfig.keys()) for module_name, qconfig in self.module_name_qconfig.items(): - self._annotate_by_config( + self._annotate_with_config( model, qconfig, _get_module_name_filter(module_name) ) tp_list = list(self.operator_type_qconfig.keys()) for operator_type, qconfig in self.operator_type_qconfig.items(): - self._annotate_by_config( + self._annotate_with_config( model, qconfig, _get_operator_type_filter(operator_type, module_name_list), ) if self.global_config: - self._annotate_by_config( + self._annotate_with_config( model, self.global_config, _get_not_operator_type_or_name_filter(tp_list, module_name_list), @@ -668,7 +668,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: return model - def _annotate_by_config( + def _annotate_with_config( self, model: torch.fx.GraphModule, config: Optional[QuantizationConfig], From ec263d963a0085697a51abcb86e3ef1a6f156f71 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 23 May 2024 09:10:28 +0800 Subject: [PATCH 025/706] revert some change Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 7f2cad0dbded..17d495be826a 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -360,11 +360,8 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_x86_inductor_config_and_operators() -from functools import wraps - - def config_checker(method: Callable) -> Callable: - @wraps(method) + @functools.wraps(method) def wrapper( self: "X86InductorQuantizer", name: Any, @@ -629,6 +626,8 @@ def _get_input_idx_for_binary_node( conv_gemm_node_idx = 1 extra_input_node_idx = 0 extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index] + assert isinstance(extra_input_node, Node) + return conv_gemm_node_idx, extra_input_node_idx def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """Annotate the model with quantization configurations. @@ -685,8 +684,6 @@ def _annotate_with_config( such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type, we need to annotate the output of this pattern. """ - if config is None: - return # Step1: Recipe of fusion patterns like conv/linear. self._annotate_conv2d_fusion_pattern(model, config, filter_fn) @@ -937,7 +934,7 @@ def _annotate_qat_conv2d_bn( def _annotate_conv2d_fusion_pattern( self, model: torch.fx.GraphModule, config, filter_fn ): - if config.is_qat: + if self._is_qat: # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat self._annotate_qat_conv2d_fusion_pattern(model, config, filter_fn) self._annotate_conv2d_binary_unary(model, config, filter_fn) @@ -948,7 +945,7 @@ def _annotate_conv2d_fusion_pattern( def _annotate_linear_fusion_pattern( self, model: torch.fx.GraphModule, config, filter_fn ): - if config.input_activation and not config.input_activation.is_dynamic: + if not self._is_dynamic: # Weiwen: Dynamic Quant of linear unary will be supported in next step self._annotate_linear_binary_unary(model, config, filter_fn) self._annotate_linear_unary(model, config, filter_fn) @@ -957,7 +954,7 @@ def _annotate_linear_fusion_pattern( def _annotate_matmul( self, model: torch.fx.GraphModule, - quantization_config: QuantizationConfig, + quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ): for node in model.graph.nodes: From 105429c9558f70c5bf845d4bab006d182e56da31 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 23 May 2024 09:10:28 +0800 Subject: [PATCH 026/706] revert some change Signed-off-by: yiliu30 --- torch/ao/quantization/quantizer/utils.py | 34 ------ .../quantizer/x86_inductor_quantizer.py | 106 +++++++++++++----- .../quantizer/xnnpack_quantizer.py | 35 +++++- 3 files changed, 112 insertions(+), 63 deletions(-) diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index 77cfc22d73be..f25d0916018b 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -47,37 +47,3 @@ def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]): ((user not in partition_nodes) or _is_sym_size_node(user)) for user in node.users ) - - -def _get_module_name_filter(module_name: str): - """Get the module_name_filter function for a given module name, the filter accepts - a node and checks if the node comes from a module that has certain module name - - For example: - node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 - - - >> module_name_filter = _get_module_name_filter("blocks.sub") - >> print(module_name_filter(node)) - True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" - """ - - def module_name_filter(n: Node) -> bool: - # example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - # get_attr nodes doesn't have nn_module_stack? - nn_module_stack = n.meta.get("nn_module_stack", {}) - - def _normalize_path(n): - prefix = 0 - # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. - if n.startswith("L['self']."): - prefix = len("L['self'].") - return n[prefix:] - - names = [_normalize_path(n) for n, _ in nn_module_stack.values()] - return module_name in names - - return module_name_filter diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 7f2cad0dbded..95cd4e50524e 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -36,7 +36,8 @@ Quantizer, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import _get_module_name_filter + +# from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( get_bias_qspec, get_input_act_qspec, @@ -102,19 +103,56 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): def _skip_annotate( - nodes: List[Node], filter_fn: Optional[Callable[[Node], bool]] = None + nodes: List[Node], filter_fn: Optional[Callable[[List[Node]], bool]] = None ): # 1) Skip annotate if any node is already annotated if _is_any_annotated(nodes): return True # 2) Not skip annotate if a) filter_fn is provided and b) any node passed the filter - if filter_fn and any(filter_fn(node) for node in nodes): + if filter_fn and filter_fn(nodes): return False return True +def _get_module_name_filter(module_name: str): + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + + def _normalize_path(n): + prefix = 0 + # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. + if n.startswith("L['self']."): + prefix = len("L['self'].") + return n[prefix:] + + names = [_normalize_path(n) for n, _ in nn_module_stack.values()] + return module_name in names + + def check_all_node(nodes: List[Node]) -> bool: + return all(module_name_filter(n) for n in nodes) + + return check_all_node + + def _get_operator_type_filter(operator_type: Callable, module_name_list): """Get the operator_type_filter function for a given operator type and module name list. @@ -141,6 +179,23 @@ def operator_type_filter(n: Node) -> bool: return operator_type_filter +def _get_operator_type_qconfig_filter(operator_type: Callable): + def operator_type_qconfig_filter(nodes: List[Node]): + # Return True, if the first node has the certain operator type + has_certain_operator_type = nodes[0].target == operator_type + return has_certain_operator_type + + return operator_type_qconfig_filter + + +def _get_global_config_filter(): + def global_config_filter(nodes: List[Node]): + # Return True, if the first node has not been annotated + return nodes[0].target in default_quantizable_ops + + return global_config_filter + + def _get_not_operator_type_or_name_filter( tp_list: List[torch._ops.OpOverloadPacket], module_name_list: List[str] ) -> Callable[[Node], bool]: @@ -360,11 +415,8 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_x86_inductor_config_and_operators() -from functools import wraps - - def config_checker(method: Callable) -> Callable: - @wraps(method) + @functools.wraps(method) def wrapper( self: "X86InductorQuantizer", name: Any, @@ -629,6 +681,8 @@ def _get_input_idx_for_binary_node( conv_gemm_node_idx = 1 extra_input_node_idx = 0 extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index] + assert isinstance(extra_input_node, Node) + return conv_gemm_node_idx, extra_input_node_idx def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """Annotate the model with quantization configurations. @@ -645,25 +699,23 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: This filter function checks if the node is marked by current stage and not marked by previous stage. """ - module_name_list = list(self.module_name_qconfig.keys()) for module_name, qconfig in self.module_name_qconfig.items(): self._annotate_with_config( model, qconfig, _get_module_name_filter(module_name) ) - tp_list = list(self.operator_type_qconfig.keys()) for operator_type, qconfig in self.operator_type_qconfig.items(): self._annotate_with_config( model, qconfig, - _get_operator_type_filter(operator_type, module_name_list), + _get_operator_type_qconfig_filter(operator_type) ) if self.global_config: self._annotate_with_config( model, self.global_config, - _get_not_operator_type_or_name_filter(tp_list, module_name_list), + _get_global_config_filter(), ) return model @@ -685,8 +737,6 @@ def _annotate_with_config( such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type, we need to annotate the output of this pattern. """ - if config is None: - return # Step1: Recipe of fusion patterns like conv/linear. self._annotate_conv2d_fusion_pattern(model, config, filter_fn) @@ -720,7 +770,7 @@ def _annotate_qat_conv2d_bn_binary_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] @@ -791,7 +841,7 @@ def _annotate_qat_conv2d_bn_binary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add] @@ -850,7 +900,7 @@ def _annotate_qat_conv2d_bn_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: fused_partitions = [] unary_patterns = [ @@ -901,7 +951,7 @@ def _annotate_qat_conv2d_bn( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d] @@ -937,7 +987,7 @@ def _annotate_qat_conv2d_bn( def _annotate_conv2d_fusion_pattern( self, model: torch.fx.GraphModule, config, filter_fn ): - if config.is_qat: + if self._is_qat: # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat self._annotate_qat_conv2d_fusion_pattern(model, config, filter_fn) self._annotate_conv2d_binary_unary(model, config, filter_fn) @@ -948,7 +998,7 @@ def _annotate_conv2d_fusion_pattern( def _annotate_linear_fusion_pattern( self, model: torch.fx.GraphModule, config, filter_fn ): - if config.input_activation and not config.input_activation.is_dynamic: + if not self._is_dynamic: # Weiwen: Dynamic Quant of linear unary will be supported in next step self._annotate_linear_binary_unary(model, config, filter_fn) self._annotate_linear_unary(model, config, filter_fn) @@ -957,8 +1007,8 @@ def _annotate_linear_fusion_pattern( def _annotate_matmul( self, model: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ): for node in model.graph.nodes: if node.target != torch.ops.aten.matmul.default: @@ -979,7 +1029,7 @@ def _annotate_conv2d_binary_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: # Conv2d + add + unary op fused_partitions = find_sequential_partitions( @@ -1027,7 +1077,7 @@ def _annotate_conv2d_binary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: # Conv2d + add fused_partitions = find_sequential_partitions( @@ -1073,7 +1123,7 @@ def _annotate_conv2d_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: fused_partitions = [] unary_patterns = [ @@ -1111,7 +1161,7 @@ def _annotate_conv2d( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: conv_partitions = get_source_partitions( gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] @@ -1313,7 +1363,7 @@ def _annotate_linear( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: linear_partitions = get_source_partitions( gm.graph, [torch.nn.Linear, torch.nn.functional.linear] @@ -1340,7 +1390,7 @@ def _annotate_linear_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: postop_list = [ torch.nn.ReLU, @@ -1374,7 +1424,7 @@ def _annotate_linear_binary_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, + filter_fn: Optional[Callable[[List[Node]], bool]] = None, ) -> None: # linear + binary_op + (optional) unary op binary_op_list = [operator.add] diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index e13a79f39267..f3d1b6ca8b39 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -22,7 +22,6 @@ ) from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( _convert_scalars_to_attrs, @@ -193,6 +192,40 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_symmetric_config_and_operators() +def _get_module_name_filter(module_name: str): + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + + def _normalize_path(n): + prefix = 0 + # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. + if n.startswith("L['self']."): + prefix = len("L['self'].") + return n[prefix:] + + names = [_normalize_path(n) for n, _ in nn_module_stack.values()] + return module_name in names + + return module_name_filter + + def _get_module_type_filter(tp: Callable): """Get the module_type_filter function for a given module type, the filter accepts a node and checks if the node comes from a module that has certain module type From 74df47f38f9a18a0fce24183598cfaa820b0595c Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 23 May 2024 09:49:35 +0800 Subject: [PATCH 027/706] refine the filter_fn Signed-off-by: yiliu30 --- torch/ao/quantization/quantizer/x86_inductor_quantizer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 95cd4e50524e..8ff9cf80bd77 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -146,7 +146,7 @@ def _normalize_path(n): names = [_normalize_path(n) for n, _ in nn_module_stack.values()] return module_name in names - + # TODO: def check_all_node(nodes: List[Node]) -> bool: return all(module_name_filter(n) for n in nodes) @@ -180,18 +180,20 @@ def operator_type_filter(n: Node) -> bool: def _get_operator_type_qconfig_filter(operator_type: Callable): + # TODO: def operator_type_qconfig_filter(nodes: List[Node]): # Return True, if the first node has the certain operator type - has_certain_operator_type = nodes[0].target == operator_type + has_certain_operator_type = nodes[-1].target == operator_type return has_certain_operator_type return operator_type_qconfig_filter def _get_global_config_filter(): + # TODO: def global_config_filter(nodes: List[Node]): # Return True, if the first node has not been annotated - return nodes[0].target in default_quantizable_ops + return any (node.target in default_quantizable_ops for node in nodes) return global_config_filter From 979e9658293d6b56cdcabc63871f15616b93dba8 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 24 May 2024 18:06:44 +0800 Subject: [PATCH 028/706] fixed the filter_fn Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 76 ++++++- .../quantizer/x86_inductor_quantizer.py | 186 ++++++++---------- 2 files changed, 153 insertions(+), 109 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 8a433e13bcf1..54c682ceba1f 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1916,6 +1916,69 @@ def forward(self, x): ] self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + @skipIfNoX86 + def test_set_module_name_qconfig_case2(self): + """Test that quantize the specific submodule.""" + + class Sub(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.relu1 = torch.nn.ReLU(inplace=False) + self.linear2 = torch.nn.Linear(10, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.relu1(x) + x = self.linear2(x) + return x + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set global to no quantization and then default config for a specific submodule. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of the two linear layers from `sub` + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of the two linear layers from `sub` + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + # node_list = None + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # two linear layers from `sub` are quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + @skipIfNoX86 def test_set_module_name_qconfig_with_underscores(self) -> None: """Test that if a module name has an underscore, we can still quantize it.""" @@ -2021,7 +2084,11 @@ def forward(self, x): torch.ops.aten.linear.default, ] self._test_quantizer( - m, example_inputs, quantizer, node_occurrence, node_list, debug=True + m, + example_inputs, + quantizer, + node_occurrence, + node_list, ) @skipIfNoX86 @@ -2078,7 +2145,11 @@ def forward(self, x): torch.ops.aten.linear.default, ] self._test_quantizer( - m, example_inputs, quantizer, node_occurrence, node_list, debug=True + m, + example_inputs, + quantizer, + node_occurrence, + node_list, ) @skipIfNoX86 @@ -2183,7 +2254,6 @@ def test_set_module_name_with_mixed_static_and_dynamic(self): node_occurrence, node_list, is_qat=True, - debug=True, ) @skipIfNoX86 diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 8ff9cf80bd77..c24b2ad4f8f4 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -68,6 +68,7 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): # * Node as output node of a fusion pattern. # * The fusion pattern supports int8 data type. # * The fusion pattern has inputs annotated to insert observer. + # * The quantization_config is not `None`. _is_output_of_quantized_pattern: bool = False @@ -116,20 +117,22 @@ def _skip_annotate( return True -def _get_module_name_filter(module_name: str): - """Get the module_name_filter function for a given module name, the filter accepts - a node and checks if the node comes from a module that has certain module name +def _get_module_name_filter(module_name: str) -> Callable[[List[Node]], bool]: + """Get a filter function for a given module name. - For example: - node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + The filter function that takes a list of nodes (as determined by the annotate function) + and returns True if **all** nodes come from the specified module name, and False otherwise. + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` - >> module_name_filter = _get_module_name_filter("blocks.sub") - >> print(module_name_filter(node)) - True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + >> module_name_filter = _get_module_name_filter("sub") + >> print(module_name_filter([relu, linear_1])) + # True # These two nodes are from "sub" and determined by `_annotate_linear_unary` function. """ - def module_name_filter(n: Node) -> bool: + def _node_filter(n: Node) -> bool: # example: { # 'L__self___sub': ("L['self'].sub", ), # 'L__self___sub_linear': ("L['self'].sub.linear", ) @@ -146,41 +149,19 @@ def _normalize_path(n): names = [_normalize_path(n) for n, _ in nn_module_stack.values()] return module_name in names - # TODO: - def check_all_node(nodes: List[Node]) -> bool: - return all(module_name_filter(n) for n in nodes) - - return check_all_node - -def _get_operator_type_filter(operator_type: Callable, module_name_list): - """Get the operator_type_filter function for a given operator type and module name list. + def module_name_filter(nodes: List[Node]) -> bool: + all_nodes_from_module_name: bool = all(_node_filter(n) for n in nodes) + return all_nodes_from_module_name - The filter accept a node and checks if 1) the node has certain operator type, - and 2) the node does not marked by `set_module_name_qconfig`. - - For example: - # linear_op.target if torch.ops.aten.linear.default - # linear_op is traced from `fc3` - node: linear_op = call_function[...](...) - - - >> operator_type_filter = _get_operator_type_filter(torch.ops.aten.linear.default, ["fc1", "fc2"]) - >> print(operator_type_filter(node)) - True # the node's target is `torch.ops.aten.linear.default` and not marked by `set_module_name_qconfig` - """ - module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] - - def operator_type_filter(n: Node) -> bool: - not_module_name_node = not any(f(n) for f in module_name_list_filters) - has_certain_operator_type = n.target == operator_type - return not_module_name_node and has_certain_operator_type - - return operator_type_filter + return module_name_filter def _get_operator_type_qconfig_filter(operator_type: Callable): - # TODO: + # TODO: Currently, we pass a list of nodes determined by the annotate function, but only one op is the target op. + # For example, [relu, conv] + # 1) Only pass the anchor node? + # 2) If one node is annotated, the rest of the nodes are annotated? def operator_type_qconfig_filter(nodes: List[Node]): # Return True, if the first node has the certain operator type has_certain_operator_type = nodes[-1].target == operator_type @@ -190,41 +171,16 @@ def operator_type_qconfig_filter(nodes: List[Node]): def _get_global_config_filter(): - # TODO: def global_config_filter(nodes: List[Node]): - # Return True, if the first node has not been annotated - return any (node.target in default_quantizable_ops for node in nodes) + # TODO: double-check + # Have the same issue as `_get_operator_type_qconfig_filter` + # Return True if any node belongs to the `default_quantizable_ops` + nodes_is_default = any(node.target in default_quantizable_ops for node in nodes) + return nodes_is_default return global_config_filter -def _get_not_operator_type_or_name_filter( - tp_list: List[torch._ops.OpOverloadPacket], module_name_list: List[str] -) -> Callable[[Node], bool]: - """Get the not_operator_type_or_name_filter function for a given operator type list and module name list. - - The filter accept a node and checks if 1) the node does not marked by `set_module_name_qconfig`, - or `set_module_type_qconfig` `set_function_type_qconfig`, and 2) the node's type - is belong to `default_quantizable_ops`. - """ - - # Only call the `operator_type_filters` is enough, since each filter of `operator_type_filters` - # will check the `module_name_list_filters`. - operator_type_filters = [ - _get_operator_type_filter(tp, module_name_list) for tp in tp_list - ] - - def not_operator_type_or_name_filter(n: Node) -> bool: - # For global_config, only quantize the `default_quantizable_ops` - is_default_quantizable_op = n.target in default_quantizable_ops - not_operator_type_or_module_name_node = not any( - f(n) for f in operator_type_filters - ) - return is_default_quantizable_op and not_operator_type_or_module_name_node - - return not_operator_type_or_name_filter - - def _map_module_function_to_aten_operator_type(): module_function_to_aten_operator: Dict[Callable, torch._ops.OpOverloadPacket] = {} map_list = ( @@ -306,6 +262,14 @@ def _is_quantized_op_pt2e(node: torch.fx.Node): return False quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None) assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation) + # TODO: + # Conv + # | + # Relu <- is annotated by `None` quantization config, `_is_output_of_quantized_pattern` is marked to False + # | + # Maxpool + # | + # return quantization_annotation._is_output_of_quantized_pattern @@ -426,7 +390,7 @@ def wrapper( ) -> "X86InductorQuantizer": if self._need_skip_config(quantization_config): warnings.warn( - f"Skip the quantization config for {name} by X86InductorQuantizer.", + f"Skip the quantization config for {name}.", ) return self return method(self, name, quantization_config) @@ -473,11 +437,12 @@ def get_supported_operator_for_quantization_config( def _need_skip_config( self, quantization_config: Optional[QuantizationConfig] ) -> bool: - """Check if the given quantization config is valid for this quantizer. + """Check if the given quantization config is valid for X86InductorQuantizer. Note: Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. If such a mix is detected, the configuration will be skipped. """ + need_skip = False if quantization_config is None: return False @@ -488,7 +453,7 @@ def _need_skip_config( warnings.warn( "Mixed QAT and Non-QAT quantization config is not supported." ) - return True + need_skip = True input_activation_spec = quantization_config.input_activation if input_activation_spec is not None: if self._is_dynamic is None: @@ -498,8 +463,8 @@ def _need_skip_config( warnings.warn( "Mixed dynamic and static quantization config is not supported." ) - return True - return False + need_skip = True + return need_skip def set_global(self, quantization_config: QuantizationConfig): if self._need_skip_config(quantization_config): @@ -599,7 +564,7 @@ def _annotate_conv_node_helper( conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) else: conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( @@ -637,7 +602,7 @@ def _annotate_linear_node_helper( linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) else: linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( @@ -708,9 +673,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: for operator_type, qconfig in self.operator_type_qconfig.items(): self._annotate_with_config( - model, - qconfig, - _get_operator_type_qconfig_filter(operator_type) + model, qconfig, _get_operator_type_qconfig_filter(operator_type) ) if self.global_config: @@ -720,6 +683,13 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: _get_global_config_filter(), ) + # After annotating the model with quantization configurations, we need to annotate the output of quantizable ops. + # For example, if we annotated maxpool2d to quantize its inputs, we need to quantize its output as well. + # So, we can fuse dq-operator-q into a quantized op. + # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 + self._annotate_output_for_int8_in_int8_out_pattern_entry(model) + return model def _annotate_with_config( @@ -730,6 +700,7 @@ def _annotate_with_config( ) -> None: """Annotate the model with a quantization configuration. + # TODO update the note High-level description of quantization recipe for X86 Inductor Backend: Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model @@ -752,13 +723,6 @@ def _annotate_with_config( self._annotate_propagation_quantizable_pattern_entry(model, config, filter_fn) - # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized - # in inputs. So, we can fuse dq-operator-q into a quantized op. - # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ - # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 - - self._annotate_output_for_int8_in_int8_out_pattern_entry(model) - def _annotate_qat_conv2d_fusion_pattern( self, model: torch.fx.GraphModule, config: QuantizationConfig, filter_fn ): @@ -831,7 +795,7 @@ def _annotate_qat_conv2d_bn_binary_unary( # TODO Remove the annotate of output in QAT when qat util support pattern matcher. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) @@ -891,7 +855,7 @@ def _annotate_qat_conv2d_bn_binary( # TODO Remove the annotate of output in QAT when qat util support pattern matcher. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) @@ -942,7 +906,7 @@ def _annotate_qat_conv2d_bn_unary( # TODO Remove the annotate of output in QAT when qat util support pattern matcher. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) @@ -980,7 +944,7 @@ def _annotate_qat_conv2d_bn( # TODO Remove the annotate of output in QAT when qat util support pattern matcher. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) @@ -1024,7 +988,7 @@ def _annotate_matmul( matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) def _annotate_conv2d_binary_unary( @@ -1072,7 +1036,7 @@ def _annotate_conv2d_binary_unary( ) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) def _annotate_conv2d_binary( @@ -1118,7 +1082,7 @@ def _annotate_conv2d_binary( binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=binary_node_input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) def _annotate_conv2d_unary( @@ -1156,7 +1120,7 @@ def _annotate_conv2d_unary( self._annotate_conv_node_helper(conv_node, False, quantization_config) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) def _annotate_conv2d( @@ -1186,7 +1150,7 @@ def _annotate_conv2d( def _annotate_maxpool2d( self, node: Node, - quantization_config: QuantizationConfig, + quantization_config: Optional[QuantizationConfig], ) -> None: if node.target is not torch.ops.aten.max_pool2d.default: return @@ -1205,7 +1169,7 @@ def _annotate_maxpool2d( maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) def _annotate_cat( @@ -1232,7 +1196,7 @@ def _annotate_cat( cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) def _annotate_propagation_quantizable_pattern_entry( @@ -1264,11 +1228,20 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): return if node.target is torch.ops.aten.max_pool2d.default: - # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not - input_nodes_to_check = [node.all_input_nodes[0]] - if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): - return - self._annotate_maxpool2d(node, quantization_config) + if quantization_config is None: + # TODO: + # If quantization_config is None, we mark the `_annotated` as True with a empty `input_qspec_map`. + # Handle the `cat` and other propagation patterns as well. + self._annotate_maxpool2d(node, quantization_config) + else: + # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not + input_nodes_to_check = [node.all_input_nodes[0]] + if not is_all_inputs_connected_to_quantized_op( + input_nodes_to_check + ): + return + + self._annotate_maxpool2d(node, quantization_config) return elif node.target is torch.ops.aten.cat.default: input_nodes_to_check = node.all_input_nodes @@ -1288,7 +1261,7 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) return @@ -1419,7 +1392,7 @@ def _annotate_linear_unary( self._annotate_linear_node_helper(linear_node, False, quantization_config) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) def _annotate_linear_binary_unary( @@ -1495,14 +1468,15 @@ def _annotate_linear_binary_unary( ] = _X86InductorQuantizationAnnotation( input_qspec_map={}, _annotated=True, - _is_output_of_quantized_pattern=(not has_unary), + _is_output_of_quantized_pattern=quantization_config is not None + and (not has_unary), ) if unary_node is not None: unary_node.meta[ QUANT_ANNOTATION_KEY ] = _X86InductorQuantizationAnnotation( _annotated=True, - _is_output_of_quantized_pattern=True, + _is_output_of_quantized_pattern=quantization_config is not None, ) def validate(self, model: torch.fx.GraphModule) -> None: From 72880b5b2e179a8fc1a64d972f8a135adf9fc1c6 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 27 May 2024 17:16:23 +0800 Subject: [PATCH 029/706] refine code Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 254 +++++++++++------- 1 file changed, 161 insertions(+), 93 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index c24b2ad4f8f4..3a37bbdbcd74 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -14,6 +14,7 @@ Set, Tuple, TYPE_CHECKING, + Union, ) import torch @@ -157,28 +158,27 @@ def module_name_filter(nodes: List[Node]) -> bool: return module_name_filter -def _get_operator_type_qconfig_filter(operator_type: Callable): - # TODO: Currently, we pass a list of nodes determined by the annotate function, but only one op is the target op. - # For example, [relu, conv] - # 1) Only pass the anchor node? - # 2) If one node is annotated, the rest of the nodes are annotated? +def _get_operator_type_qconfig_filter( + operator_type: Callable, +) -> Callable[[List[Node]], bool]: def operator_type_qconfig_filter(nodes: List[Node]): - # Return True, if the first node has the certain operator type - has_certain_operator_type = nodes[-1].target == operator_type - return has_certain_operator_type + num_nodes_with_operator_type = sum( + node.target == operator_type for node in nodes + ) + return num_nodes_with_operator_type == 1 return operator_type_qconfig_filter -def _get_global_config_filter(): - def global_config_filter(nodes: List[Node]): - # TODO: double-check - # Have the same issue as `_get_operator_type_qconfig_filter` - # Return True if any node belongs to the `default_quantizable_ops` - nodes_is_default = any(node.target in default_quantizable_ops for node in nodes) - return nodes_is_default - - return global_config_filter +def _global_config_filter(nodes: List[Node]) -> bool: + num_nodes_in_default_quantizable_ops = sum( + node.target in default_quantizable_ops for node in nodes + ) + if num_nodes_in_default_quantizable_ops > 1: + raise NotImplementedError( + "Multiple nodes in on pattern belong to default quantizable ops." + ) + return num_nodes_in_default_quantizable_ops == 1 def _map_module_function_to_aten_operator_type(): @@ -262,14 +262,6 @@ def _is_quantized_op_pt2e(node: torch.fx.Node): return False quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None) assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation) - # TODO: - # Conv - # | - # Relu <- is annotated by `None` quantization config, `_is_output_of_quantized_pattern` is marked to False - # | - # Maxpool - # | - # return quantization_annotation._is_output_of_quantized_pattern @@ -381,6 +373,15 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_x86_inductor_config_and_operators() +def _annotate_nodes_not_quantize(nodes: Union[Node, List[Node]]) -> None: + if not isinstance(nodes, list): + nodes = [nodes] + for node in nodes: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True + ) + + def config_checker(method: Callable) -> Callable: @functools.wraps(method) def wrapper( @@ -550,6 +551,9 @@ def _annotate_conv_node_helper( quantization_config: QuantizationConfig, ) -> None: """Helper function to annotate the conv node""" + if quantization_config is None: + _annotate_nodes_not_quantize(conv_node) + return input_qspec_map = {} input_node = conv_node.args[0] assert isinstance(input_node, Node) @@ -564,7 +568,7 @@ def _annotate_conv_node_helper( conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) else: conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( @@ -579,6 +583,9 @@ def _annotate_linear_node_helper( quantization_config: QuantizationConfig, ) -> None: """Helper function to annotate the linear node""" + if quantization_config is None: + _annotate_nodes_not_quantize(linear_node) + return input_qspec_map = {} assert linear_node.target in (torch.ops.aten.linear.default,) has_bias = len(linear_node.args) == 3 @@ -602,7 +609,7 @@ def _annotate_linear_node_helper( linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) else: linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( @@ -680,7 +687,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: self._annotate_with_config( model, self.global_config, - _get_global_config_filter(), + _global_config_filter, ) # After annotating the model with quantization configurations, we need to annotate the output of quantizable ops. @@ -700,7 +707,7 @@ def _annotate_with_config( ) -> None: """Annotate the model with a quantization configuration. - # TODO update the note + # TODO update the note High-level description of quantization recipe for X86 Inductor Backend: Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model @@ -783,20 +790,27 @@ def _annotate_qat_conv2d_bn_binary_unary( self._annotate_conv_node_helper(conv_node, False, quantization_config) - binary_node_input_qspec_map = {} - binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( - quantization_config - ) - binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - _annotated=True, - ) - unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, - ) + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + unary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize([binary_node, unary_node]) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) nodes_to_mark_annotated.extend(list(binary_partition.nodes)) @@ -846,17 +860,22 @@ def _annotate_qat_conv2d_bn_binary( self._annotate_conv_node_helper(conv_node, False, quantization_config) - binary_node_input_qspec_map = {} - binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( - quantization_config - ) - binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, - ) + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize(binary_node) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) nodes_to_mark_annotated.extend(list(binary_partition.nodes)) @@ -902,12 +921,17 @@ def _annotate_qat_conv2d_bn_unary( continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, - ) + if quantization_config is not None: + unary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize(unary_node) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) nodes_to_mark_annotated.extend(list(unary_partition.nodes)) @@ -938,14 +962,17 @@ def _annotate_qat_conv2d_bn( continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - bn_output_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, - ) + if quantization_config is not None: + bn_output_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize(bn_output_node) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) @@ -981,6 +1008,11 @@ def _annotate_matmul( continue if _skip_annotate([node], filter_fn): continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + continue + input_qspec_map = {} matmul_node = node for input_node in matmul_node.args: @@ -988,7 +1020,7 @@ def _annotate_matmul( matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) def _annotate_conv2d_binary_unary( @@ -1025,6 +1057,11 @@ def _annotate_conv2d_binary_unary( continue if _skip_annotate([unary_node, binary_node, conv_node], filter_fn): continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node, unary_node]) + continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) binary_node_input_qspec_map = {} binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( @@ -1036,7 +1073,7 @@ def _annotate_conv2d_binary_unary( ) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) def _annotate_conv2d_binary( @@ -1074,6 +1111,11 @@ def _annotate_conv2d_binary( continue if _skip_annotate([binary_node, conv_node], filter_fn): continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node]) + continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) binary_node_input_qspec_map = {} binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( @@ -1082,7 +1124,7 @@ def _annotate_conv2d_binary( binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=binary_node_input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) def _annotate_conv2d_unary( @@ -1117,10 +1159,15 @@ def _annotate_conv2d_unary( continue if _skip_annotate([unary_node, conv_node], filter_fn): continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, unary_node]) + continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) def _annotate_conv2d( @@ -1154,6 +1201,12 @@ def _annotate_maxpool2d( ) -> None: if node.target is not torch.ops.aten.max_pool2d.default: return + if quantization_config is None: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True + ) + return + maxpool_node = node if _is_any_annotated( [ @@ -1169,12 +1222,17 @@ def _annotate_maxpool2d( maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) def _annotate_cat( self, node: Node, quantization_config: QuantizationConfig ) -> None: + if quantization_config is None: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True + ) + return cat_node = node input_nodes = cat_node.args[0] assert isinstance(input_nodes, Sequence) @@ -1196,7 +1254,7 @@ def _annotate_cat( cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) def _annotate_propagation_quantizable_pattern_entry( @@ -1227,21 +1285,23 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): if _skip_annotate([node], filter_fn): return + # For `global_config` is not None but `max_pool2d` is `None`, we annotate + # the `max_pool2d` with `_X86InductorQuantizationAnnotation(_annotated=True)`. + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + if node.target is torch.ops.aten.max_pool2d.default: - if quantization_config is None: - # TODO: - # If quantization_config is None, we mark the `_annotated` as True with a empty `input_qspec_map`. - # Handle the `cat` and other propagation patterns as well. - self._annotate_maxpool2d(node, quantization_config) - else: - # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not - input_nodes_to_check = [node.all_input_nodes[0]] - if not is_all_inputs_connected_to_quantized_op( - input_nodes_to_check - ): - return - - self._annotate_maxpool2d(node, quantization_config) + # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not + input_nodes_to_check = [node.all_input_nodes[0]] + if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + if quantization_config is not None: + warnings.warn( + f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}." + ) + return + + self._annotate_maxpool2d(node, quantization_config) return elif node.target is torch.ops.aten.cat.default: input_nodes_to_check = node.all_input_nodes @@ -1261,7 +1321,7 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) return @@ -1389,10 +1449,15 @@ def _annotate_linear_unary( continue if _skip_annotate([unary_node, linear_node], filter_fn): continue + + if quantization_config is None: + _annotate_nodes_not_quantize([linear_node, unary_node]) + continue + self._annotate_linear_node_helper(linear_node, False, quantization_config) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) def _annotate_linear_binary_unary( @@ -1459,6 +1524,10 @@ def _annotate_linear_binary_unary( if _skip_annotate(node_list, filter_fn): continue + if quantization_config is None: + _annotate_nodes_not_quantize(node_list) + continue + self._annotate_linear_node_helper( linear_node, False, quantization_config ) @@ -1468,15 +1537,14 @@ def _annotate_linear_binary_unary( ] = _X86InductorQuantizationAnnotation( input_qspec_map={}, _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None - and (not has_unary), + _is_output_of_quantized_pattern=(not has_unary), ) if unary_node is not None: unary_node.meta[ QUANT_ANNOTATION_KEY ] = _X86InductorQuantizationAnnotation( _annotated=True, - _is_output_of_quantized_pattern=quantization_config is not None, + _is_output_of_quantized_pattern=True, ) def validate(self, model: torch.fx.GraphModule) -> None: From 03836767ae41cfd7df7cfaba01b9a083f6fc7f9b Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 28 May 2024 08:27:26 +0800 Subject: [PATCH 030/706] refine docstring Signed-off-by: yiliu30 --- torch/ao/quantization/quantizer/utils.py | 34 +++++ .../quantizer/x86_inductor_quantizer.py | 118 +++++++++--------- .../quantizer/xnnpack_quantizer.py | 35 +----- 3 files changed, 92 insertions(+), 95 deletions(-) diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index f25d0916018b..77cfc22d73be 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -47,3 +47,37 @@ def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]): ((user not in partition_nodes) or _is_sym_size_node(user)) for user in node.users ) + + +def _get_module_name_filter(module_name: str): + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + + def _normalize_path(n): + prefix = 0 + # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. + if n.startswith("L['self']."): + prefix = len("L['self'].") + return n[prefix:] + + names = [_normalize_path(n) for n, _ in nn_module_stack.values()] + return module_name in names + + return module_name_filter diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 3a37bbdbcd74..fdb5490404e3 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -38,7 +38,7 @@ SharedQuantizationSpec, ) -# from torch.ao.quantization.quantizer.utils import _get_module_name_filter +from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( get_bias_qspec, get_input_act_qspec, @@ -107,76 +107,82 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): def _skip_annotate( nodes: List[Node], filter_fn: Optional[Callable[[List[Node]], bool]] = None ): + """Determines whether to skip annotation for a list of nodes.""" + # 1) Skip annotate if any node is already annotated if _is_any_annotated(nodes): return True - # 2) Not skip annotate if a) filter_fn is provided and b) any node passed the filter + # 2) Proceed with annotation if a) a filter function is provided + # and b) the given nodes list passes the filter function check. if filter_fn and filter_fn(nodes): return False return True -def _get_module_name_filter(module_name: str) -> Callable[[List[Node]], bool]: - """Get a filter function for a given module name. +def _create_module_name_filter(module_name: str) -> Callable[[List[Node]], bool]: + """Create a filter function for a given module name. - The filter function that takes a list of nodes (as determined by the annotate function) - and returns True if **all** nodes come from the specified module name, and False otherwise. + The filter function takes a list of nodes (as determined by the annotate function) + and return True if *all* nodes come from the specified module name, False otherwise. For example: linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` - >> module_name_filter = _get_module_name_filter("sub") + >> module_name_filter = _create_module_name_filter_inner("sub") >> print(module_name_filter([relu, linear_1])) # True # These two nodes are from "sub" and determined by `_annotate_linear_unary` function. """ - def _node_filter(n: Node) -> bool: - # example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - # get_attr nodes doesn't have nn_module_stack? - nn_module_stack = n.meta.get("nn_module_stack", {}) - - def _normalize_path(n): - prefix = 0 - # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. - if n.startswith("L['self']."): - prefix = len("L['self'].") - return n[prefix:] - - names = [_normalize_path(n) for n, _ in nn_module_stack.values()] - return module_name in names - - def module_name_filter(nodes: List[Node]) -> bool: - all_nodes_from_module_name: bool = all(_node_filter(n) for n in nodes) + filter_fn = _get_module_name_filter(module_name) + + def check_all_nodes_from_module(nodes: List[Node]) -> bool: + all_nodes_from_module_name: bool = all(filter_fn(n) for n in nodes) return all_nodes_from_module_name - return module_name_filter + return check_all_nodes_from_module -def _get_operator_type_qconfig_filter( +def _create_operator_type_filter( operator_type: Callable, ) -> Callable[[List[Node]], bool]: - def operator_type_qconfig_filter(nodes: List[Node]): + """Create a filter function for a given operator type. + + The filter function takes a list of nodes and returns True if it contains + exactly one node with the specified operator type, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> operator_type_filter = _create_operator_type_filter(torch.ops.aten.linear.default) + >> print(operator_type_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and the second node is `linear`. + """ + + def operator_type_filter(nodes: List[Node]): num_nodes_with_operator_type = sum( node.target == operator_type for node in nodes ) return num_nodes_with_operator_type == 1 - return operator_type_qconfig_filter + return operator_type_filter def _global_config_filter(nodes: List[Node]) -> bool: + """Filter function for global configuration. + + This filter function takes a list of nodes and returns True if there is exactly one node + in the list that is a default quantizable operation, False otherwise. + """ num_nodes_in_default_quantizable_ops = sum( node.target in default_quantizable_ops for node in nodes ) if num_nodes_in_default_quantizable_ops > 1: raise NotImplementedError( - "Multiple nodes in on pattern belong to default quantizable ops." + "Several nodes within a single pattern are default quantizable operations." ) return num_nodes_in_default_quantizable_ops == 1 @@ -438,10 +444,10 @@ def get_supported_operator_for_quantization_config( def _need_skip_config( self, quantization_config: Optional[QuantizationConfig] ) -> bool: - """Check if the given quantization config is valid for X86InductorQuantizer. + """Check if the provided quantization config is valid for X86InductorQuantizer. Note: Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. - If such a mix is detected, the configuration will be skipped. + If such a mix is detected, the configuration will be marked for skipping.. """ need_skip = False if quantization_config is None: @@ -659,28 +665,27 @@ def _get_input_idx_for_binary_node( return conv_gemm_node_idx, extra_input_node_idx def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """Annotate the model with quantization configurations. + """Annotate the given model with quantization configurations. - Note: - 1. Annotate each node according to the users's qconfig in the following order: + Annotation contracts: + 1. Annotate each node according to the user's qconfig in the following order: `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. - 2. Skip nodes that have already been annotated by an earlier stage. For example, - if `linear1` has been annotated during in the `module_name_config` stage, it will - not be re-annotated in the `operator_type_qconfig` or `global_config` stages. - 3. For the config is `None`, the annotation will be skipped. + 2. Skip nodes that have already been annotated in earlier stage. For example, + if `linear1` has been annotated during in the `module_name_config` stage, + it won't be re-annotated in the 'operator_type_qconfig' or 'global_config' stages. + 3. For config is `None`, the node will be annotated with `_X86InductorQuantizationAnnotation(_annotated=True)`. For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. - This filter function checks if the node is marked by current stage and not marked by previous stage. + This filter function checks if the node is marked by current stage and not annotated by the previous stage. """ - for module_name, qconfig in self.module_name_qconfig.items(): self._annotate_with_config( - model, qconfig, _get_module_name_filter(module_name) + model, qconfig, _create_module_name_filter(module_name) ) for operator_type, qconfig in self.operator_type_qconfig.items(): self._annotate_with_config( - model, qconfig, _get_operator_type_qconfig_filter(operator_type) + model, qconfig, _create_operator_type_filter(operator_type) ) if self.global_config: @@ -690,11 +695,12 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: _global_config_filter, ) - # After annotating the model with quantization configurations, we need to annotate the output of quantizable ops. - # For example, if we annotated maxpool2d to quantize its inputs, we need to quantize its output as well. - # So, we can fuse dq-operator-q into a quantized op. + # Once we've annotated the model with quantization configurations, we also need to annotate + # the output of quantizable operations. For example, if we annotated `maxpool2d` to quantize its inputs, + # we will quantize its output accordingly. This enables us to fuse the dq-operator-q into a quantized op. # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 + self._annotate_output_for_int8_in_int8_out_pattern_entry(model) return model @@ -705,17 +711,13 @@ def _annotate_with_config( config: Optional[QuantizationConfig], filter_fn: Callable, ) -> None: - """Annotate the model with a quantization configuration. + """Annotate the model with the given quantization configuration. - # TODO update the note High-level description of quantization recipe for X86 Inductor Backend: Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model from start to the end. If a pattern supports computation with int8 data type and inputs connected to quantized patterns, annotate its inputs as quantized pattern. - Step 3: Since in step 2, we only annotate the inputs of quantized pattern. For some quantized patterns, - such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type, - we need to annotate the output of this pattern. """ # Step1: Recipe of fusion patterns like conv/linear. @@ -1202,9 +1204,7 @@ def _annotate_maxpool2d( if node.target is not torch.ops.aten.max_pool2d.default: return if quantization_config is None: - node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - _annotated=True - ) + _annotate_nodes_not_quantize(node) return maxpool_node = node @@ -1229,9 +1229,7 @@ def _annotate_cat( self, node: Node, quantization_config: QuantizationConfig ) -> None: if quantization_config is None: - node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - _annotated=True - ) + _annotate_nodes_not_quantize(node) return cat_node = node input_nodes = cat_node.args[0] @@ -1285,8 +1283,6 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): if _skip_annotate([node], filter_fn): return - # For `global_config` is not None but `max_pool2d` is `None`, we annotate - # the `max_pool2d` with `_X86InductorQuantizationAnnotation(_annotated=True)`. if quantization_config is None: _annotate_nodes_not_quantize(node) return diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index f3d1b6ca8b39..e13a79f39267 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -22,6 +22,7 @@ ) from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( _convert_scalars_to_attrs, @@ -192,40 +193,6 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_symmetric_config_and_operators() -def _get_module_name_filter(module_name: str): - """Get the module_name_filter function for a given module name, the filter accepts - a node and checks if the node comes from a module that has certain module name - - For example: - node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 - - - >> module_name_filter = _get_module_name_filter("blocks.sub") - >> print(module_name_filter(node)) - True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" - """ - - def module_name_filter(n: Node) -> bool: - # example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - # get_attr nodes doesn't have nn_module_stack? - nn_module_stack = n.meta.get("nn_module_stack", {}) - - def _normalize_path(n): - prefix = 0 - # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. - if n.startswith("L['self']."): - prefix = len("L['self'].") - return n[prefix:] - - names = [_normalize_path(n) for n, _ in nn_module_stack.values()] - return module_name in names - - return module_name_filter - - def _get_module_type_filter(tp: Callable): """Get the module_type_filter function for a given module type, the filter accepts a node and checks if the node comes from a module that has certain module type From c8ad6e8a57ef5526fab4e46a953d51514fef5761 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 28 May 2024 15:07:42 +0800 Subject: [PATCH 031/706] refine the UTs Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 22 +++++++------- .../quantizer/x86_inductor_quantizer.py | 30 +++++++++---------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 54c682ceba1f..467aa22d7835 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1893,7 +1893,7 @@ def forward(self, x): m = M().eval() example_inputs = (torch.randn(3, 5),) - # Set global to no quantization and then default config for a specific submodule. + # Set global to `None` and then default config for a specific submodule. quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig( "sub", xiq.get_default_x86_inductor_quantization_config() @@ -1946,17 +1946,17 @@ def forward(self, x): m = M().eval() example_inputs = (torch.randn(3, 5),) - # Set global to no quantization and then default config for a specific submodule. + # Set global to `None` and then default config for a specific submodule. quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig( "sub", xiq.get_default_x86_inductor_quantization_config() ) node_occurrence = { torch.ops.aten.linear.default: 3, - # quantize and dequantize the input of the two linear layers from `sub` + # quantize and dequantize the input of two linear layers from `sub` torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - # dequantize the weight of the two linear layers from `sub` + # dequantize the weight of two linear layers from `sub` torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } # node_list = None @@ -2059,7 +2059,7 @@ def forward(self, x): m = M().eval() example_inputs = (torch.randn(3, 5),) - # Set `sub` with default config and then no quantization for all `Linear`. + # Set `sub` with default config and then `None` for all `Linear`. # The config set by `set_module_name_qconfig` has higher priority than `set_module_type_qconfig`. quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig( @@ -2160,7 +2160,7 @@ def test_set_module_name_qconfig_for_dynamic_quant(self): for is_qat in [False, True]: m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() example_inputs = (torch.randn(1, 4, 64),) - # only quantize `self.q_proj` `self.v_proj` + # only quantize `q_proj` `v_proj` dynamic_config = xiq.get_default_x86_inductor_quantization_config( is_dynamic=True, is_qat=is_qat ) @@ -2206,7 +2206,7 @@ def test_set_module_name_qconfig_for_dynamic_quant(self): def test_set_module_name_with_mixed_static_and_dynamic(self): """Test that mixed static and dynamic quantization for a module. - Currently, mixed static and dynamic quantization is not supported. The subsequent config will be ignored. + Currently, mixed static/dynamic quantization is not supported. The subsequent config will be ignored. """ with override_quantized_engine("x86"), torch.no_grad(): @@ -2218,8 +2218,8 @@ def test_set_module_name_with_mixed_static_and_dynamic(self): dynamic_config = xiq.get_default_x86_inductor_quantization_config( is_dynamic=True ) - # set `self.v_proj` with static config - # set `self.q_proj` with dynamic config (will be skipped) + # set `v_proj` with static config + # set `q_proj` with dynamic config (will be skipped) quantizer = ( X86InductorQuantizer() .set_module_name_qconfig("q_proj", static_config) @@ -2229,7 +2229,7 @@ def test_set_module_name_with_mixed_static_and_dynamic(self): # quantize and dequantize the input torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, - # only q_proj be quantized, dequantize its weight + # only `q_proj`` was quantized, dequantize its weight torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -2242,7 +2242,7 @@ def test_set_module_name_with_mixed_static_and_dynamic(self): torch.ops.aten.linear.default, # k_proj torch.ops.aten.linear.default, - # `v_proj`'s weight will not be quantized + # not quantize `v_proj`'s weight # torch.ops.quantized_decomposed.dequantize_per_channel.default, # v_proj torch.ops.aten.linear.default, diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index fdb5490404e3..3b40bcdab8ff 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -106,14 +106,14 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): def _skip_annotate( nodes: List[Node], filter_fn: Optional[Callable[[List[Node]], bool]] = None -): - """Determines whether to skip annotation for a list of nodes.""" +) -> bool: + """Determine whether to skip annotation for a list of nodes.""" # 1) Skip annotate if any node is already annotated if _is_any_annotated(nodes): return True - # 2) Proceed with annotation if a) a filter function is provided + # 2) Proceed annotate if a) a filter function is provided # and b) the given nodes list passes the filter function check. if filter_fn and filter_fn(nodes): return False @@ -133,7 +133,7 @@ def _create_module_name_filter(module_name: str) -> Callable[[List[Node]], bool] >> module_name_filter = _create_module_name_filter_inner("sub") >> print(module_name_filter([relu, linear_1])) - # True # These two nodes are from "sub" and determined by `_annotate_linear_unary` function. + # True # These two nodes are determined by `_annotate_linear_unary` function and from "sub". """ filter_fn = _get_module_name_filter(module_name) @@ -380,6 +380,7 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: def _annotate_nodes_not_quantize(nodes: Union[Node, List[Node]]) -> None: + """Annotate nodes to exclude them from quantization (their `quantization_config` is `None`).""" if not isinstance(nodes, list): nodes = [nodes] for node in nodes: @@ -391,16 +392,16 @@ def _annotate_nodes_not_quantize(nodes: Union[Node, List[Node]]) -> None: def config_checker(method: Callable) -> Callable: @functools.wraps(method) def wrapper( - self: "X86InductorQuantizer", + quantizer: "X86InductorQuantizer", name: Any, quantization_config: Optional["QuantizationConfig"], ) -> "X86InductorQuantizer": - if self._need_skip_config(quantization_config): + if quantizer._need_skip_config(quantization_config): warnings.warn( f"Skip the quantization config for {name}.", ) - return self - return method(self, name, quantization_config) + return quantizer + return method(quantizer, name, quantization_config) return wrapper @@ -411,7 +412,7 @@ class X86InductorQuantizer(Quantizer): def __init__(self): super().__init__() - self.global_config: Optional[QuantizationConfig] = None # type: ignore[assignment] + self.global_config: Optional[QuantizationConfig] = None self.operator_type_qconfig: Dict[ torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} @@ -446,7 +447,7 @@ def _need_skip_config( ) -> bool: """Check if the provided quantization config is valid for X86InductorQuantizer. - Note: Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. + Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. If such a mix is detected, the configuration will be marked for skipping.. """ need_skip = False @@ -670,9 +671,9 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: Annotation contracts: 1. Annotate each node according to the user's qconfig in the following order: `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. - 2. Skip nodes that have already been annotated in earlier stage. For example, - if `linear1` has been annotated during in the `module_name_config` stage, - it won't be re-annotated in the 'operator_type_qconfig' or 'global_config' stages. + 2. Avoid re-annotating nodes already annotated in prior stages. For example, + if `linear1` has been annotated by `module_name_qconfig`, it won't be annotated again + during the processing of the 'operator_type_qconfig' or 'global_config'. 3. For config is `None`, the node will be annotated with `_X86InductorQuantizationAnnotation(_annotated=True)`. For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. @@ -1365,8 +1366,7 @@ def _annotate_output_for_int8_in_int8_out_pattern( ] ): return - # Don't check the `filter_fn` here, as we want to annotate - # the output of the node that's being annotated. + # Get the quantization_annotation from getitem_node maxpool_node_quantization_annotation = ( maxpool_node.meta[QUANT_ANNOTATION_KEY] From 5845bb4d85d80a942721599b44fe4e6869d89200 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 29 May 2024 08:38:42 +0800 Subject: [PATCH 032/706] refine config checker Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 2 +- .../quantizer/x86_inductor_quantizer.py | 204 +++++++++++------- 2 files changed, 124 insertions(+), 82 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 467aa22d7835..1d58bcc8fe84 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -2199,7 +2199,7 @@ def test_set_module_name_qconfig_for_dynamic_quant(self): quantizer, node_occurrence, node_list, - is_qat=True, + is_qat=is_qat, ) @skipIfNoX86 diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 3b40bcdab8ff..a31022dc6e07 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -17,6 +17,8 @@ Union, ) +from typing_extensions import TypeAlias + import torch import torch.nn.functional as F from torch.ao.quantization.fake_quantize import ( @@ -54,6 +56,9 @@ SourcePartition, ) +FilterFn: TypeAlias = Callable[[List[Node]], bool] + + if TYPE_CHECKING: from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor @@ -104,9 +109,7 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): QUANT_ANNOTATION_KEY = "quantization_annotation" -def _skip_annotate( - nodes: List[Node], filter_fn: Optional[Callable[[List[Node]], bool]] = None -) -> bool: +def _skip_annotate(nodes: List[Node], filter_fn: Optional[FilterFn] = None) -> bool: """Determine whether to skip annotation for a list of nodes.""" # 1) Skip annotate if any node is already annotated @@ -121,7 +124,7 @@ def _skip_annotate( return True -def _create_module_name_filter(module_name: str) -> Callable[[List[Node]], bool]: +def _create_module_name_filter(module_name: str) -> FilterFn: """Create a filter function for a given module name. The filter function takes a list of nodes (as determined by the annotate function) @@ -147,7 +150,7 @@ def check_all_nodes_from_module(nodes: List[Node]) -> bool: def _create_operator_type_filter( operator_type: Callable, -) -> Callable[[List[Node]], bool]: +) -> FilterFn: """Create a filter function for a given operator type. The filter function takes a list of nodes and returns True if it contains @@ -389,7 +392,7 @@ def _annotate_nodes_not_quantize(nodes: Union[Node, List[Node]]) -> None: ) -def config_checker(method: Callable) -> Callable: +def _config_checker(method: Callable) -> Callable: @functools.wraps(method) def wrapper( quantizer: "X86InductorQuantizer", @@ -406,6 +409,12 @@ def wrapper( return wrapper +@dataclass +class _QuantizationMode: + is_qat: Optional[bool] + is_dynamic: Optional[bool] + + class X86InductorQuantizer(Quantizer): supported_config_and_operators = _get_supported_config_and_operators() module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type() @@ -417,8 +426,6 @@ def __init__(self): torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} self.module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} - self._is_dynamic = None - self._is_qat = None @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: @@ -442,6 +449,23 @@ def get_supported_operator_for_quantization_config( return ops return [] + def _get_current_quantization_mode(self) -> _QuantizationMode: + """Retrieves the current quantization mode based on all configurations.""" + is_qat = None + is_dynamic = None + + for qconfig in ( + list(self.module_name_qconfig.values()) + + list(self.operator_type_qconfig.values()) + + [self.global_config] + ): + if qconfig is not None: + is_qat = qconfig.is_qat + input_activation_spec = qconfig.input_activation + if input_activation_spec is not None: + is_dynamic = input_activation_spec.is_dynamic + return _QuantizationMode(is_qat=is_qat, is_dynamic=is_dynamic) + def _need_skip_config( self, quantization_config: Optional[QuantizationConfig] ) -> bool: @@ -450,28 +474,27 @@ def _need_skip_config( Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. If such a mix is detected, the configuration will be marked for skipping.. """ - need_skip = False if quantization_config is None: return False - if self._is_qat is None: - self._is_qat = quantization_config.is_qat - else: - if self._is_qat != quantization_config.is_qat: + need_skip = False + current_mode = self._get_current_quantization_mode() + if ( + current_mode.is_qat is not None + and current_mode.is_qat != quantization_config.is_qat + ): + warnings.warn("Mixed QAT and Non-QAT quantization config is not supported.") + need_skip = True + if current_mode.is_dynamic is not None: + input_activation_spec = quantization_config.input_activation + if ( + input_activation_spec is not None + and current_mode.is_dynamic != input_activation_spec.is_dynamic + ): warnings.warn( - "Mixed QAT and Non-QAT quantization config is not supported." + "Mixed dynamic and static quantization config is not supported." ) need_skip = True - input_activation_spec = quantization_config.input_activation - if input_activation_spec is not None: - if self._is_dynamic is None: - self._is_dynamic = input_activation_spec.is_dynamic - else: - if self._is_dynamic != input_activation_spec.is_dynamic: - warnings.warn( - "Mixed dynamic and static quantization config is not supported." - ) - need_skip = True return need_skip def set_global(self, quantization_config: QuantizationConfig): @@ -489,7 +512,7 @@ def get_global_quantization_config(self): ) return self.global_config - @config_checker + @_config_checker def set_function_type_qconfig( self, function_type: Callable, @@ -508,7 +531,7 @@ def set_function_type_qconfig( ) return self - @config_checker + @_config_checker def set_module_type_qconfig( self, module_type: torch.nn.Module, @@ -525,7 +548,7 @@ def set_module_type_qconfig( ) return self - @config_checker + @_config_checker def set_module_name_qconfig( self, module_name: str, quantization_config: Optional[QuantizationConfig] ): @@ -555,7 +578,7 @@ def _annotate_conv_node_helper( self, conv_node: torch.fx.Node, annotate_output: bool, - quantization_config: QuantizationConfig, + quantization_config: Optional[QuantizationConfig], ) -> None: """Helper function to annotate the conv node""" if quantization_config is None: @@ -587,7 +610,7 @@ def _annotate_linear_node_helper( self, linear_node: torch.fx.Node, annotate_output: bool, - quantization_config: QuantizationConfig, + quantization_config: Optional[QuantizationConfig], ) -> None: """Helper function to annotate the linear node""" if quantization_config is None: @@ -679,14 +702,14 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. This filter function checks if the node is marked by current stage and not annotated by the previous stage. """ - for module_name, qconfig in self.module_name_qconfig.items(): + for module_name, quantization_config in self.module_name_qconfig.items(): self._annotate_with_config( - model, qconfig, _create_module_name_filter(module_name) + model, quantization_config, _create_module_name_filter(module_name) ) - for operator_type, qconfig in self.operator_type_qconfig.items(): + for operator_type, quantization_config in self.operator_type_qconfig.items(): self._annotate_with_config( - model, qconfig, _create_operator_type_filter(operator_type) + model, quantization_config, _create_operator_type_filter(operator_type) ) if self.global_config: @@ -709,8 +732,8 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: def _annotate_with_config( self, model: torch.fx.GraphModule, - config: Optional[QuantizationConfig], - filter_fn: Callable, + quantization_config: Optional[QuantizationConfig], + filter_fn: FilterFn, ) -> None: """Annotate the model with the given quantization configuration. @@ -722,31 +745,36 @@ def _annotate_with_config( """ # Step1: Recipe of fusion patterns like conv/linear. - self._annotate_conv2d_fusion_pattern(model, config, filter_fn) - self._annotate_linear_fusion_pattern(model, config, filter_fn) - self._annotate_matmul(model, config, filter_fn) + self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_matmul(model, quantization_config, filter_fn) # Step2: Recipe to propagate annotation for patterns beside conv/linear. # Go through all the nodes from start to end. # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/ # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 - self._annotate_propagation_quantizable_pattern_entry(model, config, filter_fn) + self._annotate_propagation_quantizable_pattern_entry( + model, quantization_config, filter_fn + ) def _annotate_qat_conv2d_fusion_pattern( - self, model: torch.fx.GraphModule, config: QuantizationConfig, filter_fn + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ): # Annotate QAT Specific patterns - self._annotate_qat_conv2d_bn_binary_unary(model, config, filter_fn) - self._annotate_qat_conv2d_bn_binary(model, config, filter_fn) - self._annotate_qat_conv2d_bn_unary(model, config, filter_fn) - self._annotate_qat_conv2d_bn(model, config, filter_fn) + self._annotate_qat_conv2d_bn_binary_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_binary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn(model, quantization_config, filter_fn) def _annotate_qat_conv2d_bn_binary_unary( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] @@ -823,8 +851,8 @@ def _annotate_qat_conv2d_bn_binary_unary( def _annotate_qat_conv2d_bn_binary( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add] @@ -887,8 +915,8 @@ def _annotate_qat_conv2d_bn_binary( def _annotate_qat_conv2d_bn_unary( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = [] unary_patterns = [ @@ -943,8 +971,8 @@ def _annotate_qat_conv2d_bn_unary( def _annotate_qat_conv2d_bn( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d] @@ -981,30 +1009,41 @@ def _annotate_qat_conv2d_bn( _mark_nodes_as_annotated(nodes_to_mark_annotated) def _annotate_conv2d_fusion_pattern( - self, model: torch.fx.GraphModule, config, filter_fn + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ): - if self._is_qat: + if (quantization_config is None) or (quantization_config.is_qat): # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat - self._annotate_qat_conv2d_fusion_pattern(model, config, filter_fn) - self._annotate_conv2d_binary_unary(model, config, filter_fn) - self._annotate_conv2d_binary(model, config, filter_fn) - self._annotate_conv2d_unary(model, config, filter_fn) - self._annotate_conv2d(model, config, filter_fn) + self._annotate_qat_conv2d_fusion_pattern( + model, quantization_config, filter_fn + ) + self._annotate_conv2d_binary_unary(model, quantization_config, filter_fn) + self._annotate_conv2d_binary(model, quantization_config, filter_fn) + self._annotate_conv2d_unary(model, quantization_config, filter_fn) + self._annotate_conv2d(model, quantization_config, filter_fn) def _annotate_linear_fusion_pattern( - self, model: torch.fx.GraphModule, config, filter_fn + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ): - if not self._is_dynamic: + if (quantization_config is None) or ( + quantization_config.input_activation + and not quantization_config.input_activation.is_dynamic + ): # Weiwen: Dynamic Quant of linear unary will be supported in next step - self._annotate_linear_binary_unary(model, config, filter_fn) - self._annotate_linear_unary(model, config, filter_fn) - self._annotate_linear(model, config, filter_fn) + self._annotate_linear_binary_unary(model, quantization_config, filter_fn) + self._annotate_linear_unary(model, quantization_config, filter_fn) + self._annotate_linear(model, quantization_config, filter_fn) def _annotate_matmul( self, model: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + filter_fn: Optional[FilterFn] = None, ): for node in model.graph.nodes: if node.target != torch.ops.aten.matmul.default: @@ -1029,8 +1068,8 @@ def _annotate_matmul( def _annotate_conv2d_binary_unary( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: # Conv2d + add + unary op fused_partitions = find_sequential_partitions( @@ -1082,8 +1121,8 @@ def _annotate_conv2d_binary_unary( def _annotate_conv2d_binary( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: # Conv2d + add fused_partitions = find_sequential_partitions( @@ -1133,8 +1172,8 @@ def _annotate_conv2d_binary( def _annotate_conv2d_unary( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = [] unary_patterns = [ @@ -1176,8 +1215,8 @@ def _annotate_conv2d_unary( def _annotate_conv2d( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: conv_partitions = get_source_partitions( gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] @@ -1257,9 +1296,12 @@ def _annotate_cat( ) def _annotate_propagation_quantizable_pattern_entry( - self, model, quantization_config, filter_fn + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ): - for node in model.graph.nodes: + for node in gm.graph.nodes: self._annotate_propagation_quantizable_pattern( node, quantization_config, filter_fn ) @@ -1393,8 +1435,8 @@ def _annotate_output_for_int8_in_int8_out_pattern( def _annotate_linear( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: linear_partitions = get_source_partitions( gm.graph, [torch.nn.Linear, torch.nn.functional.linear] @@ -1420,8 +1462,8 @@ def _annotate_linear( def _annotate_linear_unary( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: postop_list = [ torch.nn.ReLU, @@ -1459,8 +1501,8 @@ def _annotate_linear_unary( def _annotate_linear_binary_unary( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[List[Node]], bool]] = None, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: # linear + binary_op + (optional) unary op binary_op_list = [operator.add] From cabfb520cb61b12727fb1aefcffd68b55ee3381f Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 29 May 2024 15:13:16 +0800 Subject: [PATCH 033/706] add more UTs Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 272 +++++++++--------- .../quantizer/x86_inductor_quantizer.py | 4 + 2 files changed, 148 insertions(+), 128 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 1d58bcc8fe84..ea9ffd800def 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1870,55 +1870,10 @@ def test_qat_dynamic_quant_linear(self): @skipIfNoX86 def test_set_module_name_qconfig(self): - """Test that quantize the specific submodule.""" + """Test case for quantizing a specific submodule by configuring `set_module_name_qconfig`. - class Sub(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.linear(x) - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 5) - self.sub = Sub() - - def forward(self, x): - x = self.linear(x) - x = self.sub(x) - return x - - m = M().eval() - example_inputs = (torch.randn(3, 5),) - # Set global to `None` and then default config for a specific submodule. - quantizer = X86InductorQuantizer() - quantizer.set_module_name_qconfig( - "sub", xiq.get_default_x86_inductor_quantization_config() - ) - node_occurrence = { - torch.ops.aten.linear.default: 2, - # quantize and dequantize the input of the second linear (`sub`) - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, - # dequantize the weight of the second linear (`sub`) - torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, - } - node_list = [ - # first linear is not quantized - torch.ops.aten.linear.default, - # second linear (`sub`) is quantized - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.linear.default, - ] - self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) - - @skipIfNoX86 - def test_set_module_name_qconfig_case2(self): - """Test that quantize the specific submodule.""" + Expect that all linear layers within the submodule `sub` are quantized. + """ class Sub(torch.nn.Module): def __init__(self): @@ -1959,11 +1914,10 @@ def forward(self, x): # dequantize the weight of two linear layers from `sub` torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } - # node_list = None node_list = [ # first linear is not quantized torch.ops.aten.linear.default, - # two linear layers from `sub` are quantized + # two Q/DQ pairs for two linear layers from `sub` are quantized torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, @@ -2030,26 +1984,18 @@ def forward(self, x): count += 1 @skipIfNoX86 - def test_set_module_name_and_set_module_type_case1(self): + def test_set_module_name_and_module_type_case1(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. - All linear are not quantized except the last one. + Expect that all linear layers are not quantized except the last one. """ - class Sub(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.linear(x) - class M(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(5, 10) self.linear2 = torch.nn.Linear(10, 5) - self.sub = Sub() + self.sub = torch.nn.Linear(5, 5) def forward(self, x): x = self.linear1(x) @@ -2092,26 +2038,18 @@ def forward(self, x): ) @skipIfNoX86 - def test_set_module_name_and_set_module_type_case2(self): + def test_set_module_name_and_module_type_case2(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. - All linear are quantized except the last one. + Expect that all linear are quantized except the last one. """ - class Sub(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.linear(x) - class M(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(5, 10) self.linear2 = torch.nn.Linear(10, 5) - self.sub = Sub() + self.sub = torch.nn.Linear(5, 5) def forward(self, x): x = self.linear1(x) @@ -2121,7 +2059,7 @@ def forward(self, x): m = M().eval() example_inputs = (torch.randn(3, 5),) - # Set `sub` to None and then default config for a all `Linear`. + # Set `sub` with None and then default config for a all `Linear`. quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config() @@ -2136,10 +2074,13 @@ def forward(self, x): torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - # first and second linear are quantized + # Q/DQ for first lienar torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, + # Q/DQ for second lienar + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, # last linear is not quantized torch.ops.aten.linear.default, @@ -2154,7 +2095,7 @@ def forward(self, x): @skipIfNoX86 def test_set_module_name_qconfig_for_dynamic_quant(self): - """Test that quantize the specific submodule for dynamic quantization.""" + """Test that quantize a specific submodule for dynamic quantization.""" with override_quantized_engine("x86"), torch.no_grad(): for is_qat in [False, True]: @@ -2182,14 +2123,10 @@ def test_set_module_name_qconfig_for_dynamic_quant(self): torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, - # # op for de-quantizing `q_proj`'s weight, disable this check to avoid random error. - # torch.ops.quantized_decomposed.dequantize_per_channel.default, # q_proj torch.ops.aten.linear.default, # k_proj torch.ops.aten.linear.default, - # # op for de-quantizing `v_proj`'s weight, disable this check to avoid random error. - # torch.ops.quantized_decomposed.dequantize_per_channel.default, # v_proj torch.ops.aten.linear.default, ] @@ -2203,58 +2140,137 @@ def test_set_module_name_qconfig_for_dynamic_quant(self): ) @skipIfNoX86 - def test_set_module_name_with_mixed_static_and_dynamic(self): - """Test that mixed static and dynamic quantization for a module. + def test_set_module_name_with_mixed_configs(self): + """Test case for setting module names with mixed static/dynamic or QAT/non-QAT configurations. - Currently, mixed static/dynamic quantization is not supported. The subsequent config will be ignored. + The config for 'v_proj' will always be ignored and raise a warning. """ - with override_quantized_engine("x86"), torch.no_grad(): - m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() - example_inputs = (torch.randn(1, 4, 64),) - static_config = xiq.get_default_x86_inductor_quantization_config( - is_dynamic=False - ) - dynamic_config = xiq.get_default_x86_inductor_quantization_config( - is_dynamic=True - ) - # set `v_proj` with static config - # set `q_proj` with dynamic config (will be skipped) - quantizer = ( - X86InductorQuantizer() - .set_module_name_qconfig("q_proj", static_config) - .set_module_name_qconfig("v_proj", dynamic_config) - ) - node_occurrence = { - # quantize and dequantize the input - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, - # only `q_proj`` was quantized, dequantize its weight - torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, - } - node_list = [ - # quantize and dequantize the input - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - # dequantize `q_proj`'s weight, disable this check to avoid random error. - # torch.ops.quantized_decomposed.dequantize_per_channel.default, - # q_proj - torch.ops.aten.linear.default, - # k_proj - torch.ops.aten.linear.default, - # not quantize `v_proj`'s weight - # torch.ops.quantized_decomposed.dequantize_per_channel.default, - # v_proj - torch.ops.aten.linear.default, - ] - self._test_quantizer( - m, - example_inputs, - quantizer, - node_occurrence, - node_list, - is_qat=True, - ) + with self.assertWarns(UserWarning) as context: + for q_is_dynamic, v_is_dynamic, q_is_qat, v_is_qat in itertools.product( + [False, True], repeat=4 + ): + if q_is_dynamic == v_is_dynamic and q_is_qat == v_is_qat: + continue + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + quantizer = ( + X86InductorQuantizer() + .set_module_name_qconfig( + "q_proj", + xiq.get_default_x86_inductor_quantization_config( + is_qat=q_is_qat, is_dynamic=q_is_dynamic + ), + ) + .set_module_name_qconfig( + "v_proj", + xiq.get_default_x86_inductor_quantization_config( + is_qat=v_is_qat, is_dynamic=v_is_dynamic + ), + ) + ) + quant_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.default + if not q_is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.tensor + ) + dequant_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + if not q_is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + ) + node_occurrence = { + # quantize and dequantize the input + quant_op: 1, + dequant_op: 1, + # only `q_proj` was quantized, dequantize its weight + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # quantize and dequantize the input + quant_op, + dequant_op, + # q_proj + torch.ops.aten.linear.default, + # k_proj/v_proj + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=q_is_qat, + ) + warning_msg = ( + "Mixed QAT and Non-QAT" + if q_is_qat != v_is_qat + else "Mixed dynamic and static" + ) + self.assertTrue( + any( + warning_msg in msg + for msg in [str(w.message) for w in context.warnings] + ) + ) + + @skipIfNoX86 + def test_set_module_name_and_module_type_with_mixed_configs(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs. + + Expect that all linear are quantized with static quantization except the last one. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with static config and then dynamic config for a all `Linear`(ignored). + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config(is_dynamic=False) + ).set_module_type_qconfig( + torch.nn.Linear, + xiq.get_default_x86_inductor_quantization_config(is_dynamic=True), + ) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # dequantize the weight of the last linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # first and second linear are not quantized + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + # Q/DQ pairs for the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) @skipIfNoX86 def test_filter_conv2d_recipe(self): diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index a31022dc6e07..b83388ad978b 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -169,6 +169,10 @@ def operator_type_filter(nodes: List[Node]): num_nodes_with_operator_type = sum( node.target == operator_type for node in nodes ) + if num_nodes_with_operator_type > 1: + raise NotImplementedError( + f"Several nodes within a single pattern are {operator_type}." + ) return num_nodes_with_operator_type == 1 return operator_type_filter From d492b5e38358a081a4f741177f0d4fe9098ccea1 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 29 May 2024 15:50:10 +0800 Subject: [PATCH 034/706] fixed the typos Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index ea9ffd800def..a64c12a8f2a1 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1917,7 +1917,7 @@ def forward(self, x): node_list = [ # first linear is not quantized torch.ops.aten.linear.default, - # two Q/DQ pairs for two linear layers from `sub` are quantized + # two Q/DQ pairs for two linear layers from `sub` torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, @@ -1947,7 +1947,7 @@ def __init__(self): def forward(self, x): return self.baz(self.foo_bar(x)) - # Set global to no quantization and then default config for a specific submodule. + # Set global to no quantization and then default config for a specific submodule whose name includes an underscore. quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig( "foo_bar", xiq.get_default_x86_inductor_quantization_config() @@ -2041,7 +2041,7 @@ def forward(self, x): def test_set_module_name_and_module_type_case2(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. - Expect that all linear are quantized except the last one. + Expect that all linear layers are quantized except the last one. """ class M(torch.nn.Module): @@ -2170,14 +2170,14 @@ def test_set_module_name_with_mixed_configs(self): ) ) quant_op = ( - torch.ops.quantized_decomposed.quantize_per_tensor.default - if not q_is_dynamic - else torch.ops.quantized_decomposed.quantize_per_tensor.tensor + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if q_is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default ) dequant_op = ( - torch.ops.quantized_decomposed.dequantize_per_tensor.default - if not q_is_dynamic - else torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if q_is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) node_occurrence = { # quantize and dequantize the input @@ -2220,7 +2220,7 @@ def test_set_module_name_with_mixed_configs(self): def test_set_module_name_and_module_type_with_mixed_configs(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs. - Expect that all linear are quantized with static quantization except the last one. + Expect that all linear layers are quantized with static quantization except the last one. """ class M(torch.nn.Module): From bf966588f157e22970376f9bb67fac98959fd004 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Wed, 29 May 2024 14:23:38 +0000 Subject: [PATCH 035/706] [BE][Ez]: Update cudnn_frontend submodule to v1.4.0 (#127175) Updates the cudnn_frontend submodule to the latest 1.4.0 version. Should be a straightforward, header-only submodule update. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127175 Approved by: https://github.com/ezyang, https://github.com/malfet --- third_party/cudnn_frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cudnn_frontend b/third_party/cudnn_frontend index 150798fe9765..b740542818f3 160000 --- a/third_party/cudnn_frontend +++ b/third_party/cudnn_frontend @@ -1 +1 @@ -Subproject commit 150798fe976556078f443fdb059a1ff0361f58a2 +Subproject commit b740542818f36857acf7f9853f749bbad4118c65 From ade075444fa26e1c50161503d8becc4218c35118 Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Tue, 28 May 2024 17:56:40 +0000 Subject: [PATCH 036/706] [dynamo] Support numpy.dtype (#124481) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124481 Approved by: https://github.com/lezcano --- test/dynamo/test_functions.py | 5 ++ .../TestFromBuffer.test_basic_little_dtype0 | 0 .../TestFromBuffer.test_basic_little_dtype1 | 0 .../TestFromBuffer.test_basic_little_dtype2 | 0 test/torch_np/numpy_tests/core/test_dtype.py | 2 + .../numpy_tests/core/test_multiarray.py | 9 ++- .../numpy_tests/core/test_scalarmath.py | 2 +- .../numpy_tests/lib/test_arraysetops.py | 5 +- .../numpy_tests/lib/test_function_base.py | 10 +-- .../numpy_tests/lib/test_histograms.py | 2 +- torch/_dynamo/variables/misc.py | 63 ++++++++++++++----- 11 files changed, 72 insertions(+), 26 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype0 delete mode 100644 test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype1 delete mode 100644 test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype2 diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 472e9c56bae6..d919ba57b49b 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1614,6 +1614,11 @@ def test_ndarray_builtin_functions(x): def test_numpy_dtype_argument_to_function(x): return np.ones_like(x, dtype=np.float64) + @make_test + def test_numpy_dtype_call_in_function(x): + dt = np.dtype("float") + return np.full_like(x, 2.4, dtype=dt) + @make_test def test_numpy_linalg(x): return np.linalg.norm(x.numpy(), axis=0) diff --git a/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype0 b/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype0 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype1 b/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype1 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype2 b/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype2 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/torch_np/numpy_tests/core/test_dtype.py b/test/torch_np/numpy_tests/core/test_dtype.py index ccff28135a1f..00ead3f705af 100644 --- a/test/torch_np/numpy_tests/core/test_dtype.py +++ b/test/torch_np/numpy_tests/core/test_dtype.py @@ -21,6 +21,7 @@ subtest, TEST_WITH_TORCHDYNAMO, TestCase, + xfailIfTorchDynamo, xpassIfTorchDynamo, ) @@ -68,6 +69,7 @@ def test_equivalent_dtype_hashing(self): assert_(left == right) assert_(hash(left) == hash(right)) + @xfailIfTorchDynamo # TypeError -> InternalTorchDynamoError def test_invalid_types(self): # Make sure invalid type strings raise an error diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index bf9aab8ebcee..3dec3c2cdddd 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -291,6 +291,7 @@ def test_otherflags(self): assert_equal(self.a.flags["X"], False) assert_equal(self.a.flags["WRITEBACKIFCOPY"], False) + @xfail # invalid dtype def test_string_align(self): a = np.zeros(4, dtype=np.dtype("|S4")) assert_(a.flags.aligned) @@ -298,6 +299,7 @@ def test_string_align(self): a = np.zeros(5, dtype=np.dtype("|S4")) assert_(a.flags.aligned) + @xfail # structured dtypes def test_void_align(self): a = np.zeros(4, dtype=np.dtype([("a", "i4"), ("b", "i4")])) assert_(a.flags.aligned) @@ -1856,7 +1858,7 @@ def test_searchsorted_floats(self, a): y = np.searchsorted(x, x[-1]) assert_equal(y, 2) - @xpassIfTorchDynamo # ( + @xfail # ( # reason="'searchsorted_out_cpu' not implemented for 'ComplexDouble'" # ) def test_searchsorted_complex(self): @@ -5983,6 +5985,11 @@ def test_unnamed_fields(self): self._check("i:f0:", [("f0", "i")]) +# NOTE: xpassIfTorchDynamo below +# 1. TODO: torch._numpy does not handle/model _CopyMode +# 2. order= keyword not supported (probably won't be) +# 3. Under TEST_WITH_TORCHDYNAMO many of these make it through due +# to a graph break leaving the _CopyMode to only be handled by numpy. @skipif(numpy.__version__ < "1.23", reason="CopyMode is new in NumPy 1.22") @xpassIfTorchDynamo @instantiate_parametrized_tests diff --git a/test/torch_np/numpy_tests/core/test_scalarmath.py b/test/torch_np/numpy_tests/core/test_scalarmath.py index 9c535aefe016..8099ca8c4c32 100644 --- a/test/torch_np/numpy_tests/core/test_scalarmath.py +++ b/test/torch_np/numpy_tests/core/test_scalarmath.py @@ -732,7 +732,7 @@ def test_numpy_abs(self, dtype): @instantiate_parametrized_tests class TestBitShifts(TestCase): - @parametrize("type_code", np.typecodes["AllInteger"]) + @parametrize("type_code", np.typecodes["Integer"] + "B") @parametrize("op", [operator.rshift, operator.lshift]) def test_shift_all_bits(self, type_code, op): """Shifts where the shift amount is the width of the type or wider""" diff --git a/test/torch_np/numpy_tests/lib/test_arraysetops.py b/test/torch_np/numpy_tests/lib/test_arraysetops.py index 34176ee3f3b7..73897bea6981 100644 --- a/test/torch_np/numpy_tests/lib/test_arraysetops.py +++ b/test/torch_np/numpy_tests/lib/test_arraysetops.py @@ -3,7 +3,7 @@ """Test functions for 1D array set operations. """ -from unittest import skipIf +from unittest import expectedFailure as xfail, skipIf import numpy @@ -34,7 +34,7 @@ @skipIf(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") -@xpassIfTorchDynamo # (reason="TODO") +@skipIf(True, reason="TODO implement these ops") @instantiate_parametrized_tests class TestSetOps(TestCase): def test_intersect1d(self): @@ -531,6 +531,7 @@ def test_in1d_both_arrays_are_object(self): result = np.in1d(ar1, ar2) assert_array_equal(result, expected) + @xfail def test_in1d_both_arrays_have_structured_dtype(self): # Test arrays of a structured data type containing an integer field # and a field of dtype `object` allowing for arbitrary Python objects diff --git a/test/torch_np/numpy_tests/lib/test_function_base.py b/test/torch_np/numpy_tests/lib/test_function_base.py index d0eda87b0108..aea6c8ee38d9 100644 --- a/test/torch_np/numpy_tests/lib/test_function_base.py +++ b/test/torch_np/numpy_tests/lib/test_function_base.py @@ -3259,7 +3259,7 @@ def test_keepdims_2(self): subtest( [1, 7], decorators=[ - xpassIfTorchDynamo, + skip(reason="Keepdims wrapper incorrect for multiple q"), ], ), ], @@ -3273,13 +3273,13 @@ def test_keepdims_2(self): subtest( (0, 1), decorators=[ - xpassIfTorchDynamo, + skip(reason="Tuple axes"), ], ), subtest( (-3, -1), decorators=[ - xpassIfTorchDynamo, + skip(reason="Tuple axes"), ], ), ], @@ -3839,13 +3839,13 @@ def test_keepdims_2(self): subtest( (0, 1), decorators=[ - xpassIfTorchDynamo, + skip(reason="Tuple axes"), ], ), subtest( (-3, -1), decorators=[ - xpassIfTorchDynamo, + skip(reason="Tuple axes"), ], ), ], diff --git a/test/torch_np/numpy_tests/lib/test_histograms.py b/test/torch_np/numpy_tests/lib/test_histograms.py index 954fbf111484..7f8c145a05de 100644 --- a/test/torch_np/numpy_tests/lib/test_histograms.py +++ b/test/torch_np/numpy_tests/lib/test_histograms.py @@ -353,7 +353,7 @@ def test_signed_overflow_bounds(self): self.do_signed_overflow_bounds(np.short) self.do_signed_overflow_bounds(np.intc) - @xpassIfTorchDynamo # (reason="int->float conversin loses precision") + @xfail # (reason="int->float conversin loses precision") def test_signed_overflow_bounds_2(self): self.do_signed_overflow_bounds(np.int_) self.do_signed_overflow_bounds(np.longlong) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 9dc5bc52ae76..c053f04662a9 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -847,25 +847,34 @@ def can_constant_fold_through(cls, fn): assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] return fn in cls.constant_fold_functions + @classmethod + def get_constant_collection_for_func(cls, fn): + mod = fn.__module__.split(".") + assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] + return np_constant_collections_map.get(fn, None) + def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": if not config.trace_numpy: unimplemented(f"numpy.{self.value}()") - import numpy as np - from ..utils import numpy_to_tensor_wrapper from .tensor import NumpyNdarrayVariable - # lookup method name in tnp. Things like np.dtype(float) are not supported yet. - if self.value.__name__ == "dtype": + func = get_np_to_tnp_map().get(self.value) + if func is None: unimplemented( - f"numpy dtype function is not supported yet. Got type {type(self.value)}." + f"Can't find numpy function {self.value} in torch._numpy. " + " Please file an issue to request support for this function." ) - elif self.value in (np.iinfo, np.finfo): + + # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo) + if ( + collection_variable_typ := self.get_constant_collection_for_func(func) + ) is not None: try: - return NumpyTypeInfoVariable( + return collection_variable_typ( self.value( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, @@ -875,14 +884,7 @@ def call_function( unimplemented( f"{self.value.__name__} with non-const args: {args} {kwargs}" ) - else: # We are dealing with a callable. - func = get_np_to_tnp_map().get(self.value) - if func is None: - unimplemented( - f"Can't find numpy function {self.value} in torch._numpy. " - " Please file an issue to request support for this function." - ) - + else: if ( func.__module__ == "torch._numpy.random" and config.use_numpy_random_stream @@ -1091,9 +1093,14 @@ class ConstantLikeVariable(VariableTracker): _error_prefix = "ConstantLikeVariable" try: - from numpy import floating as np_floating + from numpy import ( + dtype as np_dtype, + floating as np_floating, + generic as np_generic, + ) except ImportError: np_floating = type("invalid_type", (), {}) + np_dtype = type("invalid_type", (), {}) def __init__(self, value, **kwargs): super().__init__(**kwargs) @@ -1132,6 +1139,11 @@ def var_getattr(self, tx, name: str) -> VariableTracker: result = getattr(self.value, name) if isinstance(result, self.np_floating): result = float(result) + if isinstance(result, self.np_dtype): + return NumpyDTypeVariable(result) + if isinstance(result, type) and issubclass(result, self.np_generic): + # things like x.dtype.type + return NumpyVariable(result) if variables.ConstantVariable.is_literal(result): return variables.ConstantVariable.create(result) return GetAttrVariable(self, name) @@ -1156,3 +1168,22 @@ def __init__(self, **kwargs): class NumpyTypeInfoVariable(ConstantLikeVariable): _error_prefix = "np.iinfo/np.finfo" + + +class NumpyDTypeVariable(ConstantLikeVariable): + _error_prefix = "np.dtype[...]" + + def as_proxy(self): + """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable: + + np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype. + This also handles unsupported things nicely (i.e. structured arrays and object arrays). + """ + return self.value.type.__name__ + + +np_constant_collections_map = { + tnp.finfo: NumpyTypeInfoVariable, + tnp.iinfo: NumpyTypeInfoVariable, + tnp.dtype: NumpyDTypeVariable, +} From 9a8e8101a8379e2fb3cf8a45d4cac368d7203a51 Mon Sep 17 00:00:00 2001 From: Derek <54462961+TASPlasma@users.noreply.github.com> Date: Wed, 29 May 2024 14:55:40 +0000 Subject: [PATCH 037/706] Fix wording in nn.Linear docstring. (#127240) Definition (Linear Transformation): A mapping $T : V \to W$ between $F$-vector spaces $V,W$ is called a *linear transformation* if and only if a) $T(u+v)=T(u)+T(v)$, b) $T(cv)=cT(v)$ for all $u, v \in V$, $c \in F$. Consequently, $T(0_V)=0_W$. Thus $x \mapsto xA^T+b$ for nonzero $b$ is **not** a linear transformation, but is often referred to as an affine linear transformation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127240 Approved by: https://github.com/soulitzer, https://github.com/albanD --- torch/nn/modules/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 720c1ca01c15..54981596f7ee 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -47,7 +47,7 @@ def forward(self, input: Tensor) -> Tensor: class Linear(Module): - r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. + r"""Applies an affine linear transformation to the incoming data: :math:`y = xA^T + b`. This module supports :ref:`TensorFloat32`. From 80a8fc07b2314df18b6898238b3ba4e6be94e842 Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Tue, 28 May 2024 17:56:41 +0000 Subject: [PATCH 038/706] [dynamo] Handle np.iinfo/finfo/dtype as input (#124482) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124482 Approved by: https://github.com/lezcano ghstack dependencies: #124481 --- test/dynamo/test_functions.py | 78 +++++++++++++++++++ .../numpy_tests/core/test_multiarray.py | 1 + torch/_dynamo/trace_rules.py | 12 +++ torch/_dynamo/variables/builder.py | 20 ++++- 4 files changed, 110 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index d919ba57b49b..20b9fadcf015 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2197,6 +2197,84 @@ def inner(): self.assertTrue(same(program(input1, input2), input1 + input1)) + @parametrize("int_or_float", ("int", "float")) + def test_np_constant_collections_as_input(self, int_or_float): + info_func = getattr(np, f"{int_or_float[0]}info") + dt_string_arg = f"{int_or_float}16" + np_dt_attr = getattr(np, dt_string_arg) + + dt_args = [dt_string_arg, np_dt_attr] + arg_variants_iter = itertools.chain( + dt_args, map(np.dtype, dt_args), map(info_func, dt_args) + ) + + def func(a, b, info_or_dt): + return a + info_func(info_or_dt).max + + opt_fn = torch.compile(func) + + a = torch.randn(2) + b = torch.randn(2) + eager_result = func(a, b, dt_args[0]) + + for arg in arg_variants_iter: + opt_result = opt_fn(a, b, arg) + self.assertTrue(same(opt_result, eager_result)) + + @parametrize( + "typ, info_func", + [ + (int, np.iinfo), + (float, np.finfo), + ], + name_fn=lambda t, _: t.__name__, + ) + def test_np_constant_collections_guards(self, typ, info_func): + def func_info(a, info): + return a + info.max + + def func_dtype(a, dt): + return a + info_func(dt).max + + dt_args = [ + np.dtype(typ), + np.ones((1,), dtype=typ).dtype, + np.dtype(np.dtype(typ).name), + np.dtype(typ.__name__), + ] + cnts_1 = torch._dynamo.testing.CompileCounter() + opt_fn_dtype = torch._dynamo.optimize(cnts_1)(func_dtype) + a = torch.zeros(3, dtype=typ) + for arg in dt_args: + r = opt_fn_dtype(a, arg) + # each should produce an identical arg + self.assertEqual(cnts_1.frame_count, 1) + + cnts_2 = torch._dynamo.testing.CompileCounter() + opt_fn_info = torch._dynamo.optimize(cnts_2)(func_info) + info_args = [info_func(dt) for dt in dt_args] + for arg in info_args: + r = opt_fn_info(a, arg) + + # each should produce an identical arg + self.assertEqual(cnts_2.frame_count, 1) + + if typ is float: + dt_extra = np.dtype(np.float16) + else: + dt_extra = np.dtype(np.int16) + info_extra = info_func(dt_extra) + + eager_result_dtype = func_dtype(a, dt_extra) + compile_result_dtype = opt_fn_dtype(a, dt_extra) + self.assertEqual(cnts_1.frame_count, 2) + self.assertEqual(eager_result_dtype, compile_result_dtype) + + eager_result_info = func_info(a, info_extra) + compile_result_info = opt_fn_info(a, info_extra) + self.assertEqual(cnts_2.frame_count, 2) + self.assertEqual(eager_result_info, compile_result_info) + def test_compare_constant_and_tensor(self): for op in [ operator.lt, diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index 3dec3c2cdddd..76af79f62084 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -6018,6 +6018,7 @@ def test_scalars(self): with pytest.raises(ValueError): np.array(pyscalar, dtype=np.int64, copy=np._CopyMode.NEVER) + @xfail # TODO: handle `_CopyMode` properly in torch._numpy def test_compatible_cast(self): # Some types are compatible even though they are different, no # copy is necessary for them. This is mostly true for some integers diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 6be6e4965ce1..cccb80fb0c77 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3094,6 +3094,18 @@ def is_numpy(obj) -> bool: return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids +def is_numpy_dtype(obj) -> bool: + if np is None: + return False + return isinstance(obj, np.dtype) + + +def is_numpy_type_info(obj) -> bool: + if np is None: + return False + return isinstance(obj, (np.finfo, np.iinfo)) + + BUILTIN_SKIPLIST = ( abc, collections, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 10a79ed8ff31..2d0543f8b147 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -70,7 +70,12 @@ Source, TupleIteratorGetItemSource, ) -from ..trace_rules import is_callable_allowed, is_numpy +from ..trace_rules import ( + is_callable_allowed, + is_numpy, + is_numpy_dtype, + is_numpy_type_info, +) from ..utils import ( build_checkpoint_variable, clone_input, @@ -151,6 +156,8 @@ LambdaVariable, LoggingLoggerVariable, MethodWrapperVariable, + NumpyDTypeVariable, + NumpyTypeInfoVariable, NumpyVariable, PythonModuleVariable, RegexPatternVariable, @@ -625,6 +632,17 @@ def build_key_value(i, k, v): else GuardBuilder.TYPE_MATCH ) return NumpyVariable(value, source=self.source) + elif is_numpy_dtype(value): + self.install_guards(GuardBuilder.ID_MATCH) + return NumpyDTypeVariable(value, source=self.source) + elif is_numpy_type_info(value): + if isinstance(value, np.iinfo): + self.install_guards(GuardBuilder.TYPE_MATCH) + dt_source = AttrSource(self.source, "dtype") + install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH)) + else: + self.install_guards(GuardBuilder.ID_MATCH) + return NumpyTypeInfoVariable(value, source=self.source) # NB: These can't be put in type_dispatch, they have to run later elif CollectiveFunctionRewriteVariable.can_rewrite(value): self.install_guards(GuardBuilder.FUNCTION_MATCH) From c69562caf9475496207bd24ee670337d426f6e03 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 29 May 2024 16:08:48 +0000 Subject: [PATCH 039/706] [Caffe2]Remove more caffe2 files (#126628) They are not used. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126628 Approved by: https://github.com/albanD --- caffe2/cuda_rtc/CMakeLists.txt | 11 - caffe2/cuda_rtc/common_rtc.h | 131 - caffe2/cuda_rtc/elemenntwise_rtc_gpu.cc | 129 - caffe2/cuda_rtc/pool_op_rtc_gpu.cc | 340 --- caffe2/quantization/__init__.py | 0 caffe2/test/assets/squeeze_predict_net.pb | Bin 6176 -> 0 bytes caffe2/test/caffe2_gtest_main.cc | 46 - caffe2/utils/hip/math_blas_gpu_test.cc | 379 --- caffe2/utils/math-detail.h | 90 - caffe2/utils/math.h | 467 ---- caffe2/utils/math/broadcast.cu | 110 - caffe2/utils/math/elementwise.cu | 918 ------- caffe2/utils/math/reduce.cu | 593 ----- caffe2/utils/math/reduce.cuh | 61 - caffe2/utils/math/transpose.cu | 233 -- caffe2/utils/math_gpu.cu | 2871 --------------------- caffe2/utils/math_gpu_test.cc | 429 --- caffe2/utils/math_test.cc | 523 ---- 18 files changed, 7331 deletions(-) delete mode 100644 caffe2/cuda_rtc/CMakeLists.txt delete mode 100644 caffe2/cuda_rtc/common_rtc.h delete mode 100644 caffe2/cuda_rtc/elemenntwise_rtc_gpu.cc delete mode 100644 caffe2/cuda_rtc/pool_op_rtc_gpu.cc delete mode 100644 caffe2/quantization/__init__.py delete mode 100644 caffe2/test/assets/squeeze_predict_net.pb delete mode 100644 caffe2/test/caffe2_gtest_main.cc delete mode 100644 caffe2/utils/hip/math_blas_gpu_test.cc delete mode 100644 caffe2/utils/math-detail.h delete mode 100644 caffe2/utils/math.h delete mode 100644 caffe2/utils/math/broadcast.cu delete mode 100644 caffe2/utils/math/elementwise.cu delete mode 100644 caffe2/utils/math/reduce.cu delete mode 100644 caffe2/utils/math/reduce.cuh delete mode 100644 caffe2/utils/math/transpose.cu delete mode 100644 caffe2/utils/math_gpu.cu delete mode 100644 caffe2/utils/math_gpu_test.cc delete mode 100644 caffe2/utils/math_test.cc diff --git a/caffe2/cuda_rtc/CMakeLists.txt b/caffe2/cuda_rtc/CMakeLists.txt deleted file mode 100644 index 6bb289b79d72..000000000000 --- a/caffe2/cuda_rtc/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -if(USE_CUDA) - set(Caffe2_CUDA_RTC_GPU_SRC - "${CMAKE_CURRENT_SOURCE_DIR}/elemenntwise_rtc_gpu.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/pool_op_rtc_gpu.cc" - ) - - set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} ${Caffe2_CUDA_RTC_GPU_SRC}) - set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} PARENT_SCOPE) -else() - message(STATUS "CUDA RTC operators skipped due to no CUDA support") -endif() diff --git a/caffe2/cuda_rtc/common_rtc.h b/caffe2/cuda_rtc/common_rtc.h deleted file mode 100644 index 0fa6bad7a0c4..000000000000 --- a/caffe2/cuda_rtc/common_rtc.h +++ /dev/null @@ -1,131 +0,0 @@ -#ifndef CAFFE2_CUDA_RTC_COMMON_RTC_H_ -#define CAFFE2_CUDA_RTC_COMMON_RTC_H_ - -#include -#include - -#include -#include - -#define NVRTC_CHECK(condition) \ - do { \ - nvrtcResult result = condition; \ - if (result != NVRTC_SUCCESS) { \ - LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \ - << nvrtcGetErrorString(result); \ - } \ - } while (0) - -namespace caffe2 { - -template -class CudaRTCFunction { - public: - CudaRTCFunction() : module_loaded_(false) {} - ~CudaRTCFunction() { - if (module_loaded_) { - CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_)); - } - } - - // TODO: this function is nontrivial and since CudaRTCFunction uses CRTP, it - // may potentially increase the binary size. In that case, move common parts - // into a separate function. - template - void Compile(Args... args) { - string src = static_cast(this)->GetSource(args...); - string name = static_cast(this)->KernelName(args...); - VLOG(1) << "function name: " << name; - VLOG(1) << "function src:\n" << src; - // Actually do the compiling. - nvrtcProgram prog; - NVRTC_CHECK( - nvrtcCreateProgram(&prog, src.c_str(), nullptr, 0, nullptr, nullptr)); - // Compile the program. - // TODO(Yangqing): how to find the current gpu architecture instead of hard - // coding it? - const char* nvrtc_opts[] = { - "--gpu-architecture=compute_35", "--use_fast_math"}; - nvrtcResult compile_result = nvrtcCompileProgram(prog, 2, nvrtc_opts); - if (compile_result != NVRTC_SUCCESS) { - size_t log_size; - NVRTC_CHECK(nvrtcGetProgramLogSize(prog, &log_size)); - std::string nvrtc_log(log_size, '\0'); - NVRTC_CHECK(nvrtcGetProgramLog(prog, &nvrtc_log[0])); - LOG(FATAL) << "Compilation failure for nvrtc(" - << nvrtcGetErrorString(compile_result) << "): \n" - << nvrtc_log; - } - size_t ptx_size; - NVRTC_CHECK(nvrtcGetPTXSize(prog, &ptx_size)); - vector nvrtc_ptx(ptx_size); - NVRTC_CHECK(nvrtcGetPTX(prog, nvrtc_ptx.data())); - NVRTC_CHECK(nvrtcDestroyProgram(&prog)); - // After compilation, load the module. - if (module_loaded_) { - CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_)); - } - CUDA_DRIVERAPI_ENFORCE( - cuModuleLoadDataEx(&module_, nvrtc_ptx.data(), 0, 0, 0)); - module_loaded_ = true; - CUDA_DRIVERAPI_ENFORCE( - cuModuleGetFunction(&kernel_, module_, name.c_str())); - } - - template - void Launch( - unsigned int gx, - unsigned int gy, - unsigned int gz, - unsigned int bx, - unsigned int by, - unsigned int bz, - unsigned int shared_mem, - cudaStream_t stream, - Args... args) { - CAFFE_ENFORCE( - module_loaded_, "Cannot call Launch before a module is loaded."); - void* args_voidp[] = {&args...}; - CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( - kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, args_voidp, 0)); - } - - void LaunchEx( - unsigned int gx, - unsigned int gy, - unsigned int gz, - unsigned int bx, - unsigned int by, - unsigned int bz, - unsigned int shared_mem, - cudaStream_t stream, - void** extra) { - CAFFE_ENFORCE( - module_loaded_, "Cannot call Launch before a module is loaded."); - CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( - kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, nullptr, extra)); - } - - private: - bool module_loaded_; - CUmodule module_; - CUfunction kernel_; -}; - -// TODO: this is in no way unique and is just a hack right now. -inline std::string GetUniqueName() { - static constexpr int len = 20; - static const char alpha[] = - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; - - std::stringstream ss; - ss << "_cuda_kernel_"; - for (const auto i : c10::irange(len)) { - ss << alpha[rand() % (sizeof(alpha) - 1)]; - } - return ss.str(); -} - -} // namespace caffe2 - -#endif // CAFFE2_CUDA_RTC_COMMON_RTC_H_ diff --git a/caffe2/cuda_rtc/elemenntwise_rtc_gpu.cc b/caffe2/cuda_rtc/elemenntwise_rtc_gpu.cc deleted file mode 100644 index dfa3981731e7..000000000000 --- a/caffe2/cuda_rtc/elemenntwise_rtc_gpu.cc +++ /dev/null @@ -1,129 +0,0 @@ -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/core/operator.h" -#include "caffe2/cuda_rtc/common_rtc.h" - -namespace caffe2 { -namespace { -class ElementwiseRTCFunction : public CudaRTCFunction { - public: - ElementwiseRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {} - - template - string KernelName(Args... /*args*/) { - return name_; - } - - template - string GetSource(Args... args); - - private: - string name_; -}; - -template <> -string ElementwiseRTCFunction::GetSource( - int input_size, - int output_size, - const string command_string) { - std::stringstream ss; - ss << "extern \"C\" __global__ void " << name_ - << "(const size_t nthreads, \n"; - // Insert the parameter list. - int remain_params = input_size + output_size; - for (int i = 0; i < input_size; ++i) { - ss << "const float* in" << i << ((remain_params--) ? ", \n" : ""); - } - for (int i = 0; i < output_size; ++i) { - ss << "float* out" << i << ((remain_params--) ? ", \n" : ""); - } - ss << ") {\n" - "for (int index = blockIdx.x * blockDim.x + threadIdx.x;\n" - "index < nthreads; index += blockDim.x * gridDim.x) {\n" - << command_string << "\n" - << "}\n}"; - return ss.str(); -} -} // namespace - -/** - * A GPU operator that can generate limited elementwise operations. - * - * ElementwiseRTCOp allows one to do a simple and limited thing: it takes in - * multiple inputs and multiple outputs, as well as a raw string argument - * rtc_src. The runtime then generates the following kernel code: - * - * __global__ void kernel_name(const size_t nthreads, ...) { - * for(int index = blockIdx.x * blockDim.x + threadIdx.x; - * index < nthreads; index += blockDim.x * gridDim.x) { - * rtc_src - * } - * } - * where the "..." part is auto generated, so one can refer to the input and - * output as in0, in1, ..., out0, out1... in the rtc_src string. - * - * For example, if one wants to do a vector multiplication, one can take two - * inputs and one outputs, and write rtc_src as - * out0[index] = in0[index] * in1[index]; - * - * This op is currently highly experimental. We do not have a gradient - * registered for it either. - */ -class ElementwiseRTCOp final : public Operator { - public: - ElementwiseRTCOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) { - const string src = OperatorBase::GetSingleArgument("rtc_src", ""); - CAFFE_ENFORCE(src.size(), "Op should have a non-zero source code size."); - func_.Compile(InputSize(), OutputSize(), src); - } - ~ElementwiseRTCOp() override {} - - bool RunOnDevice() override { - static_assert( - sizeof(void*) == sizeof(size_t), - "The argbuffer relies on the assumption that void* and " - "size_t have the same size."); - vector argBuffer_vec(InputSize() + OutputSize() + 1); - size_t* argBuffer = argBuffer_vec.data(); - CAFFE_ENFORCE( - Input(0).numel() < std::numeric_limits::max(), - "The kernel function currently only supports int index."); - argBuffer[0] = Input(0).numel(); - void** ptr_buffer = reinterpret_cast(argBuffer + 1); - for (int i = 0; i < InputSize(); ++i) { - ptr_buffer[i] = const_cast(Input(i).data()); - } - for (int i = 0; i < OutputSize(); ++i) { - Output(i)->ResizeLike(Input(0)); - ptr_buffer[i + InputSize()] = Output(i)->mutable_data(); - } - size_t argBufferSize = sizeof(argBuffer); - void* config[] = { - CU_LAUNCH_PARAM_BUFFER_POINTER, - argBuffer, - CU_LAUNCH_PARAM_BUFFER_SIZE, - &argBufferSize, - CU_LAUNCH_PARAM_END}; - func_.LaunchEx( - CAFFE_GET_BLOCKS(Input(0).numel()), - 1, - 1, - CAFFE_CUDA_NUM_THREADS, - 1, - 1, - 0, - context_.cuda_stream(), - config); - return true; - } - - private: - ElementwiseRTCFunction func_; -}; - -namespace { -REGISTER_CUDA_OPERATOR_WITH_ENGINE(ElementwiseRTC, NVRTC, ElementwiseRTCOp); -} - -} // namespace caffe2 diff --git a/caffe2/cuda_rtc/pool_op_rtc_gpu.cc b/caffe2/cuda_rtc/pool_op_rtc_gpu.cc deleted file mode 100644 index 8ec14e1223ae..000000000000 --- a/caffe2/cuda_rtc/pool_op_rtc_gpu.cc +++ /dev/null @@ -1,340 +0,0 @@ -#include - -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/cuda_rtc/common_rtc.h" -#include "caffe2/operators/pool_op.h" - -namespace caffe2 { -namespace { -class AveragePool {}; -class MaxPool {}; -} // namespace - -namespace { - -// The max pool forward function, with parameters written in const int. -const char kMaxPoolForwardNCHWSource[] = R"( -extern "C" -__global__ void %s(const float* bottom_data, float* top_data) { - const int nthreads = %d; - const int channels = %d; - const int height = %d; - const int width = %d; - const int pooled_height = %d; - const int pooled_width = %d; - const int kernel_h = %d; - const int kernel_w = %d; - const int stride_h = %d; - const int stride_w = %d; - const int pad_t = %d; - const int pad_l = %d; - for (int index = blockIdx.x * blockDim.x + threadIdx.x; - index < nthreads; index += blockDim.x * gridDim.x) { - int pw = index %% pooled_width; - int ph = (index / pooled_width) %% pooled_height; - int c = (index / (pooled_width * pooled_height)) %% channels; - int n = index / (pooled_width * pooled_height * channels); - int hstart = ph * stride_h - pad_t; - int wstart = pw * stride_w - pad_l; - int hend = min(hstart + kernel_h, height); - int wend = min(wstart + kernel_w, width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - float maxval = -1.0e37f; - const float* bdata_offset = bottom_data + n * channels * height * width; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - maxval = fmaxf( - bdata_offset[c * height * width + h * width + w], maxval); - } - } - top_data[index] = maxval; - } -} -)"; - -// The max pool forward function, with parameters written in const int. -const char kMaxPoolBackwardNCHWSource[] = R"( -extern "C" -__global__ void %s( - const float* const bottom_data, const float* const top_data, - const float* const top_diff, float* const bottom_diff) { - const int nthreads = %d; - const int num = %d; - const int channels = %d; - const int height = %d; - const int width = %d; - const int pooled_height = %d; - const int pooled_width = %d; - const int kernel_h = %d; - const int kernel_w = %d; - const int stride_h = %d; - const int stride_w = %d; - const int pad_t = %d; - const int pad_l = %d; - for (int index = blockIdx.x * blockDim.x + threadIdx.x; - index < nthreads; index += blockDim.x * gridDim.x) { - const int w = index %% width + pad_l; - const int h = (index / width) %% height + pad_t; - const int c = (index / width / height) %% channels; - const int n = index / width / height / channels; - const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; - const int phend = min(h / stride_h + 1, pooled_height); - const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; - const int pwend = min(w / stride_w + 1, pooled_width); - const int top_offset = - (n * channels + c) * pooled_height * pooled_width; - bottom_diff[index] = 0; - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - int top_local_offset = top_offset + ph * pooled_width + pw; - if (bottom_data[index] == top_data[top_local_offset]) { - bottom_diff[index] += top_diff[top_local_offset]; - } - } - } - } -} -)"; - -class MaxPoolRTCFunction : public CudaRTCFunction { - public: - MaxPoolRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {} - - template - string KernelName(Args... /*args*/) { - return name_; - } - - template - string GetSource(Args... args); - - private: - string name_; -}; - -class MaxPoolGradientRTCFunction - : public CudaRTCFunction { - public: - MaxPoolGradientRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {} - - template - string KernelName(Args... /*args*/) { - return name_; - } - - template - string GetSource(Args... args); - - private: - string name_; -}; - -template <> -string MaxPoolRTCFunction::GetSource( - const int output_size, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int kernel_h, - const int kernel_w, - const int stride_h, - const int stride_w, - const int pad_t, - const int pad_l) { - char buffer[65536]; - int nbytes = snprintf( - buffer, - 65536, - kMaxPoolForwardNCHWSource, - name_.c_str(), - output_size, - channels, - height, - width, - pooled_height, - pooled_width, - kernel_h, - kernel_w, - stride_h, - stride_w, - pad_t, - pad_l); - TORCH_DCHECK_GE(nbytes, 0); - TORCH_DCHECK_LT(nbytes, 65536); - return string(buffer); -} - -template <> -string MaxPoolGradientRTCFunction::GetSource( - const int output_size, - const int num, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int kernel_h, - const int kernel_w, - const int stride_h, - const int stride_w, - const int pad_t, - const int pad_l) { - char buffer[65536]; - int nbytes = snprintf( - buffer, - 65536, - kMaxPoolBackwardNCHWSource, - name_.c_str(), - output_size, - num, - channels, - height, - width, - pooled_height, - pooled_width, - kernel_h, - kernel_w, - stride_h, - stride_w, - pad_t, - pad_l); - TORCH_DCHECK_GE(nbytes, 0); - TORCH_DCHECK_LT(nbytes, 65536); - return string(buffer); -} - -} // namespace - -class MaxPoolRTCOp final : public ConvPoolOpBase { - public: - MaxPoolRTCOp(const OperatorDef& operator_def, Workspace* ws) - : ConvPoolOpBase(operator_def, ws) { - CAFFE_ENFORCE_EQ( - order_, StorageOrder::NCHW, "Currently only NCHW is supported."); - } - ~MaxPoolRTCOp() override {} - - bool RunOnDeviceWithOrderNCHW() override { - auto& X = Input(0); - auto output_sizes = - ConvPoolOpBase::GetOutputSize(X, X.dim32(1)); - auto* Y = Output(0, output_sizes, at::dtype()); - - if (input_dims_ != X.sizes()) { - // recompile - VLOG(1) << "MaxPool RTC recompiling"; - CAFFE_ENFORCE_LT(Y->numel(), std::numeric_limits::max()); - func_.Compile( - static_cast(Y->numel()), - X.dim32(1), - X.dim32(2), - X.dim32(3), - Y->dim32(2), - Y->dim32(3), - kernel_h(), - kernel_w(), - stride_h(), - stride_w(), - pad_t(), - pad_l()); - input_dims_ = X.sizes().vec(); - } - // Carry out the pooling computation. - func_.Launch( - CAFFE_GET_BLOCKS(Y->numel()), - 1, - 1, - CAFFE_CUDA_NUM_THREADS, - 1, - 1, - 0, - context_.cuda_stream(), - X.data(), - Y->mutable_data()); - return true; - } - - bool RunOnDeviceWithOrderNHWC() override { - LOG(FATAL) << "Not implemented."; - return false; - } - - private: - MaxPoolRTCFunction func_; - vector input_dims_; -}; - -class MaxPoolGradientRTCOp final : public ConvPoolOpBase { - public: - MaxPoolGradientRTCOp(const OperatorDef& operator_def, Workspace* ws) - : ConvPoolOpBase(operator_def, ws) { - CAFFE_ENFORCE_EQ( - order_, StorageOrder::NCHW, "Currently only NCHW is supported."); - } - ~MaxPoolGradientRTCOp() override {} - - bool RunOnDeviceWithOrderNCHW() override { - auto& X = Input(0); - auto& Y = Input(1); - auto& dY = Input(2); - CAFFE_ENFORCE_EQ(dY.dim(), 4); - - auto* dX = Output(0, X.sizes(), at::dtype()); - ConvPoolOpBase::ComputePads({X.dim32(2), X.dim32(3)}); - if (input_dims_ != X.sizes()) { - VLOG(1) << "MaxPoolGradient RTC recompiling"; - CAFFE_ENFORCE_LT(X.numel(), std::numeric_limits::max()); - func_.Compile( - static_cast(X.numel()), - X.dim32(0), - X.dim32(1), - X.dim32(2), - X.dim32(3), - dY.dim32(2), - dY.dim32(3), - kernel_h(), - kernel_w(), - stride_h(), - stride_w(), - pad_t(), - pad_l()); - input_dims_ = X.sizes().vec(); - } - func_.Launch( - CAFFE_GET_BLOCKS(X.numel()), - 1, - 1, - CAFFE_CUDA_NUM_THREADS, - 1, - 1, - 0, - context_.cuda_stream(), - X.data(), - Y.data(), - dY.data(), - dX->mutable_data()); - return true; - } - - bool RunOnDeviceWithOrderNHWC() override { - LOG(FATAL) << "Not implemented."; - return false; - } - - private: - MaxPoolGradientRTCFunction func_; - vector input_dims_; -}; - -namespace { -REGISTER_CUDA_OPERATOR_WITH_ENGINE(MaxPool, NVRTC, MaxPoolRTCOp); -REGISTER_CUDA_OPERATOR_WITH_ENGINE( - MaxPoolGradient, - NVRTC, - MaxPoolGradientRTCOp); -} // namespace -} // namespace caffe2 diff --git a/caffe2/quantization/__init__.py b/caffe2/quantization/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/caffe2/test/assets/squeeze_predict_net.pb b/caffe2/test/assets/squeeze_predict_net.pb deleted file mode 100644 index ac4c476b91cc6a17d26b8e2adfe65c8e05e8951a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6176 zcmb`L%Wu;_5XMvb2wL7NK)rCu3FT5snm%$tJcPJ_cu5?Ph@8fhsA(ONlm`BKc6VmX zH)|6%5;yzvHxGM$+1P6u_NCbvX7kK6kBjpkyZuW!7}(9Av^Ji*3vFYzX7lN2m~UTc zH_iGa8*Vqec$!bo^YQk&y=%x{hP}_n7xq)BPKsi@+iJef7H{3rKl8P3n9X83%%`o5 zH_u;vYTqZfFGamM>F>R* zDI}uaT%_(Tqzb)#uH1zdeCTz0LJrE~Y;J8IPy55s*{rqYNOjjDbZlW%DACcWAx3FT z3~GPBZ~Q5g#9$~9gGy@yr-D~#1f87Q zB2xa=YKehly0?%Di4mv!T@f)NI42R)*-^yE=~8085HaGlVX5i-_r^H6-AVI+#scP$0$N1jT?m2-aL1VqC|Ts z>MiiI5{i&a_ZCtiisCF$6u~(O#eI#UoGwL?7ccP=F$XooD2<6hEp;YL42BXhsOYKS z6(Z)KN(@Ro6_8T9V`3nZrUF80ZypaOMxwnG^%lmw5;2fW_ZCtiG2$$V5y3f$m;+6W zoTd;XULxkOh8U$WF{q`^go(jWA_f&b6}&>k99D@ziKhZmYIjTwMAB41NbSwz!Nf?k zm!jUnm{%eOlIh+;DkMgnB{3p6ClPa~iILM3V#G_t9Muq`G$sbM)R{0b7)r#TqNjpa zh?t`)F(~m=KuYb7iGfI(3J9sac|4dHiS|;|TNv|7#6U9LTS$e(h_fU{1m`4Tjx;fH zx|EoY^mQlR<%!O+*=ei!Vp>d!i+TIH*_>wcQE{))Xgq0OGtJSgKhJ0L?KS&OXEj%k z{eIdb_XOo=V7Ew1o#lPQ0e=F+)A{~K4|`ybTGyWb$fwy!9&N3aU)i`n(TMg(;;pr| zH8)Sj#ULB^efH7$iMzW8W{XN$Gb_H%&$5Mk9IfUDI_`xn#GY&+%E6#l-H5xa(FQ%W z>Ex!}zpP4;%cB^%yu>weKgHQXA4=nFvRwc*QHoqpQ(RCJKb)ne(m0!Z7eY;xA{W#Y z7u3X0FR7_C&L&?LP!pxd1vSM5HSrx=YATJh$txt(L@9DXO>sd@d`pv>O5<$u5&|_* zid;}rTu>AL=cT67IGg;9K~0n*7t|CN)Wm-Zsi`#1CJQ0dL@9DXO>t38S(+|;ax{1b HKXCIKU?E^P diff --git a/caffe2/test/caffe2_gtest_main.cc b/caffe2/test/caffe2_gtest_main.cc deleted file mode 100644 index 920b79ef4d65..000000000000 --- a/caffe2/test/caffe2_gtest_main.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2006, Google Inc. -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#include - -#include -#include "caffe2/core/flags.h" -#include "caffe2/core/init.h" - -C10_DEFINE_string( - caffe_test_root, - "gen/", - "The root of the caffe test folder."); - -GTEST_API_ int main(int argc, char** argv) { - // std::cout << "Running main() from gtest_main.cc\n"; - testing::InitGoogleTest(&argc, argv); - caffe2::GlobalInit(&argc, &argv); - return RUN_ALL_TESTS(); -} diff --git a/caffe2/utils/hip/math_blas_gpu_test.cc b/caffe2/utils/hip/math_blas_gpu_test.cc deleted file mode 100644 index 07d4bf11f5a4..000000000000 --- a/caffe2/utils/hip/math_blas_gpu_test.cc +++ /dev/null @@ -1,379 +0,0 @@ -#include -#include "caffe2/core/blob.h" -#include "caffe2/core/context.h" -#include "caffe2/core/hip/context_gpu.h" -#include "caffe2/core/tensor.h" -#include "caffe2/operators/utility_ops.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/conversions.h" -#include "caffe2/utils/math.h" - -namespace caffe2 { - -TEST(MathROCBLASTest, GemmNoTransNoTrans) { - if (!HasHipGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_HIP); - HIPContext context(option); - - Blob* blobX = ws.CreateBlob("X"); - Blob* blobW = ws.CreateBlob("W"); - Blob* blobY = ws.CreateBlob("Y"); - Blob* blobY_host = ws.CreateBlob("Y_host"); - - vector shapeX{5, 10}; - vector shapeW{10, 6}; - vector shapeY{5, 6}; - auto* tensorX = BlobGetMutableTensor(blobX, HIP); - tensorX->Resize(shapeX); - auto* tensorW = BlobGetMutableTensor(blobW, HIP); - tensorW->Resize(shapeW); - auto* tensorY = BlobGetMutableTensor(blobY, HIP); - tensorY->Resize(shapeY); - auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU); - tensorY_host->Resize(shapeY); - - EXPECT_EQ(tensorX->size(), 50); - EXPECT_EQ(tensorW->size(), 60); - EXPECT_EQ(tensorY->size(), 30); - - math::Set( - tensorX->size(), 1, tensorX->mutable_data(), &context); - math::Set( - tensorW->size(), 1, tensorW->mutable_data(), &context); - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kOne, - tensorX->template data(), - tensorW->template data(), - kZero, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 10) << i; - } - - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kOne, - tensorX->template data(), - tensorW->template data(), - kPointFive, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 15) << i; - } - - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kPointFive, - tensorX->template data(), - tensorW->template data(), - kOne, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 20) << i; - } -} - -TEST(MathROCBLASTest, GemmNoTransTrans) { - if (!HasHipGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_HIP); - HIPContext context(option); - - Blob* blobX = ws.CreateBlob("X"); - Blob* blobW = ws.CreateBlob("W"); - Blob* blobY = ws.CreateBlob("Y"); - Blob* blobY_host = ws.CreateBlob("Y_host"); - - vector shapeX{5, 10}; - vector shapeW{6, 10}; - vector shapeY{5, 6}; - auto* tensorX = BlobGetMutableTensor(blobX, HIP); - tensorX->Resize(shapeX); - auto* tensorW = BlobGetMutableTensor(blobW, HIP); - tensorW->Resize(shapeW); - auto* tensorY = BlobGetMutableTensor(blobY, HIP); - tensorY->Resize(shapeY); - auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU); - tensorY_host->Resize(shapeY); - - EXPECT_EQ(tensorX->size(), 50); - EXPECT_EQ(tensorW->size(), 60); - EXPECT_EQ(tensorY->size(), 30); - - math::Set( - tensorX->size(), 1, tensorX->mutable_data(), &context); - math::Set( - tensorW->size(), 1, tensorW->mutable_data(), &context); - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kOne, - tensorX->template data(), - tensorW->template data(), - kZero, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 10) << i; - } - - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kOne, - tensorX->template data(), - tensorW->template data(), - kPointFive, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 15) << i; - } - - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kPointFive, - tensorX->template data(), - tensorW->template data(), - kOne, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 20) << i; - } -} - -TEST(MathROCBLASTest, GemvNoTrans) { - if (!HasHipGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_HIP); - HIPContext context(option); - - Blob* blobA = ws.CreateBlob("A"); - Blob* blobX = ws.CreateBlob("X"); - Blob* blobY = ws.CreateBlob("Y"); - Blob* blobY_host = ws.CreateBlob("Y_host"); - - vector shapeA{5, 10}; - vector shapeX{10}; - vector shapeY{5}; - auto* tensorA = BlobGetMutableTensor(blobA, HIP); - tensorA->Resize(shapeA); - auto* tensorX = BlobGetMutableTensor(blobX, HIP); - tensorX->Resize(shapeX); - auto* tensorY = BlobGetMutableTensor(blobY, HIP); - tensorY->Resize(shapeY); - auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU); - tensorY_host->Resize(shapeY); - - EXPECT_EQ(tensorA->size(), 50); - EXPECT_EQ(tensorX->size(), 10); - EXPECT_EQ(tensorY->size(), 5); - math::Set( - tensorA->size(), 1, tensorA->mutable_data(), &context); - math::Set( - tensorX->size(), 1, tensorX->mutable_data(), &context); - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemv( - CblasNoTrans, - 5, - 10, - kOne, - tensorA->data(), - tensorX->data(), - kZero, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 10) << i; - } - - // Test Accumulate - math::Gemv( - CblasNoTrans, - 5, - 10, - kOne, - tensorA->data(), - tensorX->data(), - kPointFive, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 15) << i; - } - - // Test Accumulate - math::Gemv( - CblasNoTrans, - 5, - 10, - kPointFive, - tensorA->data(), - tensorX->data(), - kOne, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 20) << i; - } -} - -TEST(MathROCBLASTest, GemvTrans) { - if (!HasHipGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_HIP); - HIPContext context(option); - - Blob* blobA = ws.CreateBlob("A"); - Blob* blobX = ws.CreateBlob("X"); - Blob* blobY = ws.CreateBlob("Y"); - Blob* blobY_host = ws.CreateBlob("Y_host"); - - vector shapeA{6, 10}; - vector shapeX{6}; - vector shapeY{10}; - auto* tensorA = BlobGetMutableTensor(blobA, HIP); - tensorA->Resize(shapeA); - auto* tensorX = BlobGetMutableTensor(blobX, HIP); - tensorX->Resize(shapeX); - auto* tensorY = BlobGetMutableTensor(blobY, HIP); - tensorY->Resize(shapeY); - auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU); - tensorY_host->Resize(shapeY); - - EXPECT_EQ(tensorA->size(), 60); - EXPECT_EQ(tensorX->size(), 6); - EXPECT_EQ(tensorY->size(), 10); - math::Set( - tensorA->size(), 1, tensorA->mutable_data(), &context); - math::Set( - tensorX->size(), 1, tensorX->mutable_data(), &context); - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemv( - CblasTrans, - 6, - 10, - kOne, - tensorA->data(), - tensorX->data(), - kZero, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 6) << i; - } - - // Test Accumulate - math::Gemv( - CblasTrans, - 6, - 10, - kOne, - tensorA->data(), - tensorX->data(), - kPointFive, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 9) << i; - } - - // Test Accumulate - math::Gemv( - CblasTrans, - 6, - 10, - kPointFive, - tensorA->data(), - tensorX->data(), - kOne, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 12) << i; - } -} -} // namespace caffe2 diff --git a/caffe2/utils/math-detail.h b/caffe2/utils/math-detail.h deleted file mode 100644 index f2ecc711995a..000000000000 --- a/caffe2/utils/math-detail.h +++ /dev/null @@ -1,90 +0,0 @@ -#ifndef CAFFE2_UTILS_MATH_DETAIL_H_ -#define CAFFE2_UTILS_MATH_DETAIL_H_ -namespace caffe2 { - -class CPUContext; - -namespace math { -namespace detail { - -// proxy to a class because of partial specialization limitations for functions - -template -struct ScaleImpl { - inline void operator()( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { - Scale(N, alpha, x, y, context); - } -}; - -// Put light-weight implementations in .h file to enable inlining -template -struct ScaleImpl { - inline void operator()( - const int N, - const float alpha, - const T* x, - T* y, - CPUContext* /*context*/) { - TORCH_DCHECK_EQ(N, 1); - *y = *x * alpha; - } -}; - -template -struct AxpyImpl { - inline void operator()( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { - Axpy(N, alpha, x, y, context); - } -}; - -// Put light-weight implementations in .h file to enable inlining -template -struct AxpyImpl { - inline void operator()( - const int N, - const float alpha, - const T* x, - T* y, - CPUContext* /*context*/) { - TORCH_DCHECK_EQ(N, 1); - *y += *x * alpha; - } -}; - - -} // namespace detail - -template -inline void ScaleFixedSize( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { - detail::ScaleImpl()(N, alpha, x, y, context); -} - -template -inline void AxpyFixedSize( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { - detail::AxpyImpl()(N, alpha, x, y, context); -} - -} // namespace math -} // namespace caffe2 - -#endif // CAFFE2_UTILS_MATH_DETAIL_H_ diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h deleted file mode 100644 index 6acc50e8e748..000000000000 --- a/caffe2/utils/math.h +++ /dev/null @@ -1,467 +0,0 @@ -#ifndef CAFFE2_UTILS_MATH_H_ -#define CAFFE2_UTILS_MATH_H_ -// This is a simple translation from the old Caffe math interfaces. We aim to -// still keep it simple, so all platforms would be able to support it fairly -// easily. - -// We include the cblas header here so that we can obtain the macros from cblas. -extern "C" { -#include "caffe2/utils/cblas.h" -} - -#ifdef CAFFE2_USE_ACCELERATE -#include -#endif // CAFFE2_USE_ACCELERATE - -#include "caffe2/core/common.h" -#include "caffe2/core/types.h" -#include "caffe2/utils/math/broadcast.h" -#include "caffe2/utils/math/elementwise.h" -#include "caffe2/utils/math/reduce.h" -#include "caffe2/utils/math/transpose.h" -#include "caffe2/utils/math/utils.h" - -namespace caffe2 { - -// TODO: Change dims related arguments to int64_t? -class Tensor; - -// An empty class as a placeholder for a math function that has no specific -// engine specified. -class TORCH_API DefaultEngine {}; - -namespace math { - -#define C10_DECLARE_COMPARE_OP(Comp) \ - template \ - void Rowwise##Comp( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - bool* C, \ - Context* context); \ - \ - template \ - void Colwise##Comp( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - bool* C, \ - Context* context); \ - \ - template \ - void Comp( \ - const int A_ndim, \ - const int* A_dims, \ - const int B_ndim, \ - const int* B_dims, \ - const T* A, \ - const T* B, \ - bool* C, \ - Context* context); - -C10_DECLARE_COMPARE_OP(EQ) -C10_DECLARE_COMPARE_OP(NE) -C10_DECLARE_COMPARE_OP(LT) -C10_DECLARE_COMPARE_OP(LE) -C10_DECLARE_COMPARE_OP(GT) -C10_DECLARE_COMPARE_OP(GE) - -#undef C10_DECLARE_COMPARE_OP - -#define C10_DECLARE_BINARY_OP(Func) \ - template \ - void Rowwise##Func( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - T* C, \ - Context* context); \ - \ - template \ - void Colwise##Func( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - T* C, \ - Context* context); \ - \ - template \ - void Func( \ - const int A_ndim, \ - const int* A_dims, \ - const int B_ndim, \ - const int* B_dims, \ - const T* A, \ - const T* B, \ - T* C, \ - Context* context); - -C10_DECLARE_BINARY_OP(Add) -C10_DECLARE_BINARY_OP(Sub) -C10_DECLARE_BINARY_OP(Mul) -C10_DECLARE_BINARY_OP(Div) - -C10_DECLARE_BINARY_OP(And) -C10_DECLARE_BINARY_OP(Or) -C10_DECLARE_BINARY_OP(Xor) - -C10_DECLARE_BINARY_OP(BitwiseAnd) -C10_DECLARE_BINARY_OP(BitwiseOr) -C10_DECLARE_BINARY_OP(BitwiseXor) - -#undef C10_DECLARE_BINARY_OP - -// Broadcasts X with X_dims to Y with Y_dims. -template -TORCH_API void Broadcast( - const int X_ndim, - const int* X_dims, - const int Y_ndim, - const int* Y_dims, - const T alpha, - const T* X, - T* Y, - Context* context, - bool allow_broadcast_fastpath=false); - -// Computes inv_std from variance. -template -TORCH_API void InvStd( - const int N, - const T epsilon, - const T* var, - T* inv_std, - Context* context); - -// Adds batch sub-tensors elementwise to output. Stripe is the stripe length -// and N is the number of elements to add (size of Y). -template -TORCH_API void AddStripedBatch( - const int N, - const T* first, - T* y, - const int stripe, - const int batch, - Context* context); - -// Compute the row-wise max of a N*D matrix X, and write it to a N -// dimensional vector y. -template -TORCH_API void -RowwiseMax(const int N, const int D, const T* x, T* y, Context* context); - -// Compute the column-wise max of a N*D matrix X, and write it to a D -// dimensional vector y. -template -TORCH_API void -ColwiseMax(const int N, const int D, const T* x, T* y, Context* context); - -// Elemwise maximum of vector x and scalar alpha. y[i] = max(x[i], alpha) -template -TORCH_API void -Maximum(const int N, const float alpha, const T* x, T* y, Context* context); - -// Decaf gemm provides a simpler interface to the gemm functions, with the -// limitation that the data has to be contiguous in memory. -template -TORCH_API void Gemm( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const T* A, - const T* B, - const float beta, - T* C, - Context* context, - TensorProto::DataType math_type = TensorProto_DataType_FLOAT); - -// We also provide a gemm that has explicit lda, ldb and ldc specified. -// In most cases you probably want to use the function above, though. -template -TORCH_API void GemmEx( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc, - Context* context); - -// GemmBatched provides a simple abstraction into library routines -template -TORCH_API void GemmBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const T** A, - const T** B, - const float beta, - T** C, - Context* context, - TensorProto::DataType math_type = TensorProto_DataType_FLOAT); - -template -TORCH_API void GemmStridedBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const T* A, - const int A_stride, - const T* B, - const int B_stride, - const float beta, - T* C, - const int C_stride, - Context* context, - TensorProto::DataType math_type = TensorProto_DataType_FLOAT); - -// Gemv always takes in a M*N matrix A, and depending on whether we set TransA -// to Trans, the output is: -// CblasNoTrans: x is an N dim vector and y is an M dim vector. -// CblasTrans: x is an M dim vector and y is an N dim vector. -template -TORCH_API void Gemv( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const T* A, - const T* x, - const float beta, - T* y, - Context* context, - TensorProto::DataType math_type = TensorProto_DataType_FLOAT); - -template -TORCH_API void -RandUniform(const size_t n, const T a, const T b, T* r, Context* context); - -// Generate n values that sum up to a fixed sum -// and subject to a restriction a <= x <= b for each x generated -template -TORCH_API void RandFixedSum( - const size_t n, - const T a, - const T b, - const T sum, - T* r, - Context* context); - -template -TORCH_API void RandUniformUnique( - const size_t n, - const T a, - const T b, - T* r, - const size_t m, - const T* avoid, - Context* context); - -// Generate n values from synthetic data distribution, -// define by unique accesses and stack distances -template -TORCH_API void -RandSyntheticData(const size_t n, const T a, const T b, T* r, Context* context); - -template -TORCH_API void -RandGaussian(const size_t n, const T mean, const T std, T* r, Context* context); - -// Dot matrix of vector a and b, and writes the result to a single value y. -template -TORCH_API void -Dot(const int N, const T* a, const T* b, T* y, Context* context); - -// Sum of vector x, and writes the result to a single value y. -template -TORCH_API void Sum( - const int N, - const T* x, - T* y, - Context* context, - Tensor* scratch_ptr = nullptr); - -// Sum of squares of vector x, and writes the result to a single value y. -template -TORCH_API void SumSqr( - const int N, - const T* x, - T* y, - Context* context, - Tensor* scratch_ptr = nullptr); - -// Select does index selection of the rows a N*D matrix x, and gives the N -// dimensional vector y that contains the selected data. -template -TORCH_API void Select( - const int N, - const int D, - const T* x, - const int* idx, - T* y, - Context* context); - -// groups must be 1 for GPU -// For NHWC order with groups > 1, the result will be layout in -// NHW G RS C/G order to make data within the same group to be contiguous. -// For NCHW order, groups doesn't make any difference because we're doing Im2Col -// for each N and C is the slowest moving dimension among CHW. -template -TORCH_API void Im2Col( - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const T* img_data, - T* col_data, - Context* context, - const int groups = 1); - -// groups must be 1 for GPU -template -TORCH_API void Im2ColNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const T* img_data, - T* col_data, - Context* context, - const int groups = 1); - -// groups must be 1 for GPU -// For NHWC order with groups > 1, the result will be layout in -// NHW G RS C/G order to make data within the same group to be contiguous. -// For NCHW order, groups doesn't make any difference because we're doing Im2Col -// for each N and C is the slowest moving dimension among CHW. -template -TORCH_API void Col2Im( - const int channels, - const int height, - const int width, - const int patch_h, - const int patch_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const T* col_data, - T* img_data, - Context* context, - const int groups = 1); - -// groups must be 1 for GPU -// For NHWC order with groups > 1, the result will be layout in -// NHW G RS C/G order to make data within the same group to be contiguous. -// For NCHW order, groups doesn't make any difference because we're doing Im2Col -// for each N and C is the slowest moving dimension among CHW. -template -TORCH_API void Col2ImNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const T* col_data, - T* img_data, - Context* context, - const int groups = 1); - -// Applies a per-channel bias value to each channel of the input -// image. image_size is H * W -template -TORCH_API void BiasCHW( - const T* bias, - const T* bias_multiplier, - const int bias_channels, - const int image_size, - T* image, - Context* context); - -template -TORCH_API void CopyMatrix( - const size_t item_size, - const int M, - const int N, - const void* A, - const int lda, - void* B, - const int ldb, - Context* context, - TypeMeta::Copy copy = nullptr); - -template -TORCH_API void CopyMatrix( - const int M, - const int N, - const T* A, - const int lda, - T* B, - const int ldb, - Context* context); - -template -TORCH_API void CopyMatrix( - const int M, - const int N, - const T* A, - const int A_outer_stride, - const int A_inner_stride, - T* B, - const int B_outer_stride, - const int B_inner_stride, - Context* context); - -template -TORCH_API void CopyVector(const int N, const T* A, T* B, Context* context); - -} // namespace math -} // namespace caffe2 - -#include "caffe2/utils/math-detail.h" -#endif // CAFFE2_UTILS_MATH_H_ diff --git a/caffe2/utils/math/broadcast.cu b/caffe2/utils/math/broadcast.cu deleted file mode 100644 index 8c0c57951926..000000000000 --- a/caffe2/utils/math/broadcast.cu +++ /dev/null @@ -1,110 +0,0 @@ -#include "caffe2/utils/math/broadcast.h" - -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math/utils.h" - -namespace caffe2 { -namespace math { - -namespace { - -template -__global__ void AffineChannelNCHWCUDAKernel( - const int C, - const int M, - const int HxW, - const T* X, - const T* scale, - const T* bias, - T* Y); - -template <> -__global__ void AffineChannelNCHWCUDAKernel( - const int C, - const int M, - const int HxW, - const float* X, - const float* scale, - const float* bias, - float* Y) { - const int nc = blockIdx.x / M; - const int c = nc % C; - const int w = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (w < HxW) { - const int index = nc * HxW + w; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); -#else - Y[index] = fmaf(X[index], scale[c], bias[c]); -#endif - } -} - -template -__global__ void AffineChannelNHWCCUDAKernel( - const int C, - const T* X, - const T* scale, - const T* bias, - T* Y); - -template <> -__global__ void AffineChannelNHWCCUDAKernel( - const int C, - const float* X, - const float* scale, - const float* bias, - float* Y) { - const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (c < C) { - const int index = blockIdx.x * C + c; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); -#else - Y[index] = fmaf(X[index], scale[c], bias[c]); -#endif - } -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void AffineChannel( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CUDAContext* context) { \ - const int M = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \ - AffineChannelNCHWCUDAKernel \ - <<cuda_stream()>>>( \ - C, M, HxW, X, scale, bias, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void AffineChannel( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CUDAContext* context) { \ - const int M = DivUp(C, CAFFE_CUDA_NUM_THREADS); \ - AffineChannelNHWCCUDAKernel \ - <<cuda_stream()>>>(C, X, scale, bias, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float) -#undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu deleted file mode 100644 index d1911ae4db4c..000000000000 --- a/caffe2/utils/math/elementwise.cu +++ /dev/null @@ -1,918 +0,0 @@ -#include "caffe2/utils/math/elementwise.h" - -#include - -#include -#include -#include -#include - -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/conversions.h" -#include "caffe2/utils/math/half_utils.h" -#include "caffe2/utils/math/utils.h" - -namespace caffe2 { -namespace math { - -namespace { - -template -__global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) { - const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (i < N) { -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - c10::cuda::compat::sincos(__ldg(X + i), S + i, C + i); -#else - c10::cuda::compat::sincos(X[i], S + i, C + i); -#endif - } -} - -#if defined(USE_ROCM) - -template -__global__ void AxpyCUDAKernel( - const std::int64_t N, - const TAlpha alpha, - const TData* X, - TData* Y) { - const int64_t index = static_cast(blockIdx.x) * - static_cast(CAFFE_CUDA_NUM_THREADS) + - static_cast(threadIdx.x); - if (index < N) { - Y[index] += static_cast(alpha) * __ldg(X + index); - } -} - -template -__global__ void AxpyCUDAKernel( - const std::int64_t N, - const TAlpha* alpha, - const TData* X, - TData* Y) { - __shared__ TData a; - if (threadIdx.x == 0) { - a = static_cast(__ldg(alpha)); - } - __syncthreads(); - const int64_t index = static_cast(blockIdx.x) * - static_cast(CAFFE_CUDA_NUM_THREADS) + - static_cast(threadIdx.x); - if (index < N) { - Y[index] += a * __ldg(X + index); - } -} - -#define DELEGATE_HALF_AXPY_CUDA_KERNEL(TAlpha, FMAFunc) \ - template <> \ - __global__ void AxpyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const at::Half* X, \ - at::Half* Y) { \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To(FMAFunc( \ - alpha, \ - convert::To(X[index]), \ - convert::To(Y[index]))); \ - } \ - } \ - template <> \ - __global__ void AxpyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const at::Half* X, \ - at::Half* Y) { \ - __shared__ TAlpha a; \ - if (threadIdx.x == 0) { \ - a = __ldg(alpha); \ - } \ - __syncthreads(); \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To(FMAFunc( \ - a, \ - convert::To(X[index]), \ - convert::To(Y[index]))); \ - } \ - } -DELEGATE_HALF_AXPY_CUDA_KERNEL(float, fmaf) -#undef DELEGATE_HALF_AXPY_CUDA_KERNEL - -#endif // USE_ROCM - -template -__global__ void AxpbyCUDAKernel( - const std::int64_t N, - const TAlpha alpha, - const TData* X, - const TAlpha beta, - TData* Y); - -template -__global__ void AxpbyCUDAKernel( - const std::int64_t N, - const TAlpha* alpha, - const TData* X, - const TAlpha* beta, - TData* Y); - -#define DELEGATE_AXPBY_CUDA_KERNEL(TAlpha, TData, FMAFunc) \ - template <> \ - __global__ void AxpbyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - const TAlpha beta, \ - TData* Y) { \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = FMAFunc( \ - static_cast(alpha), \ - X[index], \ - static_cast(beta) * Y[index]); \ - } \ - } \ - template <> \ - __global__ void AxpbyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - const TAlpha* beta, \ - TData* Y) { \ - __shared__ TData a; \ - __shared__ TData b; \ - if (threadIdx.x == 0) { \ - a = static_cast(*alpha); \ - b = static_cast(*beta); \ - } \ - __syncthreads(); \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = FMAFunc(a, X[index], b * Y[index]); \ - } \ - } -DELEGATE_AXPBY_CUDA_KERNEL(float, float, fmaf) -DELEGATE_AXPBY_CUDA_KERNEL(float, double, fma) -#undef DELEGATE_AXPBY_CUDA_KERNEL - -#define DELEGATE_HALF_AXPBY_CUDA_KERNEL(TAlpha, FMAFunc) \ - template <> \ - __global__ void AxpbyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const at::Half* X, \ - const TAlpha beta, \ - at::Half* Y) { \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To(FMAFunc( \ - alpha, \ - convert::To(X[index]), \ - beta * convert::To(Y[index]))); \ - } \ - } \ - template <> \ - __global__ void AxpbyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const at::Half* X, \ - const TAlpha* beta, \ - at::Half* Y) { \ - __shared__ TAlpha a; \ - __shared__ TAlpha b; \ - if (threadIdx.x == 0) { \ - a = *alpha; \ - b = *beta; \ - } \ - __syncthreads(); \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To(FMAFunc( \ - a, \ - convert::To(X[index]), \ - b * convert::To(Y[index]))); \ - } \ - } -DELEGATE_HALF_AXPBY_CUDA_KERNEL(float, fmaf) -#undef DELEGATE_HALF_AXPBY_CUDA_KERNEL - -template -__global__ void ScaleCUDAKernel( - const std::int64_t N, - const TAlpha alpha, - const TData* X, - TData* Y); - -template -__global__ void ScaleCUDAKernel( - const std::int64_t N, - const TAlpha* alpha, - const TData* X, - TData* Y); - -#define CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(TAlpha, TData) \ - template <> \ - __global__ void ScaleCUDAKernel( \ - const std::int64_t N, const TAlpha alpha, const TData* X, TData* Y) { \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = static_cast(alpha) * X[index]; \ - } \ - } \ - template <> \ - __global__ void ScaleCUDAKernel( \ - const std::int64_t N, const TAlpha* alpha, const TData* X, TData* Y) { \ - __shared__ TData a; \ - if (threadIdx.x == 0) { \ - a = static_cast(*alpha); \ - } \ - __syncthreads(); \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = a * X[index]; \ - } \ - } -CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(float, float) -CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(double, double) -CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(float, double) -CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(std::int32_t, std::int32_t) -CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(std::int64_t, std::int64_t) -#undef CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL - -#define CAFFE2_SPECIALIZED_HALF_SCALE_CUDA_KERNEL(TAlpha) \ - template <> \ - __global__ void ScaleCUDAKernel( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const at::Half* X, \ - at::Half* Y) { \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To( \ - alpha * convert::To(X[index])); \ - } \ - } \ - template <> \ - __global__ void ScaleCUDAKernel( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const at::Half* X, \ - at::Half* Y) { \ - __shared__ TAlpha a; \ - if (threadIdx.x == 0) { \ - a = *alpha; \ - } \ - __syncthreads(); \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To( \ - a * convert::To(X[index])); \ - } \ - } -CAFFE2_SPECIALIZED_HALF_SCALE_CUDA_KERNEL(float) -#undef CAFFE2_SPECIALIZED_HALF_SCALE_CUDA_KERNEL - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_SET(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void Set( \ - const std::int64_t N, const T alpha, T* Y, CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (alpha == T(0)) { \ - C10_CUDA_CHECK(cudaMemsetAsync(Y, 0, sizeof(T) * N, context->cuda_stream())); \ - } else { \ - thrust::fill( \ - thrust::cuda::par.on(context->cuda_stream()), Y, Y + N, alpha); \ - } \ - } -CAFFE2_SPECIALIZED_CUDA_SET(bool) -CAFFE2_SPECIALIZED_CUDA_SET(char) -CAFFE2_SPECIALIZED_CUDA_SET(std::int8_t) -CAFFE2_SPECIALIZED_CUDA_SET(std::int16_t) -CAFFE2_SPECIALIZED_CUDA_SET(std::int32_t) -CAFFE2_SPECIALIZED_CUDA_SET(std::int64_t) -CAFFE2_SPECIALIZED_CUDA_SET(std::uint8_t) -CAFFE2_SPECIALIZED_CUDA_SET(std::uint16_t) -CAFFE2_SPECIALIZED_CUDA_SET(float) -CAFFE2_SPECIALIZED_CUDA_SET(double) -CAFFE2_SPECIALIZED_CUDA_SET(at::Half) -CAFFE2_SPECIALIZED_CUDA_SET(at::BFloat16) -#undef CAFFE2_SPECIALIZED_CUDA_SET - -#define DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(T, Func, DeviceFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int N, const T* X, T* Y, CUDAContext* context) { \ - if (N > 0) { \ - thrust::transform( \ - thrust::cuda::par.on(context->cuda_stream()), \ - X, \ - X + N, \ - Y, \ - [] __device__(const T x) { return DeviceFunc(x); }); \ - } \ - } -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log1p, log1pf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sin, sinf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Asin, asinf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cos, cosf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Acos, acosf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tan, tanf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Atan, atanf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sinh, sinhf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cosh, coshf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tanh, tanhf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Abs, fabsf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Inv, utils::Inv) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Inv, utils::Inv) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, utils::Square) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt, sqrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Rsqrt, rsqrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int32_t, - Cube, - utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int64_t, - Cube, - utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cube, utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Cube, utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cbrt, cbrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Erf, erff) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Erf, erf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, CdfNorm, normcdff) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, CdfNorm, normcdf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(bool, Not, utils::Not) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int32_t, - Neg, - utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int64_t, - Neg, - utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Neg, utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Neg, utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int32_t, - Sign, - utils::Sign) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int64_t, - Sign, - utils::Sign) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sign, utils::Sign) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sign, utils::Sign) -#undef DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION - -#define DELEGATE_CUDA_POWX(T, DeviceFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Powx( \ - const int N, const T* A, const T b, T* Y, CUDAContext* context) { \ - thrust::transform( \ - thrust::cuda::par.on(context->cuda_stream()), \ - A, \ - A + N, \ - Y, \ - [b] __device__(const T x) { return DeviceFunc(x, b); }); \ - } -DELEGATE_CUDA_POWX(float, powf) -#undef DELEGATE_CUDA_POWX - -#define CAFFE2_SPECIALIZED_CUDA_SINCOS(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void SinCos( \ - const int N, const T* X, T* S, T* C, CUDAContext* context) { \ - if (N > 0) { \ - const int K = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - SinCosCUDAKernel \ - <<cuda_stream()>>>( \ - N, X, S, C); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } -CAFFE2_SPECIALIZED_CUDA_SINCOS(float) -CAFFE2_SPECIALIZED_CUDA_SINCOS(double) -#undef CAFFE2_SPECIALIZED_CUDA_SINCOS - -#define DELEGATE_CUDA_SCALE(T, CuBLASFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (Y == X) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE(CuBLASFunc(context->cublas_handle(), N, &alpha, Y, 1)); \ - } else { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const T* alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (Y == X) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ - CUBLAS_ENFORCE(CuBLASFunc(context->cublas_handle(), N, alpha, Y, 1)); \ - } else { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } -DELEGATE_CUDA_SCALE(float, cublasSscal) -DELEGATE_CUDA_SCALE(double, cublasDscal) -#undef DELEGATE_CUDA_SCALE - -#if !defined(USE_ROCM) - -#define DELEGATE_CUDA_SCALE_EX( \ - TAlpha, TData, kAlphaType, kDataType, kExecutionType) \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (Y == X) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE(cublasScalEx( \ - context->cublas_handle(), \ - N, \ - &alpha, \ - kAlphaType, \ - Y, \ - kDataType, \ - 1, \ - kExecutionType)); \ - } else { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (Y == X) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ - CUBLAS_ENFORCE(cublasScalEx( \ - context->cublas_handle(), \ - N, \ - alpha, \ - kAlphaType, \ - Y, \ - kDataType, \ - 1, \ - kExecutionType)); \ - } else { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } -DELEGATE_CUDA_SCALE_EX(float, double, CUDA_R_32F, CUDA_R_64F, CUDA_R_64F) -DELEGATE_CUDA_SCALE_EX(float, at::Half, CUDA_R_32F, CUDA_R_16F, CUDA_R_32F) -#undef DELEGATE_CUDA_SCALE_EX - -#endif // USE_ROCM - -#define CAFFE2_SPECIALIZED_CUDA_SCALE(TAlpha, TData) \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - if (N > 0) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - if (N > 0) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, *alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } -CAFFE2_SPECIALIZED_CUDA_SCALE(std::int32_t, std::int32_t) -CAFFE2_SPECIALIZED_CUDA_SCALE(std::int64_t, std::int64_t) - -#if defined(USE_ROCM) -CAFFE2_SPECIALIZED_CUDA_SCALE(float, double) -CAFFE2_SPECIALIZED_CUDA_SCALE(float, at::Half) -#endif // USE_ROCM -#undef CAFFE2_SPECIALIZED_CUDA_SCALE - -#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Func, DeviceFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int N, const T* A, const T* B, T* C, CUDAContext* context) { \ - if (N > 0) { \ - thrust::transform( \ - thrust::cuda::par.on(context->cuda_stream()), \ - A, \ - A + N, \ - B, \ - C, \ - DeviceFunc); \ - } \ - } -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - Add, - thrust::plus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - Add, - thrust::plus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Add, thrust::plus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Add, thrust::plus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Add, utils::HalfAddFunctor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - Sub, - thrust::minus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - Sub, - thrust::minus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Sub, thrust::minus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Sub, thrust::minus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Sub, utils::HalfSubFunctor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - Mul, - thrust::multiplies()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - Mul, - thrust::multiplies()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Mul, thrust::multiplies()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Mul, thrust::multiplies()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Mul, utils::HalfMulFunctor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - Div, - thrust::divides()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - Div, - thrust::divides()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Div, thrust::divides()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Div, thrust::divides()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Div, utils::HalfDivFunctor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Min, thrust::minimum()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Min, thrust::minimum()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Max, thrust::maximum()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Max, thrust::maximum()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, And, thrust::logical_and()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, Or, thrust::logical_or()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, Xor, thrust::bit_xor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseAnd, thrust::bit_and()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - BitwiseAnd, - thrust::bit_and()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - BitwiseAnd, - thrust::bit_and()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseOr, thrust::bit_or()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - BitwiseOr, - thrust::bit_or()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - BitwiseOr, - thrust::bit_or()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseXor, thrust::bit_xor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - BitwiseXor, - thrust::bit_xor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - BitwiseXor, - thrust::bit_xor()) -#undef DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION - -#define DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(T, Func, DeviceComp) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int N, const T* A, const T* B, bool* C, CUDAContext* context) { \ - if (N > 0) { \ - thrust::transform( \ - thrust::cuda::par.on(context->cuda_stream()), \ - A, \ - A + N, \ - B, \ - C, \ - DeviceComp); \ - } \ - } -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, EQ, thrust::equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - EQ, - thrust::equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - EQ, - thrust::equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, EQ, thrust::equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, EQ, thrust::equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, NE, thrust::not_equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - NE, - thrust::not_equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - NE, - thrust::not_equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, NE, thrust::not_equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - double, - NE, - thrust::not_equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, LT, thrust::less()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - LT, - thrust::less()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - LT, - thrust::less()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, LT, thrust::less()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, LT, thrust::less()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, LE, thrust::less_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - LE, - thrust::less_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - LE, - thrust::less_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, LE, thrust::less_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, LE, thrust::less_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, GT, thrust::greater()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - GT, - thrust::greater()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - GT, - thrust::greater()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, GT, thrust::greater()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, GT, thrust::greater()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, GE, thrust::greater_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - GE, - thrust::greater_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - GE, - thrust::greater_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, GE, thrust::greater_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - double, - GE, - thrust::greater_equal()) -#undef DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION - -#define DELEGATE_CUDA_AXPY(T, CuBLASFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE( \ - CuBLASFunc(context->cublas_handle(), N, &alpha, X, 1, Y, 1)); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const T* alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ - CUBLAS_ENFORCE( \ - cublasSaxpy(context->cublas_handle(), N, alpha, X, 1, Y, 1)); \ - } -DELEGATE_CUDA_AXPY(float, cublasSaxpy) -#undef DELEGATE_CUDA_AXPY - -#if !defined(USE_ROCM) - -#define DELEGATE_CUDA_AXPY_EX( \ - TAlpha, TData, kAlphaType, kDataType, kExecutionType) \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE(cublasAxpyEx( \ - context->cublas_handle(), \ - N, \ - &alpha, \ - kAlphaType, \ - X, \ - kDataType, \ - 1, \ - Y, \ - kDataType, \ - 1, \ - kExecutionType)); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ - CUBLAS_ENFORCE(cublasAxpyEx( \ - context->cublas_handle(), \ - N, \ - alpha, \ - kAlphaType, \ - X, \ - kDataType, \ - 1, \ - Y, \ - kDataType, \ - 1, \ - kExecutionType)); \ - } -DELEGATE_CUDA_AXPY_EX(float, double, CUDA_R_32F, CUDA_R_64F, CUDA_R_64F) -DELEGATE_CUDA_AXPY_EX(float, at::Half, CUDA_R_32F, CUDA_R_16F, CUDA_R_32F) -#undef DELEGATE_CUDA_AXPY_EX - -#else // USE_ROCM - -#define CAFFE2_SPECIALIZED_CUDA_AXPY(TAlpha, TData) \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - AxpyCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - AxpyCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_AXPY(float, double) -CAFFE2_SPECIALIZED_CUDA_AXPY(float, at::Half) -#undef CAFFE2_SPECIALIZED_CUDA_AXPY - -#endif // USE_ROCM - -#define CAFFE2_SPECIALIZED_CUDA_AXPBY(TAlpha, TData) \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpby( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - const TAlpha beta, \ - TData* Y, \ - CUDAContext* context) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - AxpbyCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, beta, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpby( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - const TAlpha* beta, \ - TData* Y, \ - CUDAContext* context) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - AxpbyCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, beta, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_AXPBY(float, float) -CAFFE2_SPECIALIZED_CUDA_AXPBY(float, double) -CAFFE2_SPECIALIZED_CUDA_AXPBY(float, at::Half) -#undef CAFFE2_SPECIALIZED_CUDA_AXPBY - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math/reduce.cu b/caffe2/utils/math/reduce.cu deleted file mode 100644 index d59cbd387753..000000000000 --- a/caffe2/utils/math/reduce.cu +++ /dev/null @@ -1,593 +0,0 @@ -#include "caffe2/utils/math/reduce.h" - -#include -#include -#include -#include -#include -#include "caffe2/utils/cub_namespace.cuh" -#include - -#include -#include -#include - -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math/elementwise.h" -#include "caffe2/utils/math/reduce.cuh" -#include "caffe2/utils/math/utils.h" - -namespace caffe2 { -namespace math { - -namespace { - -template -__global__ void RowwiseReduceCUDAKernel( - const int cols, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce::TempStorage temp_storage; - const int r = blockIdx.x; - T val = init; - for (int c = threadIdx.x; c < cols; c += blockDim.x) { -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val = reducer(val, __ldg(X + r * cols + c)); -#else - val = reducer(val, X[r * cols + c]); -#endif - } - val = BlockReduce(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[r] = val * alpha; - } -} - -template -__global__ void ColwiseReduceCUDAKernel( - const int rows, - const int cols, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce::TempStorage temp_storage; - const int c = blockIdx.x; - T val = init; - for (int r = threadIdx.x; r < rows; r += blockDim.x) { -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val = reducer(val, __ldg(X + r * cols + c)); -#else - val = reducer(val, X[r * cols + c]); -#endif - } - val = BlockReduce(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[c] = val * alpha; - } -} - -template -__global__ void BothEndsReduceCUDAKernel( - const int M, - const int N, - const int K, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce2D::TempStorage - temp_storage; - const int n = blockIdx.x; - T val = init; - for (int m = threadIdx.x; m < M; m += blockDim.x) { - for (int k = threadIdx.y; k < K; k += blockDim.y) { -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val = reducer(val, __ldg(X + (m * N + n) * K + k)); -#else - val = reducer(val, X[(m * N + n) * K + k]); -#endif - } - } - val = BlockReduce2D(temp_storage) - .Reduce(val, reducer); - if (threadIdx.x == 0 && threadIdx.y == 0) { - Y[n] = val * alpha; - } -} - -template -__global__ void ReduceTensorCUDAKernel( - const int inner_size, - const SimpleArray X_strides, - const SimpleArray Y_dims, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce::TempStorage temp_storage; - const int x = blockIdx.x; - T val = init; - for (int y = threadIdx.x; y < inner_size; y += blockDim.x) { - int X_index = 0; - int Y_index = x * inner_size + y; -#pragma unroll - for (int d = D - 1; d >= 0; --d) { - X_index += Y_index % Y_dims.data[d] * X_strides.data[d]; - Y_index /= Y_dims.data[d]; - } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val = reducer(val, __ldg(X + X_index)); -#else - val = reducer(val, X[X_index]); -#endif - } - val = BlockReduce(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[x] = val * alpha; - } -} - -template -void ReduceTensorCUDAImpl( - const int outer_size, - const int inner_size, - const int* dims, - const int* axes, - const Reducer& reducer, - const T init, - const T alpha, - const T* X, - T* Y, - CUDAContext* context) { - SimpleArray X_strides; - SimpleArray Y_dims; - utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); - for (int i = 0; i < D; ++i) { - Y_dims.data[i] = dims[axes[i]]; - } - ReduceTensorCUDAKernel - <<cuda_stream()>>>( - inner_size, X_strides, Y_dims, reducer, init, alpha, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void ReduceTensorCUDA( - const int ndim, - const int* X_dims, - const int* Y_dims, - const Reducer& reducer, - const T init, - const T alpha, - const T* X, - T* Y, - CUDAContext* context) { - CAFFE_ENFORCE(utils::CheckReduceDims(ndim, X_dims, Y_dims)); - const int X_size = - std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies()); - const int Y_size = - std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies()); - if (X_size == 0) { - Set(Y_size, init * alpha, Y, context); - return; - } - if (std::equal(X_dims, X_dims + ndim, Y_dims)) { - Scale(X_size, alpha, X, Y, context); - return; - } - int rows; - int cols; - if (utils::IsRowwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { - RowwiseReduceCUDAKernel - <<cuda_stream()>>>( - cols, reducer, init, alpha, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - if (utils::IsColwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { - ColwiseReduceCUDAKernel - <<cuda_stream()>>>( - rows, cols, reducer, init, alpha, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - int M; - int N; - int K; - if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &M, &N, &K)) { - DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_2( - K, - BothEndsReduceCUDAKernel, - T, - Reducer, - N, - context->cuda_stream(), - M, - N, - K, - reducer, - init, - alpha, - X, - Y); - return; - } - std::vector axes(ndim); - utils::ComputeTransposeAxesForReduceOp(ndim, Y_dims, axes.data()); - const int outer_size = Y_size; - const int inner_size = X_size / Y_size; - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2( - ndim, - ReduceTensorCUDAImpl, - T, - Reducer, - outer_size, - inner_size, - X_dims, - axes.data(), - reducer, - init, - alpha, - X, - Y, - context); -} - -template -__global__ void -RowwiseMomentsCUDAKernel(const int cols, const T* X, T* mean, T* var) { - __shared__ typename BlockReduce::TempStorage m_storage; - __shared__ typename BlockReduce::TempStorage v_storage; - const T scale = T(1) / static_cast(cols); - const int r = blockIdx.x; - T m_val = 0; - T v_val = 0; - for (int c = threadIdx.x; c < cols; c += blockDim.x) { - const int X_index = r * cols + c; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - m_val += __ldg(X + X_index); - v_val += __ldg(X + X_index) * __ldg(X + X_index); -#else - m_val += X[X_index]; - v_val += X[X_index] * X[X_index]; -#endif - } - m_val = BlockReduce(m_storage).Sum(m_val); - v_val = BlockReduce(v_storage).Sum(v_val); - if (threadIdx.x == 0) { - const T mu = m_val * scale; - mean[r] = mu; - var[r] = v_val * scale - mu * mu; - } -} - -template -__global__ void ColwiseMomentsCUDAKernel( - const int rows, - const int cols, - const T* X, - T* mean, - T* var) { - __shared__ typename BlockReduce::TempStorage m_storage; - __shared__ typename BlockReduce::TempStorage v_storage; - const T scale = T(1) / static_cast(rows); - const int c = blockIdx.x; - T m_val = 0; - T v_val = 0; - for (int r = threadIdx.x; r < rows; r += blockDim.x) { - const int X_index = r * cols + c; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - m_val += __ldg(X + X_index); - v_val += __ldg(X + X_index) * __ldg(X + X_index); -#else - m_val += X[X_index]; - v_val += X[X_index] * X[X_index]; -#endif - } - m_val = BlockReduce(m_storage).Sum(m_val); - v_val = BlockReduce(v_storage).Sum(v_val); - if (threadIdx.x == 0) { - const T mu = m_val * scale; - mean[c] = mu; - var[c] = v_val * scale - mu * mu; - } -} - -template -__global__ void BothEndsMomentsCUDAKernel( - const int M, - const int N, - const int K, - const T* X, - T* mean, - T* var) { - __shared__ - typename BlockReduce2D::TempStorage m_storage; - __shared__ - typename BlockReduce2D::TempStorage v_storage; - const T scale = T(1) / static_cast(M * K); - const int n = blockIdx.x; - T m_val = 0; - T v_val = 0; - for (int m = threadIdx.x; m < M; m += blockDim.x) { - for (int k = threadIdx.y; k < K; k += blockDim.y) { - const int X_index = (m * N + n) * K + k; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - m_val += __ldg(X + X_index); - v_val += __ldg(X + X_index) * __ldg(X + X_index); -#else - m_val += X[X_index]; - v_val += X[X_index] * X[X_index]; -#endif - } - } - m_val = BlockReduce2D(m_storage).Sum(m_val); - v_val = BlockReduce2D(v_storage).Sum(v_val); - if (threadIdx.x == 0 && threadIdx.y == 0) { - const T mu = m_val * scale; - mean[n] = mu; - var[n] = v_val * scale - mu * mu; - } -} - -template -__global__ void MomentsCUDAKernel( - const int inner_size, - const SimpleArray X_strides, - const SimpleArray Y_dims, - const T* X, - T* mean, - T* var) { - __shared__ typename BlockReduce::TempStorage m_storage; - __shared__ typename BlockReduce::TempStorage v_storage; - const T scale = T(1) / static_cast(inner_size); - const int x = blockIdx.x; - T m_val = 0; - T v_val = 0; - for (int y = threadIdx.x; y < inner_size; y += blockDim.x) { - int X_index = 0; - int Y_index = x * inner_size + y; -#pragma unroll - for (int d = D - 1; d >= 0; --d) { - X_index += Y_index % Y_dims.data[d] * X_strides.data[d]; - Y_index /= Y_dims.data[d]; - } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - m_val += __ldg(X + X_index); - v_val += __ldg(X + X_index) * __ldg(X + X_index); -#else - m_val += X[X_index]; - v_val += X[X_index] * X[X_index]; -#endif - } - m_val = BlockReduce(m_storage).Sum(m_val); - v_val = BlockReduce(v_storage).Sum(v_val); - if (threadIdx.x == 0) { - const T mu = m_val * scale; - mean[x] = mu; - var[x] = v_val * scale - mu * mu; - } -} - -template -void MomentsCUDAImpl( - const int outer_size, - const int inner_size, - const int* dims, - const int* axes, - const T* X, - T* mean, - T* var, - CUDAContext* context) { - SimpleArray X_strides; - SimpleArray Y_dims; - utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); - for (int i = 0; i < D; ++i) { - Y_dims.data[i] = dims[axes[i]]; - } - MomentsCUDAKernel - <<cuda_stream()>>>( - inner_size, X_strides, Y_dims, X, mean, var); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void MomentsCUDA( - const int ndim, - const int* X_dims, - const int* Y_dims, - const T* X, - T* mean, - T* var, - CUDAContext* context) { - CAFFE_ENFORCE(utils::CheckReduceDims(ndim, X_dims, Y_dims)); - const int X_size = - std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies()); - const int Y_size = - std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies()); - if (X_size == 0) { - Set(Y_size, T(0), mean, context); - Set(Y_size, T(0), var, context); - return; - } - if (std::equal(X_dims, X_dims + ndim, Y_dims)) { - C10_CUDA_CHECK(cudaMemcpyAsync( - mean, - X, - sizeof(T) * X_size, - cudaMemcpyDeviceToDevice, - context->cuda_stream())); - Set(Y_size, T(0), var, context); - return; - } - int rows; - int cols; - if (utils::IsRowwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { - RowwiseMomentsCUDAKernel - <<cuda_stream()>>>( - cols, X, mean, var); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - if (utils::IsColwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { - ColwiseMomentsCUDAKernel - <<cuda_stream()>>>( - rows, cols, X, mean, var); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - int M; - int N; - int K; - if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &M, &N, &K)) { - DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( - K, - BothEndsMomentsCUDAKernel, - T, - N, - context->cuda_stream(), - M, - N, - K, - X, - mean, - var); - return; - } - std::vector axes(ndim); - utils::ComputeTransposeAxesForReduceOp(ndim, Y_dims, axes.data()); - const int outer_size = Y_size; - const int inner_size = X_size / Y_size; - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( - ndim, - MomentsCUDAImpl, - T, - outer_size, - inner_size, - X_dims, - axes.data(), - X, - mean, - var, - context); -} - -} // namespace - -#define DELEGATE_CUDA_REDUCE_FUNCTION(T, Func, Reducer, kInit) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int ndim, \ - const int* X_dims, \ - const int* Y_dims, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context, \ - bool) { \ - ReduceTensorCUDA( \ - ndim, X_dims, Y_dims, Reducer(), kInit, alpha, X, Y, context); \ - } -DELEGATE_CUDA_REDUCE_FUNCTION( - std::int32_t, - ReduceMin, - cub::Min, - std::numeric_limits::max()) -DELEGATE_CUDA_REDUCE_FUNCTION( - std::int64_t, - ReduceMin, - cub::Min, - std::numeric_limits::max()) -DELEGATE_CUDA_REDUCE_FUNCTION( - float, - ReduceMin, - cub::Min, - std::numeric_limits::max()) -DELEGATE_CUDA_REDUCE_FUNCTION( - double, - ReduceMin, - cub::Min, - std::numeric_limits::max()) -DELEGATE_CUDA_REDUCE_FUNCTION( - std::int32_t, - ReduceMax, - cub::Max, - std::numeric_limits::lowest()) -DELEGATE_CUDA_REDUCE_FUNCTION( - std::int64_t, - ReduceMax, - cub::Max, - std::numeric_limits::lowest()) -DELEGATE_CUDA_REDUCE_FUNCTION( - float, - ReduceMax, - cub::Max, - std::numeric_limits::lowest()) -DELEGATE_CUDA_REDUCE_FUNCTION( - double, - ReduceMax, - cub::Max, - std::numeric_limits::lowest()) -DELEGATE_CUDA_REDUCE_FUNCTION(std::int32_t, ReduceSum, cub::Sum, 0) -DELEGATE_CUDA_REDUCE_FUNCTION(std::int64_t, ReduceSum, cub::Sum, 0LL) -DELEGATE_CUDA_REDUCE_FUNCTION(float, ReduceSum, cub::Sum, 0.0f) -DELEGATE_CUDA_REDUCE_FUNCTION(double, ReduceSum, cub::Sum, 0.0) -#undef DELEGATE_CUDA_REDUCE_FUNCTION - -#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void ReduceMean( \ - const int ndim, \ - const int* X_dims, \ - const int* Y_dims, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context, \ - bool) { \ - int scale = 1; \ - for (int i = 0; i < ndim; ++i) { \ - if (Y_dims[i] == 1) { \ - scale *= X_dims[i]; \ - } \ - } \ - ReduceTensorCUDA( \ - ndim, \ - X_dims, \ - Y_dims, \ - cub::Sum(), \ - T(0), \ - alpha / static_cast(scale), \ - X, \ - Y, \ - context); \ - } -CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(float) -#undef CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN - -#define CAFFE2_SPECIALIZED_CUDA_MOMENTS(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void Moments( \ - const int ndim, \ - const int* X_dims, \ - const int* Y_dims, \ - const T* X, \ - T* mean, \ - T* var, \ - CUDAContext* context, \ - bool) { \ - MomentsCUDA(ndim, X_dims, Y_dims, X, mean, var, context); \ - } -CAFFE2_SPECIALIZED_CUDA_MOMENTS(float) -CAFFE2_SPECIALIZED_CUDA_MOMENTS(double) -#undef CAFFE2_SPECIALIZED_CUDA_MOMENTS - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math/reduce.cuh b/caffe2/utils/math/reduce.cuh deleted file mode 100644 index 18bdca11b9de..000000000000 --- a/caffe2/utils/math/reduce.cuh +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef CAFFE2_UTILS_MATH_REDUCE_CUH_ -#define CAFFE2_UTILS_MATH_REDUCE_CUH_ - -#include "caffe2/utils/cub_namespace.cuh" -#include - -#include "caffe2/core/common_gpu.h" - -namespace caffe2 { - -template -using BlockReduce = cub::BlockReduce; - -template -using BlockReduce2D = cub:: - BlockReduce; - -#define DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( \ - size, Func, T, grid_dim, cuda_stream, ...) \ - do { \ - if (size >= 128) { \ - Func \ - <<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else if (size >= 64) { \ - Func<<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else if (size >= 32) { \ - Func<<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else { \ - Func<<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } while (false) - -#define DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_2( \ - size, Func, T1, T2, grid_dim, cuda_stream, ...) \ - do { \ - if (size >= 128) { \ - Func \ - <<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else if (size >= 64) { \ - Func \ - <<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else if (size >= 32) { \ - Func \ - <<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else { \ - Func \ - <<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } while (false) - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_MATH_REDUCE_CUH_ diff --git a/caffe2/utils/math/transpose.cu b/caffe2/utils/math/transpose.cu deleted file mode 100644 index c3e213190856..000000000000 --- a/caffe2/utils/math/transpose.cu +++ /dev/null @@ -1,233 +0,0 @@ -#include "caffe2/utils/math/transpose.h" - -#include -#include -#include - -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math/utils.h" - -namespace caffe2 { -namespace math { - -namespace { - -constexpr int kTileDim = 32; -constexpr int kBlockRows = 8; - -// Splits the original matrix into submatrices with size 32 * 32. -// Each block transposes one submatrix by loading it into shared memory. -// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/ -template -__global__ void BatchTranspose2DCUDAKernel( - const TIndex H, - const TIndex W, - const TIndex dh, - const TIndex dw, - const TData* X, - TData* Y) { - __shared__ TData tile[kTileDim][kTileDim + 1]; - const TIndex n = blockIdx.x / (dh * dw); - const TIndex k = blockIdx.x % (dh * dw); - const TIndex r = k / dw; - const TIndex c = k % dw; - const TIndex offset = n * H * W; - int x = c * kTileDim + threadIdx.x; - int y = r * kTileDim + threadIdx.y; - if (x < W) { - for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) { -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - tile[threadIdx.y + i][threadIdx.x] = __ldg(X + offset + (y + i) * W + x); -#else - tile[threadIdx.y + i][threadIdx.x] = X[offset + (y + i) * W + x]; -#endif - } - } - __syncthreads(); - x = r * kTileDim + threadIdx.x; - y = c * kTileDim + threadIdx.y; - if (x < H) { - for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) { - Y[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i]; - } - } -} - -template -void BatchTranspose2DCUDAImpl( - const TIndex N, - const TIndex H, - const TIndex W, - const TData* X, - TData* Y, - CUDAContext* context) { - const TIndex dh = DivUp(H, kTileDim); - const TIndex dw = DivUp(W, kTileDim); - BatchTranspose2DCUDAKernel - <<cuda_stream()>>>( - H, W, dh, dw, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -#define DELEGATE_TRANSPOSE_2D_CUDA_IMPL(TIndex, TData, CuBLASFunc) \ - template <> \ - void BatchTranspose2DCUDAImpl( \ - const TIndex N, \ - const TIndex H, \ - const TIndex W, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - if (N == 1) { \ - const TData kAlpha = TData(1); \ - const TData kBeta = TData(0); \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE(CuBLASFunc( \ - context->cublas_handle(), \ - CUBLAS_OP_T, \ - CUBLAS_OP_N, \ - H, \ - W, \ - &kAlpha, \ - X, \ - W, \ - &kBeta, \ - Y, \ - H, \ - Y, \ - H)); \ - } else { \ - const TIndex dh = DivUp(H, kTileDim); \ - const TIndex dw = DivUp(W, kTileDim); \ - BatchTranspose2DCUDAKernel \ - <<cuda_stream()>>>(H, W, dh, dw, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } -DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int32_t, float, cublasSgeam) -DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int64_t, float, cublasSgeam) -DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int32_t, double, cublasDgeam) -DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int64_t, double, cublasDgeam) -#undef DELEGATE_TRANSPOSE_2D_CUDA_IMPL - -template -__global__ void TransposeCUDAKernel( - const TIndex size, - const SimpleArray X_strides, - const SimpleArray Y_dims, - const TData* X, - TData* Y) { - const int Y_index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (Y_index < size) { - TIndex X_index = 0; - TIndex v = Y_index; -#pragma unroll - for (int i = D - 1; i >= 0; --i) { - X_index += v % Y_dims.data[i] * X_strides.data[i]; - v /= Y_dims.data[i]; - } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - Y[Y_index] = __ldg(X + X_index); -#else - Y[Y_index] = X[X_index]; -#endif - } -} - -template -void TransposeCUDAImpl( - const TIndex* dims, - const int* axes, - const TData* X, - TData* Y, - CUDAContext* context) { - SimpleArray X_strides; - SimpleArray Y_dims; - utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); - TIndex size = 1; - for (int i = 0; i < D; ++i) { - Y_dims.data[i] = dims[axes[i]]; - size *= dims[i]; - } - const TIndex M = DivUp(size, CAFFE_CUDA_NUM_THREADS); - TransposeCUDAKernel - <<cuda_stream()>>>( - size, X_strides, Y_dims, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(TIndex, TData) \ - template <> \ - CAFFE2_CUDA_EXPORT void Transpose( \ - const int ndim, \ - const TIndex* dims, \ - const int* axes, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - const TIndex size = std::accumulate( \ - dims, dims + ndim, TIndex(1), std::multiplies()); \ - if (size == 0) { \ - return; \ - } \ - if (utils::IsIdentityPermutation(ndim, axes)) { \ - context->template CopySameDevice(size, X, Y); \ - return; \ - } \ - if (utils::IsBatchTranspose2D(ndim, axes)) { \ - const int H = dims[ndim - 2]; \ - const int W = dims[ndim - 1]; \ - const int N = size / (H * W); \ - BatchTranspose2DCUDAImpl(N, H, W, X, Y, context); \ - return; \ - } \ - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2( \ - ndim, TransposeCUDAImpl, TIndex, TData, dims, axes, X, Y, context); \ - } -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, float) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, float) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, double) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, double) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, std::int32_t) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, std::int32_t) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, std::int64_t) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, std::int64_t) -#undef CAFFE2_SPECIALIZED_CUDA_TRANSPOSE - -#define CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void NCHW2NHWC( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - BatchTranspose2DCUDAImpl(N, C, HxW, X, Y, context); \ - } -CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC(float) -#undef CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC - -#define CAFFE2_SPECIALIZED_CUDA_NHWC2NCHW(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void NHWC2NCHW( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - BatchTranspose2DCUDAImpl(N, HxW, C, X, Y, context); \ - } -CAFFE2_SPECIALIZED_CUDA_NHWC2NCHW(float) -#undef CAFFE2_SPECIALIZED_CUDA_NHWC2NCHW - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu deleted file mode 100644 index e6dfbf85039f..000000000000 --- a/caffe2/utils/math_gpu.cu +++ /dev/null @@ -1,2871 +0,0 @@ -// Implements the math functions for GPU. - -#include "caffe2/utils/math.h" - -#include -#include -#include -#include - -#include -#include -#include "caffe2/utils/cub_namespace.cuh" - -#include -#include -#include - -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/GpuAtomics.cuh" -#include "caffe2/utils/conversions.h" - -#include "caffe2/utils/fixed_divisor.h" -// TODO: Move this to fixed_divisor.h -#if defined(USE_ROCM) -#define FIXED_DIVISOR int32_t -#define FIXED_DIVISOR_DIV(d, n) (n / d) -#define FIXED_DIVISOR_MOD(d, n) (n % d) -#define FIXED_DIVISOR_DIV_MOD(d, n, q, r) \ - do { \ - const auto n_copy = n; \ - *q = n_copy / d; \ - *r = n_copy % d; \ - } while (0) -#else // USE_ROCM -#define FIXED_DIVISOR FixedDivisor -#define FIXED_DIVISOR_DIV(d, n) (d.Div(n)) -#define FIXED_DIVISOR_MOD(d, n) (d.Mod(n)) -#define FIXED_DIVISOR_DIV_MOD(d, n, q, r) (d.DivMod(n, q, r)) -#endif // USE_ROCM - -#if defined(USE_ROCM) -#define CUBLAS_HALF_TYPE hipblasHalf -#define HIPBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT -// until we use hipblas v2 -// hipify correctly maps things like CUDA_R_16F to HIP_R_16F, -// however hipblas v1 is still using its custom type -#ifndef HIPBLAS_V2 -#define HIP_R_16F HIPBLAS_R_16F -#define HIP_R_32F HIPBLAS_R_32F -#endif // HIPBLAS_V2 -#else // USE_ROCM -#define CUBLAS_HALF_TYPE __half -#endif // USE_ROCM - -#include "caffe2/utils/math/utils.h" - -#if THRUST_VERSION >= 100800 -#define THRUST_SUPPORTS_PER_THREAD -#endif // THRUST_VERSION >= 100800 - -namespace caffe2 { -namespace math { - -namespace { - -#define DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Func, expr) \ - template \ - struct Func##Functor { \ - inline __host__ __device__ T \ - operator()(const T& lhs, const T& rhs) const { \ - return lhs expr rhs; \ - } \ - }; \ - template <> \ - struct Func##Functor { \ - inline __host__ __device__ at::Half operator()( \ - const at::Half& lhs, \ - const at::Half& rhs) const { \ - return convert::To(convert::To( \ - lhs) expr convert::To(rhs)); \ - } \ - }; -DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Add, +) -DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Sub, -) -DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Mul, *) -DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Div, /) -#undef DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR - -template -__global__ void SimpleBinaryOpCUDAKernel( - const int N, - const BinaryOperator op, - const TIn* A, - const TIn* B, - TOut* C) { - CUDA_1D_KERNEL_LOOP(i, N) { - C[i] = op(A[i], B[i]); - } -} - -template -__global__ void RowwiseBinaryOpCUDAKenel( - const int size, - const FIXED_DIVISOR cols, - const BinaryOperator op, - const TIn* A, - const TIn* B, - TOut* C) { - CUDA_1D_KERNEL_LOOP(C_index, size) { - const int j = FIXED_DIVISOR_MOD(cols, C_index); - const int A_index = broadcast_1st ? j : C_index; - const int B_index = broadcast_1st ? C_index : j; - C[C_index] = op(A[A_index], B[B_index]); - } -} - -template -__global__ void ColwiseBinaryOpCUDAKenel( - const int size, - const FIXED_DIVISOR cols, - const BinaryOperator op, - const TIn* A, - const TIn* B, - TOut* C) { - CUDA_1D_KERNEL_LOOP(C_index, size) { - const int i = FIXED_DIVISOR_DIV(cols, C_index); - const int A_index = broadcast_1st ? i : C_index; - const int B_index = broadcast_1st ? C_index : i; - C[C_index] = op(A[A_index], B[B_index]); - } -} - -template -__global__ void BroadcastBinaryOpCUDAKernel( - const int size, - const SimpleArray A_strides, - const SimpleArray B_strides, - const SimpleArray C_dims, - const BinaryOperator op, - const TIn* A, - const TIn* B, - TOut* C) { - CUDA_1D_KERNEL_LOOP(C_index, size) { - int A_index = 0; - int B_index = 0; - int C_index_val = C_index; -#pragma unroll - for (int i = D - 1; i >= 0; --i) { - int d; - FIXED_DIVISOR_DIV_MOD(C_dims.data[i], C_index_val, &C_index_val, &d); - A_index += d * A_strides.data[i]; - B_index += d * B_strides.data[i]; - } - C[C_index] = op(A[A_index], B[B_index]); - } -} - -template -CAFFE2_CUDA_EXPORT void BinaryOpWith2DBroadcasting( - const int rows, - const int cols, - const bool rowwise_broadcast, - const bool broadcast_1st, - const BinaryOperator& op, - const TIn* A, - const TIn* B, - TOut* C, - CUDAContext* context) { - if (rows == 0 || cols == 0) { - return; - } - const int size = rows * cols; - const FIXED_DIVISOR cols_div(cols); - if (rowwise_broadcast) { - if (broadcast_1st) { - RowwiseBinaryOpCUDAKenel - <<cuda_stream()>>>(size, cols_div, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - RowwiseBinaryOpCUDAKenel - <<cuda_stream()>>>(size, cols_div, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } else { - if (broadcast_1st) { - ColwiseBinaryOpCUDAKenel - <<cuda_stream()>>>(size, cols_div, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - ColwiseBinaryOpCUDAKenel - <<cuda_stream()>>>(size, cols_div, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } -} - -template -CAFFE2_CUDA_EXPORT void BroadcastBinaryOpImpl( - const int* A_dims, - const int* B_dims, - const int* C_dims, - const BinaryOperator& op, - const TIn* A, - const TIn* B, - TOut* C, - CUDAContext* context) { - SimpleArray A_strides_array; - SimpleArray B_strides_array; - SimpleArray C_dims_array; - int A_stride = 1; - int B_stride = 1; - for (int i = D - 1; i >= 0; --i) { - if (C_dims[i] == 0) { - return; - } - A_strides_array.data[i] = A_dims[i] == 1 ? 0 : A_stride; - B_strides_array.data[i] = B_dims[i] == 1 ? 0 : B_stride; - A_stride *= A_dims[i]; - B_stride *= B_dims[i]; - C_dims_array.data[i] = FIXED_DIVISOR(C_dims[i]); - } - const int size = - std::accumulate(C_dims, C_dims + D, 1, std::multiplies()); - BroadcastBinaryOpCUDAKernel - <<cuda_stream()>>>( - size, A_strides_array, B_strides_array, C_dims_array, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -CAFFE2_CUDA_EXPORT void BroadcastBinaryOp( - const int A_ndim, - const int* A_dims, - const int B_ndim, - const int* B_dims, - const BinaryOperator& op, - const TIn* A, - const TIn* B, - TOut* C, - CUDAContext* context) { - const int ndim = std::max(A_ndim, B_ndim); - std::vector A_dims_array(ndim); - std::vector B_dims_array(ndim); - std::vector C_dims_array(ndim); - utils::ComputeBroadcastBinaryOpDims( - A_ndim, - A_dims, - B_ndim, - B_dims, - A_dims_array.data(), - B_dims_array.data(), - C_dims_array.data()); - if (A_dims_array == B_dims_array) { - const int size = std::accumulate( - C_dims_array.cbegin(), C_dims_array.cend(), 1, std::multiplies()); - SimpleBinaryOpCUDAKernel - <<cuda_stream()>>>(size, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - int rows; - int cols; - bool broadcast_1st; - if (utils::IsRowwiseBroadcastBinaryOp( - ndim, - A_dims_array.data(), - B_dims_array.data(), - &rows, - &cols, - &broadcast_1st)) { - BinaryOpWith2DBroadcasting( - rows, cols, true, broadcast_1st, op, A, B, C, context); - return; - } - if (utils::IsColwiseBroadcastBinaryOp( - ndim, - A_dims_array.data(), - B_dims_array.data(), - &rows, - &cols, - &broadcast_1st)) { - BinaryOpWith2DBroadcasting( - rows, cols, false, broadcast_1st, op, A, B, C, context); - return; - } - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_3( - ndim, - BroadcastBinaryOpImpl, - TIn, - TOut, - BinaryOperator, - A_dims_array.data(), - B_dims_array.data(), - C_dims_array.data(), - op, - A, - B, - C, - context); -} - -} // namespace - -#define DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - CAFFE2_CUDA_EXPORT void Rowwise##Func( \ - const int rows, \ - const int cols, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FIXED_DIVISOR cols_div(cols); \ - RowwiseBinaryOpCUDAKenel, true> \ - <<cuda_stream()>>>(size, cols_div, Op(), A, B, C); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Rowwise##Func( \ - const int rows, \ - const int cols, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FIXED_DIVISOR cols_div(cols); \ - RowwiseBinaryOpCUDAKenel, false> \ - <<cuda_stream()>>>(size, cols_div, Op(), A, B, C); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Colwise##Func( \ - const int rows, \ - const int cols, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FIXED_DIVISOR cols_div(cols); \ - ColwiseBinaryOpCUDAKenel, true> \ - <<cuda_stream()>>>(size, cols_div, Op(), A, B, C); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Colwise##Func( \ - const int rows, \ - const int cols, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FIXED_DIVISOR cols_div(cols); \ - ColwiseBinaryOpCUDAKenel, false> \ - <<cuda_stream()>>>(size, cols_div, Op(), A, B, C); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } - -#define DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(double, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) - -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(EQ, thrust::equal_to) -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(NE, thrust::not_equal_to) -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(LT, thrust::less) -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(LE, thrust::less_equal) -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(GT, thrust::greater) -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(GE, thrust::greater_equal) - -#undef DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION - -#define DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION(Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int64_t, std::int64_t, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(float, float, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(double, double, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(at::Half, at::Half, Func, Op) - -DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION(Add, AddFunctor) -DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION(Sub, SubFunctor) -DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION(Mul, MulFunctor) -DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION(Div, DivFunctor) - -#undef DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION - -DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, And, thrust::logical_and) -DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Or, thrust::logical_or) -DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Xor, thrust::bit_xor) - -#define DEFINE_2D_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int64_t, std::int64_t, Func, Op) - -DEFINE_2D_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseAnd, thrust::bit_and) -DEFINE_2D_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseOr, thrust::bit_or) -DEFINE_2D_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseXor, thrust::bit_xor) - -#undef DEFINE_2D_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION - -#undef DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION - -#define DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int A_ndim, \ - const int* A_dims, \ - const int B_ndim, \ - const int* B_dims, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - BroadcastBinaryOp>( \ - A_ndim, A_dims, B_ndim, B_dims, Op(), A, B, C, context); \ - } - -#define DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(double, bool, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) - -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(EQ, thrust::equal_to) -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(NE, thrust::not_equal_to) -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(LT, thrust::less) -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(LE, thrust::less_equal) -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(GT, thrust::greater) -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(GE, thrust::greater_equal) - -#undef DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION - -#define DEFINE_BROADCAST_CUDA_BINARY_FUNCTION(Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int64_t, std::int64_t, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(float, float, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(double, double, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(at::Half, at::Half, Func, Op) - -DEFINE_BROADCAST_CUDA_BINARY_FUNCTION(Add, AddFunctor) -DEFINE_BROADCAST_CUDA_BINARY_FUNCTION(Sub, SubFunctor) -DEFINE_BROADCAST_CUDA_BINARY_FUNCTION(Mul, MulFunctor) -DEFINE_BROADCAST_CUDA_BINARY_FUNCTION(Div, DivFunctor) - -#undef DEFINE_BROADCAST_CUDA_BINARY_FUNCTION - -DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, And, thrust::logical_and) -DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Or, thrust::logical_or) -DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Xor, thrust::bit_xor) - -#define DEFINE_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, Op) - -DEFINE_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseAnd, thrust::bit_and) -DEFINE_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseOr, thrust::bit_or) -DEFINE_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseXor, thrust::bit_xor) - -#undef DEFINE_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION - -#undef DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION - -#define DELEGATE_REDUCTION_FUNCTION(T, Funcname, func) \ - template <> \ - CAFFE2_CUDA_EXPORT void Funcname( \ - const int N, \ - const T* src, \ - T* dst, \ - Tensor* scratch_ptr, \ - CUDAContext* context) { \ - size_t memRequired = 0; \ - cub::DeviceReduce::func( \ - nullptr, memRequired, src, dst, N, context->cuda_stream()); \ - auto buffer_size = \ - static_cast((memRequired + sizeof(T) - 1) / sizeof(T)); \ - scratch_ptr->Resize(std::vector{buffer_size}); \ - cub::DeviceReduce::func( \ - static_cast(scratch_ptr->mutable_data()), \ - memRequired, \ - src, \ - dst, \ - N, \ - context->cuda_stream()); \ - } - -DELEGATE_REDUCTION_FUNCTION(float, ReduceMin, Min) -DELEGATE_REDUCTION_FUNCTION(float, ReduceMax, Max) -DELEGATE_REDUCTION_FUNCTION(int32_t, ReduceMax, Max) -DELEGATE_REDUCTION_FUNCTION(int64_t, ReduceMax, Max) - -#undef DELEGATE_REDUCTION_FUNCTION - -// Caffe2 gemm provides a simpler interface to the gemm functions, with the -// limitation that the data has to be contiguous in memory. -template <> -CAFFE2_CUDA_EXPORT void Gemm( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const float* B, - const float beta, - float* C, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemm( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - N)); -} - -template <> -CAFFE2_CUDA_EXPORT void Gemm( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const at::Half* A, - const at::Half* B, - const float beta, - at::Half* C, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - if (math_type == TensorProto_DataType_FLOAT) { - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); -#if defined(USE_ROCM) - // hipblas doesn't support hipblasSgemmEx type API. - // It has more general hipblasGemmEx API which is more close to cublasGemmEx. - // hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C, - // whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C - HIPBLAS_ENFORCE(hipblasGemmEx( - context->hipblas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - HIPBLAS_R_16F, - ldb, - A, - HIPBLAS_R_16F, - lda, - &beta, - C, - HIPBLAS_R_16F, - N, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT)); -#else - CUBLAS_ENFORCE(cublasSgemmEx( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - CUDA_R_16F, - ldb, - A, - CUDA_R_16F, - lda, - &beta, - C, - CUDA_R_16F, - N)); -#endif // USE_ROCM - } else if (math_type == TensorProto_DataType_FLOAT16) { - // convert alpha, beta from float -> __half - const __half alpha_fp16 = at::Half(alpha); - const __half beta_fp16 = at::Half(beta); - // call cublasHgemm - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasHgemm( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - reinterpret_cast(&alpha_fp16), - reinterpret_cast(B), - ldb, - reinterpret_cast(A), - lda, - reinterpret_cast(&beta_fp16), - reinterpret_cast(C), - N)); - } else { - // fail - CAFFE_THROW("Unsupported math type"); - } -} - -template <> -CAFFE2_CUDA_EXPORT void BiasCHW( - const float* bias, - const float* bias_multiplier, - const int bias_channels, - const int image_size, - float* image, - CUDAContext* context) { - Gemm( - CblasNoTrans, - CblasNoTrans, - bias_channels, - image_size, - 1, - 1, - bias, - bias_multiplier, - 1, - image, - context); -} - -template <> -CAFFE2_CUDA_EXPORT void GemmBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const float** A, - const float** B, - const float beta, - float** C, - CUDAContext* context, - TensorProto::DataType math_type) { -#if defined(USE_ROCM) - // loop over matrices in the batch - for (int i = 0; i < batch_size; ++i) { - Gemm( - trans_A, - trans_B, - M, - N, - K, - alpha, - A[i], - B[i], - beta, - C[i], - context, - math_type); - } -#else - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const int ldc = N; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - thrust::device_vector A_device(A, A + batch_size); - thrust::device_vector B_device(B, B + batch_size); - thrust::device_vector C_device(C, C + batch_size); - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemmBatched( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B_device.data().get(), - ldb, - A_device.data().get(), - lda, - &beta, - C_device.data().get(), - ldc, - batch_size)); -#endif -} - -template <> -CAFFE2_CUDA_EXPORT void GemmStridedBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int A_stride, - const float* B, - const int B_stride, - const float beta, - float* C, - const int C_stride, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const int ldc = N; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemmStridedBatched( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - ldb, - B_stride, - A, - lda, - A_stride, - &beta, - C, - ldc, - C_stride, - batch_size)); -} - -template <> -CAFFE2_CUDA_EXPORT void GemmBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const at::Half** A, - const at::Half** B, - const float beta, - at::Half** C, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const int ldc = N; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - if (math_type == TensorProto_DataType_FLOAT) { - thrust::device_vector A_device(A, A + batch_size); - thrust::device_vector B_device(B, B + batch_size); - thrust::device_vector C_device(C, C + batch_size); - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); -#if defined(USE_ROCM) - auto compute_type = HIPBLAS_COMPUTE_32F; -#else - auto compute_type = CUDA_R_32F; -#endif - CUBLAS_ENFORCE(cublasGemmBatchedEx( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B_device.data().get(), - CUDA_R_16F, - ldb, - A_device.data().get(), - CUDA_R_16F, - lda, - &beta, - C_device.data().get(), - CUDA_R_16F, - ldc, - batch_size, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } else if (math_type == TensorProto_DataType_FLOAT16) { - // Convert alpha, beta from float -> __half - const __half alpha_fp16 = at::Half(alpha); - const __half beta_fp16 = at::Half(beta); - thrust::host_vector A_array(batch_size); - thrust::host_vector B_array(batch_size); - thrust::host_vector<__half*> C_array(batch_size); - for (int i = 0; i < batch_size; ++i) { - A_array[i] = reinterpret_cast(A[i]); - B_array[i] = reinterpret_cast(B[i]); - C_array[i] = reinterpret_cast<__half*>(C[i]); - } - thrust::device_vector A_device( - A_array.cbegin(), A_array.cend()); - thrust::device_vector B_device( - B_array.cbegin(), B_array.cend()); - thrust::device_vector<__half*> C_device(C_array.cbegin(), C_array.cend()); - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasHgemmBatched( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - reinterpret_cast(&alpha_fp16), - reinterpret_cast(B_device.data().get()), - ldb, - reinterpret_cast(A_device.data().get()), - lda, - reinterpret_cast(&beta_fp16), - reinterpret_cast(C_device.data().get()), - ldc, - batch_size)); - } else { - CAFFE_THROW("Unsupported math type"); - } -} - -template <> -CAFFE2_CUDA_EXPORT void GemmStridedBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const at::Half* A, - const int A_stride, - const at::Half* B, - const int B_stride, - const float beta, - at::Half* C, - const int C_stride, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const int ldc = N; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - if (math_type == TensorProto_DataType_FLOAT) { - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); -#if defined(USE_ROCM) - auto compute_type = HIPBLAS_COMPUTE_32F; -#else - auto compute_type = CUDA_R_32F; -#endif - CUBLAS_ENFORCE(cublasGemmStridedBatchedEx( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - CUDA_R_16F, - ldb, - B_stride, - A, - CUDA_R_16F, - lda, - A_stride, - &beta, - C, - CUDA_R_16F, - ldc, - C_stride, - batch_size, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } else if (math_type == TensorProto_DataType_FLOAT16) { - // Convert alpha, beta from float -> __half - const __half alpha_fp16 = at::Half(alpha); - const __half beta_fp16 = at::Half(beta); - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasHgemmStridedBatched( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - reinterpret_cast(&alpha_fp16), - reinterpret_cast(B), - ldb, - B_stride, - reinterpret_cast(A), - lda, - A_stride, - reinterpret_cast(&beta_fp16), - reinterpret_cast(C), - ldc, - C_stride, - batch_size)); - } else { - CAFFE_THROW("Unsupported math type"); - } -} - -template <> -CAFFE2_CUDA_EXPORT void Gemv( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const float* A, - const float* x, - const float beta, - float* y, - CUDAContext* context, - TensorProto::DataType math_type) { - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemv( - context->cublas_handle(), - cu_trans_A, - N, - M, - &alpha, - A, - N, - x, - 1, - &beta, - y, - 1)); -} - -template <> -CAFFE2_CUDA_EXPORT void Gemv( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const at::Half* A, - const at::Half* x, - const float beta, - at::Half* y, - CUDAContext* context, - TensorProto::DataType math_type) { - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - - // sort out what we need to call cublasSgemmEx / cublasHgemm - const int m = (cu_trans_A == CUBLAS_OP_N) ? N : M; - const int k = (cu_trans_A == CUBLAS_OP_N) ? M : N; - const int lda = (cu_trans_A == CUBLAS_OP_N) ? m : k; - const int ldc = m; - - if (math_type == TensorProto_DataType_FLOAT) { - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); -#if defined(USE_ROCM) - // hipblas doesn't support hipblasSgemmEx type API. - // It has more general hipblasGemmEx API which is more close to cublasGemmEx. - // hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C, - // whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C - HIPBLAS_ENFORCE(hipblasGemmEx( - context->hipblas_handle(), - cu_trans_A, - HIPBLAS_OP_N, - m, - 1, - k, - &alpha, - A, - HIPBLAS_R_16F, - lda, - x, - HIPBLAS_R_16F, - k, - &beta, - y, - HIPBLAS_R_16F, - ldc, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT)); -#else - CUBLAS_ENFORCE(cublasSgemmEx( - context->cublas_handle(), - cu_trans_A, - CUBLAS_OP_N, - m, - 1, - k, - &alpha, - A, - CUDA_R_16F, - lda, - x, - CUDA_R_16F, - k, - &beta, - y, - CUDA_R_16F, - ldc)); -#endif // USE_ROCM - } else if (math_type == TensorProto_DataType_FLOAT16) { - const __half alpha_fp16 = at::Half(alpha); - const __half beta_fp16 = at::Half(beta); - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasHgemm( - context->cublas_handle(), - cu_trans_A, - CUBLAS_OP_N, - m, - 1, - k, - reinterpret_cast(&alpha_fp16), - reinterpret_cast(A), - lda, - reinterpret_cast(x), - k, - reinterpret_cast(&beta_fp16), - reinterpret_cast(y), - ldc)); - } else { - // fail - CAFFE_THROW("Unsupported math type"); - } -} - -#if !defined(USE_ROCM) - -// No change, but required. Defer to default CUDA engine -template <> -CAFFE2_CUDA_EXPORT void Gemm( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const float* B, - const float beta, - float* C, - CUDAContext* context, - TensorProto::DataType math_type) { - return Gemm( - trans_A, trans_B, M, N, K, alpha, A, B, beta, C, context, math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void Gemm( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const at::Half* A, - const at::Half* B, - const float beta, - at::Half* C, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // enable TensorCore for this call on this handle - if (TensorCoreAvailable()) { - CUBLAS_ENFORCE( - cublasSetMathMode(context->cublas_handle(), CUBLAS_TENSOR_OP_MATH)); - } - - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasGemmEx( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - CUDA_R_16F, - ldb, - A, - CUDA_R_16F, - lda, - &beta, - C, - CUDA_R_16F, - N, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Now disable TensorCore math for subsequent calls to this handle - if (TensorCoreAvailable()) { - CUBLAS_ENFORCE( - cublasSetMathMode(context->cublas_handle(), CUBLAS_DEFAULT_MATH)); - } -} - -template <> -CAFFE2_CUDA_EXPORT void GemmBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const float** A, - const float** B, - const float beta, - float** C, - CUDAContext* context, - TensorProto::DataType math_type) { - GemmBatched( - trans_A, - trans_B, - batch_size, - M, - N, - K, - alpha, - A, - B, - beta, - C, - context, - math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void GemmBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const at::Half** A, - const at::Half** B, - const float beta, - at::Half** C, - CUDAContext* context, - TensorProto::DataType math_type) { - GemmBatched( - trans_A, - trans_B, - batch_size, - M, - N, - K, - alpha, - A, - B, - beta, - C, - context, - math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void -GemmStridedBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int A_stride, - const float* B, - const int B_stride, - const float beta, - float* C, - const int C_stride, - CUDAContext* context, - TensorProto::DataType math_type) { - GemmStridedBatched( - trans_A, - trans_B, - batch_size, - M, - N, - K, - alpha, - A, - A_stride, - B, - B_stride, - beta, - C, - C_stride, - context, - math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void -GemmStridedBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const at::Half* A, - const int A_stride, - const at::Half* B, - const int B_stride, - const float beta, - at::Half* C, - const int C_stride, - CUDAContext* context, - TensorProto::DataType math_type) { - GemmStridedBatched( - trans_A, - trans_B, - batch_size, - M, - N, - K, - alpha, - A, - A_stride, - B, - B_stride, - beta, - C, - C_stride, - context, - math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void Gemv( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const float* A, - const float* x, - const float beta, - float* y, - CUDAContext* context, - TensorProto::DataType math_type) { - Gemv( - trans_A, M, N, alpha, A, x, beta, y, context, math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void Gemv( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const at::Half* A, - const at::Half* x, - const float beta, - at::Half* y, - CUDAContext* context, - TensorProto::DataType math_type) { - Gemv( - trans_A, M, N, alpha, A, x, beta, y, context, math_type); -} - -#endif - -template <> -CAFFE2_CUDA_EXPORT void GemmEx( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int lda, - const float* B, - const int ldb, - const float beta, - float* C, - const int ldc, - CUDAContext* context) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemm( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - ldc)); -} - -// Batched Add variants -namespace { - -template -__global__ void AddStripedBatchKernel( - const int N, - const T* first, - T* Y, - const int stripe, - const int batch) { - for (int j = 0; j < batch; j++) { - const T* x = first + j * stripe; - CUDA_1D_KERNEL_LOOP(i, N) { - float tmpY = convert::To(Y[i]); - tmpY += convert::To(x[i]); - Y[i] = convert::To(tmpY); - } - } -} -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void AddStripedBatch( \ - const int N, \ - const T* first, \ - T* Y, \ - const int stripe, \ - const int batch, \ - CUDAContext* context) { \ - AddStripedBatchKernel \ - <<cuda_stream()>>>(N, first, Y, stripe, batch); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } - -CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float); -CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(at::Half); -#undef CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH - -namespace { -template -__global__ void -UniformShift(const size_t N, const float min, const float max, T* x) { - float scale = max - min; - CUDA_1D_KERNEL_LOOP(i, N) { - x[i] = convert::To(convert::To(x[i]) * scale + min); - } -} - -__global__ void -UniformIntFit(const size_t N, const int min, const int max, unsigned int* x) { - int* x_int = reinterpret_cast(x); - int range = (max - min + 1); - CUDA_1D_KERNEL_LOOP(i, N) { - x_int[i] = min + static_cast(x[i] % range); - } -} -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void RandUniform( - const size_t n, - const float min, - const float max, - float* r, - CUDAContext* context) { - CURAND_ENFORCE(curandGenerateUniform(context->curand_generator(), r, n)); - UniformShift - <<cuda_stream()>>>(n, min, max, r); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void RandUniform( - const size_t n, - const double min, - const double max, - double* r, - CUDAContext* context) { - CURAND_ENFORCE( - curandGenerateUniformDouble(context->curand_generator(), r, n)); - UniformShift - <<cuda_stream()>>>(n, min, max, r); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void RandUniform( - const size_t n, - const int min, - const int max, - int* r, - CUDAContext* context) { - CURAND_ENFORCE(curandGenerate( - context->curand_generator(), reinterpret_cast(r), n)); - UniformIntFit<<< - CAFFE_GET_BLOCKS(n), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>( - n, min, max, reinterpret_cast(r)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -size_t HandleOddLengthRandGaussian( - const size_t n, - const T mean, - const T std, - T* r, - CUDAContext* context) { - if (n % 2 == 1) { - std::default_random_engine generator; - std::normal_distribution distribution(mean, std); - const T random_value = distribution(generator); - Set(1, random_value, r + (n - 1), context); - return n - 1; - } - return n; -} - -template <> -CAFFE2_CUDA_EXPORT void RandGaussian( - const size_t n, - const float mean, - const float std, - float* r, - CUDAContext* context) { - // If n is odd, we add a random Gaussian value at the end manually - // and generate n-1 random values using curandGenerateNormal. - // curandGenerateNormal requires n to be even. - const size_t even_n = - HandleOddLengthRandGaussian(n, mean, std, r, context); - CURAND_ENFORCE( - curandGenerateNormal(context->curand_generator(), r, even_n, mean, std)); -} - -template <> -CAFFE2_CUDA_EXPORT void RandGaussian( - const size_t n, - const double mean, - const double std, - double* r, - CUDAContext* context) { - const size_t even_n = - HandleOddLengthRandGaussian(n, mean, std, r, context); - CURAND_ENFORCE(curandGenerateNormalDouble( - context->curand_generator(), r, even_n, mean, std)); -} - -template <> -CAFFE2_CUDA_EXPORT void Dot( - const int n, - const float* a, - const float* b, - float* y, - CUDAContext* context) { - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); - CUBLAS_ENFORCE(cublasSdot(context->cublas_handle(), n, a, 1, b, 1, y)); -} - -template <> -CAFFE2_CUDA_EXPORT void Dot( - const int n, - const at::Half* a, - const at::Half* b, - at::Half* y, - CUDAContext* context) { - // execute with 32-bit math - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); - CUBLAS_ENFORCE(cublasDotEx( - context->cublas_handle(), - n, - a, - CUDA_R_16F, - 1, - b, - CUDA_R_16F, - 1, - y, - CUDA_R_16F, - CUDA_R_32F)); -} - -// A previous version of caffe2 used Thrust but it turns out that thrust -// reduction has an implicit scratch space allocation and deallocation, which -// may interfere with NCCL and create a deadlock. Hence we are using a custom -// reduction here. -#define SUM_KERNEL_NTHREADS 128 -template -__global__ void SumKernel(const int N, const T* X, T* Y, bool square) { - const int idx = threadIdx.x; - __shared__ float reduction_buffer[SUM_KERNEL_NTHREADS]; - - reduction_buffer[idx] = 0; - - // A multilevel reduction. - // N -> 128 - if (!square) { - for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) { - reduction_buffer[idx] += convert::To(X[i]); - } - } else { - for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) { - float Xi = convert::To(X[i]); - reduction_buffer[idx] += Xi * Xi; - } - } - __syncthreads(); - // 128 -> 32 - if (idx < 32) { - reduction_buffer[idx] += reduction_buffer[idx + 32] + - reduction_buffer[idx + 64] + reduction_buffer[idx + 96]; - } - __syncthreads(); - // 32 -> 1 - if (idx == 0) { - float tmp = 0; - for (int i = 0; i < 32; ++i) { - tmp += reduction_buffer[i]; - } - *Y = convert::To(tmp); - } -} - -// According to the benchmarks script -// caffe2/caffe2/experiments/python/device_reduce_sum_bench.py, -// device reduce is slower for N <= 10000. -#define DEVICE_REDUCE_SIZE_THRESHOLD 10000 - -namespace { - -template -__global__ void SumConvertKernel(float* sum, T* dest) { - *dest = convert::To(*sum); -} - -template -CAFFE2_CUDA_EXPORT void SumGenericIter( - const int N, - IterT it, - T*& dest, - CUDAContext* context, - Tensor* scratch_ptr) { - size_t memRequired = 0; - cub::DeviceReduce::Sum( - nullptr, memRequired, it, dest, N, context->cuda_stream()); - auto buffer_size = - static_cast((memRequired + sizeof(T) - 1) / sizeof(T)); - if (!dest) { - // allocate one more T at the end of scratch for dest - scratch_ptr->Resize(std::vector{buffer_size + 1}); - dest = scratch_ptr->template mutable_data() + buffer_size; - } else { - scratch_ptr->Resize(std::vector{buffer_size}); - } - cub::DeviceReduce::Sum( - static_cast(scratch_ptr->template mutable_data()), - memRequired, - it, - dest, - N, - context->cuda_stream()); -} -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void Sum( - const int N, - const float* x, - float* y, - CUDAContext* context, - Tensor* scratch_ptr) { - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { - SumGenericIter(N, x, y, context, scratch_ptr); - } else { - SumKernel<<<1, SUM_KERNEL_NTHREADS, 0, context->cuda_stream()>>>( - N, x, y, false); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } -} - -template <> -CAFFE2_CUDA_EXPORT void Sum( - const int N, - const int32_t* x, - int32_t* y, - CUDAContext* context, - Tensor* scratch_ptr) { - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { - SumGenericIter(N, x, y, context, scratch_ptr); - } else { - SumKernel<<<1, SUM_KERNEL_NTHREADS, 0, context->cuda_stream()>>>( - N, x, y, false); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } -} - -namespace { -template -struct FloatTransform { - inline __host__ __device__ float operator()(const T v) const { - return convert::To(v); - } -}; -} // namespace - -#define CAFFE2_MATH_SUM_FUNC(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void Sum( \ - const int N, \ - const T* x, \ - T* y, \ - CUDAContext* context, \ - Tensor* scratch_ptr) { \ - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { \ - FloatTransform transform; \ - cub::TransformInputIterator, const T*> it( \ - x, transform); \ - float* sum = nullptr; \ - SumGenericIter(N, it, sum, context, scratch_ptr); \ - SumConvertKernel<<<1, 1, 0, context->cuda_stream()>>>(sum, y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else { \ - SumKernel<<<1, SUM_KERNEL_NTHREADS, 0, context->cuda_stream()>>>( \ - N, x, y, false); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } - -CAFFE2_MATH_SUM_FUNC(at::Half) -#undef CAFFE2_MATH_SUM_FUNC - -namespace { -template -struct SqrTransform { - inline __host__ __device__ T operator()(const T v) const { - return v * v; - } -}; -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void SumSqr( - const int N, - const float* x, - float* y, - CUDAContext* context, - Tensor* scratch_ptr) { - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { - SqrTransform transform; - cub::TransformInputIterator, const float*> it( - x, transform); - SumGenericIter(N, it, y, context, scratch_ptr); - } else { - SumKernel<<<1, SUM_KERNEL_NTHREADS, 0, context->cuda_stream()>>>( - N, x, y, true); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } -} - -#define CAFFE2_MATH_SUMSQR_FUNC(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void SumSqr( \ - const int N, \ - const T* x, \ - T* y, \ - CUDAContext* context, \ - Tensor* scratch_ptr) { \ - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { \ - FloatTransform float_transform; \ - cub::TransformInputIterator, const T*> \ - float_it(x, float_transform); \ - SqrTransform sqr_transform; \ - cub::TransformInputIterator< \ - float, \ - SqrTransform, \ - decltype(float_it)> \ - it(float_it, sqr_transform); \ - float* sum = nullptr; \ - SumGenericIter(N, it, sum, context, scratch_ptr); \ - SumConvertKernel<<<1, 1, 0, context->cuda_stream()>>>(sum, y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else { \ - SumKernel<<<1, SUM_KERNEL_NTHREADS, 0, context->cuda_stream()>>>( \ - N, x, y, true); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } - -CAFFE2_MATH_SUMSQR_FUNC(at::Half) -#undef CAFFE2_MATH_SUMSQR_FUNC -#undef DEVICE_REDUCE_SIZE_THRESHOLD - -namespace { -template -__global__ void -SelectKernel(const int N, const int D, const T* x, const int* idx, T* y) { - CUDA_1D_KERNEL_LOOP(i, N) { - y[i] = x[i * D + idx[i]]; - } -} -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void Select( - const int N, - const int D, - const float* x, - const int* idx, - float* y, - CUDAContext* context) { - SelectKernel - <<cuda_stream()>>>(N, D, x, idx, y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void Select( - const int N, - const int D, - const at::Half* x, - const int* idx, - at::Half* y, - CUDAContext* context) { - SelectKernel - <<cuda_stream()>>>(N, D, x, idx, y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -namespace { - -template -__global__ void Im2ColNCHWCUDAKernel( - const int n, - const int input_h, - const int input_w, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int stride_h, - const int stride_w, - const int output_h, - const int output_w, - const T* img_data, - T* col_data) { - CUDA_1D_KERNEL_LOOP(index, n) { - const int w_out = index % output_w; - const int h_index = index / output_w; - const int h_out = h_index % output_h; - const int channel_in = h_index / output_h; - const int channel_out = channel_in * kernel_h * kernel_w; - const int h_in = h_out * stride_h - pad_t; - const int w_in = w_out * stride_w - pad_l; - const int output_size = output_h * output_w; - T* col_data_ptr = - col_data + (channel_out * output_h + h_out) * output_w + w_out; - const T* img_data_ptr = - img_data + (channel_in * input_h + h_in) * input_w + w_in; - int dh = 0; - for (int i = 0; i < kernel_h; ++i) { - int dw = 0; - for (int j = 0; j < kernel_w; ++j) { - const int h = h_in + dh; - const int w = w_in + dw; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - *col_data_ptr = utils::IsAGeZeroAndALtB(h, input_h) && - utils::IsAGeZeroAndALtB(w, input_w) - ? __ldg(img_data_ptr + dh * input_w + dw) - : 0; -#else - *col_data_ptr = utils::IsAGeZeroAndALtB(h, input_h) && - utils::IsAGeZeroAndALtB(w, input_w) - ? img_data_ptr[dh * input_w + dw] - : 0; -#endif - col_data_ptr += output_size; - dw += dilation_w; - } - dh += dilation_h; - } - } -} - -template -__global__ void Im2ColNHWCCUDAKernel( - const int n, - const int input_h, - const int input_w, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int stride_h, - const int stride_w, - const int output_w, - const int channels, - const T* img_data, - T* col_data) { - CUDA_1D_KERNEL_LOOP(index, n) { - const int channel_in = index % channels; - const int w_out = index / channels % output_w; - const int h_out = index / channels / output_w; - const int h_in = h_out * stride_h - pad_t; - const int w_in = w_out * stride_w - pad_l; - T* col_data_ptr = col_data + - (h_out * output_w + w_out) * channels * kernel_h * kernel_w + - channel_in; - int dh = 0; - for (int i = 0; i < kernel_h; ++i) { - int dw = 0; - for (int j = 0; j < kernel_w; ++j) { - const int h = h_in + dh; - const int w = w_in + dw; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - *col_data_ptr = utils::IsAGeZeroAndALtB(h, input_h) && - utils::IsAGeZeroAndALtB(w, input_w) - ? __ldg(img_data + (h * input_w + w) * channels + channel_in) - : 0; -#else - *col_data_ptr = utils::IsAGeZeroAndALtB(h, input_h) && - utils::IsAGeZeroAndALtB(w, input_w) - ? img_data[(h * input_w + w) * channels + channel_in] - : 0; -#endif - col_data_ptr += channels; - dw += dilation_w; - } - dh += dilation_h; - } - } -} - -template -__global__ void Col2ImNCHWCUDAKernel( - const int n, - const int input_h, - const int input_w, - const int patch_h, - const int patch_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int stride_h, - const int stride_w, - const int output_h, - const int output_w, - const T* col_data, - T* img_data) { - const int dpatch_h = dilation_h * (patch_h - 1) + 1; - const int dpatch_w = dilation_w * (patch_w - 1) + 1; - - CUDA_1D_KERNEL_LOOP(index, n) { - T val = 0; - const int w = index % input_w + pad_l; - const int h = index / input_w % input_h + pad_t; - const int c = index / (input_h * input_w); - - // compute the start and end of the output - const int w_col_start = (w < dpatch_w) ? 0 : (w - dpatch_w) / stride_w + 1; - const int w_col_end = min(w / stride_w + 1, output_w); - const int h_col_start = (h < dpatch_h) ? 0 : (h - dpatch_h) / stride_h + 1; - const int h_col_end = min(h / stride_h + 1, output_h); - - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - int h_k = (h - h_col * stride_h); - int w_k = (w - w_col * stride_w); - if (h_k % dilation_h == 0 && w_k % dilation_w == 0) { - h_k /= dilation_h; - w_k /= dilation_w; - const int col_data_index = - (((c * patch_h + h_k) * patch_w + w_k) * output_h + h_col) * - output_w + - w_col; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val += __ldg(col_data + col_data_index); -#else - val += col_data[col_data_index]; -#endif - } - } - } - img_data[index] = val; - } -} - -template -__global__ void Col2ImNHWCCUDAKernel( - const int n, - const int input_w, - const int channels, - const int patch_h, - const int patch_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int stride_h, - const int stride_w, - const int output_h, - const int output_w, - const T* col_data, - T* img_data) { - const int dpatch_h = dilation_h * (patch_h - 1) + 1; - const int dpatch_w = dilation_w * (patch_w - 1) + 1; - - CUDA_1D_KERNEL_LOOP(index, n) { - T val = 0; - const int c = index % channels; - const int w = index / channels % input_w + pad_l; - const int h = index / channels / input_w + pad_t; - // compute the start and end of the output - const int w_col_start = (w < dpatch_w) ? 0 : (w - dpatch_w) / stride_w + 1; - const int w_col_end = min(w / stride_w + 1, output_w); - const int h_col_start = (h < dpatch_h) ? 0 : (h - dpatch_h) / stride_h + 1; - const int h_col_end = min(h / stride_h + 1, output_h); - const int channels_col = patch_h * patch_w * channels; - - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - int h_k = h - h_col * stride_h; - int w_k = w - w_col * stride_w; - if (h_k % dilation_h == 0 && w_k % dilation_w == 0) { - h_k /= dilation_h; - w_k /= dilation_w; - const int c_col = (h_k * patch_w + w_k) * channels + c; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val += __ldg( - col_data + (h_col * output_w + w_col) * channels_col + c_col); -#else - val += col_data[(h_col * output_w + w_col) * channels_col + c_col]; -#endif - } - } - } - img_data[index] = val; - } -} - -template -__global__ void Im2ColNdNCHWCUDAKernel( - const int outer_size, - const int inner_size, - const int kernel_size, - SimpleArray img_shape, - SimpleArray col_shape, - SimpleArray kernel_shape, - SimpleArray stride, - SimpleArray dilation, - SimpleArray pad, - const T* X_data, - T* Y_data) { - int d_offset[N]; - int d_iter[N]; - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - int offset_i = i; -#pragma unroll - for (int d_i = N - 1; d_i >= 0; --d_i) { - d_offset[d_i] = offset_i % kernel_shape.data[d_i]; - offset_i /= kernel_shape.data[d_i]; - } - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - int offset_j = j; -#pragma unroll - for (int d_i = N - 1; d_i >= 0; --d_i) { - d_iter[d_i] = offset_j % col_shape.data[d_i + 1]; - offset_j /= col_shape.data[d_i + 1]; - } - const int col_index = i * inner_size + j; - int img_index = i / kernel_size; - bool is_padding = false; -#pragma unroll - for (int d_i = 0; d_i < N; ++d_i) { - const int d_img = d_iter[d_i] * stride.data[d_i] - pad.data[d_i] + - d_offset[d_i] * dilation.data[d_i]; - is_padding |= !utils::IsAGeZeroAndALtB(d_img, img_shape.data[d_i + 1]); - img_index = img_index * img_shape.data[d_i + 1] + d_img; - } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - if (!kCol2Im) { - Y_data[col_index] = is_padding ? 0 : __ldg(X_data + img_index); - } else if (!is_padding) { - gpu_atomic_add(Y_data + img_index, __ldg(X_data + col_index)); - } -#else - if (!kCol2Im) { - Y_data[col_index] = is_padding ? 0 : X_data[img_index]; - } else if (!is_padding) { - gpu_atomic_add(Y_data + img_index, X_data[col_index]); - } -#endif - } - } -} - -template -CAFFE2_CUDA_EXPORT void Im2ColNdNCHWCUDAImpl( - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* img_data, - float* col_data, - CUDAContext* context) { - const int outer_size = col_shape[0]; - const int inner_size = col_size / outer_size; - const int kernel_size = std::accumulate( - kernel_shape, kernel_shape + N, 1, std::multiplies()); - SimpleArray img_shape_array; - SimpleArray col_shape_array; - SimpleArray kernel_shape_array; - SimpleArray stride_array; - SimpleArray dilation_array; - SimpleArray pad_array; - std::memcpy(img_shape_array.data, img_shape, (N + 1) * sizeof(int)); - std::memcpy(col_shape_array.data, col_shape, (N + 1) * sizeof(int)); - std::memcpy(kernel_shape_array.data, kernel_shape, N * sizeof(int)); - std::memcpy(stride_array.data, stride, N * sizeof(int)); - std::memcpy(dilation_array.data, dilation, N * sizeof(int)); - std::memcpy(pad_array.data, pad, N * sizeof(int)); - Im2ColNdNCHWCUDAKernel - <<cuda_stream()>>>( - outer_size, - inner_size, - kernel_size, - img_shape_array, - col_shape_array, - kernel_shape_array, - stride_array, - dilation_array, - pad_array, - img_data, - col_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -CAFFE2_CUDA_EXPORT void Col2ImNdNCHWCUDAImpl( - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* col_data, - float* img_data, - CUDAContext* context) { - const int outer_size = col_shape[0]; - const int inner_size = col_size / outer_size; - const int kernel_size = std::accumulate( - kernel_shape, kernel_shape + N, 1, std::multiplies()); - SimpleArray img_shape_array; - SimpleArray col_shape_array; - SimpleArray kernel_shape_array; - SimpleArray stride_array; - SimpleArray dilation_array; - SimpleArray pad_array; - std::memcpy(img_shape_array.data, img_shape, (N + 1) * sizeof(int)); - std::memcpy(col_shape_array.data, col_shape, (N + 1) * sizeof(int)); - std::memcpy(kernel_shape_array.data, kernel_shape, N * sizeof(int)); - std::memcpy(stride_array.data, stride, N * sizeof(int)); - std::memcpy(dilation_array.data, dilation, N * sizeof(int)); - std::memcpy(pad_array.data, pad, N * sizeof(int)); - Set(img_size, 0, img_data, context); - Im2ColNdNCHWCUDAKernel - <<cuda_stream()>>>( - outer_size, - inner_size, - kernel_size, - img_shape_array, - col_shape_array, - kernel_shape_array, - stride_array, - dilation_array, - pad_array, - col_data, - img_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void Im2Col( - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const float* img_data, - float* col_data, - CUDAContext* context, - const int /* groups */) { - const int dkernel_h = dilation_h * (kernel_h - 1) + 1; - const int dkernel_w = dilation_w * (kernel_w - 1) + 1; - const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; - const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; - const int num_kernels = channels * output_h * output_w; - Im2ColNCHWCUDAKernel - <<cuda_stream()>>>( - num_kernels, - height, - width, - kernel_h, - kernel_w, - dilation_h, - dilation_w, - pad_t, - pad_l, - stride_h, - stride_w, - output_h, - output_w, - img_data, - col_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void Im2Col( - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const float* img_data, - float* col_data, - CUDAContext* context, - const int groups) { - CAFFE_ENFORCE_EQ(groups, 1, "groups must be 1 for GPU NHWC Im2Col"); - - const int dkernel_h = dilation_h * (kernel_h - 1) + 1; - const int dkernel_w = dilation_w * (kernel_w - 1) + 1; - const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; - const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; - const int num_kernels = output_h * output_w * channels; - Im2ColNHWCCUDAKernel - <<cuda_stream()>>>( - num_kernels, - height, - width, - kernel_h, - kernel_w, - dilation_h, - dilation_w, - pad_t, - pad_l, - stride_h, - stride_w, - output_w, - channels, - img_data, - col_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void Col2Im( - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const float* col_data, - float* img_data, - CUDAContext* context, - const int /* groups */) { - // In NCHW, the number of groups doesn't affect Col2Im. - const int dkernel_h = dilation_h * (kernel_h - 1) + 1; - const int dkernel_w = dilation_w * (kernel_w - 1) + 1; - const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; - const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; - const int num_kernels = channels * height * width; - Col2ImNCHWCUDAKernel - <<cuda_stream()>>>( - num_kernels, - height, - width, - kernel_h, - kernel_w, - dilation_h, - dilation_w, - pad_t, - pad_l, - stride_h, - stride_w, - output_h, - output_w, - col_data, - img_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void Col2Im( - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const float* col_data, - float* img_data, - CUDAContext* context, - const int groups) { - CAFFE_ENFORCE_EQ(groups, 1, "groups must be 1 for GPU NHWC Col2Im"); - - const int dkernel_h = dilation_h * (kernel_h - 1) + 1; - const int dkernel_w = dilation_w * (kernel_w - 1) + 1; - const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; - const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; - const int num_kernels = height * width * channels; - Col2ImNHWCCUDAKernel - <<cuda_stream()>>>( - num_kernels, - width, - channels, - kernel_h, - kernel_w, - dilation_h, - dilation_w, - pad_t, - pad_l, - stride_h, - stride_w, - output_h, - output_w, - col_data, - img_data); -C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void Im2ColNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* img_data, - float* col_data, - CUDAContext* context, - const int /* groups */) { - // In NCHW, the number of groups doesn't affect Im2Col. - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( - N, - Im2ColNdNCHWCUDAImpl, - float, - img_size, - col_size, - img_shape, - col_shape, - kernel_shape, - stride, - dilation, - pad, - img_data, - col_data, - context); -} - -template <> -CAFFE2_CUDA_EXPORT void Im2ColNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* img_data, - float* col_data, - CUDAContext* context, - const int groups) { - CAFFE_NOT_IMPLEMENTED; -} - -template <> -CAFFE2_CUDA_EXPORT void Col2ImNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* col_data, - float* img_data, - CUDAContext* context, - int /* groups */) { - // In NCHW, the number of groups doesn't affect Col2Im. - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( - N, - Col2ImNdNCHWCUDAImpl, - float, - img_size, - col_size, - img_shape, - col_shape, - kernel_shape, - stride, - dilation, - pad, - col_data, - img_data, - context); -} - -template <> -CAFFE2_CUDA_EXPORT void Col2ImNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* col_data, - float* img_data, - CUDAContext* context, - int groups) { - CAFFE_NOT_IMPLEMENTED; -} - -template <> -CAFFE2_CUDA_EXPORT void CopyMatrix( - const size_t itemsize, - const int M, - const int N, - const void* A, - const int lda, - void* B, - const int ldb, - CUDAContext* context, - TypeMeta::Copy copy) { - CAFFE_ENFORCE(!copy, "Copy constructor is not supported in CUDA context"); - cudaMemcpy2DAsync( - B, - ldb * itemsize, - A, - lda * itemsize, - N * itemsize, - M, - cudaMemcpyDeviceToDevice, - context->cuda_stream()); -} - -#define CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX(T) \ - template <> \ - void CopyMatrix( \ - const int M, \ - const int N, \ - const T* A, \ - const int lda, \ - T* B, \ - const int ldb, \ - CUDAContext* context) { \ - if (M == 0 || N == 0) { \ - return; \ - } \ - cudaMemcpy2DAsync( \ - B, \ - sizeof(T) * ldb, \ - A, \ - sizeof(T) * lda, \ - sizeof(T) * N, \ - M, \ - cudaMemcpyDeviceToDevice, \ - context->cuda_stream()); \ - } -CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX(float) -CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX(double) -CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX(int) -CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX(int64_t) -#undef CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX - -template <> -CAFFE2_CUDA_EXPORT void CopyVector( - const int N, - const float* src, - float* dst, - CUDAContext* context) { - if (src != dst && N > 0) { - C10_CUDA_CHECK(cudaMemcpyAsync( - dst, - src, - sizeof(float) * N, - cudaMemcpyDeviceToDevice, - context->cuda_stream())); - } -} - -template <> -CAFFE2_CUDA_EXPORT void CopyVector( - const int N, - const int* src, - int* dst, - CUDAContext* context) { - if (src != dst && N > 0) { - C10_CUDA_CHECK(cudaMemcpyAsync( - dst, - src, - sizeof(int) * N, - cudaMemcpyDeviceToDevice, - context->cuda_stream())); - } -} - -namespace { - -template -using BlockReduce = cub::BlockReduce; - -template -__global__ void RowwiseReduceKernel( - const int rows, - const int cols, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce::TempStorage temp_storage; - for (int i = blockIdx.x; i < rows; i += gridDim.x) { - T val = init; - for (int j = threadIdx.x; j < cols; j += blockDim.x) { - val = reducer(X[i * cols + j], val); - } - val = BlockReduce(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[i] = val * alpha; - } - __syncthreads(); - } -} - -template -__global__ void ColwiseReduceKernel( - const int rows, - const int cols, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce::TempStorage temp_storage; - for (int i = blockIdx.x; i < cols; i += gridDim.x) { - T val = init; - for (int j = threadIdx.x; j < rows; j += blockDim.x) { - val = reducer(X[j * cols + i], val); - } - val = BlockReduce(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[i] = val * alpha; - } - __syncthreads(); - } -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_ROWWISE_MAX(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void RowwiseMax( \ - const int N, const int D, const T* x, T* y, CUDAContext* context) { \ - RowwiseReduceKernel<<< \ - std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>( \ - N, D, cub::Max(), std::numeric_limits::lowest(), T(1), x, y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_ROWWISE_MAX(float) -#undef CAFFE2_SPECIALIZED_CUDA_ROWWISE_MAX - -#define CAFFE2_SPECIALIZED_CUDA_COLWISE_MAX(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void ColwiseMax( \ - const int N, const int D, const T* x, T* y, CUDAContext* context) { \ - ColwiseReduceKernel<<< \ - std::min(D, CAFFE_MAXIMUM_NUM_BLOCKS), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>( \ - N, D, cub::Max(), std::numeric_limits::lowest(), T(1), x, y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_COLWISE_MAX(float) -#undef CAFFE2_SPECIALIZED_CUDA_COLWISE_MAX - -namespace { -__global__ void -maximum_kernel(const int N, const float alpha, const float* x, float* y) { - CUDA_1D_KERNEL_LOOP(i, N) { - y[i] = fmaxf(x[i], alpha); - } -} -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void Maximum( - const int N, - const float alpha, - const float* x, - float* y, - CUDAContext* context) { - maximum_kernel<<< - std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(N, alpha, x, y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -namespace { - -template -__global__ void BroadcastCUDAKernel( - const int Y_size, - const SimpleArray X_strides, - const SimpleArray Y_dims, - const T alpha, - const T* X, - T* Y) { - CUDA_1D_KERNEL_LOOP(Y_index, Y_size) { - int X_index = 0; - int Y_index_val = Y_index; -#pragma unroll - for (int i = D - 1; i >= 0; --i) { - int d; - FIXED_DIVISOR_DIV_MOD(Y_dims.data[i], Y_index_val, &Y_index_val, &d); - X_index += d * X_strides.data[i]; - } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - Y[Y_index] = __ldg(X + X_index) * alpha; -#else - Y[Y_index] = X[X_index] * alpha; -#endif - } -} - -template -CAFFE2_CUDA_EXPORT void BroadcastCUDAImpl( - const int X_ndim, - const int* X_dims, - const int* Y_dims, - const T alpha, - const T* X, - T* Y, - CUDAContext* context) { - SimpleArray X_strides_array; - SimpleArray Y_dims_array; - const int d = D - X_ndim; - std::fill(X_strides_array.data, X_strides_array.data + d, 0); - int cur_stride = 1; - for (int i = D - 1; i >= d; --i) { - CAFFE_ENFORCE(X_dims[i - d] == 1 || X_dims[i - d] == Y_dims[i]); - X_strides_array.data[i] = X_dims[i - d] == 1 ? 0 : cur_stride; - cur_stride *= X_dims[i - d]; - } - for (int i = 0; i < D; ++i) { - if (Y_dims[i] == 0) { - return; - } - Y_dims_array.data[i] = FIXED_DIVISOR(Y_dims[i]); - } - const int Y_size = - std::accumulate(Y_dims, Y_dims + D, 1, std::multiplies()); - BroadcastCUDAKernel - <<cuda_stream()>>>( - Y_size, X_strides_array, Y_dims_array, alpha, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_BROADCAST(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void Broadcast( \ - const int X_ndim, \ - const int* X_dims, \ - const int Y_ndim, \ - const int* Y_dims, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context, \ - bool) { \ - CAFFE_ENFORCE_LE(X_ndim, Y_ndim); \ - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( \ - Y_ndim, \ - BroadcastCUDAImpl, \ - T, \ - X_ndim, \ - X_dims, \ - Y_dims, \ - alpha, \ - X, \ - Y, \ - context); \ - } -CAFFE2_SPECIALIZED_CUDA_BROADCAST(std::int32_t) -CAFFE2_SPECIALIZED_CUDA_BROADCAST(std::int64_t) -CAFFE2_SPECIALIZED_CUDA_BROADCAST(float) -CAFFE2_SPECIALIZED_CUDA_BROADCAST(double) -#undef CAFFE2_SPECIALIZED_CUDA_BROADCAST - -namespace { - -template -__global__ void -InvStdCUDAKernel(const int N, const T epsilon, const T* var, T* inv_std); - -#define DELEGATE_INV_STD_KERNEL_FUNCTION(T, Func) \ - template <> \ - __global__ void InvStdCUDAKernel( \ - const int N, const T epsilon, const T* var, T* inv_std) { \ - CUDA_1D_KERNEL_LOOP(i, N) { \ - inv_std[i] = Func(var[i] + epsilon); \ - } \ - } -DELEGATE_INV_STD_KERNEL_FUNCTION(float, rsqrtf) -#undef DELEGATE_INV_STD_KERNEL_FUNCTION - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_INV_STD(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void InvStd( \ - const int N, \ - const T epsilon, \ - const T* var, \ - T* inv_std, \ - CUDAContext* context) { \ - InvStdCUDAKernel \ - <<cuda_stream()>>>(N, epsilon, var, inv_std); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_INV_STD(float) -#undef CAFFE2_SPECIALIZED_CUDA_INV_STD - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc deleted file mode 100644 index 330a724162cd..000000000000 --- a/caffe2/utils/math_gpu_test.cc +++ /dev/null @@ -1,429 +0,0 @@ -#include -#include -#include -#include - -#include - -#include "caffe2/core/context.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/core/flags.h" -#include "caffe2/operators/utility_ops.h" -#include "caffe2/utils/math.h" - -C10_DECLARE_string(caffe_test_root); - -namespace caffe2 { - -void executeGpuBinaryOpTest( - int shapex0, - int shapex1, - int shapey, - std::function input0, - std::function input1, - std::function operation, - std::function correct_output) { - if (!HasCudaGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_CUDA); - CUDAContext context(option); - - Blob* blobx0 = ws.CreateBlob("X0"); - Blob* blobx1 = ws.CreateBlob("X1"); - Blob* bloby = ws.CreateBlob("Y"); - Blob* bloby_host = ws.CreateBlob("Y_host"); - - auto* tensorx0 = BlobGetMutableTensor(blobx0, CUDA); - auto* tensorx1 = BlobGetMutableTensor(blobx1, CUDA); - auto* tensory = BlobGetMutableTensor(bloby, CUDA); - - vector shapex0_vector{shapex0}; - vector shapex1_vector{shapex1}; - vector shapey_vector{shapey}; - - tensorx0->Resize(shapex0_vector); - tensorx1->Resize(shapex1_vector); - tensory->Resize(shapey_vector); - - for (int i = 0; i < shapex0; i++) { - math::Set( - 1, input0(i), tensorx0->mutable_data() + i, &context); - } - for (int i = 0; i < shapex1; i++) { - math::Set( - 1, input1(i), tensorx1->mutable_data() + i, &context); - } - operation( - shapex0, - shapex1, - tensorx0->template data(), - tensorx1->template data(), - tensory->mutable_data(), - &context); - context.FinishDeviceComputation(); - - // Copy result to CPU so we can inspect it - auto* tensory_host = BlobGetMutableTensor(bloby_host, CPU); - tensory_host->CopyFrom(*tensory); - - for (int i = 0; i < shapey; ++i) { - EXPECT_EQ(tensory_host->data()[i], correct_output(i)); - } -} - -TEST(MathUtilGPUTest, testAddStripedBatch) { - if (!HasCudaGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_CUDA); - CUDAContext context(option); - Blob* blobx = ws.CreateBlob("X"); - Blob* bloby = ws.CreateBlob("Y"); - Blob* bloby_host = ws.CreateBlob("Y_host"); - - vector shapex{33 * 9, 25}; - vector shapey{33, 25}; - - auto* tensorx = BlobGetMutableTensor(blobx, CUDA); - tensorx->Resize(shapex); - int stripe = 33 * 25; - vector tot(33, 0.0); - for (int j = 0; j < 9; j++) { - // Have different values for each line - for (int k = 0; k < 33; k++) { - math::Set( - 33, - 1.0 + j + k, - tensorx->mutable_data() + j * stripe + k * 25, - &context); - tot[k] += 1.0 + j + k; - } - } - - auto* tensory = BlobGetMutableTensor(bloby, CUDA); - tensory->Resize(shapey); - math::Set( - stripe, 0.0, tensory->mutable_data(), &context); - - math::AddStripedBatch( - stripe, - tensorx->template data(), - tensory->mutable_data(), - stripe, - 9, - &context); - context.FinishDeviceComputation(); - - // Copy result to CPU so we can inspect it - auto* tensory_host = BlobGetMutableTensor(bloby_host, CPU); - tensory_host->CopyFrom(*tensory); - - for (int k = 0; k < 33; k++) { - for (int i = 0; i < 25; i++) { - EXPECT_EQ(tensory_host->data()[k * 25 + i], tot[k]); - } - } -} - -TEST(MathUtilGPUTest, testReduceMin) { - executeGpuBinaryOpTest( - 6, - 1, - 1, - [](int /*i*/) { return 11.0f; }, - [](int /*i*/) { return 0.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* /*src1*/, - float* dst, - CUDAContext* context) { - Tensor aux(CUDA); - math::ReduceMin(N0, src0, dst, &aux, context); - }, - [](int /*i*/) { return 11.0f; }); - executeGpuBinaryOpTest( - 6, - 1, - 1, - [](int i) { return i == 3 ? 11.0f : 17.0f; }, - [](int /*i*/) { return 0.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* /*src1*/, - float* dst, - CUDAContext* context) { - Tensor aux(CUDA); - math::ReduceMin(N0, src0, dst, &aux, context); - }, - [](int /*i*/) { return 11.0f; }); -} - -TEST(MathUtilGPUTest, testReduceMax) { - executeGpuBinaryOpTest( - 6, - 1, - 1, - [](int /*i*/) { return 11.0f; }, - [](int /*i*/) { return 0.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* /*src1*/, - float* dst, - CUDAContext* context) { - Tensor aux(CUDA); - math::ReduceMax(N0, src0, dst, &aux, context); - }, - [](int /*i*/) { return 11.0f; }); - executeGpuBinaryOpTest( - 6, - 1, - 1, - [](int i) { return i == 3 ? 17.0f : 11.0f; }, - [](int /*i*/) { return 0.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* /*src1*/, - float* dst, - CUDAContext* context) { - Tensor aux(CUDA); - math::ReduceMax(N0, src0, dst, &aux, context); - }, - [](int /*i*/) { return 17.0f; }); -} - -TEST(MathUtilGPUTest, testCopyVector) { - executeGpuBinaryOpTest( - 6, - 1, - 6, - [](int i) { return 5.0f - i; }, - [](int /*i*/) { return 0.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* /*src1*/, - float* dst, - CUDAContext* context) { - math::CopyVector(N0, src0, dst, context); - }, - [](int i) { return 5.0f - i; }); -} - -namespace { - -class GemmBatchedGPUTest - : public testing::TestWithParam> { - protected: - void SetUp() override { - if (!HasCudaGPU()) { - return; - } - option_.set_device_type(PROTO_CUDA); - cuda_context_ = make_unique(option_); - Blob* X_blob = ws_.CreateBlob("X"); - Blob* W_blob = ws_.CreateBlob("W"); - Blob* Y_blob = ws_.CreateBlob("Y"); - X_ = BlobGetMutableTensor(X_blob, CUDA); - W_ = BlobGetMutableTensor(W_blob, CUDA); - Y_ = BlobGetMutableTensor(Y_blob, CUDA); - X_->Resize(std::vector{3, 5, 10}); - W_->Resize(std::vector{3, 6, 10}); - Y_->Resize(std::vector{3, 5, 6}); - math::Set( - X_->numel(), 1.0f, X_->mutable_data(), cuda_context_.get()); - math::Set( - W_->numel(), 1.0f, W_->mutable_data(), cuda_context_.get()); - trans_X_ = std::get<0>(GetParam()); - trans_W_ = std::get<1>(GetParam()); - } - - void RunGemmBatched(const float alpha, const float beta) { - const float* X_data = X_->template data(); - const float* W_data = W_->template data(); - float* Y_data = Y_->template mutable_data(); - const int X_stride = 5 * 10; - const int W_stride = 6 * 10; - const int Y_stride = 5 * 6; - std::array X_array = { - X_data, X_data + X_stride, X_data + 2 * X_stride}; - std::array W_array = { - W_data, W_data + W_stride, W_data + 2 * W_stride}; - std::array Y_array = { - Y_data, Y_data + Y_stride, Y_data + 2 * Y_stride}; - math::GemmBatched( - trans_X_ ? CblasTrans : CblasNoTrans, - trans_W_ ? CblasTrans : CblasNoTrans, - 3, - 5, - 6, - 10, - alpha, - X_array.data(), - W_array.data(), - beta, - Y_array.data(), - cuda_context_.get()); - } - - void RunGemmStridedBatched(const float alpha, const float beta) { - const float* X_data = X_->template data(); - const float* W_data = W_->template data(); - float* Y_data = Y_->template mutable_data(); - const int X_stride = 5 * 10; - const int W_stride = 6 * 10; - const int Y_stride = 5 * 6; - math::GemmStridedBatched( - trans_X_ ? CblasTrans : CblasNoTrans, - trans_W_ ? CblasTrans : CblasNoTrans, - 3, - 5, - 6, - 10, - alpha, - X_data, - X_stride, - W_data, - W_stride, - beta, - Y_data, - Y_stride, - cuda_context_.get()); - } - - void VerifyOutput(const float value) const { - Tensor Y_cpu(*Y_, CPU); - for (int i = 0; i < Y_cpu.numel(); ++i) { - EXPECT_FLOAT_EQ(value, Y_cpu.template data()[i]); - } - } - - Workspace ws_; - DeviceOption option_; - std::unique_ptr cuda_context_; - Tensor* X_ = nullptr; - Tensor* W_ = nullptr; - Tensor* Y_ = nullptr; - bool trans_X_; - bool trans_W_; -}; - -TEST_P(GemmBatchedGPUTest, GemmBatchedGPUFloatTest) { - if (!HasCudaGPU()) { - return; - } - RunGemmBatched(1.0f, 0.0f); - VerifyOutput(10.0f); - RunGemmBatched(1.0f, 0.5f); - VerifyOutput(15.0f); - RunGemmBatched(0.5f, 1.0f); - VerifyOutput(20.0f); -} - -TEST_P(GemmBatchedGPUTest, GemmStridedBatchedGPUFloatTest) { - if (!HasCudaGPU()) { - return; - } - RunGemmStridedBatched(1.0f, 0.0f); - VerifyOutput(10.0f); - RunGemmStridedBatched(1.0f, 0.5f); - VerifyOutput(15.0f); - RunGemmStridedBatched(0.5f, 1.0f); - VerifyOutput(20.0f); -} - -INSTANTIATE_TEST_CASE_P( - GemmBatchedGPUTrans, - GemmBatchedGPUTest, - testing::Combine(testing::Bool(), testing::Bool())); - -class BroadcastGPUTest : public testing::Test { - protected: - void SetUp() override { - if (!HasCudaGPU()) { - return; - } - option_.set_device_type(PROTO_CUDA); - cuda_context_ = make_unique(option_); - Blob* blob_x = ws_.CreateBlob("X"); - Blob* blob_y = ws_.CreateBlob("Y"); - X_ = BlobGetMutableTensor(blob_x, CUDA); - Y_ = BlobGetMutableTensor(blob_y, CUDA); - } - - void SetUpData( - const std::vector& X_dims, - const std::vector& Y_dims, - const std::vector& X_data) { - X_->Resize(X_dims); - Y_->Resize(Y_dims); - ASSERT_EQ(X_data.size(), X_->numel()); - cuda_context_->CopyFromCPU( - X_data.size(), X_data.data(), X_->mutable_data()); - } - - void VerifyResult(const std::vector& expected_output) { - Blob* blob_y_host = ws_.CreateBlob("Y_host"); - auto* Y_host = BlobGetMutableTensor(blob_y_host, CPU); - Y_host->CopyFrom(*Y_); - ASSERT_EQ(expected_output.size(), Y_host->numel()); - for (std::size_t i = 0; i < expected_output.size(); ++i) { - EXPECT_FLOAT_EQ(expected_output[i], Y_host->data()[i]); - } - } - - void RunBroadcastTest( - const std::vector& X_dims, - const std::vector& Y_dims, - const std::vector& X_data, - const std::vector& Y_data) { - SetUpData(X_dims, Y_dims, X_data); - math::Broadcast( - X_dims.size(), - X_dims.data(), - Y_dims.size(), - Y_dims.data(), - 1.0f, - X_->data(), - Y_->mutable_data(), - cuda_context_.get()); - VerifyResult(Y_data); - } - - Workspace ws_; - DeviceOption option_; - std::unique_ptr cuda_context_; - Tensor* X_ = nullptr; - Tensor* Y_ = nullptr; -}; - -TEST_F(BroadcastGPUTest, BroadcastGPUFloatTest) { - if (!HasCudaGPU()) { - return; - } - RunBroadcastTest({2}, {2}, {1.0f, 2.0f}, {1.0f, 2.0f}); - RunBroadcastTest({1}, {2}, {1.0f}, {1.0f, 1.0f}); - RunBroadcastTest({1}, {2, 2}, {1.0f}, {1.0f, 1.0f, 1.0f, 1.0f}); - RunBroadcastTest({2, 1}, {2, 2}, {1.0f, 2.0f}, {1.0f, 1.0f, 2.0f, 2.0f}); - RunBroadcastTest( - {2, 1}, - {2, 2, 2}, - {1.0f, 2.0f}, - {1.0f, 1.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 2.0f}); -} - -} // namespace - -} // namespace caffe2 diff --git a/caffe2/utils/math_test.cc b/caffe2/utils/math_test.cc deleted file mode 100644 index 0389a10f29e0..000000000000 --- a/caffe2/utils/math_test.cc +++ /dev/null @@ -1,523 +0,0 @@ -#include -#include -#include - -#include - -#include "caffe2/core/blob.h" -#include "caffe2/core/context.h" -#include "caffe2/core/tensor.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/conversions.h" -#include "caffe2/utils/math.h" - -#include - -namespace caffe2 { - -TEST(MathTest, GemmNoTransNoTrans) { - DeviceOption option; - CPUContext cpu_context(option); - Tensor X(std::vector{5, 10}, CPU); - Tensor W(std::vector{10, 6}, CPU); - Tensor Y(std::vector{5, 6}, CPU); - EXPECT_EQ(X.numel(), 50); - EXPECT_EQ(W.numel(), 60); - math::Set( - X.numel(), 1, X.mutable_data(), &cpu_context); - math::Set( - W.numel(), 1, W.mutable_data(), &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < X.numel(); ++i) { - TORCH_CHECK_EQ(X.data()[i], 1); - } - for (int i = 0; i < W.numel(); ++i) { - TORCH_CHECK_EQ(W.data()[i], 1); - } - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kOne, - X.data(), - W.data(), - kZero, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 10) << i; - } - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kOne, - X.data(), - W.data(), - kPointFive, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 15) << i; - } - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kPointFive, - X.data(), - W.data(), - kOne, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 20) << i; - } -} - -TEST(MathTest, GemmNoTransTrans) { - DeviceOption option; - CPUContext cpu_context(option); - Tensor X(std::vector{5, 10}, CPU); - Tensor W(std::vector{6, 10}, CPU); - Tensor Y(std::vector{5, 6}, CPU); - EXPECT_EQ(X.numel(), 50); - EXPECT_EQ(W.numel(), 60); - math::Set( - X.numel(), 1, X.mutable_data(), &cpu_context); - math::Set( - W.numel(), 1, W.mutable_data(), &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < X.numel(); ++i) { - TORCH_CHECK_EQ(X.data()[i], 1); - } - for (int i = 0; i < W.numel(); ++i) { - TORCH_CHECK_EQ(W.data()[i], 1); - } - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kOne, - X.data(), - W.data(), - kZero, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 10) << i; - } - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kOne, - X.data(), - W.data(), - kPointFive, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 15) << i; - } - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kPointFive, - X.data(), - W.data(), - kOne, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 20) << i; - } -} - -namespace { - -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class GemmBatchedTest - : public testing::TestWithParam> { - protected: - void SetUp() override { - cpu_context_ = make_unique(option_); - ReinitializeTensor( - &X_, std::vector{3, 5, 10}, at::dtype().device(CPU)); - ReinitializeTensor( - &W_, std::vector{3, 6, 10}, at::dtype().device(CPU)); - ReinitializeTensor( - &Y_, std::vector{3, 5, 6}, at::dtype().device(CPU)); - math::Set( - X_.numel(), 1, X_.mutable_data(), cpu_context_.get()); - math::Set( - W_.numel(), 1, W_.mutable_data(), cpu_context_.get()); - trans_X_ = std::get<0>(GetParam()); - trans_W_ = std::get<1>(GetParam()); - } - - void RunGemmBatched(const float alpha, const float beta) { - const float* X_data = X_.template data(); - const float* W_data = W_.template data(); - float* Y_data = Y_.template mutable_data(); - const int X_stride = 5 * 10; - const int W_stride = 6 * 10; - const int Y_stride = 5 * 6; - std::array X_array = { - X_data, X_data + X_stride, X_data + 2 * X_stride}; - std::array W_array = { - W_data, W_data + W_stride, W_data + 2 * W_stride}; - std::array Y_array = { - Y_data, Y_data + Y_stride, Y_data + 2 * Y_stride}; - math::GemmBatched( - trans_X_ ? CblasTrans : CblasNoTrans, - trans_W_ ? CblasTrans : CblasNoTrans, - 3, - 5, - 6, - 10, - alpha, - X_array.data(), - W_array.data(), - beta, - Y_array.data(), - cpu_context_.get()); - } - - void RunGemmStridedBatched(const float alpha, const float beta) { - const float* X_data = X_.template data(); - const float* W_data = W_.template data(); - float* Y_data = Y_.template mutable_data(); - const int X_stride = 5 * 10; - const int W_stride = 6 * 10; - const int Y_stride = 5 * 6; - math::GemmStridedBatched( - trans_X_ ? CblasTrans : CblasNoTrans, - trans_W_ ? CblasTrans : CblasNoTrans, - 3, - 5, - 6, - 10, - alpha, - X_data, - X_stride, - W_data, - W_stride, - beta, - Y_data, - Y_stride, - cpu_context_.get()); - } - - void VerifyOutput(const float value) const { - for (int i = 0; i < Y_.numel(); ++i) { - EXPECT_FLOAT_EQ(value, Y_.template data()[i]); - } - } - - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - DeviceOption option_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::unique_ptr cpu_context_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Tensor X_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Tensor W_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Tensor Y_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - bool trans_X_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - bool trans_W_; -}; - -TEST_P(GemmBatchedTest, GemmBatchedFloatTest) { - RunGemmBatched(1.0f, 0.0f); - VerifyOutput(10.0f); - RunGemmBatched(1.0f, 0.5f); - VerifyOutput(15.0f); - RunGemmBatched(0.5f, 1.0f); - VerifyOutput(20.0f); -} - -TEST_P(GemmBatchedTest, GemmStridedBatchedFloatTest) { - RunGemmStridedBatched(1.0f, 0.0f); - VerifyOutput(10.0f); - RunGemmStridedBatched(1.0f, 0.5f); - VerifyOutput(15.0f); - RunGemmStridedBatched(0.5f, 1.0f); - VerifyOutput(20.0f); -} - -INSTANTIATE_TEST_CASE_P( - GemmBatchedTrans, - GemmBatchedTest, - testing::Combine(testing::Bool(), testing::Bool())); - -} // namespace - -TEST(MathTest, GemvNoTrans) { - DeviceOption option; - CPUContext cpu_context(option); - Tensor A(std::vector{5, 10}, CPU); - Tensor X(std::vector{10}, CPU); - Tensor Y(std::vector{5}, CPU); - EXPECT_EQ(A.numel(), 50); - EXPECT_EQ(X.numel(), 10); - math::Set( - A.numel(), 1, A.mutable_data(), &cpu_context); - math::Set( - X.numel(), 1, X.mutable_data(), &cpu_context); - EXPECT_EQ(Y.numel(), 5); - for (int i = 0; i < A.numel(); ++i) { - TORCH_CHECK_EQ(A.data()[i], 1); - } - for (int i = 0; i < X.numel(); ++i) { - TORCH_CHECK_EQ(X.data()[i], 1); - } - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemv( - CblasNoTrans, - 5, - 10, - kOne, - A.data(), - X.data(), - kZero, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 10) << i; - } - // Test Accumulate - math::Gemv( - CblasNoTrans, - 5, - 10, - kOne, - A.data(), - X.data(), - kPointFive, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 15) << i; - } - // Test Accumulate - math::Gemv( - CblasNoTrans, - 5, - 10, - kPointFive, - A.data(), - X.data(), - kOne, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 20) << i; - } -} - -TEST(MathTest, GemvTrans) { - DeviceOption option; - CPUContext cpu_context(option); - Tensor A(std::vector{6, 10}, CPU); - Tensor X(std::vector{6}, CPU); - Tensor Y(std::vector{10}, CPU); - EXPECT_EQ(A.numel(), 60); - EXPECT_EQ(X.numel(), 6); - math::Set( - A.numel(), 1, A.mutable_data(), &cpu_context); - math::Set( - X.numel(), 1, X.mutable_data(), &cpu_context); - EXPECT_EQ(Y.numel(), 10); - for (int i = 0; i < A.numel(); ++i) { - TORCH_CHECK_EQ(A.data()[i], 1); - } - for (int i = 0; i < X.numel(); ++i) { - TORCH_CHECK_EQ(X.data()[i], 1); - } - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemv( - CblasTrans, - 6, - 10, - kOne, - A.data(), - X.data(), - kZero, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 6) << i; - } - // Test Accumulate - math::Gemv( - CblasTrans, - 6, - 10, - kOne, - A.data(), - X.data(), - kPointFive, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 9) << i; - } - // Test Accumulate - math::Gemv( - CblasTrans, - 6, - 10, - kPointFive, - A.data(), - X.data(), - kOne, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 12) << i; - } -} - -TEST(MathTest, FloatToHalfConversion) { - float a = 1.0f; - float b = 1.75f; - float c = 128.125f; - - float converted_a = static_cast(at::Half(a)); - float converted_b = static_cast(at::Half(b)); - float converted_c = static_cast(at::Half(c)); - - TORCH_CHECK_EQ(a, converted_a); - TORCH_CHECK_EQ(b, converted_b); - TORCH_CHECK_EQ(c, converted_c); -} - -namespace { - -class BroadcastTest : public testing::Test { - protected: - void SetUp() override { - cpu_context_ = make_unique(option_); - } - - void RunBroadcastTest( - const std::vector& X_dims, - const std::vector& Y_dims, - const std::vector& X_data, - const std::vector& Y_data) { - std::vector X_dims_64; - std::vector Y_dims_64; - std::copy(X_dims.cbegin(), X_dims.cend(), std::back_inserter(X_dims_64)); - std::copy(Y_dims.cbegin(), Y_dims.cend(), std::back_inserter(Y_dims_64)); - ReinitializeTensor(&X_, X_dims_64, at::dtype().device(CPU)); - ReinitializeTensor(&Y_, Y_dims_64, at::dtype().device(CPU)); - ASSERT_EQ(X_data.size(), X_.numel()); - cpu_context_->CopyFromCPU( - X_data.size(), X_data.data(), X_.mutable_data()); - for (bool allow_broadcast_fastpath : {false, true}) { - math::Broadcast( - X_dims.size(), - X_dims.data(), - Y_dims.size(), - Y_dims.data(), - 1.0f, - X_.data(), - Y_.mutable_data(), - cpu_context_.get(), - allow_broadcast_fastpath); - ASSERT_EQ(Y_data.size(), Y_.numel()); - for (const auto i : c10::irange(Y_data.size())) { - EXPECT_FLOAT_EQ(Y_data[i], Y_.data()[i]); - } - } - } - - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - DeviceOption option_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::unique_ptr cpu_context_; - - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Tensor X_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Tensor Y_; -}; - -TEST_F(BroadcastTest, BroadcastFloatTest) { - RunBroadcastTest({2}, {2}, {1.0f, 2.0f}, {1.0f, 2.0f}); - RunBroadcastTest({1}, {2}, {1.0f}, {1.0f, 1.0f}); - RunBroadcastTest({1}, {2, 2}, {1.0f}, {1.0f, 1.0f, 1.0f, 1.0f}); - RunBroadcastTest({2, 1}, {2, 2}, {1.0f, 2.0f}, {1.0f, 1.0f, 2.0f, 2.0f}); - RunBroadcastTest({1, 2}, {2, 2}, {1.0f, 2.0f}, {1.0f, 2.0f, 1.0f, 2.0f}); - RunBroadcastTest( - {2, 1}, - {2, 2, 2}, - {1.0f, 2.0f}, - {1.0f, 1.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 2.0f}); - RunBroadcastTest( - {1, 2}, - {2, 2, 2}, - {1.0f, 2.0f}, - {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}); -} - -class RandFixedSumTest : public testing::Test { - protected: - void SetUp() override { - cpu_context_ = make_unique(option_); - } - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - DeviceOption option_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::unique_ptr cpu_context_; -}; - -TEST_F(RandFixedSumTest, UpperBound) { - std::vector l(20); - math::RandFixedSum( - 20, 1, 1000, 1000, l.data(), cpu_context_.get()); -} - -} // namespace - -} // namespace caffe2 From 7de13524571f816b96e1d391b7ce0137786d0239 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 29 May 2024 16:13:58 +0000 Subject: [PATCH 040/706] [1/N] Replace exceptions with static_assert(false) in some templates (#127371) This PR tries to report some failures at build time. Once the build fails, it generally indicates that we can wrap the code inside some conditional macros, and it is a hint to further reduce the built code size. The sizeof operations were used to ensure that the assertion dependents on specific template instantiations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127371 Approved by: https://github.com/ezyang, https://github.com/Skylion007 --- aten/src/ATen/cuda/CUDABlas.h | 33 ++--- aten/src/ATen/cuda/CUDADataType.h | 3 +- aten/src/ATen/native/BlasKernel.cpp | 4 + .../ATen/native/cpu/ScatterGatherKernel.cpp | 2 +- aten/src/ATen/native/cuda/linalg/CUDASolver.h | 135 +++++++----------- .../native/sparse/cuda/SparseCUDABlas.cpp | 4 +- .../ATen/native/sparse/cuda/SparseMatMul.cu | 2 +- 7 files changed, 71 insertions(+), 112 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index d418dc53af38..2c6cef95f79f 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -48,7 +48,7 @@ class PointerModeGuard { template inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::gemm: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm: not implemented"); } template <> @@ -66,7 +66,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); template inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::gemm_internal: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm_internal: not implemented"); } template <> @@ -154,7 +154,7 @@ void scaled_gemm( template inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm: not implemented"); } template <> @@ -172,7 +172,7 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); template inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::bgemm_internal: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm_internal: not implemented"); } template <> @@ -195,7 +195,7 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); template inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::trsm: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsm: not implemented"); } template <> @@ -215,10 +215,7 @@ TORCH_CUDA_CU_API void trsm>(CUDABLAS_TRSM_ARGTYPES(c10::co template inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::blas::trsmBatched: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsmBatched: not implemented"); } template <> @@ -238,7 +235,7 @@ TORCH_CUDA_CU_API void trsmBatched>(CUDABLAS_TRSM_BATCHED_A template inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::gemv: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::blas::gemv: not implemented"); } template <> @@ -262,7 +259,7 @@ void gemv(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)); template inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::dot: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::dot: not implemented"); } template <> @@ -280,7 +277,7 @@ void dot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); template inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::vdot: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::vdot: not implemented"); } template <> @@ -295,8 +292,7 @@ void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); template void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::getrsBatched: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented"); } template<> TORCH_CUDA_CU_API void getrsBatched(CUDABLAS_GETRS_ARGTYPES(float)); @@ -313,10 +309,7 @@ TORCH_CUDA_CU_API void getrsBatched>(CUDABLAS_GETRS_ARGTYPE template void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::blas::geqrfBatched: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::blas::geqrfBatched: not implemented"); } template <> TORCH_CUDA_CU_API void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)); @@ -334,7 +327,7 @@ TORCH_CUDA_CU_API void geqrfBatched>( template void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) { - TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented for ", typeid(Dtype).name()); + TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented"); } template<> TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_ARGTYPES(float)); @@ -350,7 +343,7 @@ TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_ARGTYPES template void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::gelsBatched: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::gelsBatched: not implemented"); } template<> diff --git a/aten/src/ATen/cuda/CUDADataType.h b/aten/src/ATen/cuda/CUDADataType.h index 8615bcdae911..1696bb3a0f44 100644 --- a/aten/src/ATen/cuda/CUDADataType.h +++ b/aten/src/ATen/cuda/CUDADataType.h @@ -9,7 +9,8 @@ namespace at::cuda { template cudaDataType getCudaDataType() { - TORCH_INTERNAL_ASSERT(false, "Cannot convert type ", typeid(scalar_t).name(), " to cudaDataType.") + static_assert(false && sizeof(scalar_t), "Cannot convert type to cudaDataType."); + return {}; } template<> inline cudaDataType getCudaDataType() { diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index bc601885b54e..fb4289eb989a 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -559,12 +559,14 @@ template inline void scal(int64_t n, scalar_t a, scalar_t *x, int64_t incx) { if (n == 1) incx = 1; +#if AT_BUILD_WITH_BLAS() if (blas_impl::scal_use_fast_path(n, incx)) { int i_n = (int)n; int i_incx = (int)incx; blas_impl::scal_fast_path(&i_n, &a, x, &i_incx); return; } +#endif for (const auto i : c10::irange(n)) { if (a == scalar_t(0)) { x[i * incx] = 0; @@ -578,6 +580,7 @@ template void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy) { if(n == 1) lda = m; +#if AT_BUILD_WITH_BLAS() if (blas_impl::gemv_use_fast_path(m, n, lda, incx, incy)) { TORCH_CHECK(lda >= std::max(1L, m), "lda should be at least max(1,", m, "), but have ", lda); int i_m = (int)m; @@ -588,6 +591,7 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, i blas_impl::gemv_fast_path(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); return; } +#endif using opmath_t = at::opmath_type; if ((trans == 'T') || (trans == 't')) { diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index bcfc26c7df7d..95119b5ac085 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -655,7 +655,7 @@ std::pair radix_sort_parallel( const int64_t elements_count, const int64_t max_value) { TORCH_INTERNAL_ASSERT(false, "radix_sort_parallel: ATen not compiled with FBGEMM support"); - std::make_pair(nullptr, nullptr); + return std::make_pair(nullptr, nullptr); } } diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.h b/aten/src/ATen/native/cuda/linalg/CUDASolver.h index b8901d1d6f5d..9b17086646d8 100644 --- a/aten/src/ATen/native/cuda/linalg/CUDASolver.h +++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.h @@ -18,7 +18,7 @@ namespace solver { template void getrf(CUDASOLVER_GETRF_ARGTYPES(Dtype)) { - TORCH_CHECK(false, "at::cuda::solver::getrf: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::getrf: not implemented"); } template<> void getrf(CUDASOLVER_GETRF_ARGTYPES(float)); @@ -35,7 +35,7 @@ void getrf>(CUDASOLVER_GETRF_ARGTYPES(c10::complex)); template void getrs(CUDASOLVER_GETRS_ARGTYPES(Dtype)) { - TORCH_CHECK(false, "at::cuda::solver::getrs: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::getrs: not implemented"); } template<> void getrs(CUDASOLVER_GETRS_ARGTYPES(float)); @@ -51,10 +51,8 @@ void getrs>(CUDASOLVER_GETRS_ARGTYPES(c10::complex)); template void sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(Dtype)) { - TORCH_CHECK( - false, - "at::cuda::solver::sytrf_bufferSize: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), + "at::cuda::solver::sytrf_bufferSize: not implemented"); } template <> void sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(float)); @@ -73,10 +71,8 @@ void sytrf_bufferSize>( template void sytrf(CUDASOLVER_SYTRF_ARGTYPES(Dtype)) { - TORCH_CHECK( - false, - "at::cuda::solver::sytrf: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), + "at::cuda::solver::sytrf: not implemented"); } template <> void sytrf(CUDASOLVER_SYTRF_ARGTYPES(float)); @@ -93,7 +89,7 @@ void sytrf>(CUDASOLVER_SYTRF_ARGTYPES(c10::complex)); template void gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) { - TORCH_CHECK(false, "at::cuda::solver::gesvd_buffersize: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvd_buffersize: not implemented"); } template<> void gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()); @@ -111,7 +107,7 @@ void gesvd_buffersize>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES template void gesvd(CUDASOLVER_GESVD_ARGTYPES(Dtype, Vtype)) { - TORCH_CHECK(false, "at::cuda::solver::gesvd: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvd: not implemented"); } template<> void gesvd(CUDASOLVER_GESVD_ARGTYPES(float, float)); @@ -129,7 +125,7 @@ void gesvd>(CUDASOLVER_GESVD_ARGTYPES(c10::complex, template void gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(Dtype, Vtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdj_buffersize: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj_buffersize: not implemented"); } template<> void gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(float, float)); @@ -147,7 +143,7 @@ void gesvdj_buffersize>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYP template void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdj: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj: not implemented"); } template<> void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(float, float)); @@ -165,7 +161,7 @@ void gesvdj>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdj: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj: not implemented"); } template<> void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(float, float)); @@ -183,7 +179,7 @@ void gesvdjBatched>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10: template void gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(Dtype, Vtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdaStridedBatched_buffersize: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdaStridedBatched_buffersize: not implemented"); } template<> void gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(float, float)); @@ -203,7 +199,7 @@ void gesvdaStridedBatched_buffersize>(CUDASOLVER_GESVDA_STR template void gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(Dtype, Vtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdaStridedBatched: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdaStridedBatched: not implemented"); } template<> void gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(float, float)); @@ -220,7 +216,7 @@ void gesvdaStridedBatched>(CUDASOLVER_GESVDA_STRIDED_BATCHE template void potrf(CUDASOLVER_POTRF_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::potrf: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrf: not implemented"); } template<> void potrf(CUDASOLVER_POTRF_ARGTYPES(float)); @@ -237,7 +233,7 @@ void potrf>(CUDASOLVER_POTRF_ARGTYPES(c10::complex) template void potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::potrf_buffersize: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrf_buffersize: not implemented"); } template<> void potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(float)); @@ -254,7 +250,7 @@ void potrf_buffersize>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES template void potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::potrfBatched: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrfBatched: not implemented"); } template<> void potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES(float)); @@ -270,10 +266,8 @@ void potrfBatched>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(c10::c template void geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(scalar_t)) { - TORCH_CHECK( - false, - "at::cuda::solver::geqrf_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::geqrf_bufferSize: not implemented"); } template <> void geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(float)); @@ -292,10 +286,8 @@ void geqrf_bufferSize>( template void geqrf(CUDASOLVER_GEQRF_ARGTYPES(scalar_t)) { - TORCH_CHECK( - false, - "at::cuda::solver::geqrf: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::geqrf: not implemented"); } template <> void geqrf(CUDASOLVER_GEQRF_ARGTYPES(float)); @@ -312,7 +304,7 @@ void geqrf>( template void potrs(CUDASOLVER_POTRS_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::potrs: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrs: not implemented"); } template<> void potrs(CUDASOLVER_POTRS_ARGTYPES(float)); @@ -329,7 +321,7 @@ void potrs>(CUDASOLVER_POTRS_ARGTYPES(c10::complex) template void potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::potrsBatched: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrsBatched: not implemented"); } template<> void potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES(float)); @@ -347,10 +339,7 @@ void potrsBatched>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(c10::c template void orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(Dtype)) { - TORCH_CHECK( - false, - "at::cuda::solver::orgqr_buffersize: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::orgqr_buffersize: not implemented"); } template <> void orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(float)); @@ -368,10 +357,7 @@ void orgqr_buffersize>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES template void orgqr(CUDASOLVER_ORGQR_ARGTYPES(Dtype)) { - TORCH_CHECK( - false, - "at::cuda::solver::orgqr: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::orgqr: not implemented"); } template <> void orgqr(CUDASOLVER_ORGQR_ARGTYPES(float)); @@ -389,10 +375,8 @@ void orgqr>(CUDASOLVER_ORGQR_ARGTYPES(c10::complex) template void ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::ormqr_bufferSize: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), + "at::cuda::solver::ormqr_bufferSize: not implemented"); } template <> void ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(float)); @@ -412,10 +396,8 @@ void ormqr_bufferSize>( template void ormqr(CUDASOLVER_ORMQR_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::ormqr: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), + "at::cuda::solver::ormqr: not implemented"); } template <> void ormqr(CUDASOLVER_ORMQR_ARGTYPES(float)); @@ -431,7 +413,8 @@ void ormqr>( template cudaDataType get_cusolver_datatype() { - TORCH_CHECK(false, "cusolver doesn't support data type ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "cusolver doesn't support data type"); + return {}; } template<> cudaDataType get_cusolver_datatype(); template<> cudaDataType get_cusolver_datatype(); @@ -459,10 +442,8 @@ void xpotrs( template void syevd_bufferSize(CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevd_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::syevd_bufferSize: not implemented"); } template <> @@ -485,10 +466,8 @@ void syevd_bufferSize, double>( template void syevd(CUDASOLVER_SYEVD_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevd: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::syevd: not implemented"); } template <> @@ -509,10 +488,8 @@ void syevd, double>( template void syevj_bufferSize(CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevj_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::syevj_bufferSize: not implemented"); } template <> @@ -535,10 +512,7 @@ void syevj_bufferSize, double>( template void syevj(CUDASOLVER_SYEVJ_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevj: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), "at::cuda::solver::syevj: not implemented"); } template <> @@ -560,10 +534,8 @@ void syevj, double>( template void syevjBatched_bufferSize( CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevjBatched_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::syevjBatched_bufferSize: not implemented"); } template <> @@ -586,10 +558,8 @@ void syevjBatched_bufferSize, double>( template void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevjBatched: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::syevjBatched: not implemented"); } template <> @@ -612,10 +582,8 @@ void syevjBatched, double>( template void xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(scalar_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::xgeqrf_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::xgeqrf_bufferSize: not implemented"); } template <> @@ -637,10 +605,7 @@ void xgeqrf_bufferSize>( template void xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES(scalar_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::xgeqrf: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), "at::cuda::solver::xgeqrf: not implemented"); } template <> @@ -663,10 +628,8 @@ void xgeqrf>( template void xsyevd_bufferSize( CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::xsyevd_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::xsyevd_bufferSize: not implemented"); } template <> @@ -691,10 +654,8 @@ void xsyevd_bufferSize, double>( template void xsyevd(CUDASOLVER_XSYEVD_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::xsyevd: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::xsyevd: not implemented"); } template <> diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp index f3aabc63e2a2..d5f654097677 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp @@ -204,7 +204,7 @@ void csrmm2( T alpha, T *csrvala, int *csrrowptra, int *csrcolinda, T *b, int64_t ldb, T beta, T *c, int64_t ldc) { - TORCH_INTERNAL_ASSERT(false, "cusparse csr MM only supports data type of float, double, cfloat and cdouble."); + static_assert(false&&sizeof(T), "cusparse csr MM only supports data type of float, double, cfloat and cdouble."); } template<> void csrmm2( @@ -381,7 +381,7 @@ void csrmm2( T alpha, T *csrvala, int *csrrowptra, int *csrcolinda, T *b, int64_t ldb, T beta, T *c, int64_t ldc) { - TORCH_INTERNAL_ASSERT(false, "cusparse csr MM only supports data type of float, double, cfloat and cdouble."); + static_assert(false&&sizeof(T), "cusparse csr MM only supports data type of float, double, cfloat and cdouble."); } template<> void csrmm2( diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu index 20af0ee866a5..88c3ee05ab53 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu @@ -387,7 +387,7 @@ struct CusparseMatrixMultiplyOp { Tensor &output_values, Tensor &output_indices) { - TORCH_INTERNAL_ASSERT(false, "cusparse csr sparse-sparse MM only supports data type of float and double."); + static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double."); } }; From e8e327ba823374640cd7a64ad26dae618eaabc92 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 29 May 2024 17:05:25 +0000 Subject: [PATCH 041/706] Cover clang-tidy to torch/csrc/onnx/init.cpp (#127393) Enabling it will not cause issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127393 Approved by: https://github.com/Skylion007 --- .lintrunner.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 2dc4305d9ab9..1e0a2f37fcf4 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -235,7 +235,6 @@ exclude_patterns = [ 'torch/csrc/jit/serialization/import_legacy.cpp', 'torch/csrc/jit/serialization/export.cpp', 'torch/csrc/lazy/**/*', - 'torch/csrc/onnx/init.cpp', 'torch/csrc/mps/**/*', ] init_command = [ From cc6e72d8822d0b6ac05a9334c1db295177ee52ea Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 29 May 2024 17:11:45 +0000 Subject: [PATCH 042/706] Drop caffe2 core tests and some other stuff (#127089) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/127089 Approved by: https://github.com/Skylion007 --- caffe2/README.md | 19 - caffe2/VERSION_NUMBER | 1 - caffe2/core/blob_gpu_test.cc | 227 ---- caffe2/core/blob_test.cc | 1306 ----------------------- caffe2/core/context_gpu_test.cc | 161 --- caffe2/core/context_test.cc | 38 - caffe2/core/event_gpu_test.cc | 50 - caffe2/core/event_test.cc | 41 - caffe2/core/graph_test.cc | 200 ---- caffe2/core/init_test.cc | 72 -- caffe2/core/module_test.cc | 78 -- caffe2/core/net_async_tracing_test.cc | 114 -- caffe2/core/net_dag_utils_test.cc | 296 ----- caffe2/core/net_gpu_test.cc | 130 --- caffe2/core/net_simple_refcount_test.cc | 70 -- caffe2/core/net_test.cc | 1122 ------------------- caffe2/core/observer_test.cc | 183 ---- caffe2/core/operator_gpu_test.cc | 63 -- caffe2/core/operator_schema_test.cc | 279 ----- caffe2/core/operator_test.cc | 634 ----------- caffe2/core/parallel_net_test.cc | 322 ------ caffe2/core/plan_executor_test.cc | 414 ------- caffe2/core/serialization_test.cc | 101 -- caffe2/core/stats_test.cc | 151 --- caffe2/core/timer_test.cc | 65 -- caffe2/core/transform_test.cc | 460 -------- caffe2/core/workspace_test.cc | 149 --- caffe2/release-notes.md | 175 --- caffe2/requirements.txt | 4 - 29 files changed, 6925 deletions(-) delete mode 100644 caffe2/README.md delete mode 100644 caffe2/VERSION_NUMBER delete mode 100644 caffe2/core/blob_gpu_test.cc delete mode 100644 caffe2/core/blob_test.cc delete mode 100644 caffe2/core/context_gpu_test.cc delete mode 100644 caffe2/core/context_test.cc delete mode 100644 caffe2/core/event_gpu_test.cc delete mode 100644 caffe2/core/event_test.cc delete mode 100644 caffe2/core/graph_test.cc delete mode 100644 caffe2/core/init_test.cc delete mode 100644 caffe2/core/module_test.cc delete mode 100644 caffe2/core/net_async_tracing_test.cc delete mode 100644 caffe2/core/net_dag_utils_test.cc delete mode 100644 caffe2/core/net_gpu_test.cc delete mode 100644 caffe2/core/net_simple_refcount_test.cc delete mode 100644 caffe2/core/net_test.cc delete mode 100644 caffe2/core/observer_test.cc delete mode 100644 caffe2/core/operator_gpu_test.cc delete mode 100644 caffe2/core/operator_schema_test.cc delete mode 100644 caffe2/core/operator_test.cc delete mode 100644 caffe2/core/parallel_net_test.cc delete mode 100644 caffe2/core/plan_executor_test.cc delete mode 100644 caffe2/core/serialization_test.cc delete mode 100644 caffe2/core/stats_test.cc delete mode 100644 caffe2/core/timer_test.cc delete mode 100644 caffe2/core/transform_test.cc delete mode 100644 caffe2/core/workspace_test.cc delete mode 100644 caffe2/release-notes.md delete mode 100644 caffe2/requirements.txt diff --git a/caffe2/README.md b/caffe2/README.md deleted file mode 100644 index 13171fca23bb..000000000000 --- a/caffe2/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Caffe2 - -Caffe2 is a lightweight, modular, and scalable deep learning framework. Building on the original [Caffe](http://caffe.berkeleyvision.org), Caffe2 is designed with expression, speed, and modularity in mind. - -## Questions and Feedback - -Please use GitHub issues (https://github.com/pytorch/pytorch/issues) to ask questions, report bugs, and request new features. - -### Further Resources on [Caffe2.ai](http://caffe2.ai) - -* [Installation](http://caffe2.ai/docs/getting-started.html) -* [Learn More](http://caffe2.ai/docs/learn-more.html) -* [Upgrading to Caffe2](http://caffe2.ai/docs/caffe-migration.html) -* [Datasets](http://caffe2.ai/docs/datasets.html) -* [Model Zoo](http://caffe2.ai/docs/zoo.html) -* [Tutorials](http://caffe2.ai/docs/tutorials.html) -* [Operators Catalogue](http://caffe2.ai/docs/operators-catalogue.html) -* [C++ API](http://caffe2.ai/doxygen-c/html/classes.html) -* [Python API](http://caffe2.ai/doxygen-python/html/namespaces.html) diff --git a/caffe2/VERSION_NUMBER b/caffe2/VERSION_NUMBER deleted file mode 100644 index 100435be135a..000000000000 --- a/caffe2/VERSION_NUMBER +++ /dev/null @@ -1 +0,0 @@ -0.8.2 diff --git a/caffe2/core/blob_gpu_test.cc b/caffe2/core/blob_gpu_test.cc deleted file mode 100644 index de6ea99c0395..000000000000 --- a/caffe2/core/blob_gpu_test.cc +++ /dev/null @@ -1,227 +0,0 @@ -#include // NOLINT - -#include -#include "caffe2/core/blob.h" -#include "caffe2/core/blob_serialization.h" -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/proto/caffe2_pb.h" - -namespace caffe2 { -namespace { - -template class TensorGPUTest : public ::testing::Test {}; -template class TensorGPUDeathTest : public ::testing::Test {}; -typedef ::testing::Types TensorTypes; -TYPED_TEST_CASE(TensorGPUTest, TensorTypes); -TYPED_TEST_CASE(TensorGPUDeathTest, TensorTypes); - -TYPED_TEST(TensorGPUTest, TensorInitializedEmpty) { - if (!caffe2::HasCudaGPU()) return; - Tensor tensor(CUDA); - EXPECT_EQ(tensor.numel(), 0); - EXPECT_EQ(tensor.dim(), 1); - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - tensor.Resize(dims); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 3); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); -} - -TYPED_TEST(TensorGPUTest, TensorInitializedNonEmpty) { - if (!HasCudaGPU()) return; - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CUDA); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 3); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); - dims[0] = 7; - dims[1] = 11; - dims[2] = 13; - dims.push_back(17); - tensor.Resize(dims); - EXPECT_EQ(tensor.dim(), 4); - EXPECT_EQ(tensor.dim32(0), 7); - EXPECT_EQ(tensor.dim32(1), 11); - EXPECT_EQ(tensor.dim32(2), 13); - EXPECT_EQ(tensor.dim32(3), 17); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); -} - -TYPED_TEST(TensorGPUTest, TensorAlias) { - if (!HasCudaGPU()) return; - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CUDA); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(other_tensor.data() != nullptr); - EXPECT_EQ(tensor.data(), other_tensor.data()); -} - -TYPED_TEST(TensorGPUTest, TensorAliasCanUseDifferentShapes) { - if (!HasCudaGPU()) return; - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - vector alternate_dims(1); - alternate_dims[0] = 2 * 3 * 5; - Tensor tensor(dims, CUDA); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - other_tensor.Resize(alternate_dims); - EXPECT_EQ(other_tensor.dim(), 1); - EXPECT_EQ(other_tensor.dim32(0), alternate_dims[0]); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(other_tensor.data() != nullptr); - EXPECT_EQ(tensor.data(), other_tensor.data()); -} - -TYPED_TEST(TensorGPUTest, NoLongerAliasAfterNumelChanges) { - if (!HasCudaGPU()) return; - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CUDA); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - EXPECT_EQ(tensor.data(), other_tensor.data()); - auto* old_pointer = other_tensor.data(); - - dims[0] = 7; - tensor.Resize(dims); - EXPECT_EQ(old_pointer, other_tensor.data()); - EXPECT_NE(old_pointer, tensor.mutable_data()); -} - -TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) { - if (!HasCudaGPU()) return; - ::testing::FLAGS_gtest_death_test_style = "threadsafe"; - Tensor tensor(CUDA); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.numel(), 0); - EXPECT_THROW(tensor.data(), EnforceNotMet); -} - -#define TEST_SERIALIZATION_GPU_WITH_TYPE(TypeParam, field_name) \ - TEST(TensorGPUTest, TensorSerialization_##TypeParam) { \ - if (!HasCudaGPU()) { \ - return; \ - } \ - Blob blob; \ - Tensor cpu_tensor(CPU); \ - cpu_tensor.Resize(2, 3); \ - for (int i = 0; i < 6; ++i) { \ - cpu_tensor.mutable_data()[i] = static_cast(i); \ - } \ - BlobGetMutableTensor(&blob, CUDA)->CopyFrom(cpu_tensor); \ - string serialized = SerializeBlob(blob, "test"); \ - BlobProto proto; \ - CAFFE_ENFORCE(proto.ParseFromString(serialized)); \ - EXPECT_EQ(proto.name(), "test"); \ - EXPECT_EQ(proto.type(), "Tensor"); \ - EXPECT_TRUE(proto.has_tensor()); \ - const TensorProto& tensor_proto = proto.tensor(); \ - EXPECT_EQ( \ - tensor_proto.data_type(), \ - TypeMetaToDataType(TypeMeta::Make())); \ - EXPECT_EQ(tensor_proto.field_name##_size(), 6); \ - for (int i = 0; i < 6; ++i) { \ - EXPECT_EQ(tensor_proto.field_name(i), static_cast(i)); \ - } \ - Blob new_blob; \ - EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \ - EXPECT_TRUE(BlobIsTensorType(new_blob, CUDA)); \ - Tensor new_cpu_tensor(blob.Get(), CPU); \ - EXPECT_EQ(new_cpu_tensor.dim(), 2); \ - EXPECT_EQ(new_cpu_tensor.size(0), 2); \ - EXPECT_EQ(new_cpu_tensor.size(1), 3); \ - for (int i = 0; i < 6; ++i) { \ - EXPECT_EQ( \ - cpu_tensor.data()[i], \ - new_cpu_tensor.data()[i]); \ - } \ - } - -TEST_SERIALIZATION_GPU_WITH_TYPE(bool, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(double, double_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(float, float_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(int, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(int8_t, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(int16_t, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(uint8_t, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(uint16_t, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(int64_t, int64_data) - -TEST(TensorConstruction, ReinitializeTensorTest) { - if (!HasCudaGPU()) return; - Tensor x = caffe2::empty({1}, at::dtype().device(CUDA, 0)); - auto* data_before = x.template mutable_data(); - // We'll only compare device_type in ReinitializeTensor, - // so no tensor reallocation will happen here - ReinitializeTensor(&x, {1}, at::dtype().device(CUDA)); - auto* data_after = x.template mutable_data(); - EXPECT_EQ(data_before, data_after); -} - -TEST(TensorTest, TensorSerializationMultiDevices) { - Blob blob; - Tensor tensor(CPU); - tensor.Resize(2, 3); - for (int i = 0; i < 6; ++i) { - tensor.mutable_data()[i] = i; - } - for (int gpu_id = 0; gpu_id < NumCudaDevices(); ++gpu_id) { - CUDAGuard guard(gpu_id); - CUDAContext context(gpu_id); // switch to the current gpu - blob.Reset(new Tensor(tensor, CUDA)); - string serialized = SerializeBlob(blob, "test"); - BlobProto proto; - CAFFE_ENFORCE(proto.ParseFromString(serialized)); - EXPECT_EQ(proto.name(), "test"); - EXPECT_TRUE(proto.has_tensor()); - const TensorProto& tensor_proto = proto.tensor(); - EXPECT_EQ(tensor_proto.data_type(), TensorProto::FLOAT); - EXPECT_EQ(tensor_proto.float_data_size(), 6); - for (int i = 0; i < 6; ++i) { - EXPECT_EQ(tensor_proto.float_data(i), i); - } - EXPECT_TRUE(tensor_proto.has_device_detail()); - EXPECT_EQ(tensor_proto.device_detail().device_type(), PROTO_CUDA); - EXPECT_EQ(tensor_proto.device_detail().device_id(), gpu_id); - // Test if the restored blob is still of the same device. - blob.Reset(); - EXPECT_NO_THROW(DeserializeBlob(serialized, &blob)); - EXPECT_TRUE(BlobIsTensorType(blob, CUDA)); - EXPECT_EQ(GetGPUIDForPointer(blob.Get().data()), - gpu_id); - // Test if we force the restored blob on a different device, we - // can still get so. - blob.Reset(); - proto.mutable_tensor()->mutable_device_detail()->set_device_id(0); - EXPECT_NO_THROW(DeserializeBlob(proto.SerializeAsString(), &blob)); - EXPECT_TRUE(BlobIsTensorType(blob, CUDA)); - EXPECT_EQ(GetGPUIDForPointer(blob.Get().data()), 0); - } -} - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc deleted file mode 100644 index a7e3a8d27e23..000000000000 --- a/caffe2/core/blob_test.cc +++ /dev/null @@ -1,1306 +0,0 @@ -#include -#include -#include - -#include -#include "c10/util/Registry.h" -#include "caffe2/core/blob.h" -#include "caffe2/core/blob_serialization.h" -#include "caffe2/core/common.h" -#include "caffe2/core/context.h" -#include "caffe2/core/db.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/qtensor.h" -#include "caffe2/core/qtensor_serialization.h" -#include "caffe2/core/tensor.h" -#include "caffe2/core/test_utils.h" -#include "caffe2/core/types.h" -#include "caffe2/core/workspace.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/proto_utils.h" - -C10_DEFINE_int64(caffe2_test_big_tensor_size, 100000000, ""); -C10_DECLARE_int(caffe2_tensor_chunk_size); -C10_DECLARE_bool(caffe2_serialize_fp16_as_bytes); -C10_DECLARE_bool(caffe2_serialize_using_bytes_as_holder); - -namespace caffe2 { -using namespace ::caffe2::db; -namespace { -class BlobTestFoo { - public: - int32_t val; -}; -class BlobTestBar {}; -class BlobTestNonDefaultConstructible { - public: - BlobTestNonDefaultConstructible() = delete; - BlobTestNonDefaultConstructible(int x) : val(x) {} - int32_t val; -}; -} // namespace - -CAFFE_KNOWN_TYPE_NOEXPORT(BlobTestFoo); -CAFFE_KNOWN_TYPE_NOEXPORT(BlobTestBar); -CAFFE_KNOWN_TYPE_NOEXPORT(BlobTestNonDefaultConstructible); - -class BlobTestFooSerializer : public BlobSerializerBase { - public: - // NOLINTNEXTLINE(modernize-use-equals-default) - BlobTestFooSerializer() {} - // NOLINTNEXTLINE(modernize-use-equals-default) - ~BlobTestFooSerializer() override {} - /** - * Serializes a Blob. Note that this blob has to contain Tensor, - * otherwise this function produces a fatal error. - */ - void Serialize( - const void* pointer, - TypeMeta typeMeta, - const string& name, - SerializationAcceptor acceptor) override { - CAFFE_ENFORCE(typeMeta.Match()); - - BlobProto blob_proto; - blob_proto.set_name(name); - blob_proto.set_type("BlobTestFoo"); - // For simplicity we will just serialize the 4-byte content as a string. - blob_proto.set_content(std::string( - reinterpret_cast( - &static_cast(pointer)->val), - sizeof(int32_t))); - acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); - } -}; - -class BlobTestFooDeserializer : public BlobDeserializerBase { - public: - void Deserialize(const BlobProto& proto, Blob* blob) override { - blob->GetMutable()->val = - reinterpret_cast(proto.content().c_str())[0]; - } -}; - -REGISTER_BLOB_SERIALIZER((TypeMeta::Id()), BlobTestFooSerializer); -REGISTER_BLOB_DESERIALIZER(BlobTestFoo, BlobTestFooDeserializer); - -namespace { - -TEST(BlobTest, Blob) { - Blob blob; - - int* int_unused CAFFE2_UNUSED = blob.GetMutable(); - EXPECT_TRUE(blob.IsType()); - EXPECT_FALSE(blob.IsType()); - EXPECT_FALSE(BlobIsTensorType(blob, CPU)); - - BlobTestFoo* foo_unused CAFFE2_UNUSED = blob.GetMutable(); - EXPECT_TRUE(blob.IsType()); - EXPECT_FALSE(blob.IsType()); - EXPECT_FALSE(BlobIsTensorType(blob, CPU)); - - Tensor* tensor_unused CAFFE2_UNUSED = BlobGetMutableTensor(&blob, CPU); - EXPECT_TRUE(BlobIsTensorType(blob, CPU)); - EXPECT_FALSE(blob.IsType()); - EXPECT_FALSE(blob.IsType()); -} - -TEST(BlobTest, BlobUninitialized) { - Blob blob; - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(blob.Get(), EnforceNotMet); -} - -TEST(BlobTest, BlobWrongType) { - Blob blob; - BlobTestFoo* foo_unused CAFFE2_UNUSED = blob.GetMutable(); - EXPECT_TRUE(blob.IsType()); - EXPECT_FALSE(blob.IsType()); - // When not null, we should only call with the right type. - EXPECT_NE(&blob.Get(), nullptr); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(blob.Get(), EnforceNotMet); -} - -TEST(BlobTest, BlobReset) { - Blob blob; - std::unique_ptr foo(new BlobTestFoo()); - EXPECT_TRUE(blob.Reset(foo.release()) != nullptr); - // Also test that Reset works. - blob.Reset(); -} - -TEST(BlobTest, BlobMove) { - Blob blob1; - std::unique_ptr foo(new BlobTestFoo()); - auto* fooPtr = foo.get(); - EXPECT_TRUE(blob1.Reset(foo.release()) != nullptr); - Blob blob2; - blob2 = std::move(blob1); - // NOLINTNEXTLINE(bugprone-use-after-move,hicpp-avoid-goto,clang-analyzer-cplusplus.Move,cppcoreguidelines-avoid-goto) - ASSERT_THROW(blob1.Get(), EnforceNotMet); - EXPECT_EQ(&blob2.Get(), fooPtr); - Blob blob3{std::move(blob2)}; - EXPECT_EQ(&blob3.Get(), fooPtr); -} - -TEST(BlobTest, BlobNonConstructible) { - Blob blob; - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(blob.Get(), EnforceNotMet); - // won't work because it's not default constructible - // blob.GetMutable(); - EXPECT_FALSE( - blob.GetMutableOrNull() != nullptr); - EXPECT_TRUE(blob.Reset(new BlobTestNonDefaultConstructible(42)) != nullptr); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_NO_THROW(blob.Get()); - ASSERT_TRUE( - blob.GetMutableOrNull() != nullptr); - EXPECT_EQ(blob.Get().val, 42); - blob.GetMutableOrNull()->val = 37; - EXPECT_EQ(blob.Get().val, 37); -} - -TEST(BlobTest, BlobShareExternalPointer) { - Blob blob; - std::unique_ptr foo(new BlobTestFoo()); - EXPECT_EQ(blob.ShareExternal(foo.get()), foo.get()); - EXPECT_TRUE(blob.IsType()); - // Also test that Reset works. - blob.Reset(); -} - -TEST(BlobTest, BlobShareExternalObject) { - Blob blob; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - BlobTestFoo foo; - EXPECT_EQ(blob.ShareExternal(&foo), &foo); - EXPECT_TRUE(blob.IsType()); - // Also test that Reset works. - blob.Reset(); -} - -TEST(BlobTest, StringSerialization) { - const std::string kTestString = "Hello world?"; - Blob blob; - *blob.GetMutable() = kTestString; - - string serialized = SerializeBlob(blob, "test"); - BlobProto proto; - CHECK(proto.ParseFromString(serialized)); - EXPECT_EQ(proto.name(), "test"); - EXPECT_EQ(proto.type(), "std::string"); - EXPECT_FALSE(proto.has_tensor()); - EXPECT_EQ(proto.content(), kTestString); -} - -TEST(TensorNonTypedTest, TensorChangeType) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - - auto* ptr = tensor.mutable_data(); - EXPECT_TRUE(ptr != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(tensor.dtype().Match()); - - // int and float are same size, so should retain the pointer - // NB: this is only true when the use_count of the underlying Storage is 1, if - // the underlying Storage is shared between multiple Tensors We'll create a - // new Storage when the data type changes - EXPECT_TRUE(tensor.mutable_data() == (float*)ptr); - EXPECT_TRUE(tensor.data() == (const float*)ptr); - EXPECT_TRUE(tensor.dtype().Match()); - - // at::Half is smaller, so still should share buffer - EXPECT_TRUE(tensor.mutable_data() == (at::Half*)ptr); - EXPECT_TRUE(tensor.data() == (const at::Half*)ptr); - EXPECT_TRUE(tensor.dtype().Match()); - - // share the data with other tensor so that the pointer won't be reused - // when we reallocate - Tensor other_tensor = tensor.Alias(); - // but double is bigger, so it should allocate a new one - auto* doubleptr = tensor.mutable_data(); - EXPECT_TRUE(doubleptr != (double*)ptr); - EXPECT_TRUE(doubleptr != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(tensor.dtype().Match()); -} - -TEST(TensorNonTypedTest, NonDefaultConstructible) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - - // this doesn't compile - good! - // auto* ptr = tensor.mutable_data(); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW( - tensor.raw_mutable_data( - TypeMeta::Make()), - EnforceNotMet); -} - -template -class TensorCPUTest : public ::testing::Test {}; -template -class TensorCPUDeathTest : public ::testing::Test {}; -typedef ::testing::Types TensorTypes; -TYPED_TEST_CASE(TensorCPUTest, TensorTypes); -TYPED_TEST_CASE(TensorCPUDeathTest, TensorTypes); - -TYPED_TEST(TensorCPUTest, TensorInitializedEmpty) { - Tensor tensor(CPU); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.numel(), 0); - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - tensor.Resize(dims); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 3); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_EQ(tensor.numel(), 2 * 3 * 5); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); -} - -TYPED_TEST(TensorCPUTest, TensorInitializedNonEmpty) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 3); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); - dims[0] = 7; - dims[1] = 11; - dims[2] = 13; - dims.push_back(17); - tensor.Resize(dims); - EXPECT_EQ(tensor.dim(), 4); - EXPECT_EQ(tensor.dim32(0), 7); - EXPECT_EQ(tensor.dim32(1), 11); - EXPECT_EQ(tensor.dim32(2), 13); - EXPECT_EQ(tensor.dim32(3), 17); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); -} - -TYPED_TEST(TensorCPUTest, TensorInitializedZeroDim) { - vector dims(3); - dims[0] = 2; - dims[1] = 0; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 0); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_TRUE(tensor.mutable_data() == nullptr); - EXPECT_TRUE(tensor.data() == nullptr); -} - -TYPED_TEST(TensorCPUTest, TensorResizeZeroDim) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 3); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); - - dims[0] = 7; - dims[1] = 0; - dims[2] = 13; - tensor.Resize(dims); - EXPECT_EQ(tensor.numel(), 0); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 7); - EXPECT_EQ(tensor.dim32(1), 0); - EXPECT_EQ(tensor.dim32(2), 13); - // output value can be arbitrary, but the call to data() shouldn't crash - tensor.mutable_data(); - tensor.data(); -} - -TYPED_TEST(TensorCPUTest, TensorInitializedScalar) { - vector dims; - Tensor tensor(dims, CPU); - EXPECT_EQ(tensor.dim(), 0); - EXPECT_EQ(tensor.numel(), 1); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); -} - -TYPED_TEST(TensorCPUTest, TensorAlias) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(other_tensor.data() != nullptr); - EXPECT_EQ(tensor.data(), other_tensor.data()); - // Set one value, check the other - for (int i = 0; i < tensor.numel(); ++i) { - tensor.mutable_data()[i] = i; - EXPECT_EQ(other_tensor.data()[i], i); - } -} - -TYPED_TEST(TensorCPUTest, TensorShareDataRawPointer) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) - std::unique_ptr raw_buffer(new TypeParam[2 * 3 * 5]); - Tensor tensor(dims, CPU); - tensor.ShareExternalPointer(raw_buffer.get()); - EXPECT_EQ(tensor.mutable_data(), raw_buffer.get()); - EXPECT_EQ(tensor.data(), raw_buffer.get()); - // Set one value, check the other - for (int i = 0; i < tensor.numel(); ++i) { - raw_buffer.get()[i] = i; - EXPECT_EQ(tensor.data()[i], i); - } -} - -TYPED_TEST(TensorCPUTest, TensorShareDataRawPointerWithMeta) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) - std::unique_ptr raw_buffer(new TypeParam[2 * 3 * 5]); - Tensor tensor(dims, CPU); - TypeMeta meta = TypeMeta::Make(); - tensor.ShareExternalPointer(raw_buffer.get(), meta); - EXPECT_EQ(tensor.mutable_data(), raw_buffer.get()); - EXPECT_EQ(tensor.data(), raw_buffer.get()); - // Set one value, check the other - for (int i = 0; i < tensor.numel(); ++i) { - raw_buffer.get()[i] = i; - EXPECT_EQ(tensor.data()[i], i); - } -} - -TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - vector alternate_dims(1); - alternate_dims[0] = 2 * 3 * 5; - Tensor tensor(dims, CPU); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - other_tensor.Resize(alternate_dims); - EXPECT_EQ(other_tensor.dim(), 1); - EXPECT_EQ(other_tensor.dim32(0), alternate_dims[0]); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(other_tensor.data() != nullptr); - EXPECT_EQ(tensor.data(), other_tensor.data()); - // Set one value, check the other - for (int i = 0; i < tensor.numel(); ++i) { - tensor.mutable_data()[i] = i; - EXPECT_EQ(other_tensor.data()[i], i); - } -} - -TYPED_TEST(TensorCPUTest, NoLongerAliassAfterNumelChanges) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - EXPECT_EQ(tensor.data(), other_tensor.data()); - auto* old_pointer = other_tensor.data(); - - dims[0] = 7; - tensor.Resize(dims); - EXPECT_EQ(old_pointer, other_tensor.data()); - EXPECT_NE(old_pointer, tensor.mutable_data()); -} - -TYPED_TEST(TensorCPUTest, NoLongerAliasAfterFreeMemory) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - EXPECT_EQ(tensor.data(), other_tensor.data()); - auto* old_pointer = other_tensor.data(); - - tensor.FreeMemory(); - EXPECT_EQ(old_pointer, other_tensor.data()); - EXPECT_NE(old_pointer, tensor.mutable_data()); -} - -TYPED_TEST(TensorCPUTest, KeepOnShrink) { - // Set flags (defaults) - FLAGS_caffe2_keep_on_shrink = true; - FLAGS_caffe2_max_keep_on_shrink_memory = LLONG_MAX; - - vector dims{2, 3, 5}; - Tensor tensor(dims, CPU); - TypeParam* ptr = tensor.mutable_data(); - EXPECT_TRUE(ptr != nullptr); - // Expanding - will reallocate - tensor.Resize(3, 4, 6); - TypeParam* larger_ptr = tensor.mutable_data(); - EXPECT_TRUE(larger_ptr != nullptr); - - // This check can fail when malloc() returns the same recently freed address - // EXPECT_NE(ptr, larger_ptr); - - // Shrinking - will not reallocate - tensor.Resize(1, 2, 4); - TypeParam* smaller_ptr = tensor.mutable_data(); - EXPECT_TRUE(smaller_ptr != nullptr); - EXPECT_EQ(larger_ptr, smaller_ptr); - // resize to 0 in the meantime; - tensor.Resize(3, 0, 6); - // Expanding but still under capacity - will not reallocate - tensor.Resize(2, 3, 5); - TypeParam* new_ptr = tensor.mutable_data(); - EXPECT_TRUE(new_ptr != nullptr); - EXPECT_EQ(larger_ptr, new_ptr); -} - -TYPED_TEST(TensorCPUTest, MaxKeepOnShrink) { - // Set flags - FLAGS_caffe2_keep_on_shrink = true; - FLAGS_caffe2_max_keep_on_shrink_memory = 8 * 4 * sizeof(TypeParam); - - vector dims{1, 8, 8}; - Tensor tensor(dims, CPU); - TypeParam* ptr = tensor.mutable_data(); - EXPECT_TRUE(ptr != nullptr); - // Shrinking - will not reallocate - tensor.Resize(1, 7, 8); - TypeParam* smaller_ptr = tensor.mutable_data(); - EXPECT_TRUE(smaller_ptr != nullptr); - EXPECT_EQ(ptr, smaller_ptr); - // Resize to more than maximum shrink, should reallocate - tensor.Resize(1, 1, 8); - TypeParam* new_ptr = tensor.mutable_data(); - EXPECT_TRUE(new_ptr != nullptr); - - // This check can fail when malloc() returns the same recently freed address - // EXPECT_NE(ptr, new_ptr); - - // Restore default flags - FLAGS_caffe2_max_keep_on_shrink_memory = LLONG_MAX; -} - -TYPED_TEST(TensorCPUDeathTest, CannotAccessRawDataWhenEmpty) { - Tensor tensor(CPU); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.numel(), 0); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_ANY_THROW(tensor.raw_data()); -} - -TYPED_TEST(TensorCPUDeathTest, CannotAccessDataWhenEmpty) { - Tensor tensor(CPU); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.numel(), 0); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_ANY_THROW(tensor.data()); -} - -TEST(TensorTest, TensorNonFundamentalType) { - Tensor tensor(vector{2, 3, 4}, CPU); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - const std::string* ptr = tensor.data(); - for (int i = 0; i < tensor.numel(); ++i) { - EXPECT_TRUE(ptr[i] == ""); - } -} - -TEST(TensorTest, TensorNonFundamentalTypeClone) { - Tensor tensor(vector{2, 3, 4}, CPU); - std::string* ptr = tensor.mutable_data(); - EXPECT_TRUE(ptr != nullptr); - for (int i = 0; i < tensor.numel(); ++i) { - EXPECT_TRUE(ptr[i] == ""); - ptr[i] = "filled"; - } - Tensor dst_tensor = tensor.Clone(); - const std::string* dst_ptr = dst_tensor.data(); - for (int i = 0; i < dst_tensor.numel(); ++i) { - EXPECT_TRUE(dst_ptr[i] == "filled"); - } - // Change the original tensor - for (int i = 0; i < tensor.numel(); ++i) { - EXPECT_TRUE(ptr[i] == "filled"); - ptr[i] = "changed"; - } - // Confirm that the cloned tensor is not affect - for (int i = 0; i < dst_tensor.numel(); ++i) { - EXPECT_TRUE(dst_ptr[i] == "filled"); - } -} - -TEST(TensorTest, Tensor64BitDimension) { - // Initialize a large tensor. - int64_t large_number = - static_cast(std::numeric_limits::max()) + 1; - Tensor tensor(vector{large_number}, CPU); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.size(0), large_number); - EXPECT_EQ(tensor.numel(), large_number); - try { - EXPECT_TRUE(tensor.mutable_data() != nullptr); - } catch (const EnforceNotMet& e) { - string msg = e.what(); - size_t found = msg.find("posix_memalign"); - if (found != string::npos) { - msg = msg.substr(0, msg.find('\n')); - LOG(WARNING) << msg; - LOG(WARNING) << "Out of memory issue with posix_memalign;\n"; - return; - } else { - throw e; - } - } - EXPECT_EQ(tensor.nbytes(), large_number * sizeof(char)); - EXPECT_EQ(tensor.itemsize(), sizeof(char)); - // Try to go even larger, but this time we will not do mutable_data because we - // do not have a large enough memory. - tensor.Resize(large_number, 100); - EXPECT_EQ(tensor.dim(), 2); - EXPECT_EQ(tensor.size(0), large_number); - EXPECT_EQ(tensor.size(1), 100); - EXPECT_EQ(tensor.numel(), large_number * 100); -} - -TEST(TensorTest, UndefinedTensor) { - Tensor x; - EXPECT_FALSE(x.defined()); -} - -TEST(TensorTest, CopyAndAssignment) { - Tensor x(CPU); - x.Resize(16, 17); - testing::randomFill(x.template mutable_data(), 16 * 17); - EXPECT_TRUE(x.defined()); - - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - Tensor y(x); - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - Tensor z = x; - testing::assertTensorEquals(x, y, 0); - testing::assertTensorEquals(x, z, 0); -} - -TEST(TensorDeathTest, CannotCastDownLargeDims) { - int64_t large_number = - static_cast(std::numeric_limits::max()) + 1; - Tensor tensor(vector{large_number}, CPU); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.size(0), large_number); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(tensor.dim32(0), EnforceNotMet); -} - -#define TEST_SERIALIZATION_WITH_TYPE(TypeParam, field_name) \ - TEST(TensorTest, TensorSerialization_##TypeParam) { \ - Blob blob; \ - Tensor* tensor = BlobGetMutableTensor(&blob, CPU); \ - tensor->Resize(2, 3); \ - for (int i = 0; i < 6; ++i) { \ - tensor->mutable_data()[i] = static_cast(i); \ - } \ - string serialized = SerializeBlob(blob, "test"); \ - BlobProto proto; \ - CHECK(proto.ParseFromString(serialized)); \ - EXPECT_EQ(proto.name(), "test"); \ - EXPECT_EQ(proto.type(), "Tensor"); \ - EXPECT_TRUE(proto.has_tensor()); \ - const TensorProto& tensor_proto = proto.tensor(); \ - EXPECT_EQ( \ - tensor_proto.data_type(), \ - TypeMetaToDataType(TypeMeta::Make())); \ - EXPECT_EQ(tensor_proto.field_name##_size(), 6); \ - for (int i = 0; i < 6; ++i) { \ - EXPECT_EQ(tensor_proto.field_name(i), static_cast(i)); \ - } \ - Blob new_blob; \ - EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \ - EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); \ - const TensorCPU& new_tensor = blob.Get(); \ - EXPECT_EQ(new_tensor.dim(), 2); \ - EXPECT_EQ(new_tensor.size(0), 2); \ - EXPECT_EQ(new_tensor.size(1), 3); \ - for (int i = 0; i < 6; ++i) { \ - EXPECT_EQ( \ - tensor->data()[i], new_tensor.data()[i]); \ - } \ - } \ - \ - TEST(EmptyTensorTest, TensorSerialization_##TypeParam) { \ - Blob blob; \ - TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU); \ - tensor->Resize(0, 3); \ - tensor->mutable_data(); \ - string serialized = SerializeBlob(blob, "test"); \ - BlobProto proto; \ - CHECK(proto.ParseFromString(serialized)); \ - EXPECT_EQ(proto.name(), "test"); \ - EXPECT_EQ(proto.type(), "Tensor"); \ - EXPECT_TRUE(proto.has_tensor()); \ - const TensorProto& tensor_proto = proto.tensor(); \ - EXPECT_EQ( \ - tensor_proto.data_type(), \ - TypeMetaToDataType(TypeMeta::Make())); \ - EXPECT_EQ(tensor_proto.field_name##_size(), 0); \ - Blob new_blob; \ - EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \ - EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); \ - const TensorCPU& new_tensor = blob.Get(); \ - EXPECT_EQ(new_tensor.dim(), 2); \ - EXPECT_EQ(new_tensor.size(0), 0); \ - EXPECT_EQ(new_tensor.size(1), 3); \ - } - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(bool, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(double, double_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(float, float_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(int, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(int8_t, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(int16_t, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(uint8_t, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(uint16_t, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(int64_t, int64_data) - -TEST(TensorTest, TensorSerialization_CustomType) { - Blob blob; - TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU); - tensor->Resize(2, 3); - for (int i = 0; i < 6; ++i) { - tensor->mutable_data()[i].val = i; - } - string serialized = SerializeBlob(blob, "test"); - BlobProto proto; - CHECK(proto.ParseFromString(serialized)); - EXPECT_EQ(proto.name(), "test"); - EXPECT_EQ(proto.type(), "Tensor"); - Blob new_blob; - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); - EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); - const TensorCPU& new_tensor = blob.Get(); - EXPECT_EQ(new_tensor.dim(), 2); - EXPECT_EQ(new_tensor.size(0), 2); - EXPECT_EQ(new_tensor.size(1), 3); - for (int i = 0; i < 6; ++i) { - EXPECT_EQ( - new_tensor.data()[i].val, - tensor->data()[i].val); - } -} - -TEST(TensorTest, Half) { - const int64_t kSize = 3000000; - Blob blob; - TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU); - tensor->Resize(kSize); - for (int i = 0; i < tensor->numel(); ++i) { - tensor->mutable_data()[i].x = i % 10000; - } - string serialized = SerializeBlob(blob, "test"); - BlobProto proto; - CHECK(proto.ParseFromString(serialized)); - EXPECT_EQ(proto.name(), "test"); - EXPECT_EQ(proto.type(), "Tensor"); - EXPECT_TRUE(proto.has_tensor()); - const TensorProto& tensor_proto = proto.tensor(); - EXPECT_EQ( - tensor_proto.data_type(), TypeMetaToDataType(TypeMeta::Make())); - if (FLAGS_caffe2_serialize_fp16_as_bytes) { - EXPECT_EQ(tensor_proto.byte_data().size(), 2 * kSize); - for (int i = 0; i < kSize; ++i) { - auto value = tensor->mutable_data()[i].x; - auto low_bits = static_cast(value & 0xff); - auto high_bits = static_cast(value >> 8); - EXPECT_EQ(tensor_proto.byte_data()[2 * i], low_bits); - EXPECT_EQ(tensor_proto.byte_data()[2 * i + 1], high_bits); - } - } else { - EXPECT_EQ(tensor_proto.int32_data().size(), kSize); - } - Blob new_blob; - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); - EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); - const TensorCPU& new_tensor = blob.Get(); - EXPECT_EQ(new_tensor.dim(), 1); - EXPECT_EQ(new_tensor.size(0), kSize); - for (int i = 0; i < kSize; ++i) { - EXPECT_EQ(new_tensor.data()[i].x, i % 10000); - } -} - -TEST(TensorTest, TensorFactory) { - Tensor a = empty({1, 2, 3}, at::device(CPU).dtype()); - EXPECT_NE(a.data(), nullptr); - a.mutable_data()[0] = 3.0; - Tensor b = empty({1, 2, 3}, at::device(CPU).dtype()); - EXPECT_NE(b.data(), nullptr); - b.mutable_data()[0] = 3; -} - -TEST(QTensorTest, QTensorSerialization) { - Blob blob; - QTensor* qtensor = blob.GetMutable>(); - qtensor->SetPrecision(5); - qtensor->SetSigned(false); - qtensor->SetScale(1.337); - qtensor->SetBias(-1.337); - qtensor->Resize(std::vector{2, 3}); - // "Randomly" set bits. - srand(0); - for (int i = 0; i < 6; ++i) { - for (int j = 0; j < 5; ++j) { - // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand) - qtensor->SetBitAtIndex(j, i, rand() % 2); - } - } - - string serialized = SerializeBlob(blob, "test"); - BlobProto proto; - CHECK(proto.ParseFromString(serialized)); - EXPECT_EQ(proto.name(), "test"); - EXPECT_EQ(proto.type(), "QTensor"); - EXPECT_TRUE(proto.has_qtensor()); - const QTensorProto& qtensor_proto = proto.qtensor(); - - EXPECT_EQ(qtensor_proto.precision(), qtensor->precision()); - EXPECT_EQ(qtensor_proto.scale(), qtensor->scale()); - EXPECT_EQ(qtensor_proto.bias(), qtensor->bias()); - EXPECT_EQ(qtensor_proto.is_signed(), qtensor->is_signed()); - - Blob new_blob; - DeserializeBlob(serialized, &new_blob); - EXPECT_TRUE(new_blob.IsType>()); - const QTensor& new_qtensor = blob.Get>(); - EXPECT_EQ(new_qtensor.ndim(), 2); - EXPECT_EQ(new_qtensor.dim32(0), 2); - EXPECT_EQ(new_qtensor.dim32(1), 3); - for (int i = 0; i < 6; ++i) { - for (int j = 0; j < 5; ++j) { - EXPECT_EQ(qtensor->GetBitAtIndex(j, i), new_qtensor.GetBitAtIndex(j, i)); - } - } -} - -using StringMap = std::vector>; - -class VectorCursor : public db::Cursor { - public: - explicit VectorCursor(StringMap* data) : data_(data) { - pos_ = 0; - } - // NOLINTNEXTLINE(modernize-use-equals-default) - ~VectorCursor() override {} - void Seek(const string& /* unused */) override {} - void SeekToFirst() override {} - void Next() override { - ++pos_; - } - string key() override { - return (*data_)[pos_].first; - } - string value() override { - return (*data_)[pos_].second; - } - bool Valid() override { - return pos_ < data_->size(); - } - - private: - StringMap* data_ = nullptr; - size_t pos_ = 0; -}; - -class VectorDB : public db::DB { - public: - VectorDB(const string& source, db::Mode mode) - : DB(source, mode), name_(source) {} - ~VectorDB() override { - data_.erase(name_); - } - void Close() override {} - std::unique_ptr NewCursor() override { - return make_unique(getData()); - } - std::unique_ptr NewTransaction() override { - CAFFE_THROW("Not implemented"); - } - static void registerData(const string& name, StringMap&& data) { - std::lock_guard guard(dataRegistryMutex_); - data_[name] = std::move(data); - } - - private: - StringMap* getData() { - auto it = data_.find(name_); - CAFFE_ENFORCE(it != data_.end(), "Can't find ", name_); - return &(it->second); - } - - private: - string name_; - static std::mutex dataRegistryMutex_; - static std::map data_; -}; - -std::mutex VectorDB::dataRegistryMutex_; -std::map VectorDB::data_; - -REGISTER_CAFFE2_DB(vector_db, VectorDB); - -template -class TypedTensorTest : public ::testing::Test {}; -typedef ::testing:: - Types - TensorDataTypes; -TYPED_TEST_CASE(TypedTensorTest, TensorDataTypes); - -TYPED_TEST(TypedTensorTest, BigTensorSerialization) { - int64_t d1 = 2; - int64_t d2 = FLAGS_caffe2_test_big_tensor_size - ? FLAGS_caffe2_test_big_tensor_size / d1 - : static_cast(std::numeric_limits::max()) + 1; - int64_t size = d1 * d2; - string db_source = (string)std::tmpnam(nullptr); - VLOG(1) << "db_source: " << db_source; - - { - VLOG(1) << "Test begin"; - Blob blob; - Tensor* tensor = BlobGetMutableTensor(&blob, CPU); - VLOG(1) << "Allocating blob"; - tensor->Resize(d1, d2); - auto mutableData = tensor->mutable_data(); - VLOG(1) << "Filling out the blob"; - for (int64_t i = 0; i < size; ++i) { - mutableData[i] = static_cast(i); - } - StringMap data; - std::mutex mutex; - auto acceptor = [&](const std::string& key, const std::string& value) { - std::lock_guard guard(mutex); - data.emplace_back(key, value); - }; - SerializeBlob(blob, "test", acceptor); - VectorDB::registerData(db_source, std::move(data)); - VLOG(1) << "finished writing to DB"; - } - - { - DeviceOption option; - option.set_device_type(PROTO_CPU); - Argument db_type_arg = MakeArgument("db_type", "vector_db"); - Argument absolute_path_arg = MakeArgument("absolute_path", true); - Argument db_source_arg = MakeArgument("db", db_source); - auto op_def = CreateOperatorDef( - "Load", - "", - std::vector{}, - std::vector({"test"}), - std::vector{db_type_arg, db_source_arg, absolute_path_arg}, - option, - "DUMMY_ENGINE"); - Workspace ws; - auto load_op = CreateOperator(op_def, &ws); - EXPECT_TRUE(load_op != nullptr); - VLOG(1) << "Running operator"; - - load_op->Run(); - VLOG(1) << "Reading blob from workspace"; - auto new_blob = ws.GetBlob("test"); - EXPECT_TRUE(BlobIsTensorType(*new_blob, CPU)); - const auto& new_tensor = new_blob->Get(); - - EXPECT_EQ(new_tensor.dim(), d1); - EXPECT_EQ(new_tensor.size(0), d1); - EXPECT_EQ(new_tensor.size(1), d2); - for (int64_t i = 0; i < size; ++i) { - EXPECT_EQ(static_cast(i), new_tensor.data()[i]); - } - } -} - -struct DummyType { - /* This struct is used to test serialization and deserialization of huge - * blobs, that are not tensors. - */ - - /* implicit */ DummyType(int n_chunks_init = 0) : n_chunks(n_chunks_init) {} - std::string serialize(const std::string& name, const int32_t chunk_id) const { - BlobProto blobProto; - blobProto.set_name(name); - blobProto.set_type("DummyType"); - std::string content(""); - blobProto.set_content(content); - blobProto.set_content_num_chunks(n_chunks); - blobProto.set_content_chunk_id(chunk_id); - return blobProto.SerializeAsString(); - } - void deserialize(const BlobProto& /* unused */) { - ++n_chunks; - } - int n_chunks; -}; - -class DummyTypeSerializer : public BlobSerializerBase { - public: - // NOLINTNEXTLINE(modernize-use-equals-default) - DummyTypeSerializer() {} - // NOLINTNEXTLINE(modernize-use-equals-default) - ~DummyTypeSerializer() override {} - void Serialize( - const void* pointer, - TypeMeta typeMeta, - const string& name, - SerializationAcceptor acceptor) override { - CAFFE_ENFORCE(typeMeta.Match()); - const auto& container = *static_cast(pointer); - for (int k = 0; k < container.n_chunks; ++k) { - std::string serialized_chunk = container.serialize(name, k); - acceptor( - c10::str(name, kChunkIdSeparator, k), std::move(serialized_chunk)); - } - } -}; - -class DummyTypeDeserializer : public BlobDeserializerBase { - public: - void Deserialize(const BlobProto& proto, Blob* blob) override { - auto* container = blob->GetMutable(); - container->deserialize(proto); - } -}; -} // namespace - -CAFFE_KNOWN_TYPE_NOEXPORT(DummyType); - -namespace { -REGISTER_BLOB_SERIALIZER((TypeMeta::Id()), DummyTypeSerializer); -C10_REGISTER_TYPED_CLASS( - BlobDeserializerRegistry, - "DummyType", - DummyTypeDeserializer); - -TEST(ContentChunks, Serialization) { - string db_source = (string)std::tmpnam(nullptr); - VLOG(1) << "db_source: " << db_source; - - { - VLOG(1) << "Test begin"; - Blob blob; - DummyType* container = blob.GetMutable(); - VLOG(1) << "Allocating blob"; - container->n_chunks = 10; - VLOG(1) << "Filling out the blob"; - StringMap data; - std::mutex mutex; - auto acceptor = [&](const std::string& key, const std::string& value) { - std::lock_guard guard(mutex); - data.emplace_back(key, value); - }; - SerializeBlob(blob, "test", acceptor); - VectorDB::registerData(db_source, std::move(data)); - VLOG(1) << "finished writing to DB"; - } - - { - DeviceOption option; - option.set_device_type(PROTO_CPU); - Argument db_type_arg = MakeArgument("db_type", "vector_db"); - Argument absolute_path_arg = MakeArgument("absolute_path", true); - Argument db_source_arg = MakeArgument("db", db_source); - auto op_def = CreateOperatorDef( - "Load", - "", - std::vector{}, - std::vector({"test"}), - std::vector{db_type_arg, db_source_arg, absolute_path_arg}, - option, - "DUMMY_ENGINE"); - Workspace ws; - auto load_op = CreateOperator(op_def, &ws); - EXPECT_TRUE(load_op != nullptr); - VLOG(1) << "Running operator"; - - load_op->Run(); - VLOG(1) << "Reading blob from workspace"; - auto new_blob = ws.GetBlob("test"); - EXPECT_TRUE(new_blob->IsType()); - const auto& container = new_blob->Get(); - EXPECT_EQ(container.n_chunks, 10); - } -} - -TEST(CustomChunkSize, BigTensorSerialization) { - int64_t d1 = 2; - int64_t d2 = FLAGS_caffe2_test_big_tensor_size - ? FLAGS_caffe2_test_big_tensor_size / d1 - : static_cast(std::numeric_limits::max()) + 1; - BlobSerializationOptions options; - - Blob blob; - TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU); - tensor->Resize(d1, d2); - tensor->mutable_data(); - std::mutex mutex; - int counter = 0; - auto acceptor = [&](const std::string& /*key*/, - const std::string& /*value*/) { - std::lock_guard guard(mutex); - counter++; - }; - options.set_chunk_size(d1 * d2); - SerializeBlob(blob, "test", acceptor, options); - EXPECT_EQ(counter, 1); - - counter = 0; - options.set_chunk_size((d1 * d2) / 2 + 1); - SerializeBlob(blob, "test", acceptor, options); - EXPECT_EQ(counter, 2); - - counter = 0; - options.set_chunk_size(-1); - SerializeBlob(blob, "test", acceptor, options); - EXPECT_EQ(counter, 1); -} - -TEST(QTensor, QTensorSizingTest) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - QTensor qtensor(dims, 3); - EXPECT_TRUE(qtensor.mutable_data() != nullptr); - EXPECT_EQ(qtensor.nbytes(), 12); - EXPECT_EQ(qtensor.size(), 30); -} - -TEST(BlobTest, CastingMessage) { - Blob b; - b.GetMutable(); - b.Get(); - try { - b.Get(); - FAIL() << "Should have thrown"; - } catch (const EnforceNotMet& e) { - string msg = e.what_without_backtrace(); - LOG(INFO) << msg; - EXPECT_NE(msg.find("BlobTestFoo"), std::string::npos) << msg; - EXPECT_NE(msg.find("BlobTestBar"), std::string::npos) << msg; - } -} - -TEST(TensorConstruction, UninitializedCopyTest) { - Tensor x(CPU); - Tensor y(x, CPU); - Tensor z = x.Clone(); - EXPECT_FALSE(x.dtype_initialized()); - EXPECT_FALSE(y.dtype_initialized()); - LOG(INFO) << "z.size()" << z.numel(); - EXPECT_FALSE(z.dtype_initialized()); -} - -TEST(TensorConstruction, CopyConstructorTest) { - Tensor x(CPU); - x.Resize(5); - x.mutable_data()[0] = 1; - Tensor y = x.Clone(); - Tensor z(x, CPU); - - EXPECT_EQ(*x.data(), 1); - EXPECT_EQ(*y.data(), 1); - EXPECT_EQ(*z.data(), 1); - x.mutable_data()[0] = 5; - EXPECT_EQ(*x.data(), 5); - EXPECT_EQ(*y.data(), 1); - EXPECT_EQ(*z.data(), 1); -} - -TEST(TensorConstruction, MoveAssignmentOpTest) { - Tensor x(CPU); - x.Resize(5); - x.mutable_data()[0] = 1; - Tensor y(CPU); - y = std::move(x); - - EXPECT_EQ(*y.data(), 1); -} - -TEST(TensorSerialization, MistakenlySerializingDtypeUninitializedTensor) { - // This test preserves a legacy behavior that dtype-unitialized tensors can - // go through serialization. We want to kill this behavior - when it's done, - // remove this test - Blob blob; - Tensor* x = BlobGetMutableTensor(&blob, CPU); - x->Resize(0); - string output; - SerializeBlob( - blob, - "foo", - [&output](const string& /*blobName*/, const std::string& data) { - output = data; - }); - BlobProto b; - CHECK(b.ParseFromString(output)); - LOG(INFO) << "serialized proto: " << b.DebugString(); - - Blob new_blob; - // Deserializing an empty Tensor gives a {0}-dim, float CPU Tensor - DeserializeBlob(output, &new_blob); - const Tensor& new_tensor = new_blob.Get(); - LOG(INFO) << "tensor " << new_tensor.DebugString(); - EXPECT_TRUE(new_tensor.dtype_initialized()); - LOG(INFO) << "dtype:" << new_tensor.dtype(); - EXPECT_EQ(0, new_tensor.numel()); - EXPECT_EQ(1, new_tensor.dim()); -} - -static caffe2::BlobProto CreateProtoWithInt32Data( - const caffe2::TensorProto::DataType& dataType, - size_t numEl, - bool useCached = true) { - static std::map protos; - if (useCached && protos.count(dataType)) { - return protos[dataType]; - } - caffe2::BlobProto proto; - proto.set_type("Tensor"); - auto tensor = proto.mutable_tensor(); - tensor->add_dims(numEl); - tensor->add_dims(1); - tensor->set_data_type(dataType); - tensor->set_name("test_feature"); - tensor->mutable_device_detail()->set_device_type(0); - tensor->mutable_segment()->set_begin(0); - tensor->mutable_segment()->set_end(numEl); - for (size_t i = 0; i < numEl; ++i) { - int32_t data = 0; - switch (dataType) { - case caffe2::TensorProto_DataType_INT32: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0xffffffff); - break; - case caffe2::TensorProto_DataType_BOOL: - // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x00000001); - break; - case caffe2::TensorProto_DataType_UINT8: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x000000ff); - break; - case caffe2::TensorProto_DataType_INT8: - // NOLINTNEXTLINE(bugprone-signed-char-misuse,cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x000000ff); - break; - case caffe2::TensorProto_DataType_UINT16: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x0000ffff); - break; - case caffe2::TensorProto_DataType_INT16: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x0000ffff); - break; - case caffe2::TensorProto_DataType_FLOAT16: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x0000ffff); - break; - default: - continue; - } - tensor->add_int32_data(data); - } - protos[dataType] = proto; - return proto; -} - -void TestDataType( - const caffe2::TensorProto::DataType& dataType, - std::string dataTypeName) { - LOG(INFO) << dataTypeName; - FLAGS_caffe2_serialize_using_bytes_as_holder = true; - int numEl = 1000; - // Proto with int32 - auto protoInt32 = CreateProtoWithInt32Data(dataType, numEl, false); - caffe2::Blob blobInt32; - DeserializeBlob(protoInt32, &blobInt32); - auto serializedStr = SerializeBlob(blobInt32, protoInt32.name()); - caffe2::BlobProto protoBytes; - // Proto with bytes - protoBytes.ParseFromString(serializedStr); - caffe2::Blob blobBytes; - DeserializeBlob(protoBytes, &blobBytes); - FLAGS_caffe2_serialize_using_bytes_as_holder = false; - // Proto with int32 from proto with bytes - protoBytes.ParseFromString(SerializeBlob(blobBytes, protoBytes.name())); - EXPECT_EQ(numEl, protoInt32.tensor().int32_data_size()); - EXPECT_EQ(numEl, protoBytes.tensor().int32_data_size()); - for (int i = 0; i < numEl; ++i) { - EXPECT_EQ( - protoInt32.tensor().int32_data(i), protoBytes.tensor().int32_data(i)); - } -} - -TEST(TensorSerialization, TestCorrectness) { - FLAGS_caffe2_serialize_using_bytes_as_holder = true; - TestDataType( - caffe2::TensorProto_DataType_INT32, "TensorProto_DataType_INT32"); - TestDataType(caffe2::TensorProto_DataType_BOOL, "TensorProto_DataType_BOOL"); - TestDataType( - caffe2::TensorProto_DataType_UINT8, "TensorProto_DataType_UINT8"); - TestDataType(caffe2::TensorProto_DataType_INT8, "TensorProto_DataType_INT8"); - TestDataType( - caffe2::TensorProto_DataType_UINT16, "TensorProto_DataType_UINT16"); - TestDataType( - caffe2::TensorProto_DataType_INT16, "TensorProto_DataType_INT16"); - TestDataType( - caffe2::TensorProto_DataType_FLOAT16, "TensorProto_DataType_FLOAT16"); -} - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/context_gpu_test.cc b/caffe2/core/context_gpu_test.cc deleted file mode 100644 index 9eb92b429ef0..000000000000 --- a/caffe2/core/context_gpu_test.cc +++ /dev/null @@ -1,161 +0,0 @@ -#include -#include -#include -#include -#include - -#include "caffe2/core/context_gpu.h" -#include - -namespace caffe2 { - -TEST(CUDATest, HasCudaRuntime) { - EXPECT_TRUE(HasCudaRuntime()); -} - -TEST(CUDAContextTest, TestAllocDealloc) { - if (!HasCudaGPU()) return; - CUDAContext context(0); - context.SwitchToDevice(); - auto data = CUDAContext::New(10 * sizeof(float)); - EXPECT_NE(data.get(), nullptr); -} - -TEST(CUDAContextTest, TestSetGetDeviceWithoutCaffeMode) { - // For a while, set full device control to be true. - for (int i = 0; i < NumCudaDevices(); ++i) { - CaffeCudaSetDevice(i); - EXPECT_EQ(CaffeCudaGetDevice(), i); - } - for (int i = NumCudaDevices() - 1; i >= 0; --i) { - CaffeCudaSetDevice(i); - EXPECT_EQ(CaffeCudaGetDevice(), i); - } -} - -TEST(CUDAContextTest, MemoryPoolAllocateDealloc) { - if (!HasCudaGPU()) - return; - if (GetCudaMemoryPoolType() == CudaMemoryPoolType::NONE) { - LOG(ERROR) << "Choose a memory type that is not none to test memory pool."; - return; - } - const int nbytes = 1048576; - for (int i = 0; i < NumCudaDevices(); ++i) { - LOG(INFO) << "Device " << i << " of " << NumCudaDevices(); - CUDAGuard guard(i); - auto allocated = CUDAContext::New(nbytes); - EXPECT_NE(allocated, nullptr); - cudaPointerAttributes attr; - CUDA_ENFORCE(cudaPointerGetAttributes(&attr, allocated.get())); - EXPECT_EQ(attr.type, cudaMemoryTypeDevice); - EXPECT_EQ(attr.device, i); - void* prev_allocated = allocated.get(); - allocated.clear(); - auto new_allocated = CUDAContext::New(nbytes); - // With a pool, the above allocation should yield the same address. - EXPECT_EQ(new_allocated.get(), prev_allocated); - // But, if we are allocating something larger, we will have a different - // chunk of memory. - auto larger_allocated = CUDAContext::New(nbytes * 2); - EXPECT_NE(larger_allocated.get(), prev_allocated); - } -} - -cudaStream_t getStreamForHandle(cublasHandle_t handle) { - cudaStream_t stream = nullptr; - CUBLAS_ENFORCE(cublasGetStream(handle, &stream)); - TORCH_CHECK_NOTNULL(stream); - return stream; -} - -TEST(CUDAContextTest, TestSameThreadSameObject) { - if (!HasCudaGPU()) return; - CUDAContext context_a(0); - CUDAContext context_b(0); - EXPECT_EQ(context_a.cuda_stream(), context_b.cuda_stream()); - EXPECT_EQ(context_a.cublas_handle(), context_b.cublas_handle()); - EXPECT_EQ( - context_a.cuda_stream(), getStreamForHandle(context_b.cublas_handle())); - // CuRAND generators are context-local. - EXPECT_NE(context_a.curand_generator(), context_b.curand_generator()); -} - -TEST(CUDAContextTest, TestSameThreadTempObject) { - if (!HasCudaGPU()) - return; - CUDAContext context_outer(0); // gpu id - context_outer.SwitchToDevice(); - - if (NumCudaDevices() >= 2) { - auto before_stream = context_outer.cuda_stream(); - - // try to mess up current device - CUDAContext context_different_device(1); - context_different_device.SwitchToDevice(10); - - // go back - context_outer.SwitchToDevice(); - EXPECT_EQ(context_outer.cuda_stream(), before_stream); - - // do nothing - infers the current device and stream - CUDAContext context_noop; - EXPECT_EQ(context_outer.cuda_stream(), before_stream); - EXPECT_EQ(context_noop.cuda_stream(), before_stream); - - - // override stream - the previous context is not valid any more until - // SwitchToDevice is called again (needs to be refactored into proper guard) - CUDAContext context_override; - context_override.SwitchToDevice(1); // logical stream id - EXPECT_NE(context_override.cuda_stream(), before_stream); - // note, that accessing streams from context_outer and context_noop is not - // semantically valid any more - } -} - -TEST(CUDAContextTest, TestSameThreadDifferntObjectIfDifferentDevices) { - if (NumCudaDevices() > 1) { - CUDAContext context_a(0); - CUDAContext context_b(1); - EXPECT_NE(context_a.cuda_stream(), context_b.cuda_stream()); - EXPECT_NE(context_a.cublas_handle(), context_b.cublas_handle()); - EXPECT_NE( - context_a.cuda_stream(), getStreamForHandle(context_b.cublas_handle())); - EXPECT_NE(context_a.curand_generator(), context_b.curand_generator()); - } -} - -namespace { -// A test function to return a stream address from a temp CUDA context. You -// should not use that stream though, because the actual stream is destroyed -// after thread exit. -void TEST_GetStreamAddress(cudaStream_t* ptr) { - CUDAContext context(0); - context.SwitchToDevice(); - *ptr = context.cuda_stream(); - // Sleep for a while so we have concurrent thread executions - std::this_thread::sleep_for(std::chrono::seconds(1)); -} -} // namespace - -TEST(CUDAContextTest, TestDifferntThreadDifferentobject) { - if (!HasCudaGPU()) return; - std::array temp = {0}; - // Same thread - TEST_GetStreamAddress(&temp[0]); - TEST_GetStreamAddress(&temp[1]); - EXPECT_TRUE(temp[0] != nullptr); - EXPECT_TRUE(temp[1] != nullptr); - EXPECT_EQ(temp[0], temp[1]); - // Different threads - std::thread thread_a(TEST_GetStreamAddress, &temp[0]); - std::thread thread_b(TEST_GetStreamAddress, &temp[1]); - thread_a.join(); - thread_b.join(); - EXPECT_TRUE(temp[0] != nullptr); - EXPECT_TRUE(temp[1] != nullptr); - EXPECT_NE(temp[0], temp[1]); -} - -} // namespace caffe2 diff --git a/caffe2/core/context_test.cc b/caffe2/core/context_test.cc deleted file mode 100644 index 304f973576c1..000000000000 --- a/caffe2/core/context_test.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include - -#include -#include -#include "caffe2/core/context.h" -#include "caffe2/proto/caffe2_pb.h" - -namespace caffe2 { - -TEST(CPUContextTest, TestAllocAlignment) { - for (int i = 1; i < 10; ++i) { - auto data = CPUContext::New(i); - EXPECT_EQ((reinterpret_cast(data.get()) % gAlignment), 0); - // data is freed when out of scope - } -} - -TEST(CPUContextTest, TestAllocDealloc) { - auto data_ptr = CPUContext::New(10 * sizeof(float)); - float* data = static_cast(data_ptr.get()); - EXPECT_NE(data, nullptr); - auto dst_data_ptr = CPUContext::New(10 * sizeof(float)); - float* dst_data = static_cast(dst_data_ptr.get()); - EXPECT_NE(dst_data, nullptr); - for (int i = 0; i < 10; ++i) { - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - data[i] = i; - } - DeviceOption option; - CPUContext context(option); - context.CopyToCPU(10, data, dst_data); - for (int i = 0; i < 10; ++i) { - EXPECT_FLOAT_EQ(dst_data[i], i); - } - // data_ptr is freed when out of scope -} - -} // namespace caffe2 diff --git a/caffe2/core/event_gpu_test.cc b/caffe2/core/event_gpu_test.cc deleted file mode 100644 index 18fe152198e2..000000000000 --- a/caffe2/core/event_gpu_test.cc +++ /dev/null @@ -1,50 +0,0 @@ -#include -#include "caffe2/core/context.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/core/event.h" - -namespace caffe2 { - -TEST(EventCUDATest, EventBasics) { - if (!HasCudaGPU()) - return; - DeviceOption device_cpu; - device_cpu.set_device_type(PROTO_CPU); - DeviceOption device_cuda; - device_cuda.set_device_type(PROTO_CUDA); - - CPUContext context_cpu(device_cpu); - CUDAContext context_cuda(device_cuda); - - Event event_cpu(device_cpu); - Event event_cuda(device_cuda); - - // CPU context and event interactions - context_cpu.Record(&event_cpu); - event_cpu.SetFinished(); - event_cpu.Finish(); - context_cpu.WaitEvent(event_cpu); - - event_cpu.Reset(); - event_cpu.Record(CPU, &context_cpu); - event_cpu.SetFinished(); - event_cpu.Wait(CPU, &context_cpu); - - // CUDA context and event interactions - context_cuda.SwitchToDevice(); - context_cuda.Record(&event_cuda); - context_cuda.WaitEvent(event_cuda); - event_cuda.Finish(); - - event_cuda.Reset(); - event_cuda.Record(CUDA, &context_cuda); - event_cuda.Wait(CUDA, &context_cuda); - - // CPU context waiting for CUDA event - context_cpu.WaitEvent(event_cuda); - - // CUDA context waiting for CPU event - context_cuda.WaitEvent(event_cpu); -} - -} // namespace caffe2 diff --git a/caffe2/core/event_test.cc b/caffe2/core/event_test.cc deleted file mode 100644 index ef25ae891e9a..000000000000 --- a/caffe2/core/event_test.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include -#include "caffe2/core/context.h" -#include "caffe2/core/event.h" - -namespace caffe2 { - -TEST(EventCPUTest, EventBasics) { - DeviceOption device_option; - device_option.set_device_type(PROTO_CPU); - Event event(device_option); - CPUContext context; - - context.Record(&event); - event.SetFinished(); - - context.WaitEvent(event); - event.Finish(); - - event.Reset(); - event.Record(CPU, &context); - event.SetFinished(); - event.Wait(CPU, &context); -} - -TEST(EventCPUTest, EventErrors) { - DeviceOption device_option; - device_option.set_device_type(PROTO_CPU); - Event event(device_option); - - event.SetFinished(); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(event.SetFinished("error"), caffe2::EnforceNotMet); - ASSERT_EQ(event.ErrorMessage(), "No error"); - - event.Reset(); - event.SetFinished("error 1"); - event.SetFinished("error 2"); - ASSERT_EQ(event.ErrorMessage(), "error 1"); -} - -} // namespace caffe2 diff --git a/caffe2/core/graph_test.cc b/caffe2/core/graph_test.cc deleted file mode 100644 index 8aa4f1610793..000000000000 --- a/caffe2/core/graph_test.cc +++ /dev/null @@ -1,200 +0,0 @@ -#include -#include "caffe2/core/graph.h" -#include "caffe2/core/net.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -namespace { - -using transform::Graph; - -static std::atomic counter; - -class GraphDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */) override { - counter.fetch_add(1); - return true; - } -}; - -REGISTER_CPU_OPERATOR(GraphDummyOp1, GraphDummyOp); - -OPERATOR_SCHEMA(GraphDummyOp1) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -REGISTER_CPU_OPERATOR(GraphDummyOp2, GraphDummyOp); - -OPERATOR_SCHEMA(GraphDummyOp2) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -REGISTER_CPU_OPERATOR(GraphDummyOp3, GraphDummyOp); - -OPERATOR_SCHEMA(GraphDummyOp3) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -// Checks if two netdefs are in terms of type, input, and output. -void compare_netdefs(const NetDef& net_a, const NetDef& net_b) { - EXPECT_EQ(net_a.op_size(), net_b.op_size()); - for (int i = 0; i < net_a.op_size(); i++) { - EXPECT_EQ(net_a.op(i).type(), net_b.op(i).type()); - EXPECT_EQ(net_a.op(i).input_size(), net_b.op(i).input_size()); - for (int j = 0; j < net_a.op(i).input_size(); j++) { - EXPECT_EQ(net_a.op(i).input(j), net_b.op(i).input(j)); - } - EXPECT_EQ(net_a.op(i).output_size(), net_b.op(i).output_size()); - for (int j = 0; j < net_a.op(i).output_size(); j++) { - EXPECT_EQ(net_a.op(i).output(j), net_b.op(i).output(j)); - } - } -} - -TEST(GraphTest, TestGenerateGraphChain) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - AddOp(&netdef, "GraphDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"mid1"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "GraphDummyOp2", {"mid3"}, {"out"}); - Graph g(netdef); - EXPECT_EQ(g.size(), 4); - for (int i = 0; i < 4; i++) { - if (i < 3) { - EXPECT_EQ(g.node(i).children.size(), 1); - EXPECT_TRUE(g.node(i).children.count(i + 1)); - } - if (i > 0) { - EXPECT_EQ(g.node(i).parents.size(), 1); - EXPECT_TRUE(g.node(i).parents.count(i - 1)); - } - } - NetDef retrieved_net = g.GetNetDef(); - compare_netdefs(retrieved_net, netdef); -} - -TEST(GraphTest, TestGenerateGraphChainInPlace) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - AddOp(&netdef, "GraphDummyOp1", {"in"}, {"out"}); - AddOp(&netdef, "GraphDummyOp2", {"out"}, {"out"}); - AddOp(&netdef, "GraphDummyOp1", {"out"}, {"out"}); - AddOp(&netdef, "GraphDummyOp2", {"out"}, {"out"}); - Graph g(netdef); - EXPECT_EQ(g.size(), 4); - for (int i = 0; i < 4; i++) { - if (i < 3) { - EXPECT_EQ(g.node(i).children.size(), 1); - EXPECT_TRUE(g.node(i).children.count(i + 1)); - } - if (i > 0) { - EXPECT_EQ(g.node(i).parents.size(), 1); - EXPECT_TRUE(g.node(i).parents.count(i - 1)); - } - } - NetDef retrieved_net = g.GetNetDef(); - compare_netdefs(retrieved_net, netdef); -} - -// Diamond Graph -TEST(GraphTest, TestGenerateGraphBranch) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "GraphDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"mid1"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp2", {"mid1"}, {"mid3"}); - AddOp(&netdef, "GraphDummyOp3", {"mid2", "mid3"}, {"out"}); - - Graph g(netdef); - - EXPECT_EQ(g.size(), 4); - EXPECT_EQ(g.node(0).parents.size(), 0); - EXPECT_EQ(g.node(0).children.size(), 2); - EXPECT_EQ(g.node(1).parents.size(), 1); - EXPECT_EQ(g.node(1).children.size(), 1); - EXPECT_EQ(g.node(2).parents.size(), 1); - EXPECT_EQ(g.node(2).children.size(), 1); - EXPECT_EQ(g.node(3).parents.size(), 2); - EXPECT_EQ(g.node(3).children.size(), 0); - - NetDef retrieved_net = g.GetNetDef(); - compare_netdefs(retrieved_net, netdef); -} - -// Double Diamond Graph, reused names -TEST(GraphTest, TestReusedInputs) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "GraphDummyOp1", {"in"}, {"in"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp3", {"mid1", "mid2"}, {"in"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp3", {"mid1", "mid2"}, {"in"}); - - Graph g(netdef); - - EXPECT_EQ(g.size(), 7); - EXPECT_EQ(g.node(0).parents.size(), 0); - EXPECT_EQ(g.node(0).children.size(), 2); - EXPECT_EQ(g.node(1).parents.size(), 1); - EXPECT_EQ(g.node(1).children.size(), 1); - EXPECT_EQ(g.node(2).parents.size(), 1); - EXPECT_EQ(g.node(2).children.size(), 1); - EXPECT_EQ(g.node(3).parents.size(), 2); - EXPECT_EQ(g.node(3).children.size(), 2); - EXPECT_EQ(g.node(4).parents.size(), 1); - EXPECT_EQ(g.node(4).children.size(), 1); - EXPECT_EQ(g.node(5).parents.size(), 1); - EXPECT_EQ(g.node(5).children.size(), 1); - EXPECT_EQ(g.node(6).parents.size(), 2); - EXPECT_EQ(g.node(6).children.size(), 0); - - NetDef retrieved_net = g.GetNetDef(); - compare_netdefs(retrieved_net, netdef); -} - -TEST(GraphTest, TestGetPerimeter) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "GraphDummyOp1", {"in"}, {"in"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp3", {"mid1", "mid2"}, {"in"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp1", {"mid1", "mid2"}, {"in"}); - - Graph g(netdef); - std::vector subgraph = {3}; - - auto subgraph_input = g.GetSubgraphInput(subgraph); - EXPECT_EQ(subgraph_input.size(), 2); - EXPECT_EQ(subgraph_input[0], std::make_pair(string("mid1"), 1)); - EXPECT_EQ(subgraph_input[1], std::make_pair(string("mid2"), 2)); - - auto subgraph_output = g.GetSubgraphOutput(subgraph); - EXPECT_EQ(subgraph_output.size(), 2); - EXPECT_EQ(subgraph_output[0], std::make_pair(string("in"), 4)); - EXPECT_EQ(subgraph_output[1], std::make_pair(string("in"), 5)); -} - -} // namespace - -} // namespace caffe2 diff --git a/caffe2/core/init_test.cc b/caffe2/core/init_test.cc deleted file mode 100644 index b94d610f5a91..000000000000 --- a/caffe2/core/init_test.cc +++ /dev/null @@ -1,72 +0,0 @@ -#include -#include - -#include -#include "caffe2/core/init.h" -#include "caffe2/core/logging.h" - -namespace caffe2 { -namespace { -bool gTestInitFunctionHasBeenRun = false; -bool gTestFailInitFunctionHasBeenRun = false; - -bool TestInitFunction(int*, char***) { - gTestInitFunctionHasBeenRun = true; - return true; -} - -bool TestFailInitFunction(int*, char***) { - gTestFailInitFunctionHasBeenRun = true; - return false; -} - -REGISTER_CAFFE2_INIT_FUNCTION( - TestInitFunction, - &TestInitFunction, - "Just a test to see if GlobalInit invokes " - "registered functions correctly."); - -int dummy_argc = 1; -const char* dummy_name = "foo"; -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-pro-type-const-cast) -char** dummy_argv = const_cast(&dummy_name); -} // namespace - -TEST(InitTest, TestInitFunctionHasRun) { - caffe2::GlobalInit(&dummy_argc, &dummy_argv); - EXPECT_TRUE(gTestInitFunctionHasBeenRun); - EXPECT_FALSE(gTestFailInitFunctionHasBeenRun); -} - -TEST(InitTest, CanRerunGlobalInit) { - caffe2::GlobalInit(&dummy_argc, &dummy_argv); - EXPECT_TRUE(caffe2::GlobalInit(&dummy_argc, &dummy_argv)); -} - -void LateRegisterInitFunction() { - ::caffe2::InitRegisterer testInitFunc( - TestInitFunction, false, "This should succeed but warn"); -} - -void LateRegisterEarlyInitFunction() { - ::caffe2::InitRegisterer testSecondInitFunc( - TestInitFunction, true, "This should fail for early init"); -} - -void LateRegisterFailInitFunction() { - ::caffe2::InitRegisterer testSecondInitFunc( - TestFailInitFunction, false, "This should fail for failed init"); -} - -TEST(InitTest, FailLateRegisterInitFunction) { - caffe2::GlobalInit(&dummy_argc, &dummy_argv); - LateRegisterInitFunction(); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(LateRegisterEarlyInitFunction(), ::c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(LateRegisterFailInitFunction(), ::c10::Error); - EXPECT_TRUE(gTestInitFunctionHasBeenRun); - EXPECT_TRUE(gTestFailInitFunctionHasBeenRun); -} - -} // namespace caffe2 diff --git a/caffe2/core/module_test.cc b/caffe2/core/module_test.cc deleted file mode 100644 index 585451d23b10..000000000000 --- a/caffe2/core/module_test.cc +++ /dev/null @@ -1,78 +0,0 @@ -#include -#include - -#include "caffe2/core/module.h" -#include "caffe2/core/operator.h" -#include -#include "caffe2/core/logging.h" - -// An explicitly defined module, testing correctness when we statically link a -// module -CAFFE2_MODULE(caffe2_module_test_static, "Static module for testing."); - -namespace caffe2 { - -class Caffe2ModuleTestStaticDummyOp : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - virtual string type() { - return "base"; - } -}; - -REGISTER_CPU_OPERATOR( - Caffe2ModuleTestStaticDummy, Caffe2ModuleTestStaticDummyOp); -OPERATOR_SCHEMA(Caffe2ModuleTestStaticDummy); - -TEST(ModuleTest, StaticModule) { - const string name = "caffe2_module_test_static"; - const auto& modules = CurrentModules(); - EXPECT_EQ(modules.count(name), 1); - EXPECT_TRUE(HasModule(name)); - - // LoadModule should not raise an error, since the module is already present. - LoadModule(name); - // Even a non-existing path should not cause error. - LoadModule(name, "/does/not/exist.so"); - EXPECT_EQ(modules.count(name), 1); - EXPECT_TRUE(HasModule(name)); - - // The module will then introduce the Caffe2ModuleTestStaticDummyOp. - OperatorDef op_def; - Workspace ws; - op_def.set_type("Caffe2ModuleTestStaticDummy"); - unique_ptr op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); -} - -#ifdef CAFFE2_BUILD_SHARED_LIBS -TEST(ModuleTest, DynamicModule) { - const string name = "caffe2_module_test_dynamic"; - const auto& modules = CurrentModules(); - EXPECT_EQ(modules.count(name), 0); - EXPECT_FALSE(HasModule(name)); - - // Before loading, we should not be able to create the op. - OperatorDef op_def; - Workspace ws; - op_def.set_type("Caffe2ModuleTestDynamicDummy"); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_THROW( - CreateOperator(op_def, &ws), - EnforceNotMet); - - // LoadModule should load the proper module. - LoadModule(name); - EXPECT_EQ(modules.count(name), 1); - EXPECT_TRUE(HasModule(name)); - - // The module will then introduce the Caffe2ModuleTestDynamicDummyOp. - unique_ptr op_after_load = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op_after_load.get()); -} -#endif - -} // namespace caffe2 diff --git a/caffe2/core/net_async_tracing_test.cc b/caffe2/core/net_async_tracing_test.cc deleted file mode 100644 index 10a81ada9255..000000000000 --- a/caffe2/core/net_async_tracing_test.cc +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "caffe2/core/net_async_tracing.h" - -namespace caffe2 { - -namespace tracing { - -void testExtractShardId(const string& name, int expectedId) { - EXPECT_EQ(extractShardId(name), expectedId); -} - -TEST(NetAsyncTracingTest, ExtractShardId) { - testExtractShardId("ABCDEFshard:1705!!A", 1705); - // Should use the last one - testExtractShardId("ABCDEFshard:4324!!Ashard:01220b", 1220); - // Nothing to extract - testExtractShardId("ABCDEFsha:222", -1); - // Regular cases - testExtractShardId("FC:shard:0", 0); - testExtractShardId("FC:shard:10", 10); - testExtractShardId("FC:shard:15", 15); -} - -TEST(NetAsyncTracingTest, EveryKIteration) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - arg { - name: "enable_tracing" - i: 1 - } - arg { - name: "tracing_mode" - s: "EVERY_K_ITERATIONS" - } - arg { - name: "tracing_filepath" - s: "/tmp" - } - arg { - name: "trace_every_nth_batch" - i: 1 - } - arg { - name: "dump_every_nth_batch" - i: 1 - } - op { - output: "out" - type: "UniformFill" - } -)DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - Workspace ws; - std::unique_ptr net(CreateNet(net_def, &ws)); - net->Run(); -} - -TEST(NetAsyncTracingTest, GlobalTimeSlice) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - arg { - name: "enable_tracing" - i: 1 - } - arg { - name: "tracing_filepath" - s: "/tmp" - } - arg { - name: "trace_for_n_ms" - i: 1 - } - arg { - name: "trace_every_n_ms" - i: 1 - } - op { - output: "out" - type: "UniformFill" - } -)DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - Workspace ws; - std::unique_ptr net(CreateNet(net_def, &ws)); - net->Run(); -} - -} // namespace tracing - -} // namespace caffe2 diff --git a/caffe2/core/net_dag_utils_test.cc b/caffe2/core/net_dag_utils_test.cc deleted file mode 100644 index dfbb56614301..000000000000 --- a/caffe2/core/net_dag_utils_test.cc +++ /dev/null @@ -1,296 +0,0 @@ -#include -#include "caffe2/core/net_dag_utils.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -namespace { -class DummySyncOp final : public Operator { - public: - DummySyncOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - return true; - } -}; - -class DummyAsyncOp final : public Operator { - public: - DummyAsyncOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - return true; - } - - bool HasAsyncPart() const override { - return true; - } -}; - -REGISTER_CPU_OPERATOR(DagUtilTestDummySync, DummySyncOp); -REGISTER_CPU_OPERATOR(DagUtilTestDummyAsync, DummyAsyncOp); - -OPERATOR_SCHEMA(DagUtilTestDummySync) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX); -OPERATOR_SCHEMA(DagUtilTestDummyAsync) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX); - -class DagUtilTestContext { - public: - DagUtilTestContext(const std::string& spec, Workspace* ws) { - net_def_ = std::make_shared(); - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, net_def_.get())); - operator_nodes_ = dag_utils::prepareOperatorNodes(net_def_, ws); - } - - dag_utils::ExecutionChains computeChains() { - return dag_utils::computeGroups(operator_nodes_); - } - - private: - std::shared_ptr net_def_{nullptr}; - std::vector operator_nodes_; -}; - -void PrintChains(const dag_utils::ExecutionChains& chains) { - for (const auto& kv : chains) { - std::stringstream ss; - ss << kv.first << ": "; - for (const auto& v : kv.second) { - ss << v << ", "; - } - LOG(INFO) << ss.str(); - } -} -} // namespace - -TEST(DagUtilTest, Empty) { - const auto spec = R"DOC( - name: "test0" - type: "async_scheduling" - )DOC"; - Workspace ws; - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - EXPECT_TRUE(chains.empty()); -} - -// 4 sync ops forming a diamond -TEST(DagUtilTest, AllSync) { - const auto spec = R"DOC( - name: "test1" - type: "async_scheduling" - external_input: "in" - op { - input: "in" - output: "n1" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - output: "n2" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - output: "n3" - type: "DagUtilTestDummySync" - } - op { - input: "n2" - input: "n3" - output: "out" - type: "DagUtilTestDummySync" - } - )DOC"; - Workspace ws; - ws.CreateBlob("in"); - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - dag_utils::ExecutionChains expected{{0, {0, 1, 2, 3}}}; - EXPECT_EQ(chains, expected); -} - -// 3 async ops forming an L shape -TEST(DagUtilTest, AllAsync) { - const auto spec = R"DOC( - name: "test2" - type: "async_scheduling" - external_input: "in0" - external_input: "in1" - op { - input: "in0" - output: "n1" - type: "DagUtilTestDummyAsync" - } - op { - input: "in1" - output: "n2" - type: "DagUtilTestDummyAsync" - } - op { - input: "n1" - output: "n3" - type: "DagUtilTestDummyAsync" - } - )DOC"; - Workspace ws; - ws.CreateBlob("in0"); - ws.CreateBlob("in1"); - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - dag_utils::ExecutionChains expected{{0, {0}}, {1, {1}}, {2, {2}}}; - EXPECT_EQ(chains, expected); -} - -// 3 sync ops and 1 async op (#2) forming a diamond -TEST(DagUtilTest, Mixed0) { - const auto spec = R"DOC( - name: "test3" - type: "async_scheduling" - external_input: "in" - op { - input: "in" - output: "n1" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - output: "n2" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - output: "n3" - type: "DagUtilTestDummyAsync" - } - op { - input: "n2" - input: "n3" - output: "out" - type: "DagUtilTestDummySync" - } - )DOC"; - Workspace ws; - ws.CreateBlob("in"); - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - dag_utils::ExecutionChains expected{{0, {0, 1}}, {2, {2}}, {3, {3}}}; - EXPECT_EQ(chains, expected); -} - -// 3 sync ops and 1 async op (#2) forming a Y shape -TEST(DagUtilTest, Mixed1) { - const auto spec = R"DOC( - name: "test3" - type: "async_scheduling" - external_input: "in0" - external_input: "in1" - op { - input: "in0" - output: "n1" - type: "DagUtilTestDummySync" - } - op { - input: "in1" - output: "n2" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - input: "n2" - output: "n3" - type: "DagUtilTestDummyAsync" - } - op { - input: "n3" - output: "out" - type: "DagUtilTestDummySync" - } - )DOC"; - Workspace ws; - ws.CreateBlob("in0"); - ws.CreateBlob("in1"); - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - dag_utils::ExecutionChains expected{{0, {0, 1}}, {2, {2}}, {3, {3}}}; - EXPECT_EQ(chains, expected); -} -// More complicated mixed case. * means async -// 0* -> 1* -> 2 -// | -// 3 -> 4 -> 5 -// | | -// | 6 -// - -> 8* -// 7* -/ -TEST(DagUtilTest, Mixed2) { - const auto spec = R"DOC( - name: "test4" - type: "async_scheduling" - external_input: "in0" - external_input: "in1" - external_input: "in2" - op { - input: "in0" - output: "n1" - type: "DagUtilTestDummyAsync" - } - op { - input: "n1" - output: "n2" - type: "DagUtilTestDummyAsync" - } - op { - input: "n2" - output: "out0" - type: "DagUtilTestDummySync" - } - op { - input: "in1" - output: "n3" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - input: "n3" - output: "n4" - type: "DagUtilTestDummySync" - } - op { - input: "n4" - output: "out1" - type: "DagUtilTestDummySync" - } - op { - input: "n3" - output: "out2" - type: "DagUtilTestDummySync" - } - op { - input: "in2" - output: "n7" - type: "DagUtilTestDummyAsync" - } - op { - input: "n3" - input: "n7" - output: "out3" - type: "DagUtilTestDummyAsync" - } - )DOC"; - Workspace ws; - ws.CreateBlob("in0"); - ws.CreateBlob("in1"); - ws.CreateBlob("in2"); - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - dag_utils::ExecutionChains expected{ - {0, {0}}, {1, {1}}, {3, {3, 6}}, {4, {4, 2, 5}}, {7, {7}}, {8, {8}}}; - EXPECT_EQ(chains, expected); -} -} // namespace caffe2 diff --git a/caffe2/core/net_gpu_test.cc b/caffe2/core/net_gpu_test.cc deleted file mode 100644 index 1eb6fa513a23..000000000000 --- a/caffe2/core/net_gpu_test.cc +++ /dev/null @@ -1,130 +0,0 @@ -#include -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/net.h" -#include "caffe2/core/net_async_base.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/scope_guard.h" - -namespace caffe2 { - -namespace { - -static std::atomic counter; - -// A net test dummy op that does nothing but scaffolding. Here, we -// inherit from OperatorBase because we instantiate on both CPU and -// GPU. In general, you want to only inherit from Operator. -class NetTestDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - - NetTestDummyOp(const OperatorDef& operator_def, Workspace* ws) - : OperatorBase(operator_def, ws), - fail_(OperatorBase::GetSingleArgument("fail", false)) {} - - bool Run(int /* unused */ /*stream_id*/) override { - if (fail_) { - return false; - } - counter.fetch_add(1); - return true; - } - - // Simulate CUDA operator behavior - bool HasAsyncPart() const override { - return debug_def().device_option().device_type() == PROTO_CUDA; - } - - bool SupportsAsyncScheduling() const override { - return debug_def().device_option().device_type() == PROTO_CUDA; - } - - protected: - const bool fail_; -}; - -REGISTER_CPU_OPERATOR(NetTestDummy, NetTestDummyOp); -REGISTER_CUDA_OPERATOR(NetTestDummy, NetTestDummyOp); -REGISTER_CPU_OPERATOR(NetTestDummy2, NetTestDummyOp); -REGISTER_CUDA_OPERATOR(NetTestDummy2, NetTestDummyOp); - -OPERATOR_SCHEMA(NetTestDummy) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); -OPERATOR_SCHEMA(NetTestDummy2) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{1, 0}}); - -} // namespace - -void testExecution(std::unique_ptr& net, int num_ops) { - // Run 100 times - for (int i = 0; i < 100; i++) { - counter.exchange(0); - net.get()->Run(); - ASSERT_EQ(num_ops, counter.load()); - } -} - -void checkChainingAndRun( - const char* spec, - const dag_utils::ExecutionChains& expected) { - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - { - net_def.set_num_workers(4); - std::unique_ptr net(CreateNet(net_def, &ws)); - auto* dag = dynamic_cast_if_rtti(net.get()); - TORCH_CHECK_NOTNULL(dag); - const auto& chains = dag->TEST_execution_chains(); - EXPECT_EQ(chains, expected); - testExecution(net, net_def.op().size()); - } -} - -TEST(NetTest, DISABLED_ChainingForDifferentDevices) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden" - type: "NetTestDummy" - } - op { - input: "hidden" - output: "out" - type: "NetTestDummy" - device_option { - device_type: 1 - } - } - op { - input: "out" - output: "out2" - type: "NetTestDummy" - device_option { - device_type: 1 - } - } - op { - input: "out2" - output: "out3" - type: "NetTestDummy" - device_option { - device_type: 1 - device_id: 1 - } - } -)DOC"; - if (HasCudaGPU() && NumCudaDevices() >= 2) { - checkChainingAndRun(spec, {{0, {0, 1, 2}}, {3, {3}}}); - } -} - -} // namespace caffe2 diff --git a/caffe2/core/net_simple_refcount_test.cc b/caffe2/core/net_simple_refcount_test.cc deleted file mode 100644 index 14acf998064a..000000000000 --- a/caffe2/core/net_simple_refcount_test.cc +++ /dev/null @@ -1,70 +0,0 @@ -#include -#include "c10/util/StringUtil.h" -#include "caffe2/core/net.h" -#include "caffe2/core/net_async_scheduling.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/scope_guard.h" - -#include - -namespace caffe2 { - -namespace { - -// A net test dummy op that does nothing but scaffolding. Here, we -// inherit from OperatorBase because we instantiate on both CPU and -// GPU. In general, you want to only inherit from Operator. -class NetSimpleRefCountTestOp final : public Operator { - public: - NetSimpleRefCountTestOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - USE_OPERATOR_FUNCTIONS(CPUContext); - - bool RunOnDevice() override { - const int32_t& input = OperatorBase::Input(0); - int32_t* output = OperatorBase::Output(0); - *output = input + 1; - return true; - } -}; - -REGISTER_CPU_OPERATOR(NetSimpleRefCountTest, NetSimpleRefCountTestOp); - -OPERATOR_SCHEMA(NetSimpleRefCountTest).NumInputs(1).NumOutputs(1); - -TEST(NetSimpleRefCountTest, TestCorrectness) { - Workspace ws; - *(ws.CreateBlob("a")->GetMutable()) = 1; - NetDef net_def; - net_def.set_type("simple_refcount"); - net_def.add_op()->CopyFrom( - CreateOperatorDef("NetSimpleRefCountTest", "", {"a"}, {"b"})); - net_def.add_op()->CopyFrom( - CreateOperatorDef("NetSimpleRefCountTest", "", {"b"}, {"c"})); - net_def.add_op()->CopyFrom( - CreateOperatorDef("NetSimpleRefCountTest", "", {"b"}, {"d"})); - net_def.add_op()->CopyFrom( - CreateOperatorDef("NetSimpleRefCountTest", "", {"c"}, {"e"})); - // After execution, what should look like is: - // a = 1 - // b = deallocated - // c = deallocated - // d = 3 - // e = 4 - std::unique_ptr net(CreateNet(net_def, &ws)); - net->Run(); - // Note on ASSERT vs EXPECT: ASSERT will quit directly if condition not - // met, which is why we guard IsType<> calls with ASSERT so that the - // subsequent Get() calls do not product an exception. - ASSERT_TRUE(ws.GetBlob("a")->IsType()); - EXPECT_EQ(ws.GetBlob("a")->Get(), 1); - EXPECT_EQ(ws.GetBlob("b")->GetRaw(), nullptr); - EXPECT_EQ(ws.GetBlob("c")->GetRaw(), nullptr); - ASSERT_TRUE(ws.GetBlob("d")->IsType()); - EXPECT_EQ(ws.GetBlob("d")->Get(), 3); - ASSERT_TRUE(ws.GetBlob("e")->IsType()); - EXPECT_EQ(ws.GetBlob("e")->Get(), 4); -} - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/net_test.cc b/caffe2/core/net_test.cc deleted file mode 100644 index a1c80eca6790..000000000000 --- a/caffe2/core/net_test.cc +++ /dev/null @@ -1,1122 +0,0 @@ -#include -#include "c10/util/StringUtil.h" -#include "caffe2/core/net.h" -#include "caffe2/core/net_async_scheduling.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/scope_guard.h" - -#include - -namespace caffe2 { - -namespace { - -static std::atomic counter; - -// A net test dummy op that does nothing but scaffolding. Here, we -// inherit from OperatorBase because we instantiate on both CPU and -// GPU. In general, you want to only inherit from Operator. -class NetTestDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - - NetTestDummyOp(const OperatorDef& operator_def, Workspace* ws) - : OperatorBase(operator_def, ws), - fail_(OperatorBase::GetSingleArgument("fail", false)) {} - - bool Run(int /* unused */ /*stream_id*/) override { - if (fail_) { - return false; - } - counter.fetch_add(1); - return true; - } - - // Simulate CUDA operator behavior - bool HasAsyncPart() const override { - return debug_def().device_option().device_type() == PROTO_CUDA; - } - - bool SupportsAsyncScheduling() const override { - return debug_def().device_option().device_type() == PROTO_CUDA; - } - - protected: - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - const bool fail_; -}; - -REGISTER_CPU_OPERATOR(NetTestDummy, NetTestDummyOp); -REGISTER_CUDA_OPERATOR(NetTestDummy, NetTestDummyOp); -REGISTER_CPU_OPERATOR(NetTestDummy2, NetTestDummyOp); -REGISTER_CUDA_OPERATOR(NetTestDummy2, NetTestDummyOp); - -OPERATOR_SCHEMA(NetTestDummy) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); -OPERATOR_SCHEMA(NetTestDummy2) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{1, 0}}); - -unique_ptr CreateNetTestHelper( - Workspace* ws, - const vector& input, - const vector& output) { - NetDef net_def; - { - auto& op = *(net_def.add_op()); - op.set_type("NetTestDummy"); - op.add_input("in"); - op.add_output("hidden"); - } - { - auto& op = *(net_def.add_op()); - op.set_type("NetTestDummy"); - op.add_input("hidden"); - op.add_output("out"); - } - - for (const auto& name : input) { - net_def.add_external_input(name); - } - for (const auto& name : output) { - net_def.add_external_output(name); - } - return CreateNet(net_def, ws); -} - -} // namespace - -TEST(NetTest, ConstructionNoDeclaredInputOutput) { - Workspace ws; - ws.CreateBlob("in"); - unique_ptr net( - CreateNetTestHelper(&ws, vector(), vector())); - EXPECT_TRUE(net.get() != nullptr); -} - -TEST(NetTest, ConstructionDeclaredInput) { - Workspace ws; - ws.CreateBlob("in"); - unique_ptr net( - CreateNetTestHelper(&ws, vector{"in"}, vector())); - EXPECT_TRUE(net.get() != nullptr); -} - -TEST(NetTest, ConstructionDeclaredOutput) { - Workspace ws; - ws.CreateBlob("in"); - unique_ptr net( - CreateNetTestHelper(&ws, vector(), vector{"out"})); - EXPECT_TRUE(net.get() != nullptr); -} - -TEST(NetTest, DeclaredInputInsufficient) { - Workspace ws; - ws.CreateBlob("in"); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW( - CreateNetTestHelper(&ws, vector{"unuseful_in"}, vector()), - EnforceNotMet); -} - -TEST(NetDeathTest, DeclaredOutputNotMet) { - Workspace ws; - ws.CreateBlob("in"); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW( - CreateNetTestHelper( - &ws, vector(), vector{"unproduced_out"}), - EnforceNotMet); -} - -void testExecution(std::unique_ptr& net, int num_ops) { - // Run 100 times - for (int i = 0; i < 100; i++) { - counter.exchange(0); - net.get()->Run(); - ASSERT_EQ(num_ops, counter.load()); - } -} - -void checkChainingAndRun( - const char* spec, - const dag_utils::ExecutionChains& expected) { - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - { - net_def.set_num_workers(4); - std::unique_ptr net(CreateNet(net_def, &ws)); - auto* dag = dynamic_cast_if_rtti(net.get()); - TORCH_CHECK_NOTNULL(dag); - const auto& chains = dag->TEST_execution_chains(); - EXPECT_TRUE(chains == expected); - testExecution(net, net_def.op().size()); - } -} - -void checkNumChainsAndRun(const char* spec, const int expected_num_chains) { - Workspace ws; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - net_def.set_num_workers(4); - - // Create all external inputs - // NOLINTNEXTLINE(performance-for-range-copy) - for (auto inp : net_def.external_input()) { - ws.CreateBlob(inp); - } - - { - std::unique_ptr net(CreateNet(net_def, &ws)); - auto* dag = dynamic_cast_if_rtti(net.get()); - TORCH_CHECK_NOTNULL(dag); - const auto& chains = dag->TEST_execution_chains(); - EXPECT_EQ(expected_num_chains, chains.size()); - testExecution(net, net_def.op().size()); - } -} - -TEST(NetTest, DISABLED_ChainingForLinearModel) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden" - type: "NetTestDummy" - } - op { - input: "hidden" - output: "out" - type: "NetTestDummy" - } -)DOC"; - checkChainingAndRun(spec, {{0, {0, 1}}}); -} - -TEST(NetTest, DISABLED_ChainingForFork) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden" - type: "NetTestDummy" - } - op { - input: "hidden" - output: "out1" - type: "NetTestDummy" - } - op { - input: "hidden" - output: "out2" - type: "NetTestDummy" - } -)DOC"; - checkChainingAndRun(spec, {{0, {0}}, {1, {1}}, {2, {2}}}); -} - -// TEST(NetTest, ChainingForJoinWithAncestor) { -// const auto spec = R"DOC( -// name: "example" -// type: "dag" -// external_input: "in" -// op { -// input: "in" -// output: "hidden" -// type: "NetTestDummy" -// } -// op { -// input: "hidden" -// output: "out1" -// type: "NetTestDummy" -// } -// op { -// input: "hidden" -// output: "out2" -// type: "NetTestDummy" -// } -// op { -// input: "hidden" -// input: "out2" -// type: "NetTestDummy" -// } -// )DOC"; -// checkChainingAndRun(spec, {{0, {0}}, {1, {1}}, {2, {2, 3}}}); -// } - -TEST(NetTest, DISABLED_ChainingForForkJoin) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden1" - type: "NetTestDummy" - } - op { - input: "in" - output: "hidden2" - type: "NetTestDummy" - } - op { - input: "hidden1" - input: "hidden2" - output: "out" - type: "NetTestDummy" - } - op { - input: "out" - output: "out2" - type: "NetTestDummy" - } -)DOC"; - checkChainingAndRun(spec, {{0, {0}}, {1, {1}}, {2, {2, 3}}}); -} - -TEST(NetTest, DISABLED_ChainingForwardBackward) { - const auto spec = R"DOC( - name: "gpu_0" - type: "dag" - op { - input: "in" - input: "fc_0_w" - input: "fc_0_b" - output: "fc_0" - name: "0" - type: "NetTestDummy" - } - op { - input: "fc_0" - output: "fc_0" - name: "1" - type: "NetTestDummy" - } - op { - input: "fc_0" - input: "fc_1_w" - input: "fc_1_b" - output: "fc_1" - name: "2" - type: "NetTestDummy" - } - op { - input: "fc_1" - output: "fc_1" - name: "3" - type: "NetTestDummy" - } - op { - input: "fc_1" - input: "fc_2_w" - input: "fc_2_b" - output: "fc_2" - name: "4" - type: "NetTestDummy" - } - op { - input: "fc_2" - output: "fc_2" - name: "5" - type: "NetTestDummy" - } - op { - input: "fc_2" - input: "fc_3_w" - input: "fc_3_b" - output: "fc_3" - name: "6" - type: "NetTestDummy" - } - op { - input: "fc_3" - output: "fc_3" - name: "7" - type: "NetTestDummy" - } - op { - input: "fc_3" - input: "fc_4_w" - input: "fc_4_b" - output: "fc_4" - name: "8" - type: "NetTestDummy" - } - op { - input: "fc_4" - output: "fc_4" - name: "9" - type: "NetTestDummy" - } - op { - input: "fc_4" - input: "in2" - output: "LabelCrossEntropy" - name: "10" - type: "NetTestDummy" - } - op { - input: "LabelCrossEntropy" - output: "AveragedLoss" - name: "11" - type: "NetTestDummy" - } - op { - input: "AveragedLoss" - output: "AveragedLoss_autogen_grad" - name: "12" - type: "NetTestDummy" - } - op { - input: "LabelCrossEntropy" - input: "AveragedLoss_autogen_grad" - output: "LabelCrossEntropy_grad" - name: "13" - type: "NetTestDummy" - } - op { - input: "fc_4" - input: "label" - input: "LabelCrossEntropy_grad" - output: "fc_4_grad" - name: "14" - type: "NetTestDummy2" - } - op { - input: "fc_4" - input: "fc_4_grad" - output: "fc_4_grad" - name: "15" - type: "NetTestDummy2" - } - op { - input: "fc_3" - input: "fc_4_w" - input: "fc_4_grad" - output: "fc_4_w_grad" - output: "fc_4_b_grad" - output: "fc_3_grad" - name: "16" - type: "NetTestDummy" - } - op { - input: "fc_3" - input: "fc_3_grad" - output: "fc_3_grad" - name: "17" - type: "NetTestDummy2" - } - op { - input: "fc_2" - input: "fc_3_w" - input: "fc_3_grad" - output: "fc_3_w_grad" - output: "fc_3_b_grad" - output: "fc_2_grad" - name: "18" - type: "NetTestDummy" - } - op { - input: "fc_2" - input: "fc_2_grad" - output: "fc_2_grad" - name: "19" - type: "NetTestDummy2" - } - op { - input: "fc_1" - input: "fc_2_w" - input: "fc_2_grad" - output: "fc_2_w_grad" - output: "fc_2_b_grad" - output: "fc_1_grad" - name: "20" - type: "NetTestDummy" - } - op { - input: "fc_1" - input: "fc_1_grad" - output: "fc_1_grad" - name: "21" - type: "NetTestDummy2" - } - op { - input: "fc_0" - input: "fc_1_w" - input: "fc_1_grad" - output: "fc_1_w_grad" - output: "fc_1_b_grad" - output: "fc_0_grad" - name: "22" - type: "NetTestDummy" - } - op { - input: "fc_0" - input: "fc_0_grad" - output: "fc_0_grad" - name: "23" - type: "NetTestDummy2" - } - op { - input: "in" - input: "fc_0_w" - input: "fc_0_grad" - output: "fc_0_w_grad" - output: "fc_0_b_grad" - output: "data_grad" - name: "24" - type: "NetTestDummy" - } - external_input: "in" - external_input: "in2" - external_input: "LR" - external_input: "fc_0_w" - external_input: "fc_0_b" - external_input: "fc_1_w" - external_input: "fc_1_b" - external_input: "fc_2_w" - external_input: "fc_2_b" - external_input: "fc_3_w" - external_input: "fc_3_b" - external_input: "fc_4_w" - external_input: "fc_4_b" - external_input: "label" - )DOC"; - checkNumChainsAndRun(spec, 1); -} - -TEST(NetTest, DISABLED_ChainingForHogwildModel) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden1" - type: "NetTestDummy" - } - op { - input: "hidden1" - output: "mid1" - type: "NetTestDummy" - } - op { - input: "mid1" - output: "out1" - type: "NetTestDummy" - } - op { - input: "in" - output: "hidden2" - type: "NetTestDummy" - } - op { - input: "hidden2" - output: "mid2" - type: "NetTestDummy" - } - op { - input: "mid2" - output: "out2" - type: "NetTestDummy" - } -)DOC"; - checkNumChainsAndRun(spec, 2); -} - -TEST(NetTest, DISABLED_FailingOperator) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden" - type: "NetTestDummy" - } - op { - input: "hidden" - output: "out" - type: "NetTestDummy" - arg { - name: "fail" - i: 1 - } - } -)DOC"; - - Workspace ws; - ws.CreateBlob("in"); - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - net_def.set_num_workers(4); - std::unique_ptr net(CreateNet(net_def, &ws)); - for (int i = 0; i < 10; i++) { - counter.exchange(0); - bool run_result = false; - try { - run_result = net->Run(); - } catch (const std::exception&) { - // async_scheduling would throw - } - ASSERT_FALSE(run_result); - - ASSERT_EQ(1, counter.load()); - } - } -} - -const int kTestPoolSize = 4; - -class ExecutorHelperDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - - ExecutorHelperDummyOp(const OperatorDef& operator_def, Workspace* ws) - : OperatorBase(operator_def, ws) {} - - bool Run(int /* unused */ /*stream_id*/) override { - auto helper = GetExecutorHelper(); - CAFFE_ENFORCE(helper); - auto pool = helper->GetPool(device_option()); - CAFFE_ENFORCE(pool); - auto pool_size = pool->size(); - CAFFE_ENFORCE_EQ(pool_size, kTestPoolSize); - return true; - } -}; - -REGISTER_CPU_OPERATOR(ExecutorHelperDummy, ExecutorHelperDummyOp); - -OPERATOR_SCHEMA(ExecutorHelperDummy); - -TEST(NetTest, OperatorWithExecutorHelper) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - op { - type: "ExecutorHelperDummy" - } -)DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - Workspace ws; - net_def.set_num_workers(kTestPoolSize); - std::unique_ptr net(CreateNet(net_def, &ws)); - ASSERT_TRUE(net->Run()); -} - -TEST(NetTest, DISABLED_OperatorWithDisabledEvent) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - external_input: "in" - op { - input: "in" - output: "out" - type: "NetTestDummy" - arg { - name: "fail" - i: 1 - } - } -)DOC"; - - Workspace ws; - ws.CreateBlob("in"); - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - std::unique_ptr net(CreateNet(net_def, &ws)); - net->GetOperators()[0]->DisableEvent(); - // async_scheduling propagates exception - bool caught_exception = false; - try { - net->Run(); - } catch (const std::exception& e) { - caught_exception = true; - } - ASSERT_TRUE(caught_exception); - } -} - -TEST(NetTest, ExecutorOverride) { - const auto spec = R"DOC( - name: "example" - type: "dag" - )DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - Workspace ws; - auto old = FLAGS_caffe2_override_executor; - auto g = MakeGuard([&]() { FLAGS_caffe2_override_executor = old; }); - FLAGS_caffe2_override_executor = "dag,async_scheduling"; - - std::unique_ptr net(CreateNet(net_def, &ws)); - auto async_net = - caffe2::dynamic_cast_if_rtti(net.get()); - ASSERT_TRUE(async_net != nullptr); - } -} - -TEST(NetTest, AsyncEmptyNet) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - )DOC"; - - Workspace ws; - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - std::unique_ptr net(CreateNet(net_def, &ws)); - bool caught_exception = false; - try { - ASSERT_TRUE(net->Run()); - } catch (const std::exception& e) { - caught_exception = true; - } - ASSERT_FALSE(caught_exception); - } -} - -TEST(NetTest, DISABLED_RunAsyncFailure) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - op { - input: "in" - output: "out" - type: "NetTestDummy" - arg { - name: "fail" - i: 1 - } - } - )DOC"; - - Workspace ws; - ws.CreateBlob("in"); - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - std::unique_ptr net(CreateNet(net_def, &ws)); - - bool caught_exception = false; - try { - ASSERT_FALSE(net->Run()); - } catch (const std::exception& e) { - caught_exception = true; - } - ASSERT_TRUE(caught_exception); - } -} - -TEST(NetTest, NoTypeNet) { - const auto spec = R"DOC( - name: "no_type_net" - )DOC"; - - Workspace ws; - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - std::unique_ptr net(CreateNet(net_def, &ws)); - ASSERT_TRUE(net); - } -} - -class NotFinishingOp final : public Operator { - public: - NotFinishingOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // never calls SetFinished - return true; - } - - bool HasAsyncPart() const override { - return true; - } -}; - -REGISTER_CPU_OPERATOR(NotFinishingOp, NotFinishingOp); - -OPERATOR_SCHEMA(NotFinishingOp); - -TEST(NetTest, PendingOpsAndNetFailure) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - op { - type: "NotFinishingOp" - } - op { - type: "NetTestDummy" - arg { - name: "fail" - i: 1 - } - } -)DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - Workspace ws; - std::unique_ptr net(CreateNet(net_def, &ws)); - - try { - // net is not stuck and returns false - ASSERT_FALSE(net->Run()); - } catch (const caffe2::AsyncNetCancelled&) { - // Cancellation exception is fine since if the ops run concurrently the - // NotFinishingOp may be cancelled with an exception. - } -} - -class AsyncErrorOp final : public Operator { - public: - AsyncErrorOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws), - throw_(OperatorBase::GetSingleArgument("throw", false)), - fail_in_sync_( - OperatorBase::GetSingleArgument("fail_in_sync", false)), - sleep_time_s_(OperatorBase::GetSingleArgument("sleep_time", 1)), - error_msg_(OperatorBase::GetSingleArgument( - "error_msg", - "Error")) {} - - bool RunOnDevice() override { - if (fail_in_sync_) { - if (throw_) { - throw std::logic_error(error_msg_); - } else { - return false; - } - } else { - if (thread_) { - thread_->join(); - } - thread_ = std::make_unique([this]() { - try { - std::this_thread::sleep_for(std::chrono::seconds(sleep_time_s_)); - if (throw_) { - throw std::logic_error(error_msg_); - } else { - if (!cancel_.test_and_set()) { - event().SetFinished(error_msg_.c_str()); - } - } - } catch (...) { - if (!cancel_.test_and_set()) { - event().SetFinishedWithException(error_msg_.c_str()); - } - } - }); - return true; - } - } - - bool HasAsyncPart() const override { - return true; - } - - void CancelAsyncCallback() override { - cancel_.test_and_set(); - } - - ~AsyncErrorOp() override { - if (thread_) { - thread_->join(); - } - } - - private: - std::unique_ptr thread_; - bool throw_; - bool fail_in_sync_; - int sleep_time_s_; - std::string error_msg_; - std::atomic_flag cancel_ = ATOMIC_FLAG_INIT; -}; - -REGISTER_CPU_OPERATOR(AsyncErrorOp, AsyncErrorOp); -OPERATOR_SCHEMA(AsyncErrorOp); - -std::unique_ptr AsyncErrorNet( - Workspace* ws, - const std::string& net_name, - bool throw_, - bool fail_in_sync) { - std::string spec_template = R"DOC( - name: "" - type: "async_scheduling" - op { - type: "AsyncErrorOp" - arg { - name: "throw" - i: - } - arg { - name: "fail_in_sync" - i: - } - } - )DOC"; - - std::string spec = spec_template; - ReplaceAll(spec, "", net_name.c_str()); - ReplaceAll(spec, "", throw_ ? "1" : "0"); - ReplaceAll(spec, "", fail_in_sync ? "1" : "0"); - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - return CreateNet(net_def, ws); -} - -TEST(NetTest, AsyncErrorOpTest) { - Workspace ws; - - // Throw in sync part - auto net = AsyncErrorNet(&ws, "net1", /*throw_*/ true, /*fail_in_sync*/ true); -#ifdef CAFFE2_USE_EXCEPTION_PTR - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(net->Run(), std::logic_error); -#endif - - // Return false in sync part - net = AsyncErrorNet(&ws, "net2", /*throw_*/ false, /*fail_in_sync*/ true); - ASSERT_FALSE(net->Run()); - - // SetFinishedWithException in async part - net = AsyncErrorNet(&ws, "net3", /*throw_*/ true, /*fail_in_sync*/ false); -#ifdef CAFFE2_USE_EXCEPTION_PTR - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(net->Run(), std::logic_error); -#endif - - // SetFinished(err) in async part - net = AsyncErrorNet(&ws, "net4", /*throw_*/ false, /*fail_in_sync*/ false); - ASSERT_FALSE(net->Run()); -} - -TEST(NetTest, AsyncErrorTimingsTest) { - Workspace ws; - std::string spec = R"DOC( - name: "net" - type: "async_scheduling" - op { - type: "AsyncErrorOp" - arg { - name: "throw" - i: 1 - } - arg { - name: "fail_in_sync" - i: 0 - } - arg { - name: "sleep_time" - i: 2 - } - arg { - name: "error_msg" - s: "Error1" - } - } - op { - type: "AsyncErrorOp" - arg { - name: "throw" - i: 1 - } - arg { - name: "fail_in_sync" - i: 0 - } - arg { - name: "sleep_time" - i: 1 - } - arg { - name: "error_msg" - s: "Error2" - } - } - )DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - auto net = CreateNet(net_def, &ws); - - try { - net->Run(); - } catch (const std::logic_error& e) { - ASSERT_TRUE(std::string(e.what()) == "Error2"); - } catch (...) { - FAIL() << "Expected std::logic_error thrown"; - } -} - -class SyncErrorOp final : public Operator { - public: - SyncErrorOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws), - fail_(OperatorBase::GetSingleArgument("fail", true)), - throw_(OperatorBase::GetSingleArgument("throw", false)) {} - - bool RunOnDevice() override { - if (fail_) { - if (throw_) { - throw std::logic_error("Error"); - } else { - return false; - } - } else { - return true; - } - } - - // NOLINTNEXTLINE(modernize-use-equals-default) - ~SyncErrorOp() override {} - - private: - bool fail_; - bool throw_; -}; - -REGISTER_CPU_OPERATOR(SyncErrorOp, SyncErrorOp); -OPERATOR_SCHEMA(SyncErrorOp); - -std::unique_ptr -ChainErrorNet(Workspace* ws, const std::string& net_name, bool throw_) { - std::string spec_template = R"DOC( - name: "" - type: "async_scheduling" - op { - type: "SyncErrorOp" - arg { - name: "fail" - i: 1 - } - arg { - name: "throw" - i: - } - } - op { - type: "SyncErrorOp" - arg { - name: "fail" - i: 0 - } - } - )DOC"; - - std::string spec = spec_template; - ReplaceAll(spec, "", net_name.c_str()); - ReplaceAll(spec, "", throw_ ? "1" : "0"); - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - return CreateNet(net_def, ws); -} - -TEST(NetTest, ChainErrorTest) { - Workspace ws; - - auto net = ChainErrorNet(&ws, "net1", /*throw_*/ true); -#ifdef CAFFE2_USE_EXCEPTION_PTR - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(net->Run(), std::logic_error); -#endif - - net = ChainErrorNet(&ws, "net2", /*throw_*/ false); - ASSERT_FALSE(net->Run()); -} - -void testProfDAGNetErrorCase(bool test_error) { - std::string spec_template = R"DOC( - name: "prof_dag_error_test_net" - type: "prof_dag" - external_input: "in" - op { - input: "in" - output: "hidden" - type: "SyncErrorOp" - arg { - name: "fail" - i: - } - arg { - name: "throw" - i: 0 - } - } - op { - input: "hidden" - output: "out" - type: "SyncErrorOp" - arg { - name: "fail" - i: 0 - } - } - )DOC"; - - Workspace ws; - ws.CreateBlob("in"); - - NetDef net_def; - std::string net_spec = spec_template; - ReplaceAll(net_spec, "", test_error ? "1" : "0"); - CAFFE_ENFORCE(TextFormat::ParseFromString(net_spec, &net_def)); - auto net = CreateNet(net_def, &ws); - - // with failing op - net runs return false, without - true - for (auto num_runs = 0; num_runs < 10; ++num_runs) { - auto ret = net->Run(); - ASSERT_TRUE(test_error ? !ret : ret); - } - - // with failing op - prof_dag handles invalid runs and returns empty stats, - // without - returns stats for each op - auto* prof_dag = dynamic_cast_if_rtti(net.get()); - TORCH_CHECK_NOTNULL(prof_dag); - auto stats_proto = prof_dag->GetPerOperatorCost(); - ASSERT_EQ( - stats_proto.stats_size(), test_error ? 0 : net->GetOperators().size()); -} - -TEST(NetTest, ProfDAGNetErrorTest) { - testProfDAGNetErrorCase(/*test_error=*/false); - testProfDAGNetErrorCase(/*test_error=*/true); -} - -} // namespace caffe2 diff --git a/caffe2/core/observer_test.cc b/caffe2/core/observer_test.cc deleted file mode 100644 index 50faf92e8414..000000000000 --- a/caffe2/core/observer_test.cc +++ /dev/null @@ -1,183 +0,0 @@ -#include -#include "c10/util/Registry.h" -#include "caffe2/core/common.h" -#include "caffe2/core/net.h" -#include "caffe2/core/net_simple.h" -#include "caffe2/core/observer.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/scope_guard.h" - -namespace caffe2 { - -namespace { - -static std::atomic counter; - -template -class DummyObserver final : public ObserverBase { - public: - explicit DummyObserver(T* subject_) : ObserverBase(subject_) {} - void Start() override; - void Stop() override; - - // NOLINTNEXTLINE(modernize-use-equals-default) - ~DummyObserver() override {} -}; - -template <> -void DummyObserver::Start() { - vector operators = subject_->GetOperators(); - for (auto& op : operators) { - op->AttachObserver(std::make_unique>(op)); - } - counter.fetch_add(1000); -} - -template <> -void DummyObserver::Start() { - counter.fetch_add(100); -} - -template <> -void DummyObserver::Stop() { - counter.fetch_add(10); -} - -template <> -void DummyObserver::Stop() { - counter.fetch_add(1); -} - -class ObsTestDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */) override { - StartAllObservers(); - StopAllObservers(); - return true; - } -}; - -REGISTER_CPU_OPERATOR(ObsTestDummy, ObsTestDummyOp); -REGISTER_CUDA_OPERATOR(ObsTestDummy, ObsTestDummyOp); - -OPERATOR_SCHEMA(ObsTestDummy) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -unique_ptr CreateNetTestHelper(Workspace* ws, bool isDAG = false) { - NetDef net_def; - if (isDAG) { - net_def.set_type("dag"); - } - { - auto& op = *(net_def.add_op()); - op.set_type("ObsTestDummy"); - op.add_input("in"); - op.add_output("hidden"); - } - { - auto& op = *(net_def.add_op()); - op.set_type("ObsTestDummy"); - op.add_input("hidden"); - op.add_output("out"); - } - net_def.add_external_input("in"); - net_def.add_external_output("out"); - - return CreateNet(net_def, ws); -} -} - -TEST(ObserverTest, TestNotify) { - auto count_before = counter.load(); - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - unique_ptr net(CreateNetTestHelper(&ws)); - EXPECT_EQ(caffe2::dynamic_cast_if_rtti(net.get()), net.get()); - unique_ptr> net_ob = - make_unique>(net.get()); - net.get()->AttachObserver(std::move(net_ob)); - net.get()->Run(); - auto count_after = counter.load(); - EXPECT_EQ(1212, count_after - count_before); -} - -TEST(ObserverTest, TestUniqueMap) { - auto count_before = counter.load(); - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - unique_ptr net(CreateNetTestHelper(&ws)); - EXPECT_EQ(caffe2::dynamic_cast_if_rtti(net.get()), net.get()); - unique_ptr> net_ob = - make_unique>(net.get()); - auto* ref = net.get()->AttachObserver(std::move(net_ob)); - net.get()->Run(); - unique_ptr::Observer> test = - net.get()->DetachObserver(ref); - auto count_after = counter.load(); - EXPECT_EQ(1212, count_after - count_before); -} - -TEST(ObserverTest, TestNotifyAfterDetach) { - auto count_before = counter.load(); - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - unique_ptr net(CreateNetTestHelper(&ws)); - unique_ptr> net_ob = - make_unique>(net.get()); - auto* ob = net.get()->AttachObserver(std::move(net_ob)); - net.get()->DetachObserver(ob); - net.get()->Run(); - auto count_after = counter.load(); - EXPECT_EQ(0, count_after - count_before); -} - -TEST(ObserverTest, TestDAGNetBase) { - auto count_before = counter.load(); - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - unique_ptr net(CreateNetTestHelper(&ws, true)); - unique_ptr> net_ob = - make_unique>(net.get()); - net.get()->AttachObserver(std::move(net_ob)); - net.get()->Run(); - auto count_after = counter.load(); - EXPECT_EQ(1212, count_after - count_before); -} - -#if 0 -// This test intermittently segfaults, -// see https://github.com/pytorch/pytorch/issues/9137 -TEST(ObserverTest, TestMultipleNetBase) { - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - unique_ptr net(CreateNetTestHelper(&ws, true)); - EXPECT_EQ(caffe2::dynamic_cast_if_rtti(net.get()), net.get()); - - // There may be some default observers - const size_t prev_num = net.get()->NumObservers(); - const int num_tests = 100; - vector::Observer*> observers; - for (int i = 0; i < num_tests; ++i) { - unique_ptr> net_ob = - make_unique>(net.get()); - observers.emplace_back(net.get()->AttachObserver(std::move(net_ob))); - } - - net.get()->Run(); - - for (const auto& observer : observers) { - net.get()->DetachObserver(observer); - } - - EXPECT_EQ(net.get()->NumObservers(), prev_num); -} -#endif -} // namespace caffe2 diff --git a/caffe2/core/operator_gpu_test.cc b/caffe2/core/operator_gpu_test.cc deleted file mode 100644 index 80c58b7a3c75..000000000000 --- a/caffe2/core/operator_gpu_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -#include - -#include -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -class JustTest : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - virtual std::string type() { - return "BASE"; - } -}; - -class JustTestCUDA : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - std::string type() override { - return "CUDA"; - } -}; - -class JustTestCUDNN : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - std::string type() override { - return "CUDNN"; - } -}; - -OPERATOR_SCHEMA(JustTest).NumInputs(0, 1).NumOutputs(0, 1); -REGISTER_CUDA_OPERATOR(JustTest, JustTestCUDA); -REGISTER_CUDNN_OPERATOR(JustTest, JustTestCUDNN); - -TEST(EnginePrefTest, GPUDeviceDefaultPreferredEngines) { - if (!HasCudaGPU()) - return; - OperatorDef op_def; - Workspace ws; - op_def.mutable_device_option()->set_device_type(PROTO_CUDA); - op_def.set_type("JustTest"); - - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // CUDNN should be taken as it's in the default global preferred engines - // list - EXPECT_EQ(static_cast(op.get())->type(), "CUDNN"); - } -} - -} // namespace caffe2 diff --git a/caffe2/core/operator_schema_test.cc b/caffe2/core/operator_schema_test.cc deleted file mode 100644 index 5e54cf7d37dd..000000000000 --- a/caffe2/core/operator_schema_test.cc +++ /dev/null @@ -1,279 +0,0 @@ -#include "caffe2/core/logging.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/operator_schema.h" -#include "caffe2/utils/proto_utils.h" - -#include - -namespace caffe2 { - -OPERATOR_SCHEMA(OpSchemaTestOp) - .NumInputs(1).NumOutputs(1) - .SetDoc(R"DOC(Test Documentation)DOC") - .Input(0, "in0", "dummy input.") - .Output(0, "out0", "dummy output."); - -TEST(OperatorSchemaTest, BasicSchema) { - const OpSchema* schema = OpSchemaRegistry::Schema("OpSchemaTestOp"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - EXPECT_TRUE(schema != nullptr); - EXPECT_TRUE(schema->doc() != nullptr); - OperatorDef def1 = CreateOperatorDef( - "OpSchemaTestOp", "", - vector{"in"}, vector{"out"}); - EXPECT_TRUE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaTestOp", "", - vector{"in1", "in2"}, vector{"out"}); - EXPECT_FALSE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaTestOp", "", - vector{"in"}, vector{"out1", "out2"}); - EXPECT_FALSE(schema->Verify(def3)); -} - -OPERATOR_SCHEMA(OpSchemaSpecifiedInputOutputOp) - .NumInputs({2, 4}).NumOutputs({1, 3}); - -TEST(OperatorSchemaTest, SpecifiedInputOutput) { - const OpSchema* schema - = OpSchemaRegistry::Schema("OpSchemaSpecifiedInputOutputOp"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - EXPECT_TRUE(schema != nullptr); - OperatorDef def1 = CreateOperatorDef( - "OpSchemaSpecifiedInputOutputOp", "", - vector{"in"}, vector{"out"}); - EXPECT_FALSE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaSpecifiedInputOutputOp", "", - vector{"in1", "in2"}, vector{"out"}); - EXPECT_TRUE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaSpecifiedInputOutputOp", "", - vector{"in1", "in2"}, vector{"out1", "out2"}); - EXPECT_FALSE(schema->Verify(def3)); -} - -OPERATOR_SCHEMA(OpSchemaInputOutputRelationOp) - .NumInputsOutputs([](int in, int out) { - return out == in || out == in * 2; - }); - -TEST(OperatorSchemaTest, InputOutputRelation) { - const OpSchema* schema - = OpSchemaRegistry::Schema("OpSchemaInputOutputRelationOp"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - EXPECT_TRUE(schema != nullptr); - OperatorDef def1 = CreateOperatorDef( - "OpSchemaInputOutputRelationOp", "", - vector{"in"}, vector{"out"}); - EXPECT_TRUE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaInputOutputRelationOp", "", - vector{"in"}, vector{"out1", "out2"}); - EXPECT_TRUE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaInputOutputRelationOp", "", - vector{"in1", "in2", "in3"}, vector{"out1", "out2"}); - EXPECT_FALSE(schema->Verify(def3)); -} - -OPERATOR_SCHEMA(OpSchemaSameInputOutputOp) - .SameNumberOfOutput(); - -TEST(OperatorSchemaTest, SameInputOutput) { - const OpSchema* schema = - OpSchemaRegistry::Schema("OpSchemaSameInputOutputOp"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - OperatorDef def1 = CreateOperatorDef( - "OpSchemaSameInputOutputOp", "", - vector{"in"}, vector{"out"}); - EXPECT_TRUE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaSameInputOutputOp", "", - vector{"in1", "in2"}, vector{"out1", "out2"}); - EXPECT_TRUE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaSameInputOutputOp", "", - vector{"in1", "in2"}, vector{"out1", "out2", "out3"}); - EXPECT_FALSE(schema->Verify(def3)); -} - -OPERATOR_SCHEMA(OpSchemaCalculateOutputOp) - .NumInputs(1, 5).NumOutputs(2, 6) - .OutputCalculator([](int n) { return n + 1; }); - -TEST(OperatorSchemaTest, CalculateOutput) { - const OpSchema* schema = - OpSchemaRegistry::Schema("OpSchemaCalculateOutputOp"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - OperatorDef def1 = CreateOperatorDef( - "OpSchemaCalculateOutputOp", "", - vector{"in"}, vector{"out"}); - EXPECT_FALSE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaCalculateOutputOp", "", - vector{"in1", "in2"}, vector{"out1", "out2"}); - EXPECT_FALSE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaCalculateOutputOp", "", - vector{"in1", "in2"}, vector{"out1", "out2", "out3"}); - EXPECT_TRUE(schema->Verify(def3)); -} - -OPERATOR_SCHEMA(OpSchemaInplace) - .NumInputs(2).NumOutputs(2) - .AllowInplace({{0, 0}}) - .EnforceInplace({{1, 1}}); - -TEST(OperatorSchemaTest, Inplace) { - const OpSchema* schema = - OpSchemaRegistry::Schema("OpSchemaInplace"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - OperatorDef def1 = CreateOperatorDef( - "OpSchemaInplace", "", - vector{"in1", "in2"}, vector{"out1", "in2"}); - EXPECT_TRUE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaInplace", "", - vector{"in1", "in2"}, vector{"in1", "in2"}); - EXPECT_TRUE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaInplace", "", - vector{"in1", "in2"}, vector{"in1", "out2"}); - EXPECT_FALSE(schema->Verify(def3)); - OperatorDef def4 = CreateOperatorDef( - "OpSchemaInplace", "", - vector{"in1", "in2"}, vector{"out1", "out2"}); - EXPECT_FALSE(schema->Verify(def4)); -} - -OPERATOR_SCHEMA(OpSchemaSameInputOutputTensorInference).IdenticalTypeAndShape(); - -TEST(OperatorSchemaTest, TensorInferenceIdentical) { - const OpSchema* schema = - OpSchemaRegistry::Schema("OpSchemaSameInputOutputTensorInference"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - OperatorDef def = CreateOperatorDef( - "OpSchemaSameInputOutputTensorInference", - "", - vector{"in"}, - vector{"out"}); - vector shapes(1); - shapes[0].set_data_type(TensorProto::FLOAT); - shapes[0].add_dims(1); - shapes[0].add_dims(2); - shapes[0].add_dims(3); - vector out = schema->InferTensor(def, shapes); - EXPECT_EQ(out.size(), 1); - EXPECT_EQ(out[0].SerializeAsString(), shapes[0].SerializeAsString()); -} - -OPERATOR_SCHEMA(OpSchemaArbitraryTensorInference) - .TensorInferenceFunction( - [](const OperatorDef&, const vector&) { - vector shapes(1); - shapes[0].set_data_type(TensorProto::FLOAT); - shapes[0].add_dims(1701); - return shapes; - }); - -TEST(OperatorSchemaTest, TensorInferenceArbitrary) { - const OpSchema* schema = - OpSchemaRegistry::Schema("OpSchemaArbitraryTensorInference"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - OperatorDef def = CreateOperatorDef( - "OpSchemaArbitraryTensorInference", - "", - vector{"in"}, - vector{"out"}); - vector out = schema->InferTensor(def, vector()); - EXPECT_EQ(out.size(), 1); - EXPECT_EQ(out[0].data_type(), TensorProto::FLOAT); - EXPECT_EQ(out[0].dims_size(), 1); - EXPECT_EQ(out[0].dims(0), 1701); -} - -TEST(OperatorSchemaTest, TestCastSchema) { - // This tests a use case of the schema: the Cast op takes in the def and - // deduces the - // schema from the "to" argument. - const OpSchema* schema = OpSchemaRegistry::Schema("Cast"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - if (!schema) { - // Compiled without the Cast op. - return; - } - OperatorDef def = CreateOperatorDef( - "Cast", - "", - vector{"in"}, - vector{"out"}, - vector{MakeArgument("to", TensorProto::UINT8)}); - vector out = schema->InferTensor(def, vector(1)); - EXPECT_EQ(out.size(), 1); - // Data type should be inferred. - EXPECT_EQ(out[0].data_type(), TensorProto::UINT8); - // Dim should not be set (same as input); - EXPECT_EQ(out[0].dims_size(), 0); -} - -OPERATOR_SCHEMA(OpSchemaCostInference) - .NumInputs(2) - .NumOutputs(2) - .CostInferenceFunction([](const OperatorDef& /*def*/, - const vector& inputs) { - struct OpSchema::Cost c; - c.flops = 2 * inputs[0].dims(0) * inputs[0].dims(1) * inputs[1].dims(1); - return c; - }); - -TEST(OperatorSchemaTest, TestCostInference) { - const OpSchema* schema = OpSchemaRegistry::Schema("OpSchemaCostInference"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - if (!schema) { - return; - } - OperatorDef def = CreateOperatorDef( - "OpSchemaCostInference", "", vector{"in"}, vector{"out"}); - vector shapes(2); - shapes[0].set_data_type(TensorProto::FLOAT); - shapes[0].add_dims(10); - shapes[0].add_dims(10); - shapes[1].set_data_type(TensorProto::FLOAT); - shapes[1].add_dims(10); - shapes[1].add_dims(10); - EXPECT_EQ(2000, schema->InferCost(def, shapes).flops); -} - -} // namespace caffe2 diff --git a/caffe2/core/operator_test.cc b/caffe2/core/operator_test.cc deleted file mode 100644 index afebacc71dc3..000000000000 --- a/caffe2/core/operator_test.cc +++ /dev/null @@ -1,634 +0,0 @@ -#include - -#include "caffe2/core/net.h" -#include "caffe2/core/operator.h" -#include - -namespace caffe2 { - -// Since we instantiate this on CPU and GPU (but don't want a -// CUDAContext dependency, we use OperatorBase. In general, you only -// want to inherit from Operator in your code. -class JustTest : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - virtual string type() { - return "base"; - } -}; - -class JustTestAndNeverConstructs : public JustTest { - public: - JustTestAndNeverConstructs(const OperatorDef& def, Workspace* ws) - : JustTest(def, ws) { - throw UnsupportedOperatorFeature("I just don't construct."); - } - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - string type() override { - return "FOO"; - } -}; - -class JustTestAndDoesConstruct : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - string type() override { - return "BAR"; - } -}; - -class JustTestWithSomeOutput : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - *OperatorBase::Output(0) = 5; - return true; - } - string type() override { - return "SETTING_SOME_OUTPUT"; - } -}; - -OPERATOR_SCHEMA(JustTest).NumInputs(0, 1).NumOutputs(0, 1); -OPERATOR_SCHEMA(JustTestCPUOnly).NumInputs(0, 1).NumOutputs(0, 1); -OPERATOR_SCHEMA(JustTestWithSomeOutput); - -REGISTER_CPU_OPERATOR(JustTest, JustTest); -REGISTER_CPU_OPERATOR(JustTestCPUOnly, JustTest); -REGISTER_CPU_OPERATOR_WITH_ENGINE(JustTest, FOO, JustTestAndNeverConstructs); -REGISTER_CPU_OPERATOR_WITH_ENGINE(JustTest, BAR, JustTestAndDoesConstruct); -REGISTER_CPU_OPERATOR_WITH_ENGINE(JustTest, BAZ, JustTestAndDoesConstruct); -REGISTER_CUDA_OPERATOR(JustTest, JustTest); -REGISTER_CPU_OPERATOR(JustTestWithSomeOutput, JustTestWithSomeOutput); - -TEST(OperatorTest, DeviceTypeRegistryWorks) { - EXPECT_EQ(gDeviceTypeRegistry()->count(CPU), 1); -} - -TEST(OperatorTest, RegistryWorks) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - unique_ptr op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // After introducing events, CUDA operator creation has to have CUDA compiled - // as it needs to instantiate an Event object with CUDAContext. Thus we will - // guard this test below. - if (HasCudaRuntime()) { - op_def.mutable_device_option()->set_device_type(PROTO_CUDA); - op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - } -} - -TEST(OperatorTest, RegistryWrongDevice) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTypeCPUOnly"); - op_def.mutable_device_option()->set_device_type(PROTO_CUDA); - try { - CreateOperator(op_def, &ws); - LOG(FATAL) << "No exception was thrown"; - } catch (const std::exception& e) { - LOG(INFO) << "Exception " << e.what(); - } -} - -TEST(OperatorTest, ExceptionWorks) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("ThrowException"); - unique_ptr op = CreateOperator(op_def, &ws); - // Note: we do not do ASSERT_THROW in order to print out - // the error message for inspection. - try { - op->Run(); - // This should not happen - exception should throw above. - LOG(FATAL) << "This should not happen."; - } catch (const EnforceNotMet& err) { - LOG(INFO) << err.what(); - } - try { - op->RunAsync(); - // This should not happen - exception should throw above. - LOG(FATAL) << "This should not happen."; - } catch (const EnforceNotMet& err) { - LOG(INFO) << err.what(); - } -} - -TEST(OperatorTest, FallbackIfEngineDoesNotBuild) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - op_def.set_engine("FOO"); - unique_ptr op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ(static_cast(op.get())->type(), "base"); -} - -TEST(OperatorTest, MultipleEngineChoices) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - op_def.set_engine("FOO,BAR"); - unique_ptr op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); -} - -TEST(OperatorTest, CannotUseUninitializedBlob) { - Workspace ws; - OperatorDef op_def; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("output"); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(CreateOperator(op_def, &ws), EnforceNotMet); -} - -TEST(OperatorTest, TestParameterAccess) { - OperatorDef op_def; - Workspace ws; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("output"); - AddArgument("arg0", 0.1, &op_def); - AddArgument>("arg1", vector{1, 2}, &op_def); - AddArgument("arg2", "argstring", &op_def); - EXPECT_NE(ws.CreateBlob("input"), nullptr); - OperatorBase op(op_def, &ws); - EXPECT_FLOAT_EQ(op.GetSingleArgument("arg0", 0.0), 0.1); - vector i = op.GetRepeatedArgument("arg1"); - EXPECT_EQ(i.size(), 2); - EXPECT_EQ(i[0], 1); - EXPECT_EQ(i[1], 2); - EXPECT_EQ(op.GetSingleArgument("arg2", "default"), "argstring"); - auto default1 = op.GetRepeatedArgument("arg3", {2, 3}); - EXPECT_EQ(default1.size(), 2); - EXPECT_EQ(default1[0], 2); - EXPECT_EQ(default1[1], 3); - auto default2 = op.GetRepeatedArgument("arg4"); - EXPECT_EQ(default2.size(), 0); -} - -TEST(OperatorTest, CannotAccessParameterWithWrongType) { - OperatorDef op_def; - Workspace ws; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("output"); - AddArgument("arg0", 0.1f, &op_def); - EXPECT_NE(ws.CreateBlob("input"), nullptr); - OperatorBase op(op_def, &ws); - EXPECT_FLOAT_EQ(op.GetSingleArgument("arg0", 0.0), 0.1); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(op.GetSingleArgument("arg0", 0), EnforceNotMet); -} - -#if GTEST_HAS_DEATH_TEST -TEST(OperatorDeathTest, DISABLED_CannotAccessRepeatedParameterWithWrongType) { - OperatorDef op_def; - Workspace ws; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("output"); - AddArgument>("arg0", vector{0.1f}, &op_def); - EXPECT_NE(ws.CreateBlob("input"), nullptr); - OperatorBase op(op_def, &ws); - auto args = op.GetRepeatedArgument("arg0"); - EXPECT_EQ(args.size(), 1); - EXPECT_FLOAT_EQ(args[0], 0.1f); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_DEATH(op.GetRepeatedArgument("arg0"), - "Argument does not have the right field: expected ints"); -} -#endif - -TEST(OperatorTest, TestDefaultValue) { - OperatorDef op_def; - Workspace ws; - OperatorBase op(op_def, &ws); - EXPECT_FLOAT_EQ(op.GetSingleArgument("arg-nonexisting", 0.5f), 0.5f); -} - -TEST(OperatorTest, TestSetUp) { - Workspace ws; - OperatorDef op_def; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("output"); - EXPECT_NE(nullptr, ws.CreateBlob("input")); - unique_ptr op(CreateOperator(op_def, &ws)); - EXPECT_NE(nullptr, op.get()); - EXPECT_TRUE(ws.HasBlob("output")); -} - -TEST(OperatorTest, TestSetUpInputOutputCount) { - Workspace ws; - OperatorDef op_def; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_input("input2"); - op_def.add_output("output"); - EXPECT_NE(nullptr, ws.CreateBlob("input")); - EXPECT_NE(nullptr, ws.CreateBlob("input2")); -#ifndef CAFFE2_NO_OPERATOR_SCHEMA - // JustTest will only accept one single input. - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_ANY_THROW(CreateOperator(op_def, &ws)); -#endif - - op_def.clear_input(); - op_def.add_input("input"); - op_def.add_output("output2"); -#ifndef CAFFE2_NO_OPERATOR_SCHEMA - // JustTest will only produce one single output. - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_ANY_THROW(CreateOperator(op_def, &ws)); -#endif -} - -TEST(OperatorTest, TestOutputValues) { - NetDef net_def; - net_def.set_name("NetForTest"); - OperatorDef op_def; - Workspace ws; - op_def.set_name("JustTest1"); - op_def.set_type("JustTestWithSomeOutput"); - op_def.add_output("output"); - // JustTest will only produce one single output. - net_def.add_op()->CopyFrom(op_def); - unique_ptr net(CreateNet(net_def, &ws)); - EXPECT_TRUE(net->Run()); - EXPECT_TRUE(ws.HasBlob("output")); - EXPECT_EQ(ws.GetBlob("output")->Get(), 5); -} - -NetDef GetNetDefForTest() { - NetDef net_def; - OperatorDef op_def; - net_def.set_name("NetForTest"); - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("hidden"); - net_def.add_op()->CopyFrom(op_def); - op_def.set_name("JustTest1"); - op_def.set_input(0, "hidden"); - op_def.set_output(0, "output"); - net_def.add_op()->CopyFrom(op_def); - return net_def; -} - -TEST(NetTest, TestScaffoldingSimpleNet) { - NetDef net_def = GetNetDefForTest(); - net_def.set_type("simple"); - Workspace ws; - EXPECT_NE(nullptr, ws.CreateBlob("input")); - unique_ptr net(CreateNet(net_def, &ws)); - EXPECT_NE(nullptr, net.get()); - EXPECT_TRUE(ws.HasBlob("input")); - EXPECT_TRUE(ws.HasBlob("hidden")); - EXPECT_TRUE(ws.HasBlob("output")); - EXPECT_TRUE(net->Run()); -} - -TEST(NetTest, TestScaffoldingDAGNet) { - NetDef net_def = GetNetDefForTest(); - net_def.set_type("dag"); - net_def.set_num_workers(1); - Workspace ws; - EXPECT_NE(nullptr, ws.CreateBlob("input")); - unique_ptr net(CreateNet(net_def, &ws)); - EXPECT_NE(nullptr, net.get()); - EXPECT_TRUE(ws.HasBlob("input")); - EXPECT_TRUE(ws.HasBlob("hidden")); - EXPECT_TRUE(ws.HasBlob("output")); - EXPECT_TRUE(net->Run()); -} - -class FooGradientOp : public JustTest { - public: - using JustTest::JustTest; - string type() override { - return "FooGradient"; - } -}; - -class FooGradientDummyEngineOp : public JustTest { - public: - using JustTest::JustTest; - string type() override { - return "FooGradientDummyEngine"; - } -}; - -class GetFooGradient : public GradientMakerBase { - using GradientMakerBase::GradientMakerBase; - vector GetGradientDefs() override { - return vector{ - CreateOperatorDef( - "FooGradient", "", - std::vector{GO(0)}, - std::vector{GI(0)})}; - } -}; - -GRADIENT_OPERATOR_SCHEMA(FooGradient).NumInputs(1).NumOutputs(1); -REGISTER_CPU_GRADIENT_OPERATOR(FooGradient, FooGradientOp) -REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE( - FooGradient, - DUMMY_ENGINE, - FooGradientDummyEngineOp) -REGISTER_GRADIENT(Foo, GetFooGradient); - -TEST(OperatorGradientRegistryTest, GradientSimple) { - Argument arg = MakeArgument("arg", 1); - DeviceOption option; - option.set_device_type(PROTO_CPU); - OperatorDef def = CreateOperatorDef( - "Foo", "", std::vector{"in"}, std::vector{"out"}, - std::vector{arg}, option, "DUMMY_ENGINE"); - vector g_output(1); - g_output[0].dense_ = "out_grad"; - GradientOpsMeta meta = GetGradientForOp(def, g_output); - // Check the names, input and output. - EXPECT_EQ(meta.ops_.size(), 1); - const OperatorDef& grad_op_def = meta.ops_[0]; - EXPECT_EQ(grad_op_def.type(), "FooGradient"); - EXPECT_EQ(grad_op_def.name(), ""); - EXPECT_EQ(grad_op_def.input_size(), 1); - EXPECT_EQ(grad_op_def.output_size(), 1); - EXPECT_EQ(grad_op_def.input(0), "out_grad"); - EXPECT_EQ(grad_op_def.output(0), "in_grad"); - // Checks the engine, device option and arguments. - EXPECT_EQ(grad_op_def.engine(), "DUMMY_ENGINE"); - EXPECT_EQ(grad_op_def.device_option().device_type(), PROTO_CPU); - EXPECT_EQ(grad_op_def.arg_size(), 1); - EXPECT_EQ( - grad_op_def.arg(0).SerializeAsString(), - MakeArgument("arg", 1).SerializeAsString()); - // Checks the gradient name for input. - EXPECT_EQ(meta.g_input_.size(), 1); - EXPECT_TRUE(meta.g_input_[0].IsDense()); - EXPECT_EQ(meta.g_input_[0].dense_, "in_grad"); - - Workspace ws; - EXPECT_NE(ws.CreateBlob("out_grad"), nullptr); - unique_ptr grad_op = CreateOperator(grad_op_def, &ws); - EXPECT_NE(nullptr, grad_op.get()); - EXPECT_EQ( - static_cast(grad_op.get())->type(), "FooGradientDummyEngine"); -} - -TEST(EnginePrefTest, PerOpEnginePref) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - - SetPerOpEnginePref({{CPU, {{"JustTest", {"BAR"}}}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); - } - // clear - SetPerOpEnginePref({}); - - // Invalid operator type - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW( - SetPerOpEnginePref({{CPU, {{"NO_EXIST", {"BAR"}}}}}), EnforceNotMet); -} - -TEST(EnginePrefTest, GlobalEnginePref) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - - SetGlobalEnginePref({{CPU, {"FOO", "BAR"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); - } - // clear - SetGlobalEnginePref({}); - - SetGlobalEnginePref({{CPU, {"FOO"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ(static_cast(op.get())->type(), "base"); - } - // clear - SetGlobalEnginePref({}); - - // Invalid device type - // This check is no longer necessary with the enum class - // ASSERT_THROW(SetGlobalEnginePref({{8888, {"FOO"}}}), EnforceNotMet); -} - -TEST(EnginePrefTest, GlobalEnginePrefAndPerOpEnginePref) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - - SetPerOpEnginePref({{CPU, {{"JustTest", {"BAR"}}}}}); - SetGlobalEnginePref({{CPU, {"BAZ"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // per op pref takes precedence - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); - } - // clear - SetPerOpEnginePref({}); - SetGlobalEnginePref({}); -} - -TEST(EnginePrefTest, GlobalEnginePrefAndPerOpEnginePrefAndOpDef) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - op_def.set_engine("BAR"); - - SetPerOpEnginePref({{CPU, {{"JustTest", {"BAZ"}}}}}); - SetGlobalEnginePref({{CPU, {"BAZ"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // operator_def takes precedence - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); - } - // clear - SetPerOpEnginePref({}); - SetGlobalEnginePref({}); -} - -TEST(EnginePrefTest, SetOpEnginePref) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - - SetPerOpEnginePref({{CPU, {{"JustTest", {"BAZ"}}}}}); - SetOpEnginePref("JustTest", {{CPU, {"BAR"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // operator_def takes precedence - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); - } - // clear - SetPerOpEnginePref({}); - SetGlobalEnginePref({}); -} - -TEST(EnginePrefTest, SetDefaultEngine) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - - SetPerOpEnginePref({{CPU, {{"JustTest", {"DEFAULT"}}}}}); - SetGlobalEnginePref({{CPU, {"BAR"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // operator_def takes precedence - EXPECT_EQ(static_cast(op.get())->type(), "base"); - } - // clear - SetPerOpEnginePref({}); - SetGlobalEnginePref({}); -} - -class JustTestWithRequiredArg : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - string type() override { - return "JustTestWithRequiredArg"; - } -}; - -REGISTER_CPU_OPERATOR(JustTestWithRequiredArg, JustTestWithRequiredArg); -OPERATOR_SCHEMA(JustTestWithRequiredArg) - .NumInputs(0, 1) - .NumOutputs(0, 1) - .Arg("test_arg", "this arg is required", true); - -TEST(RequiredArg, Basic) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTestWithRequiredArg"); - - { - try { - CreateOperator(op_def, &ws); - LOG(FATAL) << "No exception was thrown"; - } catch (const std::exception& e) { - LOG(INFO) << "Exception thrown (expected): " << e.what(); - } - } - - { - op_def.add_arg()->CopyFrom(MakeArgument("test_arg", 1)); - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ( - static_cast(op.get())->type(), "JustTestWithRequiredArg"); - } -} - -class JustTestWithStandardIsTestArg : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - string type() override { - return "JustTestWithStandardIsTestArg"; - } -}; - -REGISTER_CPU_OPERATOR( - JustTestWithStandardIsTestArg, - JustTestWithStandardIsTestArg); -OPERATOR_SCHEMA(JustTestWithStandardIsTestArg) - .NumInputs(0, 1) - .NumOutputs(0, 1) - .ArgIsTest("this is_test arg is required"); - -TEST(IsTestArg, standard) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTestWithStandardIsTestArg"); - - { - try { - CreateOperator(op_def, &ws); - LOG(FATAL) << "No exception was thrown"; - } catch (const std::exception& e) { - LOG(INFO) << "Exception thrown (expected): " << e.what(); - } - } - - { - op_def.add_arg()->CopyFrom(MakeArgument(OpSchema::Arg_IsTest, 1)); - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ( - static_cast(op.get())->type(), - "JustTestWithStandardIsTestArg"); - } -} - -class JustTestWithNonStandardIsTestArg : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - string type() override { - return "JustTestWithNonStandardIsTestArg"; - } -}; - -REGISTER_CPU_OPERATOR( - JustTestWithNonStandardIsTestArg, - JustTestWithNonStandardIsTestArg); -OPERATOR_SCHEMA(JustTestWithNonStandardIsTestArg) - .NumInputs(0, 1) - .NumOutputs(0, 1) - .Arg(OpSchema::Arg_IsTest, "this is_test arg is not required"); - -TEST(IsTestArg, non_standard) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTestWithNonStandardIsTestArg"); - - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ( - static_cast(op.get())->type(), - "JustTestWithNonStandardIsTestArg"); -} - -} // namespace caffe2 diff --git a/caffe2/core/parallel_net_test.cc b/caffe2/core/parallel_net_test.cc deleted file mode 100644 index 7b17faba3150..000000000000 --- a/caffe2/core/parallel_net_test.cc +++ /dev/null @@ -1,322 +0,0 @@ -#include // NOLINT -#include // NOLINT - -#include -#include "caffe2/core/net.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -// When measuring time, we relax the measured time by +- 40ms. -#ifndef _WIN32 -const int kTimeThreshold = 40; -#else -// Even more so on Windows -const int kTimeThreshold = 50; -#endif - -// SleepOp basically sleeps for a given number of seconds. -// We allow arbitrary inputs and at most one output so that we can -// test scaffolding of networks. If the output is 1, it will be filled with -// vector with two elements: start time and end time. -class SleepOp final : public Operator { - public: - SleepOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws), - ms_(OperatorBase::GetSingleArgument("ms", 1000)) { - TORCH_DCHECK_GT(ms_, 0); - TORCH_DCHECK_LT(ms_, 3600 * 1000) << "Really? This long?"; - } - - bool RunOnDevice() override { - auto start = std::chrono::high_resolution_clock::now(); - std::this_thread::sleep_for(std::chrono::milliseconds(ms_)); - auto end = std::chrono::high_resolution_clock::now(); - if (OperatorBase::OutputSize()) { - vector* output = OperatorBase::Output>(0); - output->resize(2); - (*output)[0] = start.time_since_epoch().count(); - (*output)[1] = end.time_since_epoch().count(); - } - return true; - } - - private: - int ms_; -}; - -OPERATOR_SCHEMA(Sleep).NumInputs(0, INT_MAX).NumOutputs(0, 1); - -REGISTER_CPU_OPERATOR(Sleep, SleepOp); -REGISTER_CUDA_OPERATOR(Sleep, SleepOp); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -const char kSleepNetDefString[] = - " name: \"sleepnet\"" - " type: \"dag\"" - " num_workers: 2" - " op {" - " output: \"sleep1\"" - " name: \"sleep1\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " input: \"sleep1\"" - " output: \"sleep2\"" - " name: \"sleep2\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " output: \"sleep3\"" - " name: \"sleep3\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 150" - " }" - " }"; - -namespace { -// Run a network and get its duration in milliseconds. -int RunNetAndGetDuration(const string& net_def_str, const string& type) { - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(net_def_str, &net_def)); - net_def.set_type(type); - Workspace ws; - unique_ptr net(CreateNet(net_def, &ws)); - CAFFE_ENFORCE(net.get() != nullptr); - // Run once to kick in potential initialization (can be slower) - CAFFE_ENFORCE(net->Run()); - // Now run and time it - auto start_time = std::chrono::system_clock::now(); - CAFFE_ENFORCE(net->Run()); - // Inspect the time - it should be around 200 milliseconds, since sleep3 can - // run in parallel with sleep1 and sleep2. - auto duration = std::chrono::duration_cast( - std::chrono::system_clock::now() - start_time); - int milliseconds = duration.count(); - return milliseconds; -} -} // namespace - -TEST(DAGNetTest, TestDAGNetTiming) { - int ms = RunNetAndGetDuration(string(kSleepNetDefString), "dag"); - EXPECT_NEAR(ms, 200, kTimeThreshold); -} - -// For sanity check, we also test the sequential time - it should take 0.35 -// seconds instead since everything has to be sequential. -TEST(SimpleNetTest, TestSimpleNetTiming) { - int ms = RunNetAndGetDuration(string(kSleepNetDefString), "simple"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -// This network has two operators reading the same blob at the same time. This -// should not change anything and the DAG should still make sleep2 and sleep3 -// run in parallel. -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -const char kSleepNetDefStringReadAfterRead[] = - " name: \"sleepnet\"" - " type: \"dag\"" - " num_workers: 2" - " op {" - " output: \"sleep1\"" - " name: \"sleep1\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " input: \"sleep1\"" - " output: \"sleep2\"" - " name: \"sleep2\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " input: \"sleep1\"" - " output: \"sleep3\"" - " name: \"sleep3\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 150" - " }" - " }"; - -TEST(DAGNetTest, TestDAGNetTimingReadAfterRead) { - int ms = RunNetAndGetDuration(string(kSleepNetDefStringReadAfterRead), "dag"); - EXPECT_NEAR(ms, 250, kTimeThreshold); -} - -// For sanity check, we also test the sequential time - it should take 0.35 -// seconds instead since everything has to be sequential. -TEST(SimpleNetTest, TestSimpleNetTimingReadAfterRead) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringReadAfterRead), "simple"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -// This network has two operators writing out the sleep2 blob. As a result, the -// operator sleep2-again creates a write after write dependency and the whole -// process should be sequential. -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -const char kSleepNetDefStringWriteAfterWrite[] = - " name: \"sleepnet\"" - " type: \"dag\"" - " num_workers: 2" - " op {" - " output: \"sleep1\"" - " name: \"sleep1\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " input: \"sleep1\"" - " output: \"sleep2\"" - " name: \"sleep2\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " output: \"sleep2\"" - " name: \"sleep2-again\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 150" - " }" - " }"; - -TEST(DAGNetTest, TestDAGNetTimingWriteAfterWrite) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringWriteAfterWrite), "dag"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -TEST(SimpleNetTest, TestSimpleNetTimingWriteAfterWrite) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringWriteAfterWrite), "simple"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -// This network has an operator writing to sleep1 while another operator is -// accessing it. As a result, the operator sleep1-again creates a write after -// read dependency and the whole process should be sequential. -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -const char kSleepNetDefStringWriteAfterRead[] = - " name: \"sleepnet\"" - " type: \"dag\"" - " num_workers: 2" - " op {" - " output: \"sleep1\"" - " name: \"sleep1\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " input: \"sleep1\"" - " output: \"sleep2\"" - " name: \"sleep2\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " output: \"sleep1\"" - " name: \"sleep1-again\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 150" - " }" - " }"; - -TEST(DAGNetTest, TestDAGNetTimingWriteAfterRead) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringWriteAfterRead), "dag"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -TEST(SimpleNetTest, TestSimpleNetTimingWriteAfterRead) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringWriteAfterRead), "simple"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -// This network has an operator writing to sleep1 while another -// operator has a control dependency on it. As a result, the operator -// sleep1-again creates a write after read dependency and the whole -// process should be sequential. -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -const char kSleepNetDefStringControlDependency[] = R"DOC( - name: "sleepnet" - type: "dag" - num_workers: 2 - op { - output: "sleep1" - name: "sleep1" - type: "Sleep" - arg { - name: "ms" - i: 100 - } - } - op { - control_input: "sleep1" - output: "sleep2" - name: "sleep2" - type: "Sleep" - arg { - name: "ms" - i: 100 - } - } - op { - output: "sleep1" - name: "sleep1-again" - type: "Sleep" - arg { - name: "ms" - i: 150 - } - } -)DOC"; - -TEST(DAGNetTest, TestDAGNetTimingControlDependency) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringControlDependency), "dag"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -TEST(SimpleNetTest, TestSimpleNetTimingControlDependency) { - int ms = RunNetAndGetDuration( - string(kSleepNetDefStringControlDependency), "simple"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -} // namespace caffe2 diff --git a/caffe2/core/plan_executor_test.cc b/caffe2/core/plan_executor_test.cc deleted file mode 100644 index 7a54403805ec..000000000000 --- a/caffe2/core/plan_executor_test.cc +++ /dev/null @@ -1,414 +0,0 @@ -#ifndef ANDROID - -#include -#include "caffe2/core/init.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/plan_executor.h" - -namespace caffe2 { - -TEST(PlanExecutorTest, EmptyPlan) { - PlanDef plan_def; - Workspace ws; - EXPECT_TRUE(ws.RunPlan(plan_def)); -} - -namespace { -static std::atomic cancelCount{0}; -static std::atomic stuckRun{false}; -} // namespace - -class StuckBlockingOp final : public Operator { - public: - StuckBlockingOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // StuckBlockingOp runs and notifies ErrorOp. - stuckRun = true; - - while (!cancelled_) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - - return true; - } - - void Cancel() override { - LOG(INFO) << "cancelled StuckBlockingOp."; - cancelCount += 1; - cancelled_ = true; - } - - private: - std::atomic cancelled_{false}; -}; - -REGISTER_CPU_OPERATOR(StuckBlocking, StuckBlockingOp); -OPERATOR_SCHEMA(StuckBlocking).NumInputs(0).NumOutputs(0); - -class NoopOp final : public Operator { - public: - NoopOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // notify Error op we've ran. - stuckRun = true; - return true; - } -}; - -REGISTER_CPU_OPERATOR(Noop, NoopOp); -OPERATOR_SCHEMA(Noop).NumInputs(0).NumOutputs(0); - - -class StuckAsyncOp final : public Operator { - public: - StuckAsyncOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // notify Error op we've ran. - stuckRun = true; - // explicitly don't call SetFinished so this gets stuck - return true; - } - - void CancelAsyncCallback() override { - LOG(INFO) << "cancelled"; - cancelCount += 1; - } - - bool HasAsyncPart() const override { - return true; - } -}; - -REGISTER_CPU_OPERATOR(StuckAsync, StuckAsyncOp); -OPERATOR_SCHEMA(StuckAsync).NumInputs(0).NumOutputs(0); - -class TestError : public std::exception { - const char* what() const noexcept override { - return "test error"; - } -}; - -class ErrorOp final : public Operator { - public: - ErrorOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // Wait for StuckAsyncOp or StuckBlockingOp to run first. - while (!stuckRun) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - throw TestError(); - return true; - } -}; - -REGISTER_CPU_OPERATOR(Error, ErrorOp); -OPERATOR_SCHEMA(Error).NumInputs(0).NumOutputs(0); - -static std::atomic blockingErrorRuns{0}; -class BlockingErrorOp final : public Operator { - public: - BlockingErrorOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // First n op executions should block and then start throwing errors. - if (blockingErrorRuns.fetch_sub(1) >= 1) { - LOG(INFO) << "blocking"; - while (true) { - std::this_thread::sleep_for(std::chrono::hours(10)); - } - } else { - LOG(INFO) << "throwing"; - throw TestError(); - } - } -}; - -REGISTER_CPU_OPERATOR(BlockingError, BlockingErrorOp); -OPERATOR_SCHEMA(BlockingError).NumInputs(0).NumOutputs(0); - -PlanDef parallelErrorPlan() { - PlanDef plan_def; - - auto* stuck_net = plan_def.add_network(); - stuck_net->set_name("stuck_net"); - stuck_net->set_type("async_scheduling"); - { - auto* op = stuck_net->add_op(); - op->set_type("StuckAsync"); - } - - auto* error_net = plan_def.add_network(); - error_net->set_name("error_net"); - error_net->set_type("async_scheduling"); - { - auto op = error_net->add_op(); - op->set_type("Error"); - } - - auto* execution_step = plan_def.add_execution_step(); - execution_step->set_concurrent_substeps(true); - { - auto* substep = execution_step->add_substep(); - substep->add_network(stuck_net->name()); - } - { - auto* substep = execution_step->add_substep(); - substep->add_network(error_net->name()); - } - - return plan_def; -} - -PlanDef parallelErrorPlanWithCancellableStuckNet() { - // Set a plan with two nets: one stuck net with blocking operator that never - // returns; one error net with error op that throws. - PlanDef plan_def; - - auto* stuck_blocking_net = plan_def.add_network(); - stuck_blocking_net->set_name("stuck_blocking_net"); - { - auto* op = stuck_blocking_net->add_op(); - op->set_type("StuckBlocking"); - } - - auto* error_net = plan_def.add_network(); - error_net->set_name("error_net"); - { - auto* op = error_net->add_op(); - op->set_type("Error"); - } - - auto* execution_step = plan_def.add_execution_step(); - execution_step->set_concurrent_substeps(true); - { - auto* substep = execution_step->add_substep(); - substep->add_network(stuck_blocking_net->name()); - } - { - auto* substep = execution_step->add_substep(); - substep->add_network(error_net->name()); - } - - return plan_def; -} - -PlanDef reporterErrorPlanWithCancellableStuckNet() { - // Set a plan with a concurrent net and a reporter net: one stuck net with - // blocking operator that never returns; one reporter net with error op - // that throws. - PlanDef plan_def; - - auto* stuck_blocking_net = plan_def.add_network(); - stuck_blocking_net->set_name("stuck_blocking_net"); - { - auto* op = stuck_blocking_net->add_op(); - op->set_type("StuckBlocking"); - } - - auto* error_net = plan_def.add_network(); - error_net->set_name("error_net"); - { - auto* op = error_net->add_op(); - op->set_type("Error"); - } - - auto* execution_step = plan_def.add_execution_step(); - execution_step->set_concurrent_substeps(true); - { - auto* substep = execution_step->add_substep(); - substep->add_network(stuck_blocking_net->name()); - } - { - auto* substep = execution_step->add_substep(); - substep->set_run_every_ms(1); - substep->add_network(error_net->name()); - } - - return plan_def; -} - -struct HandleExecutorThreadExceptionsGuard { - HandleExecutorThreadExceptionsGuard(int timeout = 60) { - globalInit({ - "caffe2", - "--caffe2_handle_executor_threads_exceptions=1", - "--caffe2_plan_executor_exception_timeout=" + - caffe2::to_string(timeout), - }); - } - - ~HandleExecutorThreadExceptionsGuard() { - globalInit({ - "caffe2", - }); - } - - HandleExecutorThreadExceptionsGuard( - const HandleExecutorThreadExceptionsGuard&) = delete; - void operator=(const HandleExecutorThreadExceptionsGuard&) = delete; - - private: - void globalInit(std::vector args) { - std::vector args_ptrs; - for (auto& arg : args) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,performance-inefficient-vector-operation) - args_ptrs.push_back(const_cast(arg.data())); - } - char** new_argv = args_ptrs.data(); - int new_argc = args.size(); - CAFFE_ENFORCE(GlobalInit(&new_argc, &new_argv)); - } -}; - -TEST(PlanExecutorTest, ErrorAsyncPlan) { - HandleExecutorThreadExceptionsGuard guard; - - cancelCount = 0; - PlanDef plan_def = parallelErrorPlan(); - Workspace ws; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(ws.RunPlan(plan_def), TestError); - ASSERT_EQ(cancelCount, 1); -} - -// death tests not supported on mobile -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -TEST(PlanExecutorTest, BlockingErrorPlan) { - // TSAN doesn't play nicely with death tests -#if defined(__has_feature) -#if __has_feature(thread_sanitizer) - return; -#endif -#endif - - testing::GTEST_FLAG(death_test_style) = "threadsafe"; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_DEATH( - [] { - HandleExecutorThreadExceptionsGuard guard(/*timeout=*/1); - - PlanDef plan_def; - - std::string plan_def_template = R"DOC( - network { - name: "net" - op { - type: "BlockingError" - } - } - execution_step { - num_concurrent_instances: 2 - substep { - network: "net" - } - } - )DOC"; - - CAFFE_ENFORCE( - TextFormat::ParseFromString(plan_def_template, &plan_def)); - Workspace ws; - blockingErrorRuns = 1; - ws.RunPlan(plan_def); - FAIL() << "shouldn't have reached this point"; - }(), - "failed to stop concurrent workers after exception: test error"); -} -#endif - -TEST(PlanExecutorTest, ErrorPlanWithCancellableStuckNet) { - HandleExecutorThreadExceptionsGuard guard; - - cancelCount = 0; - PlanDef plan_def = parallelErrorPlanWithCancellableStuckNet(); - Workspace ws; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(ws.RunPlan(plan_def), TestError); - ASSERT_EQ(cancelCount, 1); -} - -TEST(PlanExecutorTest, ReporterErrorPlanWithCancellableStuckNet) { - HandleExecutorThreadExceptionsGuard guard; - - cancelCount = 0; - PlanDef plan_def = reporterErrorPlanWithCancellableStuckNet(); - Workspace ws; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(ws.RunPlan(plan_def), TestError); - ASSERT_EQ(cancelCount, 1); -} - -PlanDef shouldStopWithCancelPlan() { - // Set a plan with a looping net with should_stop_blob set and a concurrent - // net that throws an error. The error should cause should_stop to return - // false and end the concurrent net. - PlanDef plan_def; - - auto* should_stop_net = plan_def.add_network(); - { - auto* op = should_stop_net->add_op(); - op->set_type("Noop"); - } - should_stop_net->set_name("should_stop_net"); - should_stop_net->set_type("async_scheduling"); - - auto* error_net = plan_def.add_network(); - error_net->set_name("error_net"); - { - auto* op = error_net->add_op(); - op->set_type("Error"); - } - - auto* execution_step = plan_def.add_execution_step(); - execution_step->set_concurrent_substeps(true); - { - auto* substep = execution_step->add_substep(); - execution_step->set_concurrent_substeps(true); - substep->set_name("concurrent_should_stop"); - substep->set_should_stop_blob("should_stop_blob"); - auto* substep2 = substep->add_substep(); - substep2->set_name("should_stop_net"); - substep2->add_network(should_stop_net->name()); - substep2->set_num_iter(10); - } - { - auto* substep = execution_step->add_substep(); - substep->set_name("error_step"); - substep->add_network(error_net->name()); - } - - return plan_def; -} - -TEST(PlanExecutorTest, ShouldStopWithCancel) { - HandleExecutorThreadExceptionsGuard guard; - - stuckRun = false; - PlanDef plan_def = shouldStopWithCancelPlan(); - Workspace ws; - - Blob* blob = ws.CreateBlob("should_stop_blob"); - Tensor* tensor = BlobGetMutableTensor(blob, CPU); - const vector& shape{1}; - tensor->Resize(shape); - tensor->mutable_data()[0] = false; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(ws.RunPlan(plan_def), TestError); - ASSERT_TRUE(stuckRun); -} - -} // namespace caffe2 - -#endif diff --git a/caffe2/core/serialization_test.cc b/caffe2/core/serialization_test.cc deleted file mode 100644 index 902a3e01e677..000000000000 --- a/caffe2/core/serialization_test.cc +++ /dev/null @@ -1,101 +0,0 @@ -#include - -#include -#include -#include "caffe2/core/blob.h" -#include "caffe2/core/blob_serialization.h" - -// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays -C10_DEFINE_bool( - caffe2_test_generate_unknown_dtype_blob, - false, - "Recompute and log the serialized blob data for the " - "TensorSerialization.TestUnknownDType test"); - -using namespace caffe2; - -namespace { - -// This data was computed by serializing a 10-element int32_t tensor, -// but with the data_type field set to 4567. This allows us to test the -// behavior of the code when deserializing data from a future version of the -// code that has new data types that our code does not understand. -constexpr c10::string_view kFutureDtypeBlob( - "\x0a\x09\x74\x65\x73\x74\x5f\x62\x6c\x6f\x62\x12\x06\x54\x65\x6e" - "\x73\x6f\x72\x1a\x28\x08\x0a\x08\x01\x10\xd7\x23\x22\x0a\x00\x01" - "\x02\x03\x04\x05\x06\x07\x08\x09\x3a\x09\x74\x65\x73\x74\x5f\x62" - "\x6c\x6f\x62\x42\x02\x08\x00\x5a\x04\x08\x00\x10\x0a", - 61); -// The same tensor with the data_type actually set to TensorProto_DataType_INT32 -constexpr c10::string_view kInt32DtypeBlob( - "\x0a\x09\x74\x65\x73\x74\x5f\x62\x6c\x6f\x62\x12\x06\x54\x65\x6e" - "\x73\x6f\x72\x1a\x27\x08\x0a\x08\x01\x10\x02\x22\x0a\x00\x01\x02" - "\x03\x04\x05\x06\x07\x08\x09\x3a\x09\x74\x65\x73\x74\x5f\x62\x6c" - "\x6f\x62\x42\x02\x08\x00\x5a\x04\x08\x00\x10\x0a", - 60); - -void logBlob(c10::string_view data) { - constexpr size_t kBytesPerLine = 16; - constexpr size_t kCharsPerEncodedByte = 4; - std::vector hexStr; - hexStr.resize((kBytesPerLine * kCharsPerEncodedByte) + 1); - hexStr[kBytesPerLine * kCharsPerEncodedByte] = '\0'; - size_t lineIdx = 0; - for (char c : data) { - snprintf( - hexStr.data() + (kCharsPerEncodedByte * lineIdx), - kCharsPerEncodedByte + 1, - "\\x%02x", - static_cast(c)); - ++lineIdx; - if (lineIdx >= kBytesPerLine) { - LOG(INFO) << " \"" << hexStr.data() << "\""; - lineIdx = 0; - } - } - if (lineIdx > 0) { - hexStr[lineIdx * kCharsPerEncodedByte] = '\0'; - LOG(INFO) << " \"" << hexStr.data() << "\""; - } -} - -} // namespace - -TEST(TensorSerialization, TestUnknownDType) { - // This code was used to generate the blob data listed above. - constexpr size_t kTestTensorSize = 10; - if (FLAGS_caffe2_test_generate_unknown_dtype_blob) { - Blob blob; - auto* blobTensor = BlobGetMutableTensor(&blob, CPU); - blobTensor->Resize(kTestTensorSize, 1); - auto *tensorData = blobTensor->mutable_data(); - for (unsigned n = 0; n < kTestTensorSize; ++n) { - tensorData[n] = n; - } - auto data = SerializeBlob(blob, "test_blob"); - LOG(INFO) << "test blob: size=" << data.size(); - logBlob(data); - } - - // Test deserializing the normal INT32 data, - // just to santity check that deserialization works - Blob i32Blob; - DeserializeBlob(std::string(kInt32DtypeBlob), &i32Blob); - const auto& tensor = BlobGetTensor(i32Blob, c10::DeviceType::CPU); - EXPECT_EQ(kTestTensorSize, tensor.numel()); - EXPECT_EQ(TypeMeta::Make(), tensor.dtype()); - const auto* tensor_data = tensor.template data(); - for (unsigned i = 0; i < kTestTensorSize; ++i) { - EXPECT_EQ(static_cast(i), tensor_data[i]); - } - - // Now test deserializing our blob with an unknown data type - Blob futureDtypeBlob; - try { - DeserializeBlob(std::string(kFutureDtypeBlob), &futureDtypeBlob); - FAIL() << "DeserializeBlob() should have failed"; - } catch (const std::exception& ex) { - EXPECT_STREQ( - "Cannot deserialize tensor: unrecognized data type", ex.what()); - } -} diff --git a/caffe2/core/stats_test.cc b/caffe2/core/stats_test.cc deleted file mode 100644 index ab61e7a2f84b..000000000000 --- a/caffe2/core/stats_test.cc +++ /dev/null @@ -1,151 +0,0 @@ -#include -#include -#include - -#include "caffe2/core/stats.h" -#include - -namespace caffe2 { -namespace { - -struct MyCaffeClass { - explicit MyCaffeClass(const std::string& name) : stats_(name) {} - - void tryRun(int) {} - - void run(int numRuns) { - try { - CAFFE_EVENT(stats_, num_runs, numRuns); - tryRun(numRuns); - CAFFE_EVENT(stats_, num_successes); - } catch (std::exception& e) { - CAFFE_EVENT(stats_, num_failures, 1, "arg_to_usdt", e.what()); - } - CAFFE_EVENT(stats_, usdt_only, 1, "arg_to_usdt"); - } - - private: - struct MyStats { - // NOLINTNEXTLINE(modernize-pass-by-value) - CAFFE_STAT_CTOR(MyStats); - CAFFE_EXPORTED_STAT(num_runs); - CAFFE_EXPORTED_STAT(num_successes); - CAFFE_EXPORTED_STAT(num_failures); - CAFFE_STAT(usdt_only); - } stats_; -}; - -ExportedStatMap filterMap( - const ExportedStatMap& map, - const ExportedStatMap& keys) { - ExportedStatMap filtered; - for (const auto& kv : map) { - if (keys.count(kv.first) > 0) { - filtered.insert(kv); - } - } - return filtered; -} - -#define EXPECT_SUBSET(map, sub) EXPECT_EQ(filterMap((map), (sub)), (sub)) - -TEST(StatsTest, StatsTestClass) { - MyCaffeClass a("first"); - MyCaffeClass b("second"); - for (int i = 0; i < 10; ++i) { - a.run(10); - b.run(5); - } - EXPECT_SUBSET( - ExportedStatMap({ - {"first/num_runs", 100}, - {"first/num_successes", 10}, - {"first/num_failures", 0}, - {"second/num_runs", 50}, - {"second/num_successes", 10}, - {"second/num_failures", 0}, - }), - toMap(StatRegistry::get().publish())); -} - -TEST(StatsTest, StatsTestDuration) { - struct TestStats { - // NOLINTNEXTLINE(modernize-pass-by-value) - CAFFE_STAT_CTOR(TestStats); - CAFFE_STAT(count); - CAFFE_AVG_EXPORTED_STAT(time_ns); - }; - TestStats stats("stats"); - CAFFE_DURATION(stats, time_ns) { - std::this_thread::sleep_for(std::chrono::microseconds(1)); - } - - ExportedStatList data; - StatRegistry::get().publish(data); - auto map = toMap(data); - auto countIt = map.find("stats/time_ns/count"); - auto sumIt = map.find("stats/time_ns/sum"); - EXPECT_TRUE(countIt != map.end() && sumIt != map.end()); - EXPECT_EQ(countIt->second, 1); - EXPECT_GT(sumIt->second, 0); -} - -TEST(StatsTest, StatsTestSimple) { - struct TestStats { - // NOLINTNEXTLINE(modernize-pass-by-value) - CAFFE_STAT_CTOR(TestStats); - CAFFE_STAT(s1); - CAFFE_STAT(s2); - CAFFE_EXPORTED_STAT(s3); - }; - TestStats i1("i1"); - TestStats i2("i2"); - CAFFE_EVENT(i1, s1); - CAFFE_EVENT(i1, s2); - CAFFE_EVENT(i1, s3, 1); - CAFFE_EVENT(i1, s3, -1); - CAFFE_EVENT(i2, s3, 2); - - ExportedStatList data; - StatRegistry::get().publish(data); - EXPECT_SUBSET(toMap(data), ExportedStatMap({{"i1/s3", 0}, {"i2/s3", 2}})); - - StatRegistry reg2; - reg2.update(data); - reg2.update(data); - - EXPECT_SUBSET( - toMap(reg2.publish(true)), ExportedStatMap({{"i1/s3", 0}, {"i2/s3", 4}})); - EXPECT_SUBSET( - toMap(reg2.publish()), ExportedStatMap({{"i1/s3", 0}, {"i2/s3", 0}})); -} - -TEST(StatsTest, StatsTestStatic) { - struct TestStats { - // NOLINTNEXTLINE(modernize-pass-by-value) - CAFFE_STAT_CTOR(TestStats); - CAFFE_STATIC_STAT(cpuUsage); - CAFFE_STATIC_STAT(memUsage); - }; - TestStats i1("i1"); - TestStats i2("i2"); - CAFFE_EVENT(i1, cpuUsage, 95); - CAFFE_EVENT(i2, memUsage, 80); - - ExportedStatList data; - StatRegistry::get().publish(data); - EXPECT_SUBSET( - toMap(data), ExportedStatMap({{"i1/cpuUsage", 95}, {"i2/memUsage", 80}})); - - CAFFE_EVENT(i1, cpuUsage, 80); - CAFFE_EVENT(i1, memUsage, 50); - CAFFE_EVENT(i2, memUsage, 90); - - StatRegistry::get().publish(data); - EXPECT_SUBSET( - toMap(data), - ExportedStatMap( - {{"i1/cpuUsage", 80}, {"i1/memUsage", 50}, {"i2/memUsage", 90}})); -} -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/timer_test.cc b/caffe2/core/timer_test.cc deleted file mode 100644 index 8ffb2f21af03..000000000000 --- a/caffe2/core/timer_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include -#include -#include - -#include "caffe2/core/timer.h" -#include - -namespace caffe2 { -namespace { - -TEST(TimerTest, Test) { - Timer timer; - - // A timer auto-starts when it is constructed. - std::this_thread::sleep_for(std::chrono::microseconds(1)); - EXPECT_GT(timer.NanoSeconds(), 0); - - // Sleep for a while, and get the time. - timer.Start(); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - float ns = timer.NanoSeconds(); - float us = timer.MicroSeconds(); - float ms = timer.MilliSeconds(); - - // Time should be at least accurate +- 10%. (30% on Windows) -#ifndef _WIN32 - EXPECT_NEAR(ns, 100000000, 10000000); - EXPECT_NEAR(us, 100000, 10000); - EXPECT_NEAR(ms, 100, 10); -#else - EXPECT_NEAR(ns, 100000000, 30000000); - EXPECT_NEAR(us, 100000, 30000); - EXPECT_NEAR(ms, 100, 30); -#endif - - // Test restarting the clock. - timer.Start(); - EXPECT_LT(timer.MicroSeconds(), 1000); -} - -TEST(TimerTest, TestLatency) { - constexpr int iter = 1000; - float latency = 0; - Timer timer; - for (int i = 0; i < iter; ++i) { - timer.Start(); - latency += timer.NanoSeconds(); - } - std::cout << "Average nanosecond latency is: " << latency / iter << std::endl; - latency = 0; - for (int i = 0; i < iter; ++i) { - timer.Start(); - latency += timer.MicroSeconds(); - } - std::cout << "Average microsecond latency is: " << latency / iter << std::endl; - latency = 0; - for (int i = 0; i < iter; ++i) { - timer.Start(); - latency += timer.MilliSeconds(); - } - std::cout << "Average millisecond latency is: " << latency / iter << std::endl; -} - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/transform_test.cc b/caffe2/core/transform_test.cc deleted file mode 100644 index 0dc6ba92c7f9..000000000000 --- a/caffe2/core/transform_test.cc +++ /dev/null @@ -1,460 +0,0 @@ -#include -#include "caffe2/core/net.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/transform.h" - -namespace caffe2 { - -namespace { - -using transform::Graph; - -static std::atomic counter; - -class TransformDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */) override { - counter.fetch_add(1); - return true; - } -}; - -REGISTER_CPU_OPERATOR(TransformDummyOp1, TransformDummyOp); - -OPERATOR_SCHEMA(TransformDummyOp1) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -REGISTER_CPU_OPERATOR(TransformDummyOp2, TransformDummyOp); - -OPERATOR_SCHEMA(TransformDummyOp2) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -REGISTER_CPU_OPERATOR(TransformDummyOp3, TransformDummyOp); - -OPERATOR_SCHEMA(TransformDummyOp3) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -/** - * This TransformDummy transform will find all subgraphs of shape - * (TransformDummyOp1 -> TransformDummyOp2) and replaces them with - * (TransformDummyOp3). Simple unit test. - */ -class DummyTransform : public Transform { - public: - // Finds all patterns of the form (TransformDummyOp1 -> TransformDummyOp2) - bool PatternRule(const Graph& g, const std::vector& subgraph, int idx) - override { - if (subgraph.size() >= pattern_chain.size()) { - return false; - } - // which index are we trying to append the new node to? - auto pattern_idx = subgraph.size(); - // type doesn't match - if (g.node(idx).op.type() != pattern_chain[pattern_idx]) { - return false; - } - // not that head, and doesn't have exactly 1 parent - if (pattern_idx > 0 && g.node(idx).parents.size() != 1) { - return false; - } - // not that tail, and doesn't have exactly 1 child - if (pattern_idx < pattern_chain.size() - 1 && - g.node(idx).children.size() != 1) { - return false; - } - - return true; - } - - // Checks if the subgraph matched is (TransformDummyOp1 -> TransformDummyOp2) - bool ValidatorRule(const Graph& g, const std::vector& subgraph) - override { - if (subgraph.size() == 2) { - if (g.node(subgraph[0]).op.type() == "TransformDummyOp1" && - g.node(subgraph[1]).op.type() == "TransformDummyOp2") { - return true; - } - } - return false; - } - - // Replaces a match of (TransformDummyOp1 -> TransformDummyOp2) with - // (TransformDummyOp3) - bool ReplaceRule(const std::vector& match, Graph* g_ptr) override { - CHECK(g_ptr); - auto& g = *g_ptr; - OperatorDef new_op; - new_op.set_type("TransformDummyOp3"); - int new_idx = g.size(); - - std::map> new_op_children; - std::map> new_op_parents; - - // for each node parent in the head of the match, connect it to our new node - for (const auto& edge : g.node(match[0]).parents) { - int parent = edge.first; - for (const auto& blob : edge.second) { - g.node(parent).children[new_idx].push_back(blob); - new_op_parents[parent].push_back(blob); - } - } - for (const string& blob : g.node(match[0]).op.input()) { - new_op.add_input(blob); - } - - // for each child in the tail of the match, connect it to our new node - for (const auto& edge : g.node(match[1]).children) { - int child = edge.first; - for (const auto& blob : edge.second) { - g.node(child).parents[new_idx].push_back(blob); - new_op_children[child].push_back(blob); - } - } - for (const string& blob : g.node(match[1]).op.output()) { - new_op.add_output(blob); - } - - g.DeactivateSubgraph(match); - - g.push_node(transform::Node(new_op, true, new_op_parents, new_op_children)); - return true; - } - - private: - const std::vector pattern_chain = {"TransformDummyOp1", - "TransformDummyOp2"}; -}; - -REGISTER_TRANSFORM(TransformDummySwap, DummyTransform) - -TEST(TransformTest, TestPatternMatch) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "TransformDummyOp2", {"mid1"}, {"mid2"}); - AddOp(&netdef, "TransformDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "TransformDummyOp2", {"mid3"}, {"out"}); - - auto t = CreateTransform("TransformDummySwap"); - Graph g(netdef); - auto matches = t->PatternMatch(g); - - EXPECT_EQ(matches.size(), 2); - EXPECT_EQ(matches[0][0], 0); - EXPECT_EQ(matches[0][1], 1); - EXPECT_EQ(matches[1][0], 2); - EXPECT_EQ(matches[1][1], 3); -} - -TEST(TransformTest, TestReplacePattern) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "TransformDummyOp2", {"mid1"}, {"mid2"}); - AddOp(&netdef, "TransformDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "TransformDummyOp2", {"mid3"}, {"out"}); - - auto t = CreateTransform("TransformDummySwap"); - Graph g(netdef); - std::vector> matches = {{0, 1}, {2, 3}}; - t->ReplacePattern(matches, &g); - - EXPECT_EQ(g.size(), 6); - EXPECT_FALSE(g.is_node_active(0)); - EXPECT_FALSE(g.is_node_active(1)); - EXPECT_FALSE(g.is_node_active(2)); - EXPECT_FALSE(g.is_node_active(3)); - EXPECT_TRUE(g.is_node_active(4)); - EXPECT_TRUE(g.is_node_active(5)); - - EXPECT_EQ(g.node(4).children.size(), 1); - EXPECT_EQ(g.node(4).parents.size(), 0); - EXPECT_TRUE(g.node(4).children.count(5)); - - NetDef replaced_netdef = g.GetNetDef(); - - EXPECT_EQ(replaced_netdef.op().size(), 2); - EXPECT_EQ(replaced_netdef.op(0).type(), "TransformDummyOp3"); - EXPECT_EQ(replaced_netdef.op(0).input(0), "in"); - EXPECT_EQ(replaced_netdef.op(1).type(), "TransformDummyOp3"); - EXPECT_EQ(replaced_netdef.op(1).output(0), "out"); -} - -TEST(TransformTest, TestTransformApply) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "TransformDummyOp2", {"mid1"}, {"mid2"}); - AddOp(&netdef, "TransformDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "TransformDummyOp2", {"mid3"}, {"out"}); - - NetDef replaced_netdef = ApplyTransform("TransformDummySwap", netdef); - - EXPECT_EQ(replaced_netdef.op().size(), 2); - EXPECT_EQ(replaced_netdef.op(0).type(), "TransformDummyOp3"); - EXPECT_EQ(replaced_netdef.op(0).input(0), "in"); - EXPECT_EQ(replaced_netdef.op(1).type(), "TransformDummyOp3"); - EXPECT_EQ(replaced_netdef.op(1).output(0), "out"); -} - -/** - * Transform with Sorted Order matching. - * Matches two operators of type TransformDummyOp1, even if disconnected. - * These operators will be given in execution order, - * but doesn't need connectivity. - * Changes them to TransformDummyOp2. - */ -class SortedDummyTransform : public Transform { - public: - SortedDummyTransform() { - SetPatternMatchType(SORTED_WRT_EXECUTION_ORDER); - } - bool PatternRule(const Graph& g, const std::vector& subgraph, int idx) - override { - if (g.node(idx).op.type() != "TransformDummyOp1") { - return false; - } - return true; - } - bool ValidatorRule(const Graph& g, const std::vector& subgraph) - override { - if (subgraph.size() == 2) { - if (g.node(subgraph[0]).op.type() == "TransformDummyOp1" && - g.node(subgraph[1]).op.type() == "TransformDummyOp1") { - return true; - } - } - return false; - } - bool ReplaceRule(const std::vector& match, Graph* g_ptr) override { - CHECK(g_ptr); - for (const auto& x : match) { - g_ptr->node(x).op.set_type("TransformDummyOp2"); - } - return true; - } -}; - -REGISTER_TRANSFORM(SortedTransformDummySwap, SortedDummyTransform) - -TEST(TransformTest, TestPatternMatchTypeSortedOrder) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "TransformDummyOp3", {"mid1"}, {"mid2"}); - AddOp(&netdef, "TransformDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "TransformDummyOp3", {"mid3"}, {"out"}); - - auto t = CreateTransform("SortedTransformDummySwap"); - NetDef replaced_netdef = t->ApplyTo(netdef); - - EXPECT_EQ(replaced_netdef.op().size(), 4); - EXPECT_EQ(replaced_netdef.op(0).type(), "TransformDummyOp2"); - EXPECT_EQ(replaced_netdef.op(2).type(), "TransformDummyOp2"); -} - -/** - * General subgraph transform. - * Matches a TransformDummyOp1, and a TransformDummyOp2. - * Order doesn't matter. Connectedness doesn't matter. - * Turns them into TransformDummyOp3. - */ -class GeneralDummyTransform : public Transform { - public: - GeneralDummyTransform() { - SetPatternMatchType(GENERAL); - } - bool PatternRule(const Graph& g, const std::vector& subgraph, int idx) - override { - if (subgraph.size() == 0 && g.node(idx).op.type() == "TransformDummyOp1") { - return true; - } - if (subgraph.size() == 1 && g.node(idx).op.type() == "TransformDummyOp2") { - return true; - } - return false; - } - bool ValidatorRule(const Graph& g, const std::vector& subgraph) - override { - if (subgraph.size() == 2) { - if (g.node(subgraph[0]).op.type() == "TransformDummyOp1" && - g.node(subgraph[1]).op.type() == "TransformDummyOp2") { - return true; - } - } - return false; - } - bool ReplaceRule(const std::vector& match, Graph* g_ptr) override { - CHECK(g_ptr); - for (const auto& x : match) { - g_ptr->node(x).op.set_type("TransformDummyOp3"); - } - return true; - } -}; - -REGISTER_TRANSFORM(GeneralTransformDummySwap, GeneralDummyTransform) - -TEST(TransformTest, TestPatternMatchTypeGeneral) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "TransformDummyOp2", {"in"}, {"mid1"}); - AddOp(&netdef, "TransformDummyOp3", {"mid1"}, {"mid2"}); - AddOp(&netdef, "TransformDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "TransformDummyOp3", {"mid3"}, {"out"}); - - auto t = CreateTransform("GeneralTransformDummySwap"); - NetDef replaced_netdef = t->ApplyTo(netdef); - - EXPECT_EQ(replaced_netdef.op().size(), 4); - EXPECT_EQ(replaced_netdef.op(0).type(), "TransformDummyOp3"); - EXPECT_EQ(replaced_netdef.op(2).type(), "TransformDummyOp3"); -} - -class TransformSleepFastOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */) override { - std::this_thread::sleep_for(std::chrono::milliseconds(30)); - return true; - } -}; - -REGISTER_CPU_OPERATOR(TransformSleepFastOp, TransformSleepFastOp); - -OPERATOR_SCHEMA(TransformSleepFastOp) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -class TransformSleepSlowOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */) override { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - return true; - } -}; - -REGISTER_CPU_OPERATOR(TransformSleepSlowOp, TransformSleepSlowOp); - -OPERATOR_SCHEMA(TransformSleepSlowOp) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -/** - * This TransformDummy transform will find all operators of type old_type, - * and replace them with type new_type. - */ -class TypeSwapTransform : public Transform { - public: - // Determine the actual strings through inheriting from derived type. - // NOLINTNEXTLINE(modernize-pass-by-value) - explicit TypeSwapTransform(string old_type, string new_type) - : old_type(old_type), new_type(new_type) {} - - // Really simple, only accept if it's a FastSleepOp, and no match so far. - bool PatternRule(const Graph& g, const std::vector& subgraph, int idx) - override { - if (subgraph.size() == 0 && g.node(idx).op.type() == old_type) { - return true; - } - return false; - } - // Checks if the subgraph matched is a FastSleepOp - bool ValidatorRule(const Graph& g, const std::vector& subgraph) - override { - if (subgraph.size() == 1) { - if (g.node(subgraph[0]).op.type() == old_type) { - return true; - } - } - return false; - } - // Replaces op of original type to new type. - bool ReplaceRule(const std::vector& match, Graph* g_ptr) override { - CHECK(g_ptr); - auto& g = *g_ptr; - g.node(match[0]).op.set_type(new_type); - return true; - } - - private: - string old_type; - string new_type; -}; - -class FastToSlowTransform : public TypeSwapTransform { - public: - explicit FastToSlowTransform() - : TypeSwapTransform("TransformSleepFastOp", "TransformSleepSlowOp") {} -}; - -REGISTER_TRANSFORM(FastToSlow, FastToSlowTransform); - -class SlowToFastTransform : public TypeSwapTransform { - public: - explicit SlowToFastTransform() - : TypeSwapTransform("TransformSleepSlowOp", "TransformSleepFastOp") {} -}; - -REGISTER_TRANSFORM(SlowToFast, SlowToFastTransform); - -TEST(TransformTest, TestApplyTransformIfFasterIsFaster) { - NetDef init_netdef; - AddOp(&init_netdef, "ConstantFill", {}, {"in"}); - - NetDef netdef; - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid"}); - AddOp(&netdef, "TransformSleepSlowOp", {"mid"}, {"out"}); - netdef.add_external_input("in"); // This is important for this function. - - // Make sure the transform would work normally. - auto transformed_net = ApplyTransform("SlowToFast", netdef); - EXPECT_EQ(transformed_net.op(1).type(), "TransformSleepFastOp"); - - // Should be still transform normally. - auto mystery_net = - ApplyTransformIfFaster("SlowToFast", netdef, init_netdef, 5, 10, 1.01); - EXPECT_EQ(mystery_net.op(1).type(), "TransformSleepFastOp"); -} - -TEST(TransformTest, TestApplyTransformIfFasterButSlower) { - NetDef init_netdef; - AddOp(&init_netdef, "ConstantFill", {}, {"in"}); - - NetDef netdef; - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid"}); - AddOp(&netdef, "TransformSleepFastOp", {"mid"}, {"out"}); - netdef.add_external_input("in"); // This is important for this function. - - // Make sure the transform would work normally. - auto transformed_net = ApplyTransform("FastToSlow", netdef); - EXPECT_EQ(transformed_net.op(1).type(), "TransformSleepSlowOp"); - - // Should not actually change! - auto mystery_net = - ApplyTransformIfFaster("FastToSlow", netdef, init_netdef, 5, 10, 1.01); - EXPECT_EQ(mystery_net.op(1).type(), "TransformSleepFastOp"); -} - -} // namespace - -} // namespace caffe2 diff --git a/caffe2/core/workspace_test.cc b/caffe2/core/workspace_test.cc deleted file mode 100644 index c3f6ff0fb48f..000000000000 --- a/caffe2/core/workspace_test.cc +++ /dev/null @@ -1,149 +0,0 @@ -#include - -#include "caffe2/core/operator.h" -#include - -namespace caffe2 { - -class WorkspaceTestFoo {}; - -CAFFE_KNOWN_TYPE(WorkspaceTestFoo); - -TEST(WorkspaceTest, BlobAccess) { - Workspace ws; - - EXPECT_FALSE(ws.HasBlob("nonexisting")); - EXPECT_EQ(ws.GetBlob("nonexisting"), nullptr); - - EXPECT_EQ(ws.GetBlob("newblob"), nullptr); - EXPECT_NE(nullptr, ws.CreateBlob("newblob")); - EXPECT_NE(nullptr, ws.GetBlob("newblob")); - EXPECT_TRUE(ws.HasBlob("newblob")); - - // Different names should still be not created. - EXPECT_FALSE(ws.HasBlob("nonexisting")); - EXPECT_EQ(ws.GetBlob("nonexisting"), nullptr); - - // Check if the returned Blob is OK for all operations - Blob* blob = ws.GetBlob("newblob"); - int* int_unused CAFFE2_UNUSED = blob->GetMutable(); - EXPECT_TRUE(blob->IsType()); - EXPECT_FALSE(blob->IsType()); - EXPECT_NE(&blob->Get(), nullptr); - - // Re-creating the blob does not change the content as long as it already - // exists. - EXPECT_NE(nullptr, ws.CreateBlob("newblob")); - EXPECT_TRUE(blob->IsType()); - EXPECT_FALSE(blob->IsType()); - // When not null, we should only call with the right type. - EXPECT_NE(&blob->Get(), nullptr); - - // Re-creating the blob through CreateLocalBlob does not change the content - // either. - EXPECT_NE(nullptr, ws.CreateLocalBlob("newblob")); - EXPECT_TRUE(blob->IsType()); - EXPECT_NE(&blob->Get(), nullptr); - - // test removing blob - EXPECT_FALSE(ws.HasBlob("nonexisting")); - EXPECT_FALSE(ws.RemoveBlob("nonexisting")); - EXPECT_TRUE(ws.HasBlob("newblob")); - EXPECT_TRUE(ws.RemoveBlob("newblob")); - EXPECT_FALSE(ws.HasBlob("newblob")); -} - -TEST(WorkspaceTest, RunEmptyPlan) { - PlanDef plan_def; - Workspace ws; - EXPECT_TRUE(ws.RunPlan(plan_def)); -} - -TEST(WorkspaceTest, Sharing) { - Workspace parent; - EXPECT_FALSE(parent.HasBlob("a")); - EXPECT_TRUE(parent.CreateBlob("a")); - EXPECT_TRUE(parent.GetBlob("a")); - { - Workspace child(&parent); - // Child can access parent blobs - EXPECT_TRUE(child.HasBlob("a")); - EXPECT_TRUE(child.GetBlob("a")); - // Child can create local blobs - EXPECT_FALSE(child.HasBlob("b")); - EXPECT_FALSE(child.GetBlob("b")); - EXPECT_TRUE(child.CreateBlob("b")); - EXPECT_TRUE(child.GetBlob("b")); - // Parent cannot access child blobs - EXPECT_FALSE(parent.GetBlob("b")); - EXPECT_FALSE(parent.HasBlob("b")); - // Parent can create duplicate names - EXPECT_TRUE(parent.CreateBlob("b")); - // But child has local overrides - EXPECT_NE(child.GetBlob("b"), parent.GetBlob("b")); - // Child can create a blob that already exists in the parent - EXPECT_TRUE(child.CreateBlob("a")); - EXPECT_EQ(child.GetBlob("a"), parent.GetBlob("a")); - // Child can create a local blob for the blob already exists in the parent - EXPECT_TRUE(child.CreateLocalBlob("a")); - // But the local blob will be different from the one in parent workspace - EXPECT_NE(child.GetBlob("a"), parent.GetBlob("a")); - } -} - -TEST(WorkspaceTest, BlobMapping) { - Workspace parent; - EXPECT_FALSE(parent.HasBlob("a")); - EXPECT_TRUE(parent.CreateBlob("a")); - EXPECT_TRUE(parent.GetBlob("a")); - { - std::unordered_map forwarded_blobs; - forwarded_blobs["inner_a"] = "a"; - Workspace child(&parent, forwarded_blobs); - EXPECT_FALSE(child.HasBlob("a")); - EXPECT_TRUE(child.HasBlob("inner_a")); - EXPECT_TRUE(child.GetBlob("inner_a")); - Workspace ws; - EXPECT_TRUE(ws.CreateBlob("b")); - forwarded_blobs.clear(); - forwarded_blobs["inner_b"] = "b"; - child.AddBlobMapping(&ws, forwarded_blobs); - EXPECT_FALSE(child.HasBlob("b")); - EXPECT_TRUE(child.HasBlob("inner_b")); - EXPECT_TRUE(child.GetBlob("inner_b")); - } -} - -/** - * Checks that Workspace::ForEach(f) applies f on the specified set of - * workspaces in any order. - */ -static void forEachCheck(std::initializer_list workspaces) { - std::unordered_set expected(workspaces); - std::unordered_set actual; - Workspace::ForEach([&](Workspace* ws) { - auto inserted = actual.insert(ws).second; - EXPECT_TRUE(inserted); - }); - EXPECT_EQ(actual, expected); -} - -TEST(WorkspaceTest, ForEach) { - forEachCheck({}); - - { - Workspace ws1; - forEachCheck({&ws1}); - - { - Workspace ws2; - forEachCheck({&ws1, &ws2}); - } - - forEachCheck({&ws1}); - } - - forEachCheck({}); -} - -} // namespace caffe2 diff --git a/caffe2/release-notes.md b/caffe2/release-notes.md deleted file mode 100644 index d449e98f78e3..000000000000 --- a/caffe2/release-notes.md +++ /dev/null @@ -1,175 +0,0 @@ -# Caffe2 v0.7.0 Release Notes - -## Installation - -This build is confirmed for: - -* Ubuntu 14.04 -* Ubuntu 16.06 - -### Required Dependencies - -```bash -sudo apt-get update -sudo apt-get install -y --no-install-recommends \ - build-essential \ - cmake \ - git \ - libgoogle-glog-dev \ - libprotobuf-dev \ - protobuf-compiler \ - python-dev \ - python-pip -sudo pip install numpy protobuf -``` - -### Optional GPU Support - -If you plan to use GPU instead of CPU only, then you should install NVIDIA CUDA and cuDNN, a GPU-accelerated library of primitives for deep neural networks. -[NVIDIA's detailed instructions](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#ubuntu-installation) or if you're feeling lucky try the quick install set of commands below. - -**Update your graphics card drivers first!** Otherwise you may suffer from a wide range of difficult to diagnose errors. - -**For Ubuntu 14.04** - -```bash -sudo apt-get update && sudo apt-get install wget -y --no-install-recommends -wget "http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1404/x86_64/cuda-repo-ubuntu1404_8.0.61-1_amd64.deb" -sudo dpkg -i cuda-repo-ubuntu1404_8.0.61-1_amd64.deb -sudo apt-get update -sudo apt-get install cuda -``` - -**For Ubuntu 16.04** - -```bash -sudo apt-get update && sudo apt-get install wget -y --no-install-recommends -wget "http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_8.0.61-1_amd64.deb" -sudo dpkg -i cuda-repo-ubuntu1604_8.0.61-1_amd64.deb -sudo apt-get update -sudo apt-get install cuda -``` - -#### Install cuDNN (all Ubuntu versions) - -``` -CUDNN_URL="http://developer.download.nvidia.com/compute/redist/cudnn/v5.1/cudnn-8.0-linux-x64-v5.1.tgz" -wget ${CUDNN_URL} -sudo tar -xzf cudnn-8.0-linux-x64-v5.1.tgz -C /usr/local -rm cudnn-8.0-linux-x64-v5.1.tgz && sudo ldconfig -``` - -### Optional Dependencies - -> Note `libgflags2` is for Ubuntu 14.04. `libgflags-dev` is for Ubuntu 16.04. - -```bash -# for Ubuntu 14.04 -sudo apt-get install -y --no-install-recommends libgflags2 -``` - -```bash -# for Ubuntu 16.04 -sudo apt-get install -y --no-install-recommends libgflags-dev -``` - -```bash -# for both Ubuntu 14.04 and 16.04 -sudo apt-get install -y --no-install-recommends \ - libgtest-dev \ - libiomp-dev \ - libleveldb-dev \ - liblmdb-dev \ - libopencv-dev \ - libopenmpi-dev \ - libsnappy-dev \ - openmpi-bin \ - openmpi-doc \ - python-pydot -sudo pip install \ - flask \ - graphviz \ - hypothesis \ - jupyter \ - matplotlib \ - pydot python-nvd3 \ - pyyaml \ - requests \ - scikit-image \ - scipy \ - setuptools \ - tornado -``` - -### Clone & Build - -```bash -git clone --recursive https://github.com/caffe2/caffe2.git && cd caffe2 -make && cd build && sudo make install -python -c 'from caffe2.python import core' 2>/dev/null && echo "Success" || echo "Failure" -``` - -Run this command below to test if your GPU build was a success. You will get a test output either way, but it will warn you at the top of the output if CPU was used instead along with other errors like missing libraries. - -```bash -python -m caffe2.python.operator_test.relu_op_test -``` - -### Environment Variables - -These environment variables may assist you depending on your current configuration. When using the install instructions above on the AWS Deep Learning AMI you don't need to set these variables. However, our Docker scripts built on Ubuntu-14.04 or NVIDIA's CUDA images seem to benefit from having these set. If you ran into problems with the build tests above then these are good things to check. Echo them first and see what you have and possibly append or replace with these directories. Also visit the troubleshooting section below. - -```bash -echo $PYTHONPATH -# export PYTHONPATH=/usr/local:$PYTHONPATH -# export PYTHONPATH=$PYTHONPATH:/home/ubuntu/caffe2/build -echo $LD_LIBRARY_PATH -# export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH -``` - -### Setting Up Tutorials & Jupyter Server - -If you're running this all on a cloud computer, you probably won't have a UI or way to view the IPython notebooks by default. Typically, you would launch them locally with `ipython notebook` and you would see a localhost:8888 webpage pop up with the directory of notebooks running. The following example will show you how to launch the Jupyter server and connect to remotely via an SSH tunnel. - -First configure your cloud server to accept port 8889, or whatever you want, but change the port in the following commands. On AWS you accomplish this by adding a rule to your server's security group allowing a TCP inbound on port 8889. Otherwise you would adjust iptables for this. - -Next you launch the Jupyter server. - -``` -jupyter notebook --no-browser --port=8889 -``` - -Then create the SSH tunnel. This will pass the cloud server's Jupyter instance to your localhost 8888 port for you to use locally. The example below is templated after how you would connect AWS, where `your-public-cert.pem` is your own public certificate and `ubuntu@super-rad-GPU-instance.compute-1.amazonaws.com` is your login to your cloud server. You can easily grab this on AWS by going to Instances > Connect and copy the part after `ssh` and swap that out in the command below. - -``` -ssh -N -f -L localhost:8888:localhost:8889 -i "your-public-cert.pem" ubuntu@super-rad-GPU-instance.compute-1.amazonaws.com -``` - -### Troubleshooting - -|Python errors|| -|----|-----| -|Python version | [Python](https://www.python.org/) is core to run Caffe2. We currently require [Python2.7](https://www.python.org/download/releases/2.7/). *Ubuntu 14.04 and greater have Python built in by default*, and that can be used to run Caffe2. To check your version: `python --version`| -|Solution | If you want the developer version of python, you could install the `dev` package for Python: `sudo apt-get install python-dev`| -|Python environment | You may have another version of Python installed or need to support Python version 3 for other projects.| -|Solution | Try virtualenv or Anaconda. The [Anaconda](https://www.continuum.io/downloads) platform provides a single script to install many of the necessary packages for Caffe2, including Python. Using Anaconda is outside the scope of these instructions, but if you are interested, it may work well for you.| -|pip version | If you plan to use Python with Caffe2 then you need pip.| -|Solution | `sudo apt-get install python-pip` and also try using pip2 instead of pip.| -|"AttributeError: 'module' object has no attribute 'MakeArgument'" | Occurs when calling `core.CreateOperator`| -|Solution | Check your install directory (`/usr/local/`), and remove the folder `/caffe2/python/utils`| - -|Building from source|| -|----|-----| -|OS version | Caffe2 requires Ubuntu 14.04 or greater.| -|git | While you can download the Caffe2 source code and submodules directly from GitHub as a zip, using git makes it much easier.| -|Solution | `sudo apt-get install git`| -|protobuf | You may experience an error related to protobuf during the make step.| -|Solution | Make sure you've installed protobuf in **both** of these two ways: `sudo apt-get install libprotobuf-dev protobuf-compiler && sudo pip install protobuf`| -|libgflags2 error | This optional dependency is for Ubuntu 14.04.| -|Solution | Use `apt-get install libgflags-dev` for Ubuntu 16.04.| - -|GPU Support|| -|----|-----| -|GPU errors | Unsupported GPU or wrong version| -|Solution | You need to know the specific `deb` for your version of Linux. `sudo dpkg -i| |cuda-repo-__.deb` Refer to NVIDIA's [installation guide](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#ubuntu-installation).| -|Build issues | Be warned that installing CUDA and cuDNN will increase the size of your build by about 4GB, so plan to have at least 12GB for your Ubuntu disk size.| diff --git a/caffe2/requirements.txt b/caffe2/requirements.txt deleted file mode 100644 index aa8d2be43aa5..000000000000 --- a/caffe2/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -numpy -enum34 -pyyaml -requests From 8a31c2aa84e29f861f401735dd26ec3d6b7a39d0 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 29 May 2024 17:15:25 +0000 Subject: [PATCH 043/706] [export] allow complex guards as runtime asserts (#127129) With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that. For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of in the specified range satisfy ", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](https://github.com/pytorch/pytorch/pull/124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs. In this PR, relying on [hybrid backed-unbacked symints](https://github.com/pytorch/pytorch/issues/121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](https://github.com/pytorch/pytorch/blob/d7de4c9d809697b36ae0fd9e16815f6e3b4d985b/torch/fx/experimental/symbolic_shapes.py#L1824), we add a flag `_allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph. Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes: ``` # reshape def forward(self, x, y): # x: [s0, s1], y: [s2] return x.reshape([-1]) + y # guard s0 * s1 = s2 This leads to the following exported program class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1]", y: "f32[s2]"): sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0) mul: "Sym(-s2)" = -1 * sym_size_int; sym_size_int = None sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1) mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2; sym_size_int_1 = sym_size_int_2 = None add: "Sym(s0*s1 - s2)" = mul + mul_1; mul = mul_1 = None eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0; add = None _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'"); eq = None view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]); x = None add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y); view = y = None return (add_1,) ``` Another case is symbol divisibility: ``` def forward(self, x): # x: [s0, s1] return x.reshape([-1, x.shape[0] - 1]) # Eq(Mod(s0 * s1, s0 - 1), 0) ``` Applying deferred runtime asserts also helps dynamic compilation for "explicit" complex guards that typically cause problems for export. For example we can generate runtime asserts for not-equal guards, and complex conditions like the following: ``` class Foo(torch.nn.Module): def forward(self, x, y): # check that negation of first guard also shows up as runtime assertion if x.shape[0] == y.shape[0]: # False return x + y elif x.shape[0] == y.shape[0] ** 3: # False return x + 2, y + 3 elif x.shape[0] ** 2 == y.shape[0] * 3: # True return x * 2.0, y * 3.0 ``` For the above graph we will generate 3 runtime assertions: the negation of the first 2, and the 3rd condition as a guard. One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given. As shown above, the runtime asserts appear as math ops in the graph, generated by the sympy interpreter, resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related. Ideally this PR would subsume and remove the recently added [_disable_forced_specializations](https://github.com/pytorch/pytorch/pull/124949) flag, but that flag still handles one additional case of specialization: single-variable equalities where the symbol is solvable for a concrete value: see this [PR](https://github.com/pytorch/pytorch/pull/126925) This PR doesn't change any behavior around data-dependent errors/unbacked symints yet, that could be further work. NOTE: will take naming change suggestions for the flag :) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127129 Approved by: https://github.com/avikchaudhuri --- test/dynamo/test_misc.py | 4 +- test/export/test_export.py | 177 ++++++++++++++++++++--- torch/_dynamo/config.py | 9 ++ torch/_dynamo/eval_frame.py | 4 + torch/_dynamo/output_graph.py | 2 + torch/_export/non_strict_utils.py | 20 ++- torch/export/_trace.py | 44 +++++- torch/fx/experimental/symbolic_shapes.py | 107 +++++++++----- torch/fx/passes/runtime_assert.py | 28 +++- torch/utils/_sympy/reference.py | 4 +- 10 files changed, 326 insertions(+), 73 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 3ed06a55c837..6a44c5603dc9 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9195,8 +9195,8 @@ def test_shape_env_equal_constructor(self): ShapeEnv not equal: field values don't match: ==> settings: values don't match. - > Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False) - > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False) + > Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False) + > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False) """, ) self._replay_and_check(main) diff --git a/test/export/test_export.py b/test/export/test_export.py index 5c1dbe2602ca..fcc0db08c701 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -5039,8 +5039,9 @@ def forward(self, x): export(f, (inputs,), dynamic_shapes=dynamic_shapes) def test_disable_forced_specializations(self): - # case 1 - # check disable_forced_specializations flag behaves correctly + # check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags + # both behave correctly, avoiding forced specializations and deferring to runtime. + # case 1: modulo guards from torch.export import dims class Mod4Reshape(torch.nn.Module): @@ -5055,31 +5056,36 @@ def forward(self, x): r".*dx = .* must be specialized to 10 because the guards generated for it are too complex(.*\n)*" r".*dy = .* must be specialized to 72 because the guards generated for it are too complex(.*\n)*", ): - torch.export._trace._export( + export( Mod4Reshape(), inputs, dynamic_shapes={"x": (dx, dy)}, - strict=False, - _disable_forced_specializations=False, ) - ep = torch.export._trace._export( + + torch.export._trace._export( # just check this successfully compiles Mod4Reshape(), inputs, dynamic_shapes={"x": (dx, dy)}, strict=False, _disable_forced_specializations=True, ) + ep = torch.export._trace._export( + Mod4Reshape(), + inputs, + dynamic_shapes={"x": (dx, dy)}, + _allow_complex_guards_as_runtime_asserts=True, + ) out1 = ep.module()(torch.randn(8, 7)) self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape) - out2 = ep.module()(torch.randn(4, 3)) - self.assertEqual(out2.shape, torch.ones(3, 4, 1).shape) + out2 = ep.module()(torch.randn(12, 11)) + self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape) with self.assertRaisesRegex( RuntimeError, - r"shape .*7, 4, -1.* is invalid for input of size 64", + r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, 4\*s0 \- 4\), 0\) on node 'eq.*'", ): ep.module()(torch.randn(8, 8)) # fail - # case 2 + # case 2: 2d reshape class FreeReshape(torch.nn.Module): def forward(self, x, y, z): return x.reshape([-1]) + y.reshape([-1]) + z # s0*s1 = s2*s3 = s4 @@ -5090,9 +5096,9 @@ def forward(self, x, y, z): torch.randn(48), ) dynamic_shapes = { - "x": [Dim(f"dx{i}") for i in range(2)], - "y": [Dim(f"dy{i}") for i in range(2)], - "z": [Dim(f"dz{i}") for i in range(1)], + "x": [Dim(f"dx{i}", min=2) for i in range(2)], + "y": [Dim(f"dy{i}", min=2) for i in range(2)], + "z": [Dim(f"dz{i}", min=4) for i in range(1)], } with self.assertRaisesRegex( # this will force specialize torch._dynamo.exc.UserError, @@ -5100,32 +5106,85 @@ def forward(self, x, y, z): r".*dx0 = .* must be specialized to 6 because the guards generated for it are too complex(.*\n)*" r".*dx1 = .* must be specialized to 8 because the guards generated for it are too complex(.*\n)*", ): - torch.export._trace._export( + export( FreeReshape(), inputs, dynamic_shapes=dynamic_shapes, - strict=False, - _disable_forced_specializations=False, ) - ep = torch.export._trace._export( + torch.export._trace._export( FreeReshape(), inputs, dynamic_shapes=dynamic_shapes, strict=False, _disable_forced_specializations=True, ) + ep = torch.export._trace._export( + FreeReshape(), + inputs, + dynamic_shapes=dynamic_shapes, + _allow_complex_guards_as_runtime_asserts=True, + ) out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48)) self.assertEqual(out1.shape, torch.ones(48).shape) out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40)) self.assertEqual(out2.shape, torch.ones(40).shape) with self.assertRaisesRegex( RuntimeError, - r"The size of tensor a .* must match the size of tensor b .* at non-singleton dimension 0", + r"Runtime assertion failed for expression Eq\(s0\*s1 \- s2\*s3, 0\) on node 'eq.*'", ): # fail only at runtime ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail + # case 3: 3d reshape (previously failing with different issue) + class Reshape3d(torch.nn.Module): + def forward(self, x, y): + return x.reshape([-1]) + y # s0*s1*s2 = s3 + + inputs = ( + torch.randn(4, 3, 2), + torch.randn(24), + ) + dynamic_shapes = { + "x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)), + "y": (Dim("dy", min=8),), + } + with self.assertRaisesRegex( # this will force specialize + torch._dynamo.exc.UserError, + r".*Specializations unexpectedly required(.*\n)*" + r"Suggested fixes:(.*\n)*" + r".*dx0 = 4(.*\n)*" + r".*dx1 = 3(.*\n)*" + r".*dx2 = 2(.*\n)*" + r".*dy = 24(.*\n)*", + ): + export( + Reshape3d(), + inputs, + dynamic_shapes=dynamic_shapes, + ) + + torch.export._trace._export( + Reshape3d(), + inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + _disable_forced_specializations=True, + ) + ep = torch.export._trace._export( + Reshape3d(), + inputs, + dynamic_shapes=dynamic_shapes, + _allow_complex_guards_as_runtime_asserts=True, + ) + out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126)) + self.assertEqual(out1.shape, torch.ones(126).shape) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Eq\(s0\*s1\*s2 \- s3, 0\) on node 'eq.*'", + ): # fail only at runtime + ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail + def test_disable_forced_specializations_errors(self): - # check error messages with disable_forced_specializations=False/True + # check error messages with disable_forced_specializations = False/True class Foo(torch.nn.Module): def forward(self, w, x, y, z): return w.reshape([-1]) + x, y + z # simple: s0*s1 = s2, s3 = s4 @@ -5142,7 +5201,7 @@ def forward(self, w, x, y, z): "y": [Dim("dy")], # y & z incorrect, export is supposed to fail. "z": [Dim("dz")], # suggested fix should be to match these up. } - with self.assertRaisesRegex( # if disable=False, suggested fixes should specialize 3, 4, 12. + with self.assertRaisesRegex( # if allow = False, suggested fixes should specialize 3, 4, 12. torch._dynamo.exc.UserError, r".*Specializations unexpectedly required(.*\n)*" r"Suggested fixes:(.*\n)*" @@ -5172,6 +5231,84 @@ def forward(self, w, x, y, z): _disable_forced_specializations=True, ) + def test_reshape_view_helper(self): + # see: https://github.com/pytorch/pytorch/issues/126607 + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x.view(x.size(1), -1) + # torch/_refs/__init__/_reshape_view_helper() will generate guards on reshape kernel(?) + # Ne(s0, 20), so that reshape isn't no-op + # Ne(Mod(s0, 20), 0), so that reshape needs to first flatten [s0, 20, 16] -> [s0*20, 16] + # then split_dim -> [20, s0, 16] + # check that these show up in graph + return torch.nn.functional.softmax( + x, dim=0 + ) # don't think softmax actually creates any issues, just part of original test + + model = Model() + x = torch.rand(1024, 20, 16) + dynamic_shapes = {"x": {0: Dim("batch")}} + ep = torch.export._trace._export( + model, + (x,), + dynamic_shapes=dynamic_shapes, + _allow_complex_guards_as_runtime_asserts=True, + ) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Ne\(s0, 20\)", + ): + ep.module()(torch.randn(20, 20, 16)) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Ne\(Mod\(s0, 20\), 0\)", + ): + ep.module()(torch.randn(400, 20, 16)) + ep.module()(torch.randn(42, 20, 16)) + + def test_allow_explicit_guards_as_runtime_asserts(self): + # check that explicit guards are treated as runtime assertions + class Foo(torch.nn.Module): + def forward(self, x, y): + # check that negation of first guard also shows up as runtime assertion + if x.shape[0] == y.shape[0]: # False + return x + y + elif x.shape[0] == y.shape[0] ** 3: # False + return x + 2, y + 3 + elif x.shape[0] ** 2 == y.shape[0] * 3: # True + return x * 2.0, y * 3.0 + + inputs = (torch.randn(6), torch.randn(12)) + dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]} + ep = torch.export._trace._export( + Foo(), + inputs, + dynamic_shapes=dynamic_shapes, + _allow_complex_guards_as_runtime_asserts=True, + ) + # check forward pass + out0, out1 = ep.module()(torch.randn(9), torch.randn(27)) + self.assertEqual(out0.shape, torch.ones(9).shape) + self.assertEqual(out1.shape, torch.ones(27).shape) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Ne\(s0 \- s1, 0\)", + ): # fail only at runtime + ep.module()(torch.randn(4), torch.randn(4)) # fail + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Ne\(s0 \- s1\**3, 0\)", + ): + ep.module()(torch.randn(64), torch.randn(4)) # fail + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Eq\(s0\**2 \- 3\*s1, 0\)", + ): + ep.module()(torch.randn(10), torch.randn(9)) # fail + def test_constant_aliasing(self): class M1(torch.nn.Module): def __init__(self, m2, foo): diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 62138127befd..6f4219a03b18 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -225,6 +225,15 @@ def is_fbcode(): os.environ.get("TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS", "0") == "1" ) +# hybrid backed unbacked symints +prefer_deferred_runtime_asserts_over_guards = False + +# For complex dynamic shapes guards that we're unable to specify with dynamo/export's +# range constraints + dims + derived dims language, we raise constraint violation +# errors or specialize by default. If set to True, this flag avoids crashing/specialization, +# and allows complex guards as runtime assertions in the graph. +_allow_complex_guards_as_runtime_asserts = False + # By default, dynamo will treat all ints as backed SymInts, which means (1) it # will wait to see the int change over multiple runs before generalizing and # (2) it will still always 0/1 specialize an int. When true, this knob diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index fe06995771e0..fa9311f2c18a 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1129,6 +1129,8 @@ def export( assume_static_by_default: bool = False, same_signature: bool = True, disable_constraint_solver: bool = False, + prefer_deferred_runtime_asserts_over_guards: bool = False, + _allow_complex_guards_as_runtime_asserts: bool = False, _log_export_usage: bool = True, **extra_kwargs, ) -> Callable[..., ExportResult]: @@ -1304,6 +1306,8 @@ def result_capturing_wrapper(*graph_inputs): automatic_dynamic_shapes=False, capture_dynamic_output_shape_ops=True, capture_scalar_outputs=True, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, ): opt_f = optimize_assert( dynamo_normalization_capturing_compiler, diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index e2bf4e2b3ed6..3ab386cf35b2 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -292,6 +292,8 @@ def __init__( tracked_fakes=self.tracked_fakes, allow_scalar_outputs=config.capture_scalar_outputs, allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, + prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards, + _allow_complex_guards_as_runtime_asserts=config._allow_complex_guards_as_runtime_asserts, co_fields=self.co_fields, ) diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 8f67a3cd258e..d15cb29f28df 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -111,7 +111,12 @@ def make_fake_params_buffers( def make_fake_inputs( - nn_module, args, kwargs, dynamic_shapes, _is_torch_jit_trace=False + nn_module, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=False, + _allow_complex_guards_as_runtime_asserts=False, ): """ Given an nn module, example inputs, and constraints, return a new fake mode, @@ -156,13 +161,22 @@ def make_fake_inputs( "co_firstlineno": code.co_firstlineno, } fake_mode = FakeTensorMode( - shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields), + shape_env=ShapeEnv( + tracked_fakes=[], + co_fields=co_fields, + prefer_deferred_runtime_asserts_over_guards=_allow_complex_guards_as_runtime_asserts, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, + ), allow_non_fake_inputs=True, export=True, ) else: fake_mode = FakeTensorMode( - shape_env=ShapeEnv(tracked_fakes=[]), + shape_env=ShapeEnv( + tracked_fakes=[], + prefer_deferred_runtime_asserts_over_guards=_allow_complex_guards_as_runtime_asserts, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, + ), allow_non_fake_inputs=True, ) if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 31c933b4518a..976fddf0c874 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -2,6 +2,7 @@ import functools import inspect import logging +import os import re import time import warnings @@ -481,6 +482,7 @@ def _export_to_torch_ir( *, preserve_module_call_signature: Tuple[str, ...] = (), disable_constraint_solver: bool = False, + _allow_complex_guards_as_runtime_asserts: bool = False, restore_fqn: bool = True, _log_export_usage: bool = True, same_signature: bool = True, @@ -513,6 +515,10 @@ def _export_to_torch_ir( assume_static_by_default=True, tracing_mode="symbolic", disable_constraint_solver=disable_constraint_solver, + # currently the following 2 flags are tied together for export purposes, + # but untangle for sake of dynamo export api + prefer_deferred_runtime_asserts_over_guards=_allow_complex_guards_as_runtime_asserts, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, _log_export_usage=_log_export_usage, same_signature=same_signature, )( @@ -547,6 +553,10 @@ def _export_to_aten_ir( pre_dispatch=False, _is_torch_jit_trace=False, ): + # set this to False if env variable is specified + if os.environ.get("TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS", "0") == "1": + should_insert_runtime_assertion = False + # [NOTE] If the user is exporting under training mode, we want to detect if there is any # state change in the autograd global state and error. If the user is exporting under inference # mode, we don't care. At predispatch level, we don't care about the state change. @@ -1043,6 +1053,7 @@ def _strict_export( pre_dispatch: bool, original_state_dict: Dict[str, Any], orig_in_spec: TreeSpec, + _allow_complex_guards_as_runtime_asserts: bool, _disable_forced_specializations: Optional[bool], _is_torch_jit_trace: bool, ): @@ -1053,6 +1064,7 @@ def _strict_export( dynamic_shapes, preserve_module_call_signature=preserve_module_call_signature, restore_fqn=False, # don't need to restore because we will do it later + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, _log_export_usage=False, ) @@ -1215,6 +1227,7 @@ def _non_strict_export( pre_dispatch: bool, original_state_dict: Dict[str, Any], orig_in_spec: TreeSpec, + _allow_complex_guards_as_runtime_asserts: bool, _disable_forced_specializations: Optional[bool], _is_torch_jit_trace: bool, ): @@ -1283,7 +1296,12 @@ def forward(self, *args, **kwargs): equalities_inputs, original_signature, ) = make_fake_inputs( - mod, args, kwargs, dynamic_shapes, _is_torch_jit_trace=_is_torch_jit_trace + mod, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=_is_torch_jit_trace, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, # for shape env initialization ) fake_params_buffers = make_fake_params_buffers(fake_mode, _get_params_buffers(mod)) @@ -1346,6 +1364,7 @@ def _export( strict: bool = True, preserve_module_call_signature: Tuple[str, ...] = (), pre_dispatch: bool = False, + _allow_complex_guards_as_runtime_asserts: bool = False, _disable_forced_specializations: Optional[bool] = False, _is_torch_jit_trace: bool = False, ) -> ExportedProgram: @@ -1378,13 +1397,23 @@ def _export( preserve_module_call_signature: A list of submodule paths for which the original calling conventions are preserved as metadata. + _allow_complex_guards_as_runtime_asserts: + With the current dynamic shapes language for dims and derived dims, we can run into constraints + that are not expressible with the language. For example, flattening a matrix and adding to a vector, + both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible. + By default, we either raise a constraint violation error or specialize to static values. + If this flag is set to True, we avoid erroring out and instead allow complex constraints to exist as runtime + assertions in the graph. The sympy interpreter (torch/utils/_sympy/interp.py) will produce the math ops + required to compute and assert the value of the guard (e.g. sym_size_int, eq, _assert_scalar). + Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints + while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes. + _disable_forced_specializations: - By default, some inferred dynamic shapes guards/constraints that are not expressible with the current - dynamic shapes language will lead to specialization to the concrete input values provided. - If _disable_forced_specializations is set to True, we will not specialize, and will not perform runtime - checks on such produced guards. Instead, we allow the user to specify arbitrary shapes, - and fail during runtime if the inputs are invalid. Constraints expressible with the language - (e.g. ranges, linear derived dims) will still be enforced. + Similar to _allow_complex_guards_as_runtime_asserts, but only avoids specializing to static values if set to True. + For complex guards that don't specialize, this flag doesn't have any effect. Ideally this would be subsumed by + _allow_complex_guards_as_runtime_asserts, but this handles one additional case: single-variable equalities where + the symbol is solvable for a concrete value (e.g. Eq(s0 // 4, 400) -> s0 = 1600). If set to True, this flag will + avoid specializations. Direct equalities (e.g. s0 = 4), will still specialize. Returns: An ExportedProgram containing the traced method. @@ -1432,6 +1461,7 @@ def _export( pre_dispatch, original_state_dict, orig_in_spec, + _allow_complex_guards_as_runtime_asserts, _disable_forced_specializations, _is_torch_jit_trace, ) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 9a9d7baa21ef..7492009e517a 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1191,6 +1191,19 @@ def _assert_symbol_context(symbolic_context): assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object" assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC" +def _is_supported_equivalence(expr): + # Currently supported Dim ops are linear expressions with integer coefficients. + # So check that expr only contains +, *, ints, and a single occurrence of a symbol. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(expr, (sympy.Add, sympy.Mul)): + if len(expr.args) > 2: + return False + lhs, rhs = expr.args + return ( + (_is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or + (isinstance(lhs, sympy.Integer) and _is_supported_equivalence(rhs)) + ) + return isinstance(expr, sympy.Symbol) @dataclass(frozen=True) class SymbolicContext: @@ -1526,7 +1539,14 @@ class DimConstraints: Solutions are "static" values or simplified "dynamic" constraints. """ - def __init__(self, symbol_to_source, var_to_val, marked_dynamic, source_name_to_debug_name): + def __init__( + self, + symbol_to_source, + var_to_val, + marked_dynamic, + source_name_to_debug_name, + _allow_complex_guards_as_runtime_asserts=False, + ): # We try to solve systems of inequalities with 1 free variable. self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set) # Among them, we prioritize solving for a free variable that has equalities. @@ -1568,6 +1588,9 @@ def __init__(self, symbol_to_source, var_to_val, marked_dynamic, source_name_to_ # symbols that are marked dynamic self._marked_dynamic = marked_dynamic + # for constraints we can't express with the dynamic shapes language, defer as runtime asserts in export + self._allow_complex_guards_as_runtime_asserts = _allow_complex_guards_as_runtime_asserts + def rewrite_with_congruences(self, s, expr): """ Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. @@ -1831,7 +1854,7 @@ def solve( symbolic_equivalences = self._symbolic_equivalences self._symbolic_equivalences = [] for source, expr in symbolic_equivalences: - if not _disable_forced_specializations and not self._is_supported_equivalence(expr): + if not _disable_forced_specializations and not _is_supported_equivalence(expr): for s in expr.free_symbols: self._force_specialization(s) sexpr = self._dcp._print_Symbol(s) @@ -1842,19 +1865,6 @@ def solve( for source, expr in self._symbolic_equivalences: self._dynamic_results.add(f"{self._dcp.print_source(source)} == {self._dcp.doprint(expr)}") - @classmethod - def _is_supported_equivalence(cls, expr): - # Currently supported Dim ops are linear expressions with integer coefficients. - # So check that expr only contains +, *, ints, and a single occurrence of a symbol. - # (See also documentation of dynamic_shapes._DerivedDim.) - if isinstance(expr, (sympy.Add, sympy.Mul)): - lhs, rhs = expr.args - return ( - (cls._is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or - (isinstance(lhs, sympy.Integer) and cls._is_supported_equivalence(rhs)) - ) - return isinstance(expr, sympy.Symbol) - @classmethod def _is_supported_congruence(cls, congruence): base, divisor = congruence.args @@ -2211,7 +2221,7 @@ def relation_with_digit(expr, op, digit): other = c["eq"] if isinstance(other, int): others.append(f"{k} = {other}") - elif self._is_supported_equivalence(other): + elif _is_supported_equivalence(other): others.append(f"{k} = {other}") else: min_ = c.get("min", None) @@ -2339,6 +2349,7 @@ class ShapeEnvSettings: specialize_zero_one: bool duck_shape: bool prefer_deferred_runtime_asserts_over_guards: bool + _allow_complex_guards_as_runtime_asserts: bool class ShapeEnv: @@ -2432,6 +2443,10 @@ def _init( # in guards is helpful, since these guards in some sense are overly # pedantic. See also https://github.com/pytorch/pytorch/issues/121749 prefer_deferred_runtime_asserts_over_guards=False, + # When True, does not emit or raise constraint violation errors on + # implicit guards generated by ops, and defers to runtime assertions + # in the graph instead. For export. + _allow_complex_guards_as_runtime_asserts=False, # XXX Add any new settings that could affect FakeTensor evaluation # to: torch._subclasses.fake_tensor._ShapeEnvSettings ): @@ -2444,6 +2459,7 @@ def _init( specialize_zero_one=specialize_zero_one, duck_shape=duck_shape, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, ) self.guards: List[ShapeGuard] = [] @@ -2629,6 +2645,10 @@ def duck_shape(self): def prefer_deferred_runtime_asserts_over_guards(self): return self.settings.prefer_deferred_runtime_asserts_over_guards + @property + def _allow_complex_guards_as_runtime_asserts(self): + return self.settings._allow_complex_guards_as_runtime_asserts + def check_equal(self, other: "ShapeEnv") -> None: """Compare another ShapeEnv for equivalence """ @@ -3932,6 +3952,7 @@ def track_symfloat(source, val): self.var_to_val, set(symbol_to_constraints.keys()), self.source_name_to_debug_name, + self._allow_complex_guards_as_runtime_asserts, ) if not _simplified: @@ -4626,6 +4647,9 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No # Precondition: a == tgt assert isinstance(a, sympy.Symbol) + if self._allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt): + return # continuing leads to placeholder shapes having complex expressions that we can't resolve + # Handles nested tensor symbolic variables which don't have # var_to_range bounds tgt_bound = None @@ -5153,34 +5177,43 @@ def compute_concrete_val(): # is no longer necessary) self._maybe_guard_rel(g) - stack = CapturedTraceback.extract(skip=1) - guard = ShapeGuard(g, stack) - self.guards.append(guard) + if not self._allow_complex_guards_as_runtime_asserts: + # at this point, we've evaluated the concrete expr value, and have + # flipped/negated the guard if necessary. Now we know what to guard + # or defer to runtime assert on. + stack = CapturedTraceback.extract(skip=1) + guard = ShapeGuard(g, stack) + self.guards.append(guard) + else: + # it's fine to defer simple guards here without checking, + # the _maybe_guard_rel() call above will set replacements if possible, + # and so the result here will be statically known + self.defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") + except Exception: if fresh: self._remove_fx_node(node) raise else: if not self._suppress_guards_tls(): - assert guard is not None - - self._log_guard("eval", g, forcing_spec=forcing_spec) + if guard is not None: # we might have deferred this to runtime assert + self._log_guard("eval", g, forcing_spec=forcing_spec) - for s in g.free_symbols: - self.symbol_guard_counter[s] += 1 - # Forcing_spec to avoid infinite recursion - if ( - not forcing_spec and - config.symbol_guard_limit_before_specialize is not None and - self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize - ): - # Force specialization - self.log.info( - "symbol_guard_limit_before_specialize=%s exceeded on %s", - config.symbol_guard_limit_before_specialize, - s - ) - self.evaluate_expr(s, forcing_spec=True) + for s in g.free_symbols: + self.symbol_guard_counter[s] += 1 + # Forcing_spec to avoid infinite recursion + if ( + not forcing_spec and + config.symbol_guard_limit_before_specialize is not None and + self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize + ): + # Force specialization + self.log.info( + "symbol_guard_limit_before_specialize=%s exceeded on %s", + config.symbol_guard_limit_before_specialize, + s + ) + self.evaluate_expr(s, forcing_spec=True) else: self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index e32b5a13fb78..12dc62cc16e7 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -90,16 +90,30 @@ def insert_deferred_runtime_asserts( lazy_format_graph_code(f"pre insert_deferred_runtime_asserts {name}", gm), ) + # deduplicate unassociated runtime assertions + # we could do better, some guards might be redundant, + # e.g. Eq(s0, 4) & Eq(2*s0, 8) + # but unclear how to handle all of that right now. + # TODO(pianpwk): better way of doing this + new_ras = [] + ras_exprs: Set[sympy.Expr] = set() + for ras in ras_by_symbol.pop(None, []): # type: ignore[call-overload] + if ras.expr not in ras_exprs: + new_ras.append(ras) + ras_exprs.add(ras.expr) + ras_by_symbol[None] = new_ras # type: ignore[index] + # We are going to mutate the dict symbol_to_proxy: Dict[sympy.Symbol, fx.Proxy] = {} placeholders = set() last_placeholder = None for node in graph.nodes: if node.op != "placeholder": - last_placeholder = node break + last_placeholder = node placeholders.add(node) - assert last_placeholder is not None + if last_placeholder is None: # no placeholders, just insert before first node + last_placeholder = next(iter(graph.nodes)) # Identify what symbols we need to reify. This isn't strictly needed # but helps reduce churn on the graph @@ -137,6 +151,7 @@ def add_runtime_asserts(ras): ), ) + inserted_sym_nodes = 0 # for inserting unassociated runtime asserts nodes = list(graph.nodes) for i, node in enumerate(nodes[:-1]): # Placeholders can match symbols, but when we destructure them @@ -164,6 +179,8 @@ def match_symbol(symint, cb): ): symbol_to_proxy[s] = fx.Proxy(cb()) log.debug("symbol_to_proxy[%s] = %s", s, symbol_to_proxy[s]) + nonlocal inserted_sym_nodes + inserted_sym_nodes += 1 match_symbol(example_value, lambda: node) if isinstance(t := example_value, torch.Tensor): @@ -191,8 +208,13 @@ def match_symbol(symint, cb): # Handle asserts that aren't associated with any symbol. This # doesn't really have to be in the loop as it will only run once, # it just needs to happen right after the placeholders. + # insert this after placeholders & added sym nodes, and before non-placeholders. if node not in placeholders: - add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload] + last_sym_node = last_placeholder + for _ in range(inserted_sym_nodes): + last_sym_node = last_sym_node.next + with graph.inserting_before(last_sym_node.next): + add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload] defs = [] diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 8bd688b0c0c9..881b9d616eb5 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -71,7 +71,9 @@ def square(x): @staticmethod def mod(x, y): ret = abs(x) % abs(y) - if x < 0: + # without check: + # tracing will fail trying to go through control-flow if x is Proxy() + if isinstance(x, (int, sympy.Number)) and x < 0: ret *= -1 return ret From 28de9143a3f07dfd21bf8078444f6c15f829b29b Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 29 May 2024 06:22:55 -0700 Subject: [PATCH 044/706] opcheck should be usable without optional dependencies (#127292) This PR excises opcheck's dependency on torch.testing._internal.common_utils, (which comes with dependencies on expecttest and hypothesis). We do this by moving what we need to torch.testing._utils and adding a test for it. Fixes #126870, #126871 Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/127292 Approved by: https://github.com/williamwen42 ghstack dependencies: #127291 --- test/test_custom_ops.py | 15 ++++++ torch/testing/__init__.py | 1 + .../_internal/common_methods_invocations.py | 19 +------ torch/testing/_internal/common_utils.py | 34 ++----------- .../testing/_internal/optests/aot_autograd.py | 2 +- torch/testing/_internal/optests/make_fx.py | 2 +- torch/testing/_utils.py | 50 +++++++++++++++++++ 7 files changed, 74 insertions(+), 49 deletions(-) create mode 100644 torch/testing/_utils.py diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index db03098a0fec..e2af3efaa98a 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -3154,6 +3154,21 @@ def test_opcheck_bad_op(self): }, ) + def test_opcheck_does_not_require_extra_deps(self): + # torch.testing._internal.common_utils comes with a lot of additional + # test-time dependencies. Since opcheck is public API, it should be + # usable only with pytorch install-time dependencies. + cmd = [ + sys.executable, + "-c", + "import torch; import sys; \ + x = torch.randn(3, requires_grad=True); \ + torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \ + assert 'expecttest' not in sys.modules; \ + assert 'torch.testing._internal.common_utils' not in sys.modules", + ] + subprocess.check_output(cmd, shell=False) + only_for = ("cpu", "cuda") instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for) diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 58b8f828e354..352ce67e074a 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -1,3 +1,4 @@ from torch._C import FileCheck as FileCheck +from . import _utils from ._comparison import assert_allclose, assert_close as assert_close from ._creation import make_tensor as make_tensor diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f551beb759cf..62958b136364 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -36,9 +36,10 @@ make_fullrank_matrices_with_distinct_singular_values, TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, - GRADCHECK_NONDET_TOL, freeze_rng_state, slowTest, TEST_WITH_SLOW, + GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW, TEST_WITH_TORCHINDUCTOR ) +from torch.testing._utils import wrapper_set_seed import torch._refs as refs # noqa: F401 import torch._refs.nn.functional @@ -11299,22 +11300,6 @@ def reference_mse_loss(input, target, reduction="mean"): return se -def wrapper_set_seed(op, *args, **kwargs): - """Wrapper to set seed manually for some functions like dropout - See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details. - """ - with freeze_rng_state(): - torch.manual_seed(42) - output = op(*args, **kwargs) - - if isinstance(output, torch.Tensor) and output.device.type == "lazy": - # We need to call mark step inside freeze_rng_state so that numerics - # match eager execution - torch._lazy.mark_step() - - return output - - def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5): return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0] diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 5f9ef602d518..2237ec67c500 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -96,7 +96,6 @@ from torch.utils._import_utils import _check_module_exists import torch.utils._pytree as pytree -from .composite_compliance import no_dispatch try: import pytest has_pytest = True @@ -104,6 +103,10 @@ has_pytest = False +def freeze_rng_state(*args, **kwargs): + return torch.testing._utils.freeze_rng_state(*args, **kwargs) + + # Class to keep track of test flags configurable by environment variables. # Flags set here are intended to be read-only and should not be modified after # definition. @@ -1949,35 +1952,6 @@ def set_rng_seed(seed): np.random.seed(seed) -disable_functorch = torch._C._DisableFuncTorch - - -@contextlib.contextmanager -def freeze_rng_state(): - # no_dispatch needed for test_composite_compliance - # Some OpInfos use freeze_rng_state for rng determinism, but - # test_composite_compliance overrides dispatch for all torch functions - # which we need to disable to get and set rng state - with no_dispatch(), disable_functorch(): - rng_state = torch.get_rng_state() - if torch.cuda.is_available(): - cuda_rng_state = torch.cuda.get_rng_state() - try: - yield - finally: - # Modes are not happy with torch.cuda.set_rng_state - # because it clones the state (which could produce a Tensor Subclass) - # and then grabs the new tensor's data pointer in generator.set_state. - # - # In the long run torch.cuda.set_rng_state should probably be - # an operator. - # - # NB: Mode disable is to avoid running cross-ref tests on thes seeding - with no_dispatch(), disable_functorch(): - if torch.cuda.is_available(): - torch.cuda.set_rng_state(cuda_rng_state) - torch.set_rng_state(rng_state) - @contextlib.contextmanager def set_default_dtype(dtype): saved_dtype = torch.get_default_dtype() diff --git a/torch/testing/_internal/optests/aot_autograd.py b/torch/testing/_internal/optests/aot_autograd.py index 13ce9e883789..4f281c777175 100644 --- a/torch/testing/_internal/optests/aot_autograd.py +++ b/torch/testing/_internal/optests/aot_autograd.py @@ -2,7 +2,7 @@ import torch import torch.utils._pytree as pytree -from torch.testing._internal.common_methods_invocations import wrapper_set_seed +from torch.testing._utils import wrapper_set_seed from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop from .make_fx import randomize import re diff --git a/torch/testing/_internal/optests/make_fx.py b/torch/testing/_internal/optests/make_fx.py index 95f746a31af3..83cefd18bc05 100644 --- a/torch/testing/_internal/optests/make_fx.py +++ b/torch/testing/_internal/optests/make_fx.py @@ -2,7 +2,7 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx -from torch.testing._internal.common_methods_invocations import wrapper_set_seed +from torch.testing._utils import wrapper_set_seed import torch.utils._pytree as pytree diff --git a/torch/testing/_utils.py b/torch/testing/_utils.py new file mode 100644 index 000000000000..b85860eeff03 --- /dev/null +++ b/torch/testing/_utils.py @@ -0,0 +1,50 @@ +import contextlib + +import torch + +# Common testing utilities for use in public testing APIs. +# NB: these should all be importable without optional dependencies +# (like numpy and expecttest). + + +def wrapper_set_seed(op, *args, **kwargs): + """Wrapper to set seed manually for some functions like dropout + See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details. + """ + with freeze_rng_state(): + torch.manual_seed(42) + output = op(*args, **kwargs) + + if isinstance(output, torch.Tensor) and output.device.type == "lazy": + # We need to call mark step inside freeze_rng_state so that numerics + # match eager execution + torch._lazy.mark_step() # type: ignore[attr-defined] + + return output + + +@contextlib.contextmanager +def freeze_rng_state(): + # no_dispatch needed for test_composite_compliance + # Some OpInfos use freeze_rng_state for rng determinism, but + # test_composite_compliance overrides dispatch for all torch functions + # which we need to disable to get and set rng state + with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): + rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + try: + yield + finally: + # Modes are not happy with torch.cuda.set_rng_state + # because it clones the state (which could produce a Tensor Subclass) + # and then grabs the new tensor's data pointer in generator.set_state. + # + # In the long run torch.cuda.set_rng_state should probably be + # an operator. + # + # NB: Mode disable is to avoid running cross-ref tests on thes seeding + with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] + torch.set_rng_state(rng_state) From 8b5cbb7c685cb65fca584c02cdb8e4a35e5678c9 Mon Sep 17 00:00:00 2001 From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Date: Wed, 29 May 2024 17:29:04 +0000 Subject: [PATCH 045/706] Improve NLLLoss docs (#127346) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127346 Approved by: https://github.com/mikaylagawarecki --- torch/nn/modules/loss.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index fb7172e9ae54..4324c1df144d 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -167,8 +167,8 @@ class NLLLoss(_WeightedLoss): the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` - Shape: - - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or + Shape:: + - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, `N = batch size`, or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Target: :math:`(N)` or :math:`()`, where each value is @@ -181,27 +181,29 @@ class NLLLoss(_WeightedLoss): Examples:: - >>> m = nn.LogSoftmax(dim=1) - >>> loss = nn.NLLLoss() - >>> # input is of size N x C = 3 x 5 + >>> log_softmax = nn.LogSoftmax(dim=1) + >>> loss_fn = nn.NLLLoss() + >>> # input to NLLLoss is of size N x C = 3 x 5 >>> input = torch.randn(3, 5, requires_grad=True) - >>> # each element in target has to have 0 <= value < C + >>> # each element in target must have 0 <= value < C >>> target = torch.tensor([1, 0, 4]) - >>> output = loss(m(input), target) - >>> output.backward() + >>> loss = loss_fn(log_softmax(input), target) + >>> loss.backward() >>> >>> >>> # 2D loss example (used, for example, with image inputs) >>> N, C = 5, 4 - >>> loss = nn.NLLLoss() - >>> # input is of size N x C x height x width + >>> loss_fn = nn.NLLLoss() >>> data = torch.randn(N, 16, 10, 10) >>> conv = nn.Conv2d(16, C, (3, 3)) - >>> m = nn.LogSoftmax(dim=1) - >>> # each element in target has to have 0 <= value < C + >>> log_softmax = nn.LogSoftmax(dim=1) + >>> # output of conv forward is of shape [N, C, 8, 8] + >>> output = log_softmax(conv(data)) + >>> # each element in target must have 0 <= value < C >>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) - >>> output = loss(m(conv(data)), target) - >>> output.backward() + >>> # input to NLLLoss is of size N x C x height (8) x width (8) + >>> loss = loss_fn(output, target) + >>> loss.backward() """ __constants__ = ['ignore_index', 'reduction'] ignore_index: int From 090a031d6f146da3691e99f5f86220d766af4b10 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 28 May 2024 11:49:14 -0700 Subject: [PATCH 046/706] Use bit_cast instead of UB type-pun-via-union in Half.h (#127321) Summary: Type punning via union has undefined behavior due to the strict aliasing rule. bit_cast does the same thing safely (using memcpy under the hood). Test Plan: CI Godbolt demonstrates that doing this via memcpy still generates the same instructions: https://godbolt.org/z/PhePzd4Ex Pull Request resolved: https://github.com/pytorch/pytorch/pull/127321 Approved by: https://github.com/ezyang, https://github.com/Skylion007 --- c10/util/Half.h | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/c10/util/Half.h b/c10/util/Half.h index af3435941e48..afc90f106a6f 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -330,20 +331,12 @@ inline uint16_t fp16_ieee_from_fp32_value(float f) { } #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) -constexpr inline float16_t fp16_from_bits(uint16_t h) { - union { - uint16_t as_bits; - float16_t as_value; - } fp16 = {h}; - return fp16.as_value; +inline float16_t fp16_from_bits(uint16_t h) { + return c10::bit_cast(h); } -constexpr inline uint16_t fp16_to_bits(float16_t f) { - union { - float16_t as_value; - uint16_t as_bits; - } fp16 = {.as_value = f}; - return fp16.as_bits; +inline uint16_t fp16_to_bits(float16_t f) { + return c10::bit_cast(f); } // According to https://godbolt.org/z/8s14GvEjo it would translate to single From d938170314fa89acaad6b06fbbaac6b98f1e618f Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 29 May 2024 18:22:29 +0000 Subject: [PATCH 047/706] Add torchao nightly testing workflow (#126885) Add and test torchao nightly testing workflow. This workflow will be triggered under the following conditions: 1. If the PR has ciflow/torchao label 2. Manual trigger It will run the torchao benchmark on torchbench/timm/huggingface model workloads with 5 configs (noquant, autoquant, int8dynamic, int8weightonly, int4weightonly). The output will be updated to the PT2 Dashboard: https://hud.pytorch.org/benchmark/compilers Pull Request resolved: https://github.com/pytorch/pytorch/pull/126885 Approved by: https://github.com/huydhn --- .ci/pytorch/common_utils.sh | 11 ++++ .ci/pytorch/test.sh | 106 +++++++++++++++++++++++++++++++++- .github/pytorch-probot.yml | 1 + .github/workflows/torchao.yml | 85 +++++++++++++++++++++++++++ 4 files changed, 200 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/torchao.yml diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 51297f7bfff8..71e98cfaa721 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -158,6 +158,17 @@ function install_torchvision() { fi } +function install_torchao() { + # Set ARCH list so that we can build fp16 with SM75+, the logic is copied from + # pytorch/builder + # https://github.com/pytorch/ao/blob/main/packaging/env_var_script_linux.sh#L16C1-L19 + TORCH_CUDA_ARCH_LIST="8.0;8.6" + if [[ ${CU_VERSION:-} == "cu124" ]]; then + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0" + fi + pip_install --no-use-pep517 --user "git+https://github.com/pytorch/ao.git" +} + function install_tlparse() { pip_install --user "tlparse==0.3.7" PATH="$(python -m site --user-base)/bin:$PATH" diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 190f99204e9c..76d7e259f365 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -483,6 +483,89 @@ test_perf_for_dashboard() { done } +test_torchao_perf_for_dashboard() { + TEST_REPORTS_DIR=$(pwd)/test/test-reports + mkdir -p "$TEST_REPORTS_DIR" + + local suite="$1" + shift + + local backend=torchao + local modes=() + if [[ "$DASHBOARD_TAG" == *training-true* ]]; then + modes+=(training) + fi + if [[ "$DASHBOARD_TAG" == *inference-true* ]]; then + modes+=(inference) + fi + # TODO: All the accuracy tests can be skipped once the CI accuracy checking is stable enough + local targets=(accuracy performance) + local torchao_backends=(noquant int8dynamic int8weightonly int4weightonly autoquant) + + for mode in "${modes[@]}"; do + if [[ "$mode" == "inference" ]]; then + dtype=bfloat16 + elif [[ "$mode" == "training" ]]; then + dtype=amp + fi + for target in "${targets[@]}"; do + local target_flag=("--${target}") + if [[ "$target" == "performance" ]]; then + target_flag+=( --cold-start-latency) + elif [[ "$target" == "accuracy" ]]; then + target_flag+=( --no-translation-validation) + fi + + for torchao_backend in "${torchao_backends[@]}"; do + if [[ "$DASHBOARD_TAG" == *${torchao_backend}-true* ]]; then + python "benchmarks/dynamo/$suite.py" \ + "${target_flag[@]}" --"$mode" --"$dtype" --quantization "${torchao_backend}" "$@" \ + --output "$TEST_REPORTS_DIR/${backend}_${torchao_backend}_${suite}_${dtype}_${mode}_cuda_${target}.csv" + fi + done + done + done +} + +test_single_torchao_benchmark() { + # Usage: test_single_torchao_benchmark huggingface 0 --args-for-script + + # Use test-reports directory under test folder will allow the CI to automatically pick up + # the test reports and upload them to S3. Need to use full path here otherwise the script + # will bark about file not found later on + TEST_REPORTS_DIR=$(pwd)/test/test-reports + mkdir -p "$TEST_REPORTS_DIR" + + local name="$1" + shift + local suite="$1" + shift + # shard id is mandatory, even if it is not passed + local shard_id="$1" + shift + + local partition_flags=() + if [[ -n "$NUM_TEST_SHARDS" && -n "$shard_id" ]]; then + partition_flags=( --total-partitions "$NUM_TEST_SHARDS" --partition-id "$shard_id" ) + fi + + test_torchao_perf_for_dashboard "$suite" \ + "${TORCHAO_BENCHMARK_FLAGS[@]}" "$@" "${partition_flags[@]}" + +} + +test_torchao_benchmark() { + # Usage: test_torchao_benchmark huggingface 0 + TEST_REPORTS_DIR=$(pwd)/test/test-reports + + local suite="$1" + shift + local shard_id="$1" + shift + + test_single_torchao_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "$@" +} + test_single_dynamo_benchmark() { # Usage: test_single_dynamo_benchmark inductor_inference huggingface 0 --args-for-script @@ -1220,15 +1303,15 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then test_inductor_distributed elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then test_inductor_micro_benchmark -elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then +elif [[ "${TEST_CONFIG}" == *inductor_huggingface* ]]; then install_torchvision id=$((SHARD_NUMBER-1)) test_dynamo_benchmark huggingface "$id" -elif [[ "${TEST_CONFIG}" == *timm* ]]; then +elif [[ "${TEST_CONFIG}" == *inductor_timm* ]]; then install_torchvision id=$((SHARD_NUMBER-1)) test_dynamo_benchmark timm_models "$id" -elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then +elif [[ "${TEST_CONFIG}" == *inductor_torchbench* ]]; then if [[ "${TEST_CONFIG}" == *cpu_inductor* ]]; then install_torchaudio cpu else @@ -1259,6 +1342,23 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then fi PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id" fi +elif [[ "${TEST_CONFIG}" == *torchao_huggingface* ]]; then + install_torchao + install_torchvision + id=$((SHARD_NUMBER-1)) + test_torchao_benchmark huggingface "$id" +elif [[ "${TEST_CONFIG}" == *torchao_timm* ]]; then + install_torchao + install_torchvision + id=$((SHARD_NUMBER-1)) + test_torchao_benchmark timm_models "$id" +elif [[ "${TEST_CONFIG}" == *torchao_torchbench* ]]; then + install_torchao + install_torchaudio cuda + install_torchvision + id=$((SHARD_NUMBER-1)) + checkout_install_torchbench + PYTHONPATH=$(pwd)/torchbench test_torchao_benchmark torchbench "$id" elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper_abi_compatible* ]]; then install_torchvision test_inductor_cpp_wrapper_abi_compatible diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index d54346f81650..ab5cb0deba87 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -18,6 +18,7 @@ ciflow_push_tags: - ciflow/unstable - ciflow/xpu - ciflow/torchbench +- ciflow/torchao retryable_workflows: - pull - trunk diff --git a/.github/workflows/torchao.yml b/.github/workflows/torchao.yml new file mode 100644 index 000000000000..0854eb099e92 --- /dev/null +++ b/.github/workflows/torchao.yml @@ -0,0 +1,85 @@ +name: torchao + +on: + push: + tags: + - ciflow/torchao/* + workflow_dispatch: + inputs: + noquant: + description: Run noquant? + required: false + type: boolean + default: true + int8dynamic: + description: Run int8dynamic? + required: false + type: boolean + default: true + int8weightonly: + description: Run int8weightonly? + required: false + type: boolean + default: true + int4weightonly: + description: Run int4weightonly? + required: false + type: boolean + default: true + autoquant: + description: Run autoquant? + required: false + type: boolean + default: true + benchmark_configs: + description: The list of configs used the benchmark + required: false + type: string + default: torchao_huggingface_perf,torchao_timm_perf,torchao_torchbench_perf + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: read-all + +jobs: + linux-focal-cuda12_1-py3_10-gcc9-torchao-build: + name: cuda12.1-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + cuda-arch-list: '8.0' + test-matrix: | + { include: [ + { config: "torchao_huggingface_perf", shard: 1, num_shards: 3, runner: "linux.gcp.a100.large" }, + { config: "torchao_huggingface_perf", shard: 2, num_shards: 3, runner: "linux.gcp.a100.large" }, + { config: "torchao_huggingface_perf", shard: 3, num_shards: 3, runner: "linux.gcp.a100.large" }, + { config: "torchao_timm_perf", shard: 1, num_shards: 5, runner: "linux.gcp.a100.large" }, + { config: "torchao_timm_perf", shard: 2, num_shards: 5, runner: "linux.gcp.a100.large" }, + { config: "torchao_timm_perf", shard: 3, num_shards: 5, runner: "linux.gcp.a100.large" }, + { config: "torchao_timm_perf", shard: 4, num_shards: 5, runner: "linux.gcp.a100.large" }, + { config: "torchao_timm_perf", shard: 5, num_shards: 5, runner: "linux.gcp.a100.large" }, + { config: "torchao_torchbench_perf", shard: 1, num_shards: 4, runner: "linux.gcp.a100.large" }, + { config: "torchao_torchbench_perf", shard: 2, num_shards: 4, runner: "linux.gcp.a100.large" }, + { config: "torchao_torchbench_perf", shard: 3, num_shards: 4, runner: "linux.gcp.a100.large" }, + { config: "torchao_torchbench_perf", shard: 4, num_shards: 4, runner: "linux.gcp.a100.large" }, + ]} + selected-test-configs: ${{ inputs.benchmark_configs }} + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_1-py3_10-gcc9-torchao-test: + name: cuda12.1-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_1-py3_10-gcc9-torchao-build + with: + build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 + dashboard-tag: noquant-${{ inputs.noquant }}-int8dynamic-${{ inputs.int8dynamic }}-int8weightonly-${{ inputs.int8weightonly }}-int4weightonly-${{ inputs.int4weightonly }}-autoquant-${{ inputs.autoquant }} + docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-torchao-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-torchao-build.outputs.test-matrix }} + use-gha: anything-non-empty-to-use-gha + timeout-minutes: 720 + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} From 9cc0d56fdc4be3873582aa23e49b051aa27a6d2d Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 29 May 2024 18:30:48 +0000 Subject: [PATCH 048/706] Remove unused variables in tests (#127379) Reland test fixes in #127161 and reduce reduce_ops_test into floating point types. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127379 Approved by: https://github.com/ezyang --- .../core/boxing/impl/kernel_lambda_legacy_test.cpp | 12 ++---------- .../core/op_registration/op_registration_test.cpp | 1 - aten/src/ATen/test/pow_test.cpp | 8 +------- aten/src/ATen/test/reduce_ops_test.cpp | 5 ++--- aten/src/ATen/test/scalar_test.cpp | 1 - 5 files changed, 5 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/core/boxing/impl/kernel_lambda_legacy_test.cpp b/aten/src/ATen/core/boxing/impl/kernel_lambda_legacy_test.cpp index 39dceafab006..8db6abad6c33 100644 --- a/aten/src/ATen/core/boxing/impl/kernel_lambda_legacy_test.cpp +++ b/aten/src/ATen/core/boxing/impl/kernel_lambda_legacy_test.cpp @@ -731,8 +731,7 @@ TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenFallbackKernelWithout } TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool called; + bool called = false; std::optional called_arg2 = c10::nullopt; std::optional called_arg3 = c10::nullopt; std::optional called_arg4 = c10::nullopt; @@ -771,8 +770,7 @@ TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInp } TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool called; + bool called = false; std::optional called_arg2 = c10::nullopt; std::optional called_arg3 = c10::nullopt; std::optional called_arg4 = c10::nullopt; @@ -814,12 +812,6 @@ TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInp } TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool called; - std::optional called_arg2 = c10::nullopt; - std::optional called_arg3 = c10::nullopt; - std::optional called_arg4 = c10::nullopt; - auto registrar = RegisterOperators().op( "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", [] (Tensor arg1, const std::optional& arg2, std::optional arg3, std::optional arg4) { diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 707269de902e..5d12aa3d35d7 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -307,7 +307,6 @@ void stackBasedKernel(const OperatorHandle&, c10::Stack* stack) { } TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsByNameAndNoneCanInferSchema_thenFails) { - bool called_kernel = false; expectThrows([&] { auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel<&stackBasedKernel>(c10::DispatchKey::CPU) diff --git a/aten/src/ATen/test/pow_test.cpp b/aten/src/ATen/test/pow_test.cpp index fb3b073f29f3..95bb48b341f5 100644 --- a/aten/src/ATen/test/pow_test.cpp +++ b/aten/src/ATen/test/pow_test.cpp @@ -10,12 +10,6 @@ #include #include -#ifdef _WIN32 -#define DISABLED_ON_WINDOWS(x) DISABLED_##x -#else -#define DISABLED_ON_WINDOWS(x) x -#endif - using namespace at; namespace { @@ -204,7 +198,7 @@ void tensor_pow_tensor(const Vals vals, c10::ScalarType vals_dtype, Pows pows, c std::cout.precision(dbl::max_digits10); const auto vals_tensor = torch::tensor(vals, vals_dtype); - for (const auto shift : c10::irange(pows.size())) { + for ([[maybe_unused]] const auto shirt : c10::irange(pows.size())) { const auto pows_tensor = torch::tensor(pows, pows_dtype); const auto actual_pow = vals_tensor.pow(pows_tensor); diff --git a/aten/src/ATen/test/reduce_ops_test.cpp b/aten/src/ATen/test/reduce_ops_test.cpp index bcae3fdc51f9..a9ce7e4cf8f4 100644 --- a/aten/src/ATen/test/reduce_ops_test.cpp +++ b/aten/src/ATen/test/reduce_ops_test.cpp @@ -9,9 +9,8 @@ TEST(ReduceOpsTest, MaxValuesAndMinValues) { const int W = 10; const int H = 10; if (hasCUDA()) { - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - for (const auto dtype : {kHalf, kFloat, kDouble, kShort, kInt, kLong}) { - auto a = at::rand({H, W}, TensorOptions(kCUDA).dtype(at::kHalf)); + for (const auto dtype : {kHalf, kFloat, kDouble}) { + auto a = at::rand({H, W}, TensorOptions(kCUDA).dtype(dtype)); ASSERT_FLOAT_EQ( a.amax(c10::IntArrayRef{0, 1}).item(), a.max().item() diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp index c10e8386d683..0d7b62b44d21 100644 --- a/aten/src/ATen/test/scalar_test.cpp +++ b/aten/src/ATen/test/scalar_test.cpp @@ -82,7 +82,6 @@ TEST(TestScalar, TestScalar) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) ASSERT_NO_THROW(gen.set_current_seed(std::random_device()())); } - auto&& C = at::globalContext(); if (at::hasCUDA()) { auto t2 = zeros({4, 4}, at::kCUDA); cout << &t2 << "\n"; From ff65b18fcfb4bb3867cb31946fce76498031236c Mon Sep 17 00:00:00 2001 From: lancerts Date: Wed, 29 May 2024 18:53:14 +0000 Subject: [PATCH 049/706] Update the is_causal explaination in the SDPA doc (#127209) Fixes #126873 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127209 Approved by: https://github.com/drisspg --- torch/nn/functional.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 805e0b40cdd2..8d0d43087b2c 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -5088,8 +5088,10 @@ def forward(self, ...): A boolean mask where a value of True indicates that the element *should* take part in attention. A float mask of the same type as query, key, value that is added to the attention score. dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied - is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal - are set. + is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a + square matrix. The attention masking has the form of the upper left causal bias due to the alignment + (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. + An error is thrown if both attn_mask and is_causal are set. scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set to :math:`\frac{1}{\sqrt{E}}`. From 5196ef1b59382e86d7636312872ca78d551fe2aa Mon Sep 17 00:00:00 2001 From: laithsakka Date: Wed, 29 May 2024 11:55:32 -0700 Subject: [PATCH 050/706] support builtin id function on user defined object variables. (#127146) Fix: https://github.com/pytorch/pytorch/pull/127146 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127146 Approved by: https://github.com/anijain2305 ghstack dependencies: #126444 --- test/dynamo/test_misc.py | 53 ++++++++++++++++++++++++++++++ torch/_dynamo/variables/builtin.py | 6 ++++ 2 files changed, 59 insertions(+) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 6a44c5603dc9..abc4f52dfbfb 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -4231,6 +4231,59 @@ def fn_has_breaks(x): opt_fn(x) self.assertEqual(cnts.frame_count, 2) + def test_id_guarded_object(self): + class UDO: + @torch.compile(backend="eager") + def call(self, x, ref_id): + self_id = id(self) + if self_id == ref_id: + x = torch.mul(x, 1.0) + else: + x = torch.mul(x, 0) + return x + + # Make sure we do recompile when id(self) is executed on + # different self objects. + x = torch.ones(2) + obj1 = UDO() + obj1_id = id(obj1) + self.assertEqual(obj1.call(x, obj1_id), torch.ones(2)) + + obj2 = UDO() + # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails. + self.assertEqual(obj2.call(x, obj1_id), torch.zeros(2)) + + def test_id_guarded_module(self): + class M(torch.nn.Module): + def forward(self, x, ref_id): + self_id = id(self) + if self_id == ref_id: + x = torch.mul(x, 1.0) + else: + x = torch.mul(x, 0) + return x + + cnts = torch._dynamo.testing.CompileCounter() + + # Make sure we do recompile when id(self) is executed on + # different self objects. + x = torch.ones(2) + m1 = M() + m1_id = id(m1) + opt_m1 = torch._dynamo.optimize(cnts, nopython=True)(m1) + self.assertEqual(opt_m1(x, m1_id), torch.ones(2)) + self.assertEqual(opt_m1(x, m1_id), torch.ones(2)) + + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 1) + + m2 = M() + opt_m2 = torch._dynamo.optimize(cnts, nopython=True)(m2) + # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails. + self.assertEqual(opt_m2(x, m1_id), torch.zeros(2)) + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(cnts.op_count, 2) + def test_id_of_nn_module(self): class M(torch.nn.Module): def forward(self, x, ref_id): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 605f56b3047d..306a17c018f9 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1813,6 +1813,12 @@ def call_id(self, tx, *args): nn_mod_variable = args[0] mod = tx.output.get_submodule(nn_mod_variable.module_key) return variables.ConstantVariable.create(id(mod)) + elif len(args) == 1 and isinstance( + args[0], variables.UserDefinedObjectVariable + ): + install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH)) + constant_result = id(args[0].value) + return variables.ConstantVariable.create(constant_result) else: unimplemented(f"call_id with args {args}") From 90f4b3fcb24bfd892d7e23f79a457443b79c6d39 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Wed, 29 May 2024 19:08:20 +0000 Subject: [PATCH 051/706] PyTorch Distributed security assumptions (#127403) To highlight, that PyTorch Distributed should only be used in a trusted environment and never on the nodes with open network access, which is very similar in spirit to https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md#running-a-tensorflow-server Thanks to @Xbalien and @K1ingzzz for drawing attention to missing documentation on distributed workloads security assumptions Pull Request resolved: https://github.com/pytorch/pytorch/pull/127403 Approved by: https://github.com/wconstab --- SECURITY.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/SECURITY.md b/SECURITY.md index e8e0249fc896..a6f676ef39be 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -5,6 +5,7 @@ - [Untrusted models](#untrusted-models) - [Untrusted inputs](#untrusted-inputs) - [Data privacy](#data-privacy) + - [Using distributed features](#using-distributed-features) ## Reporting Security Issues @@ -54,3 +55,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some **Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and: - Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment) - If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits). + +### Using distributed features + +PyTorch can be used for distributed computing, and as such there is a `torch.distributed` package. PyTorch Distributed features are intended for internal communication only. They are not built for use in untrusted environments or networks. + +For performance reasons, none of the PyTorch Distributed primitives (including c10d, RPC, and TCPStore) include any authorization protocol and will send messages unencrypted. They accept connections from anywhere, and execute the workload sent without performing any checks. Therefore, if you run a PyTorch Distributed program on your network, anybody with access to the network can execute arbitrary code with the privileges of the user running PyTorch. From 601c5e085df3cf0d769911e7fa025aa114a1a2b3 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 29 May 2024 08:18:19 -0700 Subject: [PATCH 052/706] Add _foreach_max (#127187) This PR adds _foreach_max support, the second reduction foreach op we have :D I did have to change the autogen slightly for foreach. I can promise that the existing foreach ops' derivative behavior has not changed as I've added a skip list for the harder requirement I am setting (that the arg list should match in length). I needed to add this requirement as there is another wrong max (the one that does take in a dim for reduction) that keeps getting matched first. Caveats! - We do not fast path if the shapes, dtypes, device, the regular shebang for foreach are not met. We fall back to slowpath! - MORE IMPORTANTLY, we also do not fast path for int8 and int16 and bool, but that's really a skill issue on my end as I've hardcoded -INFINITY into the CUDA kernels, and -INFINITY is not defined for small ints. It'd be nice to know how to do this properly, but that work can also come later. - This does NOT support empty Tensors in the list, because the original max op also does not support empty Tensors. ~I think this should be allowed though, and this PR may come later.~ I understand why this is not allowed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127187 Approved by: https://github.com/albanD --- aten/src/ATen/native/ForeachOpsKernels.cpp | 11 ++ aten/src/ATen/native/cuda/ForeachReduceOp.cu | 186 ++++++++++++++++++ aten/src/ATen/native/cuda/block_reduce.cuh | 2 +- aten/src/ATen/native/native_functions.yaml | 8 + ...asDecompTest.test_has_decomposition.expect | 2 + test/test_foreach.py | 42 ++-- torch/_meta_registrations.py | 1 + .../_internal/common_methods_invocations.py | 60 +++++- torchgen/api/autograd.py | 16 ++ 9 files changed, 314 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp index 34c71a886862..9656e2aa4f72 100644 --- a/aten/src/ATen/native/ForeachOpsKernels.cpp +++ b/aten/src/ATen/native/ForeachOpsKernels.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -55,6 +56,7 @@ #include #include #include +#include #include #include #include @@ -448,6 +450,15 @@ std::vector foreach_tensor_norm_slow( return result; } +std::vector foreach_tensor_max_slow(TensorList tensors) { + check_foreach_api_restrictions(tensors); + std::vector result; + for (const auto& t : tensors) { + result.emplace_back(at::max(t)); + } + return result; +} + std::vector foreach_scalar_pow_list_kernel_slow( const Scalar& self, TensorList exponent) { diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu index 885c5d021e8c..04b7c12e9a1a 100644 --- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu +++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu @@ -16,6 +16,7 @@ #include #include #else +#include #include #include @@ -44,6 +45,191 @@ struct TensorListAddresses { const void* addresses[MAX_TENSORS_PER_KERNEL]; }; +template < + typename T, + int depth = 1, + int r_args_depth = 1, + int res_arg_index = 0> +struct LpMaxFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + T* output_per_tensor_ptr, + const int max_chunks_per_tensor) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* x = (T*)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; + n -= chunk_idx * chunk_size; + + __shared__ T s_vals[512]; + T vals[kILP]; + T r_x[kILP]; + for (int64_t i = 0; i < kILP; i++) { + vals[i] = T(-INFINITY); + r_x[i] = T(-INFINITY); + } + + if (n % kILP == 0 && (chunk_size & kILP) == 0 && is_aligned(x)) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + vals[ii] = max_propagate_nan(vals[ii], r_x[ii]); + } + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + vals[ii] = max_propagate_nan(vals[ii], x[i]); + } + } + } + } + + auto val = T(-INFINITY); + for (int i = 0; i < kILP; i++) { + val = max_propagate_nan(val, vals[i]); + } + auto final_val = at::native::cuda_utils::BlockReduceMax(val, s_vals); + + if (threadIdx.x == 0) { + output_per_tensor_ptr + [(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + + chunk_idx] = final_val; + } + } +}; + +template +__global__ void lpmax_cleanup( + const T* output_per_tensor, + TensorListAddresses addr_struct, + int max_chunks_per_tensor) { + __shared__ T vals[512]; + const T* output_this_tensor = + output_per_tensor + blockIdx.x * max_chunks_per_tensor; + T val = -INFINITY; + for (size_t i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) { + val = max_propagate_nan(val, output_this_tensor[i]); + } + T final_val = at::native::cuda_utils::BlockReduceMax(val, vals); + if (threadIdx.x == 0) { + *(T*)addr_struct.addresses[blockIdx.x] = final_val; + } +} + +std::vector foreach_tensor_max_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + // we currently use -INF as the identity value to compare against, which + // does not work for int8, int16, nor bool. Fall back to slow path here. + const bool has_small_int_or_bool = + std::any_of(tensors.begin(), tensors.end(), [](const auto& t) { + const auto scalar_type = t.scalar_type(); + return scalar_type == at::ScalarType::Short || + scalar_type == at::ScalarType::Char || + scalar_type == at::ScalarType::Bool; + }); + if (!can_use_fast_route(tensors) || has_small_int_or_bool) { + return foreach_tensor_max_slow(tensors); + } + + // for parity with max in ReduceAllOps.cpp, though I think max(empty) should + // eventually be allowed. + TORCH_CHECK( + std::all_of( + tensors.begin(), + tensors.end(), + [](const auto& t) { return t.numel() > 0; }), + "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument."); + + const size_t ntensors = tensors.size(); + int max_chunks_per_tensor = -1; + + for (const auto t : c10::irange(ntensors)) { + int max_chunks_this_tensor = + (tensors[t].numel() + kChunkSize - 1) / kChunkSize; + if (max_chunks_this_tensor > max_chunks_per_tensor) { + max_chunks_per_tensor = max_chunks_this_tensor; + } + } + const auto options = tensors[0].options(); + auto output_per_tensor = at::zeros( + {static_cast(ntensors) * max_chunks_per_tensor}, options); + + std::vector vec_res; + vec_res.reserve(ntensors); + for (const auto i : c10::irange(ntensors)) { + vec_res.push_back(at::empty({}, options)); + } + + auto tensor_lists = std::vector>{tensors.vec()}; + + AT_DISPATCH_ALL_TYPES_AND3( + kHalf, + kBFloat16, + kBool, + tensor_lists[0][0].scalar_type(), + "foreach_tensor_max_cuda_scalar_type", + [&]() { + multi_tensor_apply<1>( + tensor_lists, + LpMaxFunctor(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + const at::cuda::OptionalCUDAGuard device_guard( + device_of(output_per_tensor)); + auto stream = at::cuda::getCurrentCUDAStream(); + + const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL); + for (const auto i : c10::irange(num_kernels)) { + const size_t num_tensors_this_kernel = + (i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0) + ? MAX_TENSORS_PER_KERNEL + : (ntensors % MAX_TENSORS_PER_KERNEL); + + TensorListAddresses addr_struct; + for (const auto j : c10::irange(num_tensors_this_kernel)) { + addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j] + .mutable_data_ptr(); + } + + lpmax_cleanup<<>>( + output_per_tensor.const_data_ptr() + + i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor, + addr_struct, + max_chunks_per_tensor); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + + // correctly assign values to only non-empty slots, as the empty slots should + // get skipped + std::vector result; + result.reserve(ntensors); + int i = 0; + for (const auto& t : tensors) { + if (t.numel() != 0) { + result.emplace_back(vec_res[i]); + i++; + } else { + result.emplace_back(at::empty({}, options)); + } + } + return result; +} + template < typename T, NormType norm_type, diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index e8fd69c0aec9..c1e003ca8e53 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -103,7 +103,7 @@ __inline__ __device__ T BlockReduceMax(T val, T* shared) { shared[wid] = val; } __syncthreads(); - val = (tid < B::Warps()) ? shared[lid] : T(0); + val = (tid < B::Warps()) ? shared[lid] : T(-INFINITY); if (wid == 0) { val = WarpReduceMax(val); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e396ccb67672..41968a72fd8e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -11119,6 +11119,14 @@ CUDA: foreach_tensor_log2_cuda_ autogen: _foreach_log2.out +- func: _foreach_max(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CPU: foreach_tensor_max_slow + CUDA: foreach_tensor_max_cuda + autogen: _foreach_max.out + - func: _foreach_neg(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 669c3d91e849..f9bf58c5f474 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -248,6 +248,8 @@ aten::_foreach_log2 aten::_foreach_log2.out aten::_foreach_log2_ aten::_foreach_log_ +aten::_foreach_max +aten::_foreach_max.out aten::_foreach_maximum.List aten::_foreach_maximum.List_out aten::_foreach_maximum.Scalar diff --git a/test/test_foreach.py b/test/test_foreach.py index c46ff8ae21b6..58595b628dc7 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -88,7 +88,9 @@ def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): actual = self.func(*inputs, **kwargs) keys = tuple([e.key for e in p.key_averages()]) mta_called = any("multi_tensor_apply_kernel" in k for k in keys) - assert mta_called == (expect_fastpath and (not zero_size)) + assert mta_called == ( + expect_fastpath and (not zero_size) + ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}" else: actual = self.func(*inputs, **kwargs) if self.is_inplace: @@ -922,7 +924,10 @@ def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op): # note: BFloat16 has the same number of exponent bits as FP32 # so if squared L2 norm overflows in BF16, then it also overflows in FP32. @onlyCUDA - @ops(foreach_reduce_op_db, allowed_dtypes=(torch.half, torch.bfloat16)) + @ops( + [o for o in foreach_reduce_op_db if "norm" in o.name], + allowed_dtypes=(torch.half, torch.bfloat16), + ) def test_foreach_l2_large_value_input(self, device, dtype, op): ord, N = 2, 10 max_value = torch.finfo(dtype).max @@ -976,14 +981,20 @@ def test_big_num_tensors(self, device, dtype, op, use_cuda_graph): import math - for ord in (1, 2, math.inf): + if op.name == "_foreach_norm": + ords = (1, 2, math.inf) + else: + ords = (None,) + + for ord in ords: + kwargs = {"ord": ord} if ord else {} if not use_cuda_graph: actual = fn( inputs=[tensorlist], is_cuda=True, expect_fastpath=True, - ord=ord, zero_size=False, + **kwargs, ) else: # When using CUDA graphs and the tensor metadata doesn't fit in @@ -993,9 +1004,9 @@ def test_big_num_tensors(self, device, dtype, op, use_cuda_graph): # test verifies multi_tensor_apply's behavior in the scenario. g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): - actual = fn.func(tensorlist, ord=ord) + actual = fn.func(tensorlist, **kwargs) g.replay() - expect = ref_fn(inputs=[tensorlist], ord=ord) + expect = ref_fn(inputs=[tensorlist], **kwargs) self.assertEqual(expect, actual, equal_nan=True) @@ -1003,16 +1014,23 @@ def test_big_num_tensors(self, device, dtype, op, use_cuda_graph): @ops(foreach_reduce_op_db) def test_foreach_reduce_large_input(self, device, dtype, op): # test inputs larger than kChunkSize = 65536 - ord, N = 2, 65536 * 2 - disable_fastpath = True - if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16): - disable_fastpath = False + N = 65536 * 2 + disable_fastpath = dtype in (torch.int8, torch.int16, torch.bool) + kwargs = {} + if op.name == "_foreach_norm": + ord = 2 + disable_fastpath = not ( + ord in (1, 2) + and dtype in floating_types_and(torch.half, torch.bfloat16) + ) + kwargs["ord"] = ord + inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],) wrapped_op, ref, _, _ = self._get_funcs(op) self.assertEqual( - ref(inputs, ord=ord), + ref(inputs, **kwargs), wrapped_op( - inputs, self.is_cuda, not disable_fastpath, ord=ord, zero_size=False + inputs, self.is_cuda, not disable_fastpath, zero_size=False, **kwargs ), ) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 89b452bca505..d25866d0abb9 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3176,6 +3176,7 @@ def register(op): aten._foreach_log10, aten._foreach_log1p, aten._foreach_log2, + aten._foreach_max, aten._foreach_neg, aten._foreach_reciprocal, aten._foreach_round, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 62958b136364..4e17fcd5d277 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -9166,6 +9166,7 @@ def __init__( self._set_rightmost_arg_types( rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor, ) + self._intersperse_empty = (True, False) def _set_rightmost_arg_types( self, @@ -9330,7 +9331,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): # add empty tensor interspersion to test fully fixing #100701 for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( - num_input_tensors, self._rightmost_arg_types, (True, False)): + num_input_tensors, self._rightmost_arg_types, self._intersperse_empty): if intersperse_empty_tensors and (num_tensors != max(num_input_tensors) or str(device) == 'cpu'): # generate interspersed empty tensors for only 1 N on non-cpu device to lessen redundancy continue @@ -9364,6 +9365,24 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): ) +class foreach_max_sample_func(foreach_inputs_sample_func): + def __init__( + self, + arity: int, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool = False, + ) -> None: + super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor) + self._intersperse_empty = (False,) + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + return [] + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + return dtype in (torch.int8, torch.int16, torch.bool) + + class foreach_norm_sample_func(foreach_inputs_sample_func): def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): assert "num_input_tensors" not in kwargs @@ -11099,6 +11118,45 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): ] foreach_reduce_op_db: List[ForeachFuncInfo] = [ + ForeachFuncInfo( + "max", + sample_inputs_func=foreach_max_sample_func(1, False, False), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_foreach_reduce_large_input", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + ), + ), ForeachFuncInfo( "norm", sample_inputs_func=foreach_norm_sample_func(1, False, False), diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 1a55211b9990..10b011741d55 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -321,6 +321,17 @@ def is_foreach_func(f: NativeFunction) -> bool: "_foreach_mul.Tensor", "_foreach_div.Tensor", } +# The following do not support the alpha kwarg, which the nonforeach versions support. +_skip_argument_len_check = { + "_foreach_add.Scalar", + "_foreach_add_.Scalar", + "_foreach_add.ScalarList", + "_foreach_add_.ScalarList", + "_foreach_sub.Scalar", + "_foreach_sub_.Scalar", + "_foreach_sub.ScalarList", + "_foreach_sub_.ScalarList", +} # Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function @@ -335,6 +346,11 @@ def is_reference_for_foreach( not function_schema.name.name.inplace or str(f.func.name) in _foreach_with_inplace_ref ) + and ( + str(f.func.name) in _skip_argument_len_check + or len(f.func.arguments.flat_non_out) + == len(function_schema.arguments.flat_non_out) + ) and all( ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) for arg, ref_arg in zip( From 05e99154eef3d89ff6909f2653d2ec6cfdf26f90 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 29 May 2024 08:18:19 -0700 Subject: [PATCH 053/706] Allow int vals to go down the fastpath for _foreach_max (#127303) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127303 Approved by: https://github.com/albanD ghstack dependencies: #127187 --- aten/src/ATen/native/cuda/ForeachReduceOp.cu | 22 +++++-------------- aten/src/ATen/native/cuda/block_reduce.cuh | 2 +- test/test_foreach.py | 2 +- .../_internal/common_methods_invocations.py | 3 ++- 4 files changed, 10 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu index 04b7c12e9a1a..7c2a389351a2 100644 --- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu +++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu @@ -68,8 +68,8 @@ struct LpMaxFunctor { T vals[kILP]; T r_x[kILP]; for (int64_t i = 0; i < kILP; i++) { - vals[i] = T(-INFINITY); - r_x[i] = T(-INFINITY); + vals[i] = T(std::numeric_limits::lowest()); + r_x[i] = T(std::numeric_limits::lowest()); } if (n % kILP == 0 && (chunk_size & kILP) == 0 && is_aligned(x)) { @@ -96,7 +96,7 @@ struct LpMaxFunctor { } } - auto val = T(-INFINITY); + auto val = T(std::numeric_limits::lowest()); for (int i = 0; i < kILP; i++) { val = max_propagate_nan(val, vals[i]); } @@ -118,7 +118,7 @@ __global__ void lpmax_cleanup( __shared__ T vals[512]; const T* output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor; - T val = -INFINITY; + T val = std::numeric_limits::lowest(); for (size_t i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) { val = max_propagate_nan(val, output_this_tensor[i]); } @@ -130,21 +130,11 @@ __global__ void lpmax_cleanup( std::vector foreach_tensor_max_cuda(TensorList tensors) { check_foreach_api_restrictions(tensors); - // we currently use -INF as the identity value to compare against, which - // does not work for int8, int16, nor bool. Fall back to slow path here. - const bool has_small_int_or_bool = - std::any_of(tensors.begin(), tensors.end(), [](const auto& t) { - const auto scalar_type = t.scalar_type(); - return scalar_type == at::ScalarType::Short || - scalar_type == at::ScalarType::Char || - scalar_type == at::ScalarType::Bool; - }); - if (!can_use_fast_route(tensors) || has_small_int_or_bool) { + if (!can_use_fast_route(tensors)) { return foreach_tensor_max_slow(tensors); } - // for parity with max in ReduceAllOps.cpp, though I think max(empty) should - // eventually be allowed. + // for parity with max in ReduceAllOps.cpp, as max(empty) is ??? TORCH_CHECK( std::all_of( tensors.begin(), diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index c1e003ca8e53..df757a11761b 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -103,7 +103,7 @@ __inline__ __device__ T BlockReduceMax(T val, T* shared) { shared[wid] = val; } __syncthreads(); - val = (tid < B::Warps()) ? shared[lid] : T(-INFINITY); + val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits::lowest()); if (wid == 0) { val = WarpReduceMax(val); } diff --git a/test/test_foreach.py b/test/test_foreach.py index 58595b628dc7..8465d538187c 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -1015,7 +1015,7 @@ def test_big_num_tensors(self, device, dtype, op, use_cuda_graph): def test_foreach_reduce_large_input(self, device, dtype, op): # test inputs larger than kChunkSize = 65536 N = 65536 * 2 - disable_fastpath = dtype in (torch.int8, torch.int16, torch.bool) + disable_fastpath = False kwargs = {} if op.name == "_foreach_norm": ord = 2 diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4e17fcd5d277..9a83bf5b0038 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -9380,7 +9380,7 @@ def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, * return [] def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): - return dtype in (torch.int8, torch.int16, torch.bool) + return False class foreach_norm_sample_func(foreach_inputs_sample_func): @@ -11125,6 +11125,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( + # no complex support for ordering ops like max DecorateInfo( unittest.expectedFailure, "TestForeach", From 82a370ae3aacbbc1d8ec65deda0147999d055765 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Tue, 28 May 2024 18:07:22 -0700 Subject: [PATCH 054/706] Revert "Refresh OpOverloadPacket if a new OpOverload gets added (#126863)" (#127366) This reverts commit ed734178abc99bc1d83ad2c61d3a1e4d4f5d20c8. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127366 Approved by: https://github.com/zou3519 --- test/jit/test_list_dict.py | 31 +------------------------------ test/test_custom_ops.py | 24 ------------------------ torch/_ops.py | 26 +++++--------------------- torch/library.py | 16 +--------------- 4 files changed, 7 insertions(+), 90 deletions(-) diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 90fa24e43506..f3d314dbac77 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -5,7 +5,7 @@ import sys import types import unittest -from collections import defaultdict, OrderedDict +from collections import OrderedDict from textwrap import dedent from typing import Any, Dict, List, NamedTuple, Optional, Tuple @@ -2966,32 +2966,3 @@ def test_reference_semantics(self): self.assertEqual(len(l), 3) self.assertTrue(3 in l) self.assertEqual(l[2], 3) - - def test_defaultdict(self): - def get_dict(): - test_dict = defaultdict(list) - return test_dict - - class Test(torch.nn.Module): - segments_groupby_col: Dict[str, List[str]] - - def __init__(self): - super().__init__() - self.segments_groupby_col = get_dict() - self.col1 = "a" - self.col2 = "b" - - def forward(self): - if self.col1 in self.segments_groupby_col.keys(): - return 1 - else: - return 2 - - test = Test() - test_script = torch.jit.script(test) - test_script.segments_groupby_col - - # Smoketest for flakiness. Takes around 2s. - for i in range(300): - test = Test() - test_script = torch.jit.script(test) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index e2af3efaa98a..1239ff8e0ebd 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -2850,30 +2850,6 @@ def f(x: Tensor) -> Tensor: y = f(x) self.assertEqual(y, x.sin()) - @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") - def test_overloading(self): - called_f = 0 - called_f1 = 0 - - @torch.library.custom_op("_torch_testing::f", mutates_args=()) - def f(x: Tensor) -> Tensor: - nonlocal called_f - called_f += 1 - return x.clone() - - x = torch.randn(2, 3) - torch.ops._torch_testing.f(x) - self.assertEqual(called_f, 1) - - @torch.library.custom_op("_torch_testing::f.overload", mutates_args=()) - def f1(x: Tensor, y: Tensor) -> Tensor: - nonlocal called_f1 - called_f1 += 1 - return x.clone() - - torch.ops._torch_testing.f(x, x) - self.assertEqual(called_f1, 1) - def test_disallows_output_aliasing(self): @torch.library.custom_op("_torch_testing::f", mutates_args=()) def f(x: Tensor) -> Tensor: diff --git a/torch/_ops.py b/torch/_ops.py index 83a7b6b849df..0b19c75a51aa 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1161,10 +1161,8 @@ def __getattr__(self, op_name): # for overloads and raise an exception if there are more than one. namespace_name = self.name qualified_op_name = f"{namespace_name}::{op_name}" - module_name = self.__module__ + "." + namespace_name - try: - op, overload_names = _get_packet(qualified_op_name, module_name) + op, overload_names = torch._C._jit_get_operation(qualified_op_name) if op is None: raise AttributeError( f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" @@ -1176,7 +1174,10 @@ def __getattr__(self, op_name): f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" ) from e - op.__module__ = module_name + # let the script frontend know that op is identical to the builtin op + # with qualified_op_name + torch.jit._builtins._register_builtin(op, qualified_op_name) + op.__module__ = self.__module__ + "." + namespace_name opoverloadpacket = OpOverloadPacket( qualified_op_name, op_name, op, overload_names ) @@ -1188,23 +1189,6 @@ def __getattr__(self, op_name): return opoverloadpacket -def _get_packet(qualname, op_module): - op, overload_names = torch._C._jit_get_operation(qualname) - if op is not None: - # let the script frontend know that op is identical to the builtin op - # with qualified_op_name - torch.jit._builtins._register_builtin(op, qualname) - op.__module__ = op_module - return op, overload_names - - -def _refresh_packet(packet): - op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__) - assert op is not None - packet._op = op - packet._overload_names = overload_names - - class _PyOpNamespace(_OpNamespace): def __init__(self, name, ops): super().__init__(name) diff --git a/torch/library.py b/torch/library.py index f771141ec436..48055da5b55c 100644 --- a/torch/library.py +++ b/torch/library.py @@ -109,22 +109,8 @@ def define(self, schema, alias_analysis="", *, tags=()): assert self.m is not None if isinstance(tags, torch.Tag): tags = (tags,) - - name = schema.split("(")[0] - packet_name = name.split(".")[0] if "." in name else name - has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(getattr(torch.ops, self.ns), packet_name) - result = self.m.define(schema, alias_analysis, tuple(tags)) - name = schema.split("(")[0] - qualname = self.ns + "::" + name - - # If the OpOverloadPacket exists already, then this means we're adding a - # new OpOverload for it. Refresh the packet to include the new OpOverload. - if has_preexisting_packet: - ns = getattr(torch.ops, self.ns) - packet = getattr(ns, packet_name) - torch._ops._refresh_packet(packet) - + qualname = self.ns + "::" + schema.split("(")[0] self._op_defs.add(qualname) _defs.add(qualname) return result From c404b2968cfe1163fff1802a6c1b71d5579a729b Mon Sep 17 00:00:00 2001 From: Kwanghoon An Date: Wed, 29 May 2024 19:33:26 +0000 Subject: [PATCH 055/706] Support min/max carry over for eager mode from_float method (#127309) Summary: After QAT is completed or given pre-tuned weight observer via tunable PTQ algorithm, it should not over-write again with a given weight, at least for static QAT never. Dynamic QAT also does not require to re-run weight observer again by design. This is a fix Test Plan: Signals Differential Revision: D57747749 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127309 Approved by: https://github.com/jerryzh168 --- .../eager/test_quantize_eager_qat.py | 80 +++++++++++++------ .../ao/nn/intrinsic/qat/modules/conv_fused.py | 26 +++--- .../nn/intrinsic/qat/modules/linear_fused.py | 2 +- .../nn/intrinsic/qat/modules/linear_relu.py | 4 +- .../quantized/dynamic/modules/linear_relu.py | 4 +- .../nn/intrinsic/quantized/modules/bn_relu.py | 8 +- .../intrinsic/quantized/modules/conv_add.py | 8 +- .../intrinsic/quantized/modules/conv_relu.py | 12 +-- .../quantized/modules/linear_relu.py | 8 +- torch/ao/nn/qat/modules/conv.py | 14 ++-- torch/ao/nn/qat/modules/embedding_ops.py | 4 +- torch/ao/nn/qat/modules/linear.py | 2 +- torch/ao/nn/quantizable/modules/rnn.py | 2 +- .../ao/nn/quantized/dynamic/modules/linear.py | 2 +- torch/ao/nn/quantized/dynamic/modules/rnn.py | 24 +++--- torch/ao/nn/quantized/modules/__init__.py | 4 +- torch/ao/nn/quantized/modules/activation.py | 14 ++-- torch/ao/nn/quantized/modules/batchnorm.py | 10 +-- torch/ao/nn/quantized/modules/conv.py | 16 ++-- torch/ao/nn/quantized/modules/dropout.py | 2 +- .../ao/nn/quantized/modules/embedding_ops.py | 4 +- .../quantized/modules/functional_modules.py | 2 +- torch/ao/nn/quantized/modules/linear.py | 12 ++- .../ao/nn/quantized/modules/normalization.py | 10 +-- .../ao/nn/quantized/reference/modules/rnn.py | 2 +- .../nn/quantized/reference/modules/sparse.py | 2 +- .../ao/nn/sparse/quantized/dynamic/linear.py | 2 +- torch/ao/nn/sparse/quantized/linear.py | 2 +- torch/ao/quantization/quantize.py | 27 +++++-- 29 files changed, 178 insertions(+), 131 deletions(-) diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py index 52f169b1d5b6..31ffa3104b65 100644 --- a/test/quantization/eager/test_quantize_eager_qat.py +++ b/test/quantization/eager/test_quantize_eager_qat.py @@ -2,62 +2,63 @@ import copy import math + import torch -import torch.nn as nn -import torch.backends.mkldnn -from torch.nn import Conv2d, BatchNorm2d, ReLU, init -from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d -from torch.nn.modules.utils import _pair -import torch.ao.nn.quantized as nnq -import torch.ao.nn.quantized.dynamic as nnqd -import torch.ao.nn.qat as nnqat import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.qat as nnqat import torch.ao.nn.qat.dynamic as nnqatd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.backends.mkldnn +import torch.nn as nn +import torch.testing._internal.hypothesis_utils as hu + +from hypothesis import given, strategies as st +from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d from torch.ao.quantization import ( - prepare, convert, - prepare_qat, - quantize_qat, - QuantStub, - DeQuantStub, - default_qconfig, - default_qat_qconfig, default_embedding_qat_qconfig, + default_qat_qconfig, + default_qconfig, default_symmetric_qnnpack_qat_qconfig, - get_default_qat_qconfig, + DeQuantStub, FixedQParamsFakeQuantize, FusedMovingAvgObsFakeQuantize, + get_default_qat_qconfig, get_embedding_qat_module_mappings, get_embedding_static_quant_module_mappings, NoopObserver, + prepare, + prepare_qat, + quantize_qat, + QuantStub, ) from torch.ao.quantization.qconfig import qconfig_equals +from torch.nn import BatchNorm2d, Conv2d, init, ReLU +from torch.nn.modules.utils import _pair from torch.testing._internal.common_quantization import ( DeFusedEmbeddingBagLinear, - QuantizationTestCase, - QuantStubModel, - ManualLinearQATModel, - ManualDropoutQATModel, - ManualLinearDynamicQATModel, ManualConvLinearQATModel, ManualConvLinearSymmQATModel, + ManualDropoutQATModel, ManualEmbeddingBagLinear, - TwoLayerLinearModel, + ManualLinearDynamicQATModel, + ManualLinearQATModel, + QuantizationTestCase, + QuantStubModel, test_only_eval_fn, test_only_train_fn, + TwoLayerLinearModel, ) from torch.testing._internal.common_quantized import ( + override_qengines, override_quantized_engine, supported_qengines, - override_qengines, ) from torch.testing._internal.common_utils import skipIfNoXNNPACK -from hypothesis import given -from hypothesis import strategies as st -import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() from functools import reduce @@ -1099,6 +1100,33 @@ def test_linear_bn_workflow(self): self.assertTrue(type(mq[1]) == nnq.Linear) self.assertTrue(type(mq[2]) == nn.Identity) + + @skipIfNoXNNPACK + @override_qengines + def test_linear_precomputed_fake_quant(self): + qengine = torch.backends.quantized.engine + if qengine != "qnnpack": + return # Only qnnpack support symmetric quantization + m_ref = nn.Linear(4, 4) + + m_ref_copy = copy.deepcopy(m_ref) + qconfig = default_qconfig + m_ref_copy.qconfig = qconfig + weight_post_process = copy.deepcopy(qconfig.weight()) + activation = copy.deepcopy(qconfig.activation()) + activation(torch.randn(4, 4)) + m_ref_copy.activation_post_process = activation + m_ref_copy = nnq.Linear.from_float(m_ref_copy) + weight_post_process = qconfig.weight() + weight_post_process.min_val = torch.tensor(-1) + weight_post_process.max_val = torch.tensor(1) + m_ref.weight_post_process = weight_post_process + m_ref.activation_post_process = activation + m_ref.qconfig = qconfig + m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True) + self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale) + + if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "\tpython test/test_quantization.py TESTNAME\n\n" diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 906206e18e64..3aa068e382d7 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -289,7 +289,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss missing_keys, unexpected_keys, error_msgs) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.ao.quantization utilities @@ -453,8 +453,8 @@ def forward(self, input): return F.relu(ConvBn1d._forward(self, input)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with @@ -490,8 +490,8 @@ def forward(self, input): self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class ConvBn2d(_ConvBnNd, nn.Conv2d): r""" @@ -585,8 +585,8 @@ def forward(self, input): return F.relu(ConvBn2d._forward(self, input)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with @@ -622,8 +622,8 @@ def forward(self, input): self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class ConvBn3d(_ConvBnNd, nn.Conv3d): r""" @@ -758,8 +758,8 @@ def forward(self, input): return F.relu(ConvBn3d._forward(self, input)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with @@ -813,8 +813,8 @@ def forward(self, input): ) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) def update_bn_stats(mod): if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}: diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py index 5b67283dce4b..fb7ac4545bb3 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py @@ -133,7 +133,7 @@ def train(self, mode=True): return self @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module or qparams_dict Args: `mod' a float module, either produced by torch.ao.quantization diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py index 97f7a1dbc339..7319c882b0aa 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py @@ -36,8 +36,8 @@ def forward(self, input): return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) def to_float(self): linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None) diff --git a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py index a0bccdc0e3d3..9d0467c4cd57 100644 --- a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -47,8 +47,8 @@ def _get_name(self): return 'DynamicQuantizedLinearReLU' @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qlinear_relu): diff --git a/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py b/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py index 856fa43aac99..32c1d0eeb351 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py @@ -37,9 +37,9 @@ def _get_name(self): return 'QuantizedBNReLU2d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): # TODO: Add qat support for BNReLU2d - return super().from_float(mod) + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, bn_relu, output_scale, output_zero_point): @@ -73,9 +73,9 @@ def _get_name(self): return 'QuantizedBNReLU3d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): # TODO: Add qat support for BNReLU3d - return super().from_float(mod) + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, bn_relu, output_scale, output_zero_point): diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py index 6e46aa8915e4..a369d2b7cec7 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -42,8 +42,8 @@ def _get_name(self): return 'QuantizedConvAdd2d' @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): @@ -85,8 +85,8 @@ def _get_name(self): return 'QuantizedConvAddReLU2d' @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py index 5cdc9004c99c..10011e52b3ef 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py @@ -53,13 +53,13 @@ def _get_name(self): return 'QuantizedConvReLU1d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias) - return super().from_float(mod) + return super().from_float(mod, use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): @@ -103,13 +103,13 @@ def _get_name(self): return 'QuantizedConvReLU2d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias) - return super().from_float(mod) + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): @@ -154,7 +154,7 @@ def _get_name(self): return 'QuantizedConvReLU3d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( @@ -166,7 +166,7 @@ def from_float(cls, mod): mod.bn.weight, mod.bn.bias, ) - return super().from_float(mod) + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): diff --git a/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py index e774a72dc822..ed64cba253b2 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py @@ -40,8 +40,8 @@ def _get_name(self): return 'QuantizedLinearReLU' @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_linear_relu, output_scale, output_zero_point): @@ -77,7 +77,7 @@ def _get_name(self): return 'QuantizedLinearLeakyReLU' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU' assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' activation_post_process = mod.activation_post_process @@ -144,7 +144,7 @@ def _get_name(self): return 'QuantizedLinearTanh' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh' assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' activation_post_process = mod.activation_post_process diff --git a/torch/ao/nn/qat/modules/conv.py b/torch/ao/nn/qat/modules/conv.py index 2b588d84a74e..0f56708fb84a 100644 --- a/torch/ao/nn/qat/modules/conv.py +++ b/torch/ao/nn/qat/modules/conv.py @@ -44,7 +44,7 @@ def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @staticmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module Args: @@ -150,8 +150,8 @@ def __init__(self, dtype=dtype) @classmethod - def from_float(cls, mod): - return super().from_float(cls, mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class Conv2d(_ConvNd, nn.Conv2d): r""" @@ -208,8 +208,8 @@ def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod - def from_float(cls, mod): - return super().from_float(cls, mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class Conv3d(_ConvNd, nn.Conv3d): r""" @@ -266,5 +266,5 @@ def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod - def from_float(cls, mod): - return super().from_float(cls, mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) diff --git a/torch/ao/nn/qat/modules/embedding_ops.py b/torch/ao/nn/qat/modules/embedding_ops.py index da7f33363742..499d872ba049 100644 --- a/torch/ao/nn/qat/modules/embedding_ops.py +++ b/torch/ao/nn/qat/modules/embedding_ops.py @@ -42,7 +42,7 @@ def forward(self, input) -> Tensor: self.sparse) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module Args: `mod` a float module, either produced by torch.ao.quantization utilities @@ -112,7 +112,7 @@ def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor: self.padding_idx) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module Args: `mod` a float module, either produced by torch.ao.quantization utilities diff --git a/torch/ao/nn/qat/modules/linear.py b/torch/ao/nn/qat/modules/linear.py index 99d43ed3f6c2..a7083401cb21 100644 --- a/torch/ao/nn/qat/modules/linear.py +++ b/torch/ao/nn/qat/modules/linear.py @@ -41,7 +41,7 @@ def forward(self, input): return F.linear(input, self.weight_fake_quant(self.weight), self.bias) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.ao.quantization utilities or directly from user diff --git a/torch/ao/nn/quantizable/modules/rnn.py b/torch/ao/nn/quantizable/modules/rnn.py index 2c57d1ae9bc5..7c4eebafefbb 100644 --- a/torch/ao/nn/quantizable/modules/rnn.py +++ b/torch/ao/nn/quantizable/modules/rnn.py @@ -122,7 +122,7 @@ def from_params(cls, wi, wh, bi=None, bh=None): return cell @classmethod - def from_float(cls, other): + def from_float(cls, other, use_precomputed_fake_quant=False): assert type(other) == cls._FLOAT_MODULE assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" observed = cls.from_params(other.weight_ih, other.weight_hh, diff --git a/torch/ao/nn/quantized/dynamic/modules/linear.py b/torch/ao/nn/quantized/dynamic/modules/linear.py index bf77aa04f0cb..85b89b75fe58 100644 --- a/torch/ao/nn/quantized/dynamic/modules/linear.py +++ b/torch/ao/nn/quantized/dynamic/modules/linear.py @@ -77,7 +77,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a dynamic quantized module from a float module or qparams_dict Args: diff --git a/torch/ao/nn/quantized/dynamic/modules/rnn.py b/torch/ao/nn/quantized/dynamic/modules/rnn.py index dac1b820d50a..c81771a71889 100644 --- a/torch/ao/nn/quantized/dynamic/modules/rnn.py +++ b/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -268,7 +268,7 @@ def weight_bias_name(ihhh, layer, suffix): self._all_weight_values = torch.nn.ModuleList(_all_weight_values) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): assert type(mod) in {torch.nn.LSTM, torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU' assert hasattr( @@ -495,8 +495,8 @@ def forward(self, input, hx=None): return self.forward_tensor(input, hx) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_mod): @@ -747,8 +747,8 @@ def forward(self, input, hx=None): return self.forward_tensor(input, hx) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_mod): @@ -839,7 +839,7 @@ def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '' f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}") @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): assert type(mod) in {torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \ @@ -1012,8 +1012,8 @@ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: return ret @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class LSTMCell(RNNCellBase): @@ -1055,8 +1055,8 @@ def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> self.bias_ih, self.bias_hh) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class GRUCell(RNNCellBase): @@ -1096,5 +1096,5 @@ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: ) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) diff --git a/torch/ao/nn/quantized/modules/__init__.py b/torch/ao/nn/quantized/modules/__init__.py index 668f765fe3ef..f539db753a47 100644 --- a/torch/ao/nn/quantized/modules/__init__.py +++ b/torch/ao/nn/quantized/modules/__init__.py @@ -98,7 +98,7 @@ def forward(self, X): int(self.zero_point), self.dtype) @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): assert hasattr(mod, 'activation_post_process') scale, zero_point = mod.activation_post_process.calculate_qparams() return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype) @@ -127,5 +127,5 @@ def forward(self, Xq): return Xq.dequantize() @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): return DeQuantize() diff --git a/torch/ao/nn/quantized/modules/activation.py b/torch/ao/nn/quantized/modules/activation.py index 6fcd223e5049..094ac63fb0af 100644 --- a/torch/ao/nn/quantized/modules/activation.py +++ b/torch/ao/nn/quantized/modules/activation.py @@ -46,7 +46,7 @@ def _get_name(self): return 'QuantizedReLU6' @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): return ReLU6(mod.inplace) class Hardswish(torch.nn.Hardswish): @@ -69,7 +69,7 @@ def _get_name(self): return 'QuantizedHardswish' @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() return Hardswish(float(scale), int(zero_point)) @@ -98,7 +98,7 @@ def _get_name(self): return 'QuantizedELU' @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() return ELU(float(scale), int(zero_point), mod.alpha) @@ -129,7 +129,7 @@ def _get_name(self): return 'QuantizedLeakyReLU' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) @@ -154,7 +154,7 @@ def forward(self, input): return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): output_scale, output_zero_point = mod.activation_post_process.calculate_qparams() return cls(float(output_scale), int(output_zero_point)) @@ -187,7 +187,7 @@ def _get_name(self): return 'QuantizedSoftmax' @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() return Softmax(mod.dim, float(scale), int(zero_point)) @@ -269,7 +269,7 @@ def _get_name(self): return 'QuantizedPReLU' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() qprelu = cls(float(scale), int(zero_point), mod.num_parameters) float_wt = mod.weight.float() diff --git a/torch/ao/nn/quantized/modules/batchnorm.py b/torch/ao/nn/quantized/modules/batchnorm.py index bfef31268cff..3644a314e9e8 100644 --- a/torch/ao/nn/quantized/modules/batchnorm.py +++ b/torch/ao/nn/quantized/modules/batchnorm.py @@ -14,7 +14,7 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs)) @staticmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): activation_post_process = mod.activation_post_process if type(mod) == cls._NNI_BN_RELU_MODULE: mod = mod[0] @@ -72,8 +72,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.running_var, self.eps, self.scale, self.zero_point) @classmethod - def from_float(cls, mod): - return _BatchNorm.from_float(cls, mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return _BatchNorm.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class BatchNorm3d(_BatchNorm): r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`. @@ -102,5 +102,5 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.running_var, self.eps, self.scale, self.zero_point) @classmethod - def from_float(cls, mod): - return _BatchNorm.from_float(cls, mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return _BatchNorm.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index ad1a51ee9c3b..5e41aa5bfdaf 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -215,7 +215,7 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None): return qconv @staticmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): if hasattr(mod, "weight_fake_quant"): # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ # ".from_float only works for " + cls.__QAT_MODULE.__name__ @@ -368,14 +368,14 @@ def forward(self, input): return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by the user """ - return _ConvNd.from_float(cls, mod) + return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class Conv2d(_ConvNd): @@ -469,14 +469,14 @@ def forward(self, input): input, self._packed_params, self.scale, self.zero_point) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by the user """ - return _ConvNd.from_float(cls, mod) + return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class Conv3d(_ConvNd): @@ -571,14 +571,14 @@ def forward(self, input): input, self._packed_params, self.scale, self.zero_point) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by the user """ - return _ConvNd.from_float(cls, mod) + return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) # === Transposed Convolutions === MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd) @@ -609,7 +609,7 @@ def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: L return res @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.ao.quantization diff --git a/torch/ao/nn/quantized/modules/dropout.py b/torch/ao/nn/quantized/modules/dropout.py index 64110ab53bed..759113bdbf25 100644 --- a/torch/ao/nn/quantized/modules/dropout.py +++ b/torch/ao/nn/quantized/modules/dropout.py @@ -19,7 +19,7 @@ def _get_name(self): return 'QuantizedDropout' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): return cls(mod.p, mod.inplace) @classmethod diff --git a/torch/ao/nn/quantized/modules/embedding_ops.py b/torch/ao/nn/quantized/modules/embedding_ops.py index 25de7fa9b3cf..dc6f66a0d4eb 100644 --- a/torch/ao/nn/quantized/modules/embedding_ops.py +++ b/torch/ao/nn/quantized/modules/embedding_ops.py @@ -137,7 +137,7 @@ def weight(self): return self._packed_params._weight() @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized embedding module from a float module Args: @@ -241,7 +241,7 @@ def _get_name(self): return 'QuantizedEmbeddingBag' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized embedding_bag module from a float module Args: diff --git a/torch/ao/nn/quantized/modules/functional_modules.py b/torch/ao/nn/quantized/modules/functional_modules.py index 96408457a449..4cb135dee0ec 100644 --- a/torch/ao/nn/quantized/modules/functional_modules.py +++ b/torch/ao/nn/quantized/modules/functional_modules.py @@ -239,7 +239,7 @@ def matmul(self, x: Tensor, y: Tensor) -> Tensor: return r @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): assert type(mod) == FloatFunctional, \ "QFunctional.from_float expects an instance of FloatFunctional" scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] diff --git a/torch/ao/nn/quantized/modules/linear.py b/torch/ao/nn/quantized/modules/linear.py index 9d988104a71d..cbc01b092f3a 100644 --- a/torch/ao/nn/quantized/modules/linear.py +++ b/torch/ao/nn/quantized/modules/linear.py @@ -240,12 +240,14 @@ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: self._packed_params.set_weight_bias(w, b) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized module from an observed float module Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by the user + use_precomputed_fake_quant (bool): if True, the module will reuse min/max + values from the precomputed fake quant module. """ if hasattr(mod, 'weight_fake_quant'): if type_before_parametrizations(mod) == nniqat.LinearBn1d: @@ -267,8 +269,12 @@ def from_float(cls, mod): activation_post_process = mod.activation_post_process if type_before_parametrizations(mod) == nni.LinearReLU: mod = mod[0] - weight_post_process = mod.qconfig.weight() - weight_post_process(mod.weight) + weight_post_process = mod.qconfig.weight() if not hasattr(mod, "weight_fake_quant") else mod.weight_fake_quant + + if not use_precomputed_fake_quant: + # Observer may not have been called yet + # Observer might have been called in the previous stage via PTQ algorithm e.g. AdaRound + weight_post_process(mod.weight) dtype = weight_post_process.dtype act_scale, act_zp = activation_post_process.calculate_qparams() assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' diff --git a/torch/ao/nn/quantized/modules/normalization.py b/torch/ao/nn/quantized/modules/normalization.py index f798a241e324..e7c5c85a4527 100644 --- a/torch/ao/nn/quantized/modules/normalization.py +++ b/torch/ao/nn/quantized/modules/normalization.py @@ -30,7 +30,7 @@ def _get_name(self): return 'QuantizedLayerNorm' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.normalized_shape, mod.weight, mod.bias, float(scale), @@ -71,7 +71,7 @@ def _get_name(self): return 'QuantizedGroupNorm' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point), @@ -105,7 +105,7 @@ def _get_name(self): return 'QuantizedInstanceNorm1d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), @@ -145,7 +145,7 @@ def _get_name(self): return 'QuantizedInstanceNorm2d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), @@ -185,7 +185,7 @@ def _get_name(self): return 'QuantizedInstanceNorm3d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), diff --git a/torch/ao/nn/quantized/reference/modules/rnn.py b/torch/ao/nn/quantized/reference/modules/rnn.py index 4120338ce271..978c1d69f30a 100644 --- a/torch/ao/nn/quantized/reference/modules/rnn.py +++ b/torch/ao/nn/quantized/reference/modules/rnn.py @@ -213,7 +213,7 @@ def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> return ret @classmethod - def from_float(cls, mod, weight_qparams_dict): + def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False): ref_mod = cls( mod.input_size, mod.hidden_size, diff --git a/torch/ao/nn/quantized/reference/modules/sparse.py b/torch/ao/nn/quantized/reference/modules/sparse.py index 4890402b875a..973eb05bd3b3 100644 --- a/torch/ao/nn/quantized/reference/modules/sparse.py +++ b/torch/ao/nn/quantized/reference/modules/sparse.py @@ -76,7 +76,7 @@ def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_we self.padding_idx) @classmethod - def from_float(cls, mod, weight_qparams): + def from_float(cls, mod, weight_qparams, use_precomputed_fake_quant=False): return cls( mod.num_embeddings, mod.embedding_dim, diff --git a/torch/ao/nn/sparse/quantized/dynamic/linear.py b/torch/ao/nn/sparse/quantized/dynamic/linear.py index 5347b682fb5a..bc5cb99fced2 100644 --- a/torch/ao/nn/sparse/quantized/dynamic/linear.py +++ b/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -92,7 +92,7 @@ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor], self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized sparse dynamic module from a float module. We only care about the convert at this stage, no need for observers just yet. diff --git a/torch/ao/nn/sparse/quantized/linear.py b/torch/ao/nn/sparse/quantized/linear.py index 71caa8cbab61..9d1c8f332172 100644 --- a/torch/ao/nn/sparse/quantized/linear.py +++ b/torch/ao/nn/sparse/quantized/linear.py @@ -146,7 +146,7 @@ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor], self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized sparse module from a float module. We only care about the convert at this stage, no need for observers just yet. diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 794cb142220d..17cec9f35908 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -235,6 +235,13 @@ def insert_activation_post_process(m, special_act_post_process=None): if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \ and type_before_parametrizations(module) in qconfig_propagation_list: insert_activation_post_process(module) + # This is a special case for AdaRound eager mode + # AdaRound contains weight_fake_quant to be propagated from API to convert + # leaf node check with a number of children looks naive assumption that blocks + # Adding an exception case for AdaRound + if hasattr(module, "weight_fake_quant") and not isinstance(module, torch.nn.Sequential) \ + and type_before_parametrizations(module) in qconfig_propagation_list: + insert_activation_post_process(module) def _get_unique_devices_(module): return {p.device for p in module.parameters()} | \ @@ -520,7 +527,8 @@ def quantize_qat(model, run_fn, run_args, inplace=False): def convert( module, mapping=None, inplace=False, remove_qconfig=True, - is_reference=False, convert_custom_config_dict=None): + is_reference=False, convert_custom_config_dict=None, + use_precomputed_fake_quant=False): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class. And remove qconfig at the end if remove_qconfig is set to True. @@ -533,6 +541,7 @@ def convert( `inplace`: carry out model transformations in-place, the original module is mutated `convert_custom_config_dict`: custom configuration dictionary for convert function + `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant .. code-block:: python @@ -552,14 +561,16 @@ def convert( module = copy.deepcopy(module) _convert( module, mapping, inplace=True, is_reference=is_reference, - convert_custom_config_dict=convert_custom_config_dict) + convert_custom_config_dict=convert_custom_config_dict, + use_precomputed_fake_quant=use_precomputed_fake_quant) if remove_qconfig: _remove_qconfig(module) return module def _convert( module, mapping=None, inplace=False, - is_reference=False, convert_custom_config_dict=None): + is_reference=False, convert_custom_config_dict=None, + use_precomputed_fake_quant=False): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class @@ -571,6 +582,7 @@ def _convert( inplace: carry out model transformations in-place, the original module is mutated is_reference: a flag to enable quantized reference module + use_precomputed_fake_quant: a flag to enable use of precomputed fake quant """ if mapping is None: @@ -589,15 +601,16 @@ def _convert( if not isinstance(mod, _FusedModule) and \ type_before_parametrizations(mod) not in custom_module_class_mapping: _convert(mod, mapping, True, # inplace - is_reference, convert_custom_config_dict) - reassign[name] = swap_module(mod, mapping, custom_module_class_mapping) + is_reference, convert_custom_config_dict, + use_precomputed_fake_quant=use_precomputed_fake_quant) + reassign[name] = swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant) for key, value in reassign.items(): module._modules[key] = value return module -def swap_module(mod, mapping, custom_module_class_mapping): +def swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False): r"""Swaps the module if it has a quantized counterpart and it has an `observer` attached. @@ -623,7 +636,7 @@ def swap_module(mod, mapping, custom_module_class_mapping): weight_qparams = get_qparam_dict(weight_post_process) new_mod = qmod.from_float(mod, weight_qparams) else: - new_mod = qmod.from_float(mod) + new_mod = qmod.from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) swapped = True if swapped: From 24a4bfdcc202c30b207de269c12adfde09b4c52d Mon Sep 17 00:00:00 2001 From: Kwanghoon An Date: Wed, 29 May 2024 20:05:27 +0000 Subject: [PATCH 056/706] [AdaRound] Make versatile for data / extra param for callback function (#126891) Summary: For Speech sequential model, there could be a case where model(data) does not work correctly for feed forward, Speech model uses a different type of Criterion (a.k.a loss function) to feed a data on individual components like encoder, predictor, joiner. Hence we need extra parameter to pass feedforward wrapper Differential Revision: D57680391 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126891 Approved by: https://github.com/jerryzh168 --- .../core/experimental/test_adaround_eager.py | 25 ++++++- .../experimental/adaround_optimization.py | 65 ++++++++++++------- 2 files changed, 62 insertions(+), 28 deletions(-) diff --git a/test/quantization/core/experimental/test_adaround_eager.py b/test/quantization/core/experimental/test_adaround_eager.py index 33a16f21bd0f..a0a2f8f8aa03 100644 --- a/test/quantization/core/experimental/test_adaround_eager.py +++ b/test/quantization/core/experimental/test_adaround_eager.py @@ -29,14 +29,20 @@ def feedforawrd_callback( ) -> None: model(data) - def run_adaround(self, model, img_data): + def feedforawrd_callback_with_wrapper(self, model, data, wrapper) -> None: + wrapper(model, data) + + def run_adaround(self, model, img_data, wrapper=None): adaround_optimizer = AdaptiveRoundingOptimizer( model, - self.feedforawrd_callback, + self.feedforawrd_callback + if wrapper is None + else self.feedforawrd_callback_with_wrapper, forward_wrapper, img_data, max_iter=100, batch_size=10, + feed_forward_wrapper=wrapper, ) adarounded_model = adaround_optimizer.run_adaround() return adarounded_model @@ -63,6 +69,17 @@ def get_fake_quant(self, model): module.weight.data.copy_(fake_quant_module) return hard_fake_quant_model + def get_feed_forward_wrapper(self): + class FeedForwardWrapper(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, model, sample): + return model(sample) + + wrapper_module = FeedForwardWrapper() + return wrapper_module + def test_linear_chain(self): class LinearChain(nn.Module): def __init__(self): @@ -79,7 +96,9 @@ def forward(self, x): float_model = LinearChain() img_data = [torch.rand(10, 3, dtype=torch.float) for _ in range(50)] - adarounded_model = self.run_adaround(float_model, img_data) + adarounded_model = self.run_adaround( + float_model, img_data, self.get_feed_forward_wrapper() + ) fq_model = self.get_fake_quant(float_model) rand_input = torch.rand(10, 3) with torch.no_grad(): diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py index 7304f885a6f3..808b7abe2c78 100644 --- a/torch/ao/quantization/experimental/adaround_optimization.py +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -1,5 +1,4 @@ import copy -import logging from typing import Any, Callable, List, Optional, Tuple, Type, Union import torch @@ -12,16 +11,21 @@ from torch.nn.parallel import DataParallel from torch.utils.data import DataLoader, TensorDataset -logger: logging.Logger = logging.getLogger(__name__) - class AdaptiveRoundingOptimizer: def __init__( self, model: Union[torch.nn.Module, torch.nn.DataParallel], - callback: Callable[[torch.nn.Module, List[Any]], None], + callback: Callable[ + [ + Union[torch.nn.Module, torch.nn.DataParallel], + Any, + Optional[torch.nn.Module], + ], + None, + ], forward_hook_wrapper: Callable[[List[torch.Tensor]], Callable], - data: List[Any], + data: Any, observer: Type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver, max_iter=10000, dtype: torch.dtype = torch.qint8, @@ -29,8 +33,14 @@ def __init__( quant_max=127, qscheme: torch.qscheme = torch.per_tensor_symmetric, batch_size: int = 256, + feed_forward_wrapper: Optional[torch.nn.Module] = None, ): - self.model = model + if torch.cuda.is_available(): + self.model = model.cuda() + if torch.cuda.device_count() > 1: + self.model = torch.nn.DataParallel(model) + else: + self.model = model self.q_model = copy.deepcopy(self.model) self.device = torch.device("cuda") if torch.cuda.is_available() else None self.callback = callback @@ -47,20 +57,27 @@ def __init__( self.quant_min = quant_min self.quant_max = quant_max self.qscheme = qscheme + self.feed_forward_wrapper = feed_forward_wrapper def run_adaround(self) -> torch.nn.Module: layer_list: List[Tuple[str, torch.nn.Module, torch.nn.Module]] = [] for (name, module), q_module in zip( self.model.named_modules(), self.q_model.modules() ): + if isinstance(module, torch.nn.ReLU): + # Disable all inplace operations + module.inplace = False + if isinstance(q_module, torch.nn.ReLU): + # Disable all inplace operations + q_module.inplace = False if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)): # Knowing activation ahead-of-time would be helpful for asymmetric formulation # But this is challenging in eager mode, but graph module. layer_list.append((name, module, q_module)) - logger.info(f"Total number of layers : {len(layer_list)}") # noqa: G004 + print(f"Total number of layers : {len(layer_list)}") # noqa: G004 for name, module, q_module in layer_list: - logger.info( + print( f"Kick start adaptive rounding on {name} module {module}" # noqa: G004 ) self.optimize_adaptive_rounding( @@ -87,10 +104,15 @@ def get_data_inp_out( handler2 = q_module.register_forward_hook( self.forward_hook_wrapper(quant_fetcher) ) + if torch.cuda.is_available(): + # Somehow, we need to move the model continuously + # Otherwise, the model will be lowered to CPU misteriously + self.model = self.model.cuda() + self.q_model = self.q_model.cuda() for data_ in data: with torch.no_grad(): - self.callback(self.model, data_) - self.callback(self.q_model, data_) + self.callback(self.model, data_, self.feed_forward_wrapper) + self.callback(self.q_model, data_, self.feed_forward_wrapper) fp32_output = fp32_fetcher[1] quant_input = quant_fetcher[0] fp_out.append(fp32_output) @@ -137,7 +159,7 @@ def _compute_and_display_local_losses( out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module) soft_quant_loss = F.mse_loss(out_soft_quant, fp_out) hard_quant_loss = F.mse_loss(out_hard_quant, fp_out) - logger.info( + print( f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" # noqa: G004 ) @@ -162,13 +184,9 @@ def optimize_adaptive_rounding( optimizer = torch.optim.Adam([ada_quantizer.V]) inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data) - logger.info("==================== Before adaround ====================") - test_in, test_out, fp_test_in = self.get_data_inp_out( - module, q_module, self.data[0] - ) - + print("==================== Before adaround ====================") assert ( - torch.abs(test_out[0] - module(fp_test_in[0])).sum().item() == 0 + torch.abs(out[0] - module(fp_in[0])).sum().item() == 0 ), "In-placed activation is detected, please do not use activation in-placed" # Stack the tensors in each list into a single tensor # Assuming inp and out are your lists of tensors @@ -177,9 +195,7 @@ def optimize_adaptive_rounding( dataset = TensorDataset(inp_tensor, out_tensor) dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) - self._compute_and_display_local_losses( - ada_quantizer, q_module, test_in[0], test_out[0] - ) + self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0]) global_idx = 0 one_iter = len(out) // self.batch_size for iteration in range(self.max_iter // one_iter): @@ -191,6 +207,7 @@ def optimize_adaptive_rounding( q_out = torch.nn.functional.conv1d( q_inp, q_weight, + bias=q_module.bias, stride=q_module.stride, padding=q_module.padding, dilation=q_module.dilation, @@ -219,14 +236,12 @@ def optimize_adaptive_rounding( if global_idx >= self.max_iter: break if iteration % 30 == 0: - logger.info( + print( f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " # noqa: G004 f"reconstruction_loss {reconstruction_loss.item()}" # noqa: G004 ) - logger.info("==================== After adaround ====================") - self._compute_and_display_local_losses( - ada_quantizer, q_module, test_in[0], test_out[0] - ) + print("==================== After adaround ====================") + self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0]) ada_quantizer.use_soft_rounding = True ada_quantizer.V.requires_grad = False From 92d081e22842b4e4ed330b604e1f35ea3646c334 Mon Sep 17 00:00:00 2001 From: Yuanhao Ji Date: Wed, 29 May 2024 20:09:52 +0000 Subject: [PATCH 057/706] [Docs] Add `str` type to `cuda.get_device_name()` and `cuda. get_device_capability()` function (#126743) Fixes #126400 The `get_device_name()` and `get_device_capability()` allow passing in a string, but it's not stated in the doc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126743 Approved by: https://github.com/eqy, https://github.com/kit1980 --- torch/cuda/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 8c19788d1055..7cbb53012fe1 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -420,7 +420,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str: r"""Get the name of a device. Args: - device (torch.device or int, optional): device for which to return the + device (torch.device or int or str, optional): device for which to return the name. This function is a no-op if this argument is a negative integer. It uses the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -435,7 +435,7 @@ def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int] r"""Get the cuda capability of a device. Args: - device (torch.device or int, optional): device for which to return the + device (torch.device or int or str, optional): device for which to return the device capability. This function is a no-op if this argument is a negative integer. It uses the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` From 84b5aa9a68083aa633376b3b9c5d2c1a47901b45 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 29 May 2024 20:36:58 +0000 Subject: [PATCH 058/706] [Caffe2] [Reland] Remove Caffe2 proto files (#127394) Reland of #126134, which was reverted due to the wrong base. Now that #126705 has been relanded, it's time to remand this one. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127394 Approved by: https://github.com/r-barnes --- BUILD.bazel | 3 - caffe2/proto/BUILD.bazel | 18 --- caffe2/proto/caffe2_legacy.proto | 50 -------- caffe2/proto/caffe2_legacy_pb2.pyi | 58 ---------- caffe2/proto/hsm.proto | 62 ---------- caffe2/proto/hsm_pb2.pyi | 109 ------------------ caffe2/proto/metanet.proto | 50 -------- caffe2/proto/metanet_pb2.pyi | 160 -------------------------- caffe2/proto/predictor_consts.proto | 36 ------ caffe2/proto/predictor_consts_pb2.pyi | 63 ---------- caffe2/proto/prof_dag.proto | 68 ----------- caffe2/proto/prof_dag_pb2.pyi | 126 -------------------- 12 files changed, 803 deletions(-) delete mode 100644 caffe2/proto/caffe2_legacy.proto delete mode 100644 caffe2/proto/caffe2_legacy_pb2.pyi delete mode 100644 caffe2/proto/hsm.proto delete mode 100644 caffe2/proto/hsm_pb2.pyi delete mode 100644 caffe2/proto/metanet.proto delete mode 100644 caffe2/proto/metanet_pb2.pyi delete mode 100644 caffe2/proto/predictor_consts.proto delete mode 100644 caffe2/proto/predictor_consts_pb2.pyi delete mode 100644 caffe2/proto/prof_dag.proto delete mode 100644 caffe2/proto/prof_dag_pb2.pyi diff --git a/BUILD.bazel b/BUILD.bazel index 6d01ff42305c..ecbeaab9bbf8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -552,7 +552,6 @@ cc_library( ":caffe2_core_macros", ":caffe2_for_aten_headers", "//caffe2/proto:caffe2_pb", - "//caffe2/proto:cc_proto", ], ) @@ -574,7 +573,6 @@ cc_library( ":caffe2_perfkernels_avx2", ":caffe2_perfkernels_avx512", "//caffe2/proto:caffe2_pb", - "//caffe2/proto:cc_proto", "//third_party/miniz-2.1.0:miniz", "@com_google_protobuf//:protobuf", "@eigen", @@ -782,7 +780,6 @@ cc_library( deps = [ ":caffe2", ":torch_headers", - "//caffe2/proto:torch_cc_proto", "@kineto", ] + if_cuda([ "@cuda//:nvToolsExt", diff --git a/caffe2/proto/BUILD.bazel b/caffe2/proto/BUILD.bazel index dcffaac0e3de..58766661ac67 100644 --- a/caffe2/proto/BUILD.bazel +++ b/caffe2/proto/BUILD.bazel @@ -35,21 +35,3 @@ proto_library( srcs = ["torch.proto"], deps = [":caffe2_proto"], ) - -cc_proto_library( - name = "cc_proto", - visibility = ["//:__pkg__"], - deps = [":proto"], -) - -proto_library( - name = "proto", - srcs = [ - "caffe2_legacy.proto", - "hsm.proto", - "metanet.proto", - "predictor_consts.proto", - "prof_dag.proto", - ], - deps = [":caffe2_proto"], -) diff --git a/caffe2/proto/caffe2_legacy.proto b/caffe2/proto/caffe2_legacy.proto deleted file mode 100644 index 4fb2cda002fe..000000000000 --- a/caffe2/proto/caffe2_legacy.proto +++ /dev/null @@ -1,50 +0,0 @@ -syntax = "proto2"; - -package caffe2; - -// Original Caffe1 Datum copy: this is used in image input op to allow us to -// load caffe1 serialized datum without having to regenerate the database. -message CaffeDatum { - optional int32 channels = 1; - optional int32 height = 2; - optional int32 width = 3; - // the actual image data, in bytes - optional bytes data = 4; - optional int32 label = 5; - // Optionally, the datum could also hold float data. - repeated float float_data = 6; - // If true data contains an encoded image that need to be decoded - optional bool encoded = 7 [ default = false ]; -} - -enum LegacyPadding { - NOTSET = 0; // Do not use old-stype padding strategies. - - // VALID and SAME are two strategies adopted in Google DistBelief: it forces - // the input shape as follows. For SAME, the output is: - // R_out = ceil(float(R) / float(S)) - // C_out = ceil(float(C) / float(S)) - // where R and C are row and column, S is the stride, and K is the kernel. - // The number of padded pixels is then computed as - // Pr = ((R_out - 1) * S + K - R) - // Pc = ((C_out - 1) * S + K - C) - // When Pr and Pc are even numbers, both sides (left and right, or top and - // bottom) get half each. When Pr and Pc are odd numbers, the right and the - // bottom gets the one additional padding pixel. - // For VALID, padding values of 0 are always used. - VALID = 1; - SAME = 2; - - // CAFFE_LEGACY_POOLING is a flag that notifies the code to use the old Caffe - // padding strategy. - // Basically, in caffe2, after padding the convolution and pooling use the - // same computation strategy: half-windows at the right and bottom are - // discarded. In Caffe, convolution follows this strategy but if there are - // some pixels in the half-windows, the pooling layer will actually put one - // additional output. If you set LegacyPadding to this, we will compute the - // equivalent padding strategy in caffe2 so that the output size is - // backward compatible with Caffe. - // THIS IS NOW DEPRECATED. ANY non-conventional use has to be manually - // converted. - CAFFE_LEGACY_POOLING = 3; -} diff --git a/caffe2/proto/caffe2_legacy_pb2.pyi b/caffe2/proto/caffe2_legacy_pb2.pyi deleted file mode 100644 index eaee65471eef..000000000000 --- a/caffe2/proto/caffe2_legacy_pb2.pyi +++ /dev/null @@ -1,58 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -global___LegacyPadding = LegacyPadding -class _LegacyPadding(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[LegacyPadding], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - NOTSET = LegacyPadding.V(0) - VALID = LegacyPadding.V(1) - SAME = LegacyPadding.V(2) - CAFFE_LEGACY_POOLING = LegacyPadding.V(3) -class LegacyPadding(metaclass=_LegacyPadding): - V = typing.NewType('V', int) -NOTSET = LegacyPadding.V(0) -VALID = LegacyPadding.V(1) -SAME = LegacyPadding.V(2) -CAFFE_LEGACY_POOLING = LegacyPadding.V(3) - -class CaffeDatum(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - CHANNELS_FIELD_NUMBER: int - HEIGHT_FIELD_NUMBER: int - WIDTH_FIELD_NUMBER: int - DATA_FIELD_NUMBER: int - LABEL_FIELD_NUMBER: int - FLOAT_DATA_FIELD_NUMBER: int - ENCODED_FIELD_NUMBER: int - channels: int = ... - height: int = ... - width: int = ... - data: bytes = ... - label: int = ... - float_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - encoded: bool = ... - - def __init__(self, - *, - channels : typing.Optional[int] = ..., - height : typing.Optional[int] = ..., - width : typing.Optional[int] = ..., - data : typing.Optional[bytes] = ..., - label : typing.Optional[int] = ..., - float_data : typing.Optional[typing.Iterable[float]] = ..., - encoded : typing.Optional[bool] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"channels",b"channels",u"data",b"data",u"encoded",b"encoded",u"height",b"height",u"label",b"label",u"width",b"width"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"channels",b"channels",u"data",b"data",u"encoded",b"encoded",u"float_data",b"float_data",u"height",b"height",u"label",b"label",u"width",b"width"]) -> None: ... -global___CaffeDatum = CaffeDatum diff --git a/caffe2/proto/hsm.proto b/caffe2/proto/hsm.proto deleted file mode 100644 index 2e3152cc332e..000000000000 --- a/caffe2/proto/hsm.proto +++ /dev/null @@ -1,62 +0,0 @@ -syntax = "proto2"; - -package caffe2; - -// Hierarchical Softmax protobuffer convention: -// The HSM operator requires a hierarchy of vocabulary words in the form of a -// tree from the user. This tree is expressed using the proto format. -// TreeProto points to the root NodeProto which can recursively contain children -// NodeProtos (internal nodes) or word_ids (leaf nodes). - -// The aforementioned TreeProto is internally translated into a list of word_ids -// tagged with a list of NodeProtos that lie in the path from the root to that -// word_id using hsm_util.create_hierarchy(tree_proto). -// Specifically, HierarchyProto contains a list of PathProtos. Each PathProto -// belongs to a word_id and contains a list of PathNodeProtos. Each -// PathNodeProto contains information about the number of children the node has -// (length), the index of the child node that lies in the path from root to -// word_id (target) and a cumulative sum of children nodes (index; this acts as -// the weight parameter matrix offset). - -// Each node in the hierarchy contains links to either leaf nodes or more -// non-terminal nodes -message NodeProto { - // Links to non-terminal children nodes - repeated NodeProto children = 1; - // Links to terminal (leaf) nodes - repeated int32 word_ids = 2; - optional int32 offset = 3; - optional string name = 4; - repeated float scores = 5; -} - -// Protobuf format to accept hierarchy for hierarchical softmax operator. -// TreeProto points to the root node. -message TreeProto { - optional NodeProto root_node = 1; -} - -// Internal Protobuf format which represents the path in the tree hierarchy for -// each word in the vocabulary. -message HierarchyProto { - optional int32 size = 1; - repeated PathProto paths = 2; -} - -// Each PathProto belongs to a word and is an array of nodes in the -// path from the root to the leaf (which is the word itself) in the tree. -message PathProto { - optional int32 word_id = 1; - repeated PathNodeProto path_nodes = 2; -} - -// Represents a node in the path from the root node all the way down to the -// word (leaf). -message PathNodeProto { - // Parameter matrix offset for this node - optional int32 index = 1; - // Number of children - optional int32 length = 2; - // Index of the next node in the path - optional int32 target = 3; -} diff --git a/caffe2/proto/hsm_pb2.pyi b/caffe2/proto/hsm_pb2.pyi deleted file mode 100644 index 86a47f58d17c..000000000000 --- a/caffe2/proto/hsm_pb2.pyi +++ /dev/null @@ -1,109 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -class NodeProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - CHILDREN_FIELD_NUMBER: int - WORD_IDS_FIELD_NUMBER: int - OFFSET_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - SCORES_FIELD_NUMBER: int - word_ids: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - offset: int = ... - name: typing.Text = ... - scores: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - - @property - def children(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___NodeProto]: ... - - def __init__(self, - *, - children : typing.Optional[typing.Iterable[global___NodeProto]] = ..., - word_ids : typing.Optional[typing.Iterable[int]] = ..., - offset : typing.Optional[int] = ..., - name : typing.Optional[typing.Text] = ..., - scores : typing.Optional[typing.Iterable[float]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"name",b"name",u"offset",b"offset"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"children",b"children",u"name",b"name",u"offset",b"offset",u"scores",b"scores",u"word_ids",b"word_ids"]) -> None: ... -global___NodeProto = NodeProto - -class TreeProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - ROOT_NODE_FIELD_NUMBER: int - - @property - def root_node(self) -> global___NodeProto: ... - - def __init__(self, - *, - root_node : typing.Optional[global___NodeProto] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"root_node",b"root_node"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"root_node",b"root_node"]) -> None: ... -global___TreeProto = TreeProto - -class HierarchyProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - SIZE_FIELD_NUMBER: int - PATHS_FIELD_NUMBER: int - size: int = ... - - @property - def paths(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PathProto]: ... - - def __init__(self, - *, - size : typing.Optional[int] = ..., - paths : typing.Optional[typing.Iterable[global___PathProto]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"size",b"size"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"paths",b"paths",u"size",b"size"]) -> None: ... -global___HierarchyProto = HierarchyProto - -class PathProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - WORD_ID_FIELD_NUMBER: int - PATH_NODES_FIELD_NUMBER: int - word_id: int = ... - - @property - def path_nodes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PathNodeProto]: ... - - def __init__(self, - *, - word_id : typing.Optional[int] = ..., - path_nodes : typing.Optional[typing.Iterable[global___PathNodeProto]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"word_id",b"word_id"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"path_nodes",b"path_nodes",u"word_id",b"word_id"]) -> None: ... -global___PathProto = PathProto - -class PathNodeProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - INDEX_FIELD_NUMBER: int - LENGTH_FIELD_NUMBER: int - TARGET_FIELD_NUMBER: int - index: int = ... - length: int = ... - target: int = ... - - def __init__(self, - *, - index : typing.Optional[int] = ..., - length : typing.Optional[int] = ..., - target : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"index",b"index",u"length",b"length",u"target",b"target"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"index",b"index",u"length",b"length",u"target",b"target"]) -> None: ... -global___PathNodeProto = PathNodeProto diff --git a/caffe2/proto/metanet.proto b/caffe2/proto/metanet.proto deleted file mode 100644 index 8008610ac0fa..000000000000 --- a/caffe2/proto/metanet.proto +++ /dev/null @@ -1,50 +0,0 @@ -syntax = "proto2"; - -import "caffe2/proto/caffe2.proto"; - -package caffe2; - -message ModelInfo { - optional string project = 1; - optional string modelClass = 2; - optional string version = 3; - optional string predictorType = 4 [ default = "SINGLE_PREDICTOR" ]; - optional string modelId = 5; -} - -message BlobsMap { - required string key = 1; - repeated string value = 2; -} - -message NetsMap { - required string key = 1; - required NetDef value = 2; -} - -message PlansMap { - required string key = 1; - required PlanDef value = 2; -} - -message StringMap { - required string key = 1; - required string value = 2; -} - -message MetaNetDef { - repeated BlobsMap blobs = 1; - // Text-format serialized NetDefs. - repeated NetsMap nets = 2; - // Info about where the model comes from. Possible use cases: - // 1) sanity check or diagnose - // 2) provide info for evaluation. - optional ModelInfo modelInfo = 3; - repeated PlansMap plans = 4; - repeated StringMap applicationSpecificInfo = 5; - repeated string blobsOrder = 6; - repeated string preLoadBlobs = 7; - optional TensorBoundShapes tensorBoundShapes = 8; - repeated string requestOnlyEmbeddings = 9; - optional AOTConfig aotConfig = 10; -} diff --git a/caffe2/proto/metanet_pb2.pyi b/caffe2/proto/metanet_pb2.pyi deleted file mode 100644 index 096fd90df876..000000000000 --- a/caffe2/proto/metanet_pb2.pyi +++ /dev/null @@ -1,160 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import caffe2.proto.caffe2_pb2 -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -class ModelInfo(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - PROJECT_FIELD_NUMBER: int - MODELCLASS_FIELD_NUMBER: int - VERSION_FIELD_NUMBER: int - PREDICTORTYPE_FIELD_NUMBER: int - MODELID_FIELD_NUMBER: int - project: typing.Text = ... - modelClass: typing.Text = ... - version: typing.Text = ... - predictorType: typing.Text = ... - modelId: typing.Text = ... - - def __init__(self, - *, - project : typing.Optional[typing.Text] = ..., - modelClass : typing.Optional[typing.Text] = ..., - version : typing.Optional[typing.Text] = ..., - predictorType : typing.Optional[typing.Text] = ..., - modelId : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"modelClass",b"modelClass",u"modelId",b"modelId",u"predictorType",b"predictorType",u"project",b"project",u"version",b"version"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"modelClass",b"modelClass",u"modelId",b"modelId",u"predictorType",b"predictorType",u"project",b"project",u"version",b"version"]) -> None: ... -global___ModelInfo = ModelInfo - -class BlobsMap(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - VALUE_FIELD_NUMBER: int - key: typing.Text = ... - value: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - value : typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> None: ... -global___BlobsMap = BlobsMap - -class NetsMap(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - VALUE_FIELD_NUMBER: int - key: typing.Text = ... - - @property - def value(self) -> caffe2.proto.caffe2_pb2.NetDef: ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - value : typing.Optional[caffe2.proto.caffe2_pb2.NetDef] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> None: ... -global___NetsMap = NetsMap - -class PlansMap(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - VALUE_FIELD_NUMBER: int - key: typing.Text = ... - - @property - def value(self) -> caffe2.proto.caffe2_pb2.PlanDef: ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - value : typing.Optional[caffe2.proto.caffe2_pb2.PlanDef] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> None: ... -global___PlansMap = PlansMap - -class StringMap(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - VALUE_FIELD_NUMBER: int - key: typing.Text = ... - value: typing.Text = ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - value : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> None: ... -global___StringMap = StringMap - -class MetaNetDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - BLOBS_FIELD_NUMBER: int - NETS_FIELD_NUMBER: int - MODELINFO_FIELD_NUMBER: int - PLANS_FIELD_NUMBER: int - APPLICATIONSPECIFICINFO_FIELD_NUMBER: int - BLOBSORDER_FIELD_NUMBER: int - PRELOADBLOBS_FIELD_NUMBER: int - TENSORBOUNDSHAPES_FIELD_NUMBER: int - REQUESTONLYEMBEDDINGS_FIELD_NUMBER: int - AOTCONFIG_FIELD_NUMBER: int - blobsOrder: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - preLoadBlobs: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - requestOnlyEmbeddings: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - - @property - def blobs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___BlobsMap]: ... - - @property - def nets(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___NetsMap]: ... - - @property - def modelInfo(self) -> global___ModelInfo: ... - - @property - def plans(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PlansMap]: ... - - @property - def applicationSpecificInfo(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___StringMap]: ... - - @property - def tensorBoundShapes(self) -> caffe2.proto.caffe2_pb2.TensorBoundShapes: ... - - @property - def aotConfig(self) -> caffe2.proto.caffe2_pb2.AOTConfig: ... - - def __init__(self, - *, - blobs : typing.Optional[typing.Iterable[global___BlobsMap]] = ..., - nets : typing.Optional[typing.Iterable[global___NetsMap]] = ..., - modelInfo : typing.Optional[global___ModelInfo] = ..., - plans : typing.Optional[typing.Iterable[global___PlansMap]] = ..., - applicationSpecificInfo : typing.Optional[typing.Iterable[global___StringMap]] = ..., - blobsOrder : typing.Optional[typing.Iterable[typing.Text]] = ..., - preLoadBlobs : typing.Optional[typing.Iterable[typing.Text]] = ..., - tensorBoundShapes : typing.Optional[caffe2.proto.caffe2_pb2.TensorBoundShapes] = ..., - requestOnlyEmbeddings : typing.Optional[typing.Iterable[typing.Text]] = ..., - aotConfig : typing.Optional[caffe2.proto.caffe2_pb2.AOTConfig] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"aotConfig",b"aotConfig",u"modelInfo",b"modelInfo",u"tensorBoundShapes",b"tensorBoundShapes"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"aotConfig",b"aotConfig",u"applicationSpecificInfo",b"applicationSpecificInfo",u"blobs",b"blobs",u"blobsOrder",b"blobsOrder",u"modelInfo",b"modelInfo",u"nets",b"nets",u"plans",b"plans",u"preLoadBlobs",b"preLoadBlobs",u"requestOnlyEmbeddings",b"requestOnlyEmbeddings",u"tensorBoundShapes",b"tensorBoundShapes"]) -> None: ... -global___MetaNetDef = MetaNetDef diff --git a/caffe2/proto/predictor_consts.proto b/caffe2/proto/predictor_consts.proto deleted file mode 100644 index d45ecb8396c7..000000000000 --- a/caffe2/proto/predictor_consts.proto +++ /dev/null @@ -1,36 +0,0 @@ -syntax = "proto2"; - -package caffe2; - -message PredictorConsts { - // Important - to ensure ordered traversal of the DB, these must be - // set in the given (lexicographic) order in the input DBReader. - optional string META_NET_DEF = 1 [ default = "!!META_NET_DEF" ]; - - // The key the Predictor sets in the global workspace for DBReader - // consumed by the LoadOp in GLOBAL_INIT_NET. - - optional string PREDICTOR_DBREADER = 2 [ default = "!!PREDICTOR_DBREADER" ]; - - // Blob types used in MetaNetDef blobs - optional string PARAMETERS_BLOB_TYPE = 3 [ default = "PARAMETERS_BLOB_TYPE" ]; - optional string INPUTS_BLOB_TYPE = 4 [ default = "INPUTS_BLOB_TYPE" ]; - optional string OUTPUTS_BLOB_TYPE = 5 [ default = "OUTPUTS_BLOB_TYPE" ]; - - // Net types used in MetaNetDef nets - optional string GLOBAL_INIT_NET_TYPE = 6 [ default = "GLOBAL_INIT_NET_TYPE" ]; - optional string PREDICT_INIT_NET_TYPE = 7 - [ default = "PREDICT_INIT_NET_TYPE" ]; - optional string PREDICT_NET_TYPE = 8 [ default = "PREDICT_NET_TYPE" ]; - optional string SINGLE_PREDICTOR = 9 [ default = "SINGLE_PREDICTOR" ]; - optional string MULTI_PREDICTOR = 10 [ default = "MULTI_PREDICTOR" ]; - optional string TRAIN_INIT_PLAN_TYPE = 11 - [ default = "TRAIN_INIT_PLAN_TYPE" ]; - optional string TRAIN_PLAN_TYPE = 12 [ default = "TRAIN_PLAN_TYPE" ]; - - // Shape info blob name - optional string SHAPE_INFO_BLOB = 13 [ default = "SHAPE_INFO_BLOB" ]; - // Sequential blob reader name - optional string DEFERRED_BLOB_READER = 14 - [ default = "__DEFERRED_BLOB_READER__" ]; -} diff --git a/caffe2/proto/predictor_consts_pb2.pyi b/caffe2/proto/predictor_consts_pb2.pyi deleted file mode 100644 index 83b62ae0e949..000000000000 --- a/caffe2/proto/predictor_consts_pb2.pyi +++ /dev/null @@ -1,63 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import google.protobuf.descriptor -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -class PredictorConsts(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - META_NET_DEF_FIELD_NUMBER: int - PREDICTOR_DBREADER_FIELD_NUMBER: int - PARAMETERS_BLOB_TYPE_FIELD_NUMBER: int - INPUTS_BLOB_TYPE_FIELD_NUMBER: int - OUTPUTS_BLOB_TYPE_FIELD_NUMBER: int - GLOBAL_INIT_NET_TYPE_FIELD_NUMBER: int - PREDICT_INIT_NET_TYPE_FIELD_NUMBER: int - PREDICT_NET_TYPE_FIELD_NUMBER: int - SINGLE_PREDICTOR_FIELD_NUMBER: int - MULTI_PREDICTOR_FIELD_NUMBER: int - TRAIN_INIT_PLAN_TYPE_FIELD_NUMBER: int - TRAIN_PLAN_TYPE_FIELD_NUMBER: int - SHAPE_INFO_BLOB_FIELD_NUMBER: int - DEFERRED_BLOB_READER_FIELD_NUMBER: int - META_NET_DEF: typing.Text = ... - PREDICTOR_DBREADER: typing.Text = ... - PARAMETERS_BLOB_TYPE: typing.Text = ... - INPUTS_BLOB_TYPE: typing.Text = ... - OUTPUTS_BLOB_TYPE: typing.Text = ... - GLOBAL_INIT_NET_TYPE: typing.Text = ... - PREDICT_INIT_NET_TYPE: typing.Text = ... - PREDICT_NET_TYPE: typing.Text = ... - SINGLE_PREDICTOR: typing.Text = ... - MULTI_PREDICTOR: typing.Text = ... - TRAIN_INIT_PLAN_TYPE: typing.Text = ... - TRAIN_PLAN_TYPE: typing.Text = ... - SHAPE_INFO_BLOB: typing.Text = ... - DEFERRED_BLOB_READER: typing.Text = ... - - def __init__(self, - *, - META_NET_DEF : typing.Optional[typing.Text] = ..., - PREDICTOR_DBREADER : typing.Optional[typing.Text] = ..., - PARAMETERS_BLOB_TYPE : typing.Optional[typing.Text] = ..., - INPUTS_BLOB_TYPE : typing.Optional[typing.Text] = ..., - OUTPUTS_BLOB_TYPE : typing.Optional[typing.Text] = ..., - GLOBAL_INIT_NET_TYPE : typing.Optional[typing.Text] = ..., - PREDICT_INIT_NET_TYPE : typing.Optional[typing.Text] = ..., - PREDICT_NET_TYPE : typing.Optional[typing.Text] = ..., - SINGLE_PREDICTOR : typing.Optional[typing.Text] = ..., - MULTI_PREDICTOR : typing.Optional[typing.Text] = ..., - TRAIN_INIT_PLAN_TYPE : typing.Optional[typing.Text] = ..., - TRAIN_PLAN_TYPE : typing.Optional[typing.Text] = ..., - SHAPE_INFO_BLOB : typing.Optional[typing.Text] = ..., - DEFERRED_BLOB_READER : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"DEFERRED_BLOB_READER",b"DEFERRED_BLOB_READER",u"GLOBAL_INIT_NET_TYPE",b"GLOBAL_INIT_NET_TYPE",u"INPUTS_BLOB_TYPE",b"INPUTS_BLOB_TYPE",u"META_NET_DEF",b"META_NET_DEF",u"MULTI_PREDICTOR",b"MULTI_PREDICTOR",u"OUTPUTS_BLOB_TYPE",b"OUTPUTS_BLOB_TYPE",u"PARAMETERS_BLOB_TYPE",b"PARAMETERS_BLOB_TYPE",u"PREDICTOR_DBREADER",b"PREDICTOR_DBREADER",u"PREDICT_INIT_NET_TYPE",b"PREDICT_INIT_NET_TYPE",u"PREDICT_NET_TYPE",b"PREDICT_NET_TYPE",u"SHAPE_INFO_BLOB",b"SHAPE_INFO_BLOB",u"SINGLE_PREDICTOR",b"SINGLE_PREDICTOR",u"TRAIN_INIT_PLAN_TYPE",b"TRAIN_INIT_PLAN_TYPE",u"TRAIN_PLAN_TYPE",b"TRAIN_PLAN_TYPE"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"DEFERRED_BLOB_READER",b"DEFERRED_BLOB_READER",u"GLOBAL_INIT_NET_TYPE",b"GLOBAL_INIT_NET_TYPE",u"INPUTS_BLOB_TYPE",b"INPUTS_BLOB_TYPE",u"META_NET_DEF",b"META_NET_DEF",u"MULTI_PREDICTOR",b"MULTI_PREDICTOR",u"OUTPUTS_BLOB_TYPE",b"OUTPUTS_BLOB_TYPE",u"PARAMETERS_BLOB_TYPE",b"PARAMETERS_BLOB_TYPE",u"PREDICTOR_DBREADER",b"PREDICTOR_DBREADER",u"PREDICT_INIT_NET_TYPE",b"PREDICT_INIT_NET_TYPE",u"PREDICT_NET_TYPE",b"PREDICT_NET_TYPE",u"SHAPE_INFO_BLOB",b"SHAPE_INFO_BLOB",u"SINGLE_PREDICTOR",b"SINGLE_PREDICTOR",u"TRAIN_INIT_PLAN_TYPE",b"TRAIN_INIT_PLAN_TYPE",u"TRAIN_PLAN_TYPE",b"TRAIN_PLAN_TYPE"]) -> None: ... -global___PredictorConsts = PredictorConsts diff --git a/caffe2/proto/prof_dag.proto b/caffe2/proto/prof_dag.proto deleted file mode 100644 index ab427a1c66fa..000000000000 --- a/caffe2/proto/prof_dag.proto +++ /dev/null @@ -1,68 +0,0 @@ -syntax = "proto2"; - -package caffe2; - -// A few notes about the Caffe2's protobuffer convention: -// (1) Most objects are registered by their types, such as operators and nets. -// For these, we have a string-type field "type" for registration purposes. -// (2) We do not use extension because that used to create quite some conflicts -// in Caffe's protobuf design. -// (3) We have not used any proto3 specific features, such as Any or Map. This -// is mainly for backward compatibility purposes but we may consider using -// those in the future. - -// A two number summary for a value. It also has count for restoring. -message TwoNumberStatsProto { - optional float mean = 1; - optional float stddev = 2; - optional int64 count = 3; -} - -// Blob profiling information. Profile for a blob is created every time -// a node outputs to the blob. -message BlobProfile { - // Name of the blob (corresponds to OperatorDef.output). - optional string name = 1; // required - - // Profiling statistics. - optional TwoNumberStatsProto bytes_used = 3; -} - -// Protobuf format to serialize profiler data. -message ProfDAGProto { - // The name for the operator - required string name = 1; - // The mean execution time - required float mean = 2; - // The standard deviation - required float stddev = 3; - - // New field to represent the numbers above, and with count. - optional TwoNumberStatsProto execution_time = 4; - - // Blob profiles that this node outputs. - repeated BlobProfile output_profile = 5; - - // The extra_info from the operator device option. - repeated string extra_info = 7; -} - -// Operator profiling information. -// -// Note: The indices for elements of 'stats' and the indices of -// 'output_profile' inside each 'stats' are assumed to match the -// indices of 'op' elements of a corresponding NetDef and the 'output' -// indices within each 'op'. -message ProfDAGProtos { - repeated ProfDAGProto stats = 1; - optional string net_name = 2; - repeated OpProfile ops_stats = 3; -} - -// Represents specification of an operation cost. -message OpProfile { - optional string idx = 1; - optional string net_name = 2; - optional string type = 3; - optional float exec_time_secs = 4; -} diff --git a/caffe2/proto/prof_dag_pb2.pyi b/caffe2/proto/prof_dag_pb2.pyi deleted file mode 100644 index 98affd51fd0b..000000000000 --- a/caffe2/proto/prof_dag_pb2.pyi +++ /dev/null @@ -1,126 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -class TwoNumberStatsProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - MEAN_FIELD_NUMBER: int - STDDEV_FIELD_NUMBER: int - COUNT_FIELD_NUMBER: int - mean: float = ... - stddev: float = ... - count: int = ... - - def __init__(self, - *, - mean : typing.Optional[float] = ..., - stddev : typing.Optional[float] = ..., - count : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"count",b"count",u"mean",b"mean",u"stddev",b"stddev"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"count",b"count",u"mean",b"mean",u"stddev",b"stddev"]) -> None: ... -global___TwoNumberStatsProto = TwoNumberStatsProto - -class BlobProfile(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - BYTES_USED_FIELD_NUMBER: int - name: typing.Text = ... - - @property - def bytes_used(self) -> global___TwoNumberStatsProto: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - bytes_used : typing.Optional[global___TwoNumberStatsProto] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"bytes_used",b"bytes_used",u"name",b"name"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"bytes_used",b"bytes_used",u"name",b"name"]) -> None: ... -global___BlobProfile = BlobProfile - -class ProfDAGProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - MEAN_FIELD_NUMBER: int - STDDEV_FIELD_NUMBER: int - EXECUTION_TIME_FIELD_NUMBER: int - OUTPUT_PROFILE_FIELD_NUMBER: int - EXTRA_INFO_FIELD_NUMBER: int - name: typing.Text = ... - mean: float = ... - stddev: float = ... - extra_info: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - - @property - def execution_time(self) -> global___TwoNumberStatsProto: ... - - @property - def output_profile(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___BlobProfile]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - mean : typing.Optional[float] = ..., - stddev : typing.Optional[float] = ..., - execution_time : typing.Optional[global___TwoNumberStatsProto] = ..., - output_profile : typing.Optional[typing.Iterable[global___BlobProfile]] = ..., - extra_info : typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"execution_time",b"execution_time",u"mean",b"mean",u"name",b"name",u"stddev",b"stddev"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"execution_time",b"execution_time",u"extra_info",b"extra_info",u"mean",b"mean",u"name",b"name",u"output_profile",b"output_profile",u"stddev",b"stddev"]) -> None: ... -global___ProfDAGProto = ProfDAGProto - -class ProfDAGProtos(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - STATS_FIELD_NUMBER: int - NET_NAME_FIELD_NUMBER: int - OPS_STATS_FIELD_NUMBER: int - net_name: typing.Text = ... - - @property - def stats(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ProfDAGProto]: ... - - @property - def ops_stats(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___OpProfile]: ... - - def __init__(self, - *, - stats : typing.Optional[typing.Iterable[global___ProfDAGProto]] = ..., - net_name : typing.Optional[typing.Text] = ..., - ops_stats : typing.Optional[typing.Iterable[global___OpProfile]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"net_name",b"net_name"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"net_name",b"net_name",u"ops_stats",b"ops_stats",u"stats",b"stats"]) -> None: ... -global___ProfDAGProtos = ProfDAGProtos - -class OpProfile(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - IDX_FIELD_NUMBER: int - NET_NAME_FIELD_NUMBER: int - TYPE_FIELD_NUMBER: int - EXEC_TIME_SECS_FIELD_NUMBER: int - idx: typing.Text = ... - net_name: typing.Text = ... - type: typing.Text = ... - exec_time_secs: float = ... - - def __init__(self, - *, - idx : typing.Optional[typing.Text] = ..., - net_name : typing.Optional[typing.Text] = ..., - type : typing.Optional[typing.Text] = ..., - exec_time_secs : typing.Optional[float] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"exec_time_secs",b"exec_time_secs",u"idx",b"idx",u"net_name",b"net_name",u"type",b"type"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"exec_time_secs",b"exec_time_secs",u"idx",b"idx",u"net_name",b"net_name",u"type",b"type"]) -> None: ... -global___OpProfile = OpProfile From 38a33c3202a7098019504abd8220797edc4f5fe4 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 29 May 2024 20:37:23 +0000 Subject: [PATCH 059/706] don't call .item in onehot for XLA (#127335) We found that `nn.function.one_hot` will cause a graph break due to the item call in the native implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127335 Approved by: https://github.com/ezyang --- aten/src/ATen/native/Onehot.cpp | 4 ++-- test/test_nn.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp index 97c35599f791..ffd19b2e93a9 100644 --- a/aten/src/ATen/native/Onehot.cpp +++ b/aten/src/ATen/native/Onehot.cpp @@ -43,7 +43,7 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { // non-empty tensor if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS && - self.device().type() != at::kPrivateUse1) { + self.device().type() != at::kPrivateUse1 && self.device().type() != at::kXLA) { // for cuda, rely on device assert thrown by scatter TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative."); } @@ -51,7 +51,7 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { num_classes = self.max().item().toLong() + 1; } else { if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS && - self.device().type() != at::kPrivateUse1) { + self.device().type() != at::kPrivateUse1 && self.device().type() != at::kXLA) { // rely on device asserts from scatter to avoid sync here TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes."); } else { diff --git a/test/test_nn.py b/test/test_nn.py index d49c9bc1eec4..c6f3a61e972b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8936,7 +8936,9 @@ def test_linear_empty(self, device): _test_module_empty_input(self, mod, inp) def test_one_hot(self, device): - if self.device_type != 'cuda': # cuda throws device assert for invalid data + # cuda throws device assert for invalid data + # xla ignores out of bound indices + if self.device_type != 'cuda' and self.device_type != 'xla': with self.assertRaises(RuntimeError): torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1) From d99b115eb3a824c92468b6a1ab583bafd2ab2e46 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Wed, 29 May 2024 21:17:09 +0000 Subject: [PATCH 060/706] Fix delete old branches workflow (#127442) The ubuntu runner started using 2.45.1 (prev 2.43.2), which includes https://github.com/git/git/commit/1f49f7506f0d840e048ba78f7e0544c407568b58 (changes +00:00 timezone to Z) Python versions prior to 3.11 do not support Z when parsing isoformat, so update the workflow to use 3.11 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127442 Approved by: https://github.com/huydhn, https://github.com/malfet --- .github/workflows/delete_old_branches.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/delete_old_branches.yml b/.github/workflows/delete_old_branches.yml index 04a0521419a8..eabb98e32065 100644 --- a/.github/workflows/delete_old_branches.yml +++ b/.github/workflows/delete_old_branches.yml @@ -29,7 +29,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.11' architecture: x64 check-latest: false From b0ef363972203b163cddc95e4c6054b8221c2300 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 29 May 2024 10:49:27 -0700 Subject: [PATCH 061/706] [dtensor] rename _Partial -> Partial for all imports (#127420) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127420 Approved by: https://github.com/awgu --- test/distributed/_tensor/test_dtensor.py | 8 ++++---- test/distributed/_tensor/test_dtensor_compile.py | 10 +++++----- test/distributed/_tensor/test_math_ops.py | 2 +- test/distributed/_tensor/test_matrix_ops.py | 10 +++++----- test/distributed/_tensor/test_op_strategy.py | 8 ++++---- test/distributed/_tensor/test_pointwise_ops.py | 12 ++++++------ test/distributed/_tensor/test_redistribute.py | 14 +++++++------- test/distributed/_tensor/test_tensor_ops.py | 14 +++++++------- torch/distributed/_tensor/_redistribute.py | 8 ++++---- torch/distributed/_tensor/_utils.py | 4 ++-- torch/distributed/_tensor/api.py | 8 ++++---- torch/distributed/_tensor/ops/basic_strategy.py | 8 ++++---- torch/distributed/_tensor/ops/embedding_ops.py | 10 +++++----- torch/distributed/_tensor/ops/math_ops.py | 8 ++++---- torch/distributed/_tensor/ops/pointwise_ops.py | 4 ++-- torch/distributed/_tensor/ops/random_ops.py | 2 +- torch/distributed/_tensor/ops/tensor_ops.py | 14 +++++++------- torch/distributed/_tensor/ops/utils.py | 4 ++-- torch/distributed/_tensor/ops/view_ops.py | 2 +- torch/distributed/_tensor/placement_types.py | 14 +++++++------- 20 files changed, 82 insertions(+), 82 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index 70a67b8e0b93..531245057e1f 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -14,7 +14,7 @@ init_device_mesh, ) from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard +from torch.distributed._tensor.placement_types import Partial, Replicate, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -174,7 +174,7 @@ def test_from_local(self): ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec) self.assertEqual(ddp_tensor.size(), local_tensor.size()) - partial_spec = [_Partial()] + partial_spec = [Partial()] partial_tensor = DTensor.from_local(local_tensor, device_mesh, partial_spec) self.assertEqual(partial_tensor.size(), local_tensor.size()) @@ -330,7 +330,7 @@ def test_to_local_grad_hint(self): with comm_mode: local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local( - grad_placements=[_Partial()] + grad_placements=[Partial()] ) local_out.backward(torch.ones_like(local_out)) @@ -362,7 +362,7 @@ def test_full_tensor_grad_hint(self): global_tensor = torch.ones(8, 3, requires_grad=True) sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements) - local_out = sharded_dtensor.full_tensor(grad_placements=[_Partial()]) + local_out = sharded_dtensor.full_tensor(grad_placements=[Partial()]) local_out.sum().backward() replica_grad = sharded_dtensor.grad.full_tensor() diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index f40cb4999858..325d18be79f3 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -17,10 +17,10 @@ DeviceMesh, DTensor, init_device_mesh, + Partial, Replicate, Shard, ) -from torch.distributed._tensor.placement_types import _Partial from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, @@ -121,7 +121,7 @@ def fn(x): compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn) - for x in [Shard(0), Replicate(), _Partial()]: + for x in [Shard(0), Replicate(), Partial()]: opt_fn = fn(x) compiled_out = compiled_fn(x) self.assertEqual(opt_fn, compiled_out) @@ -313,7 +313,7 @@ def fn(x): x_dt = DTensor.from_local( x, mesh, - [_Partial()], + [Partial()], run_check=False, shape=(10, 257, 160), stride=(41120, 160, 1), @@ -354,7 +354,7 @@ def fn(x): x_dt = DTensor.from_local( x, mesh, - [_Partial()], + [Partial()], run_check=False, shape=(10, 257, 160), stride=(41120, 160, 1), @@ -515,7 +515,7 @@ def fn(x): return x + x x = torch.randn(4, 4, requires_grad=True) - x_dt = DTensor.from_local(x, mesh, [_Partial()], run_check=False) + x_dt = DTensor.from_local(x, mesh, [Partial()], run_check=False) y = torch.randn(4, 4, requires_grad=True) y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False) diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py index 6f8015bfd0a4..d2ea73ae8c87 100644 --- a/test/distributed/_tensor/test_math_ops.py +++ b/test/distributed/_tensor/test_math_ops.py @@ -371,7 +371,7 @@ def _replicate_fn(name, module, device_mesh): if elementwise_affine: # if input is sharded on any outer dimension, the gradient of weight - # and bias should be _Partial + # and bias should be Partial dim_map = x_dist._spec.dim_map outer_dims = range(norm_idx) needs_reduction = any(dim_map[d] >= 0 for d in outer_dims) diff --git a/test/distributed/_tensor/test_matrix_ops.py b/test/distributed/_tensor/test_matrix_ops.py index fa3f9272c63e..7889ed46ca5e 100644 --- a/test/distributed/_tensor/test_matrix_ops.py +++ b/test/distributed/_tensor/test_matrix_ops.py @@ -10,7 +10,7 @@ from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.placement_types import ( - _Partial, + Partial, Placement, Replicate, Shard, @@ -77,7 +77,7 @@ def test_addmm_auto_redistribute(self): # test if addmm output is a partial self.assertIsInstance(dist_res, DTensor) - self.assertIsInstance(dist_res.placements[0], _Partial) + self.assertIsInstance(dist_res.placements[0], Partial) # test if result is the same as tensor dist_local_res = dist_res.full_tensor() @@ -144,11 +144,11 @@ def test_t_partial(self): da = distribute_tensor(a, device_mesh, [Shard(1)]) db = distribute_tensor(b, device_mesh, [Shard(0)]) - # mm(da, db) should return a _Partial tensor. - # transposing it should keep it _Partial + # mm(da, db) should return a Partial tensor. + # transposing it should keep it Partial dc = torch.mm(da, db).t() - self.assertTrue(isinstance(dc.placements[0], _Partial)) + self.assertTrue(isinstance(dc.placements[0], Partial)) # check that the local and distributed op results match self.assertEqual( diff --git a/test/distributed/_tensor/test_op_strategy.py b/test/distributed/_tensor/test_op_strategy.py index 0cb469e1c405..5194d5bf7d89 100644 --- a/test/distributed/_tensor/test_op_strategy.py +++ b/test/distributed/_tensor/test_op_strategy.py @@ -11,8 +11,8 @@ gen_einsum_strategies, ) from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Replicate, Shard, TensorMeta, @@ -139,7 +139,7 @@ def test_redistribute_cost_mesh_1d(self): mesh_1d = self.build_device_mesh() shard_placement = (Shard(0),) replica_placement = (Replicate(),) - partial_placement = (_Partial(),) + partial_placement = (Partial(),) global_tensor = torch.randn(10, 10) global_tensor_meta = self._extract_tensor_meta(global_tensor) @@ -174,7 +174,7 @@ def test_redistribute_cost_latency(self): mesh = self.build_device_mesh() shard0_placement = (Shard(0),) - partial_placement = (_Partial(),) + partial_placement = (Partial(),) shard1_placement = (Shard(1),) shard0_tensor_meta = self._extract_tensor_meta(torch.randn(8)) @@ -220,7 +220,7 @@ def test_redistribute_cost_mesh_2d(self): ) shard_placement = (Shard(0), Shard(0)) replica_placement = (Replicate(), Replicate()) - partial_placement = (_Partial(), _Partial()) + partial_placement = (Partial(), Partial()) global_tensor = torch.randn(8, 8) global_tensor_meta = self._extract_tensor_meta(global_tensor) diff --git a/test/distributed/_tensor/test_pointwise_ops.py b/test/distributed/_tensor/test_pointwise_ops.py index 4b25efdc9105..f0103bad2de6 100644 --- a/test/distributed/_tensor/test_pointwise_ops.py +++ b/test/distributed/_tensor/test_pointwise_ops.py @@ -11,7 +11,7 @@ from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed._tensor.placement_types import ( - _Partial, + Partial, Placement, Replicate, Shard, @@ -141,15 +141,15 @@ def _run_sharded_elementwise_ops( def test_partial_add(self): device_mesh = self.build_device_mesh() - d_1 = DTensor.from_local(torch.rand(2, 2), device_mesh, [_Partial()]) - d_2 = DTensor.from_local(torch.rand(2, 2), device_mesh, [_Partial()]) + d_1 = DTensor.from_local(torch.rand(2, 2), device_mesh, [Partial()]) + d_2 = DTensor.from_local(torch.rand(2, 2), device_mesh, [Partial()]) d_3 = d_1 + d_2 self.assertTrue(d_3._spec.placements[0].is_partial()) def test_partial_mul(self): device_mesh = self.build_device_mesh() - d_1 = DTensor.from_local(torch.ones(2, 2), device_mesh, [_Partial()]) - d_2 = DTensor.from_local(torch.ones(2, 2), device_mesh, [_Partial()]) + d_1 = DTensor.from_local(torch.ones(2, 2), device_mesh, [Partial()]) + d_2 = DTensor.from_local(torch.ones(2, 2), device_mesh, [Partial()]) d_3 = d_1 * d_2 self.assertTrue(d_3._spec.placements[0].is_replicate()) self.assertEqual(d_3.to_local(), torch.ones(2, 2) * (self.world_size**2)) @@ -256,7 +256,7 @@ def test_dropout_errors(self): with self.assertRaisesRegex(RuntimeError, "supported"): self._run_sharded_elementwise_ops( device_mesh=device_mesh, - placements=[_Partial("sum")], + placements=[Partial("sum")], input_size=(8, 5), op=torch.nn.functional.dropout, ) diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py index c97682b606c7..1d2673a6a7bc 100644 --- a/test/distributed/_tensor/test_redistribute.py +++ b/test/distributed/_tensor/test_redistribute.py @@ -6,7 +6,7 @@ import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard +from torch.distributed._tensor.placement_types import Partial, Replicate, Shard from torch.testing._internal.common_utils import run_tests @@ -105,7 +105,7 @@ def test_replicate_to_local_partial_grad(self): with comm_mode: out = replica_tensor.redistribute(placements=[Replicate()]).to_local( - grad_placements=[_Partial()] + grad_placements=[Partial()] ) out.backward(torch.ones_like(out)) @@ -168,7 +168,7 @@ def test_partial_to_replicate_forward_backward(self): # backward should work as expected device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) partial_local = torch.ones(12, 3, device=self.device_type, requires_grad=True) - partial_spec = [_Partial()] + partial_spec = [Partial()] replica_spec = [Replicate()] comm_mode = CommDebugMode() @@ -199,11 +199,11 @@ def test_partial_to_replicate_forward_backward(self): def test_replicate_to_partial(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) - partial_spec = _Partial() + partial_spec = Partial() replica_spec = Replicate() # 1) test replicate -> partial forward replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec]) - with self.assertRaisesRegex(RuntimeError, "Can not redistribute to _Partial"): + with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"): partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec]) from torch.distributed._tensor._redistribute import Redistribute @@ -246,7 +246,7 @@ def test_replicate_to_partial(self): @with_comms def test_partial_to_shard(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - partial_spec = [_Partial()] + partial_spec = [Partial()] my_rank = device_mesh.get_rank() input_sizes_and_shard_dim = [ @@ -441,7 +441,7 @@ def test_multi_dim_mesh(self): possibilities = [Replicate()] + [Shard(i) for i in range(full_tensor.ndim)] all_outputs = list(itertools.product(*(mesh_shape.ndim * [possibilities]))) all_inputs = list( - itertools.product(*(mesh_shape.ndim * [possibilities + [_Partial()]])) + itertools.product(*(mesh_shape.ndim * [possibilities + [Partial()]])) ) for inputs in all_inputs: diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py index 2d8d726da865..24e527533315 100644 --- a/test/distributed/_tensor/test_tensor_ops.py +++ b/test/distributed/_tensor/test_tensor_ops.py @@ -4,7 +4,7 @@ import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard +from torch.distributed._tensor.placement_types import Partial, Replicate, Shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -92,7 +92,7 @@ def test_inplace_op(self): # test inplace op self and other dtensor with other specs # and make sure out spec not change shard_spec = [Shard(0)] - partial_spec = [_Partial()] + partial_spec = [Partial()] dt_to_inplace_add = distribute_tensor(input_tensor, mesh, shard_spec) partial_grad = DTensor.from_local(torch.randn(12, 3), mesh, partial_spec) res = dt_to_inplace_add.add_(partial_grad) @@ -168,7 +168,7 @@ def test_ones_like(self): @with_comms def test_ones_like_partial_sum(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [_Partial()] + shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) @@ -181,7 +181,7 @@ def test_ones_like_partial_sum(self): @with_comms def test_fill_inplace_partial_sum(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [_Partial()] + shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) @@ -197,7 +197,7 @@ def test_fill_inplace_partial_sum(self): @with_comms def test_zeros_like_partial_sum(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [_Partial()] + shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) @@ -236,8 +236,8 @@ def test_stack(self): mesh_2d = DeviceMesh( self.device_type, torch.arange(self.world_size).reshape(2, 2) ) - partial_replicate_placement = [_Partial(), Replicate()] - partial_placement = [_Partial(), _Partial()] + partial_replicate_placement = [Partial(), Replicate()] + partial_placement = [Partial(), Partial()] partial_replicate_dt = DTensor.from_local( torch.randn(4, 8), mesh_2d, partial_replicate_placement diff --git a/torch/distributed/_tensor/_redistribute.py b/torch/distributed/_tensor/_redistribute.py index 5cef7dbb047c..b72db29157f8 100644 --- a/torch/distributed/_tensor/_redistribute.py +++ b/torch/distributed/_tensor/_redistribute.py @@ -7,8 +7,8 @@ import torch.distributed._tensor.api as dtensor from torch.distributed._tensor.device_mesh import DeviceMesh from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -177,7 +177,7 @@ def redistribute_local_tensor( if target.is_replicate(): # Case 1: target is Replicate if current.is_partial(): - partial_spec = cast(_Partial, current) + partial_spec = cast(Partial, current) new_local_tensor = partial_spec._reduce_value( local_tensor, device_mesh, i ) @@ -195,7 +195,7 @@ def redistribute_local_tensor( target_placement = cast(Shard, target) target_dim = target_placement.dim if current.is_partial(): - partial_spec = cast(_Partial, current) + partial_spec = cast(Partial, current) new_local_tensor = partial_spec._reduce_shard_value( local_tensor, device_mesh, i, target_placement ) @@ -219,7 +219,7 @@ def redistribute_local_tensor( ) elif target.is_partial(): if current.is_replicate(): - partial_spec = cast(_Partial, target) + partial_spec = cast(Partial, target) # skip the replicate to partial transformation when we are in backward pass # In this case we keep the grad as replicate, this is because we don't # want to convert the replicated gradients back to partial, although diff --git a/torch/distributed/_tensor/_utils.py b/torch/distributed/_tensor/_utils.py index 08c381dd3d1d..a3cc8ee5a602 100644 --- a/torch/distributed/_tensor/_utils.py +++ b/torch/distributed/_tensor/_utils.py @@ -4,8 +4,8 @@ import torch.distributed._tensor.api as dtensor from torch._prims_common import ShapeType from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -178,7 +178,7 @@ def compute_global_tensor_info( if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]: # rescale the stride by the shard size tensor_stride[i] = tensor_stride[i] * mesh_dim_size - elif not isinstance(placement, (Replicate, _Partial)): + elif not isinstance(placement, (Replicate, Partial)): raise RuntimeError(f"placement type {type(placement)} not supported!") return tensor_shape, tensor_stride diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index c0c0e1470df5..287d07a7c868 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -15,8 +15,8 @@ ) from torch.distributed._tensor._utils import compute_global_tensor_info from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -275,10 +275,10 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): ) def __coerce_tangent_metadata__(self): - if not any(isinstance(p, _Partial) for p in self.placements): + if not any(isinstance(p, Partial) for p in self.placements): return self placements = [ - Replicate() if isinstance(p, _Partial) else p for p in self.placements + Replicate() if isinstance(p, Partial) else p for p in self.placements ] return self.redistribute(device_mesh=self.device_mesh, placements=placements) @@ -456,7 +456,7 @@ def redistribute( for i, placement in enumerate(placements): if placement.is_partial(): raise RuntimeError( - "Can not redistribute to _Partial, _Partial is for internal use only!" + "Can not redistribute to Partial, redistributing to Partial is for internal use only!" ) elif isinstance(placement, Shard) and placement.dim < 0: # normalize shard dim to be positive diff --git a/torch/distributed/_tensor/ops/basic_strategy.py b/torch/distributed/_tensor/ops/basic_strategy.py index 6274be44cd67..cc28cc19d370 100644 --- a/torch/distributed/_tensor/ops/basic_strategy.py +++ b/torch/distributed/_tensor/ops/basic_strategy.py @@ -5,8 +5,8 @@ from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -126,7 +126,7 @@ def gen_einsum_strategies( # split contracting dim for contracting_dim in edims.contracting_dims: - placement_list = [_Partial()] + placement_list = [Partial()] for input_dim in input_dims: input_contracting_dim = input_dim.index(contracting_dim) placement_list.append(Shard(input_contracting_dim)) @@ -157,9 +157,9 @@ def gen_einsum_strategies( # linearity strategy if linearity: - linearity_placement_list: List[Placement] = [_Partial()] + linearity_placement_list: List[Placement] = [Partial()] for input_dim in input_dims: - linearity_placement_list.append(_Partial()) + linearity_placement_list.append(Partial()) mesh_dim_strategies.append(linearity_placement_list) all_mesh_dim_strategies.append(mesh_dim_strategies) diff --git a/torch/distributed/_tensor/ops/embedding_ops.py b/torch/distributed/_tensor/ops/embedding_ops.py index e79bdd13cd8c..f861c5fcbd57 100644 --- a/torch/distributed/_tensor/ops/embedding_ops.py +++ b/torch/distributed/_tensor/ops/embedding_ops.py @@ -19,8 +19,8 @@ ) from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -42,7 +42,7 @@ def materialize_mask(self, mask): def release_mask(self): # TODO: evaluate if we need to release the mask buffer or the buffer - # can just have the same lifetime as the _Partial placement + # can just have the same lifetime as the Partial placement if self.data is None: raise RuntimeError("MaskBuffer has not been materialized") self.data = None @@ -62,7 +62,7 @@ def apply_mask(self, tensor): @dataclass(frozen=True) -class _MaskPartial(_Partial): +class _MaskPartial(Partial): """ A partial mask placement devised for rowwise sharded embedding op, where we need to mask and adjust the indices to the local embedding shard, embedding masking @@ -275,11 +275,11 @@ def embedding_dense_backward_strategy( # batch dim sharding, weight replicated, grad_out/input have same sharding # that can shard on any dim, weight grad partial for input_dim in range(len(indices_shape)): - batch_sharding = [_Partial(), Shard(input_dim), Shard(input_dim)] + batch_sharding = [Partial(), Shard(input_dim), Shard(input_dim)] single_mesh_dim_strategies.append(batch_sharding) # grad_out partial, input replicate, weight grad keep partial - partial_sharding = [_Partial(), _Partial(), Replicate()] + partial_sharding = [Partial(), Partial(), Replicate()] single_mesh_dim_strategies.append(partial_sharding) all_mesh_dim_strategies.append(single_mesh_dim_strategies) diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py index 9a02f798f8ac..91d20a9dd8cb 100644 --- a/torch/distributed/_tensor/ops/math_ops.py +++ b/torch/distributed/_tensor/ops/math_ops.py @@ -25,8 +25,8 @@ register_op_strategy, ) from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -52,7 +52,7 @@ class NormReduction: @dataclass(frozen=True) -class _NormPartial(_Partial): +class _NormPartial(Partial): """ This placement is used for partial vector norm. @@ -229,7 +229,7 @@ def map_placements_after_reduction( """ new_placements: List[Placement] = [] for placement in placements: - if isinstance(placement, (Replicate, _Partial)): + if isinstance(placement, (Replicate, Partial)): new_placements.append(placement) else: assert isinstance(placement, Shard) @@ -247,7 +247,7 @@ def map_placements_after_reduction( def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement: if isinstance(reduction_op, NormReduction): return _NormPartial(norm_type=reduction_op.norm_type) - return _Partial(reduction_op) + return Partial(reduction_op) def common_reduction_strategy( diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py index 4a6fc0458119..ab80f783cf5b 100644 --- a/torch/distributed/_tensor/ops/pointwise_ops.py +++ b/torch/distributed/_tensor/ops/pointwise_ops.py @@ -22,8 +22,8 @@ register_op_strategy, ) from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -460,7 +460,7 @@ def common_pointwise_strategy( common_ndim = len(common_shape) new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim out_placements.append(Shard(new_shard_dim)) - elif isinstance(placement, _Partial) and not linearity: + elif isinstance(placement, Partial) and not linearity: # clear the partial placemnet if op does not support linearity # by default we just replicate the partial, need to see if this # is optimal for all cases diff --git a/torch/distributed/_tensor/ops/random_ops.py b/torch/distributed/_tensor/ops/random_ops.py index 3f33d16cc152..390dc419ecd7 100644 --- a/torch/distributed/_tensor/ops/random_ops.py +++ b/torch/distributed/_tensor/ops/random_ops.py @@ -24,7 +24,7 @@ def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: arg_spec = arg_strategy.output_spec if is_tensor_partial(arg_spec): # TODO: figure out how inplace random op should behave when it's partial - raise RuntimeError(f"{op_schema.op} with _Partial is not supported yet!") + raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!") random_strategy.strategies.append(PlacementStrategy(output_specs=arg_spec)) return random_strategy diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index 54a607d58c55..b42fdcc6cc08 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -26,8 +26,8 @@ register_prop_rule, ) from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -103,7 +103,7 @@ def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: output_spec = DTensorSpec( mesh=arg_spec.mesh, placements=tuple( - Replicate() if isinstance(p, _Partial) else p + Replicate() if isinstance(p, Partial) else p for p in arg_spec.placements ), ) @@ -154,7 +154,7 @@ def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: output_spec = DTensorSpec( mesh=arg_spec.mesh, placements=tuple( - Replicate() if isinstance(p, _Partial) else p + Replicate() if isinstance(p, Partial) else p for p in arg_spec.placements ), ) @@ -613,7 +613,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding: # 2. Other dimensions of values_spec can remain sharded if they are so. # For indices: # Indices can be either sharded or replicated. All index tensors need to be sharded - # in a compatible way, following the pointwise rule (including resolving _Partial + # in a compatible way, following the pointwise rule (including resolving Partial # into either sharded or replicated) values_spec, multi_indices_spec = op_schema.args_schema @@ -683,7 +683,7 @@ def place(vp: Placement, ip: Placement) -> Placement: ) if isinstance(ip, Shard): return Shard(ip.dim + insert_dim) - # _Partial or Replicated + # Partial or Replicated return vp value_placements = tuple( @@ -737,13 +737,13 @@ def split_rule(op_schema: OpSchema) -> OutputSharding: dim = cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 dim = normalize_dim(dim, ndim) - # TODO: tensor to split cannot have _Partial + # TODO: tensor to split cannot have Partial # in its placements for now. Will need to # support in future. if input_spec.sums: raise NotImplementedError( f"splitting distributed tensor with " - f"_Partial placement is not implemented!\n" + f"Partial placement is not implemented!\n" f"DTensorSpec={input_spec}" ) diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py index 149e690cedc4..98f65eab610b 100644 --- a/torch/distributed/_tensor/ops/utils.py +++ b/torch/distributed/_tensor/ops/utils.py @@ -8,8 +8,8 @@ from torch.distributed._tensor._op_schema import OpStrategy, RuntimeSchemaInfo from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -193,7 +193,7 @@ def map_placements_after_broadcast( """Map each placement based on the output shape after broadcast.""" new_placements: List[Placement] = [] for placement in placements: - if isinstance(placement, (Replicate, _Partial)): + if isinstance(placement, (Replicate, Partial)): new_placements.append(placement) else: assert isinstance(placement, Shard) diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index 449526f13a43..303d802bc7bc 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -439,7 +439,7 @@ def dim_reduction( ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool ) -> DimMap: """ - General fallback for reduction ops where _Partial() does not apply. + General fallback for reduction ops where Partial() does not apply. This will cause incoming tensor to be replicated on the reducing dimensions. """ diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index d90bcb6c258a..5cb5aaf55fa9 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -32,7 +32,7 @@ def is_replicate(self) -> bool: return isinstance(self, Replicate) def is_partial(self) -> bool: - return isinstance(self, _Partial) + return isinstance(self, Partial) @dataclass(frozen=True) @@ -412,7 +412,7 @@ class Partial(Placement): def _reduce_value( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int ) -> torch.Tensor: - # _Partial placement contract #1: + # Partial placement contract #1: # _reduce_value: reduce the value of the tensor on the mesh dimension return funcol.all_reduce( tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) @@ -425,7 +425,7 @@ def _reduce_shard_value( mesh_dim: int, shard_spec: Placement, ) -> torch.Tensor: - # _Partial placement contract #2: + # Partial placement contract #2: # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension shard_spec = cast(Shard, shard_spec) return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) @@ -433,7 +433,7 @@ def _reduce_shard_value( def _partition_value( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int ) -> torch.Tensor: - # _Partial placement contract #3: + # Partial placement contract #3: # _partition_value: partition the value of a replicated tensor on the mesh dimension # _partition_value is the conjugate operation of _reduce_value @@ -446,7 +446,7 @@ def _partition_value( return tensor / num_chunks def __eq__(self, other: object) -> bool: - if not isinstance(other, _Partial): + if not isinstance(other, Partial): return False return self.reduce_op == other.reduce_op @@ -457,7 +457,7 @@ def __repr__(self) -> str: """ machine readable representation of the Partial placement """ - return f"_Partial({self.reduce_op})" + return f"Partial({self.reduce_op})" def __str__(self) -> str: """ @@ -668,7 +668,7 @@ def from_dim_map( # find all mesh dims that need pending reductions for s in sums: - placements[s] = _Partial() + placements[s] = Partial() for i, m in enumerate(dim_map): if m >= 0: From 9257a0698b57acc5607ee6fe31a16fdd93af1731 Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 29 May 2024 11:11:55 -0700 Subject: [PATCH 062/706] [Split Build] Load dependencies from libtorch in __init__.py (#126826) This PR makes it such that we search for a libtorch wheel when initializing pytorch in order to find the necessary shared libraries. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126826 Approved by: https://github.com/huydhn, https://github.com/atalman, https://github.com/ZainRizvi --- c10/cuda/CMakeLists.txt | 3 ++- setup.py | 41 +++++++++++++++++-------------- tools/setup_helpers/env.py | 2 ++ torch/__init__.py | 49 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 74 insertions(+), 21 deletions(-) diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index 893a85562976..3327dab4779b 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -14,6 +14,8 @@ configure_file( if(BUILD_LIBTORCHLESS) find_library(C10_CUDA_LIB c10_cuda PATHS $ENV{LIBTORCH_LIB_PATH} NO_DEFAULT_PATH) +else() + set(C10_CUDA_LIB c10_cuda) endif() # Note: if you want to add ANY dependency to the c10 library, make sure you @@ -75,7 +77,6 @@ if(NOT BUILD_LIBTORCHLESS) $ $ $) - set(C10_CUDA_LIB c10_cuda) # ---[ Installation # Note: for now, we will put all export path into one single Caffe2Targets group diff --git a/setup.py b/setup.py index 65d81b4b01cb..e2529335bcc6 100644 --- a/setup.py +++ b/setup.py @@ -226,18 +226,6 @@ def _get_package_path(package_name): BUILD_LIBTORCH_WHL = os.getenv("BUILD_LIBTORCH_WHL", "0") == "1" BUILD_PYTHON_ONLY = os.getenv("BUILD_PYTHON_ONLY", "0") == "1" - -# set up appropriate env variables -if BUILD_LIBTORCH_WHL: - # Set up environment variables for ONLY building libtorch.so and not libtorch_python.so - # functorch is not supported without python - os.environ["BUILD_FUNCTORCH"] = "OFF" - - -if BUILD_PYTHON_ONLY: - os.environ["BUILD_LIBTORCHLESS"] = "ON" - os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('libtorch')}/lib" - python_min_version = (3, 8, 0) python_min_version_str = ".".join(map(str, python_min_version)) if sys.version_info < python_min_version: @@ -265,9 +253,26 @@ def _get_package_path(package_name): from tools.build_pytorch_libs import build_caffe2 from tools.generate_torch_version import get_torch_version from tools.setup_helpers.cmake import CMake -from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS +from tools.setup_helpers.env import ( + build_type, + IS_DARWIN, + IS_LINUX, + IS_WINDOWS, + LIBTORCH_PKG_NAME, +) from tools.setup_helpers.generate_linker_script import gen_linker_script +# set up appropriate env variables +if BUILD_LIBTORCH_WHL: + # Set up environment variables for ONLY building libtorch.so and not libtorch_python.so + # functorch is not supported without python + os.environ["BUILD_FUNCTORCH"] = "OFF" + + +if BUILD_PYTHON_ONLY: + os.environ["BUILD_LIBTORCHLESS"] = "ON" + os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path(LIBTORCH_PKG_NAME)}/lib" + ################################################################################ # Parameters parsed from environment ################################################################################ @@ -342,7 +347,7 @@ def report(*args): # Version, create_version_file, and package_name ################################################################################ -DEFAULT_PACKAGE_NAME = "libtorch" if BUILD_LIBTORCH_WHL else "torch" +DEFAULT_PACKAGE_NAME = LIBTORCH_PKG_NAME if BUILD_LIBTORCH_WHL else "torch" package_name = os.getenv("TORCH_PACKAGE_NAME", DEFAULT_PACKAGE_NAME) package_type = os.getenv("PACKAGE_TYPE", "wheel") @@ -1133,7 +1138,7 @@ def main(): ] if BUILD_PYTHON_ONLY: - install_requires.append("libtorch") + install_requires.append(LIBTORCH_PKG_NAME) use_prioritized_text = str(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD", "")) if ( @@ -1442,9 +1447,9 @@ def main(): if parts[0] == "torch": modified_packages.append(DEFAULT_PACKAGE_NAME + package[len("torch") :]) packages = modified_packages - package_dir = {"libtorch": "torch"} - torch_package_dir_name = "libtorch" - package_data = {"libtorch": torch_package_data} + package_dir = {LIBTORCH_PKG_NAME: "torch"} + torch_package_dir_name = LIBTORCH_PKG_NAME + package_data = {LIBTORCH_PKG_NAME: torch_package_data} extensions = [] else: torch_package_dir_name = "torch" diff --git a/tools/setup_helpers/env.py b/tools/setup_helpers/env.py index d87e97a2bb5a..eed5198ca9f2 100644 --- a/tools/setup_helpers/env.py +++ b/tools/setup_helpers/env.py @@ -21,6 +21,8 @@ BUILD_DIR = "build" +LIBTORCH_PKG_NAME = "libtorchsplit" + def check_env_flag(name: str, default: str = "") -> bool: return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] diff --git a/torch/__init__.py b/torch/__init__.py index c2bf4a802838..a2492c40a949 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -17,6 +17,9 @@ import ctypes import inspect import threading +import pdb +import importlib +import importlib.util # multipy/deploy is setting this import before importing torch, this is the most # reliable way we have to detect if we're running within deploy. @@ -165,11 +168,53 @@ def _preload_cuda_deps(lib_folder, lib_name): raise ValueError(f"{lib_name} not found in the system path {sys.path}") ctypes.CDLL(lib_path) - # See Note [Global dependencies] def _load_global_deps() -> None: + + LIBTORCH_PKG_NAME = "libtorchsplit" + + def find_package_path(package_name): + spec = importlib.util.find_spec(package_name) + if spec: + # The package might be a namespace package, so get_data may fail + try: + loader = spec.loader + if loader is not None: + file_path = loader.get_filename() # type: ignore[attr-defined] + return os.path.dirname(file_path) + except AttributeError: + pass + return None + + def load_shared_libraries(library_path): + lib_dir = os.path.join(library_path, 'lib') + if not os.path.exists(lib_dir): + return + + # Determine the file extension based on the platform + if platform.system() == 'Darwin': + lib_ext = '.dylib' + else: + lib_ext = '.so' + + # Find all shared library files with the appropriate extension + library_files = [f for f in os.listdir(lib_dir) if f.endswith(lib_ext)] + if not library_files: + return + + for lib_file in library_files: + lib_path = os.path.join(lib_dir, lib_file) + try: + ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + except OSError as err: + print(f"Failed to load {lib_path}: {err}") + if _running_with_deploy() or platform.system() == 'Windows': return + split_build_lib_name = LIBTORCH_PKG_NAME + library_path = find_package_path(split_build_lib_name) + if library_path: + load_shared_libraries(library_path) lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so') here = os.path.abspath(__file__) @@ -1268,7 +1313,7 @@ def _check_tensor_all(cond, message=None): # noqa: F811 # For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and # NumPy consistency (https://numpy.org/devdocs/reference/constants.html) -from math import e , nan , inf , pi +from math import e, nan , inf , pi newaxis: None = None __all__.extend(['e', 'pi', 'nan', 'inf', 'newaxis']) From 3174e6cb8e2d37210c7569e51dc6a9522110e0f3 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 29 May 2024 22:58:41 +0000 Subject: [PATCH 063/706] [Temp][CI] Run older MPS tests/Mac builds on MacOS 13 (#127428) To avoid ambiguity while migration outlined in https://github.com/pytorch-labs/pytorch-gha-infra/pull/399 is in progress. Otherwise, MPS jobs for Ventura can be accidentally scheduled on Sonoma or builds, which might result in flaky failures on trunk. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127428 Approved by: https://github.com/huydhn --- .github/workflows/mac-mps.yml | 4 ++-- .github/workflows/trunk.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/mac-mps.yml b/.github/workflows/mac-mps.yml index da98d01550a4..8554df29d1f6 100644 --- a/.github/workflows/mac-mps.yml +++ b/.github/workflows/mac-mps.yml @@ -19,13 +19,13 @@ jobs: with: sync-tag: macos-py3-arm64-build build-environment: macos-13-py3-arm64 - runner-type: macos-m1-stable + runner-type: macos-m1-13 build-generates-artifacts: true # To match the one pre-installed in the m1 runners python-version: 3.9.12 test-matrix: | { include: [ - { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-stable" }, + { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" }, { config: "mps", shard: 1, num_shards: 1, runner: "macos-m2-14" }, ]} diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 9da73c8addb7..f0567393d5fa 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -149,7 +149,7 @@ jobs: with: sync-tag: macos-py3-arm64-build build-environment: macos-13-py3-arm64 - runner-type: macos-m1-stable + runner-type: macos-m1-13 build-generates-artifacts: true # To match the one pre-installed in the m1 runners python-version: 3.9.12 @@ -172,7 +172,7 @@ jobs: python-version: 3.9.12 test-matrix: | { include: [ - { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-stable" }, + { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" }, { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" }, ]} From 15a7916c0ec14c45ce7b1c4e4e6c43fc8cf4e221 Mon Sep 17 00:00:00 2001 From: Darshan Sanghani Date: Wed, 29 May 2024 23:16:14 +0000 Subject: [PATCH 064/706] Ability to capture Process Groups information into Execution Traces (#126995) Contains a method added to the ExecutionTraceObserver class to record the snapshot of the current process group config upon tracing start. Unit test: ``` (pytorch) [dsang@devgpu021.nha2 ~/github/pytorch-fork (viable/strict)]$ touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_ddp_profiling_execution_trace /home/dsang/github/pytorch-fork/torch/distributed/optim/__init__.py:28: UserWarning: TorchScript support for functional optimizers isdeprecated and will be removed in a future PyTorch release. Consider using the torch.compile optimizer instead. warn("TorchScript support for functional optimizers is" test_ddp_profiling_execution_trace (__main__.TestDistBackendWithSpawn.test_ddp_profiling_execution_trace) ... /home/dsang/github/pytorch-fork/torch/distributed/optim/__init__.py:28: UserWarning: TorchScript support for functional optimizers isdeprecated and will be removed in a future PyTorch release. Consider using the torch.compile optimizer instead. warn("TorchScript support for functional optimizers is" /home/dsang/github/pytorch-fork/torch/distributed/optim/__init__.py:28: UserWarning: TorchScript support for functional optimizers isdeprecated and will be removed in a future PyTorch release. Consider using the torch.compile optimizer instead. warn("TorchScript support for functional optimizers is" NCCL version 2.20.5+cuda12.0 [rank1]:[W523 16:06:01.705774398 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) [rank0]:[W523 16:06:01.705905760 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) [rank1]:[W523 16:06:01.715182258 execution_trace_observer.cpp:819] Enabling Execution Trace Observer printing pg info into trace [rank0]:[W523 16:06:01.715841805 execution_trace_observer.cpp:819] Enabling Execution Trace Observer printing pg info into trace [rank1]:[W523 16:06:01.727881877 execution_trace_observer.cpp:831] Disabling Execution Trace Observer [rank0]:[W523 16:06:01.728792871 execution_trace_observer.cpp:831] Disabling Execution Trace Observer Execution trace saved at /tmp/tmpdsov4ngi.et.json [{'id': 3, 'name': '## process_group:init ##', 'ctrl_deps': 2, 'inputs': {'values': ['[{"pg_name": "0", "pg_desc": "default_pg", "backend_config": "cuda:nccl", "ranks": [], "group_size": 2, "group_count": 1}]'], 'shapes': [[]], 'types': ['String']}, 'outputs': {'values': [], 'shapes': [], 'types': []}, 'attrs': [{'name': 'rf_id', 'type': 'uint64', 'value': 1}, {'name': 'fw_parent', 'type': 'uint64', 'value': 0}, {'name': 'seq_id', 'type': 'int64', 'value': -1}, {'name': 'scope', 'type': 'uint64', 'value': 7}, {'name': 'tid', 'type': 'uint64', 'value': 1}, {'name': 'fw_tid', 'type': 'uint64', 'value': 0}, {'name': 'op_schema', 'type': 'string', 'value': ''}, {'name': 'kernel_backend', 'type': 'string', 'value': ''}, {'name': 'kernel_file', 'type': 'string', 'value': ''}]}] Execution trace saved at /tmp/tmpsdiqy6az.et.json [{'id': 3, 'name': '## process_group:init ##', 'ctrl_deps': 2, 'inputs': {'values': ['[{"pg_name": "0", "pg_desc": "default_pg", "backend_config": "cuda:nccl", "ranks": [], "group_size": 2, "group_count": 1}]'], 'shapes': [[]], 'types': ['String']}, 'outputs': {'values': [], 'shapes': [], 'types': []}, 'attrs': [{'name': 'rf_id', 'type': 'uint64', 'value': 1}, {'name': 'fw_parent', 'type': 'uint64', 'value': 0}, {'name': 'seq_id', 'type': 'int64', 'value': -1}, {'name': 'scope', 'type': 'uint64', 'value': 7}, {'name': 'tid', 'type': 'uint64', 'value': 1}, {'name': 'fw_tid', 'type': 'uint64', 'value': 0}, {'name': 'op_schema', 'type': 'string', 'value': ''}, {'name': 'kernel_backend', 'type': 'string', 'value': ''}, {'name': 'kernel_file', 'type': 'string', 'value': ''}]}] ok ---------------------------------------------------------------------- Ran 1 test in 24.447s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126995 Approved by: https://github.com/briancoutinho, https://github.com/sraikund16 --- torch/profiler/profiler.py | 14 ++++++++++++++ .../_internal/distributed/distributed_test.py | 4 ++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index a9f65104a99e..3847da03ab8e 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -849,6 +849,7 @@ def start(self): if self._registered and not self._execution_trace_running: _enable_execution_trace_observer() self._execution_trace_running = True + self._record_pg_config() def stop(self): """ @@ -875,3 +876,16 @@ def get_output_file_path(self) -> str: "A callback to the ET profiler needs to be registered " "first before getting the output file path" ) + + def _record_pg_config(self) -> None: + # Records the PG config info to the trace as node: + # ## process_group:init ## + if ( + self.is_registered + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info + torch.autograd._record_function_with_args_enter( + "## process_group:init ##", json.dumps(pg_config_info) + ) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index b9873b9950fa..92fff2623b31 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -6995,7 +6995,8 @@ def _validate_execution_trace_nccl(self, et_file: str) -> None: """ with open(et_file) as f: et = json.load(f) - + pg_cfg_node = [n for n in et["nodes"] if n["name"] == "## process_group:init ##"] + self.assertGreaterEqual(len(pg_cfg_node), 1) nccl_meta_nodes = [n for n in et["nodes"] if n["name"] == "record_param_comms"] self.assertEqual(len(nccl_meta_nodes), 3) per_coll_meta = defaultdict(list) @@ -7052,7 +7053,6 @@ def test_ddp_profiling_execution_trace(self): fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) fp.close() et_file = fp.name - et = ExecutionTraceObserver().register_callback(et_file) # first profiler context need not have ET From 0fa2c5b049e40d5a21410d24fd4fd49cbc47f866 Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 28 May 2024 18:24:44 +0000 Subject: [PATCH 065/706] Fix mask propagation in the presence of where (#125574) Before, when calling ops.where, masks were not properly propagated. We now restrict the optimisation to `ops.masked`, which I think it was what the original code intended to do. I'm not 100% sure that even in the masked case this code is not introducing some bugs, but this is a strict improvement over the previous state. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125574 Approved by: https://github.com/peterbell10 ghstack dependencies: #114471, #126783 --- test/inductor/test_torchinductor.py | 9 +++++++++ torch/_inductor/codegen/common.py | 5 +---- torch/_inductor/codegen/triton.py | 10 +++------- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 4f65504d5696..c6f5e61ac0f9 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1306,6 +1306,15 @@ def reflection_pad_left(x, n): expect = reflection_pad_left(x, 3) self.assertEqual(expect, actual) + def test_index_propagation_device_assert_masked(self): + def fn(a): + idx = torch.arange(a.size(0), device=a.device) + padded_idx = torch.constant_pad_nd(idx, (1050, 0)) + padded_idx = torch.where(padded_idx >= 0, padded_idx, padded_idx) + return a[padded_idx] + + self.common(fn, (torch.randn(1024),)) + @skipIfRocm @config.patch(debug_index_asserts=False) def test_neg_index(self): diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index e471fefabe1a..3e238203b770 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1637,10 +1637,7 @@ def indirect_indexing( pos = var.bounds & ValueRanges(0, sympy.oo) new_bounds = new_bounds | pos - new_var = self.cse.generate(self.compute, stm, bounds=new_bounds) - # Propagate the mask as mask propagation when using where is not correct - new_var.update_on_args("index_wrap", (var,), {}) - var = new_var + var = self.cse.generate(self.compute, stm, bounds=new_bounds) sympy_var = parent_handler.indirect_indexing(var, size, check) if generate_assert(check): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index d83680198e7d..d239e711db1f 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -432,12 +432,6 @@ def __init__(self, name, bounds: ValueRanges[Any]): self.mask_vars: Set[str] = set() def update_on_args(self, name, args, kwargs): - # When making a variable that is going to be used in indirect indexing - # if a where clause is used it should mean that the result is always a - # valid index, so you shouldn't include any of the dependent variables - # in the resulting load mask - if name == "where": - return for arg in args: if isinstance(arg, TritonCSEVariable): self.mask_vars.update(arg.mask_vars) @@ -889,7 +883,9 @@ def masked(mask, body, other): f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", bounds=ValueRanges.wrap(other), ) - return ops.where(new_mask, result, other) + ret = ops.where(new_mask, result, other) + ret.mask_vars.discard(new_mask) + return ret @staticmethod def load_seed(name, offset): From 8ea1dc874880de9bbcf32699c4098a3c5e3f3b20 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 29 May 2024 23:17:56 +0000 Subject: [PATCH 066/706] Use Python::NumPy target (#127399) Now that we use FindPython, use it again for numpy detection. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127399 Approved by: https://github.com/malfet --- cmake/Dependencies.cmake | 31 ++++++++----------- cmake/Modules/FindNumPy.cmake | 57 ----------------------------------- tools/setup_helpers/cmake.py | 4 +-- tools/setup_helpers/numpy_.py | 24 --------------- torch/CMakeLists.txt | 2 +- 5 files changed, 15 insertions(+), 103 deletions(-) delete mode 100644 cmake/Modules/FindNumPy.cmake delete mode 100644 tools/setup_helpers/numpy_.py diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a582a3e6ec05..e15b55cd16ed 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -864,7 +864,11 @@ if(BUILD_PYTHON) # These should fill in the rest of the variables, like versions, but resepct # the variables we set above - find_package(Python COMPONENTS Interpreter Development) + if(USE_NUMPY) + find_package(Python COMPONENTS Interpreter Development NumPy) + else() + find_package(Python COMPONENTS Interpreter Development) + endif() if(NOT Python_Development_FOUND) message(FATAL_ERROR @@ -876,29 +880,20 @@ if(BUILD_PYTHON) "Found Python libraries version ${Python_VERSION}. Python < 3.8 is no longer supported by PyTorch.") endif() - # When building pytorch, we pass this in directly from setup.py, and - # don't want to overwrite it because we trust python more than cmake - if(NUMPY_INCLUDE_DIR) - set(NUMPY_FOUND ON) - elseif(USE_NUMPY) - find_package(NumPy) - if(NOT NUMPY_FOUND) - message(WARNING "NumPy could not be found. Not building with NumPy. Suppress this warning with -DUSE_NUMPY=OFF") - endif() - endif() - - if(Python_Interpreter_FOUND AND Python_Development_FOUND) + if(Python_Interpreter_FOUND) add_library(python::python INTERFACE IMPORTED) target_include_directories(python::python SYSTEM INTERFACE ${Python_INCLUDE_DIRS}) if(WIN32) target_link_libraries(python::python INTERFACE ${Python_LIBRARIES}) endif() - caffe2_update_option(USE_NUMPY OFF) - if(NUMPY_FOUND) - caffe2_update_option(USE_NUMPY ON) - add_library(numpy::numpy INTERFACE IMPORTED) - target_include_directories(numpy::numpy SYSTEM INTERFACE ${NUMPY_INCLUDE_DIR}) + if(USE_NUMPY) + if(NOT Python_NumPy_FOUND) + message(WARNING "NumPy could not be found. Not building with NumPy. Suppress this warning with -DUSE_NUMPY=OFF") + caffe2_update_option(USE_NUMPY OFF) + else() + caffe2_update_option(USE_NUMPY ON) + endif() endif() # Observers are required in the python build caffe2_update_option(USE_OBSERVERS ON) diff --git a/cmake/Modules/FindNumPy.cmake b/cmake/Modules/FindNumPy.cmake deleted file mode 100644 index 2c43b95bdcf6..000000000000 --- a/cmake/Modules/FindNumPy.cmake +++ /dev/null @@ -1,57 +0,0 @@ -# - Find the NumPy libraries -# This module finds if NumPy is installed, and sets the following variables -# indicating where it is. -# -# TODO: Update to provide the libraries and paths for linking npymath lib. -# -# NUMPY_FOUND - was NumPy found -# NUMPY_VERSION - the version of NumPy found as a string -# NUMPY_VERSION_MAJOR - the major version number of NumPy -# NUMPY_VERSION_MINOR - the minor version number of NumPy -# NUMPY_VERSION_PATCH - the patch version number of NumPy -# NUMPY_VERSION_DECIMAL - e.g. version 1.6.1 is 10601 -# NUMPY_INCLUDE_DIR - path to the NumPy include files - -unset(NUMPY_VERSION) -unset(NUMPY_INCLUDE_DIR) - -if(Python_Interpreter_FOUND) - execute_process(COMMAND "${Python_EXECUTABLE}" "-c" - "import numpy as n; print(n.__version__); print(n.get_include());" - RESULT_VARIABLE __result - OUTPUT_VARIABLE __output - OUTPUT_STRIP_TRAILING_WHITESPACE) - - if(__result MATCHES 0) - string(REGEX REPLACE ";" "\\\\;" __values ${__output}) - string(REGEX REPLACE "\r?\n" ";" __values ${__values}) - list(GET __values 0 NUMPY_VERSION) - list(GET __values 1 NUMPY_INCLUDE_DIR) - - string(REGEX MATCH "^([0-9])+\\.([0-9])+\\.([0-9])+" __ver_check "${NUMPY_VERSION}") - if(NOT "${__ver_check}" STREQUAL "") - set(NUMPY_VERSION_MAJOR ${CMAKE_MATCH_1}) - set(NUMPY_VERSION_MINOR ${CMAKE_MATCH_2}) - set(NUMPY_VERSION_PATCH ${CMAKE_MATCH_3}) - math(EXPR NUMPY_VERSION_DECIMAL - "(${NUMPY_VERSION_MAJOR} * 10000) + (${NUMPY_VERSION_MINOR} * 100) + ${NUMPY_VERSION_PATCH}") - string(REGEX REPLACE "\\\\" "/" NUMPY_INCLUDE_DIR ${NUMPY_INCLUDE_DIR}) - else() - unset(NUMPY_VERSION) - unset(NUMPY_INCLUDE_DIR) - message(STATUS "Requested NumPy version and include path, but got instead:\n${__output}\n") - endif() - endif() -else() - message(STATUS "To find NumPy Python interpretator is required to be found.") -endif() - -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(NumPy REQUIRED_VARS NUMPY_INCLUDE_DIR NUMPY_VERSION - VERSION_VAR NUMPY_VERSION) - -if(NUMPY_FOUND) - message(STATUS "NumPy ver. ${NUMPY_VERSION} found (include: ${NUMPY_INCLUDE_DIR})") -endif() - -caffe_clear_vars(__result __output __error_value __values __ver_check __error_value) diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 3e3e06d54115..4d10b3db1aa3 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -13,7 +13,6 @@ from . import which from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file from .env import BUILD_DIR, check_negative_env_flag, IS_64BIT, IS_DARWIN, IS_WINDOWS -from .numpy_ import NUMPY_INCLUDE_DIR, USE_NUMPY def _mkdir_p(d: str) -> None: @@ -285,7 +284,7 @@ def generate( "BUILD_TEST": build_test, # Most library detection should go to CMake script, except this one, which Python can do a much better job # due to NumPy's inherent Pythonic nature. - "USE_NUMPY": USE_NUMPY, + "USE_NUMPY": not check_negative_env_flag("USE_NUMPY"), } ) @@ -309,7 +308,6 @@ def generate( args, Python_EXECUTABLE=sys.executable, TORCH_BUILD_VERSION=version, - NUMPY_INCLUDE_DIR=NUMPY_INCLUDE_DIR, **build_options, ) diff --git a/tools/setup_helpers/numpy_.py b/tools/setup_helpers/numpy_.py deleted file mode 100644 index e93fcfd24707..000000000000 --- a/tools/setup_helpers/numpy_.py +++ /dev/null @@ -1,24 +0,0 @@ -"""NumPy helper. - -Note: If you plan to add a library detection script like this one, consider it twice. Most library detection should go -to CMake script. This one is an exception, because Python code can do a much better job due to NumPy's inherent Pythonic -nature. -""" - -from .env import check_negative_env_flag - - -# Set USE_NUMPY to what the user wants, because even if we fail here, cmake -# will check for the presence of NumPy again (`cmake/Dependencies.cmake`). -USE_NUMPY = not check_negative_env_flag("USE_NUMPY") -NUMPY_INCLUDE_DIR = None - -if USE_NUMPY: - try: - import numpy as np - except ImportError: - pass - else: - # To reach here, the user must has not disabled NumPy build and the - # NumPy library is present in the system. - NUMPY_INCLUDE_DIR = np.get_include() diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index d212b17e0e8e..b4db57488f02 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -300,7 +300,7 @@ add_dependencies(torch_python Caffe2_PROTO) add_dependencies(torch_python onnx_proto) # Avoid numpy for the DEPLOY build if(USE_NUMPY) - target_link_libraries(torch_python PRIVATE numpy::numpy) + target_link_libraries(torch_python PRIVATE Python::NumPy) target_compile_definitions(torch_python PRIVATE USE_NUMPY) endif() From 7931eee5c5ebcdf468bff4d308510b03355cd909 Mon Sep 17 00:00:00 2001 From: Shan19900305 Date: Wed, 29 May 2024 23:19:30 +0000 Subject: [PATCH 067/706] Support torch.dtype as parameter in pybind11 cpp extension. (#126865) Support torch.dtype as parameter in pybind11 cpp extension. Example: ` cpp_extension.my_ops(self, other, torch.dtype) ` @ezyang @bdhirsh Co-authored-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126865 Approved by: https://github.com/ezyang --- test/cpp_extensions/extension.cpp | 6 ++++++ test/test_cpp_extensions_aot.py | 4 ++++ torch/csrc/utils/pybind.h | 32 +++++++++++++++++++++++++++++- torch/csrc/utils/tensor_dtypes.cpp | 8 ++------ torch/csrc/utils/tensor_dtypes.h | 8 +++----- 5 files changed, 46 insertions(+), 12 deletions(-) diff --git a/test/cpp_extensions/extension.cpp b/test/cpp_extensions/extension.cpp index 1de9e0397111..0b609e82e0c5 100644 --- a/test/cpp_extensions/extension.cpp +++ b/test/cpp_extensions/extension.cpp @@ -2,6 +2,7 @@ // test include_dirs in setuptools.setup with relative path #include +#include torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) { return x.sigmoid() + y.sigmoid(); @@ -31,6 +32,10 @@ torch::Tensor random_tensor() { return torch::randn({1}); } +at::ScalarType get_math_type(at::ScalarType other) { + return at::toOpMathType(other); +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)"); m.def( @@ -52,4 +57,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_symint", []() { return c10::SymInt(1); }); m.def("get_symintarrayref", []() { return at::SymIntArrayRef({1, 2, 3}); }); m.def("get_tensor", []() { return random_tensor(); }); + m.def("get_math_type", &get_math_type); } diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 3e5ce5cfcef4..eb6d43e4cf00 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -55,6 +55,10 @@ def test_extension_function(self): y = torch.randn(4, 4) z = cpp_extension.sigmoid_add(x, y) self.assertEqual(z, x.sigmoid() + y.sigmoid()) + # test pybind support torch.dtype cast. + self.assertEqual( + str(torch.float32), str(cpp_extension.get_math_type(torch.half)) + ) def test_extension_module(self): mm = cpp_extension.MatrixMultiplier(4, 8) diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index 553738b8999b..19874d2e29b2 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -189,6 +190,35 @@ struct type_caster { } }; +template <> +struct type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype")); + + // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType + // cannot be default-initialized, we provide this constructor to explicitly + // initialize that field. The value doesn't matter as it will be overwritten + // after a successful call to load. + type_caster() : value(at::kFloat) {} + + bool load(handle src, bool) { + PyObject* obj = src.ptr(); + if (THPDtype_Check(obj)) { + value = reinterpret_cast(obj)->scalar_type; + return true; + } + return false; + } + + static handle cast( + const at::ScalarType& src, + return_value_policy /* policy */, + handle /* parent */) { + return Py_NewRef(torch::getTHPDtype(src)); + } +}; + template <> struct type_caster { public: @@ -206,7 +236,7 @@ struct type_caster { if (THPStream_Check(obj)) { value = c10::Stream::unpack3( ((THPStream*)obj)->stream_id, - ((THPStream*)obj)->device_index, + static_cast(((THPStream*)obj)->device_index), static_cast(((THPStream*)obj)->device_type)); return true; } diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index 200e04eaddb0..5290392d900f 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -1,14 +1,11 @@ #include #include #include -#include #include #include #include -#include -namespace torch { -namespace utils { +namespace torch::utils { std::pair getDtypeNames(at::ScalarType scalarType) { switch (scalarType) { @@ -125,5 +122,4 @@ void initializeDtypes() { } } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_dtypes.h b/torch/csrc/utils/tensor_dtypes.h index 32b769971d03..9a947b380e92 100644 --- a/torch/csrc/utils/tensor_dtypes.h +++ b/torch/csrc/utils/tensor_dtypes.h @@ -1,15 +1,13 @@ #pragma once -#include +#include #include #include -namespace torch { -namespace utils { +namespace torch::utils { std::pair getDtypeNames(at::ScalarType scalarType); void initializeDtypes(); -} // namespace utils -} // namespace torch +} // namespace torch::utils From 76fc58c1601365d90d74cdf0eb22cca3f8d742b4 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 28 May 2024 06:48:27 -0700 Subject: [PATCH 068/706] Document the legacy constructor for Tensor (#122625) Fixes https://github.com/pytorch/pytorch/issues/122408 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/122625 Approved by: https://github.com/albanD --- docs/source/tensors.rst | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 218c83d0a373..7bfa8704f5e5 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -212,6 +212,37 @@ Tensor class reference (see :ref:`tensor-creation-ops`). - To create a tensor with similar type but different size as another tensor, use ``tensor.new_*`` creation ops. + - There is a legacy constructor ``torch.Tensor`` whose use is discouraged. + Use :func:`torch.tensor` instead. + +.. method:: Tensor.__init__(self, data) + + This constructor is deprecated, we recommend using :func:`torch.tensor` instead. + What this constructor does depends on the type of ``data``. + + * If ``data`` is a Tensor, returns an alias to the original Tensor. Unlike + :func:`torch.tensor`, this tracks autograd and will propagate gradients to + the original Tensor. ``device`` kwarg is not supported for this ``data`` type. + + * If ``data`` is a sequence or nested sequence, create a tensor of the default + dtype (typically ``torch.float32``) whose data is the values in the + sequences, performing coercions if necessary. Notably, this differs from + :func:`torch.tensor` in that this constructor will always construct a float + tensor, even if the inputs are all integers. + + * If ``data`` is a :class:`torch.Size`, returns an empty tensor of that size. + + This constructor does not support explicitly specifying ``dtype`` or ``device`` of + the returned tensor. We recommend using :func:`torch.tensor` which provides this + functionality. + + Args: + data (array_like): The tensor to construct from. + + Keyword args: + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + .. autoattribute:: Tensor.T .. autoattribute:: Tensor.H From f14dc3bde82aeda53a0d23aa69198c43b64cf77f Mon Sep 17 00:00:00 2001 From: Lei Ding <69283446+Dmovic@users.noreply.github.com> Date: Wed, 29 May 2024 23:58:06 +0000 Subject: [PATCH 069/706] Fix check message (#126951) As title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126951 Approved by: https://github.com/Skylion007, https://github.com/kit1980 --- aten/src/ATen/native/AdaptiveAveragePooling3d.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp index bbd4f68d40d0..54cde5aad4c0 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp @@ -310,7 +310,7 @@ Tensor adaptive_avg_pool3d_symint(Tensor const& input, SymIntArrayRef output_siz TORCH_CHECK(output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3"); TORCH_CHECK( (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0), - "adaptive_avg_pool2d: elements of output_size must be greater than or equal to 0 ", + "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ", "but received {", output_size[0], ", ", output_size[1], ",", output_size[2], "}"); if (output_size[0] == 1 && output_size[1] == 1 && output_size[2] == 1 && !input.is_xpu()) { From d66f12674cfe0151a86dc10b8de216f83bf42e6e Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Thu, 30 May 2024 00:08:09 +0000 Subject: [PATCH 070/706] Handle tuple and dict during TorchScript to ExportedProgram conversion (#127341) * Add some test cases for testing List, Tuple, and Dict * Refactor the conversion code slightly * Add a logic to handle Dict Pull Request resolved: https://github.com/pytorch/pytorch/pull/127341 Approved by: https://github.com/SherlockNoMad, https://github.com/angelayi --- test/export/test_converter.py | 53 +++++++++++++++++++++++++++++++---- torch/_export/converter.py | 42 ++++++++++++++++++++++----- 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index ab6c3b802418..b6d0e54a59e1 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1,6 +1,9 @@ # Owner(s): ["oncall: export"] import torch + +import torch.utils._pytree as pytree + from torch._dynamo.test_case import TestCase from torch._export.converter import TS2EPConverter @@ -8,18 +11,58 @@ class TestConverter(TestCase): + def _check_equal_ts_ep_converter(self, mod, inp): + ts_model = torch.jit.script(mod) + ep = TS2EPConverter(ts_model, inp).convert() + ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) + orig_out, _ = pytree.tree_flatten(mod(*inp)) + self.assertEqual(len(ep_out), len(orig_out)) + for ep_t, orig_t in zip(ep_out, orig_out): + self.assertEqual(ep_t.shape, orig_t.shape) + self.assertTrue(torch.allclose(ep_t, orig_t)) + def test_ts2ep_converter_basic(self): - class Module(torch.nn.Module): + class MSingle(torch.nn.Module): def forward(self, x, y): return x + y - m = Module() + class MMulti(torch.nn.Module): + def forward(self, x, y): + x = x.cos() + 1 + y = y.sin() - 1 + return x, y + inp = (torch.ones(1, 3), torch.ones(1, 3)) + self._check_equal_ts_ep_converter(MSingle(), inp) + self._check_equal_ts_ep_converter(MMulti(), inp) - ts_model = torch.jit.script(m) - ep = TS2EPConverter(ts_model, inp).convert() + def test_ts2ep_converter_container_output(self): + # Output is a List. + class MOutputList(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + a = x * x + b = y + y + return [a, b] + + # Output is a Tuple. + class MOutputTuple(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + a = x * x + b = y + y + return (a, b) + + # Output is a Dict. + class MOutputDict(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + a = x * x + b = y + y + return {"data": {"mul": a, "add": b}} + + inp = (torch.tensor(4), torch.tensor(4)) - torch.testing.assert_close(ep.module()(*inp)[0], m(*inp)) + self._check_equal_ts_ep_converter(MOutputList(), inp) + self._check_equal_ts_ep_converter(MOutputTuple(), inp) + self._check_equal_ts_ep_converter(MOutputDict(), inp) if __name__ == "__main__": diff --git a/torch/_export/converter.py b/torch/_export/converter.py index d902c5f1ac55..459f534ca636 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -33,8 +33,6 @@ def replacement(im, dim, scale): replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement) - print(replaced_patterns) - def normalize_name(name: str) -> str: return name.replace(".", "_") @@ -76,7 +74,9 @@ def __init__( self.input_specs: List[InputSpec] = [] self.output_specs: List[OutputSpec] = [] - self.name_to_node: Dict[str, Union[torch.fx.Node, List[torch.fx.Node]]] = {} + self.name_to_node: Dict[ + str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]] + ] = {} self.constant_map: Dict[str, Any] = {} self.attribute_map: Dict[str, Any] = {} self.tensor_constants: Dict[str, torch.Tensor] = {} @@ -236,12 +236,35 @@ def convert_aten_op(self, node: torch._C.Node): def convert_prim_ListConstruct(self, node: torch._C.Node): output_list = [] - for input in node.inputs(): - output_list.append(self.get_fx_value(input)) + for inp in node.inputs(): + output_list.append(self.get_fx_value(inp)) output_name = node.output().debugName() self.name_to_node[output_name] = output_list + def convert_prim_DictConstruct(self, node: torch._C.Node): + output_dict = {} + k, v = None, None + for i, inp in enumerate(node.inputs()): + # We assume key value are stored in pair in the DictConstruct. + # The first element is the key and the following is the value. + if i % 2 == 0: + k = self.get_fx_value(inp) + else: + v = self.get_fx_value(inp) + assert ( + k is not None and v is not None + ), "DictConstruct has an empty key value pair." + output_dict[k] = v + k, v = None, None + + assert ( + k is None and v is None + ), "DictConstruct has an odd number of elements (violating our assumption)." + + output_name = node.output().debugName() + self.name_to_node[output_name] = output_dict + def convert_aten_Int(self, node: torch._C.Node): # converts aten::Int as aten._to_copy + aten::_local_scalar_dense target = torch.ops.aten._to_copy.default @@ -324,8 +347,11 @@ def convert_node(self, node: torch._C.Node): self.convert_prim_GetAttr(node) elif node_kind == "prim::NumToTensor": self.convert_prim_NumToTensor(node) - elif node_kind == "prim::ListConstruct": + elif node_kind in {"prim::ListConstruct", "prim::TupleConstruct"}: + # Tuple is just a non-mutable List, so we can handle them together. self.convert_prim_ListConstruct(node) + elif node_kind == "prim::DictConstruct": + self.convert_prim_DictConstruct(node) # elif node_kind == "aten::Int": # convert_aten_Int(node) elif node_kind == "aten::_convolution": @@ -354,7 +380,9 @@ def convert_graph_outputs(self): ) ) - self.fx_graph.output(args) + self.fx_graph.output( + args[0] + ) # Get rid of an extra list wrapped around final output. def retrace_as_exported_program(self, gm: torch.fx.GraphModule): # TODO: adjust input orders to match GraphSignature convention From 49ad90349d57c35ab83f40c28d8b18caefb416d1 Mon Sep 17 00:00:00 2001 From: saadelkouari Date: Thu, 30 May 2024 00:50:26 +0000 Subject: [PATCH 071/706] Correct error message for aten::_local_scalar_dense on meta tensor (#124554) registering a meta for aten::_local_scalar_dense with a different error message. Fixes pytorch#119588 Co-authored-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/124554 Approved by: https://github.com/ezyang --- test/test_meta.py | 5 +++++ torch/_meta_registrations.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/test/test_meta.py b/test/test_meta.py index ebd91e71c29f..a5368fbfaee7 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -1724,6 +1724,11 @@ def f(): out = f() self.assertEqual(out.shape, [10, 16]) + def test_local_scalar_dense_call(self): + with self.assertRaisesRegex(RuntimeError, "cannot be called on meta tensors"): + meta_tensor = torch.randn(1, device='meta') + meta_tensor.item() + instantiate_device_type_tests(TestMeta, globals()) def print_op_str_if_not_supported(op_str): diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index d25866d0abb9..95e3aa1eebf3 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -6345,6 +6345,11 @@ def meta_channel_shuffle(input, groups): ) +@register_meta(aten._local_scalar_dense) +def meta_local_scalar_dense(self: Tensor): + raise RuntimeError("Tensor.item() cannot be called on meta tensors") + + def _create_unary_float_meta_func(func): @register_meta(func) @out_wrapper() From 1abcac9dab500c743746f6ecf386c4a9975435d4 Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 29 May 2024 11:27:52 -0700 Subject: [PATCH 072/706] New Custom Ops Documentation landing page (#127400) We create a new landing page for PyTorch custom ops (suggested by jansel). All of our error messages will link here, and I'll work with the docs team to see if we can boost SEO for this page. NB: the landing page links some non-searchable webpages. Two of those (the Python custom ops tutorial and C++ custom ops tutorial) will turn into actual webpages when PyTorch 2.4 comes around. I'll make the third one (the Custom Operators Manual) once it stabilizes (we continously add new things to it and the length means that we might want to create a custom website for it to make the presentation more ingestable). Test Plan: - view docs preview. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127400 Approved by: https://github.com/jansel ghstack dependencies: #127291, #127292 --- docs/source/export.rst | 20 ++++----- docs/source/library.rst | 5 ++- docs/source/notes/custom_operators.rst | 56 ++++++++++++++++++++++++++ docs/source/notes/extending.rst | 22 +++++----- 4 files changed, 79 insertions(+), 24 deletions(-) create mode 100644 docs/source/notes/custom_operators.rst diff --git a/docs/source/export.rst b/docs/source/export.rst index a4217e8081ba..c6134d187b66 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -632,23 +632,17 @@ number of paths. In such cases, users will need to rewrite their code using special control flow operators. Currently, we support :ref:`torch.cond ` to express if-else like control flow (more coming soon!). -Missing Meta Kernels for Operators -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Missing Fake/Meta/Abstract Kernels for Operators +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -When tracing, a META implementation (or "meta kernel") is required for all -operators. This is used to reason about the input/output shapes for this -operator. +When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is +required for all operators. This is used to reason about the input/output shapes +for this operator. -To register a meta kernel for a C++ Custom Operator, please refer to -`this documentation `__. - -The official API for registering custom meta kernels for custom ops implemented -in python is currently undergoing development. While the final API is being -refined, you can refer to the documentation -`here `_. +Please see :func:`torch.library.register_fake` for more details. In the unfortunate case where your model uses an ATen operator that is does not -have a meta kernel implementation yet, please file an issue. +have a FakeTensor kernel implementation yet, please file an issue. Read More diff --git a/docs/source/library.rst b/docs/source/library.rst index 236da45f93c1..f632d93d1ec4 100644 --- a/docs/source/library.rst +++ b/docs/source/library.rst @@ -1,3 +1,5 @@ +.. _torch-library-docs: + torch.library =================================== .. py:module:: torch.library @@ -9,7 +11,8 @@ custom operators, and extending operators defined with PyTorch's C++ operator registration APIs (e.g. aten operators). For a detailed guide on effectively using these APIs, please see -`this gdoc `_ +Please see :ref:`custom-ops-landing-page` +for more details on how to effectively use these APIs. Testing custom ops ------------------ diff --git a/docs/source/notes/custom_operators.rst b/docs/source/notes/custom_operators.rst new file mode 100644 index 000000000000..2cdf214351b0 --- /dev/null +++ b/docs/source/notes/custom_operators.rst @@ -0,0 +1,56 @@ +.. _custom-ops-landing-page: + +PyTorch Custom Operators Landing Page +===================================== + +PyTorch offers a large library of operators that work on Tensors (e.g. :func:`torch.add`, +:func:`torch.sum`, etc). However, you may wish to bring a new custom operation to PyTorch +and get it to work with subsystems like :func:`torch.compile`, autograd, and :func:`torch.vmap`. +In order to do so, you must register the custom operation with PyTorch via the Python +:ref:`torch-library-docs` or C++ TORCH_LIBRARY APIs. + +TL;DR +----- + +How do I author a custom op from Python? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. + [comment] TODO(rzou): The following will be a link to a tutorial on the PyTorch tutorials site in 2.4 + +Please see the `Python Custom Operators tutorial `_ + + +How do I integrate custom C++ and/or CUDA code with PyTorch? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. + [comment] TODO(rzou): The following will be a link to a tutorial on the PyTorch tutorials site in 2.4 + +Please see the `Custom C++ and CUDA Operators tutorial `_ + + +For more details +^^^^^^^^^^^^^^^^ + +Please see `The Custom Operators Manual (gdoc) `_ +(we're working on moving the information to our docs site). We recommend that you +first read one of the tutorials above and then use the Custom Operators Manual as a reference; +it is not meant to be read head to toe. + +When should I create a Custom Operator? +--------------------------------------- +If your operation is expressible as a composition of built-in PyTorch operators +then please write it as a Python function and call it instead of creating a +custom operator. Use the operator registration APIs to create a custom op if you +are calling into some library that PyTorch doesn't understand (e.g. custom C/C++ code, +a custom CUDA kernel, or Python bindings to C/C++/CUDA extensions). + +Why should I create a Custom Operator? +-------------------------------------- + +It is possible to use a C/C++/CUDA kernel by grabbing a Tensor's data pointer +and passing it to a pybind'ed kernel. However, this approach doesn't compose with +PyTorch subsystems like autograd, torch.compile, vmap, and more. In order +for an operation to compose with PyTorch subsystems, it must be registered +via the operator registration APIs. diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index 80796375c3fe..bf69d0e012f6 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -4,6 +4,18 @@ Extending PyTorch In this note we'll cover ways of extending :mod:`torch.nn`, :mod:`torch.autograd`, :mod:`torch`, and writing custom C++ extensions. +Adding new operators +-------------------- + +PyTorch offers a large library of operators that work on Tensors (e.g. :func:`torch.add`, +:func:`torch.sum`, etc). However, you may wish to bring a new custom operation to PyTorch +and have it behave like PyTorch's built-in operators. In order to do so, you must +register the custom operation with PyTorch via the Python :ref:`torch-library-docs` or C++ TORCH_LIBRARY +APIs. + + +Please see :ref:`custom-ops-landing-page` for more details. + .. _extending-autograd: Extending :mod:`torch.autograd` @@ -968,13 +980,3 @@ Which prints the following, with extra comments:: Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{}) Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{}) Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{}) - - -Writing custom C++ extensions ------------------------------ - -See this -`PyTorch tutorial `_ -for a detailed explanation and examples. - -Documentations are available at :doc:`../cpp_extension`. From 67739d8c6ff0d6331b570a547346920bed810838 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 30 May 2024 01:16:57 +0000 Subject: [PATCH 073/706] Revert "[Submodule] Remove deprecated USE_TBB option and TBB submodule (#127051)" This reverts commit 699db7988d84d163ebb6919f78885e4630182a7a. Reverted https://github.com/pytorch/pytorch/pull/127051 on behalf of https://github.com/PaliC due to This PR needs to be synced using the import button as there is a bug in our diff train ([comment](https://github.com/pytorch/pytorch/pull/127051#issuecomment-2138496995)) --- .ci/pytorch/build.sh | 5 +- .ci/pytorch/test.sh | 18 + .gitmodules | 4 + BUILD.bazel | 14 +- CMakeLists.txt | 6 + WORKSPACE | 10 + aten/src/ATen/CMakeLists.txt | 10 + aten/src/ATen/Config.h.in | 1 + aten/src/ATen/Parallel.h | 2 + aten/src/ATen/ParallelCommon.cpp | 2 + aten/src/ATen/ParallelNativeTBB.cpp | 115 ++++++ aten/src/ATen/ParallelNativeTBB.h | 52 +++ aten/src/ATen/ParallelThreadPoolNative.cpp | 2 +- aten/src/ATen/cpu/tbb/CMakeLists.txt | 391 ++++++++++++++++++ .../ATen/cpu/tbb/extra/version_string.ver.in | 11 + buckbuild.bzl | 4 + build_variables.bzl | 1 + caffe2/CMakeLists.txt | 19 + cmake/Dependencies.cmake | 29 ++ cmake/Modules/FindMKL.cmake | 4 +- cmake/Modules/FindMKLDNN.cmake | 2 +- cmake/Summary.cmake | 4 + cmake/public/utils.cmake | 3 + defs.bzl | 2 + setup.py | 8 + third_party/mkl-dnn.BUILD | 5 +- third_party/mkl.BUILD | 5 +- third_party/tbb | 1 + third_party/tbb.BUILD | 75 ++++ third_party/tbb.patch | 34 ++ torch/testing/_internal/common_modules.py | 4 +- torch/testing/_internal/common_optimizers.py | 3 +- torch/testing/_internal/common_utils.py | 33 ++ torch/utils/cpp_extension.py | 3 + 34 files changed, 863 insertions(+), 19 deletions(-) create mode 100644 aten/src/ATen/ParallelNativeTBB.cpp create mode 100644 aten/src/ATen/ParallelNativeTBB.h create mode 100644 aten/src/ATen/cpu/tbb/CMakeLists.txt create mode 100644 aten/src/ATen/cpu/tbb/extra/version_string.ver.in create mode 160000 third_party/tbb create mode 100644 third_party/tbb.BUILD create mode 100644 third_party/tbb.patch diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 187e6d788bdd..130b770a2cc2 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -44,7 +44,10 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda11* ]]; then fi fi -if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then +if [[ ${BUILD_ENVIRONMENT} == *"paralleltbb"* ]]; then + export ATEN_THREADING=TBB + export USE_TBB=1 +elif [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then export ATEN_THREADING=NATIVE fi diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 76d7e259f365..0dcf5cd0b388 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -776,6 +776,7 @@ test_aten() { ${SUDO} ln -sf "$TORCH_LIB_DIR"/libmkldnn* "$TEST_BASE_DIR" ${SUDO} ln -sf "$TORCH_LIB_DIR"/libnccl* "$TEST_BASE_DIR" ${SUDO} ln -sf "$TORCH_LIB_DIR"/libtorch* "$TEST_BASE_DIR" + ${SUDO} ln -sf "$TORCH_LIB_DIR"/libtbb* "$TEST_BASE_DIR" ls "$TEST_BASE_DIR" aten/tools/run_tests.sh "$TEST_BASE_DIR" @@ -800,6 +801,21 @@ test_without_numpy() { popd } +# pytorch extensions require including torch/extension.h which includes all.h +# which includes utils.h which includes Parallel.h. +# So you can call for instance parallel_for() from your extension, +# but the compilation will fail because of Parallel.h has only declarations +# and definitions are conditionally included Parallel.h(see last lines of Parallel.h). +# I tried to solve it #39612 and #39881 by including Config.h into Parallel.h +# But if Pytorch is built with TBB it provides Config.h +# that has AT_PARALLEL_NATIVE_TBB=1(see #3961 or #39881) and it means that if you include +# torch/extension.h which transitively includes Parallel.h +# which transitively includes tbb.h which is not available! +if [[ "${BUILD_ENVIRONMENT}" == *tbb* ]]; then + sudo mkdir -p /usr/include/tbb + sudo cp -r "$PWD"/third_party/tbb/include/tbb/* /usr/include/tbb +fi + test_libtorch() { local SHARD="$1" @@ -813,6 +829,7 @@ test_libtorch() { ln -sf "$TORCH_LIB_DIR"/libc10* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libshm* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_BIN_DIR" + ln -sf "$TORCH_LIB_DIR"/libtbb* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libnvfuser* "$TORCH_BIN_DIR" export CPP_TESTS_DIR="${TORCH_BIN_DIR}" @@ -949,6 +966,7 @@ test_rpc() { # test reporting process to function as expected. ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libc10* "$TORCH_BIN_DIR" + ln -sf "$TORCH_LIB_DIR"/libtbb* "$TORCH_BIN_DIR" CPP_TESTS_DIR="${TORCH_BIN_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_cpp_rpc } diff --git a/.gitmodules b/.gitmodules index 476f11fd945c..fe30ac3d0e5b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -82,6 +82,10 @@ ignore = dirty path = third_party/foxi url = https://github.com/houseroad/foxi.git +[submodule "third_party/tbb"] + path = third_party/tbb + url = https://github.com/01org/tbb + branch = tbb_2018 [submodule "android/libs/fbjni"] ignore = dirty path = android/libs/fbjni diff --git a/BUILD.bazel b/BUILD.bazel index ecbeaab9bbf8..8c1aa2729101 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -125,6 +125,10 @@ filegroup( data = [":generate-code"], ) +exports_files( + srcs = ["aten/src/ATen/cpu/tbb/extra/version_string.ver.in"], +) + # ATen filegroup( name = "aten_base_cpp", @@ -271,6 +275,7 @@ header_template_rule( "@AT_BUILD_WITH_LAPACK@": "1", "@AT_PARALLEL_OPENMP@": "0", "@AT_PARALLEL_NATIVE@": "1", + "@AT_PARALLEL_NATIVE_TBB@": "0", "@AT_BLAS_F2C@": "0", "@AT_BLAS_USE_CBLAS_DOT@": "1", }, @@ -354,9 +359,6 @@ cc_library( ":aten_src_ATen_config", ] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"), copts = ATEN_COPTS, - linkopts = [ - "-ldl", - ], data = if_cuda( [":libcaffe2_nvrtc.so"], [], @@ -770,9 +772,6 @@ cc_library( ], )) + torch_sources, copts = TORCH_COPTS, - linkopts = [ - "-lrt", - ], defines = [ "CAFFE2_NIGHTLY_VERSION=20200115", ], @@ -792,9 +791,6 @@ cc_library( cc_library( name = "shm", srcs = glob(["torch/lib/libshm/*.cpp"]), - linkopts = [ - "-lrt", - ], deps = [ ":torch", ], diff --git a/CMakeLists.txt b/CMakeLists.txt index 335f5750648c..10a92dcc7c2c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -362,6 +362,9 @@ cmake_dependent_option( cmake_dependent_option( USE_TENSORPIPE "Use TensorPipe. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) +option(USE_TBB "Use TBB (Deprecated)" OFF) +cmake_dependent_option( + USE_SYSTEM_TBB "Use system-provided Intel TBB." OFF "USE_TBB" OFF) option(ONNX_ML "Enable traditional ONNX ML API." ON) option(HAVE_SOVERSION "Whether to add SOVERSION to the shared objects" OFF) option(BUILD_LIBTORCH_CPU_WITH_DEBUG @@ -480,6 +483,9 @@ if(USE_SYSTEM_LIBS) if(USE_NCCL) set(USE_SYSTEM_NCCL ON) endif() + if(USE_TBB) + set(USE_SYSTEM_TBB ON) + endif() endif() # Used when building Caffe2 through setup.py diff --git a/WORKSPACE b/WORKSPACE index 5b4f2f2e3375..f7e604332213 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -168,6 +168,16 @@ new_local_repository( path = "third_party/opentelemetry-cpp", ) +new_patched_local_repository( + name = "tbb", + build_file = "//third_party:tbb.BUILD", + patch_strip = 1, + patches = [ + "@//third_party:tbb.patch", + ], + path = "third_party/tbb", +) + new_local_repository( name = "tensorpipe", build_file = "//third_party:tensorpipe.BUILD", diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 9fa7a1f2305b..9ec458fda45e 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -349,6 +349,16 @@ endif() list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..) +if(USE_TBB) + if(USE_SYSTEM_TBB) + message("ATen is compiled with system-provided Intel TBB.") + else() + message("ATen is compiled with Intel TBB (${TBB_ROOT_DIR}).") + endif() + list(APPEND ATen_CPU_INCLUDE ${TBB_INCLUDE_DIR}) + list(APPEND ATen_CPU_DEPENDENCY_LIBS TBB::tbb) +endif() + if(BLAS_FOUND) if($ENV{TH_BINARY_BUILD}) message(STATUS "TH_BINARY_BUILD detected. Enabling special linkage.") diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in index fdd2ac2bc5f7..93b8e0434f1a 100644 --- a/aten/src/ATen/Config.h.in +++ b/aten/src/ATen/Config.h.in @@ -17,5 +17,6 @@ #define AT_BUILD_WITH_LAPACK() @AT_BUILD_WITH_LAPACK@ #define AT_PARALLEL_OPENMP @AT_PARALLEL_OPENMP@ #define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@ +#define AT_PARALLEL_NATIVE_TBB @AT_PARALLEL_NATIVE_TBB@ #define AT_BLAS_F2C() @AT_BLAS_F2C@ #define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@ diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index 966e29c0289f..ff14f568d22a 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -153,6 +153,8 @@ TORCH_API int intraop_default_num_threads(); #include // IWYU pragma: keep #elif AT_PARALLEL_NATIVE #include // IWYU pragma: keep +#elif AT_PARALLEL_NATIVE_TBB +#include // IWYU pragma: keep #endif #include // IWYU pragma: keep diff --git a/aten/src/ATen/ParallelCommon.cpp b/aten/src/ATen/ParallelCommon.cpp index e5d9bb83c016..0504a066eef5 100644 --- a/aten/src/ATen/ParallelCommon.cpp +++ b/aten/src/ATen/ParallelCommon.cpp @@ -80,6 +80,8 @@ std::string get_parallel_info() { ss << "OpenMP"; #elif AT_PARALLEL_NATIVE ss << "native thread pool"; + #elif AT_PARALLEL_NATIVE_TBB + ss << "native thread pool and TBB"; #endif #ifdef C10_MOBILE ss << " [mobile]"; diff --git a/aten/src/ATen/ParallelNativeTBB.cpp b/aten/src/ATen/ParallelNativeTBB.cpp new file mode 100644 index 000000000000..06be418f7d9c --- /dev/null +++ b/aten/src/ATen/ParallelNativeTBB.cpp @@ -0,0 +1,115 @@ +#include +#if AT_PARALLEL_NATIVE_TBB +#include +#include +#include + +#include +#include + +#include +#define TBB_PREVIEW_GLOBAL_CONTROL 1 +#include + +#ifdef _OPENMP +#include +#endif + +#if AT_MKL_ENABLED() +#include +#endif + +namespace at { + +namespace { +static thread_local tbb::task_group tg_; +thread_local int this_thread_id{0}; + +std::mutex global_thread_mutex_; +std::shared_ptr global_thread_limit_ = nullptr; +std::atomic num_intraop_threads_{-1}; + +void _internal_set_num_threads(int nthreads) { + TORCH_INTERNAL_ASSERT(nthreads > 0); + { + std::unique_lock lk(global_thread_mutex_); + // This is an antipattern and we shouldn't be constraining the number of + // threads in library code. + // TODO: Think of a smarter way to leverage tbb::thread_arena to limit the + // number of slots instead of the number of threads. + global_thread_limit_ = std::make_shared( + tbb::global_control::max_allowed_parallelism, nthreads); + num_intraop_threads_.store(nthreads); + } +} +} + +void init_num_threads() { + #ifdef _OPENMP + omp_set_num_threads(1); + #endif + + #if AT_MKL_ENABLED() + mkl_set_num_threads(1); + #endif + + int nthreads = num_intraop_threads_.load(); + if (nthreads < 0) { + nthreads = intraop_default_num_threads(); + } + _internal_set_num_threads(nthreads); +} + +void set_num_threads(int nthreads) { + TORCH_CHECK(nthreads > 0); + + _internal_set_num_threads(nthreads); +} + +int get_num_threads() { + at::internal::lazy_init_num_threads(); + return tbb::global_control::active_value( + tbb::global_control::max_allowed_parallelism); +} + +int get_thread_num() { + return this_thread_id; +} + +namespace internal { +void set_thread_num(int id) { + this_thread_id = id; +} +} + +bool in_parallel_region() { + return tbb::this_task_arena::current_thread_index() >= 0; +} + +void intraop_launch(std::function func) { + if (get_num_threads() > 1) { + tg_.run(func); + } else { + func(); + } +} + +c10::intrusive_ptr intraop_launch_future( + std::function func) { + auto future = c10::make_intrusive(NoneType::get()); + if (get_num_threads() > 1) { + tg_.run( + [func, future]() { + func(); + future->markCompleted(); + } + ); + } else { + func(); + future->markCompleted(); + } + return future; +} + +} // namespace at +#endif diff --git a/aten/src/ATen/ParallelNativeTBB.h b/aten/src/ATen/ParallelNativeTBB.h new file mode 100644 index 000000000000..9193e06ed695 --- /dev/null +++ b/aten/src/ATen/ParallelNativeTBB.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include + +#include + +#ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#endif +#include + +#define INTRA_OP_PARALLEL + +namespace at::internal { + +template +inline void invoke_parallel( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const F& f) { + // Choose number of tasks based on grain size and number of threads. + int64_t chunk_size = divup((end - begin), get_num_threads()); + // Make sure each task is at least grain_size size. + chunk_size = std::max(grain_size, chunk_size); + + std::atomic_flag err_flag = ATOMIC_FLAG_INIT; + std::exception_ptr eptr; + tbb::parallel_for( + tbb::blocked_range(begin, end, chunk_size), + [&eptr, &err_flag, f](const tbb::blocked_range& r) { + try { + internal::ThreadIdGuard tid_guard( + tbb::this_task_arena::current_thread_index()); + f(r.begin(), r.end()); + } catch (...) { + if (!err_flag.test_and_set()) { + eptr = std::current_exception(); + } + } + }, + tbb::static_partitioner{}); + if (eptr) { + std::rethrow_exception(eptr); + } +} + +} // namespace at::internal diff --git a/aten/src/ATen/ParallelThreadPoolNative.cpp b/aten/src/ATen/ParallelThreadPoolNative.cpp index 3ea51bb5e683..a9d5095f32a9 100644 --- a/aten/src/ATen/ParallelThreadPoolNative.cpp +++ b/aten/src/ATen/ParallelThreadPoolNative.cpp @@ -1,5 +1,5 @@ #include -#if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE +#if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE || AT_PARALLEL_NATIVE_TBB #include #include #include diff --git a/aten/src/ATen/cpu/tbb/CMakeLists.txt b/aten/src/ATen/cpu/tbb/CMakeLists.txt new file mode 100644 index 000000000000..6e946d5d13d5 --- /dev/null +++ b/aten/src/ATen/cpu/tbb/CMakeLists.txt @@ -0,0 +1,391 @@ +# Based on https://github.com/wjakob/tbb/blob/master/CMakeLists.txt +# All credit goes to Wenzel Jakob! + +cmake_minimum_required(VERSION 2.8.12 FATAL_ERROR) +project(tbb CXX) + +include(CheckCXXCompilerFlag) +include(CheckCXXSourceRuns) + +if(POLICY CMP0058) + cmake_policy(SET CMP0058 NEW) +endif() + +if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) + message(STATUS "Setting build type to 'Release' as none was specified.") + set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" + "MinSizeRel" "RelWithDebInfo") +endif() + +if(NOT TBB_ROOT_DIR) + set(TBB_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}") +endif() +if(NOT TBB_INSTALL_EXPORT_NAME) + set(TBB_INSTALL_EXPORT_NAME "Caffe2Targets") +endif() +if(NOT TBB_INSTALL_EXPORT_DESTINATION) + set(TBB_INSTALL_EXPORT_DESTINATION lib) +endif() +if(NOT TBB_INSTALL_RUNTIME_DIR) + set(TBB_INSTALL_RUNTIME_DIR bin) +endif() +if(NOT TBB_INSTALL_LIBRARY_DIR) + set(TBB_INSTALL_LIBRARY_DIR lib) +endif() +if(NOT TBB_INSTALL_ARCHIVE_DIR) + set(TBB_INSTALL_ARCHIVE_DIR lib) +endif() +if(NOT TBB_INSTALL_INCLUDE_DIR) + set(TBB_INSTALL_INCLUDE_DIR "${TBB_ROOT_DIR}/include") +endif() + +set(TBB_INCLUDES + "${TBB_ROOT_DIR}/include" + "${TBB_ROOT_DIR}/src" + "${TBB_ROOT_DIR}/src/rml/include" + ${CMAKE_CURRENT_BINARY_DIR}) + +option(TBB_BUILD_SHARED "Build TBB shared library" ON) +option(TBB_BUILD_STATIC "Build TBB static library" ON) +option(TBB_BUILD_TBBMALLOC "Build TBB malloc library" ON) +option(TBB_BUILD_TBBMALLOC_PROXY "Build TBB malloc proxy library" ON) +option(TBB_BUILD_TESTS "Build TBB tests and enable testing infrastructure" ON) +option(TBB_CI_BUILD "Is this a continuous integration build?" OFF) + +if(APPLE) + set(CMAKE_MACOSX_RPATH ON) +endif() + +file(GLOB tbb_src "${TBB_ROOT_DIR}/src/tbb/*.cpp" "${TBB_ROOT_DIR}/src/old/*.cpp") +list(APPEND tbb_src ${TBB_ROOT_DIR}/src/rml/client/rml_tbb.cpp) +file(GLOB to_remove "${TBB_ROOT_DIR}/src/old/test*.cpp") +if(NOT "${to_remove}" STREQUAL "") + list(REMOVE_ITEM tbb_src ${to_remove}) +endif() + +set(tbbmalloc_static_src + src/tbbmalloc/backend.cpp + src/tbbmalloc/large_objects.cpp + src/tbbmalloc/backref.cpp + src/tbbmalloc/tbbmalloc.cpp + src/tbbmalloc/frontend.cpp + src/tbb/itt_notify.cpp) + +set(tbbmalloc_src ${tbbmalloc_static_src}) + +set(tbbmalloc_proxy_src + src/tbbmalloc/proxy.cpp + src/tbbmalloc/tbb_function_replacement.cpp) + +if(CMAKE_SYSTEM_PROCESSOR MATCHES "(i386|x86_64)") + if(NOT APPLE AND NOT MINGW) + add_definitions(-DDO_ITT_NOTIFY) + endif() +endif() + +if(APPLE) + # Disable annoying "has no symbols" warnings + set(CMAKE_C_ARCHIVE_CREATE " Scr ") + set(CMAKE_CXX_ARCHIVE_CREATE " Scr ") + set(CMAKE_C_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") + set(CMAKE_CXX_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") +endif() + +macro(CHECK_CXX_COMPILER_AND_LINKER_FLAGS _RESULT _CXX_FLAGS _LINKER_FLAGS) + set(CMAKE_REQUIRED_FLAGS ${_CXX_FLAGS}) + set(CMAKE_REQUIRED_LIBRARIES ${_LINKER_FLAGS}) + set(CMAKE_REQUIRED_QUIET TRUE) + check_cxx_source_runs("#include \nint main(int argc, char **argv) { std::cout << \"test\"; return 0; }" ${_RESULT}) + set(CMAKE_REQUIRED_FLAGS "") + set(CMAKE_REQUIRED_LIBRARIES "") +endmacro() + +# Prefer libc++ in conjunction with Clang +if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + if(CMAKE_CXX_FLAGS MATCHES "-stdlib=libc\\+\\+") + message(STATUS "TBB: using libc++.") + else() + CHECK_CXX_COMPILER_AND_LINKER_FLAGS(HAS_LIBCPP "-stdlib=libc++" "-stdlib=libc++") + if(HAS_LIBCPP) + string(APPEND CMAKE_CXX_FLAGS " -stdlib=libc++ -D_LIBCPP_VERSION") + string(APPEND CMAKE_EXE_LINKER_FLAGS " -stdlib=libc++") + string(APPEND CMAKE_SHARED_LINKER_FLAGS " -stdlib=libc++") + message(STATUS "TBB: using libc++.") + else() + message(STATUS "TBB: NOT using libc++.") + endif() + endif() +endif() + +if(UNIX) + add_definitions(-DUSE_PTHREAD) + + check_cxx_compiler_flag("-std=c++17" SUPPORTS_STDCXX17) + if(SUPPORTS_STDCXX17) + set(CMAKE_CXX_FLAGS "-std=c++17 ${CMAKE_CXX_FLAGS}") + endif() + + check_cxx_compiler_flag("-mrtm -Werror" SUPPORTS_MRTM) + if(SUPPORTS_MRTM) + set(CMAKE_CXX_FLAGS "-mrtm ${CMAKE_CXX_FLAGS}") + endif() + +elseif(WIN32) + if(MSVC) + cmake_minimum_required(VERSION 3.1) + enable_language(ASM_MASM) + set(CMAKE_CXX_FLAGS "/GS- /Zc:wchar_t /Zc:forScope /DUSE_WINTHREAD ${CMAKE_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS "/D_CRT_SECURE_NO_DEPRECATE /D_WIN32_WINNT=0x0600 ${CMAKE_CXX_FLAGS}") + check_cxx_compiler_flag("/volatile:iso" SUPPORTS_VOLATILE_FLAG) + if(SUPPORTS_VOLATILE_FLAG) + set(CMAKE_CXX_FLAGS "/volatile:iso ${CMAKE_CXX_FLAGS}") + endif() + set(CMAKE_CXX_FLAGS "/wd4267 /wd4800 /wd4146 /wd4244 /wd4577 /wd4018 ${CMAKE_CXX_FLAGS}") + if(NOT CMAKE_SIZEOF_VOID_P) + message(FATAL_ERROR "'CMAKE_SIZEOF_VOID_P' is undefined. Please delete your build directory and rerun CMake again!") + endif() + + if(CMAKE_SIZEOF_VOID_P EQUAL 8) + list(APPEND tbb_src "${TBB_ROOT_DIR}/src/tbb/intel64-masm/atomic_support.asm") + list(APPEND tbb_src "${TBB_ROOT_DIR}/src/tbb/intel64-masm/itsx.asm") + list(APPEND tbb_src "${TBB_ROOT_DIR}/src/tbb/intel64-masm/intel64_misc.asm") + list(APPEND tbbmalloc_src "${TBB_ROOT_DIR}/src/tbb/intel64-masm/atomic_support.asm") + set(CMAKE_ASM_MASM_FLAGS "/DEM64T=1 ${CMAKE_ASM_MASM_FLAGS}") + else() + list(APPEND tbb_src "${TBB_ROOT_DIR}/src/tbb/ia32-masm/atomic_support.asm" + "${TBB_ROOT_DIR}/src/tbb/ia32-masm/itsx.asm src/tbb/ia32-masm/lock_byte.asm") + # Enable SAFESEH feature for assembly (x86 builds only). + set(CMAKE_ASM_MASM_FLAGS "/safeseh ${CMAKE_ASM_MASM_FLAGS}") + endif() + elseif(MINGW) + add_definitions(-DUSE_WINTHREAD) + add_definitions(-D_WIN32_WINNT=0x0502) + set(CMAKE_CXX_FLAGS "-mthreads ${CMAKE_CXX_FLAGS}") + endif() +endif() + +if(MSVC) + set(ENABLE_RTTI "/EHsc /GR ") + set(DISABLE_RTTI "/EHs- /GR- ") +elseif(UNIX) + set(ENABLE_RTTI "-frtti -fexceptions ") + set(DISABLE_RTTI "-fno-rtti -fno-exceptions ") +endif() + +##-------- +# - Added TBB_USE_GLIBCXX_VERSION macro to specify the version of GNU +# libstdc++ when it cannot be properly recognized, e.g. when used +# with Clang on Linux* OS. Inspired by a contribution from David A. +if(NOT TBB_USE_GLIBCXX_VERSION AND UNIX AND NOT APPLE) + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + # using Clang + string(REPLACE "." "0" TBB_USE_GLIBCXX_VERSION ${CMAKE_CXX_COMPILER_VERSION}) + endif() +endif() + +if(TBB_USE_GLIBCXX_VERSION) + add_definitions(-DTBB_USE_GLIBCXX_VERSION=${TBB_USE_GLIBCXX_VERSION}) +endif() + +##------- + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + check_cxx_compiler_flag("-flifetime-dse=1" SUPPORTS_FLIFETIME) + if(SUPPORTS_FLIFETIME) + add_definitions(-flifetime-dse=1) + endif() +endif() + +# Linker export definitions +if(APPLE) + set(ARCH_PREFIX "mac") +elseif(WIN32) + set(ARCH_PREFIX "win") +else() + set(ARCH_PREFIX "lin") +endif() + +if(CMAKE_SIZEOF_VOID_P EQUAL 8) + set(ARCH_PREFIX "${ARCH_PREFIX}64") +else() + set(ARCH_PREFIX "${ARCH_PREFIX}32") +endif() + +if(MINGW) + set(ARCH_PREFIX "${ARCH_PREFIX}-gcc") + # there's no win32-gcc-tbb-export.def, use lin32-tbb-export.def + execute_process(COMMAND ${CMAKE_COMMAND} -E copy ${TBB_ROOT_DIR}/src/tbb/lin32-tbb-export.def ${TBB_ROOT_DIR}/src/tbb/win32-gcc-tbb-export.def) +endif() + +if(MSVC) + add_custom_command(OUTPUT tbb.def + COMMAND ${CMAKE_CXX_COMPILER} /TC /EP ${TBB_ROOT_DIR}/src/tbb/${ARCH_PREFIX}-tbb-export.def -I ${TBB_ROOT_DIR}/include > tbb.def + MAIN_DEPENDENCY ${TBB_ROOT_DIR}/src/tbb/${ARCH_PREFIX}-tbb-export.def + COMMENT "Preprocessing tbb.def" + ) + + add_custom_command(OUTPUT tbbmalloc.def + COMMAND ${CMAKE_CXX_COMPILER} /TC /EP ${TBB_ROOT_DIR}/src/tbbmalloc/${ARCH_PREFIX}-tbbmalloc-export.def -I ${TBB_ROOT_DIR}/include > tbbmalloc.def + MAIN_DEPENDENCY ${TBB_ROOT_DIR}/src/tbbmalloc/${ARCH_PREFIX}-tbbmalloc-export.def + COMMENT "Preprocessing tbbmalloc.def" + ) +else() + add_custom_command(OUTPUT tbb.def + COMMAND ${CMAKE_CXX_COMPILER} -xc++ -E ${TBB_ROOT_DIR}/src/tbb/${ARCH_PREFIX}-tbb-export.def -I ${TBB_ROOT_DIR}/include -o tbb.def + MAIN_DEPENDENCY ${TBB_ROOT_DIR}/src/tbb/${ARCH_PREFIX}-tbb-export.def + COMMENT "Preprocessing tbb.def" + ) + + add_custom_command(OUTPUT tbbmalloc.def + COMMAND ${CMAKE_CXX_COMPILER} -xc++ -E ${TBB_ROOT_DIR}/src/tbbmalloc/${ARCH_PREFIX}-tbbmalloc-export.def -I ${TBB_ROOT_DIR}/include -o tbbmalloc.def + MAIN_DEPENDENCY ${TBB_ROOT_DIR}/src/tbbmalloc/${ARCH_PREFIX}-tbbmalloc-export.def + COMMENT "Preprocessing tbbmalloc.def" + ) +endif() + +add_custom_target(tbb_def_files DEPENDS tbb.def tbbmalloc.def) + +# TBB library +if(TBB_BUILD_STATIC) + add_library(tbb_static STATIC ${tbb_src}) + target_include_directories(tbb_static PRIVATE ${TBB_INCLUDES}) + set_property(TARGET tbb_static APPEND PROPERTY COMPILE_DEFINITIONS "__TBB_BUILD=1") + set_property(TARGET tbb_static APPEND_STRING PROPERTY COMPILE_FLAGS ${ENABLE_RTTI}) + install(TARGETS tbb_static + EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} + ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR}) + if(MSVC) + target_compile_definitions(tbb_static PUBLIC __TBB_NO_IMPLICIT_LINKAGE=1) + endif() + + if(UNIX AND NOT APPLE) + target_link_libraries(tbb_static PUBLIC pthread dl) + endif() +endif() + +if(TBB_BUILD_SHARED) + add_library(tbb SHARED ${tbb_src}) + target_include_directories(tbb PRIVATE ${TBB_INCLUDES}) + set_property(TARGET tbb APPEND PROPERTY COMPILE_DEFINITIONS "__TBB_BUILD=1") + set_property(TARGET tbb APPEND_STRING PROPERTY COMPILE_FLAGS ${ENABLE_RTTI}) + add_dependencies(tbb tbb_def_files) + + if(APPLE) + set_property(TARGET tbb APPEND PROPERTY LINK_FLAGS "-Wl,-exported_symbols_list,\"${CMAKE_CURRENT_BINARY_DIR}/tbb.def\"") + elseif(MSVC) + set_property(TARGET tbb APPEND PROPERTY LINK_FLAGS "/DEF:\"${CMAKE_CURRENT_BINARY_DIR}/tbb.def\"") + else() + set_property(TARGET tbb APPEND PROPERTY LINK_FLAGS "-Wl,-version-script,\"${CMAKE_CURRENT_BINARY_DIR}/tbb.def\"") + endif() + + install(TARGETS tbb + EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} + LIBRARY DESTINATION ${TBB_INSTALL_LIBRARY_DIR} + ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR} + RUNTIME DESTINATION ${TBB_INSTALL_RUNTIME_DIR}) + if(UNIX AND NOT APPLE) + target_link_libraries(tbb PUBLIC pthread dl) + endif() + if(MSVC) + target_compile_definitions(tbb PUBLIC __TBB_NO_IMPLICIT_LINKAGE=1) + endif() +endif() + + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + # Quench a warning on GCC + set_source_files_properties(${TBB_ROOT_DIR}/src/tbb/governor.cpp COMPILE_FLAGS "-Wno-missing-field-initializers ") +elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + # Quench a warning on Clang + set_source_files_properties(${TBB_ROOT_DIR}/src/tbb/itt_notify.cpp COMPILE_FLAGS "-Wno-varargs ") +elseif(MSVC) + # Quench a warning on MSVC + set_source_files_properties(${TBB_ROOT_DIR}/src/tbb/scheduler.cpp COMPILE_FLAGS "/wd4458 ") +endif() + +if(TBB_BUILD_TBBMALLOC) + # TBB malloc library + if(TBB_BUILD_STATIC) + add_library(tbbmalloc_static STATIC ${tbbmalloc_static_src}) + target_include_directories(tbbmalloc_static PRIVATE ${TBB_INCLUDES}) + set_property(TARGET tbbmalloc_static APPEND PROPERTY COMPILE_DEFINITIONS "__TBBMALLOC_BUILD=1") + set_property(TARGET tbbmalloc_static APPEND_STRING PROPERTY COMPILE_FLAGS ${DISABLE_RTTI}) + if(MSVC) + target_compile_definitions(tbbmalloc_static PUBLIC __TBB_NO_IMPLICIT_LINKAGE=1 __TBBMALLOC_NO_IMPLICIT_LINKAGE=1) + endif() + install(TARGETS tbbmalloc_static + EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} + ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR}) + endif() + + if(TBB_BUILD_SHARED) + add_library(tbbmalloc SHARED ${tbbmalloc_src}) + target_include_directories(tbbmalloc PRIVATE ${TBB_INCLUDES}) + set_property(TARGET tbbmalloc APPEND PROPERTY COMPILE_DEFINITIONS "__TBBMALLOC_BUILD=1") + set_property(TARGET tbbmalloc APPEND_STRING PROPERTY COMPILE_FLAGS ${DISABLE_RTTI}) + add_dependencies(tbbmalloc tbb_def_files) + if(APPLE) + set_property(TARGET tbbmalloc APPEND PROPERTY LINK_FLAGS "-Wl,-exported_symbols_list,\"${CMAKE_CURRENT_BINARY_DIR}/tbbmalloc.def\"") + elseif(MSVC) + set_property(TARGET tbbmalloc APPEND PROPERTY LINK_FLAGS "/DEF:\"${CMAKE_CURRENT_BINARY_DIR}/tbbmalloc.def\"") + else() + set_property(TARGET tbbmalloc APPEND PROPERTY LINK_FLAGS "-Wl,-version-script,\"${CMAKE_CURRENT_BINARY_DIR}/tbbmalloc.def\"") + endif() + if(MSVC) + target_compile_definitions(tbbmalloc PUBLIC __TBB_NO_IMPLICIT_LINKAGE=1 __TBBMALLOC_NO_IMPLICIT_LINKAGE=1) + endif() + install(TARGETS tbbmalloc + EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} + LIBRARY DESTINATION ${TBB_INSTALL_LIBRARY_DIR} + ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR} + RUNTIME DESTINATION ${TBB_INSTALL_RUNTIME_DIR}) + if(UNIX AND NOT APPLE) + target_link_libraries(tbbmalloc PUBLIC pthread dl) + endif() + endif() +endif() + +if(TBB_BUILD_TBBMALLOC_PROXY) + # TBB malloc proxy library + if(TBB_BUILD_STATIC) + add_library(tbbmalloc_proxy_static STATIC ${tbbmalloc_proxy_src}) + set_property(TARGET tbbmalloc_proxy_static APPEND PROPERTY COMPILE_DEFINITIONS "__TBBMALLOC_BUILD=1") + set_property(TARGET tbbmalloc_proxy_static APPEND_STRING PROPERTY COMPILE_FLAGS ${DISABLE_RTTI}) + install(TARGETS tbbmalloc_proxy_static + EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} + ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR}) + endif() + + if(TBB_BUILD_SHARED) + add_library(tbbmalloc_proxy SHARED ${tbbmalloc_proxy_src}) + set_property(TARGET tbbmalloc_proxy APPEND PROPERTY COMPILE_DEFINITIONS "__TBBMALLOC_BUILD=1") + set_property(TARGET tbbmalloc_proxy APPEND_STRING PROPERTY COMPILE_FLAGS ${DISABLE_RTTI}) + target_link_libraries(tbbmalloc_proxy PUBLIC tbbmalloc) + install(TARGETS tbbmalloc_proxy + EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} + LIBRARY DESTINATION ${TBB_INSTALL_LIBRARY_DIR} + ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR} + RUNTIME DESTINATION ${TBB_INSTALL_RUNTIME_DIR}) + if(UNIX AND NOT APPLE) + target_link_libraries(tbbmalloc_proxy PUBLIC pthread dl) + endif() + endif() +endif() + +install(DIRECTORY "${TBB_ROOT_DIR}/include/tbb" DESTINATION ${TBB_INSTALL_INCLUDE_DIR}) + +# version_string.ver +if(UNIX) + execute_process(COMMAND date "+%a, %d %b %Y %H:%M:%S %z" + OUTPUT_VARIABLE _configure_date + OUTPUT_STRIP_TRAILING_WHITESPACE) +elseif(WIN32) + execute_process(COMMAND cmd " /C date /T" + OUTPUT_VARIABLE _configure_date + OUTPUT_STRIP_TRAILING_WHITESPACE) +else() + set(_configure_date "Unknown") +endif() +include_directories(${CMAKE_BINARY_DIR}) +configure_file(extra/version_string.ver.in version_string.ver @ONLY) diff --git a/aten/src/ATen/cpu/tbb/extra/version_string.ver.in b/aten/src/ATen/cpu/tbb/extra/version_string.ver.in new file mode 100644 index 000000000000..bb9f96e8f295 --- /dev/null +++ b/aten/src/ATen/cpu/tbb/extra/version_string.ver.in @@ -0,0 +1,11 @@ +#define __TBB_VERSION_STRINGS(N) \ +#N": BUILD_HOST @CMAKE_SYSTEM_NAME@" ENDL \ +#N": BUILD_OS @CMAKE_SYSTEM@" ENDL \ +#N": BUILD_KERNEL @CMAKE_SYSTEM_VERSION@" ENDL \ +#N": BUILD_GCC @CMAKE_CXX_COMPILER_ID@" ENDL \ +#N": BUILD_LIBC Unknown" ENDL \ +#N": BUILD_LD Unknown" ENDL \ +#N": BUILD_TARGET Unknown" ENDL \ +#N": BUILD_COMMAND Unknown" ENDL + +#define __TBB_DATETIME "@_configure_date@" diff --git a/buckbuild.bzl b/buckbuild.bzl index 649ebe668365..4c4fc9a89a28 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -261,6 +261,7 @@ def get_aten_preprocessor_flags(): "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", "-DAT_PARALLEL_OPENMP_FBXPLAT=0", "-DAT_PARALLEL_NATIVE_FBXPLAT=1", + "-DAT_PARALLEL_NATIVE_TBB_FBXPLAT=0", "-DUSE_LAPACK_FBXPLAT=0", "-DAT_BLAS_F2C_FBXPLAT=0", "-DAT_BLAS_USE_CBLAS_DOT_FBXPLAT=0", @@ -1111,6 +1112,9 @@ def define_buck_targets( "@AT_PARALLEL_NATIVE@", "AT_PARALLEL_NATIVE_FBXPLAT", "--replace", + "@AT_PARALLEL_NATIVE_TBB@", + "AT_PARALLEL_NATIVE_TBB_FBXPLAT", + "--replace", "@AT_BUILD_WITH_LAPACK@", "USE_LAPACK_FBXPLAT", "--replace", diff --git a/build_variables.bzl b/build_variables.bzl index 8b5ac4f46d7c..5770aca4c9e0 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -999,6 +999,7 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/NestedTensorImpl.cpp", "aten/src/ATen/ParallelCommon.cpp", "aten/src/ATen/ParallelNative.cpp", + "aten/src/ATen/ParallelNativeTBB.cpp", "aten/src/ATen/ParallelOpenMP.cpp", "aten/src/ATen/ParallelThreadPoolNative.cpp", "aten/src/ATen/PythonTorchFunctionTLS.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 1e29044e19fd..73c6467075ba 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -16,11 +16,14 @@ endif() # ATen parallelism settings # OMP - OpenMP for intra-op, native thread pool for inter-op parallelism # NATIVE - using native thread pool for intra- and inter-op parallelism +# TBB - using TBB for intra- and native thread pool for inter-op parallelism if(INTERN_BUILD_MOBILE) set(ATEN_THREADING "NATIVE" CACHE STRING "ATen parallel backend") else() if(USE_OPENMP) set(ATEN_THREADING "OMP" CACHE STRING "ATen parallel backend") + elseif(USE_TBB) + set(ATEN_THREADING "TBB" CACHE STRING "ATen parallel backend") else() set(ATEN_THREADING "NATIVE" CACHE STRING "ATen parallel backend") endif() @@ -28,12 +31,19 @@ endif() set(AT_PARALLEL_OPENMP 0) set(AT_PARALLEL_NATIVE 0) +set(AT_PARALLEL_NATIVE_TBB 0) message(STATUS "Using ATen parallel backend: ${ATEN_THREADING}") if("${ATEN_THREADING}" STREQUAL "OMP") set(AT_PARALLEL_OPENMP 1) elseif("${ATEN_THREADING}" STREQUAL "NATIVE") set(AT_PARALLEL_NATIVE 1) +elseif("${ATEN_THREADING}" STREQUAL "TBB") + if(NOT USE_TBB) + message(FATAL_ERROR "Using TBB backend but USE_TBB is off") + endif() + message(WARNING "ATEN TBB Threading is deprectated.") + set(AT_PARALLEL_NATIVE_TBB 1) else() message(FATAL_ERROR "Unknown ATen parallel backend: ${ATEN_THREADING}") endif() @@ -1213,6 +1223,11 @@ if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU" set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/qlinear_unpack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) endif() +if(USE_TBB) + list(APPEND ATen_CPU_INCLUDE ${TBB_INCLUDE_DIR}) + target_link_libraries(torch_cpu PUBLIC TBB::tbb) +endif() + target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE}) target_include_directories(torch_cpu PRIVATE @@ -1690,6 +1705,10 @@ if(BUILD_SHARED_LIBS) target_link_libraries(torch_global_deps ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS}) target_link_libraries(torch_global_deps torch::cudart torch::nvtoolsext) endif() + if(USE_TBB) + target_link_libraries(torch_global_deps TBB::tbb) + endif() + install(TARGETS torch_global_deps DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index e15b55cd16ed..fb89afd9fff7 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -134,6 +134,35 @@ else() "Cannot find threading library. PyTorch requires Threads to compile.") endif() +if(USE_TBB) + if(USE_SYSTEM_TBB) + find_package(TBB 2018.0 REQUIRED CONFIG COMPONENTS tbb) + + get_target_property(TBB_INCLUDE_DIR TBB::tbb INTERFACE_INCLUDE_DIRECTORIES) + else() + message(STATUS "Compiling TBB from source") + # Unset our restrictive C++ flags here and reset them later. + # Remove this once we use proper target_compile_options. + set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + set(CMAKE_CXX_FLAGS) + + set(TBB_ROOT_DIR "${PROJECT_SOURCE_DIR}/third_party/tbb") + set(TBB_BUILD_STATIC OFF CACHE BOOL " " FORCE) + set(TBB_BUILD_SHARED ON CACHE BOOL " " FORCE) + set(TBB_BUILD_TBBMALLOC OFF CACHE BOOL " " FORCE) + set(TBB_BUILD_TBBMALLOC_PROXY OFF CACHE BOOL " " FORCE) + set(TBB_BUILD_TESTS OFF CACHE BOOL " " FORCE) + add_subdirectory(${PROJECT_SOURCE_DIR}/aten/src/ATen/cpu/tbb) + set_property(TARGET tbb tbb_def_files PROPERTY FOLDER "dependencies") + + set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) + + set(TBB_INCLUDE_DIR "${TBB_ROOT_DIR}/include") + + add_library(TBB::tbb ALIAS tbb) + endif() +endif() + # ---[ protobuf if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) if(USE_LITE_PROTO) diff --git a/cmake/Modules/FindMKL.cmake b/cmake/Modules/FindMKL.cmake index daaa5dd24f00..01de7c7cec15 100644 --- a/cmake/Modules/FindMKL.cmake +++ b/cmake/Modules/FindMKL.cmake @@ -71,8 +71,8 @@ IF (NOT "${MKL_THREADING}" STREQUAL "SEQ" AND MESSAGE(FATAL_ERROR "Invalid MKL_THREADING (${MKL_THREADING}), should be one of: SEQ, TBB, OMP") ENDIF() -IF ("${MKL_THREADING}" STREQUAL "TBB" AND NOT TARGET TBB::tbb) - MESSAGE(FATAL_ERROR "MKL_THREADING is TBB but TBB is not found") +IF ("${MKL_THREADING}" STREQUAL "TBB" AND NOT USE_TBB) + MESSAGE(FATAL_ERROR "MKL_THREADING is TBB but USE_TBB is turned off") ENDIF() MESSAGE(STATUS "MKL_THREADING = ${MKL_THREADING}") diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 9e002c939e5b..f6a19812c83d 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -101,7 +101,7 @@ IF(NOT MKLDNN_FOUND) IF(NOT MKLDNN_CPU_RUNTIME) SET(MKLDNN_CPU_RUNTIME "OMP" CACHE STRING "") ELSEIF(MKLDNN_CPU_RUNTIME STREQUAL "TBB") - IF(TARGET TBB::tbb) + IF(USE_TBB) MESSAGE(STATUS "MKL-DNN is using TBB") SET(TBB_cmake_included TRUE) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 99b6521328d6..329bdd19a6cd 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -151,6 +151,10 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_OBSERVERS : ${USE_OBSERVERS}") message(STATUS " USE_OPENCL : ${USE_OPENCL}") message(STATUS " USE_OPENMP : ${USE_OPENMP}") + message(STATUS " USE_TBB : ${USE_TBB}") + if(${USE_TBB}) + message(STATUS " USE_SYSTEM_TBB : ${USE_SYSTEM_TBB}") + endif() message(STATUS " USE_MIMALLOC : ${USE_MIMALLOC}") message(STATUS " USE_VULKAN : ${USE_VULKAN}") if(${USE_VULKAN}) diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 0f5da8e6cae2..597c22c75dd0 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -317,6 +317,9 @@ function(caffe2_binary_target target_name_or_src) if(DEFINED Caffe2_MODULES) target_link_libraries(${__target} ${Caffe2_MODULES}) endif() + if(USE_TBB AND NOT USE_SYSTEM_TBB) + target_include_directories(${__target} PUBLIC ${TBB_INCLUDE_DIR}) + endif() install(TARGETS ${__target} DESTINATION bin) endfunction() diff --git a/defs.bzl b/defs.bzl index 6ea4b1219325..83aa9383d7c4 100644 --- a/defs.bzl +++ b/defs.bzl @@ -64,6 +64,8 @@ def get_cpu_parallel_backend_flags(): defs = [] if parallel_backend == "openmp": defs.append("-DAT_PARALLEL_OPENMP_FBCODE=1") + elif parallel_backend == "tbb": + defs.append("-DAT_PARALLEL_NATIVE_TBB_FBCODE=1") elif parallel_backend == "native": defs.append("-DAT_PARALLEL_NATIVE_FBCODE=1") else: diff --git a/setup.py b/setup.py index e2529335bcc6..9826207de73b 100644 --- a/setup.py +++ b/setup.py @@ -179,6 +179,13 @@ # possible values: # OMP - use OpenMP for intra-op and native backend for inter-op tasks # NATIVE - use native thread pool for both intra- and inter-op tasks +# TBB - using TBB for intra- and native thread pool for inter-op parallelism +# +# USE_TBB +# enable TBB support +# +# USE_SYSTEM_TBB +# Use system-provided Intel TBB. # # USE_SYSTEM_LIBS (work in progress) # Use system-provided libraries to satisfy the build dependencies. @@ -364,6 +371,7 @@ def get_submodule_folders(): for name in [ "gloo", "cpuinfo", + "tbb", "onnx", "foxi", "QNNPACK", diff --git a/third_party/mkl-dnn.BUILD b/third_party/mkl-dnn.BUILD index bb7f34107892..9a688a52b1cf 100644 --- a/third_party/mkl-dnn.BUILD +++ b/third_party/mkl-dnn.BUILD @@ -130,7 +130,10 @@ cc_library( ], deps = [ "@mkl", - ], + ] + select({ + "@pytorch//tools/config:thread_sanitizer": [], + "//conditions:default": ["@tbb"], + }), defines = [ "DNNL_ENABLE_MAX_CPU_ISA", "DNNL_ENABLE_CONCURRENT_EXEC", diff --git a/third_party/mkl.BUILD b/third_party/mkl.BUILD index b7abb0e035ad..c3115f164a66 100644 --- a/third_party/mkl.BUILD +++ b/third_party/mkl.BUILD @@ -12,7 +12,10 @@ cc_library( "libmkl_vml_avx2.so", "libmkl_vml_avx512.so", "libmkl_vml_def.so", - ], + ] + select({ + "@pytorch//tools/config:thread_sanitizer": [], + "//conditions:default": ["libmkl_tbb_thread.so"], + }), visibility = ["//visibility:public"], deps = ["@mkl_headers"], ) diff --git a/third_party/tbb b/third_party/tbb new file mode 160000 index 000000000000..a51a90bc609b --- /dev/null +++ b/third_party/tbb @@ -0,0 +1 @@ +Subproject commit a51a90bc609bb73db8ea13841b5cf7aa4344d4a9 diff --git a/third_party/tbb.BUILD b/third_party/tbb.BUILD new file mode 100644 index 000000000000..b11e65847331 --- /dev/null +++ b/third_party/tbb.BUILD @@ -0,0 +1,75 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") +load("@pytorch//third_party:substitution.bzl", "template_rule") + +licenses(["notice"]) # Apache 2.0 + +template_rule( + name = "version_string", + src = "@//:aten/src/ATen/cpu/tbb/extra/version_string.ver.in", + out = "version_string.h", + substitutions = { + "@CMAKE_SYSTEM_NAME@": "Unknown", + "@CMAKE_SYSTEM@": "Unknown", + "@CMAKE_SYSTEM_VERSION@": "Unknown", + "@CMAKE_CXX_COMPILER_ID@": "Unknown", + "@_configure_date@": "Unknown", + } +) + +cc_library( + name = "tbb", + srcs = [":version_string"] + glob( + [ + "src/old/*.h", + "src/rml/client/*.h", + "src/rml/include/*.h", + "src/rml/server/*.h", + "src/tbb/*.h", + "src/tbb/tools_api/*.h", + "src/tbb/tools_api/legacy/*.h", + "src/old/*.cpp", + "src/tbb/*.cpp", + ], + exclude = ["src/old/test_*.cpp"], + ) + ["src/rml/client/rml_tbb.cpp"], + hdrs = glob( + [ + "include/tbb/*", + "include/tbb/compat/*", + "include/tbb/internal/*", + "include/tbb/machine/*", + ], + exclude = ["include/tbb/scalable_allocator.h"], + ), + copts = [ + "-Iexternal/tbb/src/rml/include", + "-Iexternal/tbb/src", + "-pthread", + "-DDO_ITT_NOTIFY=1", + "-DUSE_PTHREAD=1", + "-D__TBB_BUILD=1", + "-D__TBB_DYNAMIC_LOAD_ENABLED=0", + "-D__TBB_SOURCE_DIRECTLY_INCLUDED=1", + "-fno-sanitize=vptr", + "-fno-sanitize=thread", + ], + defines = [ + # TBB Cannot detect the standard library version when using clang with libstdc++. + # See https://github.com/01org/tbb/issues/22 + "TBB_USE_GLIBCXX_VERSION=(_GLIBCXX_RELEASE*10000)", + "TBB_PREVIEW_GLOBAL_CONTROL=1", + "TBB_PREVIEW_LOCAL_OBSERVER=1", + "__TBB_ALLOW_MUTABLE_FUNCTORS=1", + ], + includes = [ + "include", + "src/tbb/tools_api", + ], + linkopts = [ + "-ldl", + "-lpthread", + "-lrt", + ], + textual_hdrs = ["src/tbb/tools_api/ittnotify_static.c"], + visibility = ["//visibility:public"], +) diff --git a/third_party/tbb.patch b/third_party/tbb.patch new file mode 100644 index 000000000000..4a1f6845b774 --- /dev/null +++ b/third_party/tbb.patch @@ -0,0 +1,34 @@ +diff --git a/src/rml/server/rml_server.cpp b/src/rml/server/rml_server.cpp +index 2508465..1e22ad2 100644 +--- a/src/rml/server/rml_server.cpp ++++ b/src/rml/server/rml_server.cpp +@@ -3279,10 +3279,10 @@ extern "C" void __KMP_call_with_my_server_info( ::rml::server_info_callback_t cb + /* + * RML server info + */ +-#include "version_string.ver" ++#include "version_string.h" + + #ifndef __TBB_VERSION_STRINGS +-#pragma message("Warning: version_string.ver isn't generated properly by version_info.sh script!") ++#pragma message("Warning: version_string.h isn't generated properly by version_info.sh script!") + #endif + + // We use the build time as the RML server info. TBB is required to build RML, so we make it the same as the TBB build time. +diff --git a/src/tbb/tbb_version.h b/src/tbb/tbb_version.h +index dcaa55b..4981a8a 100644 +--- a/src/tbb/tbb_version.h ++++ b/src/tbb/tbb_version.h +@@ -25,10 +25,10 @@ + #ifndef ENDL + #define ENDL "\n" + #endif +-#include "version_string.ver" ++#include "version_string.h" + + #ifndef __TBB_VERSION_STRINGS +-#pragma message("Warning: version_string.ver isn't generated properly by version_info.sh script!") ++#pragma message("Warning: version_string.h isn't generated properly by version_info.sh script!") + // here is an example of macros value: + #define __TBB_VERSION_STRINGS \ + "TBB: BUILD_HOST\tUnknown\n" \ diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 5e7e3739695d..ffd0e6f95a87 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -24,7 +24,7 @@ marginrankingloss_reference, multimarginloss_reference, multilabelmarginloss_reference, nllloss_reference, nlllossNd_reference, smoothl1loss_reference, softmarginloss_reference, get_reduction) from torch.testing._internal.common_utils import ( - freeze_rng_state, skipIfMps, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS, + freeze_rng_state, set_single_threaded_if_parallel_tbb, skipIfMps, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS, skipIfTorchDynamo) from types import ModuleType from typing import List, Tuple, Type, Set, Dict @@ -235,7 +235,7 @@ def __init__(self, self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin) def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): - result = [] + result = [set_single_threaded_if_parallel_tbb] for decorator in self.decorators: if isinstance(decorator, DecorateInfo): if decorator.is_active(test_class, test_name, device, dtype, param_kwargs): diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index c7122c8666d4..0f179e0c23fb 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -39,6 +39,7 @@ from torch.testing._internal.common_methods_invocations import DecorateInfo from torch.testing._internal.common_utils import ( _TestParametrizer, + set_single_threaded_if_parallel_tbb, skipIfMps, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, @@ -160,7 +161,7 @@ def __init__( self.supports_fused_on = supports_fused_on def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): - result = [] + result = [set_single_threaded_if_parallel_tbb] for decorator in self.decorators: if isinstance(decorator, DecorateInfo): if decorator.is_active( diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2237ec67c500..475ab977cdb4 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1497,6 +1497,8 @@ def wrapper(*args, **kwargs): # See: https://github.com/pytorch/pytorch/pull/59402#issuecomment-858811135 TestEnvironment.def_flag("TEST_CUDA_MEM_LEAK_CHECK", env_var="PYTORCH_TEST_CUDA_MEM_LEAK_CHECK") +# True if CI is running TBB-enabled Pytorch +IS_TBB = "tbb" in os.getenv("BUILD_ENVIRONMENT", "") # Dict of NumPy dtype -> torch dtype (when the correspondence exists) numpy_to_torch_dtype_dict = { @@ -1873,6 +1875,19 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper + +def skipIfTBB(message="This test makes TBB sad"): + def dec_fn(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if IS_TBB: + raise unittest.SkipTest(message) + else: + fn(*args, **kwargs) + return wrapper + return dec_fn + + def skip_if_pytest(fn): @wraps(fn) def wrapped(*args, **kwargs): @@ -4708,6 +4723,24 @@ def dtype_name(dtype): } +def set_single_threaded_if_parallel_tbb(fn): + """Set test to be single threaded for parallel tbb. + + See https://github.com/pytorch/pytorch/issues/64571#issuecomment-914691883 + """ + if not IS_TBB: + return fn + + @wraps(fn) + def wrap_fn(*args, **kwargs): + num_threads = torch.get_num_threads() + torch.set_num_threads(1) + try: + return fn(*args, **kwargs) + finally: + torch.set_num_threads(num_threads) + return wrap_fn + @functools.lru_cache def get_cycles_per_ms() -> float: diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 913947ea84c7..b625d2dbd40f 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1880,6 +1880,9 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): if not is_standalone: extra_ldflags.append('-ltorch_python') + if is_standalone and "TBB" in torch.__config__.parallel_info(): + extra_ldflags.append('-ltbb') + if is_standalone: extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}") From da39461d615ab7d4867fb31ff00d186fb7bc2954 Mon Sep 17 00:00:00 2001 From: SandishKumarHN Date: Thu, 30 May 2024 01:47:38 +0000 Subject: [PATCH 074/706] [optim] Move test_grad_scaling_autocast_fused_optimizers to test_cuda.py (#126418) this PR address the comments in this PR #124904 - Move test_grad_scaling_autocast_fused_optimizers to test_cuda.py - Combine _grad_scaling_autocast_fused_optimizers into test_grad_scaling_autocast_fused_optimizers - Move to OptimizerInfo framework. - For failing tests test_grad_scaling_autocast_fused_optimizers AdamW_cuda_float32, Adam_cuda_float32 - Added toleranceOverride in this PR - created a issue #127000 ``` > (c2env) [sandish@devgpu166.ash6 ~/pytorch (refactoroptimizers)]$ python test/test_cuda.py -k test_grad_scaling_autocast_fused_optimizers -v /home/sandish/pytorch/torch/backends/cudnn/__init__.py:106: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system. warnings.warn( /home/sandish/pytorch/torch/backends/cudnn/__init__.py:106: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system. warnings.warn( test_grad_scaling_autocast_fused_optimizers_Adagrad_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True} {'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'lr': 0.1, 'fused': True} {'lr': 0.1, 'fused': True} {'initial_accumulator_value': 0.1, 'weight_decay': 0.1, 'fused': True} {'initial_accumulator_value': 0.1, 'weight_decay': 0.1, 'fused': True} {'lr': 0.1, 'lr_decay': 0.5, 'weight_decay': 0.1, 'fused': True} {'lr': 0.1, 'lr_decay': 0.5, 'weight_decay': 0.1, 'fused': True} {'lr': tensor(0.0010), 'fused': True} {'lr': tensor(0.0010), 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_AdamW_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_Adam_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_SGD_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'lr': tensor(0.0010), 'fused': True} {'lr': tensor(0.0010), 'fused': True} {'momentum': 0.9, 'fused': True} {'momentum': 0.9, 'fused': True} {'momentum': 0.9, 'dampening': 0.5, 'fused': True} {'momentum': 0.9, 'dampening': 0.5, 'fused': True} {'momentum': 0.9, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_Adagrad_cuda_float32 (__main__.TestCudaOptimsCUDA) ... skipped 'cuda is not supported for fused on Adagrad' test_grad_scaling_autocast_fused_optimizers_AdamW_cuda_float32 (__main__.TestCudaOptimsCUDA) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'capturable': True, 'fused': True} {'capturable': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True} {'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True} {'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_Adam_cuda_float32 (__main__.TestCudaOptimsCUDA) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'capturable': True, 'fused': True} {'capturable': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True} {'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True} {'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_SGD_cuda_float32 (__main__.TestCudaOptimsCUDA) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'lr': tensor(0.0010), 'fused': True} {'lr': tensor(0.0010), 'fused': True} {'momentum': 0.9, 'fused': True} {'momentum': 0.9, 'fused': True} {'momentum': 0.9, 'dampening': 0.5, 'fused': True} {'momentum': 0.9, 'dampening': 0.5, 'fused': True} {'momentum': 0.9, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} ok ---------------------------------------------------------------------- Ran 8 tests in 16.117s OK (skipped=1) > lintrunner test/test_cuda.py ---------------------------------------------------------------------- ok No lint issues. > lintrunner torch/testing/_internal/common_optimizers.py ---------------------------------------------------------------------- ok No lint issues. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126418 Approved by: https://github.com/janeyx99 --- test/test_cuda.py | 83 ++++++++++++++++- test/test_optim.py | 96 +------------------- torch/testing/_internal/common_optimizers.py | 25 +++++ 3 files changed, 108 insertions(+), 96 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index 6ce7555519d7..785f0499df05 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -29,6 +29,7 @@ ) from torch.testing._internal.autocast_test_lists import AutocastTestLists from torch.testing._internal.common_cuda import ( + _create_scaling_case, _get_torch_cuda_version, TEST_CUDNN, TEST_MULTIGPU, @@ -36,8 +37,9 @@ from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, + onlyNativeDeviceTypes, ) -from torch.testing._internal.common_optimizers import optim_db, optims +from torch.testing._internal.common_optimizers import optim_db, optims, TensorTracker from torch.testing._internal.common_utils import ( freeze_rng_state, gcIfJetson, @@ -4741,6 +4743,85 @@ class TestCudaOptims(TestCase): # These tests will be instantiate with instantiate_device_type_tests # to apply the new OptimizerInfo structure. + @onlyNativeDeviceTypes + @optims( + [optim for optim in optim_db if "fused" in optim.supported_impls], + dtypes=[torch.float32], + ) + def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info): + device = device.split(":")[0] + if device not in optim_info.supports_fused_on: + self.skipTest( + f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" + ) + optim_inputs = optim_info.optim_inputs_func(device=device) + optim_cls = optim_info.optim_cls + for optim_input in optim_inputs: + for _separate_unscale in (True, False): + kwargs = optim_input.kwargs + kwargs["fused"] = True + torch.manual_seed(20) + ( + mod_control, + mod_scaling, + opt_control, + opt_scaling, + data, + loss_fn, + _, + ) = _create_scaling_case( + optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, device=device + ) + optimizer_kwargs = deepcopy(kwargs) + optimizer_kwargs["fused"] = False + if "lr" not in kwargs: + # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr + optimizer_kwargs["lr"] = 1.0 + opt_control = optim_cls(mod_control.parameters(), **optimizer_kwargs) + scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0) + scaler_control = torch.amp.GradScaler(device, init_scale=128.0) + tracker = TensorTracker() + for input, target in data: + opt_control.zero_grad() + with torch.autocast(device_type=device, dtype=torch.half): + output_control = mod_control(input) + loss_control = loss_fn(output_control, target) + scaler_control.scale(loss_control).backward() + scaler_control.step(opt_control) + scaler_control.update() + + opt_scaling.zero_grad() + with torch.autocast(device_type=device, dtype=torch.half): + output_scaling = mod_scaling(input) + loss_scaling = loss_fn(output_scaling, target) + scaler_scaling.scale(loss_scaling).backward() + if _separate_unscale: + scaler_scaling.unscale_(opt_scaling) + scaler_scaling.step(opt_scaling) + scaler_scaling.update() + + tracker.add(loss_control) + tracker.pop_check_set(loss_scaling, self) + for param_control, param_scaling in zip( + mod_control.parameters(), mod_scaling.parameters() + ): + tracker.add(param_control.grad) + tracker.pop_check_set(param_scaling.grad, self) + tracker.add(param_control) + tracker.pop_check_set(param_scaling, self) + + state_control, state_scaling = ( + opt_control.state[param_control], + opt_scaling.state[param_scaling], + ) + + for k in state_control: + actual = state_scaling[k] + if k == "step": + actual = actual.squeeze() + tracker.add(state_control[k]) + tracker.pop_check_set(actual, self) + @onlyCUDA @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" diff --git a/test/test_optim.py b/test/test_optim.py index 9e3ee50ff302..3ab57fecd833 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -20,7 +20,7 @@ register_optimizer_step_post_hook, register_optimizer_step_pre_hook, ) -from torch.testing._internal.common_cuda import _create_scaling_case, TEST_MULTIGPU +from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, largeTensorTest, @@ -1954,100 +1954,6 @@ def test_fused_cpu_matches_cuda(self, device, dtype, optim_info): optimizers.append(optimizer) self._compare_between(inpts, models, optimizers) - @onlyNativeDeviceTypes - @optims( - [optim for optim in optim_db if "fused" in optim.supported_impls], - dtypes=[torch.float32], - ) - def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info): - # This ut is from test_cuda.py test_grad_scaling_autocast_fused_optimizers - # but only test Adam/AdamW on CPU - # TODO: haozhe, support SGD and unified this ut with the CUDA only one - if device not in optim_info.supports_fused_on: - self.skipTest( - f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" - ) - optim_inputs = optim_info.optim_inputs_func(device=device) - optim_cls = optim_info.optim_cls - for optim_input in optim_inputs: - kwargs = optim_input.kwargs - kwargs["fused"] = True - for _separate_unscale in (True, False): - self._grad_scaling_autocast_fused_optimizers( - device=device, - optimizer_ctor=optim_cls, - optimizer_kwargs=kwargs, - separate_unscale=_separate_unscale, - ) - - def _grad_scaling_autocast_fused_optimizers( - self, device, optimizer_ctor, optimizer_kwargs, separate_unscale - ): - torch.manual_seed(20) - ( - mod_control, - mod_scaling, - opt_control, - opt_scaling, - data, - loss_fn, - _, - ) = _create_scaling_case( - optimizer_ctor=optimizer_ctor, - optimizer_kwargs=optimizer_kwargs, - device="cpu", - ) - kwargs = deepcopy(optimizer_kwargs) - kwargs["fused"] = False - if "lr" not in optimizer_kwargs: - # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr - kwargs["lr"] = 1.0 - opt_control = optimizer_ctor(mod_control.parameters(), **kwargs) - - scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0) - scaler_control = torch.amp.GradScaler(device, init_scale=128.0) - tracker = TensorTracker() - for input, target in data: - opt_control.zero_grad() - with torch.autocast(device_type=device, dtype=torch.half): - output_control = mod_control(input) - loss_control = loss_fn(output_control, target) - scaler_control.scale(loss_control).backward() - scaler_control.step(opt_control) - scaler_control.update() - - opt_scaling.zero_grad() - with torch.autocast(device_type=device, dtype=torch.half): - output_scaling = mod_scaling(input) - loss_scaling = loss_fn(output_scaling, target) - scaler_scaling.scale(loss_scaling).backward() - if separate_unscale: - scaler_scaling.unscale_(opt_scaling) - scaler_scaling.step(opt_scaling) - scaler_scaling.update() - - tracker.add(loss_control) - tracker.pop_check_set(loss_scaling, self) - for param_control, param_scaling in zip( - mod_control.parameters(), mod_scaling.parameters() - ): - tracker.add(param_control.grad) - tracker.pop_check_set(param_scaling.grad, self) - tracker.add(param_control) - tracker.pop_check_set(param_scaling, self) - - state_control, state_scaling = ( - opt_control.state[param_control], - opt_scaling.state[param_scaling], - ) - - for k in state_control: - actual = state_scaling[k] - if k == "step": - actual = actual.squeeze() - tracker.add(state_control[k]) - tracker.pop_check_set(actual, self) - @onlyCUDA @optims( [o for o in optim_db if "foreach" in o.supported_impls], dtypes=[torch.float32] diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 0f179e0c23fb..bca785dd9543 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -1254,6 +1254,17 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "TestOptimRenewed", "test_fused_matches_forloop", ), + DecorateInfo( + # Note on tolerances: + # Tracking through #127000 + toleranceOverride( + { + torch.float32: tol(atol=3e-5, rtol=1.3e-06), + } + ), + "TestCudaOptims", + "test_grad_scaling_autocast_fused_optimizers", + ), ), skips=( DecorateInfo( @@ -1370,6 +1381,20 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "TestOptimRenewed", "test_fused_matches_forloop", ), + # Note on tolerances: + # Tracking through #127000 + DecorateInfo( + toleranceOverride( + { + torch.float32: tol( + atol=3e-5, + rtol=1.3e-06, + ) + } + ), + "TestCudaOptims", + "test_grad_scaling_autocast_fused_optimizers", + ), ), skips=( DecorateInfo( From e1c322112a3d7b128b42e27f68bc9a714bfd9a09 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 28 May 2024 12:54:01 -0700 Subject: [PATCH 075/706] [compiled autograd] torch.compile API (#125880) - enter existing compiled autograd ctx manager before entering torch.compile frames Pull Request resolved: https://github.com/pytorch/pytorch/pull/125880 Approved by: https://github.com/jansel --- test/inductor/test_compiled_autograd.py | 170 +++++++++++++++++++++++- torch/_dynamo/config.py | 4 + torch/_dynamo/eval_frame.py | 32 ++++- 3 files changed, 203 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 87299d796f6c..2daacc308071 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn from torch import _inductor as inductor -from torch._dynamo import compiled_autograd +from torch._dynamo import compiled_autograd, config from torch._dynamo.utils import counters from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA @@ -54,10 +54,14 @@ def hook3(gI, gO): class TestCompiledAutograd(TestCase): def setUp(self) -> None: super().setUp() + torch._logging.set_logs(compiled_autograd_verbose=False) + config.compiled_autograd = False compiled_autograd.reset() def tearDown(self) -> None: super().tearDown() + torch._logging.set_logs(compiled_autograd_verbose=False) + config.compiled_autograd = False compiled_autograd.reset() def check_output_and_recompiles( @@ -230,6 +234,170 @@ def fn(): self.check_output_and_recompiles(fn) + def test_torch_compile_api_inductor(self): + def fn(): + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + + res = [] + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + expected = fn() + with config.patch(compiled_autograd=True): + compiled_fn = torch.compile(fn) + actual = compiled_fn() + self.assertEqual(expected, actual) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_api_aot_eager(self): + def fn(): + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + + res = [] + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + expected = fn() + with config.patch(compiled_autograd=True): + compiled_fn = torch.compile(fn, backend="aot_eager") + actual = compiled_fn() + self.assertEqual(expected, actual) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_api_eager(self): + def fn(): + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + + res = [] + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + expected = fn() + with config.patch(compiled_autograd=True): + compiled_fn = torch.compile(fn, backend="eager") + actual = compiled_fn() + self.assertEqual(expected, actual) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_multiple_torch_compile(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + def fn(): + result = model(x).sum() + result.backward() + + model2 = torch.nn.Linear(4, 4) + x2 = torch.randn([1, 4]) + + def fn2(): + result = model2(x2).sum() + result.backward() + + no_ca1 = torch.compile(fn) + no_ca1() + self.assertEqual(counters["compiled_autograd"]["captures"], 0) + counters.clear() + + with config.patch(compiled_autograd=True): + with_ca = torch.compile(fn2) + with_ca() + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + counters.clear() + + no_ca2 = torch.compile(fn) + no_ca2() + self.assertEqual(counters["compiled_autograd"]["captures"], 0) + + def test_torch_compile_graph_break(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + @torch._dynamo.disable() + def fn(): + result = model(x).sum() + result.backward() + + with config.patch(compiled_autograd=True): + opt_fn = torch.compile(fn) + opt_fn() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_graph_break2(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + @torch._dynamo.disable() + def inner_fn(loss): + loss.backward() + + def fn(): + result = model(x).sum() + inner_fn(result) + + with config.patch(compiled_autograd=True): + opt_fn = torch.compile(fn) + opt_fn() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_only_backward_call(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + result = model(x).sum() + with config.patch(compiled_autograd=True): + opt_bwd = torch.compile(lambda: result.backward()) + opt_bwd() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + def test_dynamo_boxed(self): def get_placeholders(gm_): placeholders = [] diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 6f4219a03b18..212021859c46 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -454,6 +454,10 @@ def default_debug_dir_root(): # WARNING: this is an experimental flag and is subject to change. _experimental_support_context_fn_in_torch_utils_checkpoint = False +# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile(). +# Note: AOT Autograd will still trace joint graphs. +compiled_autograd = False + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index fa9311f2c18a..626b206cfe48 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -493,6 +493,9 @@ def __init__( export=False, dynamic=None, compiler_config=None, + rebuild_ctx: Optional[ + Callable[[], Union[OptimizeContext, _NullDecorator]] + ] = None, ): def on_enter(): install_generation_tagging_init() @@ -508,6 +511,17 @@ def on_enter(): compiler_config=compiler_config, ) + if config.compiled_autograd: + assert rebuild_ctx is not None + + def call_compiled_autograd(): + compiler_fn = rebuild_ctx() + ctx = torch._dynamo.compiled_autograd.enable(compiler_fn) + ctx.__enter__() + return functools.partial(ctx.__exit__, None, None, None) + + self.enter_exit_hooks.append(call_compiled_autograd) + class RunOnlyContext(_TorchDynamoContext): def __init__(self): @@ -577,6 +591,7 @@ def _optimize_catch_errors( export=False, dynamic=None, compiler_config=None, + rebuild_ctx=None, ): return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), @@ -585,6 +600,7 @@ def _optimize_catch_errors( export=export, dynamic=dynamic, compiler_config=compiler_config, + rebuild_ctx=rebuild_ctx, ) @@ -635,7 +651,15 @@ def is_inductor_supported(): return False -def optimize( +def optimize(*args, **kwargs): + def rebuild_ctx(): + return optimize(*args, **kwargs) + + return _optimize(rebuild_ctx, *args, **kwargs) + + +def _optimize( + rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], backend="inductor", *, nopython=False, @@ -643,7 +667,7 @@ def optimize( guard_fail_fn=None, disable=False, dynamic=None, -): +) -> Union[OptimizeContext, _NullDecorator]: """ The main entrypoint of TorchDynamo. Do graph capture and call backend() to optimize extracted graphs. @@ -691,6 +715,7 @@ def toy_example(a, b): backend, dynamic=dynamic, hooks=hooks, + rebuild_ctx=rebuild_ctx, ) # The backend function is stashed in the callable returned by # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can @@ -703,6 +728,7 @@ def toy_example(a, b): compiler_config=backend.get_compiler_config() if hasattr(backend, "get_compiler_config") else None, + rebuild_ctx=rebuild_ctx, ) @@ -1466,6 +1492,7 @@ def optimize_assert( export=False, export_constraints=None, dynamic=None, + rebuild_ctx=None, ): """ The same as `torch._dynamo.optimize(backend, nopython=True)` @@ -1483,6 +1510,7 @@ def optimize_assert( backend_ctx_ctor, export=export, dynamic=dynamic, + rebuild_ctx=rebuild_ctx, ) From 3d541835d509910fceca00fc5a916e9718c391d8 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Thu, 30 May 2024 02:21:05 +0000 Subject: [PATCH 076/706] distributed debug handlers (#126601) This adds debug handlers as described in: * https://gist.github.com/d4l3k/828b7be585c7615e85b2c448b308d925 (public copy) * https://docs.google.com/document/d/1la68szcS6wUYElUUX-P6zXgkPA8lnfzpagMTPys3aQ8/edit (internal copy) This is only adding the C++ pieces that will be used from the main process. The Python and torchrun pieces will be added in a follow up PR. This adds 2 handlers out of the box: * `/handler/ping` for testing purposes * `/handler/dump_nccl_trace_pickle` as a POC integration with Flight Recorder Test plan: ``` python test/distributed/elastic/test_control_plane.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126601 Approved by: https://github.com/kurman, https://github.com/c-p-i-o --- BUILD.bazel | 1 + WORKSPACE | 6 + build_variables.bzl | 2 + caffe2/CMakeLists.txt | 3 + cmake/Dependencies.cmake | 4 + docs/source/distributed.elastic.rst | 1 + docs/source/elastic/control_plane.rst | 10 + .../distributed/elastic/test_control_plane.py | 86 +++++++++ third_party/cpp-httplib.BUILD | 10 + torch/CMakeLists.txt | 2 + torch/_C/_distributed_c10d.pyi | 4 + .../distributed/c10d/ProcessGroupNCCL.cpp | 8 + .../c10d/control_plane/Handlers.cpp | 75 ++++++++ .../c10d/control_plane/Handlers.hpp | 67 +++++++ .../c10d/control_plane/WorkerServer.cpp | 178 ++++++++++++++++++ .../c10d/control_plane/WorkerServer.hpp | 28 +++ torch/csrc/distributed/c10d/init.cpp | 12 ++ torch/distributed/elastic/control_plane.py | 51 +++++ 18 files changed, 548 insertions(+) create mode 100644 docs/source/elastic/control_plane.rst create mode 100644 test/distributed/elastic/test_control_plane.py create mode 100644 third_party/cpp-httplib.BUILD create mode 100644 torch/csrc/distributed/c10d/control_plane/Handlers.cpp create mode 100644 torch/csrc/distributed/c10d/control_plane/Handlers.hpp create mode 100644 torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp create mode 100644 torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp create mode 100644 torch/distributed/elastic/control_plane.py diff --git a/BUILD.bazel b/BUILD.bazel index 8c1aa2729101..e61290ca2cab 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -780,6 +780,7 @@ cc_library( ":caffe2", ":torch_headers", "@kineto", + "@cpp-httplib", ] + if_cuda([ "@cuda//:nvToolsExt", "@cutlass", diff --git a/WORKSPACE b/WORKSPACE index f7e604332213..9f32aea703dc 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -178,6 +178,12 @@ new_patched_local_repository( path = "third_party/tbb", ) +new_local_repository( + name = "cpp-httplib", + build_file = "//third_party:cpp-httplib.BUILD", + path = "third_party/cpp-httplib", +) + new_local_repository( name = "tensorpipe", build_file = "//third_party:tensorpipe.BUILD", diff --git a/build_variables.bzl b/build_variables.bzl index 5770aca4c9e0..26986506ec8b 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -515,6 +515,8 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/sequence_num.cpp", "torch/csrc/distributed/c10d/socket.cpp", "torch/csrc/distributed/c10d/Work.cpp", + "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", + "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", ] # These files are only supported on Linux (and others) but not on Windows. diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 73c6467075ba..54a31185c127 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1244,6 +1244,9 @@ if(USE_KINETO) ${TORCH_ROOT}/third_party/kineto/libkineto/src) endif() +target_include_directories(torch_cpu PRIVATE + ${TORCH_ROOT}/third_party/cpp-httplib) + install(DIRECTORY "${TORCH_SRC_DIR}/csrc" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp") diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index fb89afd9fff7..e9fd67018da7 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1728,3 +1728,7 @@ endif() # Include google/FlatBuffers include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake) + +# Include cpp-httplib +add_library(httplib INTERFACE IMPORTED) +target_include_directories(httplib SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/cpp-httplib) diff --git a/docs/source/distributed.elastic.rst b/docs/source/distributed.elastic.rst index 24d33d1982df..0aabb560c9c8 100644 --- a/docs/source/distributed.elastic.rst +++ b/docs/source/distributed.elastic.rst @@ -29,6 +29,7 @@ Documentation elastic/metrics elastic/events elastic/subprocess_handler + elastic/control_plane .. toctree:: :maxdepth: 1 diff --git a/docs/source/elastic/control_plane.rst b/docs/source/elastic/control_plane.rst new file mode 100644 index 000000000000..c37454cf1b0a --- /dev/null +++ b/docs/source/elastic/control_plane.rst @@ -0,0 +1,10 @@ +Control Plane +============= + +.. automodule:: torch.distributed.elastic.control_plane +.. currentmodule:: torch.distributed.elastic.control_plane + +This module contains optional helpers that add extra debug and control handlers +into your application. + +.. autofunction:: torch.distributed.elastic.control_plane.worker_main diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py new file mode 100644 index 000000000000..c9ae512f2718 --- /dev/null +++ b/test/distributed/elastic/test_control_plane.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Owner(s): ["oncall: distributed"] + +import json +import os +import pickle +import socket +import tempfile +from contextlib import contextmanager + +from urllib3.connection import HTTPConnection +from urllib3.connectionpool import HTTPConnectionPool + +from torch.distributed.elastic.control_plane import ( + TORCH_WORKER_SERVER_SOCKET, + worker_main, +) +from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase + + +class UnixHTTPConnection(HTTPConnection): + def __init__(self, socket_path: str) -> None: + super().__init__("localhost") + + self.socket_path = socket_path + + def connect(self) -> None: + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.connect(self.socket_path) + + +class UnixHTTPConnectionPool(HTTPConnectionPool): + def __init__(self, socket_path: str) -> None: + super().__init__("localhost") + + self.socket_path = socket_path + + def _new_conn(self): + return UnixHTTPConnection(self.socket_path) + + +@contextmanager +def local_worker_server() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "socket.sock") + os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path + + with worker_main(): + pool = UnixHTTPConnectionPool(socket_path) + yield pool + + +class WorkerServerTest(TestCase): + def test_worker_server(self) -> None: + with local_worker_server() as pool: + resp = pool.request("GET", "/") + self.assertEqual(resp.status, 200) + self.assertEqual( + resp.data, + b"""

torch.distributed.WorkerServer

+Handler names +""", + ) + + resp = pool.request("POST", "/handler/ping") + self.assertEqual(resp.status, 200) + self.assertEqual(resp.data, b"pong") + + resp = pool.request("GET", "/handler/") + self.assertEqual(resp.status, 200) + self.assertIn("ping", json.loads(resp.data)) + + resp = pool.request("POST", "/handler/nonexistant") + self.assertEqual(resp.status, 404) + self.assertIn(b"Handler nonexistant not found:", resp.data) + + @requires_cuda + def test_dump_nccl_trace_pickle(self) -> None: + with local_worker_server() as pool: + resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") + self.assertEqual(resp.status, 200) + out = pickle.loads(resp.data) + + +if __name__ == "__main__": + run_tests() diff --git a/third_party/cpp-httplib.BUILD b/third_party/cpp-httplib.BUILD new file mode 100644 index 000000000000..3cd0c3dbe94b --- /dev/null +++ b/third_party/cpp-httplib.BUILD @@ -0,0 +1,10 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "cpp-httplib", + hdrs = ["httplib.h"], + includes = [ + "/", + ], + visibility = ["//visibility:public"], +) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index b4db57488f02..1cf1fe2cf599 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -68,6 +68,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES ${TORCH_ROOT}/third_party/onnx ${TORCH_ROOT}/third_party/flatbuffers/include ${TORCH_ROOT}/third_party/kineto/libkineto/include + ${TORCH_ROOT}/third_party/cpp-httplib ${TORCH_SRC_DIR}/csrc ${TORCH_SRC_DIR}/csrc/api/include @@ -80,6 +81,7 @@ set(TORCH_PYTHON_LINK_LIBRARIES python::python pybind::pybind11 opentelemetry::api + httplib shm fmt::fmt-header-only ATEN_CPU_FILES_GEN_LIB) diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 1a3e4ea63342..5594d6153b07 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -632,3 +632,7 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... + +class _WorkerServer: + def __init__(self, socket_path: str) -> None: ... + def shutdown(self) -> None: ... diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 91ce50a4183f..1ea278a44e3c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -355,6 +356,13 @@ std::string dump_nccl_trace() { } #endif +// TODO(c-p-i-o): add a JSON endpoint. +control_plane::RegisterHandler dumpHandler{ + "dump_nccl_trace_pickle", + [](const control_plane::Request&, control_plane::Response& res) { + res.setContent(dump_nccl_trace(), "application/octet-stream"); + }}; + std::optional)>>& get_cpp_trace_dumper() { static std::optional< diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp new file mode 100644 index 000000000000..e29f1e3a2ac3 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -0,0 +1,75 @@ +#include + +#include +#include +#include +#include + +namespace c10d { +namespace control_plane { + +namespace { + +class HandlerRegistry { + public: + void registerHandler(const std::string& name, HandlerFunc f) { + std::unique_lock lock(handlersMutex_); + + if (handlers_.find(name) != handlers_.end()) { + throw std::runtime_error( + fmt::format("Handler {} already registered", name)); + } + + handlers_[name] = f; + } + + HandlerFunc getHandler(const std::string& name) { + std::shared_lock lock(handlersMutex_); + + auto it = handlers_.find(name); + if (it == handlers_.end()) { + throw std::runtime_error(fmt::format("Failed to find handler {}", name)); + } + return handlers_[name]; + } + + std::vector getHandlerNames() { + std::shared_lock lock(handlersMutex_); + + std::vector names; + for (const auto& [name, _] : handlers_) { + names.push_back(name); + } + return names; + } + + private: + std::shared_mutex handlersMutex_{}; + std::unordered_map handlers_{}; +}; + +HandlerRegistry& getHandlerRegistry() { + static HandlerRegistry registry; + return registry; +} + +RegisterHandler pingHandler{"ping", [](const Request&, Response& res) { + res.setContent("pong", "text/plain"); + }}; + +} // namespace + +void registerHandler(const std::string& name, HandlerFunc f) { + return getHandlerRegistry().registerHandler(name, f); +} + +HandlerFunc getHandler(const std::string& name) { + return getHandlerRegistry().getHandler(name); +} + +std::vector getHandlerNames() { + return getHandlerRegistry().getHandlerNames(); +} + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp new file mode 100644 index 000000000000..0c1063054931 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include + +namespace c10d { +namespace control_plane { + +// Request represents a request to the handler. This conceptually maps to an +// HTTP request but could be called via other transports. +class TORCH_API Request { + public: + virtual ~Request() = default; + + virtual const std::string& body() = 0; +}; + +// Response represents a response to the handler. This conceptually maps to an +// HTTP response but could be called via other transports. +class TORCH_API Response { + public: + virtual ~Response() = default; + + // Set the response body to the provided string. + // TODO: add support for chunked responses + virtual void setContent( + std::string&& content, + const std::string& content_type) = 0; + + // Set the response status code. + // These should match standard HTTP status codes. + virtual void setStatus(int status) = 0; +}; + +using HandlerFunc = std::function; + +// Registers a handler. The name needs to be unique and can be called by using +// getHandler directly or via WorkerServer for remote requests. +// These handlers are called from a background C++ thread concurrently with the +// main thread. These handlers need to be thread safe and not cause issues +// during Python training. +TORCH_API void registerHandler(const std::string& name, HandlerFunc f); + +// Fetches a handler by name. +TORCH_API HandlerFunc getHandler(const std::string& name); + +TORCH_API std::vector getHandlerNames(); + +// Registers a handler statically. +// See registerHandler for more details. +class TORCH_API RegisterHandler { + public: + RegisterHandler(const std::string& name, HandlerFunc f) { + registerHandler(name, f); + } + + // disable move, copy + RegisterHandler(const RegisterHandler&) = delete; + RegisterHandler(RegisterHandler&&) = delete; + RegisterHandler& operator=(const RegisterHandler&) = delete; + RegisterHandler& operator=(RegisterHandler&&) = delete; +}; + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp new file mode 100644 index 000000000000..14d287e9607f --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -0,0 +1,178 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10d { +namespace control_plane { + +namespace { +class RequestImpl : public Request { + public: + RequestImpl(const httplib::Request& req) : req_(req) {} + + const std::string& body() override { + return req_.body; + } + + private: + const httplib::Request& req_; +}; + +class ResponseImpl : public Response { + public: + ResponseImpl(httplib::Response& res) : res_(res) {} + + void setStatus(int status) override { + res_.status = status; + } + + void setContent(std::string&& content, const std::string& content_type) + override { + res_.set_content(std::move(content), content_type); + } + + private: + httplib::Response& res_; +}; + +std::string jsonStrEscape(const std::string& str) { + std::ostringstream ostream; + for (char ch : str) { + if (ch == '"') { + ostream << "\\\""; + } else if (ch == '\\') { + ostream << "\\\\"; + } else if (ch == '\b') { + ostream << "\\b"; + } else if (ch == '\f') { + ostream << "\\f"; + } else if (ch == '\n') { + ostream << "\\n"; + } else if (ch == '\r') { + ostream << "\\r"; + } else if (ch == '\t') { + ostream << "\\t"; + } else if ('\x00' <= ch && ch <= '\x1f') { + ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0') + << static_cast(ch); + } else { + ostream << ch; + } + } + return ostream.str(); +} +} // namespace + +WorkerServer::WorkerServer(const std::string& socketFile) { + // using unix sockets + server_.set_address_family(AF_UNIX); + + // adjust keep alives as it stops the server from shutting down quickly + server_.set_keep_alive_timeout(1); // second, default is 5 + server_.set_keep_alive_max_count( + 30); // wait max 30 seconds before closing socket + + server_.Get("/", [](const httplib::Request& req, httplib::Response& res) { + res.set_content( + R"BODY(

torch.distributed.WorkerServer

+Handler names +)BODY", + "text/html"); + }); + server_.Get( + "/handler/", [](const httplib::Request& req, httplib::Response& res) { + std::ostringstream body; + body << "["; + bool first = true; + for (const auto& name : getHandlerNames()) { + if (!first) { + body << ","; + } + first = false; + + body << "\"" << jsonStrEscape(name) << "\""; + } + body << "]"; + + res.set_content(body.str(), "application/json"); + }); + server_.Post( + "/handler/:handler", + [](const httplib::Request& req, httplib::Response& res) { + auto handler_name = req.path_params.at("handler"); + HandlerFunc handler; + try { + handler = getHandler(handler_name); + } catch (const std::exception& e) { + res.status = 404; + res.set_content( + fmt::format("Handler {} not found: {}", handler_name, e.what()), + "text/plain"); + return; + } + RequestImpl torchReq{req}; + ResponseImpl torchRes{res}; + + try { + handler(torchReq, torchRes); + } catch (const std::exception& e) { + res.status = 500; + res.set_content( + fmt::format("Handler {} failed: {}", handler_name, e.what()), + "text/plain"); + return; + } catch (...) { + res.status = 500; + res.set_content( + fmt::format( + "Handler {} failed with unknown exception", handler_name), + "text/plain"); + return; + } + }); + + if (std::filesystem::exists(socketFile)) { + throw std::runtime_error(fmt::format("{} already exists", socketFile)); + } + + C10D_WARNING("Server listening to {}", socketFile); + if (!server_.bind_to_port(socketFile, 80)) { + throw std::runtime_error(fmt::format("Error binding to {}", socketFile)); + } + + serverThread_ = std::thread([this]() { + try { + if (!server_.listen_after_bind()) { + throw std::runtime_error("failed to listen"); + } + } catch (std::exception& e) { + C10D_ERROR("Error while running server: {}", e.what()); + throw; + } + C10D_WARNING("Server exited"); + }); +} + +void WorkerServer::shutdown() { + C10D_WARNING("Server shutting down"); + server_.stop(); + serverThread_.join(); +} + +WorkerServer::~WorkerServer() { + if (serverThread_.joinable()) { + C10D_WARNING("WorkerServer destructor called without shutdown"); + shutdown(); + } +} + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp new file mode 100644 index 000000000000..7d64038f0b01 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace c10d { +namespace control_plane { + +class TORCH_API WorkerServer : public c10::intrusive_ptr_target { + public: + WorkerServer(const std::string& socketFile); + ~WorkerServer(); + + void shutdown(); + + private: + httplib::Server server_; + std::thread serverThread_; +}; + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 2aaf9009a246..9f0122e78332 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #ifndef _WIN32 #include @@ -3164,6 +3165,17 @@ such as `dist.all_reduce(tensor, async_op=True)`. return py::bytes(::c10d::dump_nccl_trace()); }); #endif + + intrusive_ptr_class_<::c10d::control_plane::WorkerServer>( + module, "_WorkerServer", R"( +)") + .def( + py::init([](const std::string& socketPath) { + return c10::make_intrusive<::c10d::control_plane::WorkerServer>( + socketPath); + }), + py::arg("socket_path")) + .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown); Py_RETURN_TRUE; } diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py new file mode 100644 index 000000000000..160383637865 --- /dev/null +++ b/torch/distributed/elastic/control_plane.py @@ -0,0 +1,51 @@ +import os +from contextlib import contextmanager, ExitStack +from typing import Generator + +from torch.distributed.elastic.multiprocessing.errors import record + +__all__ = [ + "worker_main", +] + +TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" + + +@contextmanager +def _worker_server(socket_path: str) -> Generator[None, None, None]: + from torch._C._distributed_c10d import _WorkerServer + + server = _WorkerServer(socket_path) + try: + yield + finally: + server.shutdown() + + +@contextmanager +@record +def worker_main() -> Generator[None, None, None]: + """ + This is a context manager that wraps your main entry function. This combines + the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that + exposes handlers via a unix socket specified by + ``Torch_WORKER_SERVER_SOCKET``. + + Example + + :: + + @worker_main() + def main(): + pass + + if __name__=="__main__": + main() + + """ + with ExitStack() as stack: + socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) + if socket_path is not None: + stack.enter_context(_worker_server(socket_path)) + + yield From e0fc1ab6257af45b36192f7bcb4d707cb5734a2b Mon Sep 17 00:00:00 2001 From: chilli Date: Wed, 29 May 2024 13:16:10 -0700 Subject: [PATCH 077/706] Forward fix for templates + views (#127446) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127446 Approved by: https://github.com/eellison --- test/inductor/test_cuda_repro.py | 23 +++++++++++++++++++++++ torch/_inductor/codegen/simd.py | 5 +++-- torch/_inductor/select_algorithm.py | 2 +- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 386fb36a635e..6ead970a7884 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1181,6 +1181,29 @@ def outer_reduce(x): self.assertEqual(outer_reduce(a), out) self.assertTrue("for roffset" not in code) + def test_epilogue_fusion_with_view(self): + class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.linear = torch.nn.Linear(262144, 100) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = x.view(x.size(0), -1) + return self.relu(self.linear(x)) + + m = ToyModel().to(device="cuda:0") + input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0") + from torch._inductor.utils import fresh_inductor_cache + + with fresh_inductor_cache(): + cm = torch.compile(m, mode="max-autotune") + out = cm(input_tensor) + out2 = m(input_tensor) + self.assertEqual(out, out2, atol=1e-3, rtol=1e-3) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 23c602c10e5d..9140b1887f7f 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1400,8 +1400,9 @@ def codegen_template( for node in [template_node, *epilogue_nodes]: node.mark_run() partial_code = render() - for node in epilogue_nodes: - node.codegen(kernel.split_and_set_ranges(node.get_ranges())) + with kernel.set_subgraph_body(""): + for node in epilogue_nodes: + node.codegen(kernel.split_and_set_ranges(node.get_ranges())) if not isinstance(partial_code, str): partial_code.finalize_hook("") diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 7aafcfe31488..8361566e5f08 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -150,7 +150,7 @@ def __init__( @contextlib.contextmanager def set_subgraph_body(self, body_name: str): old_body = self.body - assert body_name in self.subgraph_bodies + assert body_name in self.subgraph_bodies, body_name self.body = self.subgraph_bodies[body_name] yield self.body = old_body From f58fc16e8f059232f452a333f32e14ff681e12af Mon Sep 17 00:00:00 2001 From: James Wu Date: Wed, 29 May 2024 13:02:26 -0700 Subject: [PATCH 078/706] [easy?] Move AsyncCompile to a different file (#127235) By moving AsyncCompile to its own file, we can import codecache without running the side effects of AsyncCompile. This will be important for AOTAutogradCaching, where we want to share some implementation details with codecache.py without spawning new processes. To conservatively maintain the same behavior elsewhere, every time we import codecache, I've added an import to torch._inductor.async_compile (except in autograd_cache.py, where the explicit goal is to not do this) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127235 Approved by: https://github.com/aorenste, https://github.com/oulgen, https://github.com/masnesral --- test/inductor/test_codecache.py | 2 +- test/inductor/test_cudacodecache.py | 3 +- test/inductor/test_halide.py | 1 + test/inductor/test_kernel_benchmark.py | 1 + test/inductor/test_triton_wrapper.py | 1 + torch/_inductor/async_compile.py | 239 +++++++++++++++++++ torch/_inductor/autotune_process.py | 1 + torch/_inductor/codecache.py | 207 +--------------- torch/_inductor/codegen/cpp_wrapper_cpu.py | 3 +- torch/_inductor/codegen/wrapper.py | 4 +- torch/_inductor/compile_fx.py | 2 + torch/_inductor/compile_worker/__main__.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 1 - torch/_inductor/scheduler.py | 1 + torch/_inductor/select_algorithm.py | 1 + torch/testing/_internal/inductor_utils.py | 2 +- 16 files changed, 257 insertions(+), 214 deletions(-) create mode 100644 torch/_inductor/async_compile.py diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 994786740a65..af12454df3c0 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -12,8 +12,8 @@ from torch._dynamo import reset from torch._dynamo.utils import counters from torch._inductor import config, metrics +from torch._inductor.async_compile import AsyncCompile from torch._inductor.codecache import ( - AsyncCompile, cuda_compile_command, CUDACodeCache, FxGraphCachePickler, diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index 33a179a9abc7..ac26f6a6656c 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -6,7 +6,8 @@ import torch from torch._inductor import config -from torch._inductor.codecache import AsyncCompile, CUDACodeCache +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.codecache import CUDACodeCache from torch._inductor.codegen.cuda.cuda_env import nvcc_exist from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase diff --git a/test/inductor/test_halide.py b/test/inductor/test_halide.py index 52227c20d1ff..158a669dad2f 100644 --- a/test/inductor/test_halide.py +++ b/test/inductor/test_halide.py @@ -3,6 +3,7 @@ import unittest import torch +import torch._inductor.async_compile from torch._inductor.codecache import HalideCodeCache from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta from torch._inductor.test_case import run_tests, TestCase diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 23804e08f23f..87ddb0bec2e6 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -6,6 +6,7 @@ from unittest.mock import patch import torch +import torch._inductor.async_compile from torch._dynamo.testing import rand_strided from torch._inductor import config from torch._inductor.codecache import PyCodeCache diff --git a/test/inductor/test_triton_wrapper.py b/test/inductor/test_triton_wrapper.py index 24ba84ebf86a..7f7ded46182a 100644 --- a/test/inductor/test_triton_wrapper.py +++ b/test/inductor/test_triton_wrapper.py @@ -4,6 +4,7 @@ import sys import torch +import torch._inductor.async_compile from torch._inductor.codecache import PyCodeCache from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py new file mode 100644 index 000000000000..c163df9bd878 --- /dev/null +++ b/torch/_inductor/async_compile.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import functools +import logging +import multiprocessing +import os +import sys +from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from functools import partial +from time import time +from typing import Any, Callable, Dict, List, Optional, Set + +import torch +from torch._dynamo.device_interface import get_registered_device_interfaces +from torch._inductor import config +from torch._inductor.codecache import ( + CodeCacheFuture, + CppCodeCache, + CppPythonBindingsCodeCache, + CUDACodeCache, + HalideCodeCache, + LambdaFuture, + TritonCodeCache, + TritonFuture, +) +from torch._inductor.compile_worker.subproc_pool import ( + _warm_process_pool, + AnyPool, + SubprocPool, +) +from torch._inductor.compile_worker.watchdog import _async_compile_initializer + +from torch._inductor.runtime.compile_tasks import ( + _set_triton_ptxas_path, + _worker_compile_triton, +) +from torch._inductor.runtime.hints import HalideMeta + +from torch.hub import _Faketqdm, tqdm + +# timing metrics for time spent in the compilation +_cumulative_compile_time = 0.0 +_t0: Optional[float] = None + +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +def caching_device_properties(): + for _, device_interface in get_registered_device_interfaces(): + if device_interface.is_available(): + device_interface.Worker.get_device_properties() + + +def _compile_start() -> None: + global _t0 + if _t0 is None: + _t0 = time() + + +def _compile_end() -> None: + global _cumulative_compile_time, _t0 + if _t0 is not None: + t1 = time() + _cumulative_compile_time += t1 - _t0 + _t0 = None + # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) + + +# Used to keep track of all process pools invoked so far. +_pool_set: Set[AnyPool] = set() + + +def shutdown_compile_workers() -> None: + """Shut down all outstanding compile-worker pools.""" + for pool in _pool_set: + pool.shutdown() + after_fork() + + +def after_fork(): + """Reset pools to initial state without shutting them down""" + _pool_set.clear() + AsyncCompile.process_pool.cache_clear() + + +try: + os.register_at_fork(after_in_child=after_fork) +except AttributeError: + pass # register_at_fork does not exists on windows + + +class AsyncCompile: + def __init__(self) -> None: + pass + + @staticmethod + @functools.lru_cache(1) + def pool() -> ThreadPoolExecutor: + assert config.compile_threads > 1 + return ThreadPoolExecutor(config.compile_threads) + + @staticmethod + @functools.lru_cache(1) + def process_pool() -> AnyPool: + assert config.compile_threads > 1 + pool: AnyPool + if config.worker_start_method == "subprocess": + # Wrapper around ProcessPoolExecutor forks in a new process we control + pool = SubprocPool(config.compile_threads) + else: + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + ctx = multiprocessing.get_context(config.worker_start_method) + pool = ProcessPoolExecutor( + config.compile_threads, + mp_context=ctx, + initializer=partial(_async_compile_initializer, os.getpid()), + ) + # when this pool is created in a subprocess object, the normal exit handler + # doesn't run, and we need to register our own handler. + # exitpriority has to be high, because another one of the finalizers will + # kill the worker thread that sends the shutdown message to the workers... + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + + _pool_set.add(pool) + return pool + + @classmethod + def warm_pool(cls) -> None: + if config.compile_threads <= 1: + return + _compile_start() + _warm_process_pool(cls.process_pool(), config.compile_threads) + _compile_end() + + @classmethod + def submit(cls, task: Callable[..., Any]) -> Any: + if config.compile_threads <= 1: + return task() + return cls.pool().submit(task) + + def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): + kernel_code_log.info("Triton Kernel:\n%s", source_code) + _compile_start() + _set_triton_ptxas_path() + + kernel = TritonCodeCache.load(kernel_name, source_code) + if config.compile_threads > 1: + return TritonFuture( + kernel, + self.process_pool().submit( + _worker_compile_triton, + kernel._reload_in_subproc, + ), + ) + else: + kernel.precompile() + return kernel + + def multi_kernel(self, *args, **kwargs) -> Any: + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + # no need to call this in parallel since the sub-kernels are already parallel tasks + return MultiKernelCall(*args, **kwargs) + + def cpp(self, source_code: str): + kernel_code_log.info("CPP Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppCodeCache.load(source_code).kernel + else: + get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) + return LambdaFuture(lambda: get_result().kernel) + + def cpp_pybinding(self, argtypes: List[str], source_code: str): + kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) + else: + get_result = CppPythonBindingsCodeCache.load_pybinding_async( + argtypes, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def cuda(self, source_code, dst_file_ext): + kernel_code_log.info("CUDA Kernel:\n%s", source_code) + + def task(): + return CUDACodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def halide(self, meta: HalideMeta, source_code: str): + kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) + if config.compile_threads <= 1: + return HalideCodeCache.generate_halide(meta, source_code) + else: + get_result = HalideCodeCache.generate_halide_async( + meta, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def wait(self, scope: Dict[str, Any]) -> None: + num_kernels = len( + [ + value + for key, value in scope.items() + if isinstance(value, (Future, CodeCacheFuture)) + ] + ) + pbar = tqdm( + total=num_kernels, + desc="Inductor Compilation", + disable=config.disable_progress, + delay=0, + ) + if config.compile_threads > 1: + for key, result in scope.items(): + if config.verbose_progress and not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(key) + if isinstance(result, (Future, CodeCacheFuture)): + scope[key] = result.result() + pbar.update(1) + + _compile_end() + + +if ( + os.environ.get("TORCH_TNT_IN_USE", "0") == "1" + or os.environ.get("TORCH_WARM_POOL", "1") != "1" +): + pass +else: + AsyncCompile.warm_pool() diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 7db4d2a3291c..5e04211639d6 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -25,6 +25,7 @@ ) import torch +import torch._inductor.async_compile from torch import multiprocessing from torch._dynamo.testing import rand_strided diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c2e6b4f0d95d..62c7252db493 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -9,7 +9,6 @@ import io import json import logging -import multiprocessing import os import pickle import pkgutil @@ -26,7 +25,7 @@ import threading import warnings from bisect import bisect_right -from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures import Future from copy import copy from ctypes import c_void_p, cdll, CDLL from functools import partial @@ -49,22 +48,13 @@ ) import torch -from torch._dynamo.device_interface import get_registered_device_interfaces from torch._dynamo.utils import counters, dynamo_timed from torch._inductor import config, exc, metrics from torch._inductor.codegen.cuda import cuda_env -from torch._inductor.compile_worker.subproc_pool import ( - _warm_process_pool, - AnyPool, - SubprocPool, -) -from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import ( _module_to_triton_kernel, _reload_python_module, _reload_python_module_in_subproc, - _set_triton_ptxas_path, - _worker_compile_triton, ) from torch._inductor.runtime.hints import HalideMeta from torch._inductor.runtime.runtime_utils import cache_dir @@ -82,7 +72,6 @@ from torch._inductor.graph import GraphLowering from torch._inductor.ir import ChoiceCaller -from torch.hub import _Faketqdm, tqdm _HERE = os.path.abspath(__file__) _TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) @@ -114,31 +103,11 @@ def use_global_cache() -> bool: output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") -kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") LOCK_TIMEOUT = 600 _IS_WINDOWS = sys.platform == "win32" -# timing metrics for time spent in the compilation -_cumulative_compile_time = 0.0 -_t0: Optional[float] = None - - -def _compile_start() -> None: - global _t0 - if _t0 is None: - _t0 = time() - - -def _compile_end() -> None: - global _cumulative_compile_time, _t0 - if _t0 is not None: - t1 = time() - _cumulative_compile_time += t1 - _t0 - _t0 = None - # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) - log = logging.getLogger(__name__) @@ -3205,12 +3174,6 @@ def load(cls, source_code, dst_file_ext) -> Tuple[DLLWrapper, str, str]: return (DLLWrapper(dst_file_path), hash_key, source_code_path) -def caching_device_properties(): - for _, device_interface in get_registered_device_interfaces(): - if device_interface.is_available(): - device_interface.Worker.get_device_properties() - - class CodeCacheFuture: def result(self): raise NotImplementedError @@ -3244,171 +3207,3 @@ def __init__(self, result_fn): def result(self): return self.result_fn() - - -# Used to keep track of all process pools invoked so far. -_pool_set: Set[AnyPool] = set() - - -def shutdown_compile_workers() -> None: - """Shut down all outstanding compile-worker pools.""" - for pool in _pool_set: - pool.shutdown() - after_fork() - - -def after_fork(): - """Reset pools to initial state without shutting them down""" - _pool_set.clear() - AsyncCompile.process_pool.cache_clear() - - -try: - os.register_at_fork(after_in_child=after_fork) -except AttributeError: - pass # register_at_fork does not exists on windows - - -class AsyncCompile: - def __init__(self) -> None: - pass - - @staticmethod - @functools.lru_cache(1) - def pool() -> ThreadPoolExecutor: - assert config.compile_threads > 1 - return ThreadPoolExecutor(config.compile_threads) - - @staticmethod - @functools.lru_cache(1) - def process_pool() -> AnyPool: - assert config.compile_threads > 1 - pool: AnyPool - if config.worker_start_method == "subprocess": - # Wrapper around ProcessPoolExecutor forks in a new process we control - pool = SubprocPool(config.compile_threads) - else: - # ensure properties have been calculated before processes - # are forked - caching_device_properties() - ctx = multiprocessing.get_context(config.worker_start_method) - pool = ProcessPoolExecutor( - config.compile_threads, - mp_context=ctx, - initializer=partial(_async_compile_initializer, os.getpid()), - ) - # when this pool is created in a subprocess object, the normal exit handler - # doesn't run, and we need to register our own handler. - # exitpriority has to be high, because another one of the finalizers will - # kill the worker thread that sends the shutdown message to the workers... - multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) - - _pool_set.add(pool) - return pool - - @classmethod - def warm_pool(cls) -> None: - if config.compile_threads <= 1: - return - _compile_start() - _warm_process_pool(cls.process_pool(), config.compile_threads) - _compile_end() - - @classmethod - def submit(cls, task: Callable[..., Any]) -> Any: - if config.compile_threads <= 1: - return task() - return cls.pool().submit(task) - - def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): - kernel_code_log.info("Triton Kernel:\n%s", source_code) - _compile_start() - _set_triton_ptxas_path() - - kernel = TritonCodeCache.load(kernel_name, source_code) - if config.compile_threads > 1: - return TritonFuture( - kernel, - self.process_pool().submit( - _worker_compile_triton, - kernel._reload_in_subproc, - ), - ) - else: - kernel.precompile() - return kernel - - def multi_kernel(self, *args, **kwargs) -> Any: - from torch._inductor.codegen.multi_kernel import MultiKernelCall - - # no need to call this in parallel since the sub-kernels are already parallel tasks - return MultiKernelCall(*args, **kwargs) - - def cpp(self, source_code: str): - kernel_code_log.info("CPP Kernel:\n%s", source_code) - if config.compile_threads <= 1: - return CppCodeCache.load(source_code).kernel - else: - get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) - return LambdaFuture(lambda: get_result().kernel) - - def cpp_pybinding(self, argtypes: List[str], source_code: str): - kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) - if config.compile_threads <= 1: - return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) - else: - get_result = CppPythonBindingsCodeCache.load_pybinding_async( - argtypes, source_code, submit_fn=self.submit - ) - return LambdaFuture(get_result) - - def cuda(self, source_code, dst_file_ext): - kernel_code_log.info("CUDA Kernel:\n%s", source_code) - - def task(): - return CUDACodeCache.load(source_code, dst_file_ext)[0] - - return self.submit(task) - - def halide(self, meta: HalideMeta, source_code: str): - kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) - if config.compile_threads <= 1: - return HalideCodeCache.generate_halide(meta, source_code) - else: - get_result = HalideCodeCache.generate_halide_async( - meta, source_code, submit_fn=self.submit - ) - return LambdaFuture(get_result) - - def wait(self, scope: Dict[str, Any]) -> None: - num_kernels = len( - [ - value - for key, value in scope.items() - if isinstance(value, (Future, CodeCacheFuture)) - ] - ) - pbar = tqdm( - total=num_kernels, - desc="Inductor Compilation", - disable=config.disable_progress, - delay=0, - ) - if config.compile_threads > 1: - for key, result in scope.items(): - if config.verbose_progress and not isinstance(pbar, _Faketqdm): - pbar.set_postfix_str(key) - if isinstance(result, (Future, CodeCacheFuture)): - scope[key] = result.result() - pbar.update(1) - - _compile_end() - - -if ( - os.environ.get("TORCH_TNT_IN_USE", "0") == "1" - or os.environ.get("TORCH_WARM_POOL", "1") != "1" -): - pass -else: - AsyncCompile.warm_pool() diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 1259418fc09e..cefd7e96acce 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -9,10 +9,11 @@ from sympy import Expr import torch + +import torch._inductor.async_compile import torch._ops from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey from .. import config, ir - from ..codecache import CudaKernelParamCache from ..utils import cache_on_self, sympy_product from ..virtualized import V diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 02f4fee19bb0..030c73833a0e 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -38,7 +38,7 @@ from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.symbol import symbol_is_type, SymT -from .. import codecache, config, ir +from .. import async_compile, config, ir from ..ir import ReinterpretView from ..runtime import triton_heuristics from ..runtime.hints import DeviceProperties @@ -506,7 +506,7 @@ def write_header(self) -> None: from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided - from {codecache.__name__} import AsyncCompile + from {async_compile.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 7eca31da87b4..26d75669a206 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -11,6 +11,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from unittest import mock +import torch._inductor.async_compile + import torch.fx import torch.utils._pytree as pytree diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py index 6cd1d1e600ac..e478a5345675 100644 --- a/torch/_inductor/compile_worker/__main__.py +++ b/torch/_inductor/compile_worker/__main__.py @@ -3,7 +3,7 @@ import sys import typing -from torch._inductor.codecache import caching_device_properties +from torch._inductor.async_compile import caching_device_properties from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 82c8f9a4fb71..5a27f7a08cdc 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -748,7 +748,6 @@ def save_cuda_kernel(self, grid, stream, launcher): # User defined triton kernels will have arbitrary kwarg names "meta": launcher.config.kwargs, } - from torch._inductor.codecache import CudaKernelParamCache binary = ( diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 3b8a13c49cb1..8a64d5941a46 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -28,6 +28,7 @@ import sympy import torch +import torch._inductor.async_compile from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 8361566e5f08..bd48144d7bfe 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -22,6 +22,7 @@ from filelock import FileLock import torch +import torch._inductor.async_compile from torch._dynamo.testing import rand_strided from torch._dynamo.utils import counters, identity, preserve_rng_state diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index e8db1e394b96..d441988d4bd2 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -5,7 +5,7 @@ import unittest import functools from subprocess import CalledProcessError - +import torch._inductor.async_compile from torch._inductor.codecache import CppCodeCache from torch.utils._triton import has_triton from torch.testing._internal.common_utils import ( From 998f38814c038f0a63a097f097f965cd03413a07 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Tue, 28 May 2024 16:41:11 -0700 Subject: [PATCH 079/706] [dtensor][debug] added c10d allgather, allgather_coalesced, and allgather_into_tensor_coalesced tracing to CommDebugMode (#127334) **Summary** Added c10d allgather, allgather_coalesced, and allgather_into_tensor_coalesced tracing to CommDebugMode and edited test case in test_comm_mode to include added features. **Test Plan** pytest test/distributed/_tensor/debug/test_comm_mode.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/127334 Approved by: https://github.com/XilunWu, https://github.com/yifuwang ghstack dependencies: #127025, #127029, #127040, #127134 --- .../_tensor/debug/test_comm_mode.py | 28 +++++++++++++++++++ torch/distributed/_tensor/debug/comm_mode.py | 5 +++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py index 4143da2bd88c..0962d9ca600a 100644 --- a/test/distributed/_tensor/debug/test_comm_mode.py +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -160,6 +160,34 @@ def test_comm_mode_with_c10d(self): comm_counts = comm_mode.get_comm_counts() self.assertEqual(comm_counts[c10d_ops.scatter_], 1) + # tests c10d all_gather tracing + output_list = [] + + with comm_mode: + dist.all_gather(output_list, inp, None) + + comm_counts = comm_mode.get_comm_counts() + self.assertEqual(comm_counts[c10d_ops.allgather_], 1) + + # tests c10d allgather_coalesced_ tracing + output_list = [] + + with comm_mode: + dist.all_gather_coalesced(output_list, [inp], None) + + comm_counts = comm_mode.get_comm_counts() + self.assertEqual(comm_counts[c10d_ops.allgather_coalesced_], 1) + + # tests c10d allgather_into_tensor_coalesced_ tracing + comm_mode = CommDebugMode() + with comm_mode as A, dist._coalescing_manager() as B: + # dist.all_reduce_coalesced(inp) + dist.all_gather_into_tensor(all_gather_out, inp) + + comm_counts = comm_mode.get_comm_counts() + self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual(comm_counts[c10d_ops.allgather_into_tensor_coalesced_], 1) + @requires_nccl() def test_comm_mode_with_c10d_allreduce_coalesced(self): world_pg = self.world_pg diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index 62e10a160384..d566da546d21 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -25,9 +25,12 @@ } c10d_collective_ops = { - c10d_ops.allreduce_, c10d_ops._allgather_base_, c10d_ops._reduce_scatter_base_, + c10d_ops.allgather_, + c10d_ops.allgather_coalesced_, + c10d_ops.allgather_into_tensor_coalesced_, + c10d_ops.allreduce_, c10d_ops.allreduce_coalesced_, c10d_ops.broadcast_, c10d_ops.gather_, From 15cc9f2e7e7b2b175f24755925dc38d4d430905d Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Tue, 28 May 2024 16:41:12 -0700 Subject: [PATCH 080/706] [dtensor][be] added checksAssert function and refactored test cases (#127356) **Summary** Added c10d checksAsserts functions to reduce written lines of code and refactored test cases. Merged one test case into another. **Test Plan** pytest test/distributed/_tensor/debug/test_comm_mode.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/127356 Approved by: https://github.com/XilunWu ghstack dependencies: #127025, #127029, #127040, #127134, #127334 --- .../_tensor/debug/test_comm_mode.py | 66 +++++++------------ 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py index 0962d9ca600a..6cb94c860024 100644 --- a/test/distributed/_tensor/debug/test_comm_mode.py +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -33,6 +33,13 @@ def setUp(self): self.device_type = "cuda" if torch.cuda.is_available() else "cpu" self.world_pg = dist.distributed_c10d._get_default_group() + def checksAssert(self, comm_mode, key, expected_value, expected_total_value): + comm_counts = comm_mode.get_comm_counts() + self.assertEqual(comm_mode.get_total_counts(), expected_total_value) + self.assertEqual(comm_counts[key], expected_value) + + return + def test_comm_mode(self): world_pg = self.world_pg @@ -115,50 +122,48 @@ def test_comm_mode_with_c10d(self): all_gather_out = inp.new_empty(self.world_size * 2, 8, 16) comm_mode = CommDebugMode() + + # tests c10d all_reduce tracing with comm_mode: dist.all_reduce(inp) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.allreduce_], 1) + self.checksAssert(comm_mode, c10d_ops.allreduce_, 1, 1) + # tests c10d all_gather_into_tensor tracing with comm_mode: dist.all_gather_into_tensor(all_gather_out, inp) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops._allgather_base_], 1) + self.checksAssert(comm_mode, c10d_ops._allgather_base_, 1, 1) + # tests c10d reduce_scatter tracing with comm_mode: dist.reduce_scatter_tensor(inp, all_gather_out) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops._reduce_scatter_base_], 1) + self.checksAssert(comm_mode, c10d_ops._reduce_scatter_base_, 1, 1) + # tests c10d broadcast tracing with comm_mode: dist.broadcast(inp, 0) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.broadcast_], 1) + self.checksAssert(comm_mode, c10d_ops.broadcast_, 1, 1) # tests c10d gather tracing with comm_mode: dist.gather(inp, None, 0) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.gather_], 1) + self.checksAssert(comm_mode, c10d_ops.gather_, 1, 1) # tests c10d reduce tracing with comm_mode: dist.reduce(inp, 0) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.reduce_], 1) + self.checksAssert(comm_mode, c10d_ops.reduce_, 1, 1) # tests c10d scatter tracing with comm_mode: dist.scatter(inp, None, 0) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.scatter_], 1) + self.checksAssert(comm_mode, c10d_ops.scatter_, 1, 1) # tests c10d all_gather tracing output_list = [] @@ -166,8 +171,7 @@ def test_comm_mode_with_c10d(self): with comm_mode: dist.all_gather(output_list, inp, None) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.allgather_], 1) + self.checksAssert(comm_mode, c10d_ops.allgather_, 1, 1) # tests c10d allgather_coalesced_ tracing output_list = [] @@ -175,39 +179,19 @@ def test_comm_mode_with_c10d(self): with comm_mode: dist.all_gather_coalesced(output_list, [inp], None) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.allgather_coalesced_], 1) + self.checksAssert(comm_mode, c10d_ops.allgather_coalesced_, 1, 1) # tests c10d allgather_into_tensor_coalesced_ tracing - comm_mode = CommDebugMode() - with comm_mode as A, dist._coalescing_manager() as B: - # dist.all_reduce_coalesced(inp) + with comm_mode, dist._coalescing_manager(): dist.all_gather_into_tensor(all_gather_out, inp) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_mode.get_total_counts(), 1) - self.assertEqual(comm_counts[c10d_ops.allgather_into_tensor_coalesced_], 1) + self.checksAssert(comm_mode, c10d_ops.allgather_into_tensor_coalesced_, 1, 1) - @requires_nccl() - def test_comm_mode_with_c10d_allreduce_coalesced(self): - world_pg = self.world_pg - - inp = torch.rand(2, 8, 16).cuda() - all_gather_out = inp.new_empty(self.world_size * 2, 8, 16) - - comm_mode = CommDebugMode() + # tests c10d allreduce_coalesced with comm_mode: dist.all_reduce_coalesced(inp) - dist.all_gather_into_tensor(all_gather_out, inp) - dist.reduce_scatter_tensor(inp, all_gather_out) - dist.broadcast(inp, 0) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_mode.get_total_counts(), 4) - self.assertEqual(comm_counts[c10d_ops.allreduce_coalesced_], 1) - self.assertEqual(comm_counts[c10d_ops._allgather_base_], 1) - self.assertEqual(comm_counts[c10d_ops._reduce_scatter_base_], 1) - self.assertEqual(comm_counts[c10d_ops.broadcast_], 1) + self.checksAssert(comm_mode, c10d_ops.allreduce_coalesced_, 1, 1) if __name__ == "__main__": From 3947731887a8829608cc2aec672dbb4420ede38f Mon Sep 17 00:00:00 2001 From: laithsakka Date: Wed, 29 May 2024 15:31:39 -0700 Subject: [PATCH 081/706] enable test_parameter_free_dynamic_shapes test when nn module inlining is on (#127424) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127424 Approved by: https://github.com/mlazos ghstack dependencies: #126444, #127146 --- test/dynamo/test_dynamic_shapes.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 4ceed0fad3dd..df86179657ca 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -85,10 +85,11 @@ def make_dynamic_cls(cls): DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821 ) - # TODO model is somehow not being freed when z3 is available - unittest.expectedFailure( - DynamicShapesMiscTests.test_parameter_free_dynamic_shapes # noqa: F821 - ) + if not config.inline_inbuilt_nn_modules: + # TODO model is somehow not being freed when z3 is available + unittest.expectedFailure( + DynamicShapesMiscTests.test_parameter_free_dynamic_shapes # noqa: F821 + ) unittest.expectedFailure( # Test is only valid without dynamic shapes From 5d316c81bea0f789183c7e9d5578b7e9776e3757 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Thu, 30 May 2024 04:50:52 +0000 Subject: [PATCH 082/706] [Inductor] Add 0 initialization to Triton masked loads (#127311) For a masked `tl.load` operation, the Triton language specifies that values masked out (i.e. where the mask evaluates to false) are undefined in the output of the load. Triton provides an optional `other` parameter which, when included, provides an explicit value to use for masked out values from the load. If the output from a masked load without the `other` parameter is used in a conditional, unexpected behavior can occur. Despite the language specification, all Triton backends currently in use by PyTorch Inductor (NVIDIA, AMD, and Intel) 0-initialize masked loads if `other` is not present (we recently changed the Intel backend behavior to match NVIDIA and AMD because that's what our users expect, even if we are not following the Triton spec to the tee). This PR attempts to "future-proof" Inductor for new backends (or perhaps changes in the current backends? - we did not see any performance change from 0-initializing in the Intel XPU backend but one could imagine compiler optimizations to remove paths that depend on undefined) to add an explicit `other` in instances where later conditionals depend on the `tl.load` output. I also removed an exception to `other` behavior for boolean loads, which was put in place for a Triton bug that should be fixed. I added `other` to the getting started documentation as a clue that masked load behavior requires explicit initialization if, even though I don't expect `undef` values to cause the example code to fail if the underlying output is not 0-initialized. Finally, I added other to the `make_load` function in `select_algorithm.py`, though I wasn't able to determine if that function was actually being called. Fixes #126535 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127311 Approved by: https://github.com/jansel --- docs/source/torch.compiler_get_started.rst | 2 +- test/inductor/test_cuda_repro.py | 4 ++-- torch/_inductor/codegen/triton.py | 9 +-------- torch/_inductor/runtime/triton_helpers.py | 2 +- torch/_inductor/select_algorithm.py | 2 +- 5 files changed, 6 insertions(+), 13 deletions(-) diff --git a/docs/source/torch.compiler_get_started.rst b/docs/source/torch.compiler_get_started.rst index 624b351d6fa8..caec0760acc7 100644 --- a/docs/source/torch.compiler_get_started.rst +++ b/docs/source/torch.compiler_get_started.rst @@ -64,7 +64,7 @@ the following: xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp0 = tl.load(in_ptr0 + (x0), xmask, other=0.0) tmp1 = tl.cos(tmp0) tmp2 = tl.sin(tmp1) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 6ead970a7884..c1ce2769a658 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -417,8 +417,8 @@ def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr): block_start = pid * XBLOCK offsets = block_start + tl.arange(0, XBLOCK) mask = offsets < xnumel - x = tl.load(in_out_ptr0 + offsets, mask=mask) - y = tl.load(in_ptr0 + offsets, mask=mask) + x = tl.load(in_out_ptr0 + offsets, mask=mask, other=0.0) + y = tl.load(in_ptr0 + offsets, mask=mask, other=0.0) output = x + y tl.store(in_out_ptr0 + offsets, output, mask=mask) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index d239e711db1f..785f79d91503 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1291,14 +1291,7 @@ def load(self, name: str, index: sympy.Expr): ep = ", eviction_policy='evict_first'" else: ep = "" - # "other" below is a workaround for https://github.com/openai/triton/issues/737 - # for bool, even though it's likely subject to the same bug, setting `other` leads - # to LLVM errors so we are skipping it for now - if ( - (has_tmpmask or has_rindex) - and V.graph.get_dtype(name) != torch.bool - and indexing.has_mask() - ): + if (has_tmpmask or has_rindex) and indexing.has_mask(): other = ", other=0.0" else: other = "" diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index 71b746bdf49a..9179e5cd676a 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -195,7 +195,7 @@ def bucketize_binary_search( while full_range > 1: mid = (high + low) // 2 mask = mid < OFFSETS_SIZE - bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask) + bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0) if right: is_above = values >= bucket_upper_bound else: diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bd48144d7bfe..124c6aea6125 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -453,7 +453,7 @@ def make_load(self, name, indices, mask): index = " + ".join( f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) ) - return f"tl.load({name} + ({index}), {mask})" + return f"tl.load({name} + ({index}), {mask}, other=0.0)" def template_env(self): """ From d44ab8ba6de5821754ce2bc5fd3d97d294468b1a Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 29 May 2024 10:56:23 -0700 Subject: [PATCH 083/706] [dynamo] utility to generate bytecode from template function (#127359) This will be helpful in reducing some of the hardcoded and python-version-dependent bytecode generation in various places in dynamo - e.g. resume function generation and object reconstruction. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127359 Approved by: https://github.com/jansel ghstack dependencies: #127329 --- test/dynamo/test_bytecode_utils.py | 115 ++++++++++++++++++++++- torch/_dynamo/bytecode_transformation.py | 114 ++++++++++++++++++++++ torch/_dynamo/testing.py | 6 ++ 3 files changed, 234 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_bytecode_utils.py b/test/dynamo/test_bytecode_utils.py index 3bbf7270b06b..0e813c883785 100644 --- a/test/dynamo/test_bytecode_utils.py +++ b/test/dynamo/test_bytecode_utils.py @@ -8,7 +8,7 @@ import torch import torch._dynamo.test_case from torch._dynamo import bytecode_analysis, bytecode_transformation -from torch._dynamo.testing import skipIfNotPy311 +from torch._dynamo.testing import skipIfNotPy311, skipIfNotPy312 class BytecodeTests(torch._dynamo.test_case.TestCase): @@ -414,6 +414,119 @@ def test_remove_dead_code_with_exn_table_entries(self): self.assertEqual(tab[0].end, 4) self.assertEqual(tab[0].target, 6) + def test_bytecode_from_template(self): + def fn(d1): + for k, v in d1.items(): + d2[k] = v + + varname_map = {"d1": "var1", "d2": "var2", "k": "var3", "v": "var4"} + insts = bytecode_transformation.bytecode_from_template(fn, varname_map) + for inst in insts: + self.assertIsNone(inst.starts_line) + if inst.opname.startswith("LOAD"): + self.assertNotIn(inst.argval, varname_map) + if inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR"): + self.assertIsNone(inst.arg) + self.assertFalse(inst.opname.startswith("RETURN")) + + @skipIfNotPy311 + def test_bytecode_from_template_noprefix(self): + # Test that 3.11+ prefix instructions are removed + def gen_fn(): + cl = None + + def fn(): + return cl + + return fn + + fn = gen_fn() + + dis_insts = list(dis.get_instructions(fn)) + names = {inst.opname for inst in dis_insts} + self.assertIn("RESUME", names) + self.assertIn("COPY_FREE_VARS", names) + + insts = bytecode_transformation.bytecode_from_template(fn) + names = {inst.opname for inst in insts} + self.assertNotIn("RESUME", names) + self.assertNotIn("COPY_FREE_VARS", names) + + def test_bytecode_from_template_noreturn1(self): + # Test that functions with multiple returns will have their + # returns replaced with jumps to the end + def fn(): + if x: + return y + z = 3 + return z + + dis_insts = list(dis.get_instructions(fn)) + dis_returns = list(filter(lambda x: x.opname.startswith("RETURN"), dis_insts)) + self.assertGreater(len(dis_returns), 1) + self.assertTrue(dis_insts[-1].opname.startswith("RETURN")) + + insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) + self.assertEqual(insts[-1].opname, "NOP") + self.assertEqual(len(dis_insts), len(insts)) + for i0, i1 in zip(dis_insts, insts): + if i0.opname.startswith("RETURN"): + if i1 is insts[-1]: + continue + self.assertIn("JUMP", i1.opname) + self.assertIs(i1.target, insts[-1]) + + # Should work with 3.10, but testing with 3.11+ is sufficient. + # In 3.8, `fn` ends with a RETURN_VALUE. + @skipIfNotPy311 + def test_bytecode_from_template_noreturn2(self): + # Test function that doesn't end with RETURN_VALUE + def fn(): + if x: + return x + if x: + return x + raise RuntimeError + + dis_insts = list(dis.get_instructions(fn)) + self.assertFalse(dis_insts[-1].opname.startswith("RETURN")) + + insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) + self.assertEqual(insts[-1].opname, "NOP") + self.assertEqual(insts[-2].opname, dis_insts[-1].opname) + self.assertEqual(len(dis_insts) + 1, len(insts)) + for i0, i1 in zip(dis_insts, insts): + if i0.opname.startswith("RETURN"): + self.assertIn("JUMP", i1.opname) + self.assertIs(i1.target, insts[-1]) + + @skipIfNotPy312 + def test_bytecode_from_template_noreturn_const(self): + # Test 3.12+ RETURN_CONST + def fn(): + if x: + return 1 + return 0 + + dis_insts = list(dis.get_instructions(fn)) + dis_return_consts = list( + filter(lambda x: x.opname == "RETURN_CONST", dis_insts) + ) + self.assertGreater(len(dis_return_consts), 1) + self.assertTrue(dis_insts[-1].opname == "RETURN_CONST") + + insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) + self.assertEqual(insts[-1].opname, "NOP") + insts_i = 0 + for i, inst in enumerate(dis_insts): + if inst.opname == "RETURN_CONST": + self.assertEqual(insts[insts_i].opname, "LOAD_CONST") + insts_i += 1 + if insts_i != len(insts) - 1: + self.assertIn("JUMP", insts[insts_i].opname) + self.assertIs(insts[insts_i].target, insts[-1]) + insts_i += 1 + class BytecodeHookTests(torch._dynamo.test_case.TestCase): def test_bytecode_hook(self): diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index dec673b0e910..f07fe1c7a0e0 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1117,6 +1117,23 @@ def should_compute_arg(): instructions[i].arg = idx +def clear_instruction_args(instructions): + # Clear the instruction arg for instructions that have argvals. + # Useful for using dis'd bytecode within generated bytecode. + for inst in instructions: + if ( + inst.argval is not _NotProvided + and ( + inst.opcode in HAS_LOCAL + or inst.opcode in HAS_NAME + or inst.opcode in HAS_FREE + or inst.opcode in HAS_CONST + ) + and inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR") + ): + inst.arg = None + + def get_code_keys() -> List[str]: # Python 3.11 changes to code keys are not fully documented. # See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24 @@ -1247,3 +1264,100 @@ def unique_id(name) -> str: def is_generator(code: types.CodeType) -> bool: co_generator = 0x20 return (code.co_flags & co_generator) > 0 + + +def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True): + """Generates bytecode from a template function `fn` for use in + dynamo bytecode generation. + + For example, we can generate Python-version-independent bytecode + for looping through a dictionary and copying the values to a new dictionary. + + def template(d1, d2): + for k, v in d1.items(): + d2[k] = v + + + or a try block: + + def template(): + try: + dummy1 + except: + dummy2 + raise + dummy3 + + Args: + fn: a function template to generate bytecode from + varname_map: a mapping of `fn`'s varnames to new names. This + map will be applied to the generated bytecode's varnames. + For example, local variables in `fn` can be replaced with + new names that are generated by `OutputGraph.new_var`. + noreturn: remove all RETURN_* bytecodes and replace them with a jump + to the end of the bytecode. + noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive). + """ + insts = cleaned_instructions(fn.__code__) + clear_instruction_args(insts) + + if noprefix: + for i, inst in enumerate(insts): + if inst.opname == "RESUME": + insts = insts[i + 1 :] + break + + for inst in insts: + # If we don't reset starts_line, then the generated + # bytecode's line number will be based on fn's. + inst.starts_line = None + if varname_map and inst.argval in varname_map: + inst.argval = varname_map[inst.argval] + + if noreturn: + if sys.version_info >= (3, 12): + # replace RETURN_CONST with LOAD_CONST RETURN_VALUE + new_insts = [] + for inst in insts: + if inst.opname == "RETURN_CONST": + inst.opcode = dis.opmap["LOAD_CONST"] + inst.opname = "LOAD_CONST" + new_insts.append(inst) + # no need to propagate target/exn table + new_insts.append(create_instruction("RETURN_VALUE")) + else: + new_insts.append(inst) + insts = new_insts + + returns = [] + for inst in insts: + if inst.opname == "RETURN_VALUE": + returns.append(inst) + + if len(returns) == 1 and returns[0] is insts[-1]: + # only 1 return at the end - just pop it + insts.pop(-1) + elif len(returns) > 0: + # create jump target - if the last inst is a return, + # we can replace it with a NOP and make that the jump target. + if insts[-1] is returns[-1]: + insts[-1].opname = "NOP" + insts[-1].opcode = dis.opmap["NOP"] + insts[-1].arg = None + insts[-1].argval = _NotProvided + returns.pop(-1) + else: + insts.append(create_instruction("NOP")) + + # replace returns with jumps + for inst in returns: + # don't replace inst with new instruction + # due to targetting/exn table/etc. + jump_inst = create_jump_absolute(insts[-1]) + inst.opname = jump_inst.opname + inst.opcode = jump_inst.opcode + inst.arg = jump_inst.arg + inst.argval = jump_inst.argval + inst.target = jump_inst.target + + return insts diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 9e9abe84228b..99b6607afead 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -343,6 +343,12 @@ def skipIfNotPy311(fn): return unittest.skip(fn) +def skipIfNotPy312(fn): + if sys.version_info >= (3, 12): + return fn + return unittest.skip(fn) + + def xfailIfPy312(fn): if sys.version_info >= (3, 12): return unittest.expectedFailure(fn) From cd06ae0cb80104cc63095040d1c209f2cf57ed9c Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 29 May 2024 17:14:05 -0700 Subject: [PATCH 084/706] Relax use_count constraints for swap_tensors when AccumulateGrad holds a reference (#127313) ### Before this PR: `torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1 ```python a = torch.randn(2, 3, requires_grad=True) b = torch.randn(2, 4) out = a * 2 out.sum().backward() # Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward # torch.utils.swap_tensors(a, b) del out # Calling swap_tensors here would pass torch.utils.swap_tensors(a, b) ``` ### After this PR: `torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad` A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph). ```python a = torch.randn(2, 3, requires_grad=True) b = torch.randn(2, 4) out = a * 2 out.sum().backward() # Calling swap_tensors here is ok torch.utils.swap_tensors(a, b) # If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors ``` ### Application to `nn.Module` This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address https://github.com/pytorch/pytorch/pull/126814#issuecomment-2127777866. Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node ```python import torch import torch.nn as nn m = nn.Linear(3, 5) inp = torch.randn(2, 3) out = m(inp) out.sum().backward() m.cpu() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127313 Approved by: https://github.com/soulitzer --- test/test_modules.py | 17 ++++++++++++++-- test/test_nn.py | 22 +++++++++++++++------ test/test_torch.py | 7 ++----- torch/csrc/Module.cpp | 12 ++---------- torch/csrc/autograd/python_variable.cpp | 8 ++++++++ torch/overrides.py | 1 + torch/utils/__init__.py | 26 +++++++++++++++++++++++++ 7 files changed, 70 insertions(+), 23 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index ab05e9df4355..e854eec8add7 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -863,7 +863,8 @@ def test_errors(self, device, dtype, module_info, training): else: raise NotImplementedError(f"Unknown error type {error_input.error_on}") - @modules([module for module in module_db if not module.is_lazy]) + # Only run this test for float32 because the test loops over all the dtypes + @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32]) @parametrize('swap', [True, False]) @parametrize('set_grad', [True, False]) @wrapSwapTensorsTest() @@ -879,6 +880,7 @@ def test_to(self, device, dtype, module_info, training, swap, set_grad): for module_input in module_inputs: c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs + args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs m = module_cls(*c_args, **c_kwargs) @@ -896,6 +898,17 @@ def _to(m, set_grad=False): setattr(m, n, new_b) _to(m, set_grad=set_grad) + # Check .to() can be run after forward and backward with swap + has_params = len(list(m.parameters())) > 0 + if swap and not set_grad and has_params: + out = m(*args, **kwargs) + if isinstance(out, tuple): + out = out[0] + out.sum().backward() + m.to(dtype=torch.half) + # reset + m.to(dtype=torch.float32) + prev_device, prev_dtype = device, dtype for device_, dtype_ in product(devices, dtypes): # if device/dtype do not change, grad.to(device, dtype) is a no-op so @@ -903,6 +916,7 @@ def _to(m, set_grad=False): # parameters will be wrapped in an nn.Parameter before swapping # which will cause the ._cdata to change g_no_swap = device_ == prev_device and dtype_ == prev_dtype + prev_prev_device, prev_prev_dtype = prev_device, prev_dtype prev_device, prev_dtype = device_, dtype_ p_ids_before = [id(p) for p in m.parameters()] @@ -940,7 +954,6 @@ def _to(m, set_grad=False): self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after))) self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after))) - @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32]) @parametrize('swap', [True, False]) @wrapSwapTensorsTest() diff --git a/test/test_nn.py b/test/test_nn.py index c6f3a61e972b..6dfac4f7ca1b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1594,19 +1594,29 @@ def add_one_inplace(t): finally: torch.__future__.set_overwrite_module_params_on_conversion(False) - def test_swap_module_params_fails_after_forward(self): + def test_swap_module_params_poisons_acc_grad(self): try: torch.__future__.set_swap_module_params_on_conversion(True) + # (1) backward cannot be run after _apply + # forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors + # additionally, if any Tensors are saved for backward, their use_count will be bumped m = torch.nn.Linear(2, 3) inp = torch.randn(2, 2) - # forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors out = m(inp) - with self.assertRaisesRegex(RuntimeError, re.escape("_apply(): Couldn't swap Linear.weight")): - m.half() - del out - # works as expected now m.half() self.assertTrue(all(p.dtype == torch.float16 for p in m.parameters())) + with self.assertRaisesRegex(RuntimeError, "Trying to execute AccumulateGrad node that was poisoned by swap_tensors"): + out.sum().backward() + # (2) _apply can be run after backward() + # After running backward, all the references generated by "save for backward" will be cleared + # So the use_count will be 2 (1 from Tensor itself, and 1 from AccumulateGrad node), swap_tensors + # should allow this. + inp2 = torch.randn(2, 2, dtype=torch.half) + out2 = m(inp2) + out2.sum().backward() + m.float() + self.assertTrue(all(p.dtype == torch.float32 for p in m.parameters())) + out3 = m(inp) finally: torch.__future__.set_swap_module_params_on_conversion(False) diff --git a/test/test_torch.py b/test/test_torch.py index 0d8a672b93cf..717943f43646 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10623,12 +10623,9 @@ def test_swap_basic(self): if t1.is_floating_point(): t3 = t1.clone().detach().requires_grad_(True) out = t3 * 2 - with self.assertRaisesRegex(RuntimeError, "Expected single reference to a's"): - torch.utils.swap_tensors(t3, t2) - del out - # Now succeeds torch.utils.swap_tensors(t3, t2) - torch.utils.swap_tensors(t1, t2) + with self.assertRaisesRegex(RuntimeError, "AccumulateGrad node that was poisoned by swap_tensors"): + out.sum().backward() wr = weakref.ref(t1) with self.assertRaisesRegex(RuntimeError, "has weakref"): diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 9ff9131435f4..5fad0f0a9541 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -375,22 +375,14 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { THPVariable* a = reinterpret_cast(a_); THPVariable* b = reinterpret_cast(b_); - TORCH_CHECK( - a->cdata->use_count() == 1, - "Expected single reference to a's Tensor object but got ", - a->cdata->use_count()); - TORCH_CHECK( - b->cdata->use_count() == 1, - "Expected single reference to b's Tensor object but got ", - b->cdata->use_count()); // weak_use_count() adds 1 if use_count is non-zero TORCH_CHECK( a->cdata->weak_use_count() == 1, - "Expected no weakrefs to a's Tensor object but got ", + "Expected no weakrefs to t1's Tensor object but got ", a->cdata->weak_use_count() - 1); TORCH_CHECK( b->cdata->weak_use_count() == 1, - "Expected no weakrefs to b's Tensor object but got ", + "Expected no weakrefs to t2's Tensor object but got ", b->cdata->weak_use_count() - 1); // Swap the Tensor Impl diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 078b0f92124c..65f4b0efd3c1 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1615,6 +1615,13 @@ int THPVariable_set_imag(PyObject* self, PyObject* imag, void* unused) { END_HANDLE_TH_ERRORS_RET(-1) } +PyObject* THPVariable__use_count(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + const auto& t = THPVariable_Unpack(self); + return THPUtils_packUInt64(t.use_count()); + END_HANDLE_TH_ERRORS +} + // properties are registered here because we are currently only able to bind // them manually. TODO: make declarable in native_functions // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) @@ -1766,6 +1773,7 @@ static PyMethodDef extra_methods[] = { THPVariable_rev_view_func_unsafe, METH_O, nullptr}, + {"_use_count", THPVariable__use_count, METH_NOARGS, nullptr}, {nullptr}}; struct THPVariableMeta { diff --git a/torch/overrides.py b/torch/overrides.py index 6c521bc7003b..509568900983 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -357,6 +357,7 @@ def get_ignored_functions() -> Set[Callable]: Tensor._is_any_true, Tensor._addmm_activation, Tensor.to_padded_tensor, + Tensor._use_count, } diff --git a/torch/utils/__init__.py b/torch/utils/__init__.py index ccdad48eca97..a5ca0329a794 100644 --- a/torch/utils/__init__.py +++ b/torch/utils/__init__.py @@ -46,6 +46,32 @@ def swap_attr(name): setattr(t1, name, (getattr(t2, name))) setattr(t2, name, tmp) + def error_pre_hook(grad_outputs): + raise RuntimeError("Trying to execute AccumulateGrad node that was poisoned by swap_tensors " + "this can happen when you try to run backward on a tensor that was swapped. " + "For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` " + "you should not change the device or dtype of the module (e.g. `m.cpu()` or `m.half()`) " + "between running forward and backward. To resolve this, please only change the " + "device/dtype before running forward (or after both forward and backward).") + + def check_use_count(t, name='t1'): + use_count = t._use_count() + error_str = (f"Expected use_count of {name} to be 1 or 2 with an AccumulateGrad node but got {use_count} " + f"make sure you are not holding references to the tensor in other places.") + if use_count > 1: + if use_count == 2 and t.is_leaf: + accum_grad_node = torch.autograd.graph.get_gradient_edge(t).node + # Make sure that the accumulate_grad node was not lazy_init-ed by get_gradient_edge + if t._use_count() == 2: + accum_grad_node.register_prehook(error_pre_hook) + else: + raise RuntimeError(error_str) + else: + raise RuntimeError(error_str) + + check_use_count(t1, 't1') + check_use_count(t2, 't2') + # Swap the types # Note that this will fail if there are mismatched slots swap_attr("__class__") From 705346bf8db09175f48b2c8323d29e46c5f7c58b Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 30 May 2024 07:08:42 +0000 Subject: [PATCH 085/706] [ONNX] Skip optimizer when it fails (#127349) continue #127039 (1) Skip optimizer when it fails (2) Update onnx, ort, and onnx-script (3) The update to onnx-script results in the actual optimizer and rewriter enabling in this PR, and https://github.com/pytorch/pytorch/pull/123379 did not update onnx-script. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127349 Approved by: https://github.com/justinchuby --- .ci/docker/common/install_onnx.sh | 6 +- .../test_dynamo_with_onnxruntime_backend.py | 12 +- test/onnx/pytorch_test_common.py | 34 +++- test/onnx/test_fx_op_consistency.py | 148 +++++++----------- test/onnx/test_fx_passes.py | 14 ++ test/onnx/test_fx_to_onnx.py | 48 ++++-- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 64 +++----- test/onnx/test_pytorch_onnx_onnxruntime.py | 3 + torch/onnx/_internal/exporter.py | 6 + 9 files changed, 177 insertions(+), 158 deletions(-) diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index a1a5fde7d2f5..a91c798fcdf2 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -30,10 +30,10 @@ pip_install \ pip_install coloredlogs packaging -pip_install onnxruntime==1.17.0 -pip_install onnx==1.15.0 +pip_install onnxruntime==1.18 +pip_install onnx==1.16.0 # pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps -pip_install onnxscript==0.1.0.dev20240315 --no-deps +pip_install onnxscript==0.1.0.dev20240523 --no-deps # Cache the transformers model to be used later by ONNX tests. We need to run the transformers # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ diff --git a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py index 1d5127b18603..951e7cfd7c54 100644 --- a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py +++ b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py @@ -24,7 +24,6 @@ from torch.testing._internal import common_utils sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - import onnx_test_common @@ -782,8 +781,9 @@ def record_onnx_model_transform(onnx_model): result = compiled_model() self.assertEqual(len(recorded_models), 1) + # NOTE: Constant folded by optimizer self.assertTrue( - "aten_add" in [node.op_type for node in recorded_models[0].graph.node] + "Constant" in [node.op_type for node in recorded_models[0].graph.node] ) self.assertEqual(result, torch.ones(4, 8)) @@ -822,11 +822,11 @@ def example_model(x: torch.Tensor): # Part 2: Change the ONNX model seen by the transform so that # ORT receives a different model. + # NOTE: the function is optimized away by optimizer def replace_relu_with_sigmoid(onnx_model): - for function in onnx_model.functions: - for node in function.node: - if node.op_type == "Relu": - node.op_type = "Sigmoid" + for node in onnx_model.graph.node: + if node.op_type == "Relu": + node.op_type = "Sigmoid" def another_example_model(x: torch.Tensor): y = torch.relu(x) diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index b9b5e9859bab..6fdbf4e92839 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -205,10 +205,10 @@ def xfail_dynamic_fx_test( Args: reason: The reason for xfailing dynamic exporting test. model_type (TorchModelType): The model type to xfail dynamic exporting test for. - When None, model type is not used to skip dynamic tests. + When None, model type is not used to xfail dynamic tests. Returns: - A decorator for skipping dynamic exporting test. + A decorator for xfailing dynamic exporting test. """ def skip_dec(func): @@ -225,6 +225,36 @@ def wrapper(self, *args, **kwargs): return skip_dec +def xfail_op_level_debug_test( + error_message: str, + model_type: Optional[TorchModelType] = None, + reason: Optional[str] = None, +): + """Xfail op level debug test. + + Args: + reason: The reason for xfailing op level debug test. + model_type (TorchModelType): The model type to xfail dynamic exporting test for. + When None, model type is not used to xfail op level debug tests. + + Returns: + A decorator for xfailing op level debug test. + """ + + def skip_dec(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if self.op_level_debug and ( + not model_type or self.model_type == model_type + ): + return xfail(error_message, reason)(func)(self, *args, **kwargs) + return func(self, *args, **kwargs) + + return wrapper + + return skip_dec + + def skip_dynamic_fx_test(reason: str, model_type: TorchModelType = None): """Skip dynamic exporting test. diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index 4c71aafa473e..760ede50bd6d 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -147,6 +147,7 @@ def skip_torchlib_forward_compatibility( ), xfail( "__rmatmul__", + dtypes=(torch.float16,), reason="fixme: Assertion error: result mismatch", ), xfail( @@ -217,11 +218,6 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.COMPLEX_TYPES, reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") ), - xfail( - "all", - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") - ), xfail( "allclose", reason=onnx_test_common.reason_dynamo_does_not_support("Allclose") @@ -240,11 +236,6 @@ def skip_torchlib_forward_compatibility( dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), ), - xfail( - "any", - reason=onnx_test_common.reason_onnx_does_not_support( - "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") - ), xfail( "arange", dtypes=(torch.uint8,), @@ -346,10 +337,6 @@ def skip_torchlib_forward_compatibility( "chalf", reason="fixme: ONNX shape type inference error: Invalid tensor data type 0." ), - xfail( - "chunk", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Chunk", "bool") - ), xfail( "chunk", dtypes=(torch.uint8, torch.int8, torch.int16,), @@ -424,6 +411,16 @@ def skip_torchlib_forward_compatibility( "cross", reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"), ), + xfail( + "diag", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), + ), + xfail( + "diagonal_copy", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), + ), xfail( "dot", dtypes=(torch.uint8, torch.int8, torch.int16,), reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16") @@ -523,6 +520,11 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.COMPLEX_TYPES, reason=onnx_test_common.reason_dynamo_does_not_support("full_like", "complex64") ), + xfail( + "gather", + reason="HandleNegativeAxis(int64_t, int64_t) IsAxisInRange(axis, tensor_rank) was \ + false. axis 0 is not in valid range [-0,-1]" + ), xfail( "geometric", reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), @@ -532,6 +534,11 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_script_does_not_support("Heaviside", "bool"), ), + xfail( + "index_add", + dtypes=(torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "int64, int32, bool"), + ), xfail( "index_fill", dtypes=onnx_test_common.COMPLEX_TYPES, @@ -539,7 +546,7 @@ def skip_torchlib_forward_compatibility( ), xfail( "index_put", - dtypes=onnx_test_common.BOOL_TYPES, + dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"), ), xfail( @@ -547,6 +554,11 @@ def skip_torchlib_forward_compatibility( dtypes=(torch.uint8, torch.int8, torch.int16,), reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"), ), + xfail( + "index_put", + dtypes=(torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "float16"), + ), xfail( "isnan", dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, @@ -624,11 +636,6 @@ def skip_torchlib_forward_compatibility( dtypes=(torch.float16,), reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", ), - xfail( - "logcumsumexp", - reason=onnx_test_common.reason_onnx_does_not_support( - "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") - ), xfail( "logical_and", dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, @@ -649,12 +656,7 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, reason=onnx_test_common.reason_onnx_script_does_not_support("Xor", "float, int"), ), - xfail( - "logsumexp", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceLogSumExp", "bool, int"), - ), - xfail( + skip( "masked.logsumexp", reason="fixme: https://github.com/onnx/onnx/issues/4986", ), @@ -724,12 +726,9 @@ def skip_torchlib_forward_compatibility( xfail( "max", variant_name="reduction_with_dim", + dtypes=(torch.int64,), reason="https://github.com/onnx/onnx/issues/4986", ), - xfail( - "mean", - reason="(ReduceMean) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), xfail( "min", variant_name="reduction_no_dim", @@ -864,6 +863,11 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.COMPLEX_TYPES, reason="fixme: Assertion error: result mismatch", ), + xfail( + "nn.functional.cosine_embedding_loss", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("CosineEmbeddingLoss", "bool"), + ), xfail( "nn.functional.ctc_loss", reason=onnx_test_common.reason_dynamo_does_not_support("aten.ctc_loss.default"), @@ -954,6 +958,16 @@ def skip_torchlib_forward_compatibility( variant_name="reflect", reason="fixme: Assertion error: result mismatch", ), + xfail( + "nn.functional.pixel_shuffle", + dtypes=(torch.int32, torch.int64) + onnx_test_common.BOOL_TYPES, + reason="fixme: ONNX Runtime does not support int32/64 inputs", + ), + xfail( + "nn.functional.poisson_nll_loss", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason="fixme: result mismatch with NaN.", + ), xfail( "nn.functional.rrelu", reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), @@ -1131,31 +1145,11 @@ def skip_torchlib_forward_compatibility( dtypes=(torch.float16,), reason="fixme: Assertion error: result mismatch", ), - xfail( - "split", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), - xfail( - "split", - variant_name="list_args", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), - xfail( - "split_with_sizes", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), xfail( "square", dtypes=(torch.int8, torch.uint8, torch.int16), reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"), ), - xfail( - "squeeze", - reason="fixme: Assertion error: result mismatch", - ), xfail( "squeeze", variant_name="multiple", @@ -1213,11 +1207,6 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.INT_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Floor", "int"), ), - xfail( - "unbind", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), xfail( "unflatten", dtypes=onnx_test_common.BOOL_TYPES, @@ -1240,16 +1229,6 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"), ), - xfail( - "unsafe_split", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), - xfail( - "unsafe_chunk", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), xfail( "where", dtypes=onnx_test_common.BOOL_TYPES, @@ -1415,8 +1394,10 @@ def skip_torchlib_forward_compatibility( ), xfail( "index_add", - matcher=lambda sample: len(sample.input.shape) < 2, - reason="fixme: https://github.com/microsoft/onnxscript/issues/1212", + matcher=lambda sample: len(sample.input.shape) == 0, + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "ScatterND", "0-D tensor" + ), ), xfail( "index_add", @@ -1425,8 +1406,10 @@ def skip_torchlib_forward_compatibility( ), xfail( "index_copy", - matcher=lambda sample: len(sample.input.shape) < 2, - reason="fixme: https://github.com/microsoft/onnxscript/issues/1212", + matcher=lambda sample: len(sample.input.shape) == 0, + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "ScatterND", "0-D tensor" + ), ), xfail( "index_copy", @@ -1457,12 +1440,6 @@ def skip_torchlib_forward_compatibility( matcher=lambda sample: len(sample.input.shape) == 0, reason="fixme: LogSoftMax does not support empty tensor as input", ), - xfail( - "logsumexp", - matcher=lambda sample: isinstance(sample.input, torch.Tensor) - and len(sample.input.shape) == 0, - reason="fixme: IsScalar", - ), skip( "masked.log_softmax", matcher=lambda sample: len(sample.input.shape) == 0, @@ -1473,12 +1450,6 @@ def skip_torchlib_forward_compatibility( matcher=lambda sample: torch.numel(sample.input) == 0, reason="values of matmul of [m, 0] and [0, n] matrices are undefined", ), - xfail( - "min", - variant_name="reduction_with_dim", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: https://github.com/onnx/onnx/issues/4986", - ), skip( "mm", matcher=lambda sample: torch.numel(sample.input) == 0, @@ -1570,8 +1541,7 @@ def skip_torchlib_forward_compatibility( xfail( "nn.functional.instance_norm", model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - matcher=lambda sample: sample.kwargs.get("running_mean") is not None - or sample.input.dtype in (torch.float16,), + matcher=lambda sample: sample.kwargs.get("running_mean") is not None, reason="fixme: KeyError: 'self___kwargs__running_mean'", ), xfail( @@ -1580,6 +1550,11 @@ def skip_torchlib_forward_compatibility( and sample.kwargs.get("padding") == 1, reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", ), + xfail( + "nn.functional.pixel_shuffle", + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: ORT does not support empty tensor as input", + ), xfail( "nonzero", matcher=lambda sample: len(sample.input.shape) == 0 @@ -1625,12 +1600,6 @@ def skip_torchlib_forward_compatibility( matcher=lambda sample: len(sample.input.shape) == 0, reason="fixme: LogSoftMax does not support empty tensor as input", ), - xfail( - "t", - matcher=lambda sample: isinstance(sample.input, torch.Tensor) - and len(sample.input.shape) < 2, - reason="fixme: IsScalar", - ), xfail( "unflatten", reason="Logic not implemented for size 0 inputs in op.Reshape", @@ -2016,6 +1985,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): "nn.functional.poisson_nll_loss": [3e-2, 1e-3], "nn.functional.nll_loss": [3e-2, 1e-3], "native_batch_norm": [3e-2, 1e-3], + "norm": [1e-2, 1e-2], "dot": [3e-2, 1e-3], "logit": [3e-2, 1e-3], "rsub": [3e-2, 1e-3], diff --git a/test/onnx/test_fx_passes.py b/test/onnx/test_fx_passes.py index 9ebbf11646dc..8389f3912075 100644 --- a/test/onnx/test_fx_passes.py +++ b/test/onnx/test_fx_passes.py @@ -1,4 +1,6 @@ # Owner(s): ["module: onnx"] +import pytorch_test_common + import torch import torch._dynamo import torch.fx @@ -96,6 +98,10 @@ def func(x, y, z): @common_utils.instantiate_parametrized_tests class TestModularizePass(common_utils.TestCase): + @pytorch_test_common.xfail( + error_message="'torch_nn_modules_activation_GELU_used_gelu_1' not found", + reason="optimizer", + ) @common_utils.parametrize( "is_exported_program", [ @@ -146,6 +152,10 @@ def forward(self, x, y): ) self.assertFalse(any("ReLU" in name for name in function_proto_names)) + @pytorch_test_common.xfail( + error_message="'torch_nn_modules_activation_ReLU_relu_1' not found", + reason="optimizer", + ) @common_utils.parametrize( "is_exported_program", [ @@ -187,6 +197,10 @@ def forward(self, x, y): self.assertIn("torch_nn_modules_activation_ReLU_relu_1", function_proto_names) self.assertIn("torch_nn_modules_activation_ReLU_relu_2", function_proto_names) + @pytorch_test_common.xfail( + error_message="'torch_nn_modules_activation_ReLU_inner_module_relu_1' not found", + reason="optimizer", + ) @common_utils.parametrize( "is_exported_program", [ diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index b660b0525dba..6369ff3872d4 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -171,9 +171,13 @@ def forward(self, input): torch.argmax(input, dim=1, keepdim=True), ) - _ = dynamo_export( - ArgminArgmaxModel(), model_input, export_options=self.export_options - ) + # NOTE: KeyError: dim raised in optimizer + with self.assertWarnsOnceRegex( + UserWarning, "ONNXScript optimizer failed. Skipping optimization." + ): + _ = dynamo_export( + ArgminArgmaxModel(), model_input, export_options=self.export_options + ) def test_multiple_outputs_op_with_evaluator(self): class TopKModel(torch.nn.Module): @@ -182,7 +186,8 @@ def forward(self, x): return torch.sum(values) x = torch.arange(1.0, 6.0, requires_grad=True) - onnx_program = dynamo_export(TopKModel(), x, export_options=self.export_options) + + _ = dynamo_export(TopKModel(), x, export_options=self.export_options) def test_unsupported_indices_fake_tensor_generated_with_op_level_debug(self): class EmbedModelWithoutPaddingIdx(torch.nn.Module): @@ -364,11 +369,13 @@ def _assert_node_outputs_has_value_info( node: onnx.NodeProto, value_infos: Mapping[str, onnx.ValueInfoProto], local_functions: Mapping[Tuple[str, str], onnx.FunctionProto], + exclude_names_in_value_info, function_id: str = "", ): for output in node.output: name = f"{function_id}/{output}" if function_id else output - self.assertIn(name, value_infos) + if name not in exclude_names_in_value_info: + self.assertIn(name, value_infos) if node.domain.startswith("pkg.onnxscript.torch_lib"): # No shape info available for values inside torchlib functions. return @@ -378,13 +385,25 @@ def _assert_node_outputs_has_value_info( for node in function.node: function_id = f"{function.domain}::{function.name}" _assert_node_outputs_has_value_info( - node, value_infos, local_functions, function_id + node, + value_infos, + local_functions, + exclude_names_in_value_info, + function_id, ) type_infos = {vi.name: vi for vi in model_proto.graph.value_info} functions = {(f.domain, f.name): f for f in model_proto.functions} + # NOTE: inputs, outputs, and initializers are not included in value_info spec + exclude_names_in_value_info = ( + [input.name for input in model_proto.graph.input] + + [output.name for output in model_proto.graph.output] + + [init.name for init in model_proto.graph.initializer] + ) for node in model_proto.graph.node: - _assert_node_outputs_has_value_info(node, type_infos, functions) + _assert_node_outputs_has_value_info( + node, type_infos, functions, exclude_names_in_value_info + ) def test_dynamo_export_retains_readable_parameter_and_buffer_names(self): class SubModule(torch.nn.Module): @@ -424,10 +443,11 @@ def forward(self, tensor_x: torch.Tensor): model = MNISTModel() onnx_program = torch.onnx.dynamo_export(model, tensor_x) model_proto = onnx_program.model_proto - self.assertEqual( - {initializer.name for initializer in model_proto.graph.initializer}, - {*model.state_dict().keys()}, - ) + + # NOTE: initializers could be optimized away by onnx optimizer + onnx_initilizers = {init.name for init in model_proto.graph.initializer} + torch_weights = {*model.state_dict().keys()} + self.assertTrue(onnx_initilizers.issubset(torch_weights)) @common_utils.parametrize( "checkpoint_type", @@ -708,7 +728,11 @@ def forward(self, input: torch.Tensor): input = input.to(float8_type) return input + torch.tensor(1.0, dtype=float8_type) - _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4)) + # NOTE: shape inference error raised in optimizer due to unsupported dtype + with self.assertWarnsOnceRegex( + UserWarning, "ONNXScript optimizer failed. Skipping optimization." + ): + _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4)) def test_export_with_logging_logger(self): logger = logging.getLogger(__name__) diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 149b9dc987bb..5345e0219c14 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -577,9 +577,6 @@ def forward(self, x): x = torch.randn(1, 1, 1, 32, device=torch.device("cuda")) self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (x,)) - # NOTE:The test was meant to test the empty bounding box case, but it is not - # supported. When we have vision model examples, we will have a better test case - # to demonstrate in FX and FX exporter. def test_view_dynamic_zero_dim(self): class ViewModel(torch.nn.Module): def forward(self, input): @@ -587,12 +584,11 @@ def forward(self, input): return input.view(1, -1) x = torch.ones(2) - # y = torch.empty(0) + y = torch.empty(0) self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( ViewModel(), (x,), - # additional_test_inputs=[((y,),)], # TODO: Without `additional_test_inputs` arg, dynamic shape cannot be verified - skip_dynamic_shapes_check=True, # Has static shape for dynamic_shapes=True due to 0/1 specialization + additional_test_inputs=[((y,),)], ) def test_flatten_dynamic_axes(self): @@ -666,6 +662,11 @@ def forward(self, x): @pytorch_test_common.xfail_if_model_type_is_exportedprogram( error_message="Trying to flatten user inputs with exported input tree spec" ) + @pytorch_test_common.xfail_dynamic_fx_test( + error_message="!(it.GetName().empty())", + reason="With after onnx==1.16, constant folding in optimizer causes this error.", + model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, + ) def test_gpt2_tiny_from_config(self): # Model config = transformers.GPT2Config( @@ -1145,6 +1146,11 @@ def create_kwargs(): reason="Dynamic shape check is not expected for exported program in this test suite.", model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, ) + @pytorch_test_common.xfail_dynamic_fx_test( + error_message="!(it.GetName().empty())", + reason="With after onnx==1.16, constant folding in optimizer causes this error.", + model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, + ) @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( error_message="Expected 4 inputs, got 2", reason="https://github.com/pytorch/pytorch/issues/115745", @@ -1259,13 +1265,14 @@ def create_model(): model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, ) @pytorch_test_common.xfail_dynamic_fx_test( - error_message="NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node", - reason="Need to check Trilu node in the ONNX graph", + error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", + reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, ) - @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( - error_message="NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node", - reason="Need to check Trilu node in the ONNX graph", + @pytorch_test_common.xfail_op_level_debug_test( + error_message="Could not find an implementation for Trilu(14) node", + reason="ORT error during op level dubug", + model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, ) @pytorch_test_common.xfail_if_model_type_is_exportedprogram( error_message="aot_autograd expected to have an entirely functional graph", @@ -1477,41 +1484,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) - @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( - error_message="Expected 4 inputs, got 2", - reason="https://github.com/pytorch/pytorch/issues/115745", - ) - def test_fake_tensor_mode_huggingface_tiny_gpt2_torch_load(self): - model_name = "sshleifer/tiny-gpt2" - device = "cpu" - - def create_model(): - return transformers.AutoModel.from_pretrained(model_name).to(device).eval() - - def create_args(): - tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) - kwargs = tokenizer("Hello world!", return_tensors="pt") - input_ids = kwargs["input_ids"] - attention_mask = kwargs["attention_mask"] - return input_ids, None, attention_mask - - def create_pytorch_only_extra_kwargs(): - return {"return_dict": False} - - self._test_fake_tensor_mode_exporter( - "huggingface_sshleifer_tiny-gpt2", - create_model, - create_args, - create_pytorch_only_extra_kwargs, - load_checkpoint_during_init=self.load_checkpoint_during_init, - export_within_fake_mode=self.export_within_fake_mode, - model_type=self.model_type, - ) - if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 816bcfc3b8df..e49d5d3bceeb 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -2585,6 +2585,9 @@ def forward(self, x, update): update = torch.randn(4, 1, 3, 2) self.run_test(IndexPutModel2(), (x, update)) + @unittest.skip( + "regression in 1.18: https://github.com/microsoft/onnxruntime/issues/20855" + ) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put_loop(self): @torch.jit.script diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index 7eefc5a917b0..cf9f1cd747e5 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -1276,6 +1276,12 @@ def export(self) -> ONNXProgram: "ONNXScript optimizer is not available. Skipping optimization. " "Please `pip install onnxscript -U` to enable post-export optimization." ) + except Exception as e: + warnings.warn( + "ONNXScript optimizer failed. Skipping optimization. " + "\n\nPLEASE REPORT A BUG AT https://github.com/microsoft/onnxscript/issues " + f"\n\nDetail:\n{e}" + ) return torch.onnx.ONNXProgram( onnx_model, From 1071437169b1f64806250457af7350c1e80f27ee Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Wed, 29 May 2024 16:00:34 -0700 Subject: [PATCH 086/706] Introduce cuda_p2p based fused_all_gather_matmul and fused_matmul_reduce_scatter (#126634) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126634 Approved by: https://github.com/Chillee, https://github.com/wanchaol --- test/distributed/test_cuda_p2p.py | 79 +++++ torch/distributed/_cuda_p2p/__init__.py | 365 +++++++++++++++++++++++- 2 files changed, 443 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_cuda_p2p.py b/test/distributed/test_cuda_p2p.py index 14ff4bd3d0eb..1e743896bc7b 100644 --- a/test/distributed/test_cuda_p2p.py +++ b/test/distributed/test_cuda_p2p.py @@ -6,9 +6,12 @@ import torch.distributed as dist from torch.distributed._cuda_p2p import ( + _fused_all_gather_matmul_fallback, + _fused_matmul_reduce_scatter_fallback, get_cuda_p2p_backend, get_p2p_buffer_size, is_cuda_p2p_group, + p2p_usage_counter, ) from torch.testing._internal.common_distributed import ( MultiProcessTestCase, @@ -16,6 +19,8 @@ skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, run_tests, skip_but_pass_in_sandcastle_if, skipIfRocm, @@ -43,6 +48,7 @@ def requires_cuda_p2p_access(): ) +@instantiate_parametrized_tests @requires_nccl() @requires_cuda_p2p_access() class ProcessGroupCudaP2PTest(MultiProcessTestCase): @@ -137,6 +143,79 @@ def test_p2p_buffer(self) -> None: torch.cuda.synchronize() dist.destroy_process_group() + @skipIfRocm + @skip_if_lt_x_gpu(2) + @parametrize("gather_dim", [0, 1]) + def test_fused_all_gather_matmul(self, gather_dim: int) -> None: + B = 8 + M = 64 + N = 16 + K = 32 + BUFFER_SIZE = B * M * K // self.world_size * 4 + + self._init_process_group(BUFFER_SIZE) + group = dist.group.WORLD + rank = self.rank + world_size = self.world_size + + torch.manual_seed(42 + rank) + A_shard = torch.rand(B, M // self.world_size, K, device="cuda") + Bs = [torch.rand(K, N, device="cuda") for _ in range(3)] + + ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback( + A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name + ) + with p2p_usage_counter() as counter: + ag_output_1, mm_outputs_1 = torch.ops.cuda_p2p.fused_all_gather_matmul( + A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name + ) + assert counter["fused_all_gather_matmul"] == 1 + + assert torch.allclose(ag_output_0, ag_output_1) + assert ag_output_0.stride() == ag_output_1.stride() + for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1): + assert torch.allclose(mm_output_0, mm_output_1) + assert mm_output_0.stride(), mm_output_1.stride() + + dist.barrier() + torch.cuda.synchronize() + dist.destroy_process_group() + + @skipIfRocm + @skip_if_lt_x_gpu(2) + @parametrize("scatter_dim", [0, 1]) + def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None: + B = 8 + M = 64 + N = 16 + K = 32 + BUFFER_SIZE = B * M * N // self.world_size * 4 * 2 + + self._init_process_group(BUFFER_SIZE) + group = dist.group.WORLD + rank = self.rank + world_size = self.world_size + + torch.manual_seed(42 + rank) + A = torch.rand(B, M, K, device="cuda") + B = torch.rand(K, N, device="cuda") + + output_0 = _fused_matmul_reduce_scatter_fallback( + A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name + ) + with p2p_usage_counter() as counter: + output_1 = torch.ops.cuda_p2p.fused_matmul_reduce_scatter( + A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name + ) + assert counter["fused_matmul_reduce_scatter"] == 1 + + assert torch.allclose(output_0, output_1) + assert output_0.stride() == output_1.stride() + + dist.barrier() + torch.cuda.synchronize() + dist.destroy_process_group() + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_cuda_p2p/__init__.py b/torch/distributed/_cuda_p2p/__init__.py index f91d1f29d98b..4d07bfcbf067 100644 --- a/torch/distributed/_cuda_p2p/__init__.py +++ b/torch/distributed/_cuda_p2p/__init__.py @@ -1,7 +1,8 @@ +from collections import defaultdict from contextlib import contextmanager from functools import partial -from typing import Callable, cast, List, Tuple, Union +from typing import Callable, cast, Dict, List, Optional, Tuple, Union import torch import torch.distributed._functional_collectives as funcol @@ -121,3 +122,365 @@ def get_p2p_buffer_size(group: c10d.ProcessGroup) -> int: extended_api=True, devices=["cuda"], ) + + +_test_with_non_cuda_p2p_group: bool = False + + +@contextmanager +def test_with_non_cuda_p2p_group(): + """ + Force ops in this file to work with non-cuda_p2p groups for testing + purposes. Not thread safe. + """ + global _test_with_non_cuda_p2p_group + prev = _test_with_non_cuda_p2p_group + try: + _test_with_non_cuda_p2p_group = True + yield + finally: + _test_with_non_cuda_p2p_group = prev + + +_current_p2p_usage_counter: Optional[Dict[str, int]] = None + + +@contextmanager +def p2p_usage_counter(): + """ + Record the number of ops that utilized p2p capability for testing purposes. + Fallbacks are excluded. + """ + global _current_p2p_usage_counter + prev = _current_p2p_usage_counter + try: + _current_p2p_usage_counter = defaultdict(int) + yield _current_p2p_usage_counter + finally: + _current_p2p_usage_counter = prev + + +def _pipelined_all_gather_and_consume( + shard: torch.Tensor, + shard_consumer: Callable[[torch.Tensor, int], None], + ag_out: torch.Tensor, + group: c10d.ProcessGroup, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + tensor = all_gather_tensor(shard, gather_dim=1, group=group) + chunks = tensor.chunk(group.size()) + for src_rank, chunk in enumerate(chunks): + shard_consumer(chunk, src_rank) + + NOTE: + - The shard passed to shard consumer will always be contiguous. + """ + p2p_buf_sz_req = shard.numel() * shard.element_size() + if get_p2p_buffer_size(group) < p2p_buf_sz_req: + # We preferred the caller to handle fallback so that the computation + # doesn't need to be decomposed. + raise RuntimeError( + f"_pipelined_all_gather_and_consume on input with shape={shard.shape} " + f"and dtype={shard.dtype} requires {p2p_buf_sz_req} bytes of p2p buffers " + f"(got {get_p2p_buffer_size(group)} bytes)." + ) + + backend = get_cuda_p2p_backend(group) + group_size = group.size() + rank = group.rank() + + backend.stream().wait_stream(torch.cuda.current_stream()) + local_p2p_buf = backend.get_p2p_buffer(rank, shard.shape, shard.dtype) + + chunks = ag_out.chunk(group.size()) + + # While consuming local shard, copy it to the local p2p buffer + # in another stream. + shard_consumer(shard, rank) + chunks[rank].copy_(shard) + + with torch.cuda.stream(backend.stream()): + local_p2p_buf.copy_(shard) + work = backend.intra_node_barrier() + work.wait() + + # At this point, all ranks have copied their local shard to + # their local p2p buffer. Each rank can now copy and consume + # remote shards. + for i in range(1, group_size): + if i % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend.stream() + remote_rank = (i + rank) % group_size + remote_p2p_buf = backend.get_p2p_buffer(remote_rank, shard.shape, shard.dtype) + with torch.cuda.stream(stream): + chunks[remote_rank].copy_(remote_p2p_buf) + shard_consumer(chunks[remote_rank], remote_rank) + + torch.cuda.current_stream().wait_stream(backend.stream()) + + with torch.cuda.stream(backend.stream()): + work = backend.intra_node_barrier() + work.wait() + + +def _pipelined_produce_and_all2all( + chunk_producer: Callable[[int, torch.Tensor], None], + output: torch.Tensor, + group: c10d.ProcessGroup, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + chunks = [ + chunk_producer(dst_rank, chunks[dst_rank]) + for dst_rank in range(group.size()): + ] + dist.all_to_all_single(output=output, input=torch.cat(chunks)) + """ + group_size = group.size() + rank = group.rank() + + out_chunks = output.chunk(group_size) + p2p_buf_sz_req = out_chunks[0].numel() * out_chunks[0].element_size() * 2 + if get_p2p_buffer_size(group) < p2p_buf_sz_req: + # We preferred the caller to handle fallback so that the computation + # doesn't need to be decomposed. + raise RuntimeError( + f"_pipelined_produce_and_all2all on output with shape={output.shape} " + f"and dtype={output.dtype} requires {p2p_buf_sz_req} bytes of p2p buffers " + f"(got {get_p2p_buffer_size(group)} bytes)." + ) + + backend = get_cuda_p2p_backend(group) + backend.stream().wait_stream(torch.cuda.current_stream()) + + def get_p2p_buf(rank: int, idx: int) -> torch.Tensor: + assert idx in (0, 1) + offset = 0 if idx == 0 else out_chunks[0].numel() + return backend.get_p2p_buffer( + rank, out_chunks[0].shape, out_chunks[0].dtype, offset + ) + + # Prepare two local p2p buffers, so that a remote rank can pull the result + # of step [i] in one p2p buffer while the local rank can compute the + # result of step [i+1] and write it directly the other p2p buffer. + local_p2p_buf_0 = get_p2p_buf(rank, 0) + local_p2p_buf_1 = get_p2p_buf(rank, 1) + + # Directly write the local result to the destination. + # No need to go through the p2p buffers. + chunk_producer(rank, out_chunks[rank]) + + with torch.cuda.stream(backend.stream()): + chunk_producer((rank + 1) % group_size, local_p2p_buf_0) + backend.intra_node_barrier() + remote_p2p_buf = get_p2p_buf((rank - 1) % group_size, 0) + out_chunks[(rank - 1) % group_size].copy_(remote_p2p_buf) + + for step in range(2, group_size): + remote_rank = (rank - step) % group_size + if step % 2 == 0: + stream = torch.cuda.current_stream() + p2p_buf = local_p2p_buf_1 + remote_p2p_buf = get_p2p_buf(remote_rank, 1) + else: + stream = backend.stream() + p2p_buf = local_p2p_buf_0 + remote_p2p_buf = get_p2p_buf(remote_rank, 0) + with torch.cuda.stream(stream): + chunk_producer((rank + step) % group_size, p2p_buf) + backend.intra_node_barrier() + out_chunks[remote_rank].copy_(remote_p2p_buf) + + torch.cuda.current_stream().wait_stream(backend.stream()) + backend.intra_node_barrier() + + +lib = torch.library.Library("cuda_p2p", "DEF") # noqa: TOR901 +lib.define( + "fused_all_gather_matmul(Tensor A, Tensor[] Bs, int gather_dim, str group_name) -> (Tensor, Tensor[])" +) +lib.define( + "fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor" +) + + +@torch.library.impl(lib, "fused_all_gather_matmul", "Meta") +def _fused_all_gather_matmul_fallback( + A_shard: torch.Tensor, + Bs: List[torch.Tensor], + gather_dim: int, + group_name: str, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + group_size = c10d._get_group_size_by_name(group_name) + A = torch.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = torch.ops._c10d_functional.wait_tensor(A) + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + return A.movedim(0, gather_dim), [ + torch.matmul(A, B).movedim(0, gather_dim) for B in Bs + ] + + +@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA") +def _fused_all_gather_matmul( + A_shard: torch.Tensor, + Bs: List[torch.Tensor], + gather_dim: int, + group_name: str, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + all_gather_tensor(A_shard, gather_dim, group_name) @ B + """ + if A_shard.dim() < 2: + raise ValueError("A_shard must be a matrix") + for B in Bs: + if B.dim() != 2: + raise ValueError("B must be a matrix") + if gather_dim < 0 or gather_dim >= A_shard.dim(): + raise ValueError("Invalid gather_dim") + + group = c10d._resolve_process_group(group_name) + p2p_buf_sz_req = A_shard.numel() * A_shard.element_size() + if ( + _test_with_non_cuda_p2p_group + or get_p2p_buffer_size(group) < p2p_buf_sz_req + # Pipelining a mamtul with split-k is not supported + or gather_dim == len(A_shard.shape) - 1 + ): + return _fused_all_gather_matmul_fallback(A_shard, Bs, gather_dim, group_name) + + if _current_p2p_usage_counter is not None: + _current_p2p_usage_counter["fused_all_gather_matmul"] += 1 + + # Move the gather_dim to the front and flatten the tensor into a 2D matrix. + # The flattened tensor doesn't need to be contiguous (for computation + # efficiency), as _pipelined_all_gather_and_consume guarantees that shards + # passed to shard_consumer are contiguous. + x = A_shard.movedim(gather_dim, 0) + leading_dims = [group.size()] + list(x.shape[:-1]) + x = x.flatten(0, -2) + + # Helper function for reverting the above transformation + def unflatten(t): + return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim) + + ag_out = x.new_empty( + x.shape[0] * group.size(), + x.shape[1], + ) + outputs = [ + x.new_empty( + x.shape[0] * group.size(), + B.shape[1], + ) + for B in Bs + ] + output_shards = [output.chunk(group.size()) for output in outputs] + + # Computing block-wise matmul along the first dim of A + def shard_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, B in enumerate(Bs): + torch.mm(shard, B, out=output_shards[idx][rank]) + + _pipelined_all_gather_and_consume( + x, + shard_consumer, + ag_out, + group, + ) + return unflatten(ag_out), [unflatten(output) for output in outputs] + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "Meta") +def _fused_matmul_reduce_scatter_fallback( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.Tensor: + res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + res = funcol.wait_tensor(res) + return res + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA") +def _fused_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.Tensor: + """ + Perform the following logic with micro-pipelined computation and + communication: + + reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + + NOTE: + - The K dim across ranks are currently accumulated with bf16 with results + in accuracy loss. + """ + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if scatter_dim < 0 or scatter_dim >= A.dim(): + raise ValueError("Invalid gather_dim") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(torch.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(torch.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + + group = c10d._resolve_process_group(group_name) + out_shape = [*A.shape[:-1], B.shape[1]] + out_shape[scatter_dim] //= group.size() + p2p_buf_sz_req = torch.Size(out_shape).numel() * A.element_size() * 2 + if _test_with_non_cuda_p2p_group or get_p2p_buffer_size(group) < p2p_buf_sz_req: + return _fused_matmul_reduce_scatter_fallback( + A, B, reduce_op, scatter_dim, group_name + ) + + if _current_p2p_usage_counter is not None: + _current_p2p_usage_counter["fused_matmul_reduce_scatter"] += 1 + + # Move the gather_dim to the front and flatten the tensor into a 2D matrix + x = A.movedim(scatter_dim, 0) + leading_dims = [group.size()] + list(x.shape[:-1]) + leading_dims[1] //= group.size() + x = x.flatten(0, -2) + shards = x.chunk(group.size()) + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: torch.Tensor) -> None: + torch.matmul(shards[rank], B, out=out) + + stacked_partials = x.new_empty(x.shape[0], B.shape[1]) + + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group, + ) + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + return reduce_fn( + stacked_partials.view(*leading_dims, -1) + .movedim(1, scatter_dim + 1) + .movedim(0, scatter_dim), + dim=scatter_dim, + ) From 30d98611a3a35287c47ded9647f0b4c81fbdf036 Mon Sep 17 00:00:00 2001 From: chuanqiw Date: Thu, 30 May 2024 12:10:12 +0000 Subject: [PATCH 087/706] [CI] add xpu test in periodic workflow (#126410) Works for https://github.com/pytorch/pytorch/issues/114850 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126410 Approved by: https://github.com/EikanWang, https://github.com/atalman --- .github/workflows/periodic.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 716a72cc6d23..cf2da41c1c44 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -242,3 +242,30 @@ jobs: build-environment: linux-focal-rocm6.1-py3.8 docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} + + linux-jammy-xpu-py3_8-build: + name: linux-jammy-xpu-py3.8 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-jammy-xpu-py3.8 + docker-image-name: pytorch-linux-jammy-xpu-2024.0-py3 + runner: linux.2xlarge + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" }, + { config: "default", shard: 2, num_shards: 4, runner: "linux.idc.xpu" }, + { config: "default", shard: 3, num_shards: 4, runner: "linux.idc.xpu" }, + { config: "default", shard: 4, num_shards: 4, runner: "linux.idc.xpu" }, + ]} + + linux-jammy-xpu-py3_8-test: + name: linux-jammy-xpu-py3.8 + uses: ./.github/workflows/_xpu-test.yml + needs: linux-jammy-xpu-py3_8-build + permissions: + id-token: write + contents: read + with: + build-environment: linux-jammy-xpu-py3.8 + docker-image: ${{ needs.linux-jammy-xpu-py3_8-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-xpu-py3_8-build.outputs.test-matrix }} From 9f73c65b8f644d599ff3ff53927b738cfbb7d191 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Thu, 30 May 2024 12:10:31 +0000 Subject: [PATCH 088/706] xpu: pass MAX_JOBS building xpu_mkldnn_proj (#126562) mkldnn is quite big project and MAX_JOBS support is essential when building on a system with big number of cpus and limited memory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126562 Approved by: https://github.com/jgong5, https://github.com/guangyey, https://github.com/albanD --- cmake/Modules/FindMKLDNN.cmake | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index f6a19812c83d..2e33654cc484 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -28,6 +28,14 @@ IF(NOT MKLDNN_FOUND) endif() set(DNNL_MAKE_COMMAND "cmake" "--build" ".") + include(ProcessorCount) + ProcessorCount(proc_cnt) + if ((DEFINED ENV{MAX_JOBS}) AND ("$ENV{MAX_JOBS}" LESS_EQUAL ${proc_cnt})) + list(APPEND DNNL_MAKE_COMMAND "-j" "$ENV{MAX_JOBS}") + if(CMAKE_GENERATOR MATCHES "Make|Ninja") + list(APPEND DNNL_MAKE_COMMAND "--" "-l" "$ENV{MAX_JOBS}") + endif() + endif() ExternalProject_Add(xpu_mkldnn_proj SOURCE_DIR ${MKLDNN_ROOT} PREFIX ${XPU_MKLDNN_DIR_PREFIX} From cdeb242fc977210e211fd77b217320205c9f4042 Mon Sep 17 00:00:00 2001 From: "haozhe.zhu" Date: Wed, 29 May 2024 05:29:15 -0700 Subject: [PATCH 089/706] [inductor] fix mkldnn linear binary fusion check ut (#127296) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In this PR: (1)Fix the unary fusion for bf16 conv/linear. Previously we registered same fusion pattern for `bf16. fp16`. And we do not check the dtype while matching the pattern. This results the `fp16` case matched the `bf16` pattern but in later replacement, we found that we have a float16 here which is not expected, so we do not fuse them. We fix it by checking dtypes to avoid `fp16` case matched `bf16` pattern. ``` def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): def fn(match): matched = _is_single_computation_op(computation_op, **lowp_dtype**)(match) # previously we do not check lowp_dtype here ``` It is not exposed before because we only check the match count, and the match count is anyway correct because we matched the pattern. To address this, we add check on number of `generated_kernel`. If it is not fused, there will be an additional kernel to compute the post op. (2)Previous the ut ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_binary ``` dose not check the fusion status, fix it in this PR. (3)Extend `test_conv_binary` to test with lp. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127296 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_mkldnn_pattern_matcher.py | 72 +++++++++++++++++--- torch/_inductor/fx_passes/mkldnn_fusion.py | 14 ++-- 2 files changed, 73 insertions(+), 13 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 756de35df84c..92a9dd59a5dc 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -10,7 +10,7 @@ from torch._dynamo import config as dynamo_config from torch._dynamo.utils import counters from torch._export import capture_pre_autograd_graph -from torch._inductor import config +from torch._inductor import config, metrics from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch.ao.quantization.quantize_pt2e import ( @@ -264,6 +264,7 @@ def forward(self, x): memory_format, dtype, ) in options: + metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) else: @@ -284,6 +285,18 @@ def forward(self, x): # Has extra dtype conversion nodes for autocast. match_nodes += 2 self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype) + generated_kernel_count = 0 + if dtype != torch.float32: + # "to_dtype" for input + generated_kernel_count = 1 + if memory_format == torch.contiguous_format: + # "to_dtype + to_channel_last" for input, "to_contiguous" for output + generated_kernel_count = 2 + if memory_format == torch.channels_last_3d: + # for float conv3d, the output for eager is channel last, we will generate "to_contiguous" for output + # for lp conv3d, the output for eager is channel last too, we will only generate "to_dtype" + generated_kernel_count = 1 + self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) def test_conv2d_unary_cpu(self): self._test_conv_unary_cpu_base(dim=4) @@ -321,6 +334,7 @@ def forward(self, x): dtypes.append(torch.float16) options = itertools.product(unary_list, [True, False], dtypes) for unary_fn, bias, dtype in options: + metrics.reset() mod = M(unary_fn, 10, 30, bias=bias).eval() # only fuse for linear when the dtype is bf16 mod = mod @@ -335,6 +349,8 @@ def forward(self, x): self._test_common( mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype ) + # only generated 1 kernel for "to" + self.assertEqual(metrics.generated_kernel_count, 1) @unittest.skipIf(not TEST_MKL, "Test requires MKL") def test_linear_fp32(self): @@ -386,6 +402,7 @@ def forward(self, x): ) for unary_fn, memory_format, dtype in options: + metrics.reset() x_shape = (1, 3, 28, 28) mod = M(unary_fn).eval() @@ -401,6 +418,14 @@ def forward(self, x): # Has extra dtype conversion nodes for autocast. match_nodes += 2 self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype) + generated_kernel_count = 0 + if dtype != torch.float32: + # "to" for input + generated_kernel_count = 1 + if memory_format == torch.contiguous_format: + # "to_dtype + to_channel_last" for input, "to_contiguous" for output + generated_kernel_count = 2 + self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) def _test_conv_binary_base(self, dim=4): assert dim == 4 or dim == 5 @@ -430,19 +455,29 @@ def forward(self, x): else: return self.binary_fn(x1, x2) + dtypes = [ + torch.float, + ] + if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + dtypes.append(torch.bfloat16) + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + dtypes.append(torch.float16) cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d test_memory_format = [torch.contiguous_format, cl_format] options = itertools.product( binary_list, [True, False], test_memory_format, + dtypes, ) for ( binary_fn, has_relu, memory_format, + dtype, ) in options: + metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) else: @@ -457,7 +492,19 @@ def forward(self, x): match_nodes = binary_list[binary_fn][1] if has_relu: match_nodes += 1 - self._test_common(mod, (v,), match_count, match_nodes + 2) + self._test_common( + mod, (v,), match_count, match_nodes + 2, check_autocast=dtype + ) + generated_kernel_count = 0 + if dtype != torch.float32: + # "to_dtype" for input + generated_kernel_count = 1 + if memory_format == torch.contiguous_format: + # "to_dtype + to_channel_last" for input, "to_contiguous" for output + generated_kernel_count = 2 + elif memory_format == torch.channels_last_3d: + generated_kernel_count = 1 + self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) def test_conv2d_binary(self): self._test_conv_binary_base(dim=4) @@ -489,7 +536,7 @@ def forward(self, x, y): ) out_feature = 30 for binary_fn, input_shape, bias, dtype in options: - torch._dynamo.reset() + metrics.reset() # addmm(mm) + (linear+add) match_count = 2 match_nodes = 3 @@ -498,13 +545,20 @@ def forward(self, x, y): # view + linear + view(joint_graph+freeze pass) match_count = match_count + 5 if is_inplace else match_count + 3 match_nodes = match_nodes + 7 if is_inplace else match_nodes + 5 - mod = M(binary_fn, input_shape[-1], out_feature, bias).to(dtype).eval() - v = torch.randn(input_shape).to(dtype) + mod = M(binary_fn, input_shape[-1], out_feature, bias).eval() + v = torch.randn(input_shape) other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype) - mod_c = torch.compile(mod) - out, code = run_and_get_code(mod_c, v, other) - self.assertEqual(out, mod(v, other), rtol=1e-2, atol=1e-2) - # TODO - assert fusions work code + self._test_common( + mod, + ( + v, + other, + ), + match_count, + match_nodes, + check_autocast=dtype, + ) + self.assertEqual(metrics.generated_kernel_count, 1) def test_multi_linear_share_same_input(self): # llama pattern. diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 3edb4a397932..5d1a723fa58a 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -197,9 +197,15 @@ def _binary_fusion_v1(computation_call, binary_fn): def _binary_fusion_v2(computation_call, binary_fn): return CallFunction(binary_fn, computation_call, KeywordArg("other")) - def _is_single_computation_op(computation_op): + def _is_single_computation_op(computation_op, lowp_dtype=None): def fn(match): computation_nodes = filter_nodes(match.nodes, computation_op) + + if lowp_dtype: + output_node_meta = match.output_node().meta.get("val") + if output_node_meta.dtype != lowp_dtype: + return False + if len(computation_nodes) < 1: return False if any(n.args[-3] != "none" for n in computation_nodes): @@ -210,7 +216,7 @@ def fn(match): def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): def fn(match): - matched = _is_single_computation_op(computation_op)(match) + matched = _is_single_computation_op(computation_op, lowp_dtype)(match) computation_node = filter_nodes(match.nodes, computation_op)[0] if lowp_dtype: conversion_dtype_nodes = filter_nodes( @@ -249,7 +255,7 @@ def fn(match, *args, **kwargs): def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): @register_lowering_pattern( - pattern, extra_check=_is_single_computation_op(computation_op) + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) ) def fn(match, *args, **kwargs): negative_slope = kwargs.get("negative_slope") @@ -291,7 +297,7 @@ def fn(match, *args, **kwargs): def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): @register_lowering_pattern( - pattern, extra_check=_is_single_computation_op(computation_op) + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) ) def fn(match, *args, **kwargs): min_value = kwargs.get("min_value") From 3f5d8636aaa34c7a78213bf32cc30a946ff57b46 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 29 May 2024 18:24:22 -0700 Subject: [PATCH 090/706] [inductor] Copy RedisRemoteCacheBackend into pytorch (#127480) Summary: We need an implementation of RedisRemoteCacheBackend with the same API that we're using for FbMemcacheRemoteFxGraphCacheBackend. So we'll stop using the Triton implementation and adapt a version for use by inductor. I also renamed parameters and cache entries to match our cache terminology. Test Plan: Ran this command twice and inspected log output to ensure I got cache hits: ``` TORCH_LOGS=+torch._inductor.codecache TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=1 python benchmarks/dynamo/torchbench.py --performance --inductor --device cuda --training --amp --print-compilation-time --only dcgan ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127480 Approved by: https://github.com/oulgen --- mypy.ini | 7 +++ test/inductor/test_codecache.py | 2 +- test/inductor/test_max_autotune.py | 2 +- torch/_inductor/codecache.py | 14 +++--- torch/_inductor/remote_cache.py | 46 ++++++++++++++++++++ torch/_inductor/runtime/triton_heuristics.py | 4 +- 6 files changed, 65 insertions(+), 10 deletions(-) create mode 100644 torch/_inductor/remote_cache.py diff --git a/mypy.ini b/mypy.ini index 48bd363ef6d1..7d51847da44b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -294,3 +294,10 @@ ignore_missing_imports = True [mypy-torch_xla.*] ignore_missing_imports = True + +# +# Third party dependencies that are optional. +# + +[mypy-redis] +ignore_missing_imports = True \ No newline at end of file diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index af12454df3c0..a96d9aa67e82 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -197,7 +197,7 @@ def put(self, filename, data): cache_module = ( "triton.runtime.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend" if config.is_fbcode() - else "triton.runtime.cache.RedisRemoteCacheBackend" + else "torch._inductor.remote_cache.RedisRemoteCacheBackend" ) with config.patch( diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index cef7d610ee4d..eed927a98644 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -269,7 +269,7 @@ def put(self, filename, data): cache_module = ( "triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend" if config.is_fbcode() - else "triton.runtime.cache.RedisRemoteCacheBackend" + else "torch._inductor.remote_cache.RedisRemoteCacheBackend" ) with config.patch( diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 62c7252db493..c47f01751482 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -976,16 +976,16 @@ def load( if remote: cache_id = "fx-graph-v1" try: - import triton - if config.is_fbcode(): - remote_cache = triton.runtime.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend( - cache_id + from triton.runtime.fb_memcache import ( + FbMemcacheRemoteFxGraphCacheBackend, ) + + remote_cache = FbMemcacheRemoteFxGraphCacheBackend(cache_id) else: - remote_cache = triton.runtime.cache.RedisRemoteCacheBackend( - cache_id - ) + from torch._inductor.remote_cache import RedisRemoteCacheBackend + + remote_cache = RedisRemoteCacheBackend(cache_id) except Exception: remote_cache = None log.warning("Unable to create a remote cache", exc_info=True) diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py new file mode 100644 index 000000000000..7c40f603c4d9 --- /dev/null +++ b/torch/_inductor/remote_cache.py @@ -0,0 +1,46 @@ +import os +from abc import abstractmethod + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, cache_id: str): + pass + + @abstractmethod + def get(self, key: str): + pass + + @abstractmethod + def put(self, key: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + """ + A Redis implementation of a remote/distributed cache. + """ + + def __init__(self, cache_id: str): + import redis + + self._cache_id = cache_id + self._key_fmt = os.environ.get( + "TORCHINDUCTOR_REDIS_KEY_FORMAT", "pt2:{cache_id}:{key}" + ) + self._redis = redis.Redis( + host=os.environ.get("TRITON_REDIS_HOST", "localhost"), + port=int(os.environ.get("TRITON_REDIS_PORT", 6379)), + ) + + def _get_key(self, key: str) -> str: + return self._key_fmt.format(cache_id=self._cache_id, key=key) + + def get(self, key: str): + return self._redis.get(self._get_key(key)) + + def put(self, key: str, data: bytes): + return self._redis.set(self._get_key(key), data) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5a27f7a08cdc..75584a60c0ff 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1078,7 +1078,9 @@ def cached_autotune( key ) else: - remote_cache = triton.runtime.cache.RedisRemoteCacheBackend(key) + from torch._inductor.remote_cache import RedisRemoteCacheBackend + + remote_cache = RedisRemoteCacheBackend(key) except Exception: remote_cache = None log.warning("Unable to create a remote cache", exc_info=True) From 6e0eeecc7cd4dc389683e35d1f2e34738e09e597 Mon Sep 17 00:00:00 2001 From: albanD Date: Thu, 30 May 2024 13:29:23 +0000 Subject: [PATCH 091/706] Add back private function torch.cuda.amp.autocast_mode._cast (#127433) This is unfortunately used in a few places in the wild: https://github.com/search?q=torch.cuda.amp.autocast_mode._cast&type=code Pull Request resolved: https://github.com/pytorch/pytorch/pull/127433 Approved by: https://github.com/zou3519, https://github.com/guangyey --- torch/cuda/amp/autocast_mode.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index e50206c70577..09a44f50c90b 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -50,6 +50,11 @@ def __call__(self, func): return super().__call__(func) +# Preserved only for BC reasons +def _cast(value, dtype): + return torch.amp.autocast_mode._cast(value, "cuda", dtype) + + @deprecated( "`torch.cuda.amp.custom_fwd(args...)` is deprecated. " "Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.", From ce63b676f365b291e7cb966e3299a8121bf99b3c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 30 May 2024 13:53:31 +0000 Subject: [PATCH 092/706] Revert "[compiled autograd] torch.compile API (#125880)" This reverts commit e1c322112a3d7b128b42e27f68bc9a714bfd9a09. Reverted https://github.com/pytorch/pytorch/pull/125880 on behalf of https://github.com/atalman due to sorry your PR broke lint, need to revert ([comment](https://github.com/pytorch/pytorch/pull/125880#issuecomment-2139605376)) --- test/inductor/test_compiled_autograd.py | 170 +----------------------- torch/_dynamo/config.py | 4 - torch/_dynamo/eval_frame.py | 32 +---- 3 files changed, 3 insertions(+), 203 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 2daacc308071..87299d796f6c 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn from torch import _inductor as inductor -from torch._dynamo import compiled_autograd, config +from torch._dynamo import compiled_autograd from torch._dynamo.utils import counters from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA @@ -54,14 +54,10 @@ def hook3(gI, gO): class TestCompiledAutograd(TestCase): def setUp(self) -> None: super().setUp() - torch._logging.set_logs(compiled_autograd_verbose=False) - config.compiled_autograd = False compiled_autograd.reset() def tearDown(self) -> None: super().tearDown() - torch._logging.set_logs(compiled_autograd_verbose=False) - config.compiled_autograd = False compiled_autograd.reset() def check_output_and_recompiles( @@ -234,170 +230,6 @@ def fn(): self.check_output_and_recompiles(fn) - def test_torch_compile_api_inductor(self): - def fn(): - torch.manual_seed(123) - model = torch.nn.Sequential( - torch.nn.Linear(4, 4), - torch.nn.Sigmoid(), - ) - - res = [] - for _ in range(3): - x = torch.randn([1, 4]) - - result = model(x).sum() - result.backward() - res.append(model[0].weight.grad) - res.append(model[0].bias.grad) - model.zero_grad() - return res - - expected = fn() - with config.patch(compiled_autograd=True): - compiled_fn = torch.compile(fn) - actual = compiled_fn() - self.assertEqual(expected, actual) - self.assertEqual(counters["compiled_autograd"]["captures"], 1) - - def test_torch_compile_api_aot_eager(self): - def fn(): - torch.manual_seed(123) - model = torch.nn.Sequential( - torch.nn.Linear(4, 4), - torch.nn.Sigmoid(), - ) - - res = [] - for _ in range(3): - x = torch.randn([1, 4]) - - result = model(x).sum() - result.backward() - res.append(model[0].weight.grad) - res.append(model[0].bias.grad) - model.zero_grad() - return res - - expected = fn() - with config.patch(compiled_autograd=True): - compiled_fn = torch.compile(fn, backend="aot_eager") - actual = compiled_fn() - self.assertEqual(expected, actual) - self.assertEqual(counters["compiled_autograd"]["captures"], 1) - - def test_torch_compile_api_eager(self): - def fn(): - torch.manual_seed(123) - model = torch.nn.Sequential( - torch.nn.Linear(4, 4), - torch.nn.Sigmoid(), - ) - - res = [] - for _ in range(3): - x = torch.randn([1, 4]) - - result = model(x).sum() - result.backward() - res.append(model[0].weight.grad) - res.append(model[0].bias.grad) - model.zero_grad() - return res - - expected = fn() - with config.patch(compiled_autograd=True): - compiled_fn = torch.compile(fn, backend="eager") - actual = compiled_fn() - self.assertEqual(expected, actual) - self.assertEqual(counters["compiled_autograd"]["captures"], 1) - - def test_multiple_torch_compile(self): - model = torch.nn.Sequential( - torch.nn.Linear(4, 4), - torch.nn.Sigmoid(), - ) - x = torch.randn([1, 4]) - - def fn(): - result = model(x).sum() - result.backward() - - model2 = torch.nn.Linear(4, 4) - x2 = torch.randn([1, 4]) - - def fn2(): - result = model2(x2).sum() - result.backward() - - no_ca1 = torch.compile(fn) - no_ca1() - self.assertEqual(counters["compiled_autograd"]["captures"], 0) - counters.clear() - - with config.patch(compiled_autograd=True): - with_ca = torch.compile(fn2) - with_ca() - self.assertEqual(counters["compiled_autograd"]["captures"], 1) - counters.clear() - - no_ca2 = torch.compile(fn) - no_ca2() - self.assertEqual(counters["compiled_autograd"]["captures"], 0) - - def test_torch_compile_graph_break(self): - model = torch.nn.Sequential( - torch.nn.Linear(4, 4), - torch.nn.Sigmoid(), - ) - x = torch.randn([1, 4]) - - @torch._dynamo.disable() - def fn(): - result = model(x).sum() - result.backward() - - with config.patch(compiled_autograd=True): - opt_fn = torch.compile(fn) - opt_fn() - - self.assertEqual(counters["compiled_autograd"]["captures"], 1) - - def test_torch_compile_graph_break2(self): - model = torch.nn.Sequential( - torch.nn.Linear(4, 4), - torch.nn.Sigmoid(), - ) - x = torch.randn([1, 4]) - - @torch._dynamo.disable() - def inner_fn(loss): - loss.backward() - - def fn(): - result = model(x).sum() - inner_fn(result) - - with config.patch(compiled_autograd=True): - opt_fn = torch.compile(fn) - opt_fn() - - self.assertEqual(counters["compiled_autograd"]["captures"], 1) - - def test_torch_compile_only_backward_call(self): - model = torch.nn.Sequential( - torch.nn.Linear(4, 4), - torch.nn.Sigmoid(), - ) - x = torch.randn([1, 4]) - - result = model(x).sum() - with config.patch(compiled_autograd=True): - opt_bwd = torch.compile(lambda: result.backward()) - opt_bwd() - - self.assertEqual(counters["compiled_autograd"]["captures"], 1) - def test_dynamo_boxed(self): def get_placeholders(gm_): placeholders = [] diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 212021859c46..6f4219a03b18 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -454,10 +454,6 @@ def default_debug_dir_root(): # WARNING: this is an experimental flag and is subject to change. _experimental_support_context_fn_in_torch_utils_checkpoint = False -# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile(). -# Note: AOT Autograd will still trace joint graphs. -compiled_autograd = False - if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 626b206cfe48..fa9311f2c18a 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -493,9 +493,6 @@ def __init__( export=False, dynamic=None, compiler_config=None, - rebuild_ctx: Optional[ - Callable[[], Union[OptimizeContext, _NullDecorator]] - ] = None, ): def on_enter(): install_generation_tagging_init() @@ -511,17 +508,6 @@ def on_enter(): compiler_config=compiler_config, ) - if config.compiled_autograd: - assert rebuild_ctx is not None - - def call_compiled_autograd(): - compiler_fn = rebuild_ctx() - ctx = torch._dynamo.compiled_autograd.enable(compiler_fn) - ctx.__enter__() - return functools.partial(ctx.__exit__, None, None, None) - - self.enter_exit_hooks.append(call_compiled_autograd) - class RunOnlyContext(_TorchDynamoContext): def __init__(self): @@ -591,7 +577,6 @@ def _optimize_catch_errors( export=False, dynamic=None, compiler_config=None, - rebuild_ctx=None, ): return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), @@ -600,7 +585,6 @@ def _optimize_catch_errors( export=export, dynamic=dynamic, compiler_config=compiler_config, - rebuild_ctx=rebuild_ctx, ) @@ -651,15 +635,7 @@ def is_inductor_supported(): return False -def optimize(*args, **kwargs): - def rebuild_ctx(): - return optimize(*args, **kwargs) - - return _optimize(rebuild_ctx, *args, **kwargs) - - -def _optimize( - rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], +def optimize( backend="inductor", *, nopython=False, @@ -667,7 +643,7 @@ def _optimize( guard_fail_fn=None, disable=False, dynamic=None, -) -> Union[OptimizeContext, _NullDecorator]: +): """ The main entrypoint of TorchDynamo. Do graph capture and call backend() to optimize extracted graphs. @@ -715,7 +691,6 @@ def toy_example(a, b): backend, dynamic=dynamic, hooks=hooks, - rebuild_ctx=rebuild_ctx, ) # The backend function is stashed in the callable returned by # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can @@ -728,7 +703,6 @@ def toy_example(a, b): compiler_config=backend.get_compiler_config() if hasattr(backend, "get_compiler_config") else None, - rebuild_ctx=rebuild_ctx, ) @@ -1492,7 +1466,6 @@ def optimize_assert( export=False, export_constraints=None, dynamic=None, - rebuild_ctx=None, ): """ The same as `torch._dynamo.optimize(backend, nopython=True)` @@ -1510,7 +1483,6 @@ def optimize_assert( backend_ctx_ctor, export=export, dynamic=dynamic, - rebuild_ctx=rebuild_ctx, ) From 3fb8a0b627e7463cfe6f8d88763507ea664efd84 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Thu, 23 May 2024 22:06:39 +0000 Subject: [PATCH 093/706] Fix nextafter in inductor CPP codegen (#126876) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126876 Approved by: https://github.com/peterbell10, https://github.com/jgong5 --- test/inductor/test_torchinductor_opinfo.py | 1 + torch/_inductor/codegen/cpp.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 2a7995de4e0e..b66c0ce0832f 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -440,6 +440,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "triu", "cummax", "cummin", + "nextafter", "_chunk_cat", "constant_pad_nd", } diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 92a9c285d2b1..a14f93e14e6b 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1236,8 +1236,8 @@ def log2(x): return f"{x}.log2()" @staticmethod - def nextafter(x): - return f"{x}.nextafter()" + def nextafter(x, y): + return f"{x}.nextafter({y})" @staticmethod def copysign(a, b): From 48538d3d144ebc38244afac107a5a4505dc57d22 Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 30 May 2024 10:06:48 +0000 Subject: [PATCH 094/706] Implement svd_lowrank and pca_lowrank for complex numbers (#125580) We fix a number of bugs previously present in the complex implementation. We also heavily simplify the implementation, using, among other things, that we now have conjugate views. I saw there is a comment regarding how slow some checks on this function are. As such, I removed quite a few of the combinations of inputs to make the OpInfo lighter. I still left a couple relevant examples to not regress coverage though. Fixes https://github.com/pytorch/pytorch/issues/122188 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125580 Approved by: https://github.com/pearu, https://github.com/peterbell10 --- test/functorch/test_ops.py | 3 + test/test_linalg.py | 21 ++-- torch/_dynamo/trace_rules.py | 3 - torch/_linalg_utils.py | 23 +--- torch/_lobpcg.py | 12 +- torch/_lowrank.py | 116 ++++++++---------- .../_internal/common_methods_invocations.py | 78 ++++++++---- 7 files changed, 125 insertions(+), 131 deletions(-) diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index c8df820c7c9c..bd75a8d0bb74 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -608,6 +608,7 @@ def abs_if_complex(t): "nn.functional.batch_norm", {torch.float32: tol(atol=4e-05, rtol=5e-05)} ), tol1("nn.functional.conv2d", {torch.float32: tol(atol=4e-05, rtol=5e-05)}), + tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), tol1( "nn.functional.multi_head_attention_forward", @@ -2366,6 +2367,8 @@ def fn(input, weight, bias): "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-06, rtol=5e-06)} ), tol1("nn.functional.conv3d", {torch.float32: tol(atol=5e-04, rtol=9e-03)}), + tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), + tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), ), ) def test_vmap_autograd_grad(self, device, dtype, op): diff --git a/test/test_linalg.py b/test/test_linalg.py index e6fb5fdca250..5963b1e50448 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -2416,7 +2416,7 @@ def test_nuclear_norm_exceptions_old(self, device): @skipCUDAIfNoCusolver @skipCPUIfNoLapack - @dtypes(torch.double) + @dtypes(torch.double, torch.cdouble) def test_svd_lowrank(self, device, dtype): from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix @@ -2439,14 +2439,12 @@ def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **option # check if u, s, v is a SVD u, s, v = u[..., :q], s[..., :q], v[..., :q] - A = u.matmul(s.diag_embed()).matmul(v.mT) + A = (u * s.unsqueeze(-2)).matmul(v.mH) self.assertEqual(A, a, rtol=1e-7, atol=2e-7) - # check if svd_lowrank produces same singular values as torch.svd - U, S, V = torch.svd(a) - self.assertEqual(s.shape, S.shape) - self.assertEqual(u.shape, U.shape) - self.assertEqual(v.shape, V.shape) + # check if svd_lowrank produces same singular values as linalg.svdvals + U, S, Vh = torch.linalg.svd(a, full_matrices=False) + V = Vh.mH self.assertEqual(s, S) if density == 1: @@ -2454,10 +2452,11 @@ def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **option # # check if pairs (u, U) and (v, V) span the same # subspaces, respectively - u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] - U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] - self.assertEqual(u.mT.matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - self.assertEqual(v.mT.matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + u, v = u[..., :actual_rank], v[..., :actual_rank] + U, V = U[..., :actual_rank], V[..., :actual_rank] + expected_ones = u.mH.matmul(U).det().abs() + self.assertEqual(expected_ones, torch.ones_like(expected_ones)) + self.assertEqual(v.mH.matmul(V).det().abs(), torch.ones_like(expected_ones)) all_batches = [(), (1,), (3,), (2, 3)] for actual_rank, size, all_batches in [ # noqa: B020 diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index cccb80fb0c77..90f0667fecfb 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2292,7 +2292,6 @@ "torch._linalg_utils._symeig", "torch._linalg_utils.basis", "torch._linalg_utils.bform", - "torch._linalg_utils.conjugate", "torch._linalg_utils.eig", "torch._linalg_utils.get_floating_dtype", "torch._linalg_utils.is_sparse", @@ -2302,8 +2301,6 @@ "torch._linalg_utils.qform", "torch._linalg_utils.solve", "torch._linalg_utils.symeig", - "torch._linalg_utils.transjugate", - "torch._linalg_utils.transpose", "torch._load_global_deps", "torch._lowrank._svd_lowrank", "torch._lowrank.get_approximate_basis", diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index c9d5cde41f60..198decab4826 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -43,30 +43,9 @@ def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: return torch.matmul(A, B) -def conjugate(A): - """Return conjugate of tensor A. - - .. note:: If A's dtype is not complex, A is returned. - """ - if A.is_complex(): - return A.conj() - return A - - -def transpose(A): - """Return transpose of a matrix or batches of matrices.""" - ndim = len(A.shape) - return A.transpose(ndim - 1, ndim - 2) - - -def transjugate(A): - """Return transpose conjugate of a matrix or batches of matrices.""" - return conjugate(transpose(A)) - - def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: """Return bilinear form of matrices: :math:`X^T A Y`.""" - return matmul(transpose(X), matmul(A, Y)) + return matmul(X.mT, matmul(A, Y)) def qform(A: Optional[Tensor], S: Tensor): diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index 6ca1e7294217..864b5dc6245f 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -924,7 +924,7 @@ def _update_ortho(self): S_, mm( Z[:, n - nc :], - _utils.basis(_utils.transpose(Z[: n - nc, n - nc :])), + _utils.basis(Z[: n - nc, n - nc :].mT), ), ) np = P.shape[-1] @@ -1045,7 +1045,7 @@ def _get_svqb( # The original algorithm 4 from [DuerschPhD2015]. d_col = (d**-0.5).reshape(d.shape[0], 1) - DUBUD = (UBU * d_col) * _utils.transpose(d_col) + DUBUD = (UBU * d_col) * d_col.mT E, Z = _utils.symeig(DUBUD) t = tau * abs(E).max() if drop: @@ -1057,7 +1057,7 @@ def _get_svqb( else: E[(torch.where(E < t))[0]] = t - return torch.matmul(U * _utils.transpose(d_col), Z * E**-0.5) + return torch.matmul(U * d_col.mT, Z * E**-0.5) def _get_ortho(self, U, V): """Return B-orthonormal U with columns are B-orthogonal to V. @@ -1105,7 +1105,7 @@ def _get_ortho(self, U, V): BV_norm = torch.norm(mm_B(self.B, V)) BU = mm_B(self.B, U) - VBU = mm(_utils.transpose(V), BU) + VBU = mm(V.mT, BU) i = j = 0 stats = "" for i in range(i_max): @@ -1125,7 +1125,7 @@ def _get_ortho(self, U, V): self.ivars["ortho_j"] = j return U BU = mm_B(self.B, U) - UBU = mm(_utils.transpose(U), BU) + UBU = mm(U.mT, BU) U_norm = torch.norm(U) BU_norm = torch.norm(BU) R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype) @@ -1136,7 +1136,7 @@ def _get_ortho(self, U, V): self.fvars[vkey] = rerr if rerr < tau_ortho: break - VBU = mm(_utils.transpose(V), BU) + VBU = mm(V.mT, BU) VBU_norm = torch.norm(VBU) U_norm = torch.norm(U) rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1 diff --git a/torch/_lowrank.py b/torch/_lowrank.py index 7a920ef4a455..c739cc37178e 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -17,7 +17,8 @@ def get_approximate_basis( """Return tensor :math:`Q` with :math:`q` orthonormal columns such that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` - approximates :math:`A - M`. + approximates :math:`A - M`. without instantiating any tensors + of the size of :math:`A` or :math:`M`. .. note:: The implementation is based on the Algorithm 4.4 from Halko et al, 2009. @@ -46,7 +47,7 @@ def get_approximate_basis( default value 2 is more than enough. M (Tensor, optional): the input tensor's mean of size - :math:`(*, 1, n)`. + :math:`(*, m, n)`. References:: - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding @@ -57,27 +58,27 @@ def get_approximate_basis( """ niter = 2 if niter is None else niter - m, n = A.shape[-2:] - dtype = _utils.get_floating_dtype(A) + dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype matmul = _utils.matmul - R = torch.randn(n, q, dtype=dtype, device=A.device) + R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device) # The following code could be made faster using torch.geqrf + torch.ormqr # but geqrf is not differentiable - A_H = _utils.transjugate(A) - if M is None: - Q = torch.linalg.qr(matmul(A, R)).Q - for i in range(niter): - Q = torch.linalg.qr(matmul(A_H, Q)).Q - Q = torch.linalg.qr(matmul(A, Q)).Q - else: - M_H = _utils.transjugate(M) - Q = torch.linalg.qr(matmul(A, R) - matmul(M, R)).Q - for i in range(niter): - Q = torch.linalg.qr(matmul(A_H, Q) - matmul(M_H, Q)).Q - Q = torch.linalg.qr(matmul(A, Q) - matmul(M, Q)).Q + X = matmul(A, R) + if M is not None: + X = X - matmul(M, R) + Q = torch.linalg.qr(X).Q + for i in range(niter): + X = matmul(A.mH, Q) + if M is not None: + X = X - matmul(M.mH, Q) + Q = torch.linalg.qr(X).Q + X = matmul(A, Q) + if M is not None: + X = X - matmul(M, Q) + Q = torch.linalg.qr(X).Q return Q @@ -89,19 +90,26 @@ def svd_lowrank( ) -> Tuple[Tensor, Tensor, Tensor]: r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, batches of matrices, or a sparse matrix :math:`A` such that - :math:`A \approx U diag(S) V^T`. In case :math:`M` is given, then + :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then SVD is computed for the matrix :math:`A - M`. .. note:: The implementation is based on the Algorithm 5.1 from Halko et al, 2009. - .. note:: To obtain repeatable results, reset the seed for the - pseudorandom number generator + .. note:: For an adequate approximation of a k-rank matrix + :math:`A`, where k is not known in advance but could be + estimated, the number of :math:`Q` columns, q, can be + choosen according to the following criteria: in general, + :math:`k <= q <= min(2*k, m, n)`. For large low-rank + matrices, take :math:`q = k + 5..10`. If k is + relatively small compared to :math:`min(m, n)`, choosing + :math:`q = k + 0..2` may be sufficient. - .. note:: The input is assumed to be a low-rank matrix. + .. note:: This is a randomized method. To obtain repeatable results, + set the seed for the pseudorandom number generator .. note:: In general, use the full-rank SVD implementation - :func:`torch.linalg.svd` for dense matrices due to its 10-fold + :func:`torch.linalg.svd` for dense matrices due to its 10x higher performance characteristics. The low-rank SVD will be useful for huge sparse matrices that :func:`torch.linalg.svd` cannot handle. @@ -116,7 +124,7 @@ def svd_lowrank( integer, and defaults to 2 M (Tensor, optional): the input tensor's mean of size - :math:`(*, 1, n)`, which will be broadcasted + :math:`(*, m, n)`, which will be broadcasted to the size of A in this function. References:: @@ -144,48 +152,30 @@ def _svd_lowrank( niter: Optional[int] = 2, M: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor]: + # Algorithm 5.1 in Halko et al 2009 + q = 6 if q is None else q m, n = A.shape[-2:] matmul = _utils.matmul - if M is None: - M_t = None - else: + if M is not None: M = M.broadcast_to(A.size()) - M_t = _utils.transpose(M) - A_t = _utils.transpose(A) - - # Algorithm 5.1 in Halko et al 2009, slightly modified to reduce - # the number conjugate and transpose operations - if m < n or n > q: - # computing the SVD approximation of a transpose in - # order to keep B shape minimal (the m < n case) or the V - # shape small (the n > q case) - Q = get_approximate_basis(A_t, q, niter=niter, M=M_t) - Q_c = _utils.conjugate(Q) - if M is None: - B_t = matmul(A, Q_c) - else: - B_t = matmul(A, Q_c) - matmul(M, Q_c) - assert B_t.shape[-2] == m, (B_t.shape, m) - assert B_t.shape[-1] == q, (B_t.shape, q) - assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape - U, S, Vh = torch.linalg.svd(B_t, full_matrices=False) - V = Vh.mH - V = Q.matmul(V) - else: - Q = get_approximate_basis(A, q, niter=niter, M=M) - Q_c = _utils.conjugate(Q) - if M is None: - B = matmul(A_t, Q_c) - else: - B = matmul(A_t, Q_c) - matmul(M_t, Q_c) - B_t = _utils.transpose(B) - assert B_t.shape[-2] == q, (B_t.shape, q) - assert B_t.shape[-1] == n, (B_t.shape, n) - assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape - U, S, Vh = torch.linalg.svd(B_t, full_matrices=False) - V = Vh.mH - U = Q.matmul(U) + + # Assume that A is tall + if m < n: + A = A.mH + if M is not None: + M = M.mH + + Q = get_approximate_basis(A, q, niter=niter, M=M) + B = matmul(Q.mH, A) + if M is not None: + B = B - matmul(Q.mH, M) + U, S, Vh = torch.linalg.svd(B, full_matrices=False) + V = Vh.mH + U = Q.matmul(U) + + if m < n: + U, V = V, U return U, S, V @@ -198,7 +188,7 @@ def pca_lowrank( This function returns a namedtuple ``(U, S, V)`` which is the nearly optimal approximation of a singular value decomposition of - a centered matrix :math:`A` such that :math:`A = U diag(S) V^T`. + a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}` .. note:: The relation of ``(U, S, V)`` to PCA is as follows: @@ -293,7 +283,7 @@ def pca_lowrank( ) ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device) - M = _utils.transpose(torch.sparse.mm(C_t, ones_m1_t)) + M = torch.sparse.mm(C_t, ones_m1_t).mT return _svd_lowrank(A, q, niter=niter, M=M) else: C = A.mean(dim=(-2,), keepdim=True) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 9a83bf5b0038..f59cca11becc 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2101,7 +2101,7 @@ def error_inputs_T(self, device, has_ndims_error=False): r'to reverse their shape is not supported\.')) -def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False, **kwargs): +def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False): """ This function produces two tensors of shape (*, m, k) and (*, n, k) with k <= min(m, n). Their matrix product could be used to generate tensor of shape (*, m, n) of rank k. @@ -2115,13 +2115,18 @@ def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad= for k in range(min(3, m, n)): a = make_arg((*batch, m, k)) b = make_arg((*batch, n, k)) - yield SampleInput(a, b, **kwargs) + yield a, b def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): - for sample in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad, **kwargs): - *batch, m, k = sample.input.shape - *_, n, _ = sample.args[0].shape + # Function that's well defined on the outputs for complex inputs + def fn(usv): + U, S, V = usv + return U @ V.mH, S + + for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): + *batch, m, k = a.shape + n = b.shape[-2] # NOTE: since svd_lowrank relies on non rank-revealing SVD, # it inherits the problem of unstable behavior with repeated @@ -2130,20 +2135,13 @@ def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwa # we can only use k for q. # This issues could be resolved with using a rank-revealing SVD # which does not include "zero" singular values. - op_kwargs = { - 'q': k, - 'M': None - } - - # without M specified - yield clone_sample(sample, **op_kwargs) + yield SampleInput(a, b, q=k, M=None).with_metadata(output_process_fn_grad=fn) - # now with M - # TODO: fix bug in the documentation for svd_lowrank: - # M has to be (*, m, n), and not (*, 1, n) as written - # in the documentation - op_kwargs['M'] = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad) - yield clone_sample(sample, **op_kwargs) + for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): + *batch, m, k = a.shape + n = b.shape[-2] + M = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(a, b, q=k, M=M).with_metadata(output_process_fn_grad=fn) def chunk_iter(iterable, size): it = iter(iterable) @@ -17700,10 +17698,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1): lambda a, b, **kwargs: torch.svd_lowrank(a @ b.mT, **kwargs), *args, **kwargs ), - dtypes=floating_types(), + dtypes=floating_and_complex_types(), # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, supports_out=False, + # Due to the use of randomness check_batched_grad=False, check_batched_gradgrad=False, check_batched_forward_grad=False, @@ -17711,14 +17710,29 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, sample_inputs_func=sample_inputs_svd_lowrank, decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, - DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}), - 'TestCommon', 'test_noncontiguous_samples', - device_type='cuda')], + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.complex64: tol(atol=1e-02, rtol=1e-02)}), + 'TestCommon', 'test_noncontiguous_samples'), + # FIXME This should be the following, but the toleranceOverride does not seem to do anything! + # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), + # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + DecorateInfo(unittest.skip("See comment above"), + 'TestFwdGradients', + 'test_fn_fwgrad_bwgrad', + dtypes=[torch.complex128]), + DecorateInfo(unittest.skip("See comment above"), + 'TestBwdGradientsCUDA', + 'test_fn_gradgrad', + dtypes=[torch.complex128]), + ], skips=( # test does not work with passing lambda for op DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), DecorateInfo(slowTest, 'TestCompositeCompliance', 'test_forward_ad'), )), OpInfo('pca_lowrank', @@ -17726,7 +17740,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): lambda a, b, **kwargs: torch.pca_lowrank(a @ b.mT, **kwargs), *args, **kwargs ), - dtypes=floating_types(), + dtypes=floating_and_complex_types(), # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, supports_out=False, @@ -17737,13 +17751,25 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_pca_lowrank, decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, - DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}), - 'TestCommon', 'test_noncontiguous_samples', - device_type='cuda')], + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.complex64: tol(atol=4e-02, rtol=4e-02)}), + 'TestCommon', 'test_noncontiguous_samples'), + # FIXME This should be the following, but the toleranceOverride does not seem to do anything! + # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), + # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + DecorateInfo(unittest.skip("See comment above"), + 'TestFwdGradients', + 'test_fn_fwgrad_bwgrad', + dtypes=[torch.complex128]), + + ], skips=( # test does not work with passing lambda for op DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), )), BinaryUfuncInfo('polar', From 18a3f781e6382e2222d7c30c18136267407f9953 Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 30 May 2024 10:11:10 +0000 Subject: [PATCH 095/706] Reduce number of samples in {svd,pca}_lowrank OpInfos (#127199) We don't need to generate so many samples for these very expensive ops. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127199 Approved by: https://github.com/peterbell10, https://github.com/zou3519 ghstack dependencies: #125580 --- test/functorch/test_ops.py | 5 +++++ .../_internal/common_methods_invocations.py | 17 ++++++----------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index bd75a8d0bb74..62e43843bdba 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -460,6 +460,11 @@ class TestOperators(TestCase): {torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type="cuda", ), + tol1( + "svd_lowrank", + {torch.float32: tol(atol=5e-05, rtol=7e-06)}, + device_type="mps", + ), tol1( "linalg.tensorsolve", {torch.float32: tol(atol=3e-04, rtol=3e-04)}, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f59cca11becc..1570942abfc8 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2108,14 +2108,13 @@ def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad= """ make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - batches = [(), (0, ), (2, ), (1, 1)] - size = [1, 5, 10] - + batches = [(), (2,)] + size = [3, 4] for batch, m, n in product(batches, size, size): - for k in range(min(3, m, n)): - a = make_arg((*batch, m, k)) - b = make_arg((*batch, n, k)) - yield a, b + k = 2 + a = make_arg((*batch, m, k)) + b = make_arg((*batch, n, k)) + yield a, b def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): @@ -17720,10 +17719,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', dtypes=[torch.complex128]), - DecorateInfo(unittest.skip("See comment above"), - 'TestBwdGradientsCUDA', - 'test_fn_gradgrad', - dtypes=[torch.complex128]), ], skips=( # test does not work with passing lambda for op From c9beea13ace3409090750189e18c7c50ad7a7fb5 Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 29 May 2024 11:27:52 -0700 Subject: [PATCH 096/706] Rewrite existing links to custom ops gdocs with the landing page (#127423) NB: these links will be live after the docs build happens, which is once a day. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/127423 Approved by: https://github.com/jansel, https://github.com/williamwen42 ghstack dependencies: #127291, #127292, #127400 --- aten/src/ATen/core/MetaFallbackKernel.cpp | 4 ++-- c10/core/StorageImpl.cpp | 2 +- c10/core/TensorImpl.h | 2 +- torch/_dynamo/output_graph.py | 2 +- torch/library.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/core/MetaFallbackKernel.cpp b/aten/src/ATen/core/MetaFallbackKernel.cpp index 2a7c34b17076..e87f641f9eb1 100644 --- a/aten/src/ATen/core/MetaFallbackKernel.cpp +++ b/aten/src/ATen/core/MetaFallbackKernel.cpp @@ -16,8 +16,8 @@ static void metaFallback( "fake impl or Meta kernel registered. You may have run into this message " "while using an operator with PT2 compilation APIs (torch.compile/torch.export); " "in order to use this operator with those APIs you'll need to add a fake impl. " - "Please see the following doc for next steps: " - "https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit"); + "Please see the following for next steps: " + "https://pytorch.org/docs/main/notes/custom_operators.html"); } TORCH_LIBRARY_IMPL(_, Meta, m) { diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index 9dd6f5f43131..2b5bbdb86c8a 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -18,7 +18,7 @@ void throwNullDataPtrError() { "If you're using torch.compile/export/fx, it is likely that we are erroneously " "tracing into a custom kernel. To fix this, please wrap the custom kernel into " "an opaque custom op. Please see the following for details: " - "https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ"); + "https://pytorch.org/docs/main/notes/custom_operators.html"); } // NOTE: [FakeTensor.data_ptr deprecation] diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index e49a66c916ff..877c1c09543c 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1580,7 +1580,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { "If you're using torch.compile/export/fx, it is likely that we are erroneously " "tracing into a custom kernel. To fix this, please wrap the custom kernel into " "an opaque custom op. Please see the following for details: " - "https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ\n" + "https://pytorch.org/docs/main/notes/custom_operators.html\n" "If you're using Caffe2, Caffe2 uses a lazy allocation, so you will need to call " "mutable_data() or raw_mutable_data() to actually allocate memory."); // Caller does the type check. diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 3ab386cf35b2..21cd2e889e90 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1676,7 +1676,7 @@ def example_value_from_input_node(self, node: torch.fx.Node): "(and fall back to eager-mode PyTorch) on all ops " "that have do not have the 'pt2_compliant_tag'. " "Please see the following doc for how to mark this op as PT2 compliant " - "https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ" + "https://pytorch.org/docs/main/notes/custom_operators.html" ) diff --git a/torch/library.py b/torch/library.py index 48055da5b55c..a69e16950f7e 100644 --- a/torch/library.py +++ b/torch/library.py @@ -556,7 +556,7 @@ def register_fake( This API may be used as a decorator (see examples). For a detailed guide on custom ops, please see - https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit + https://pytorch.org/docs/main/notes/custom_operators.html Examples: >>> import torch From ffe506e85350a505be5698c871d50b2fc614406d Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 29 May 2024 17:44:09 -0700 Subject: [PATCH 097/706] Better graph break msg (and warning) on Dynamo x Python C++ extension (#127301) Dynamo graph breaks on Python C/C++ extensions (e.g. pybinded functions). The usual way to handle this is to turn those extensions into custom ops. This PR adds a nicer graph break message and also changes it to unconditionally warn on this graph break (because graph break messages are usually not visible). Fixes https://github.com/pytorch/pytorch/issues/126799 Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/127301 Approved by: https://github.com/jansel ghstack dependencies: #127291, #127292, #127400, #127423 --- test/dynamo/test_misc.py | 33 ++++++++++++++++++++++++++++ torch/_dynamo/variables/functions.py | 26 ++++++++++++++++++++-- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index abc4f52dfbfb..d9611028f789 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -33,6 +33,7 @@ import torch.onnx.operators import torch.utils._pytree as pytree +import torch.utils.cpp_extension from torch import Tensor from torch._C import FileCheck from torch._dynamo import allow_in_graph @@ -223,6 +224,38 @@ def fn(x): with self.assertRaises(TypeError): fn(torch.randn(16)) + def test_cpp_extension_recommends_custom_ops(self): + cpp_source = """ + #include + at::Tensor foobar(const at::Tensor& x) { + return x.clone(); + } + """ + module = torch.utils.cpp_extension.load_inline( + name="mylib", + cpp_sources=cpp_source, + functions="foobar", + verbose=True, + ) + + x = torch.ones(2, 2, requires_grad=True) + counters.clear() + + @torch.compile(backend="eager") + def f(x): + return module.foobar(x) + + with self.assertWarnsOnceRegex( + UserWarning, ".*https://pytorch.org/docs/main/notes/custom_operators.html.*" + ): + f(x) + self.assertEqual(len(counters["graph_break"]), 1) + first_graph_break = list(counters["graph_break"].keys())[0] + self.assertExpectedInline( + first_graph_break, + """Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/docs/main/notes/custom_operators.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""", + ) + def test_callpacked(self): def call_packed(args): a, b, c = args diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 3fab4413cb0f..745e29af4929 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -6,6 +6,7 @@ import inspect import itertools import types +import warnings from typing import Dict, List, Optional, TYPE_CHECKING, Union import torch @@ -634,9 +635,30 @@ def wraps(fn): else: try: path = inspect.getfile(self.value) + msg = f"'skip function {self.value.__qualname__} in file {path}'" except TypeError: - path = f"Builtin {self.value.__name__}" - msg = f"'skip function {self.value.__qualname__} in file {path}'" + known_python_builtin_modules = {"_abc", "_warnings"} + if self.value.__module__ in known_python_builtin_modules: + msg = ( + f"Graph break due to unsupported Python builtin {self.value.__module__}.{self.value.__qualname__}. " + f"Please file an issue on GitHub " + f"so the PyTorch team can add support for it. " + ) + else: + msg = ( + f"Graph break due to unsupported builtin {self.value.__module__}.{self.value.__qualname__}. " + f"This function is either a Python builtin (e.g. _warnings.warn) " + f"or a third-party C/C++ Python extension (perhaps created with pybind). " + f"If it is a Python builtin, please file an issue on GitHub " + f"so the PyTorch team can add support for it and see the next case for a workaround. " + f"If it is a third-party C/C++ Python extension, please " + f"either wrap it into a PyTorch-understood custom operator " + f"(see https://pytorch.org/docs/main/notes/custom_operators.html " + f"for more details) or, if it is traceable, use " + f"torch.compiler.allow_in_graph." + ) + # also warn on it because most users won't see the graph break message + warnings.warn(msg) msg += f"', {self.reason}'" if self.reason else "" unimplemented(msg) From d2df0f56a3046d9f8bf8eb34fd787bf506647d3c Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Wed, 29 May 2024 20:10:48 +0000 Subject: [PATCH 098/706] Fix compilation_latency regression caused by #127060 (#127326) It seems that while #127060 improved the speed for tacotron2 it introduced a compilation_latency regression for some of the TIMM benchmarks. The original change was to precompute the Dep metadata - but apparently some benchmarks have few enough overlaps that precomputing O(n) deps was slower than ignoring O(n^2) deps. So change it to go back to computing the Dep metadata on demand but to then cache the result. `dm_nfnet_f0` was a good example because on the dashboard it showed an increase from 140s -> 154s. ``` python benchmarks/dynamo/timm_models.py --performance --cold-start-latency --training --amp --backend inductor --dynamic-shapes --dynamic-batch-only --device cuda --total-partitions 5 --partition-id 1 --output timm-0.csv --only dm_nfnet_f0 ``` Looking at the compilation_latency result. On viable (d6e3e8980): 172.777958 176.725071 177.907955 On viable with #127060 and #127061 fully backed out: 158.305166 158.688560 160.791187 On viable w/ this change: 160.094164 160.201845 161.752157 I think that's probably close enough considering the variance. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127326 Approved by: https://github.com/oulgen --- torch/_inductor/scheduler.py | 47 +++++++++++++++++------------------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 8a64d5941a46..5a13c7f3cae4 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -162,8 +162,6 @@ class BaseSchedulerNode: group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] read_writes: dependencies.ReadWrites unmet_dependencies: Set[Dep] - # Processed deps used while scoring fusion - read_and_write_deps_with_hint: Set[Tuple[Dep, int]] def __init__(self, scheduler: "Scheduler", node: ir.Buffer) -> None: self.scheduler: Scheduler = scheduler @@ -252,25 +250,6 @@ def set_read_writes(self, rw: dependencies.ReadWrites) -> None: self.unmet_dependencies = self.read_writes.reads self.prune_deps() - # read_and_write_deps_with_hint are a summary of read_writes used by - # score_fusion_memory() - def dep_size_hint(dep: Dep) -> int: - try: - if dep.has_unbacked_symbols(): - return 0 - return dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - return 0 - - self.read_and_write_deps_with_hint = { - (dep, hint) - for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) - if (hint := dep_size_hint(dep)) > 0 - } - def op_counts(self) -> Counter[str]: return self.read_writes.op_counts @@ -1394,9 +1373,12 @@ def merge(self, other: "NodeUser") -> "NodeUser": class Scheduler: + __dep_size_hint_cache: Dict[Dep, int] + @dynamo_timed def __init__(self, nodes: List[ir.Buffer]) -> None: super().__init__() + self.__dep_size_hint_cache = {} V.graph.scheduler = self self.backends: Dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) @@ -2505,6 +2487,22 @@ def score_fusion( proximity_score, ) + def dep_size_hint(self, dep: Dep) -> int: + res = 0 + if dep not in self.__dep_size_hint_cache: + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.__dep_size_hint_cache[dep] = res + else: + res = self.__dep_size_hint_cache[dep] + return res + def score_fusion_memory( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> int: @@ -2512,11 +2510,10 @@ def score_fusion_memory( The first term in our fusion score that estimates number of saved memory operations. """ - return sum( - hint - for dep, hint in node1.read_and_write_deps_with_hint - & node2.read_and_write_deps_with_hint + common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( + node2.read_writes.reads | node2.read_writes.writes ) + return sum(self.dep_size_hint(dep) for dep in common_memory_deps) def get_possible_fusions_with_highest_priority( self, possible_fusions: List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] From 576c5ef1dd5affd8425ddb07c58f3c9e0d1acfe8 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Wed, 29 May 2024 16:14:28 -0700 Subject: [PATCH 099/706] [inductor] fix some tests in test_max_autotune.py (#127472) Fix https://github.com/pytorch/pytorch/issues/126176 . We should not use torch.empty to generate input data if we are gonna do any accuracy test. torch.empty may return NaN. In that cause both the reference and the actual result may contain NaN at the same index. But `NaN != NaN` so the test fail. Also if torch.empty returns NaN is not deterministic. It may depends on other tests running earlier. Generating random data instead of calling torch.empty fixes the problem. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127472 Approved by: https://github.com/eellison, https://github.com/jansel --- test/inductor/test_max_autotune.py | 44 +++++++++++------------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index eed927a98644..bd74ea58ad59 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -9,7 +9,7 @@ from torch import multiprocessing as mp, nn from torch._dynamo import reset from torch._dynamo.exc import BackendCompilerFailed -from torch._dynamo.testing import reset_rng_state +from torch._dynamo.testing import rand_strided, reset_rng_state from torch._inductor import config from torch._inductor.autotune_process import ( BenchmarkRequest, @@ -674,12 +674,10 @@ def test_non_contiguous_input_mm(self): Make sure the triton template can work with non-contiguous inputs without crash. Check https://github.com/pytorch/pytorch/issues/125437 for more details. """ - x = torch.empty_strided( + x = rand_strided( (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" ) - y = torch.empty_strided( - (32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda" - ) + y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda") @torch.compile(mode="max-autotune") def f(x, y): @@ -687,16 +685,14 @@ def f(x, y): ref = x @ y act = f(x, y) - self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3)) + self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) def test_non_contiguous_input_addmm(self): - b = torch.empty((768), dtype=torch.bfloat16, device="cuda") - x = torch.empty_strided( + b = torch.randn((768), dtype=torch.bfloat16, device="cuda") + x = rand_strided( (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" ) - y = torch.empty_strided( - (32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda" - ) + y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda") @torch.compile(mode="max-autotune") def f(x, y): @@ -704,13 +700,13 @@ def f(x, y): ref = torch.addmm(b, x, y) act = f(x, y) - self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3)) + self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) def test_non_contiguous_input_bmm(self): - x = torch.empty_strided( + x = rand_strided( (1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device="cuda" ) - y = torch.empty_strided( + y = rand_strided( (1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device="cuda" ) @@ -720,22 +716,14 @@ def f(x, y): ref = torch.bmm(x, y) act = f(x, y) - self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3)) + self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) def test_non_contiguous_input_mm_plus_mm(self): - x1 = torch.empty_strided( - (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" - ) - y1 = torch.empty_strided( - (32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda" - ) + x1 = rand_strided((50257, 32768), (1, 50304), device="cuda") + y1 = rand_strided((32768, 768), (768, 1), device="cuda") - x2 = torch.empty_strided( - (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" - ) - y2 = torch.empty_strided( - (32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda" - ) + x2 = rand_strided((50257, 32768), (1, 50304), device="cuda") + y2 = rand_strided((32768, 768), (768, 1), device="cuda") @torch.compile(mode="max-autotune") def f(x1, y1, x2, y2): @@ -743,7 +731,7 @@ def f(x1, y1, x2, y2): ref = x1 @ y1 + x2 @ y2 act = f(x1, y1, x2, y2) - self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3)) + self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) @config.patch( max_autotune=True, From be7be9fa166b7748a3e75ec327811419126a57ff Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 30 May 2024 16:19:53 +0000 Subject: [PATCH 100/706] [Distributed] [8/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#125102) This PR continues to clean clang-tidy warnings in torch/csrc/distributed/c10d, following https://github.com/pytorch/pytorch/pull/124987. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125102 Approved by: https://github.com/ezyang --- torch/csrc/distributed/c10d/Functional.cpp | 18 +++-- torch/csrc/distributed/c10d/Store.hpp | 5 +- .../distributed/c10d/TCPStoreLibUvBackend.cpp | 5 +- torch/csrc/distributed/c10d/init.cpp | 2 +- .../c10d/quantization/quantization.cpp | 31 +++----- .../c10d/quantization/quantization.h | 10 +-- .../c10d/quantization/quantization_gpu.cu | 78 +++++++++---------- .../c10d/quantization/quantization_gpu.h | 10 +-- torch/csrc/distributed/c10d/sequence_num.cpp | 3 +- torch/csrc/distributed/c10d/socket.cpp | 2 +- 10 files changed, 67 insertions(+), 97 deletions(-) diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index 9d525f0d5640..2485999e7a00 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -199,7 +199,7 @@ at::Tensor all_gather_into_tensor( at::Tensor& all_gather_into_tensor_out( at::Tensor& input, int64_t group_size, - std::string group_name, + const std::string& group_name, at::Tensor& output) { c10d::AllgatherOptions opts; @@ -463,9 +463,9 @@ class ReduceScatterTensor static torch::autograd::Variable forward( torch::autograd::AutogradContext* ctx, const at::Tensor& input, - std::string reduce_op, + const std::string& reduce_op, int64_t group_size, - std::string group_name) { + const std::string& group_name) { TORCH_CHECK(reduce_op == "sum", "Only sum reduce op is supported"); ctx->saved_data["group_size"] = group_size; @@ -510,9 +510,9 @@ class ReduceScatterTensor at::Tensor reduce_scatter_tensor_autograd( const at::Tensor& input, - std::string reduce_op, + const std::string& reduce_op, int64_t group_size, - std::string group_name) { + const std::string& group_name) { return ReduceScatterTensor::apply(input, reduce_op, group_size, group_name); } @@ -523,7 +523,7 @@ class AllGatherIntoTensor torch::autograd::AutogradContext* ctx, const at::Tensor& input, int64_t group_size, - std::string group_name) { + const std::string& group_name) { ctx->saved_data["group_size"] = group_size; ctx->saved_data["group_name"] = group_name; @@ -566,7 +566,7 @@ class AllGatherIntoTensor at::Tensor all_gather_into_tensor_autograd( const at::Tensor& input, int64_t group_size, - std::string group_name) { + const std::string& group_name) { return AllGatherIntoTensor::apply(input, group_size, group_name); } @@ -607,7 +607,7 @@ at::Tensor shard_dim_alltoall( const at::Tensor& input, int64_t gather_dim, int64_t shard_dim, - std::string group_name) { + const std::string& group_name) { auto group = c10d::resolve_process_group(group_name); auto group_size = group->getSize(); std::vector output_sizes = input.sizes().vec(); @@ -619,12 +619,14 @@ at::Tensor shard_dim_alltoall( } output_sizes[shard_dim] = output_sizes[shard_dim] / group_size; std::vector inputs; + inputs.reserve(group_size); auto length = output_sizes[shard_dim]; for (int i = 0; i < group_size; i++) { inputs.push_back(input.narrow(shard_dim, i * length, length).contiguous()); } // allocate outputs std::vector outputs; + outputs.reserve(group_size); for (int i = 0; i < group_size; i++) { outputs.push_back(input.new_empty(output_sizes).contiguous()); } diff --git a/torch/csrc/distributed/c10d/Store.hpp b/torch/csrc/distributed/c10d/Store.hpp index 626a9e3b688b..061d0ed620ba 100644 --- a/torch/csrc/distributed/c10d/Store.hpp +++ b/torch/csrc/distributed/c10d/Store.hpp @@ -106,8 +106,7 @@ class StoreTimeoutGuard { explicit StoreTimeoutGuard( Store& store, const std::chrono::milliseconds& timeout) - : store_(store) { - oldTimeout_ = store.getTimeout(); + : store_(store), oldTimeout_(store.getTimeout()) { store.setTimeout(timeout); } @@ -123,7 +122,7 @@ class StoreTimeoutGuard { private: Store& store_; - std::chrono::milliseconds oldTimeout_; + std::chrono::milliseconds oldTimeout_{}; }; } // namespace c10d diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index f33cbb019401..845803c5e17e 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -752,7 +752,7 @@ class UvClient : public UvTcpSocket { if (!stream.read_key(key)) return false; - auto data = store->get(key); + const auto& data = store->get(key); StreamWriter sw(iptr()); sw.write_vector(data); sw.send(); @@ -884,8 +884,7 @@ class UvClient : public UvTcpSocket { return false; } - auto data = store->get(key); - sw.write_vector(data); + sw.write_vector(store->get(key)); } sw.send(); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 9f0122e78332..c4b9a9823c84 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2032,7 +2032,7 @@ communication mechanism. self->registerOnCompletionHook( [hookWrapper = ::c10d::PythonOnCompletionHook(std::move( hook))](std::shared_ptr<::c10d::WorkInfo> workInfo) { - hookWrapper(std::move(workInfo)); + hookWrapper(workInfo); }); }, py::arg("hook"), diff --git a/torch/csrc/distributed/c10d/quantization/quantization.cpp b/torch/csrc/distributed/c10d/quantization/quantization.cpp index 8ed6d97d6d80..2d4fa2ba3812 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization.cpp +++ b/torch/csrc/distributed/c10d/quantization/quantization.cpp @@ -2,10 +2,7 @@ #include #include -namespace torch { -namespace distributed { -namespace c10d { -namespace quantization { +namespace torch::distributed::c10d::quantization { // TODO: The kernels are copied from fbgemm_gpu, we should dedup them later @@ -31,11 +28,9 @@ static void BFloat16QuantizedToFloat_ref( const size_t nrows, const size_t ncols, float* const output) { - const int32_t output_columns = ncols; - for (const auto row : c10::irange(nrows)) { const at::BFloat16* input_row = input + row * ncols; - float* output_row = output + row * output_columns; + float* output_row = output + row * ncols; for (const auto col : c10::irange(ncols)) { uint32_t val_fp32 = static_cast( @@ -52,11 +47,9 @@ at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) { TENSOR_NDIM_EQUALS(input, 2); const auto input_sizes = input.sizes(); - const int32_t nrows = input_sizes[0]; - const int32_t ncols = input_sizes[1]; - const int32_t output_columns = ncols; - auto output = - at::empty({nrows, output_columns}, input.options().dtype(at::kHalf)); + const auto nrows = input_sizes[0]; + const auto ncols = input_sizes[1]; + auto output = at::empty({nrows, ncols}, input.options().dtype(at::kHalf)); FloatToBFloat16Quantized_ref( input.const_data_ptr(), @@ -73,13 +66,10 @@ at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) { TENSOR_NDIM_EQUALS(input, 2); const auto input_sizes = input.sizes(); - const int32_t nrows = input_sizes[0]; - const int32_t ncols = input_sizes[1]; - const int32_t output_columns = ncols; + const auto nrows = input_sizes[0]; + const auto ncols = input_sizes[1]; - auto output = at::empty( - {nrows, output_columns}, // 4 = sizeof(float) - input.options().dtype(at::kFloat)); // + auto output = at::empty({nrows, ncols}, input.options().dtype(at::kFloat)); BFloat16QuantizedToFloat_ref( reinterpret_cast(input.const_data_ptr()), nrows, @@ -99,7 +89,4 @@ TORCH_LIBRARY_IMPL(quantization, CPU, m) { m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu); } -} // namespace quantization -} // namespace c10d -} // namespace distributed -} // namespace torch +} // namespace torch::distributed::c10d::quantization diff --git a/torch/csrc/distributed/c10d/quantization/quantization.h b/torch/csrc/distributed/c10d/quantization/quantization.h index 8cf3455ce79b..3d2f23de421b 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization.h +++ b/torch/csrc/distributed/c10d/quantization/quantization.h @@ -8,15 +8,9 @@ #include #include -namespace torch { -namespace distributed { -namespace c10d { -namespace quantization { +namespace torch::distributed::c10d::quantization { at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input); at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input); -} // namespace quantization -} // namespace c10d -} // namespace distributed -} // namespace torch +} // namespace torch::distributed::c10d::quantization diff --git a/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu b/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu index 48cc7cfc4f3e..480cfb91cfb1 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu +++ b/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu @@ -9,16 +9,16 @@ // FP32 -> BF16 kernel __global__ void _float_to_bfloat16_cuda_kernel( const float* __restrict__ input, - const int nrows, - const int ncols, + const size_t nrows, + const size_t ncols, uint16_t* __restrict__ output) { - const int row_incre = blockDim.y * gridDim.y; - const int col_incre = blockDim.x * gridDim.x; - for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; + const auto row_incre = blockDim.y * gridDim.y; + const auto col_incre = blockDim.x * gridDim.x; + for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; row += row_incre) { const float* input_row = input + row * ncols; uint16_t* output_row = output + row * ncols; - for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; + for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; col += col_incre) { // Add 2^15 and right shift 16 to do round-nearest output_row[col] = @@ -31,14 +31,14 @@ __global__ void _float_to_bfloat16_cuda_kernel( // BF16 -> FP32 kernel __global__ void _bfloat16_to_float_cuda_kernel( const uint16_t* __restrict__ input, - const int nrows, - const int ncols, + const size_t nrows, + const size_t ncols, float* __restrict__ output) { - const int row_incre = blockDim.y * gridDim.y; - const int col_incre = blockDim.x * gridDim.x; - for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; + const auto row_incre = blockDim.y * gridDim.y; + const auto col_incre = blockDim.x * gridDim.x; + for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; row += row_incre) { - for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; + for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; col += col_incre) { const uint16_t* input_row = input + row * ncols; float* output_row = output + row * ncols; @@ -50,10 +50,7 @@ __global__ void _bfloat16_to_float_cuda_kernel( } } -namespace torch { -namespace distributed { -namespace c10d { -namespace quantization { +namespace torch::distributed::c10d::quantization { at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) { TENSOR_ON_CUDA_GPU(input); @@ -63,27 +60,28 @@ at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) { at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(input.get_device()); - const int nrows = input.size(0); - const int ncols = input.size(1); - const int output_columns = ncols; + const auto nrows = input.size(0); + const auto ncols = input.size(1); + const size_t output_columns = ncols; auto output = at::empty( - {nrows, output_columns}, + {nrows, ncols}, #if HAS_NCCL_BF16_DATATYPE input.options().dtype(at::kBFloat16)); #else input.options().dtype(at::kHalf)); #endif - if (nrows == 0 || output_columns == 0) { + if (nrows == 0 || ncols == 0) { return output; } - constexpr int threads_per_block = 256; - const int blockDim_x = std::min(output_columns, threads_per_block); + constexpr size_t threads_per_block = 256; + const auto blockDim_x = std::min(output_columns, threads_per_block); dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); - const int gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; - const int gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); + const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; + const auto gridDim_y = + std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); dim3 gridDim(gridDim_x, gridDim_y); _float_to_bfloat16_cuda_kernel<<< @@ -113,24 +111,25 @@ at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) { at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(input.get_device()); - const int nrows = input.size(0); - const int ncols = input.size(1); - const int output_columns = ncols; + const auto nrows = input.size(0); + const auto ncols = input.size(1); + const size_t output_columns = ncols; auto output = at::empty( - {nrows, output_columns}, // 4 = sizeof(float) + {nrows, ncols}, // 4 = sizeof(float) input.options().dtype(at::kFloat)); // at::kBytes for uint8_t - if (nrows == 0 || output_columns == 0) { + if (nrows == 0 || ncols == 0) { return output; } - constexpr int threads_per_block = 256; + constexpr size_t threads_per_block = 256; - const int blockDim_x = std::min(output_columns, threads_per_block); + const auto blockDim_x = std::min(output_columns, threads_per_block); dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); - const int gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; - const int gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); + const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; + const auto gridDim_y = + std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); dim3 gridDim(gridDim_x, gridDim_y); _bfloat16_to_float_cuda_kernel<<< @@ -152,14 +151,11 @@ at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) { } #define DISPATCH_TO_CUDA(name, function) \ - m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function))) + m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function))) TORCH_LIBRARY_IMPL(quantization, CUDA, m) { - DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda); - DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda); + DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda); + DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda); } -} // namespace quantization -} // namespace c10d -} // namespace distributed -} // namespace torch +} // namespace torch::distributed::c10d::quantization diff --git a/torch/csrc/distributed/c10d/quantization/quantization_gpu.h b/torch/csrc/distributed/c10d/quantization/quantization_gpu.h index 90bfc083b39d..f865599595d3 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization_gpu.h +++ b/torch/csrc/distributed/c10d/quantization/quantization_gpu.h @@ -8,15 +8,9 @@ #include #include -namespace torch { -namespace distributed { -namespace c10d { -namespace quantization { +namespace torch::distributed::c10d::quantization { at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input); at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input); -} // namespace quantization -} // namespace c10d -} // namespace distributed -} // namespace torch +} // namespace torch::distributed::c10d::quantization diff --git a/torch/csrc/distributed/c10d/sequence_num.cpp b/torch/csrc/distributed/c10d/sequence_num.cpp index 6ea35820179e..fd76247199f6 100644 --- a/torch/csrc/distributed/c10d/sequence_num.cpp +++ b/torch/csrc/distributed/c10d/sequence_num.cpp @@ -1,11 +1,10 @@ #include -#include #include #include namespace c10d { -SequenceNum::SequenceNum() : num_(c10::nullopt) {} +SequenceNum::SequenceNum() = default; SequenceNum::SequenceNum(const uint64_t num) : num_(num) {} diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index 093a47a076b0..6cbaa018762e 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -670,7 +670,7 @@ class SocketConnectOp { static const std::chrono::seconds delay_duration_; - enum class ConnectResult { Success, Error, Retry }; + enum class ConnectResult : uint8_t { Success, Error, Retry }; public: SocketConnectOp( From ea5c17de9050db2fc97bd38d209c740631b85ab7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 30 May 2024 16:23:06 +0000 Subject: [PATCH 101/706] Revert "Add torchao nightly testing workflow (#126885)" This reverts commit d938170314fa89acaad6b06fbbaac6b98f1e618f. Reverted https://github.com/pytorch/pytorch/pull/126885 on behalf of https://github.com/atalman due to Broke inductor periodic test ([comment](https://github.com/pytorch/pytorch/pull/126885#issuecomment-2140139486)) --- .ci/pytorch/common_utils.sh | 11 ---- .ci/pytorch/test.sh | 106 +--------------------------------- .github/pytorch-probot.yml | 1 - .github/workflows/torchao.yml | 85 --------------------------- 4 files changed, 3 insertions(+), 200 deletions(-) delete mode 100644 .github/workflows/torchao.yml diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 71e98cfaa721..51297f7bfff8 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -158,17 +158,6 @@ function install_torchvision() { fi } -function install_torchao() { - # Set ARCH list so that we can build fp16 with SM75+, the logic is copied from - # pytorch/builder - # https://github.com/pytorch/ao/blob/main/packaging/env_var_script_linux.sh#L16C1-L19 - TORCH_CUDA_ARCH_LIST="8.0;8.6" - if [[ ${CU_VERSION:-} == "cu124" ]]; then - TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0" - fi - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/ao.git" -} - function install_tlparse() { pip_install --user "tlparse==0.3.7" PATH="$(python -m site --user-base)/bin:$PATH" diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 0dcf5cd0b388..6a9c81fb79dc 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -483,89 +483,6 @@ test_perf_for_dashboard() { done } -test_torchao_perf_for_dashboard() { - TEST_REPORTS_DIR=$(pwd)/test/test-reports - mkdir -p "$TEST_REPORTS_DIR" - - local suite="$1" - shift - - local backend=torchao - local modes=() - if [[ "$DASHBOARD_TAG" == *training-true* ]]; then - modes+=(training) - fi - if [[ "$DASHBOARD_TAG" == *inference-true* ]]; then - modes+=(inference) - fi - # TODO: All the accuracy tests can be skipped once the CI accuracy checking is stable enough - local targets=(accuracy performance) - local torchao_backends=(noquant int8dynamic int8weightonly int4weightonly autoquant) - - for mode in "${modes[@]}"; do - if [[ "$mode" == "inference" ]]; then - dtype=bfloat16 - elif [[ "$mode" == "training" ]]; then - dtype=amp - fi - for target in "${targets[@]}"; do - local target_flag=("--${target}") - if [[ "$target" == "performance" ]]; then - target_flag+=( --cold-start-latency) - elif [[ "$target" == "accuracy" ]]; then - target_flag+=( --no-translation-validation) - fi - - for torchao_backend in "${torchao_backends[@]}"; do - if [[ "$DASHBOARD_TAG" == *${torchao_backend}-true* ]]; then - python "benchmarks/dynamo/$suite.py" \ - "${target_flag[@]}" --"$mode" --"$dtype" --quantization "${torchao_backend}" "$@" \ - --output "$TEST_REPORTS_DIR/${backend}_${torchao_backend}_${suite}_${dtype}_${mode}_cuda_${target}.csv" - fi - done - done - done -} - -test_single_torchao_benchmark() { - # Usage: test_single_torchao_benchmark huggingface 0 --args-for-script - - # Use test-reports directory under test folder will allow the CI to automatically pick up - # the test reports and upload them to S3. Need to use full path here otherwise the script - # will bark about file not found later on - TEST_REPORTS_DIR=$(pwd)/test/test-reports - mkdir -p "$TEST_REPORTS_DIR" - - local name="$1" - shift - local suite="$1" - shift - # shard id is mandatory, even if it is not passed - local shard_id="$1" - shift - - local partition_flags=() - if [[ -n "$NUM_TEST_SHARDS" && -n "$shard_id" ]]; then - partition_flags=( --total-partitions "$NUM_TEST_SHARDS" --partition-id "$shard_id" ) - fi - - test_torchao_perf_for_dashboard "$suite" \ - "${TORCHAO_BENCHMARK_FLAGS[@]}" "$@" "${partition_flags[@]}" - -} - -test_torchao_benchmark() { - # Usage: test_torchao_benchmark huggingface 0 - TEST_REPORTS_DIR=$(pwd)/test/test-reports - - local suite="$1" - shift - local shard_id="$1" - shift - - test_single_torchao_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "$@" -} - test_single_dynamo_benchmark() { # Usage: test_single_dynamo_benchmark inductor_inference huggingface 0 --args-for-script @@ -1321,15 +1238,15 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then test_inductor_distributed elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then test_inductor_micro_benchmark -elif [[ "${TEST_CONFIG}" == *inductor_huggingface* ]]; then +elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then install_torchvision id=$((SHARD_NUMBER-1)) test_dynamo_benchmark huggingface "$id" -elif [[ "${TEST_CONFIG}" == *inductor_timm* ]]; then +elif [[ "${TEST_CONFIG}" == *timm* ]]; then install_torchvision id=$((SHARD_NUMBER-1)) test_dynamo_benchmark timm_models "$id" -elif [[ "${TEST_CONFIG}" == *inductor_torchbench* ]]; then +elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then if [[ "${TEST_CONFIG}" == *cpu_inductor* ]]; then install_torchaudio cpu else @@ -1360,23 +1277,6 @@ elif [[ "${TEST_CONFIG}" == *inductor_torchbench* ]]; then fi PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id" fi -elif [[ "${TEST_CONFIG}" == *torchao_huggingface* ]]; then - install_torchao - install_torchvision - id=$((SHARD_NUMBER-1)) - test_torchao_benchmark huggingface "$id" -elif [[ "${TEST_CONFIG}" == *torchao_timm* ]]; then - install_torchao - install_torchvision - id=$((SHARD_NUMBER-1)) - test_torchao_benchmark timm_models "$id" -elif [[ "${TEST_CONFIG}" == *torchao_torchbench* ]]; then - install_torchao - install_torchaudio cuda - install_torchvision - id=$((SHARD_NUMBER-1)) - checkout_install_torchbench - PYTHONPATH=$(pwd)/torchbench test_torchao_benchmark torchbench "$id" elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper_abi_compatible* ]]; then install_torchvision test_inductor_cpp_wrapper_abi_compatible diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index ab5cb0deba87..d54346f81650 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -18,7 +18,6 @@ ciflow_push_tags: - ciflow/unstable - ciflow/xpu - ciflow/torchbench -- ciflow/torchao retryable_workflows: - pull - trunk diff --git a/.github/workflows/torchao.yml b/.github/workflows/torchao.yml deleted file mode 100644 index 0854eb099e92..000000000000 --- a/.github/workflows/torchao.yml +++ /dev/null @@ -1,85 +0,0 @@ -name: torchao - -on: - push: - tags: - - ciflow/torchao/* - workflow_dispatch: - inputs: - noquant: - description: Run noquant? - required: false - type: boolean - default: true - int8dynamic: - description: Run int8dynamic? - required: false - type: boolean - default: true - int8weightonly: - description: Run int8weightonly? - required: false - type: boolean - default: true - int4weightonly: - description: Run int4weightonly? - required: false - type: boolean - default: true - autoquant: - description: Run autoquant? - required: false - type: boolean - default: true - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - default: torchao_huggingface_perf,torchao_timm_perf,torchao_torchbench_perf - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -permissions: read-all - -jobs: - linux-focal-cuda12_1-py3_10-gcc9-torchao-build: - name: cuda12.1-py3.10-gcc9-sm80 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks - cuda-arch-list: '8.0' - test-matrix: | - { include: [ - { config: "torchao_huggingface_perf", shard: 1, num_shards: 3, runner: "linux.gcp.a100.large" }, - { config: "torchao_huggingface_perf", shard: 2, num_shards: 3, runner: "linux.gcp.a100.large" }, - { config: "torchao_huggingface_perf", shard: 3, num_shards: 3, runner: "linux.gcp.a100.large" }, - { config: "torchao_timm_perf", shard: 1, num_shards: 5, runner: "linux.gcp.a100.large" }, - { config: "torchao_timm_perf", shard: 2, num_shards: 5, runner: "linux.gcp.a100.large" }, - { config: "torchao_timm_perf", shard: 3, num_shards: 5, runner: "linux.gcp.a100.large" }, - { config: "torchao_timm_perf", shard: 4, num_shards: 5, runner: "linux.gcp.a100.large" }, - { config: "torchao_timm_perf", shard: 5, num_shards: 5, runner: "linux.gcp.a100.large" }, - { config: "torchao_torchbench_perf", shard: 1, num_shards: 4, runner: "linux.gcp.a100.large" }, - { config: "torchao_torchbench_perf", shard: 2, num_shards: 4, runner: "linux.gcp.a100.large" }, - { config: "torchao_torchbench_perf", shard: 3, num_shards: 4, runner: "linux.gcp.a100.large" }, - { config: "torchao_torchbench_perf", shard: 4, num_shards: 4, runner: "linux.gcp.a100.large" }, - ]} - selected-test-configs: ${{ inputs.benchmark_configs }} - secrets: - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - linux-focal-cuda12_1-py3_10-gcc9-torchao-test: - name: cuda12.1-py3.10-gcc9-sm80 - uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-cuda12_1-py3_10-gcc9-torchao-build - with: - build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - dashboard-tag: noquant-${{ inputs.noquant }}-int8dynamic-${{ inputs.int8dynamic }}-int8weightonly-${{ inputs.int8weightonly }}-int4weightonly-${{ inputs.int4weightonly }}-autoquant-${{ inputs.autoquant }} - docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-torchao-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-torchao-build.outputs.test-matrix }} - use-gha: anything-non-empty-to-use-gha - timeout-minutes: 720 - secrets: - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} From b506d37331a6c58c846ad6c822fa3c6d2caf62cf Mon Sep 17 00:00:00 2001 From: Daniil Kutz Date: Thu, 30 May 2024 16:25:02 +0000 Subject: [PATCH 102/706] Fix multiple errors while parsing NativeFunctions from YAML (#127413) Fixing multiple errors in parse_native_yaml when loading NativeFunctions from Yaml file. Add assertions that validates parsed data. Fixes #127404, #127405, #127406, #127407, #127408, #127409, #127410, #127411 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127413 Approved by: https://github.com/ezyang --- torchgen/gen.py | 8 +++++++- torchgen/model.py | 17 ++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/torchgen/gen.py b/torchgen/gen.py index d715361146ea..a1c1a8f957f3 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -165,9 +165,11 @@ def parse_native_yaml_struct( rs: List[NativeFunction] = [] bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) for e in es: + assert isinstance(e, dict), f"expected to be dict: {e}" assert isinstance(e.get("__line__"), int), e loc = Location(path, e["__line__"]) funcs = e.get("func") + assert funcs is not None, f"missed 'func' in {e}" with context(lambda: f"in {loc}:\n {funcs}"): func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys) rs.append(func) @@ -268,7 +270,11 @@ def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None: base_func_map[f.func.name.name].append(f) for f in funcs: if f.structured_delegate is not None: - delegate_func = func_map[f.structured_delegate] + delegate_func = func_map.get(f.structured_delegate) + assert delegate_func is not None, ( + f"{f.func.name} is marked as a structured_delegate pointing to " + f"{f.structured_delegate}, but {f.structured_delegate} is missing." + ) assert delegate_func.structured, ( f"{f.func.name} is marked as a structured_delegate pointing to " f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. " diff --git a/torchgen/model.py b/torchgen/model.py index 2706f234c56b..bed8f262f592 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -626,6 +626,9 @@ def from_yaml( assert device_check_s is None or isinstance( device_check_s, str ), f"not a str: {device_check_s}" + assert ( + device_check_s is None or device_check_s in DeviceCheckType.__members__ + ), f"illegal device_check: {device_check_s}" device_check: DeviceCheckType if device_check_s is None: device_check = DeviceCheckType.ExactSame @@ -706,7 +709,12 @@ def from_yaml( for ks, v in raw_dispatch.items(): if ks == "__line__": continue # not worth tracking line numbers for dispatch entries - assert isinstance(ks, str), e + assert isinstance( + ks, str + ), f"illegal dispatch key '{ks}' in {raw_dispatch}" + assert isinstance( + v, str + ), f"illegal dispatch value '{v}' in {raw_dispatch}" for k in ks.split(","): dispatch_key = DispatchKey.parse(k.strip()) num_dispatch_keys += 1 @@ -2006,8 +2014,12 @@ def alias_info(self) -> Optional[Annotation]: def parse(arg: str) -> "Argument": name: str default: Optional[str] + assert " " in arg, f"illegal argument '{arg}'" type_and_annot, name_and_default = arg.rsplit(" ", 1) if "=" in name_and_default: + assert ( + name_and_default.count("=") == 1 + ), f"illegal argument with default value: '{name_and_default}'" name, default = name_and_default.split("=") else: name = name_and_default @@ -2792,6 +2804,9 @@ def parse(src: object) -> "Precompute": ) arg, with_list_raw = raw_replace_item.split(" -> ") + assert ( + " " not in arg + ), f"illegal kernel param name '{arg}' in precomputed parameters'" with_list = with_list_raw.split(",") with_list_args = [Argument.parse(name.strip()) for name in with_list] replace[arg] = with_list_args From 8777443d73e1db5d83d3962587a03e0a544b1144 Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 30 May 2024 16:26:33 +0000 Subject: [PATCH 103/706] Remove FindMatlabMex.cmake (#127414) It is not used anymore. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127414 Approved by: https://github.com/ezyang --- cmake/Modules/FindMatlabMex.cmake | 48 ------------------------------- 1 file changed, 48 deletions(-) delete mode 100644 cmake/Modules/FindMatlabMex.cmake diff --git a/cmake/Modules/FindMatlabMex.cmake b/cmake/Modules/FindMatlabMex.cmake deleted file mode 100644 index 28ae65e7cbba..000000000000 --- a/cmake/Modules/FindMatlabMex.cmake +++ /dev/null @@ -1,48 +0,0 @@ -# This module looks for MatlabMex compiler -# Defines variables: -# Matlab_DIR - Matlab root dir -# Matlab_mex - path to mex compiler -# Matlab_mexext - path to mexext - -if(MSVC) - foreach(__ver "9.30" "7.14" "7.11" "7.10" "7.9" "7.8" "7.7") - get_filename_component(__matlab_root "[HKEY_LOCAL_MACHINE\\SOFTWARE\\MathWorks\\MATLAB\\${__ver};MATLABROOT]" ABSOLUTE) - if(__matlab_root) - break() - endif() - endforeach() -endif() - -if(APPLE) - foreach(__ver "R2014b" "R2014a" "R2013b" "R2013a" "R2012b" "R2012a" "R2011b" "R2011a" "R2010b" "R2010a") - if(EXISTS /Applications/MATLAB_${__ver}.app) - set(__matlab_root /Applications/MATLAB_${__ver}.app) - break() - endif() - endforeach() -endif() - -if(UNIX) - execute_process(COMMAND which matlab OUTPUT_STRIP_TRAILING_WHITESPACE - OUTPUT_VARIABLE __out RESULT_VARIABLE __res) - - if(__res MATCHES 0) # Suppress `readlink` warning if `which` returned nothing - execute_process(COMMAND which matlab COMMAND xargs readlink - COMMAND xargs dirname COMMAND xargs dirname COMMAND xargs echo -n - OUTPUT_VARIABLE __matlab_root OUTPUT_STRIP_TRAILING_WHITESPACE) - endif() -endif() - - -find_path(Matlab_DIR NAMES bin/mex bin/mexext PATHS ${__matlab_root} - DOC "Matlab directory" NO_DEFAULT_PATH) - -find_program(Matlab_mex NAMES mex mex.bat HINTS ${Matlab_DIR} PATH_SUFFIXES bin NO_DEFAULT_PATH) -find_program(Matlab_mexext NAMES mexext mexext.bat HINTS ${Matlab_DIR} PATH_SUFFIXES bin NO_DEFAULT_PATH) - -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(MatlabMex DEFAULT_MSG Matlab_mex Matlab_mexext) - -if(MATLABMEX_FOUND) - mark_as_advanced(Matlab_mex Matlab_mexext) -endif() From e9a6bbbf7c60583fb2fba132c15122bb12c728ec Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 30 May 2024 17:01:02 +0000 Subject: [PATCH 104/706] Revert "[CI] add xpu test in periodic workflow (#126410)" This reverts commit 30d98611a3a35287c47ded9647f0b4c81fbdf036. Reverted https://github.com/pytorch/pytorch/pull/126410 on behalf of https://github.com/malfet due to Let's sync up on the test strategy/policies here ([comment](https://github.com/pytorch/pytorch/pull/126410#issuecomment-2140269549)) --- .github/workflows/periodic.yml | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index cf2da41c1c44..716a72cc6d23 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -242,30 +242,3 @@ jobs: build-environment: linux-focal-rocm6.1-py3.8 docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} - - linux-jammy-xpu-py3_8-build: - name: linux-jammy-xpu-py3.8 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-jammy-xpu-py3.8 - docker-image-name: pytorch-linux-jammy-xpu-2024.0-py3 - runner: linux.2xlarge - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" }, - { config: "default", shard: 2, num_shards: 4, runner: "linux.idc.xpu" }, - { config: "default", shard: 3, num_shards: 4, runner: "linux.idc.xpu" }, - { config: "default", shard: 4, num_shards: 4, runner: "linux.idc.xpu" }, - ]} - - linux-jammy-xpu-py3_8-test: - name: linux-jammy-xpu-py3.8 - uses: ./.github/workflows/_xpu-test.yml - needs: linux-jammy-xpu-py3_8-build - permissions: - id-token: write - contents: read - with: - build-environment: linux-jammy-xpu-py3.8 - docker-image: ${{ needs.linux-jammy-xpu-py3_8-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-xpu-py3_8-build.outputs.test-matrix }} From 12d6446507df794e5f1f563250bbbd8bbd08044b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 30 May 2024 17:18:23 +0000 Subject: [PATCH 105/706] Revert "[inductor] fix mkldnn linear binary fusion check ut (#127296)" This reverts commit cdeb242fc977210e211fd77b217320205c9f4042. Reverted https://github.com/pytorch/pytorch/pull/127296 on behalf of https://github.com/huydhn due to Sorry for reverting you change but one of the tests is failing on trunk ROCm. Please help fix and reland the change https://github.com/pytorch/pytorch/actions/runs/9302535020/job/25606932572 ([comment](https://github.com/pytorch/pytorch/pull/127296#issuecomment-2140334323)) --- test/inductor/test_mkldnn_pattern_matcher.py | 72 +++----------------- torch/_inductor/fx_passes/mkldnn_fusion.py | 14 ++-- 2 files changed, 13 insertions(+), 73 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 92a9dd59a5dc..756de35df84c 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -10,7 +10,7 @@ from torch._dynamo import config as dynamo_config from torch._dynamo.utils import counters from torch._export import capture_pre_autograd_graph -from torch._inductor import config, metrics +from torch._inductor import config from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch.ao.quantization.quantize_pt2e import ( @@ -264,7 +264,6 @@ def forward(self, x): memory_format, dtype, ) in options: - metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) else: @@ -285,18 +284,6 @@ def forward(self, x): # Has extra dtype conversion nodes for autocast. match_nodes += 2 self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype) - generated_kernel_count = 0 - if dtype != torch.float32: - # "to_dtype" for input - generated_kernel_count = 1 - if memory_format == torch.contiguous_format: - # "to_dtype + to_channel_last" for input, "to_contiguous" for output - generated_kernel_count = 2 - if memory_format == torch.channels_last_3d: - # for float conv3d, the output for eager is channel last, we will generate "to_contiguous" for output - # for lp conv3d, the output for eager is channel last too, we will only generate "to_dtype" - generated_kernel_count = 1 - self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) def test_conv2d_unary_cpu(self): self._test_conv_unary_cpu_base(dim=4) @@ -334,7 +321,6 @@ def forward(self, x): dtypes.append(torch.float16) options = itertools.product(unary_list, [True, False], dtypes) for unary_fn, bias, dtype in options: - metrics.reset() mod = M(unary_fn, 10, 30, bias=bias).eval() # only fuse for linear when the dtype is bf16 mod = mod @@ -349,8 +335,6 @@ def forward(self, x): self._test_common( mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype ) - # only generated 1 kernel for "to" - self.assertEqual(metrics.generated_kernel_count, 1) @unittest.skipIf(not TEST_MKL, "Test requires MKL") def test_linear_fp32(self): @@ -402,7 +386,6 @@ def forward(self, x): ) for unary_fn, memory_format, dtype in options: - metrics.reset() x_shape = (1, 3, 28, 28) mod = M(unary_fn).eval() @@ -418,14 +401,6 @@ def forward(self, x): # Has extra dtype conversion nodes for autocast. match_nodes += 2 self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype) - generated_kernel_count = 0 - if dtype != torch.float32: - # "to" for input - generated_kernel_count = 1 - if memory_format == torch.contiguous_format: - # "to_dtype + to_channel_last" for input, "to_contiguous" for output - generated_kernel_count = 2 - self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) def _test_conv_binary_base(self, dim=4): assert dim == 4 or dim == 5 @@ -455,29 +430,19 @@ def forward(self, x): else: return self.binary_fn(x1, x2) - dtypes = [ - torch.float, - ] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): - dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): - dtypes.append(torch.float16) cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d test_memory_format = [torch.contiguous_format, cl_format] options = itertools.product( binary_list, [True, False], test_memory_format, - dtypes, ) for ( binary_fn, has_relu, memory_format, - dtype, ) in options: - metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) else: @@ -492,19 +457,7 @@ def forward(self, x): match_nodes = binary_list[binary_fn][1] if has_relu: match_nodes += 1 - self._test_common( - mod, (v,), match_count, match_nodes + 2, check_autocast=dtype - ) - generated_kernel_count = 0 - if dtype != torch.float32: - # "to_dtype" for input - generated_kernel_count = 1 - if memory_format == torch.contiguous_format: - # "to_dtype + to_channel_last" for input, "to_contiguous" for output - generated_kernel_count = 2 - elif memory_format == torch.channels_last_3d: - generated_kernel_count = 1 - self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) + self._test_common(mod, (v,), match_count, match_nodes + 2) def test_conv2d_binary(self): self._test_conv_binary_base(dim=4) @@ -536,7 +489,7 @@ def forward(self, x, y): ) out_feature = 30 for binary_fn, input_shape, bias, dtype in options: - metrics.reset() + torch._dynamo.reset() # addmm(mm) + (linear+add) match_count = 2 match_nodes = 3 @@ -545,20 +498,13 @@ def forward(self, x, y): # view + linear + view(joint_graph+freeze pass) match_count = match_count + 5 if is_inplace else match_count + 3 match_nodes = match_nodes + 7 if is_inplace else match_nodes + 5 - mod = M(binary_fn, input_shape[-1], out_feature, bias).eval() - v = torch.randn(input_shape) + mod = M(binary_fn, input_shape[-1], out_feature, bias).to(dtype).eval() + v = torch.randn(input_shape).to(dtype) other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype) - self._test_common( - mod, - ( - v, - other, - ), - match_count, - match_nodes, - check_autocast=dtype, - ) - self.assertEqual(metrics.generated_kernel_count, 1) + mod_c = torch.compile(mod) + out, code = run_and_get_code(mod_c, v, other) + self.assertEqual(out, mod(v, other), rtol=1e-2, atol=1e-2) + # TODO - assert fusions work code def test_multi_linear_share_same_input(self): # llama pattern. diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 5d1a723fa58a..3edb4a397932 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -197,15 +197,9 @@ def _binary_fusion_v1(computation_call, binary_fn): def _binary_fusion_v2(computation_call, binary_fn): return CallFunction(binary_fn, computation_call, KeywordArg("other")) - def _is_single_computation_op(computation_op, lowp_dtype=None): + def _is_single_computation_op(computation_op): def fn(match): computation_nodes = filter_nodes(match.nodes, computation_op) - - if lowp_dtype: - output_node_meta = match.output_node().meta.get("val") - if output_node_meta.dtype != lowp_dtype: - return False - if len(computation_nodes) < 1: return False if any(n.args[-3] != "none" for n in computation_nodes): @@ -216,7 +210,7 @@ def fn(match): def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): def fn(match): - matched = _is_single_computation_op(computation_op, lowp_dtype)(match) + matched = _is_single_computation_op(computation_op)(match) computation_node = filter_nodes(match.nodes, computation_op)[0] if lowp_dtype: conversion_dtype_nodes = filter_nodes( @@ -255,7 +249,7 @@ def fn(match, *args, **kwargs): def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): @register_lowering_pattern( - pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) + pattern, extra_check=_is_single_computation_op(computation_op) ) def fn(match, *args, **kwargs): negative_slope = kwargs.get("negative_slope") @@ -297,7 +291,7 @@ def fn(match, *args, **kwargs): def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): @register_lowering_pattern( - pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) + pattern, extra_check=_is_single_computation_op(computation_op) ) def fn(match, *args, **kwargs): min_value = kwargs.get("min_value") From f9937afd4f87fbb4844642ae2f587b13b5caa08c Mon Sep 17 00:00:00 2001 From: James Wu Date: Thu, 30 May 2024 10:33:38 -0700 Subject: [PATCH 106/706] Add noqa to prevent lint warnings (#127545) This is to prevent the import from being removed due to unused import. What's annoying about this is that it's not consistently running: lintrunner doesn't warn me on this PR even without the comment, but it does on other PRs Pull Request resolved: https://github.com/pytorch/pytorch/pull/127545 Approved by: https://github.com/masnesral --- test/inductor/test_halide.py | 2 +- test/inductor/test_kernel_benchmark.py | 2 +- test/inductor/test_triton_wrapper.py | 2 +- torch/_inductor/autotune_process.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 2 +- torch/_inductor/compile_fx.py | 2 +- torch/_inductor/scheduler.py | 2 +- torch/_inductor/select_algorithm.py | 2 +- torch/testing/_internal/inductor_utils.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/inductor/test_halide.py b/test/inductor/test_halide.py index 158a669dad2f..9b923bd1981d 100644 --- a/test/inductor/test_halide.py +++ b/test/inductor/test_halide.py @@ -3,7 +3,7 @@ import unittest import torch -import torch._inductor.async_compile +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._inductor.codecache import HalideCodeCache from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta from torch._inductor.test_case import run_tests, TestCase diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 87ddb0bec2e6..ffe0300d8aad 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -6,7 +6,7 @@ from unittest.mock import patch import torch -import torch._inductor.async_compile +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.testing import rand_strided from torch._inductor import config from torch._inductor.codecache import PyCodeCache diff --git a/test/inductor/test_triton_wrapper.py b/test/inductor/test_triton_wrapper.py index 7f7ded46182a..f0d3ad829d45 100644 --- a/test/inductor/test_triton_wrapper.py +++ b/test/inductor/test_triton_wrapper.py @@ -4,7 +4,7 @@ import sys import torch -import torch._inductor.async_compile +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._inductor.codecache import PyCodeCache from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 5e04211639d6..c9462d788e8d 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -25,7 +25,7 @@ ) import torch -import torch._inductor.async_compile +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch import multiprocessing from torch._dynamo.testing import rand_strided diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index cefd7e96acce..90a0702d4b8a 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -10,7 +10,7 @@ import torch -import torch._inductor.async_compile +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch._ops from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey from .. import config, ir diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 26d75669a206..c9a253709f40 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from unittest import mock -import torch._inductor.async_compile +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch.fx import torch.utils._pytree as pytree diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 5a13c7f3cae4..a7517575d888 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -28,7 +28,7 @@ import sympy import torch -import torch._inductor.async_compile +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 124c6aea6125..13bfdcb60fda 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -22,7 +22,7 @@ from filelock import FileLock import torch -import torch._inductor.async_compile +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.testing import rand_strided from torch._dynamo.utils import counters, identity, preserve_rng_state diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index d441988d4bd2..1078a189f69c 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -5,7 +5,7 @@ import unittest import functools from subprocess import CalledProcessError -import torch._inductor.async_compile +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._inductor.codecache import CppCodeCache from torch.utils._triton import has_triton from torch.testing._internal.common_utils import ( From 7827afca1440c897100ee8afeb0da69e3120199e Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Wed, 29 May 2024 21:40:26 -0700 Subject: [PATCH 107/706] Copy the constant folding pass to the pass under export/passes folder (#127456) It's a generic pass and I'm trying to find a good place to host it. It's currently needed by quantization flow. See context in D55930580, it's too much effort to land a fix in the inductor folder. Differential Revision: [D57934182](https://our.internmc.facebook.com/intern/diff/D57934182/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127456 Approved by: https://github.com/angelayi --- torch/_export/passes/constant_folding.py | 297 +++++++++++++++++++++++ torch/ao/quantization/quantize_pt2e.py | 2 +- 2 files changed, 298 insertions(+), 1 deletion(-) create mode 100644 torch/_export/passes/constant_folding.py diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py new file mode 100644 index 000000000000..54b7a1565924 --- /dev/null +++ b/torch/_export/passes/constant_folding.py @@ -0,0 +1,297 @@ +import collections +from collections import defaultdict +from typing import Any, Callable, Dict, Optional + +import torch +import torch.utils._pytree as pytree + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + + +def replace_node_with_constant(gm, node, constant, name=None): + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm, + skip_constructors=False, + ): + super().__init__(gm) + self.node_replacements: Dict[torch.fx.Node, Any] = {} + self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + + def is_impure(self, node: torch.fx.node.Node): + if ( + node.target == torch.ops.prims.convert_element_type.default + and node.args[0].op == "get_attr" # type: ignore[union-attr] + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ): + # For int8_weight -> dq -> bf16_weight + return True + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + return False + + def node_to_last_non_output_use(self): + last_non_output_use = collections.defaultdict(list) + seen_uses = set() + output_node = next(iter(reversed(self.module.graph.nodes))) + + for node in reversed(self.module.graph.nodes): + if node.target == "output": + continue + + def add_use(inp): + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node): + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg): + self.env[arg] = self.unknown_value + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + # We need to do this weird thing because in cases where flattened_inputs + # contains a ScriptObject, equality checking results in a type error if + # the types are different. + if any( + type(self.unknown_value) == type(input_) and self.unknown_value == input_ + for input_ in flattened_inputs + ): + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target == aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and node.op != "get_attr" + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + out = super().run_node(node) + + if node.op != "get_attr" and isinstance(out, torch.Tensor): + if out.device.type == "meta": + return out + + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self): + env = {} + for n in self.module.graph.find_nodes(op="placeholder"): + env[n] = self.unknown_value + return super().run(initial_env=env) + + +@torch.utils._python_dispatch._disable_current_modes() +def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + # Get all attr users by looking up the graph instead from node.users, because in this case + # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor. + + # opcode name target args kwargs + # ------------- ------------------- ---------------- --------------------------- -------- + # placeholder arg0_1 arg0 () {} + # get_attr _tensor_constant0 state () {} + # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {} + # get_attr _tensor_constant0_1 state () {} + # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {} + # output output output ([add],) {} + + get_attr_node_users = defaultdict(list) + for node in gm.graph.nodes: + if node.op == "get_attr": + get_attr_node_users[node.target].extend(node.users.keys()) + for node in gm.graph.find_nodes(op="get_attr"): + if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +@torch.utils._python_dispatch._disable_current_modes() +def constant_graph_tag(gm: torch.fx.GraphModule): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node in gm.graph.nodes: + if ( + node.op == "get_attr" + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag(gm) + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.find_nodes(op="get_attr"): + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + + new_graph = torch.fx.Graph() + + node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + return new_gm diff --git a/torch/ao/quantization/quantize_pt2e.py b/torch/ao/quantization/quantize_pt2e.py index d9919aa2e9c5..b312d89911a5 100644 --- a/torch/ao/quantization/quantize_pt2e.py +++ b/torch/ao/quantization/quantize_pt2e.py @@ -26,7 +26,7 @@ from torch.fx.passes.infra.pass_manager import PassManager from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ -from torch._inductor.constant_folding import constant_fold +from torch._export.passes.constant_folding import constant_fold __all__ = [ "prepare_pt2e", From 39cf2f8e66d42a4587b2f722ec77c66e5ce1d46d Mon Sep 17 00:00:00 2001 From: Ali Waheed Date: Thu, 30 May 2024 18:13:22 +0000 Subject: [PATCH 108/706] Added sorting notes for eig/eigvals (#127492) Fixes #58034 @lezcano , Added suggested comments for eig and eigvals in the documentation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127492 Approved by: https://github.com/lezcano, https://github.com/kit1980 --- torch/linalg/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 29df838bab54..0637f3f7b83c 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -457,6 +457,8 @@ Also supports batches of matrices, and if :attr:`A` is a batch of matrices then the output has the same batch dimensions. +The returned eigenvalues are not guaranteed to be in any specific order. + .. note:: The eigenvalues and eigenvectors of a real matrix may be complex. """ + fr""" @@ -559,6 +561,8 @@ Also supports batches of matrices, and if :attr:`A` is a batch of matrices then the output has the same batch dimensions. +The returned eigenvalues are not guaranteed to be in any specific order. + .. note:: The eigenvalues of a real matrix may be complex, as the roots of a real polynomial may be complex. The eigenvalues of a matrix are always well-defined, even when the matrix is not diagonalizable. From bfdec93395f675a0e5a59e95aef9104ac8f5081a Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 29 May 2024 17:14:06 -0700 Subject: [PATCH 109/706] Default XLA to use swap_tensors path in nn.Module._apply (#126814) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126814 Approved by: https://github.com/JackCaoG, https://github.com/albanD ghstack dependencies: #127313 --- test/test_nn.py | 4 ++-- torch/nn/modules/module.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 6dfac4f7ca1b..6bcb4017e4b5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8184,9 +8184,9 @@ def test_batchnorm_large_batch(self, device, dtype): @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128) def test_conv_empty_input(self, device, dtype): def help(input, conv, memory_format): - ref_out = conv(input) + ref_out = conv(input).detach() conv_cl = conv.to(memory_format=memory_format) - out_cl = conv_cl(input) + out_cl = conv_cl(input).detach() self.assertEqual(ref_out, out_cl) input_cl = input.to(memory_format=memory_format) out_cl2 = conv(input_cl) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 58129acd48a3..80f7876f28fd 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -794,6 +794,13 @@ def compute_should_use_set_data(tensor, tensor_applied): should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() + def compute_should_use_swap_tensors(tensor, tensor_applied): + return (should_use_swap_tensors + # subclasses may have multiple child tensors so we need to use swap_tensors + or is_traceable_wrapper_subclass(tensor_applied) + or tensor.device.type == 'xla' + or tensor_applied.device.type == 'xla') + for key, param in self._parameters.items(): if param is None: continue @@ -804,8 +811,7 @@ def compute_should_use_set_data(tensor, tensor_applied): param_applied = fn(param) p_should_use_set_data = compute_should_use_set_data(param, param_applied) - # subclasses may have multiple child tensors so we need to use swap_tensors - p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) + p_should_use_swap_tensors = compute_should_use_swap_tensors(param, param_applied) param_grad = param.grad if p_should_use_swap_tensors: From fa426b096b3635daab6ce26b44d50f3baab5a4e5 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 29 May 2024 17:14:06 -0700 Subject: [PATCH 110/706] Default meta device to use swap_tensors in nn.Module._apply (.to_empty and .to('meta')) (#126819) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126819 Approved by: https://github.com/albanD ghstack dependencies: #127313, #126814 --- test/test_modules.py | 27 +++++++++++++++++---------- torch/nn/modules/module.py | 2 ++ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index e854eec8add7..601cf5cefdf9 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -983,16 +983,23 @@ def test_to_empty(self, device, dtype, module_info, swap, training): p_ids_after = [id(p) for p in m.parameters()] p_cdatas_after = [p._cdata for p in m.parameters()] - if swap: - # id same, ._cdata differs --> swapped cdata of THPVariable - self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) - self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) - else: - # id and ._cdata differ - # meta and device have different shallow copy types, so this will create a new - # parameter and assign it to the module - self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after))) - self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) + # id same, ._cdata differs --> swapped cdata of THPVariable + # Technically, meta and device have different shallow copy types, so when swap=False it will create a new + # parameter and assign it to the module BUT we opt into swap_tensors when either one is on meta. + self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) + self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) + + # Test the opposite direction device --> meta + m = m.to(device="meta") + + p_ids_after_meta = [id(p) for p in m.parameters()] + p_cdatas_after_meta = [p._cdata for p in m.parameters()] + + # id same, ._cdata differs --> swapped cdata of THPVariable + # Technically, meta and device have different shallow copy types, so when swap=False it will create a new + # parameter and assign it to the module BUT we opt into swap_tensors when either one is on meta. + self.assertTrue(all(a == b for a, b in zip(p_ids_after, p_ids_after_meta))) + self.assertTrue(all(a != b for a, b in zip(p_cdatas_after, p_cdatas_after_meta))) instantiate_device_type_tests(TestModule, globals(), allow_mps=True) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 80f7876f28fd..2e65bb97c659 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -798,6 +798,8 @@ def compute_should_use_swap_tensors(tensor, tensor_applied): return (should_use_swap_tensors # subclasses may have multiple child tensors so we need to use swap_tensors or is_traceable_wrapper_subclass(tensor_applied) + or tensor.device.type == 'meta' + or tensor_applied.device.type == 'meta' or tensor.device.type == 'xla' or tensor_applied.device.type == 'xla') From 4afc5c7bb9f20a2b8abd33066c29ac8e30846272 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Thu, 30 May 2024 18:35:44 +0000 Subject: [PATCH 111/706] [torchscript] Handle prim::device and prim::dtype (#127466) - Support prim::device and prim::dtype during torchscript migration to export - Add unit tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/127466 Approved by: https://github.com/SherlockNoMad --- test/export/test_converter.py | 44 ++++++++++++++++++++++++++++++++++ torch/_export/converter.py | 45 ++++++++++++++++++++++++++++------- 2 files changed, 81 insertions(+), 8 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index b6d0e54a59e1..64cea8cf8ac9 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1,5 +1,7 @@ # Owner(s): ["oncall: export"] +import unittest + import torch import torch.utils._pytree as pytree @@ -9,6 +11,8 @@ from torch.testing._internal.common_utils import run_tests +requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") + class TestConverter(TestCase): def _check_equal_ts_ep_converter(self, mod, inp): @@ -64,6 +68,46 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): self._check_equal_ts_ep_converter(MOutputTuple(), inp) self._check_equal_ts_ep_converter(MOutputDict(), inp) + def test_prim_device(self): + class Module(torch.nn.Module): + def forward(self, x): + device = x.device + return torch.ones(2, 3, device=device) + + inp = (torch.rand(3, 4),) + self._check_equal_ts_ep_converter(Module(), inp) + + @requires_cuda + def test_prim_device_cuda(self): + class Module(torch.nn.Module): + def forward(self, x): + device = x.device + return torch.ones(2, 3, device=device) + + inp = (torch.rand((3, 4), device="cuda:0"),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_prim_dtype(self): + class Module(torch.nn.Module): + def forward(self, x): + dtype = x.dtype + return torch.ones(2, 3, dtype=dtype) + + for dtype in [ + torch.float32, + torch.double, + ]: + inp = (torch.rand((3, 4), dtype=dtype),) + self._check_equal_ts_ep_converter(Module(), inp) + + for dtype in [ + torch.uint8, + torch.int8, + torch.int32, + ]: + inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),) + self._check_equal_ts_ep_converter(Module(), inp) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 459f534ca636..7e6812985bad 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -5,6 +5,7 @@ from torch.export.exported_program import ExportedProgram from torch.export.graph_signature import ( + ConstantArgument, InputKind, InputSpec, OutputKind, @@ -201,6 +202,20 @@ def convert_prim_Constant(self, node: torch._C.Node): self.constant_map[name] = value + def convert_prim_device(self, node: torch._C.Node): + input_type = node.input().type() + if input_type.isSubtypeOf(torch._C.TensorType.get()): + device = input_type.device() # type: ignore[attr-defined] + output_name = node.output().debugName() + self.constant_map[output_name] = device + else: + raise ValueError(f"Unsupported JitType ({input_type}) when get device") + + def convert_prim_dtype(self, node: torch._C.Node): + dtype = node.input().type().dtype() + output_name = node.output().debugName() + self.constant_map[output_name] = dtype + def convert_prim_GetAttr(self, node: torch._C.Node): def get_attr(name: str): if name in self.attribute_map: @@ -350,6 +365,10 @@ def convert_node(self, node: torch._C.Node): elif node_kind in {"prim::ListConstruct", "prim::TupleConstruct"}: # Tuple is just a non-mutable List, so we can handle them together. self.convert_prim_ListConstruct(node) + elif node_kind == "prim::device": + self.convert_prim_device(node) + elif node_kind == "prim::dtype": + self.convert_prim_dtype(node) elif node_kind == "prim::DictConstruct": self.convert_prim_DictConstruct(node) # elif node_kind == "aten::Int": @@ -369,17 +388,27 @@ def convert_graph_outputs(self): output_name = graph_output.debugName() if output_name in self.name_to_node: args.append(self.name_to_node[output_name]) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=TensorArgument(name=output_name), + target=output_name, + ) + ) + elif output_name in self.constant_map: + args.append(self.constant_map[output_name]) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=ConstantArgument( + name=output_name, value=self.constant_map[output_name] + ), + target=output_name, + ) + ) else: raise ValueError(f"Output {output_name} not found") - self.output_specs.append( - OutputSpec( - OutputKind.USER_OUTPUT, - arg=TensorArgument(name=output_name), - target=output_name, - ) - ) - self.fx_graph.output( args[0] ) # Get rid of an extra list wrapped around final output. From 2cb6f20867525f5f1906facfddf3d62aa85d4b04 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Thu, 30 May 2024 19:10:53 +0000 Subject: [PATCH 112/706] Warn env vars only once during program (#127046) This avoids logs being excessively noisy in some training runs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127046 Approved by: https://github.com/kwen2501, https://github.com/wconstab --- torch/csrc/distributed/c10d/Utils.hpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index b193c8971b57..a03337e97514 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -38,6 +38,11 @@ TORCH_API std::vector getTensorShapes( // Use -2 to represent unset state of env vars #define C10D_ENV_NOT_SET -2 +#define WARN_ENV_VAR_ONCE(deprecated_env, new_env) \ + TORCH_WARN_ONCE( \ + "Environment variable " + deprecated_env + " is deprecated; use " + \ + new_env + " instead"); + // Turns at::IntArrayRef into "(1, 2, 3, 4)". inline std::string toString(at::IntArrayRef l) { std::stringstream ss; @@ -102,9 +107,7 @@ inline std::string getCvarString( if (val == nullptr) { continue; } else if (i) { - TORCH_WARN( - "Environment variable " + env[i] + " is deprecated; use " + env[0] + - " instead"); + WARN_ENV_VAR_ONCE(env[i], env[0]); } ret = val; @@ -129,9 +132,7 @@ inline int getCvarInt(const std::vector& env, int def) { if (val == nullptr) { continue; } else if (i) { - TORCH_WARN( - "Environment variable " + env[i] + " is deprecated; use " + env[0] + - " instead"); + WARN_ENV_VAR_ONCE(env[i], env[0]); } try { @@ -160,9 +161,7 @@ inline bool getCvarBool(const std::vector& env, bool def) { if (val_ == nullptr) { continue; } else if (i) { - TORCH_WARN( - "Environment variable " + env[i] + " is deprecated; use " + env[0] + - " instead"); + WARN_ENV_VAR_ONCE(env[i], env[0]); } std::string val = std::string(val_); From 19333d1eb9b8965edd6c8a52fd59b5c67b4fb523 Mon Sep 17 00:00:00 2001 From: Prachi Gupta Date: Thu, 30 May 2024 19:26:58 +0000 Subject: [PATCH 113/706] [ROCm] Update triton pin to fix libtanh issue (#125396) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/125396 Approved by: https://github.com/pruthvistony, https://github.com/nmacchioni --- .ci/docker/ci_commit_pins/triton-rocm.txt | 2 +- test/inductor/test_cpu_cpp_wrapper.py | 14 ++++++++++++-- test/inductor/test_triton_kernels.py | 1 + 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton-rocm.txt b/.ci/docker/ci_commit_pins/triton-rocm.txt index 2df035af1fdd..15f681977a12 100644 --- a/.ci/docker/ci_commit_pins/triton-rocm.txt +++ b/.ci/docker/ci_commit_pins/triton-rocm.txt @@ -1 +1 @@ -bbe6246e37d8aa791c67daaf9d9d61b26c9ccfdc +01cbe5045a6898c9a925f01435c8277b2fe6afcc diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 0888f3ad47a1..10744a675fbf 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -9,7 +9,7 @@ from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) -from torch.testing._internal.common_utils import IS_MACOS, slowTest +from torch.testing._internal.common_utils import IS_MACOS, slowTest, TEST_WITH_ROCM from torch.testing._internal.inductor_utils import HAS_CPU @@ -68,7 +68,17 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): ("cpp_wrapper",), is_skip=True ), } - +if TEST_WITH_ROCM: + test_failures_cpp_wrapper.update( + { + "test_linear_packed": test_torchinductor.TestFailure( + ("cpp_wrapper"), is_skip=True + ), + "test_linear_packed_dynamic_shapes": test_torchinductor.TestFailure( + ("cpp_wrapper"), is_skip=True + ), + } + ) if config.abi_compatible: xfail_list = [ "test_conv2d_binary_inplace_fusion_failed_cpu", diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index accab8beae6b..41b20188f635 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -586,6 +586,7 @@ def call_triton( self.assertEqual(int_result, resulti) @requires_cuda + @skipIfRocm def test_triton_kernel_constants(self): @triton.jit def mulC_kernel( From ff23c5b7d7aa6fa9965e07738bc9ae925db3b041 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Thu, 30 May 2024 19:57:32 +0000 Subject: [PATCH 114/706] [cudagraph] improve log for mutating static input tensor addresses (#127145) Summary: This diff adds more log for cudagraph when static input tensor mutates. For each placeholder whose static input tensor address mutates, we log its name, changed data pointer address, and the input stack trace. Since some placeholder may have empty stack trace, we find its first user with an non-empty stack trace and print this stack trace instead. Test Plan: buck2 run fbcode//caffe2/test/inductor:cudagraph_trees -- --r test_static_inputs_address_mutation_log Differential Revision: D57805118 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127145 Approved by: https://github.com/eellison --- test/inductor/test_cudagraph_trees.py | 41 +++++++++++++++++++++++++++ torch/_inductor/cudagraph_trees.py | 17 +++++++---- torch/_inductor/cudagraph_utils.py | 14 +++++++++ 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 1ac9af7bc6e7..58c819b804ff 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -1729,6 +1729,47 @@ def test_storage_access_error(self): with self.assertRaisesRegex(Exception, "custom error msg"): device = x.untyped_storage() + def test_static_inputs_address_mutation_log(self): + class Goo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2, device="cuda") + + def forward(self, x) -> torch.Tensor: + return self.linear(x) + + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.static_tensor = torch.zeros((2, 2), device="cuda") + self.goo = Goo() + + def forward(self, x) -> torch.Tensor: + self.static_tensor.add_(torch.ones((2, 2), device="cuda")) + return self.static_tensor + x + self.goo(x) + + foo = Foo() + foo = torch.compile(foo, mode="reduce-overhead") + inp = torch.rand((2, 2), device="cuda") + + for _ in range(3): + foo(inp) + + # mutates static input tensors' addresses + foo.static_tensor = torch.ones((2, 2), device="cuda") + foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda")) + + with self.assertRaisesRegex( + Exception, + r"static input data pointer changed.\n" + r"input name: primals_2. data pointer changed from .* to .*. input stack trace: None\n" + r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*," + r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n\n", + ): + self.curr_node().run( + [foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp] + ) + instantiate_parametrized_tests(CudaGraphTreeTests) if __name__ == "__main__": diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 6fe00710b0af..e7a1f3364823 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -80,6 +80,7 @@ from torch._inductor.cudagraph_utils import ( check_for_mutation, FunctionID, + get_placeholder_stack_trace, log_cudagraph_skip_and_bump_counter, WrappedFunction, ) @@ -960,11 +961,17 @@ def check_static_inputs_are_stable(self, new_inputs): self.static_input_data_ptrs[i] for i in self.non_managed_static_input_idxs ] - for t, data_ptr in zip(static_tensors, data_ptrs): - torch._check( - t.data_ptr() == data_ptr, - lambda: f"static input data pointer changed from {data_ptr} to {t.data_ptr()}", - ) + error_msg = "static input data pointer changed.\n" + for i, (t, data_ptr) in enumerate(zip(static_tensors, data_ptrs)): + index = self.non_managed_static_input_idxs[i] + if t.data_ptr() != data_ptr: + placeholder = self.wrapped_function.placeholders[index] + error_msg = ( + f"{error_msg}input name: {placeholder.name}. " + f"data pointer changed from {data_ptr} to {t.data_ptr()}. " + f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n" + ) + torch._check(False, lambda: error_msg) def run_first_inputs(self, new_inputs): if config.triton.fast_path_cudagraph_asserts: diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index c87022fcb788..a1ac4936f417 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -162,3 +162,17 @@ def check_for_mutation_ignore_cuda_graph_managed_tensor( else: has_mutation = len(compiled_graph.mutated_inputs) != 0 return None if not has_mutation else default_msg + + +def get_placeholder_stack_trace(placeholder: torch.fx.Node) -> Optional[str]: + """ + Gets the first non-empty stack trace of a placeholder or its users. + """ + if placeholder.stack_trace: + return placeholder.stack_trace + + for user in placeholder.users: + if user.stack_trace: + return user.stack_trace + + return None From aa3d041830e414ed53b7626ae8dd33560912df7e Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 29 May 2024 09:47:03 -0700 Subject: [PATCH 115/706] [pipelining] Fix block comments for doc rendering (#127418) Previous: image image New: https://docs-preview.pytorch.org/pytorch/pytorch/127418/distributed.pipelining.html Pull Request resolved: https://github.com/pytorch/pytorch/pull/127418 Approved by: https://github.com/wconstab --- torch/distributed/pipelining/PipelineStage.py | 1 + torch/distributed/pipelining/_IR.py | 38 ++++++++++--------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index 93b67696bc79..7196770994bc 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -1067,6 +1067,7 @@ class ManualPipelineStage(_PipelineStageBase): as opposed to the PipelineStage class that is outputed from pipeline(). This class extends the `_PipelineStageBase` class and can similarly be used in `PipelineScheule`. + Args: submodule (nn.Module): The PyTorch module wrapped by this stage. stage_index (int): The ID of this stage. diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 68465cc6cd0b..7fe8eab83d97 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -323,13 +323,13 @@ def pipe_split(): no-op if your annotated module is run eagerly. Example: - >>> # xdoctest: +SKIP - >>> def forward(self, x): - >>> x = torch.mm(x, self.mm_param) - >>> x = torch.relu(x) - >>> pipe_split() - >>> x = self.lin(x) - >>> return x + >>> # xdoctest: +SKIP + >>> def forward(self, x): + >>> x = torch.mm(x, self.mm_param) + >>> x = torch.relu(x) + >>> pipe_split() + >>> x = self.lin(x) + >>> return x The above example will be split into two stages. """ @@ -1129,15 +1129,16 @@ def pipeline( ) -# Context manager for setting `args_chunk_spec` during creation of Pipe class ArgsChunkSpec: """ + Context manager for setting `args_chunk_spec` during creation of Pipe + Example: - >>> # xdoctest: +SKIP - >>> # There are three positional arguments to the model, and - >>> # we are chunking them along dimension 0, 0 and 1, respectively - >>> with ArgsChunkSpec((0, 0, 1)): - >>> pipe = pipeline(model, num_chunks, example_args) + >>> # xdoctest: +SKIP + >>> # There are three positional arguments to the model, and + >>> # we are chunking them along dimension 0, 0 and 1, respectively + >>> with ArgsChunkSpec((0, 0, 1)): + >>> pipe = pipeline(model, num_chunks, example_args) """ def __init__( @@ -1159,14 +1160,15 @@ def __exit__(self, exc_type, exc_val, traceback): Pipe.args_chunk_spec = None -# Context manager for setting `kwargs_chunk_spec` during creation of Pipe class KwargsChunkSpec: """ + Context manager for setting `kwargs_chunk_spec` during creation of Pipe + Example: - >>> # xdoctest: +SKIP - >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument - >>> with KwargsChunkSpec({"id": 0, "mask": 1}): - >>> pipe = pipeline(model, num_chunks, (), example_kwargs) + >>> # xdoctest: +SKIP + >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument + >>> with KwargsChunkSpec({"id": 0, "mask": 1}): + >>> pipe = pipeline(model, num_chunks, (), example_kwargs) """ def __init__( From cce2192396274a52d37ca6fea06f27edced9148d Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 28 May 2024 13:25:54 -0700 Subject: [PATCH 116/706] [pipelining] Support calling multiple recv fwd/bwd ops (#127084) Currently, only a single `get_fwd_recv_ops` or `get_bwd_recv_ops` can be called before `forward_one_chunk` and `backward_one_chunk` since they both share the same chunk_id counter. This creates a separate `recv_chunk_id` counter so that recvs can be accumulated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127084 Approved by: https://github.com/wconstab --- torch/distributed/pipelining/PipelineStage.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index 7196770994bc..f5aac602faba 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -128,13 +128,18 @@ def __init__( self._outputs_meta: Optional[Tuple[torch.Tensor, ...]] = None # map microbatch ID to list of forward tensor args self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} - # Current forward chunk id + # Current forward chunk id to be used in computation self.fwd_chunk_id: int = 0 - # Current backward chunk id + # Current backward chunk id to be used in computation self.bwd_chunk_id: int = 0 # Caching chunk outputs for final output merge or reduction self.output_chunks: List[Any] = [] + # Current forward chunk id to be used in recv + self.recv_fwd_chunk_id: int = 0 + # Current backward chunk id to be used in recv + self.recv_bwd_chunk_id: int = 0 + # Create stage id to group rank mapping # In interleaved case, `group_rank` is stage index % group size. self.stage_index_to_group_rank: Dict[int, int] = {} @@ -267,15 +272,16 @@ def get_fwd_recv_ops(self) -> List[dist.P2POp]: Returns a list of ops that are needed to receive the input arguments for this stage. """ - recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[self.fwd_chunk_id] + recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[self.recv_fwd_chunk_id] # In case there is backward pass, set requires_grad for receive buffers # before first forward - if self.has_backward and not self.set_requires_grad[self.fwd_chunk_id]: + if self.has_backward and not self.set_requires_grad[self.recv_fwd_chunk_id]: for a in recv_infos: if isinstance(a, _RecvInfo): a.buffer.requires_grad_(True) + self.recv_fwd_chunk_id += 1 return self._get_recv_ops(recv_infos) def get_bwd_recv_ops(self) -> List[dist.P2POp]: @@ -288,11 +294,12 @@ def get_bwd_recv_ops(self) -> List[dist.P2POp]: # Create bwd recv infra lazily recv_infos = self.grad_recv_info.setdefault( - self.bwd_chunk_id, + self.recv_bwd_chunk_id, # `grad_recv_info` is a mirror of `act_send_info` self._create_grad_recv_info(self.act_send_info), ) + self.recv_bwd_chunk_id += 1 return self._get_recv_ops(recv_infos) def get_fwd_send_ops(self) -> List[dist.P2POp]: @@ -370,6 +377,8 @@ def clear_runtime_states(self) -> None: # Reset pointers self.fwd_chunk_id = 0 self.bwd_chunk_id = 0 + self.recv_fwd_chunk_id = 0 + self.recv_bwd_chunk_id = 0 # map microbatch ID to list of forward tensor args self.fwd_cache.clear() # Caching chunk outputs for final output merge or reduction From 846f79e61ab2caab5cef0cc46e79f439ac9634ab Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 30 May 2024 20:45:31 +0000 Subject: [PATCH 117/706] Revert "Reduce number of samples in {svd,pca}_lowrank OpInfos (#127199)" This reverts commit 18a3f781e6382e2222d7c30c18136267407f9953. Reverted https://github.com/pytorch/pytorch/pull/127199 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing MacOS trunk job https://hud.pytorch.org/pytorch/pytorch/commit/18a3f781e6382e2222d7c30c18136267407f9953#25619618844 ([comment](https://github.com/pytorch/pytorch/pull/127199#issuecomment-2140834363)) --- test/functorch/test_ops.py | 5 ----- .../_internal/common_methods_invocations.py | 17 +++++++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 62e43843bdba..bd75a8d0bb74 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -460,11 +460,6 @@ class TestOperators(TestCase): {torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type="cuda", ), - tol1( - "svd_lowrank", - {torch.float32: tol(atol=5e-05, rtol=7e-06)}, - device_type="mps", - ), tol1( "linalg.tensorsolve", {torch.float32: tol(atol=3e-04, rtol=3e-04)}, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1570942abfc8..f59cca11becc 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2108,13 +2108,14 @@ def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad= """ make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - batches = [(), (2,)] - size = [3, 4] + batches = [(), (0, ), (2, ), (1, 1)] + size = [1, 5, 10] + for batch, m, n in product(batches, size, size): - k = 2 - a = make_arg((*batch, m, k)) - b = make_arg((*batch, n, k)) - yield a, b + for k in range(min(3, m, n)): + a = make_arg((*batch, m, k)) + b = make_arg((*batch, n, k)) + yield a, b def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): @@ -17719,6 +17720,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1): 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', dtypes=[torch.complex128]), + DecorateInfo(unittest.skip("See comment above"), + 'TestBwdGradientsCUDA', + 'test_fn_gradgrad', + dtypes=[torch.complex128]), ], skips=( # test does not work with passing lambda for op From ad1b18ab2fe103dc608bd379e61054fac296fdeb Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Thu, 30 May 2024 21:08:45 +0000 Subject: [PATCH 118/706] Add repo-specific scale config files (#127566) Part of moving pytorch/pytorch CI infra to a Linux foundation run AWS account. For self-hosted runners that can run jobs from just a single repo, the runner scalers expect them to be stored in the repo itself. These scale-config files define how the linux foundation's self-hosted runners are configured. These will apply to runners that only are available to the pytorch/pytorch and pytorch/pytorch-canary repos Pull Request resolved: https://github.com/pytorch/pytorch/pull/127566 Approved by: https://github.com/zxiiro, https://github.com/huydhn, https://github.com/atalman --- .github/lf-canary-scale-config.yml | 8 ++++++++ .github/lf-scale-config.yml | 8 ++++++++ 2 files changed, 16 insertions(+) create mode 100644 .github/lf-canary-scale-config.yml create mode 100644 .github/lf-scale-config.yml diff --git a/.github/lf-canary-scale-config.yml b/.github/lf-canary-scale-config.yml new file mode 100644 index 000000000000..6aeca46f484e --- /dev/null +++ b/.github/lf-canary-scale-config.yml @@ -0,0 +1,8 @@ +# Defines runner types provisioned by by LF Self-hosted runners for pytorch/pytorch-canary and their labels. +runner_types: + lf.c.linux.2xlarge: + disk_size: 150 + instance_type: c5.2xlarge + is_ephemeral: false + max_available: 3120 + os: linux diff --git a/.github/lf-scale-config.yml b/.github/lf-scale-config.yml new file mode 100644 index 000000000000..758b7bd90314 --- /dev/null +++ b/.github/lf-scale-config.yml @@ -0,0 +1,8 @@ +# Defines runner types provisioned by by LF Self-hosted runners for pytorch/pytorch and their labels. +runner_types: + lf.linux.2xlarge: + disk_size: 150 + instance_type: c5.2xlarge + is_ephemeral: false + max_available: 3120 + os: linux From bf2f5e70dd4035e265714d87c69d602b80e563be Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 30 May 2024 21:13:17 +0000 Subject: [PATCH 119/706] Fix warnings in SmallVector (#127250) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/127250 Approved by: https://github.com/ezyang --- c10/test/util/small_vector_test.cpp | 8 +++---- c10/util/SmallVector.cpp | 2 +- c10/util/SmallVector.h | 33 +++++++++++------------------ 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/c10/test/util/small_vector_test.cpp b/c10/test/util/small_vector_test.cpp index e05d21ce88f1..1efe4d4910e0 100644 --- a/c10/test/util/small_vector_test.cpp +++ b/c10/test/util/small_vector_test.cpp @@ -576,8 +576,8 @@ TYPED_TEST(SmallVectorTest, EraseTest) { SCOPED_TRACE("EraseTest"); this->makeSequence(this->theVector, 1, 3); - const auto& theConstVector = this->theVector; - this->theVector.erase(theConstVector.begin()); + auto& theVector = this->theVector; + this->theVector.erase(theVector.begin()); this->assertValuesInOrder(this->theVector, 2u, 2, 3); } @@ -586,8 +586,8 @@ TYPED_TEST(SmallVectorTest, EraseRangeTest) { SCOPED_TRACE("EraseRangeTest"); this->makeSequence(this->theVector, 1, 3); - const auto& theConstVector = this->theVector; - this->theVector.erase(theConstVector.begin(), theConstVector.begin() + 2); + auto& theVector = this->theVector; + this->theVector.erase(theVector.begin(), theVector.begin() + 2); this->assertValuesInOrder(this->theVector, 1u, 3); } diff --git a/c10/util/SmallVector.cpp b/c10/util/SmallVector.cpp index 14b2fa9eb671..e30cdbf8dd3b 100644 --- a/c10/util/SmallVector.cpp +++ b/c10/util/SmallVector.cpp @@ -123,7 +123,7 @@ void* SmallVectorBase::mallocForGrow( // Note: Moving this function into the header may cause performance regression. template void SmallVectorBase::grow_pod( - void* FirstEl, + const void* FirstEl, size_t MinSize, size_t TSize) { size_t NewCapacity = getNewCapacity(MinSize, TSize, this->capacity()); diff --git a/c10/util/SmallVector.h b/c10/util/SmallVector.h index 919553811454..cbcfbc52cb8a 100644 --- a/c10/util/SmallVector.h +++ b/c10/util/SmallVector.h @@ -38,11 +38,6 @@ #include #include -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") -#endif - namespace c10 { /// This is all the stuff common to all SmallVectors. @@ -75,7 +70,7 @@ class C10_API SmallVectorBase { /// This is an implementation of the grow() method which only works /// on POD-like data types and is out of line to reduce code duplication. /// This function will report a fatal error if it cannot increase capacity. - void grow_pod(void* FirstEl, size_t MinSize, size_t TSize); + void grow_pod(const void* FirstEl, size_t MinSize, size_t TSize); public: SmallVectorBase() = delete; @@ -112,8 +107,10 @@ using SmallVectorSizeType = /// Figure out the offset of the first element. template struct SmallVectorAlignmentAndSize { + // NOLINTNEXTLINE(*c-arrays*) alignas(SmallVectorBase>) char Base[sizeof( SmallVectorBase>)]; + // NOLINTNEXTLINE(*c-arrays*) alignas(T) char FirstEl[sizeof(T)]; }; @@ -246,7 +243,7 @@ class SmallVectorTemplateCommon bool ReferencesStorage = false; int64_t Index = -1; - if (!U::TakesParamByValue) { + if constexpr (!U::TakesParamByValue) { if (C10_UNLIKELY(This->isReferenceToStorage(&Elt))) { ReferencesStorage = true; Index = &Elt - This->begin(); @@ -306,7 +303,7 @@ class SmallVectorTemplateCommon size_type size_in_bytes() const { return size() * sizeof(T); } - size_type max_size() const { + constexpr size_type max_size() const { return std::min(this->SizeTypeMax(), size_type(-1) / sizeof(T)); } @@ -475,6 +472,7 @@ class SmallVectorTemplateBase : public SmallVectorTemplateCommon { this->set_size(this->size() + 1); } + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) void push_back(T&& Elt) { T* EltPtr = reserveForParamAndGetAddress(Elt); ::new ((void*)this->end()) T(::std::move(*EltPtr)); @@ -788,13 +786,9 @@ class SmallVectorImpl : public SmallVectorTemplateBase { assign(RHS.begin(), RHS.end()); } - iterator erase(const_iterator CI) { - // Just cast away constness because this is a non-const member function. - iterator I = const_cast(CI); - + iterator erase(iterator I) { assert( - this->isReferenceToStorage(CI) && - "Iterator to erase is out of bounds."); + this->isReferenceToStorage(I) && "Iterator to erase is out of bounds."); iterator N = I; // Shift all elts down one. @@ -804,11 +798,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase { return (N); } - iterator erase(const_iterator CS, const_iterator CE) { - // Just cast away constness because this is a non-const member function. - iterator S = const_cast(CS); - iterator E = const_cast(CE); - + iterator erase(iterator S, iterator E) { assert(this->isRangeInStorage(S, E) && "Range to erase is out of bounds."); iterator N = S; @@ -1402,6 +1392,7 @@ class /* LLVM_GSL_OWNER */ SmallVector : public SmallVectorImpl, .end())>::iterator_category, std::input_iterator_tag>, int> = 0> + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) SmallVector& operator=(Container&& C) { this->assign(C.begin(), C.end()); return *this; @@ -1439,6 +1430,7 @@ using ValueTypeFromRangeType = std::remove_const_t< /// SmallVector with elements of the vector. This is useful, for example, /// when you want to iterate a range and then sort the results. template +// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) SmallVector, Size> to_vector(R&& Range) { return {std::begin(Range), std::end(Range)}; } @@ -1447,6 +1439,7 @@ SmallVector< ValueTypeFromRangeType, CalculateSmallVectorDefaultInlinedElements< ValueTypeFromRangeType>::value> +// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) to_vector(R&& Range) { return {std::begin(Range), std::end(Range)}; } @@ -1472,5 +1465,3 @@ inline void swap( } } // end namespace std - -C10_CLANG_DIAGNOSTIC_POP() From 094183dba61bb9f8aeb691058fe8a7f95a1d5a1c Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 30 May 2024 21:18:28 +0000 Subject: [PATCH 120/706] [torchbench][pt2] Enable Huggingface and Timm models for interal buck runner (#127460) Summary: Add huggingface and timm model runs to the internal pt2 benchmark runner. Test Plan: Tesing huggingface model: ``` $ buck2 run mode/opt //pytorch/benchmark:pt2 -- --only BlenderbotSmallForCausalLM --performance --training --device=cuda --amp 33/ 33 +0 frames 2s 13 graphs 13 graph calls 0/ -12 = 0% ops 0% time ``` Testing timm model: ``` $ buck2 run mode/opt //pytorch/benchmark:pt2 -- --only coat_lite_mini --performance --training --device=cuda --amp loading model: 0it [00:11, ?it/s] cuda train coat_lite_mini 8/ 8 +0 frames 4s 2 graphs 2 graph calls 0/ -1 = 0% ops 0% time ``` Differential Revision: D57930582 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127460 Approved by: https://github.com/HDCharles, https://github.com/huydhn --- benchmarks/dynamo/huggingface.py | 5 ++++- benchmarks/dynamo/timm_models.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index 5e139783c196..dca2915a07b2 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -7,7 +7,10 @@ import sys import warnings -from common import BenchmarkRunner, download_retry_decorator, main, reset_rng_state +try: + from .common import BenchmarkRunner, download_retry_decorator, main, reset_rng_state +except ImportError: + from common import BenchmarkRunner, download_retry_decorator, main, reset_rng_state import torch diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index db29a9bf365a..75a12517698e 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -7,7 +7,10 @@ import sys import warnings -from common import BenchmarkRunner, download_retry_decorator, main +try: + from .common import BenchmarkRunner, download_retry_decorator, main +except ImportError: + from common import BenchmarkRunner, download_retry_decorator, main import torch From 6849b8041124cbf4660e50cf4b8cb480aadc0d79 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 30 May 2024 21:22:39 +0000 Subject: [PATCH 121/706] Add `ninja` as dev dependency (#127380) `ninja` is required to build C++ extensions in tests. ```pytb ERROR: test_autograd_cpp_node (__main__.TestCompiledAutograd) ---------------------------------------------------------------------- Traceback (most recent call last): File "/home/PanXuehai/Projects/pytorch/torch/testing/_internal/common_utils.py", line 2741, in wrapper method(*args, **kwargs) File "test/inductor/test_compiled_autograd.py", line 1061, in test_autograd_cpp_node module = torch.utils.cpp_extension.load_inline( File "/home/PanXuehai/Projects/pytorch/torch/utils/cpp_extension.py", line 1643, in load_inline return _jit_compile( File "/home/PanXuehai/Projects/pytorch/torch/utils/cpp_extension.py", line 1718, in _jit_compile _write_ninja_file_and_build_library( File "/home/PanXuehai/Projects/pytorch/torch/utils/cpp_extension.py", line 1800, in _write_ninja_file_and_build_library verify_ninja_availability() File "/home/PanXuehai/Projects/pytorch/torch/utils/cpp_extension.py", line 1849, in verify_ninja_availability raise RuntimeError("Ninja is required to load C++ extensions") RuntimeError: Ninja is required to load C++ extensions To execute this test, run the following from the base repo dir: python test/inductor/test_compiled_autograd.py -k TestCompiledAutograd.test_autograd_cpp_node ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127380 Approved by: https://github.com/ezyang --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 09259eb5c23c..cc1616a1d99c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ networkx jinja2 fsspec lintrunner +ninja # setuptools was removed from default python install setuptools ; python_version >= "3.12" packaging From f471482eb215bf7eeeca7aaa5dcc082cd3dbb958 Mon Sep 17 00:00:00 2001 From: dilililiwhy Date: Thu, 30 May 2024 21:33:39 +0000 Subject: [PATCH 122/706] Try to include NCCL related header file with macro USE_C10D_NCCL (#127501) Fixes #ISSUE_NUMBER Try to include NCCL related header file with macro USE_C10D_NCCL, so that third-party device compilation will not be interrupted. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127501 Approved by: https://github.com/ezyang --- torch/csrc/distributed/c10d/TraceUtils.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 575bb0451f18..e8dadb6537e0 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -1,16 +1,19 @@ #pragma once -#include #include #include #include #include -#include #include #include #include #include #include +#ifdef USE_C10D_NCCL +#include +#include +#endif + #include #include #include From a288b95d4e5ceed327c5bdb9696331aa87688d60 Mon Sep 17 00:00:00 2001 From: hippocookie Date: Thu, 30 May 2024 21:34:13 +0000 Subject: [PATCH 123/706] Enable UFMT on test_shape_ops.py test_show_pickle.py test_sort_and_select.py (#127165) Fixes some files in #123062 Run lintrunner on files: test_shape_ops.py test_show_pickle.py test_sort_and_select.py ```bash $ lintrunner --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127165 Approved by: https://github.com/ezyang --- .lintrunner.toml | 3 - test/test_shape_ops.py | 243 ++++++++----- test/test_show_pickle.py | 13 +- test/test_sort_and_select.py | 646 ++++++++++++++++++++++------------- 4 files changed, 587 insertions(+), 318 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 1e0a2f37fcf4..db202cfd0fa4 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1115,9 +1115,6 @@ exclude_patterns = [ 'test/test_segment_reductions.py', 'test/test_serialization.py', 'test/test_set_default_mobile_cpu_allocator.py', - 'test/test_shape_ops.py', - 'test/test_show_pickle.py', - 'test/test_sort_and_select.py', 'test/test_sparse.py', 'test/test_sparse_csr.py', 'test/test_sparse_semi_structured.py', diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py index 47acfff9c6d4..5ea4139888bc 100644 --- a/test/test_shape_ops.py +++ b/test/test_shape_ops.py @@ -1,22 +1,40 @@ # Owner(s): ["module: tests"] -import torch -import numpy as np - -from itertools import product, combinations, permutations, chain -from functools import partial import random -import warnings import unittest +import warnings +from functools import partial + +from itertools import chain, combinations, permutations, product + +import numpy as np +import torch from torch import nan from torch.testing import make_tensor -from torch.testing._internal.common_utils import ( - TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict, IS_JETSON, TEST_PRIVATEUSE1_DEVICE_TYPE) from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyNativeDeviceTypes, - dtypesIfCUDA, largeTensorTest) -from torch.testing._internal.common_dtype import all_types_and_complex_and, all_types, all_types_and + dtypes, + dtypesIfCUDA, + instantiate_device_type_tests, + largeTensorTest, + onlyCPU, + onlyCUDA, + onlyNativeDeviceTypes, +) +from torch.testing._internal.common_dtype import ( + all_types, + all_types_and, + all_types_and_complex_and, +) +from torch.testing._internal.common_utils import ( + IS_JETSON, + run_tests, + skipIfTorchDynamo, + TEST_PRIVATEUSE1_DEVICE_TYPE, + TestCase, + torch_to_numpy_dtype_dict, +) + # TODO: replace with make_tensor def _generate_input(shape, dtype, device, with_extremal): @@ -29,17 +47,19 @@ def _generate_input(shape, dtype, device, with_extremal): x = torch.randn(*shape, device=device) * random.randint(30, 100) x = x.to(torch.bfloat16) else: - x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint( + 30, 100 + ) x[torch.randn(*shape) > 0.5] = 0 if with_extremal and dtype.is_floating_point: # Use extremal values - x[torch.randn(*shape) > 0.5] = float('nan') - x[torch.randn(*shape) > 0.5] = float('inf') - x[torch.randn(*shape) > 0.5] = float('-inf') + x[torch.randn(*shape) > 0.5] = float("nan") + x[torch.randn(*shape) > 0.5] = float("inf") + x[torch.randn(*shape) > 0.5] = float("-inf") elif with_extremal and dtype.is_complex: - x[torch.randn(*shape) > 0.5] = complex('nan') - x[torch.randn(*shape) > 0.5] = complex('inf') - x[torch.randn(*shape) > 0.5] = complex('-inf') + x[torch.randn(*shape) > 0.5] = complex("nan") + x[torch.randn(*shape) > 0.5] = complex("inf") + x[torch.randn(*shape) > 0.5] = complex("-inf") elif dtype == torch.bool: x = torch.zeros(shape, dtype=dtype, device=device) x[torch.randn(*shape) > 0.5] = True @@ -48,8 +68,8 @@ def _generate_input(shape, dtype, device, with_extremal): return x -class TestShapeOps(TestCase): +class TestShapeOps(TestCase): # TODO: update to work on CUDA, too @onlyCPU def test_unbind(self, device): @@ -71,7 +91,7 @@ def test_tolist(self, device): tensor0D = torch.tensor(list0D) self.assertEqual(tensor0D.tolist(), list0D) - table1D = [1., 2., 3.] + table1D = [1.0, 2.0, 3.0] tensor1D = torch.tensor(table1D) storage = torch.Storage(table1D) self.assertEqual(tensor1D.tolist(), table1D) @@ -102,19 +122,29 @@ def test_movedim_invalid(self, device, dtype): fn(x, 0, 5) # Mismatch in size of `source` and `destination` - with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"): - fn(x, (1, 0), (0, )) - - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: Invalid source or destination dims:" + ): + fn(x, (1, 0), (0,)) + + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `source`" + ): fn(x, (0, 0), (0, 1)) - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `source`" + ): fn(x, (0, 1, 0), (0, 1, 2)) - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `destination`" + ): fn(x, (0, 1), (1, 1)) - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `destination`" + ): fn(x, (0, 1, 2), (1, 0, 1)) @dtypes(torch.int64, torch.float, torch.complex128) @@ -137,8 +167,12 @@ def test_movedim(self, device, dtype): # Integer `source` and `destination` torch_fn = partial(fn, source=src_dim, destination=dst_dim) - np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + np_fn = partial( + np.moveaxis, source=src_dim, destination=dst_dim + ) + self.compare_with_numpy( + torch_fn, np_fn, x, device=None, dtype=None + ) if nd == 0: continue @@ -148,9 +182,13 @@ def make_index_negative(sequence, idx): sequence[random_idx] = sequence[random_idx] - nd return tuple(src_sequence) - for src_sequence in permutations(range(nd), r=random.randint(1, nd)): + for src_sequence in permutations( + range(nd), r=random.randint(1, nd) + ): # Sequence `source` and `destination` - dst_sequence = tuple(random.sample(range(nd), len(src_sequence))) + dst_sequence = tuple( + random.sample(range(nd), len(src_sequence)) + ) # Randomly change a dim to a negative dim representation of itself. random_prob = random.random() @@ -166,9 +204,15 @@ def make_index_negative(sequence, idx): random_idx = random.randint(0, len(src_sequence) - 1) src_sequence = make_index_negative(src_sequence, random_idx) - torch_fn = partial(fn, source=src_sequence, destination=dst_sequence) - np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + torch_fn = partial( + fn, source=src_sequence, destination=dst_sequence + ) + np_fn = partial( + np.moveaxis, source=src_sequence, destination=dst_sequence + ) + self.compare_with_numpy( + torch_fn, np_fn, x, device=None, dtype=None + ) # Move dim to same position x = torch.randn(2, 3, 5, 7, 11) @@ -213,10 +257,7 @@ def test_diagonal(self, device): def test_diagonal_multidim(self, device, dtype): x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device) xn = x.numpy() - for args in [(2, 2, 3), - (2,), - (-2, 1, 2), - (0, -2, -1)]: + for args in [(2, 2, 3), (2,), (-2, 1, 2), (0, -2, -1)]: result = torch.diagonal(x, *args) expected = xn.diagonal(*args) self.assertEqual(expected.shape, result.shape) @@ -270,14 +311,22 @@ def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nan max_vals = max_vals.cpu().numpy() # Use NumPy implementation as reference - X_clamped = torch.tensor(np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device) + X_clamped = torch.tensor( + np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device + ) return X, X_clamped # Tests clamp and its alias, clip @dtypes(torch.int64, torch.float32) def test_clamp(self, device, dtype): - op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, - torch.clip, torch.Tensor.clip, torch.Tensor.clip_) + op_list = ( + torch.clamp, + torch.Tensor.clamp, + torch.Tensor.clamp_, + torch.clip, + torch.Tensor.clip, + torch.Tensor.clip_, + ) # min/max argument product args = product((-10, None), (10, None)) @@ -287,10 +336,9 @@ def test_clamp(self, device, dtype): if min_val is None and max_val is None: continue - X, Y_expected = self.generate_clamp_baseline(device, dtype, - min_vals=min_val, - max_vals=max_val, - with_nans=False) + X, Y_expected = self.generate_clamp_baseline( + device, dtype, min_vals=min_val, max_vals=max_val, with_nans=False + ) # Test op X1 = X.clone() # So that the in-place ops do not change X @@ -304,8 +352,14 @@ def test_clamp(self, device, dtype): self.assertEqual(Y_expected, Y_out) def test_clamp_propagates_nans(self, device): - op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, - torch.clip, torch.Tensor.clip, torch.Tensor.clip_) + op_list = ( + torch.clamp, + torch.Tensor.clamp, + torch.Tensor.clamp_, + torch.clip, + torch.Tensor.clip, + torch.Tensor.clip_, + ) # min/max argument product args = product((-10, None), (10, None)) @@ -315,10 +369,13 @@ def test_clamp_propagates_nans(self, device): if min_val is None and max_val is None: continue - X, Y_expected = self.generate_clamp_baseline(device, torch.float, - min_vals=min_val, - max_vals=max_val, - with_nans=True) + X, Y_expected = self.generate_clamp_baseline( + device, + torch.float, + min_vals=min_val, + max_vals=max_val, + with_nans=True, + ) Y_expected = torch.isnan(Y_expected) # Test op @@ -334,7 +391,7 @@ def test_clamp_propagates_nans(self, device): def test_clamp_raises_arg_errors(self, device): X = torch.randn(100, dtype=torch.float, device=device) - error_msg = 'At least one of \'min\' or \'max\' must not be None' + error_msg = "At least one of 'min' or 'max' must not be None" with self.assertRaisesRegex(RuntimeError, error_msg): X.clamp() with self.assertRaisesRegex(RuntimeError, error_msg): @@ -369,18 +426,22 @@ def all_t(): self.assertEqual(in_t.flip(p_dims), out_t) if len(p_dims) > 0: # Wrap 1st dim - self.assertEqual(in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t) + self.assertEqual( + in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t + ) def gen_data(): # Basic tests data = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2) nonctg = make_from_size((2, 2, 2), noncontiguous=True).copy_(data) - dims_result = ((0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)), - (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)), - (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)), - ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)), - ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2))) + dims_result = ( + (0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)), + (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)), + (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)), + ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)), + ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2)), + ) for in_tensor, (dims, out_tensor) in product((data, nonctg), dims_result): yield in_tensor, dims, out_tensor @@ -393,7 +454,9 @@ def gen_data(): yield in_t, 1, in_t # Transposed - in_t = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1) + in_t = ( + make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1) + ) dims = (0, 1, 2) out_t = make_from_data([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2) yield in_t, dims, out_t @@ -411,7 +474,9 @@ def gen_data(): if device == "cpu" and dtype != torch.bfloat16: for mf in [torch.contiguous_format, torch.channels_last]: for c in [2, 3, 8, 16]: - in_t = make_from_size((2, c, 32, 32)).contiguous(memory_format=mf) + in_t = make_from_size((2, c, 32, 32)).contiguous( + memory_format=mf + ) np_in_t = in_t.numpy() np_out_t = np_in_t[:, :, :, ::-1].copy() @@ -464,7 +529,9 @@ def gen_data(): size = [2, 3, 4] data = make_from_size(size) possible_dims = range(len(size)) - test_dims = chain(combinations(possible_dims, 1), combinations(possible_dims, 2)) + test_dims = chain( + combinations(possible_dims, 1), combinations(possible_dims, 2) + ) for dims in test_dims: self.assertEqual(size, list(data.flip(dims).size())) @@ -483,7 +550,6 @@ def test_flip_errors(self, device, dtype): self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3)) self.assertRaises(IndexError, lambda: data.flip(3)) - def _rand_shape(self, dim, min_size, max_size): return tuple(torch.randint(min_size, max_size + 1, (dim,))) @@ -504,8 +570,10 @@ def test_flip_numpy(self, device, dtype): self.compare_with_numpy(torch_fn, np_fn, data) @onlyCUDA # CPU is too slow - @largeTensorTest('17GB') # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB - @largeTensorTest("81GB", "cpu") # even for CUDA test, sufficient system memory is required + @largeTensorTest("17GB") # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB + @largeTensorTest( + "81GB", "cpu" + ) # even for CUDA test, sufficient system memory is required @unittest.skipIf(IS_JETSON, "Too large for Jetson") def test_flip_large_tensor(self, device): t_in = torch.empty(2**32 + 1, dtype=torch.uint8).random_() @@ -569,7 +637,9 @@ def test_rot90(self, device): # test tensor with more than 2D data = torch.arange(1, 9, device=device).view(2, 2, 2) - self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) + self.assertEqual( + torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2]) + ) self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) # test for errors @@ -601,7 +671,6 @@ def test_nonzero_no_warning(self, device): @dtypes(*all_types_and(torch.half, torch.bool, torch.bfloat16)) def test_nonzero(self, device, dtype): - shapes = [ torch.Size((12,)), torch.Size((12, 1)), @@ -616,7 +685,9 @@ def gen_nontrivial_input(shape, dtype, device): return torch.randint(2, shape, device=device, dtype=dtype) else: # windows does not work for bfloat16 randing - return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype) + return torch.randint(2, shape, device=device, dtype=torch.float).to( + dtype + ) for shape in shapes: tensor = gen_nontrivial_input(shape, dtype, device) @@ -624,20 +695,31 @@ def gen_nontrivial_input(shape, dtype, device): dst2 = tensor.nonzero(as_tuple=False) dst3 = torch.empty([], dtype=torch.long, device=device) torch.nonzero(tensor, out=dst3) - if self.device_type != 'xla': + if self.device_type != "xla": # xla does not raise runtime error self.assertRaisesRegex( RuntimeError, "scalar type Long", - lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.float, device=device)) + lambda: torch.nonzero( + tensor, out=torch.empty([], dtype=torch.float, device=device) + ), ) - if self.device_type == 'cuda' or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE: + if ( + self.device_type == "cuda" + or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE + ): self.assertRaisesRegex( RuntimeError, "on the same device", - lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.long)) + lambda: torch.nonzero( + tensor, out=torch.empty([], dtype=torch.long) + ), ) - np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy() + np_array = ( + tensor.cpu().numpy() + if dtype != torch.bfloat16 + else tensor.float().cpu().numpy() + ) np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) @@ -656,7 +738,9 @@ def test_nonzero_astuple_out(self, device): with self.assertRaises(RuntimeError): torch.nonzero(t, as_tuple=True, out=out) - self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out)) + self.assertEqual( + torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out) + ) # Verifies that JIT script cannot handle the as_tuple kwarg # See Issue https://github.com/pytorch/pytorch/issues/45499. @@ -684,7 +768,9 @@ def _foo(t): def test_nonzero_discontiguous(self, device): shape = (4, 4) tensor = torch.randint(2, shape, device=device) - tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor) + tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_( + tensor + ) dst1 = tensor.nonzero(as_tuple=False) dst2 = tensor_nc.nonzero(as_tuple=False) self.assertEqual(dst1, dst2, atol=0, rtol=0) @@ -695,7 +781,9 @@ def test_nonzero_discontiguous(self, device): self.assertEqual(data_ptr, dst3.data_ptr()) self.assertEqual(dst1, dst3, atol=0, rtol=0) # discontiguous out - dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2] + dst4 = torch.empty( + dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device + )[:, ::2] data_ptr = dst4.data_ptr() strides = dst4.stride() torch.nonzero(tensor, out=dst4) @@ -710,7 +798,7 @@ def test_nonzero_non_diff(self, device): @dtypes(torch.int64, torch.float, torch.complex128) def test_sparse_dense_dim(self, device, dtype): - for shape in [(), (2, ), (2, 3)]: + for shape in [(), (2,), (2, 3)]: if dtype.is_complex or dtype.is_floating_point: x = torch.rand(shape, device=device, dtype=dtype) else: @@ -718,7 +806,8 @@ def test_sparse_dense_dim(self, device, dtype): self.assertEqual(x.sparse_dim(), 0) self.assertEqual(x.dense_dim(), len(shape)) + instantiate_device_type_tests(TestShapeOps, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_show_pickle.py b/test/test_show_pickle.py index 929584943007..48b459e12eac 100644 --- a/test/test_show_pickle.py +++ b/test/test_show_pickle.py @@ -1,15 +1,16 @@ # Owner(s): ["oncall: mobile"] -import unittest import io import tempfile +import unittest + import torch import torch.utils.show_pickle -from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase -class TestShowPickle(TestCase): +class TestShowPickle(TestCase): @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows") def test_scripted_model(self): class MyCoolModule(torch.nn.Module): @@ -26,11 +27,13 @@ def forward(self, x): torch.jit.save(m, tmp) tmp.flush() buf = io.StringIO() - torch.utils.show_pickle.main(["", tmp.name + "@*/data.pkl"], output_stream=buf) + torch.utils.show_pickle.main( + ["", tmp.name + "@*/data.pkl"], output_stream=buf + ) output = buf.getvalue() self.assertRegex(output, "MyCoolModule") self.assertRegex(output, "weight") -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 7709131e6102..73414bb28be3 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -1,44 +1,68 @@ # Owner(s): ["module: tests"] -import torch -import numpy as np - import random -from torch import nan from itertools import permutations, product +import numpy as np +import torch +from torch import nan + from torch.testing import make_tensor -from torch.testing._internal.common_dtype import all_types, all_types_and, floating_types_and, integral_types -from torch.testing._internal.common_utils import \ - (TestCase, run_tests, slowTest, skipIfTorchDynamo) -from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, onlyNativeDeviceTypes, - onlyCUDA, dtypesIfCUDA, dtypesIfCPU, onlyCPU, largeTensorTest) +from torch.testing._internal.common_device_type import ( + dtypes, + dtypesIfCPU, + dtypesIfCUDA, + instantiate_device_type_tests, + largeTensorTest, + onlyCPU, + onlyCUDA, + onlyNativeDeviceTypes, +) +from torch.testing._internal.common_dtype import ( + all_types, + all_types_and, + floating_types_and, + integral_types, +) +from torch.testing._internal.common_utils import ( + run_tests, + skipIfTorchDynamo, + slowTest, + TestCase, +) # TODO: remove this SIZE = 100 -class TestSortAndSelect(TestCase): +class TestSortAndSelect(TestCase): def assertIsOrdered(self, order, x, mxx, ixx, task): SIZE = x.size(1) - if order == 'descending': + if order == "descending": + def check_order(a, b): # `a != a` because we put NaNs # at the end of ascending sorted lists, # and the beginning of descending ones. return ((a != a) | (a >= b)).all().item() - elif order == 'ascending': + + elif order == "ascending": + def check_order(a, b): # see above return ((b != b) | (a <= b)).all().item() + else: - error(f'unknown order "{order}", must be "ascending" or "descending"') # noqa: F821 + error( # noqa: F821 + f'unknown order "{order}", must be "ascending" or "descending"' + ) are_ordered = True for k in range(1, SIZE): - self.assertTrue(check_order(mxx[:, k - 1], mxx[:, k]), - f'torch.sort ({order}) values unordered for {task}') + self.assertTrue( + check_order(mxx[:, k - 1], mxx[:, k]), + f"torch.sort ({order}) values unordered for {task}", + ) seen = set() indicesCorrect = True @@ -50,8 +74,11 @@ def check_order(a, b): for k in range(size0): seen.clear() for j in range(size): - self.assertEqual(x[k][ixx[k][j]], mxx[k][j], - msg=f'torch.sort ({order}) indices wrong for {task}') + self.assertEqual( + x[k][ixx[k][j]], + mxx[k][j], + msg=f"torch.sort ({order}) indices wrong for {task}", + ) seen.add(ixx[k][j]) self.assertEqual(len(seen), size) @@ -79,19 +106,22 @@ def test_sort(self, device): self.assertEqual(x.argsort(), res1ind) # Test sorting of random numbers - self.assertIsOrdered('ascending', x, res2val, res2ind, 'random') + self.assertIsOrdered("ascending", x, res2val, res2ind, "random") # Test simple sort self.assertEqual( torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0], torch.tensor((10, 20, 30, 40, 50), device=device), - atol=0, rtol=0 + atol=0, + rtol=0, ) # Test that we still have proper sorting with duplicate keys x = torch.floor(torch.rand(4, SIZE, device=device) * 10) torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered('ascending', x, res2val, res2ind, 'random with duplicate keys') + self.assertIsOrdered( + "ascending", x, res2val, res2ind, "random with duplicate keys" + ) # DESCENDING SORT x = torch.rand(4, SIZE, device=device) @@ -107,35 +137,41 @@ def test_sort(self, device): self.assertEqual(x.argsort(x.dim() - 1, True), res1ind) # Test sorting of random numbers - self.assertIsOrdered('descending', x, res2val, res2ind, 'random') + self.assertIsOrdered("descending", x, res2val, res2ind, "random") # Test simple sort task self.assertEqual( - torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[0], + torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[ + 0 + ], torch.tensor((50, 40, 30, 20, 10), device=device), - atol=0, rtol=0 + atol=0, + rtol=0, ) # Test that we still have proper sorting with duplicate keys - self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys') + self.assertIsOrdered( + "descending", x, res2val, res2ind, "random with duplicate keys" + ) # Test argument sorting with and without stable x = torch.tensor([1, 10, 2, 2, 3, 7, 7, 8, 9, 9] * 3) - self.assertEqual(torch.argsort(x, stable=True), torch.sort(x, stable=True).indices) - self.assertEqual(torch.argsort(x, stable=False), torch.sort(x, stable=False).indices) + self.assertEqual( + torch.argsort(x, stable=True), torch.sort(x, stable=True).indices + ) + self.assertEqual( + torch.argsort(x, stable=False), torch.sort(x, stable=False).indices + ) self.assertEqual(torch.argsort(x), torch.sort(x).indices) - # Test sorting with NaNs x = torch.rand(4, SIZE, device=device) - x[1][2] = float('NaN') - x[3][0] = float('NaN') + x[1][2] = float("NaN") + x[3][0] = float("NaN") torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered('ascending', x, res2val, res2ind, - 'random with NaNs') + self.assertIsOrdered("ascending", x, res2val, res2ind, "random with NaNs") torch.sort(x, out=(res2val, res2ind), descending=True) - self.assertIsOrdered('descending', x, res2val, res2ind, - 'random with NaNs') + self.assertIsOrdered("descending", x, res2val, res2ind, "random with NaNs") def test_sort_stable_none(self): # Called sort with stable=None used to trigger an assertion @@ -169,19 +205,19 @@ def test_stable_sort(self, device, dtype): _, idx = x.sort(stable=True) self.assertEqual( idx[:ncopies], - torch.arange(start=0, end=2 * ncopies, step=2, device=device) + torch.arange(start=0, end=2 * ncopies, step=2, device=device), ) self.assertEqual( idx[ncopies:], - torch.arange(start=1, end=2 * ncopies, step=2, device=device) + torch.arange(start=1, end=2 * ncopies, step=2, device=device), ) @onlyCUDA @dtypes(torch.uint8) - @largeTensorTest('200GB') # Unfortunately 80GB A100 is not large enough + @largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough def test_sort_large(self, device, dtype): t0 = torch.randperm(8192, device=device).to(dtype) - t = t0.view(1, 8192).expand(2 ** 18 + 1, -1).contiguous() + t = t0.view(1, 8192).expand(2**18 + 1, -1).contiguous() v, i = t.sort() del t iv, im = i.var_mean(dim=0) @@ -193,7 +229,6 @@ def test_sort_large(self, device, dtype): self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device)) self.assertEqual(im, t0.sort().indices) - @dtypes(torch.float32) def test_sort_restride(self, device, dtype): # Input: non-contiguous (stride: 5) 3-element array @@ -223,14 +258,24 @@ def _test_sort_discontiguous(self, device, dtype): n = t.size(dim) # assert ordered - self.assertTrue((r1.values.narrow(dim, 1, n - 1) >= r1.values.narrow(dim, 0, n - 1)).all()) + self.assertTrue( + ( + r1.values.narrow(dim, 1, n - 1) + >= r1.values.narrow(dim, 0, n - 1) + ).all() + ) # assert that different segments does not mix, which can easily happen # if the stride is not handled correctly - self.assertTrue((t.unsqueeze(-1).transpose(dim, -1) == r1.values.unsqueeze(-1)).any(dim=dim).any(dim=-1).all()) + self.assertTrue( + (t.unsqueeze(-1).transpose(dim, -1) == r1.values.unsqueeze(-1)) + .any(dim=dim) + .any(dim=-1) + .all() + ) # assert stride is preserved - if self.device_type == 'cuda': + if self.device_type == "cuda": # FIXME: this behavior should be true for all cases, not # just the one specified in if condition self.assertEqual(r1.values.stride(), t.stride()) @@ -262,7 +307,9 @@ def test_sort_1d_output_discontiguous(self, device, dtype): @dtypes(*integral_types()) def test_sort_1d_parallel(self, device, dtype): low = 0 if dtype == torch.uint8 else -128 - tensor = torch.randint(low=low, high=127, size=(100000, ), device=device, dtype=dtype) + tensor = torch.randint( + low=low, high=127, size=(100000,), device=device, dtype=dtype + ) vals, _ = torch.sort(tensor, stable=True) self.assertEqual(True, torch.all(vals[:-1] <= vals[1:])) @@ -283,9 +330,9 @@ def test_topk_1d_output_discontiguous(self, device, dtype): @dtypes(*all_types_and(torch.half, torch.bfloat16)) def test_stable_sort_against_numpy(self, device, dtype): if dtype in floating_types_and(torch.float16, torch.bfloat16): - inf = float('inf') - neg_inf = -float('inf') - nan = float('nan') + inf = float("inf") + neg_inf = -float("inf") + nan = float("nan") else: if dtype != torch.bool: # no torch.iinfo support for torch.bool @@ -305,7 +352,7 @@ def generate_samples(): # binary strings yield (torch.tensor([0, 1] * size, dtype=dtype, device=device), 0) - if self.device_type == 'cuda': + if self.device_type == "cuda": return yield (torch.tensor([0, 1] * 100, dtype=dtype, device=device), 0) @@ -326,13 +373,21 @@ def repeated_index_fill(t, dim, idxs, vals): # for each dimension. n_fill_vals = 3 # cardinality of (inf, neg_inf, nan) for dim in range(len(sizes)): - idxs = (torch.randint(high=size, size=(size // 10,)) for i in range(n_fill_vals)) + idxs = ( + torch.randint(high=size, size=(size // 10,)) + for i in range(n_fill_vals) + ) vals = (inf, neg_inf, nan) - subsets = chain.from_iterable(combinations(list(zip(idxs, vals)), r) - for r in range(1, n_fill_vals + 1)) + subsets = chain.from_iterable( + combinations(list(zip(idxs, vals)), r) + for r in range(1, n_fill_vals + 1) + ) for subset in subsets: idxs_subset, vals_subset = zip(*subset) - yield (repeated_index_fill(x, dim, idxs_subset, vals_subset), dim) + yield ( + repeated_index_fill(x, dim, idxs_subset, vals_subset), + dim, + ) for sample, dim in generate_samples(): _, idx_torch = sample.sort(dim=dim, stable=True) @@ -340,7 +395,7 @@ def repeated_index_fill(t, dim, idxs, vals): sample_numpy = sample.float().cpu().numpy() else: sample_numpy = sample.cpu().numpy() - idx_numpy = np.argsort(sample_numpy, axis=dim, kind='stable') + idx_numpy = np.argsort(sample_numpy, axis=dim, kind="stable") self.assertEqual(idx_torch, idx_numpy) @dtypes(*all_types_and(torch.half, torch.bfloat16)) @@ -349,7 +404,9 @@ def test(shape): tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9) if tensor.size() != torch.Size([]): if dtype is torch.bfloat16: - expected = torch.from_numpy(np.msort(tensor.float().cpu().numpy())).bfloat16() + expected = torch.from_numpy( + np.msort(tensor.float().cpu().numpy()) + ).bfloat16() else: expected = torch.from_numpy(np.msort(tensor.cpu().numpy())) else: @@ -364,11 +421,15 @@ def test(shape): shapes = ( [], - [0, ], - [20, ], + [ + 0, + ], + [ + 20, + ], [1, 20], [30, 30], - [10, 20, 30] + [10, 20, 30], ) for shape in shapes: test(shape) @@ -414,9 +475,12 @@ def compare(t, k, dim, dir): sortKVal, sortKInd = topKViaSort(t, k, dim, dir) compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim) - t = torch.rand(random.randint(1, SIZE), - random.randint(1, SIZE), - random.randint(1, SIZE), device=device) + t = torch.rand( + random.randint(1, SIZE), + random.randint(1, SIZE), + random.randint(1, SIZE), + device=device, + ) for _kTries in range(3): for _dimTries in range(3): @@ -457,91 +521,94 @@ def test_topk_arguments(self, device): self.assertRaises(TypeError, lambda: q.topk(4, True)) def test_unique_dim(self, device): - self.assertFalse(hasattr(torch, 'unique_dim')) + self.assertFalse(hasattr(torch, "unique_dim")) def run_test(device, dtype): - x = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]], - [[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) + x = torch.tensor( + [ + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + ], + dtype=dtype, + device=device, + ) x_empty = torch.empty(5, 0, dtype=dtype, device=device) x_ill_formed_empty = torch.empty(5, 0, 0, dtype=dtype, device=device) - x_ill_formed_empty_another = torch.empty(5, 0, 5, dtype=dtype, device=device) + x_ill_formed_empty_another = torch.empty( + 5, 0, 5, dtype=dtype, device=device + ) if dtype in floating_types_and(torch.float16, torch.bfloat16): - x_nan = torch.tensor([float("nan"), 0, 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) - expected_unique_dim0 = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) + x_nan = torch.tensor( + [float("nan"), 0, 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) + expected_unique_dim0 = torch.tensor( + [[[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]]], + dtype=dtype, + device=device, + ) expected_inverse_dim0 = torch.tensor([0, 0]) expected_counts_dim0 = torch.tensor([2]) - expected_unique_dim1 = torch.tensor([[[0., 1.], - [1., 1.], - [2., 1.]], - [[0., 1.], - [1., 1.], - [2., 1.]]], - dtype=dtype, - device=device) - expected_unique_dim1_bool = torch.tensor([[[False, True], [True, True]], - [[False, True], [True, True]]], - dtype=torch.bool, - device=device) + expected_unique_dim1 = torch.tensor( + [ + [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]], + [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]], + ], + dtype=dtype, + device=device, + ) + expected_unique_dim1_bool = torch.tensor( + [[[False, True], [True, True]], [[False, True], [True, True]]], + dtype=torch.bool, + device=device, + ) expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) expected_inverse_dim1_bool = torch.tensor([1, 0, 1, 0]) expected_counts_dim1 = torch.tensor([2, 1, 1]) expected_counts_dim1_bool = torch.tensor([2, 2]) - expected_unique_dim2 = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]], - [[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) + expected_unique_dim2 = torch.tensor( + [ + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + ], + dtype=dtype, + device=device, + ) expected_inverse_dim2 = torch.tensor([0, 1]) expected_counts_dim2 = torch.tensor([1, 1]) expected_unique_empty = torch.empty(5, 0, dtype=dtype, device=device) expected_inverse_empty = torch.tensor([], dtype=torch.long, device=device) expected_counts_empty = torch.tensor([], dtype=torch.long, device=device) if dtype in floating_types_and(torch.float16, torch.bfloat16): - expected_unique_nan = torch.tensor([float("nan"), 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) - expected_inverse_nan = torch.tensor([0, 1, 1, 2, 3, 4], dtype=torch.long, device=device) - expected_counts_nan = torch.tensor([1, 2, 1, 1, 1], dtype=torch.long, device=device) + expected_unique_nan = torch.tensor( + [float("nan"), 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) + expected_inverse_nan = torch.tensor( + [0, 1, 1, 2, 3, 4], dtype=torch.long, device=device + ) + expected_counts_nan = torch.tensor( + [1, 2, 1, 1, 1], dtype=torch.long, device=device + ) # dim0 x_unique = torch.unique(x, dim=0) self.assertEqual(expected_unique_dim0, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=0) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=0) + x, return_inverse=False, return_counts=True, dim=0 + ) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_counts_dim0, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=0) + x, return_inverse=True, return_counts=True, dim=0 + ) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) self.assertEqual(expected_counts_dim0, x_counts) @@ -553,10 +620,7 @@ def run_test(device, dtype): else: self.assertEqual(expected_unique_dim1, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=1) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_inverse_dim1_bool, x_inverse) @@ -565,10 +629,8 @@ def run_test(device, dtype): self.assertEqual(expected_inverse_dim1, x_inverse) x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=1) + x, return_inverse=False, return_counts=True, dim=1 + ) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_counts_dim1_bool, x_counts) @@ -577,10 +639,8 @@ def run_test(device, dtype): self.assertEqual(expected_counts_dim1, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=1) + x, return_inverse=True, return_counts=True, dim=1 + ) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_inverse_dim1_bool, x_inverse) @@ -594,36 +654,27 @@ def run_test(device, dtype): x_unique = torch.unique(x, dim=2) self.assertEqual(expected_unique_dim2, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=2) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=2) + x, return_inverse=False, return_counts=True, dim=2 + ) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_counts_dim2, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=2) + x, return_inverse=True, return_counts=True, dim=2 + ) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) self.assertEqual(expected_counts_dim2, x_counts) # test empty tensor x_unique, x_inverse, x_counts = torch.unique( - x_empty, - return_inverse=True, - return_counts=True, - dim=1) + x_empty, return_inverse=True, return_counts=True, dim=1 + ) self.assertEqual(expected_unique_empty, x_unique) self.assertEqual(expected_inverse_empty, x_inverse) self.assertEqual(expected_counts_empty, x_counts) @@ -631,10 +682,8 @@ def run_test(device, dtype): # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): x_unique, x_inverse, x_counts = torch.unique( - x_nan, - return_inverse=True, - return_counts=True, - dim=0) + x_nan, return_inverse=True, return_counts=True, dim=0 + ) self.assertEqual(expected_unique_nan, x_unique) self.assertEqual(expected_inverse_nan, x_inverse) self.assertEqual(expected_counts_nan, x_counts) @@ -643,10 +692,8 @@ def run_test(device, dtype): # Checking for runtime error, as this is the expected behaviour with self.assertRaises(RuntimeError): torch.unique( - x_ill_formed_empty, - return_inverse=True, - return_counts=True, - dim=1) + x_ill_formed_empty, return_inverse=True, return_counts=True, dim=1 + ) # test along dim2 with self.assertRaises(RuntimeError): @@ -654,46 +701,66 @@ def run_test(device, dtype): x_ill_formed_empty_another, return_inverse=True, return_counts=True, - dim=2) + dim=2, + ) # test consecutive version y = torch.tensor( - [[0, 1], - [0, 1], - [0, 1], - [1, 2], - [1, 2], - [3, 4], - [0, 1], - [0, 1], - [3, 4], - [1, 2]], + [ + [0, 1], + [0, 1], + [0, 1], + [1, 2], + [1, 2], + [3, 4], + [0, 1], + [0, 1], + [3, 4], + [1, 2], + ], dtype=dtype, - device=device + device=device, ) # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): - y_nan = torch.tensor([float("nan"), 0, 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) + y_nan = torch.tensor( + [float("nan"), 0, 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) expected_y_unique = torch.tensor( - [[0, 1], - [1, 2], - [3, 4], - [0, 1], - [3, 4], - [1, 2]], + [[0, 1], [1, 2], [3, 4], [0, 1], [3, 4], [1, 2]], dtype=dtype, - device=device + device=device, + ) + expected_y_inverse = torch.tensor( + [0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device + ) + expected_y_counts = torch.tensor( + [3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device + ) + expected_y_inverse_bool = torch.tensor( + [0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device + ) + expected_y_counts_bool = torch.tensor( + [3, 3, 2, 2], dtype=torch.int64, device=device ) - expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device) - expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device) - expected_y_inverse_bool = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device) - expected_y_counts_bool = torch.tensor([3, 3, 2, 2], dtype=torch.int64, device=device) if dtype in floating_types_and(torch.float16, torch.bfloat16): - expected_y_unique_nan = torch.tensor([float("nan"), 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) - expected_y_inverse_nan = torch.tensor([0, 1, 1, 2, 3, 4], dtype=torch.long, device=device) - expected_y_counts_nan = torch.tensor([1, 2, 1, 1, 1], dtype=torch.long, device=device) - - y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0) + expected_y_unique_nan = torch.tensor( + [float("nan"), 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) + expected_y_inverse_nan = torch.tensor( + [0, 1, 1, 2, 3, 4], dtype=torch.long, device=device + ) + expected_y_counts_nan = torch.tensor( + [1, 2, 1, 1, 1], dtype=torch.long, device=device + ) + + y_unique, y_inverse, y_counts = torch.unique_consecutive( + y, return_inverse=True, return_counts=True, dim=0 + ) if x.dtype == torch.bool: self.assertEqual(expected_y_inverse_bool, y_inverse) self.assertEqual(expected_y_counts_bool, y_counts) @@ -704,23 +771,27 @@ def run_test(device, dtype): # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): y_unique, y_inverse, y_counts = torch.unique_consecutive( - y_nan, - return_inverse=True, - return_counts=True, - dim=0) + y_nan, return_inverse=True, return_counts=True, dim=0 + ) self.assertEqual(expected_y_unique_nan, y_unique) self.assertEqual(expected_y_inverse_nan, y_inverse) self.assertEqual(expected_y_counts_nan, y_counts) # Test dim is sorted same as NumPy with dims >= 3 - x = torch.tensor([[[[1, 0, 1, 0, 1, 1], - [0, 1, 1, 0, 1, 1]], - [[0, 1, 1, 0, 0, 1], - [0, 0, 0, 1, 0, 0]]], - [[[0, 1, 0, 1, 1, 1], - [0, 1, 1, 0, 1, 1]], - [[0, 0, 1, 1, 0, 1], - [1, 1, 0, 0, 0, 0]]]], dtype=dtype, device=device) + x = torch.tensor( + [ + [ + [[1, 0, 1, 0, 1, 1], [0, 1, 1, 0, 1, 1]], + [[0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0]], + ], + [ + [[0, 1, 0, 1, 1, 1], [0, 1, 1, 0, 1, 1]], + [[0, 0, 1, 1, 0, 1], [1, 1, 0, 0, 0, 0]], + ], + ], + dtype=dtype, + device=device, + ) xn = x.cpu().numpy() for d in range(x.dim()): t = torch.unique(x, dim=d) @@ -750,15 +821,20 @@ def test_topk_noncontiguous_gpu(self, device): def _test_topk_dtype(self, device, dtype, integral, size): if integral: - a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, - size=(size,), dtype=dtype, device=device) + a = torch.randint( + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + size=(size,), + dtype=dtype, + device=device, + ) else: a = torch.randn(size=(size,), dtype=dtype, device=device) - sort_topk = a.sort()[0][-(size // 2):].flip(0) + sort_topk = a.sort()[0][-(size // 2) :].flip(0) topk = a.topk(size // 2) - self.assertEqual(sort_topk, topk[0]) # check values - self.assertEqual(sort_topk, a[topk[1]]) # check indices + self.assertEqual(sort_topk, topk[0]) # check values + self.assertEqual(sort_topk, a[topk[1]]) # check indices @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64) def test_topk_integral(self, device, dtype): @@ -770,7 +846,6 @@ def test_topk_integral(self, device, dtype): @dtypes(torch.bfloat16, torch.half) def test_topk_lower_precision(self, device, dtype): - small = 10 large = 4096 verylarge = 8192 # multi_block topk on cuda @@ -780,14 +855,20 @@ def test_topk_lower_precision(self, device, dtype): @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) @dtypes(torch.float, torch.double, torch.bfloat16, torch.half) def test_topk_nonfinite(self, device, dtype): - x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype) + x = torch.tensor( + [float("nan"), float("inf"), 1e4, 0, -1e4, -float("inf")], + device=device, + dtype=dtype, + ) val, idx = x.topk(4) - expect = torch.tensor([float('nan'), float('inf'), 1e4, 0], device=device, dtype=dtype) + expect = torch.tensor( + [float("nan"), float("inf"), 1e4, 0], device=device, dtype=dtype + ) self.assertEqual(val, expect) self.assertEqual(idx, [0, 1, 2, 3]) val, idx = x.topk(4, largest=False) - expect = torch.tensor([-float('inf'), -1e4, 0, 1e4], device=device, dtype=dtype) + expect = torch.tensor([-float("inf"), -1e4, 0, 1e4], device=device, dtype=dtype) self.assertEqual(val, expect) self.assertEqual(idx, [5, 4, 3, 2]) @@ -796,13 +877,13 @@ def test_topk_4d(self, device): large = 8192 for size in (small, large): x = torch.ones(2, size, 2, 2, device=device) - x[:, 1, :, :] *= 2. + x[:, 1, :, :] *= 2.0 x[:, 10, :, :] *= 1.5 val, ind = torch.topk(x, k=2, dim=1) expected_ind = torch.ones(2, 2, 2, 2, dtype=torch.long, device=device) expected_ind[:, 1, :, :] = 10 expected_val = torch.ones(2, 2, 2, 2, device=device) - expected_val[:, 0, :, :] *= 2. + expected_val[:, 0, :, :] *= 2.0 expected_val[:, 1, :, :] *= 1.5 self.assertEqual(val, expected_val, atol=0, rtol=0) self.assertEqual(ind, expected_ind, atol=0, rtol=0) @@ -838,7 +919,17 @@ def _test_unique_scalar_empty(self, dtype, device, f): self.assertEqual(inverse, expected_inverse) self.assertEqual(counts, expected_counts) - def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape): + def _test_unique_with_expects( + self, + device, + dtype, + f, + x, + expected_unique, + expected_inverse, + expected_counts, + additional_shape, + ): def ensure_tuple(x): if isinstance(x, torch.Tensor): return (x,) @@ -847,7 +938,9 @@ def ensure_tuple(x): for return_inverse in [True, False]: for return_counts in [True, False]: # test with expected - ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) + ret = ensure_tuple( + f(x, return_inverse=return_inverse, return_counts=return_counts) + ) self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) self.assertEqual(expected_unique, ret[0]) if return_inverse: @@ -858,7 +951,9 @@ def ensure_tuple(x): # tests per-element unique on a higher rank tensor. y = x.view(additional_shape) - y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True) + y_unique, y_inverse, y_counts = f( + y, return_inverse=True, return_counts=True + ) self.assertEqual(expected_unique, y_unique) self.assertEqual(expected_inverse.view(additional_shape), y_inverse) self.assertEqual(expected_counts, y_counts) @@ -872,9 +967,17 @@ def ensure_tuple(x): return x if dtype is torch.bool: - x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device) - expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device) - expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device) + x = torch.tensor( + [True, False, False, False, True, False, True, False], + dtype=torch.bool, + device=device, + ) + expected_unique = torch.tensor( + [False, True], dtype=torch.bool, device=device + ) + expected_inverse = torch.tensor( + [1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device + ) expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device) else: x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device) @@ -890,18 +993,29 @@ def ensure_tuple(x): x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x) xs = (x, x_sliced) for f, x in product(fs, xs): - self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2)) + self._test_unique_with_expects( + device, + dtype, + f, + x, + expected_unique, + expected_inverse, + expected_counts, + (2, 2, 2), + ) self._test_unique_scalar_empty(dtype, device, f) # test unsorted unique fs = ( lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs), - lambda x, **kwargs: x.unique(sorted=False, **kwargs) + lambda x, **kwargs: x.unique(sorted=False, **kwargs), ) for f, x in product(fs, xs): self._test_unique_scalar_empty(dtype, device, f) for return_inverse, return_counts in product((True, False), repeat=2): - ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) + ret = ensure_tuple( + f(x, return_inverse=return_inverse, return_counts=return_counts) + ) self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) x_list = x.tolist() x_unique_list = ret[0].tolist() @@ -924,18 +1038,40 @@ def ensure_tuple(x): @dtypes(*all_types_and(torch.half, torch.bool)) def test_unique_consecutive(self, device, dtype): if dtype is torch.bool: - x = torch.tensor([True, False, False, False, True, True, False, False, False], dtype=torch.bool, device=device) - expected_unique = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) - expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device) - expected_counts = torch.tensor([1, 3, 2, 3], dtype=torch.long, device=device) + x = torch.tensor( + [True, False, False, False, True, True, False, False, False], + dtype=torch.bool, + device=device, + ) + expected_unique = torch.tensor( + [True, False, True, False], dtype=torch.bool, device=device + ) + expected_inverse = torch.tensor( + [0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device + ) + expected_counts = torch.tensor( + [1, 3, 2, 3], dtype=torch.long, device=device + ) else: x = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], dtype=dtype, device=device) expected_unique = torch.tensor([1, 2, 5, 2, 3], dtype=dtype, device=device) expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device) expected_counts = torch.tensor([1, 3, 2, 2, 1], device=device) - for f in [torch.unique_consecutive, lambda x, **kwargs: x.unique_consecutive(**kwargs)]: - self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (3, 3)) + for f in [ + torch.unique_consecutive, + lambda x, **kwargs: x.unique_consecutive(**kwargs), + ]: + self._test_unique_with_expects( + device, + dtype, + f, + x, + expected_unique, + expected_inverse, + expected_counts, + (3, 3), + ) self._test_unique_scalar_empty(dtype, device, f) @dtypes(torch.double) @@ -991,7 +1127,7 @@ def test_kthvalue(self, device, dtype): self.assertEqual(x, x0, atol=0, rtol=0) # simple test case (with repetitions) - y = torch.tensor((3., 5, 4, 1, 1, 5), dtype=dtype, device=device) + y = torch.tensor((3.0, 5, 4, 1, 1, 5), dtype=dtype, device=device) self.assertEqual(torch.kthvalue(y, 3)[0], 3, atol=0, rtol=0) self.assertEqual(torch.kthvalue(y, 2)[0], 1, atol=0, rtol=0) @@ -1007,7 +1143,7 @@ def test_kthvalue(self, device, dtype): self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) @dtypes(torch.float) - @onlyNativeDeviceTypes # Fails on XLA + @onlyNativeDeviceTypes # Fails on XLA def test_kthvalue_scalar(self, device, dtype): # Test scalar input (test case from https://github.com/pytorch/pytorch/issues/30818) # Tests that passing a scalar tensor or 1D tensor with 1 element work either way @@ -1029,7 +1165,9 @@ def assert_isin_equal(a, b): # multi-dim tensor, multi-dim tensor a = torch.arange(24, device=device, dtype=dtype).reshape([2, 3, 4]) - b = torch.tensor([[10, 20, 30], [0, 1, 3], [11, 22, 33]], device=device, dtype=dtype) + b = torch.tensor( + [[10, 20, 30], [0, 1, 3], [11, 22, 33]], device=device, dtype=dtype + ) assert_isin_equal(a, b) # zero-dim tensor @@ -1073,16 +1211,56 @@ def define_expected(lst, invert=False): c = torch.isin(a, b, assume_unique=True, invert=invert) self.assertEqual(c, ec) - a = torch.tensor([5, 4, 5, 3, 4, 4, 3, 4, 3, 5, 2, 1, 5, 5], device=device, dtype=dtype) + a = torch.tensor( + [5, 4, 5, 3, 4, 4, 3, 4, 3, 5, 2, 1, 5, 5], + device=device, + dtype=dtype, + ) b = torch.tensor([2, 3, 4] * mult, device=device, dtype=dtype) - ec = define_expected([False, True, False, True, True, True, True, True, True, - False, True, False, False, False], invert=invert) + ec = define_expected( + [ + False, + True, + False, + True, + True, + True, + True, + True, + True, + False, + True, + False, + False, + False, + ], + invert=invert, + ) c = torch.isin(a, b, invert=invert) self.assertEqual(c, ec) - b = torch.tensor([2, 3, 4] * mult + [5, 5, 4] * mult, device=device, dtype=dtype) - ec = define_expected([True, True, True, True, True, True, True, True, True, True, - True, False, True, True], invert=invert) + b = torch.tensor( + [2, 3, 4] * mult + [5, 5, 4] * mult, device=device, dtype=dtype + ) + ec = define_expected( + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + ], + invert=invert, + ) c = torch.isin(a, b, invert=invert) self.assertEqual(c, ec) @@ -1108,12 +1286,14 @@ def define_expected(lst, invert=False): for assume_unique in [False, True]: a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3]) b = torch.arange(3, 30, device=device, dtype=dtype) - ec = define_expected([[False, False, False], [True, True, True]], invert=invert) + ec = define_expected( + [[False, False, False], [True, True, True]], invert=invert + ) c = torch.isin(a, b, invert=invert, assume_unique=assume_unique) self.assertEqual(c, ec) def test_isin_different_dtypes(self, device): - supported_types = all_types() if device == 'cpu' else all_types_and(torch.half) + supported_types = all_types() if device == "cpu" else all_types_and(torch.half) for mult in [1, 10]: for assume_unique in [False, True]: for dtype1, dtype2 in product(supported_types, supported_types): @@ -1127,18 +1307,18 @@ def test_isin_different_dtypes(self, device): @dtypes(*all_types()) def test_isin_different_devices(self, device, dtype): a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3]) - b = torch.arange(3, 30, device='cpu', dtype=dtype) + b = torch.arange(3, 30, device="cpu", dtype=dtype) with self.assertRaises(RuntimeError): torch.isin(a, b) - c = torch.arange(6, device='cpu', dtype=dtype).reshape([2, 3]) + c = torch.arange(6, device="cpu", dtype=dtype).reshape([2, 3]) d = torch.arange(3, 30, device=device, dtype=dtype) with self.assertRaises(RuntimeError): torch.isin(c, d) @dtypes(*integral_types()) def test_sort_overflow(self, device, dtype): - " Regression test for https://github.com/pytorch/pytorch/issues/111189 " + "Regression test for https://github.com/pytorch/pytorch/issues/111189" prev_num_threads = torch.get_num_threads() try: low = 0 if dtype == torch.uint8 else -1 @@ -1153,5 +1333,5 @@ def test_sort_overflow(self, device, dtype): instantiate_device_type_tests(TestSortAndSelect, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() From 4ee003abdfee3eef701735d6835db0e7b9f8ec61 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Thu, 30 May 2024 17:26:49 +0100 Subject: [PATCH 124/706] [inductor] Repeat should not return a view (#127533) Fixes #127474 `as_strided` unwraps views and looks at the underlying storage, so it isn't legal to lower `repeat`, which should return a new storage, into a view. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127533 Approved by: https://github.com/lezcano --- test/inductor/test_torchinductor.py | 13 +++++++++++++ .../test_torchinductor_codegen_dynamic_shapes.py | 1 + torch/_inductor/lowering.py | 2 +- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index c6f5e61ac0f9..5b5a56394372 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4292,6 +4292,19 @@ def fn(x): (torch.randn([1, 2, 4, 8]),), ) + def test_repeat_as_strided(self): + # Reproducer for #127474 + + def fn(x): + view_size = (3, 2) + full = x.repeat((3, 2)) + view = torch.as_strided(full, view_size, full.stride()) + result = view + view + + return result + + self.common(fn, (torch.randn(1, 1),)) + def test_repeat_interleave(self): def fn(x): return ( diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index b1ccc49df499..58deed4460d8 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -135,6 +135,7 @@ def run(*ex, **kwargs): "test_zeros_dynamic_shapes": TestFailure(("cpu",)), "test_uint_dynamic_shapes": TestFailure(("cpu",)), "test_issue102546_dynamic_shapes": TestFailure(("cpu",)), + "test_repeat_as_strided_dynamic_shapes": TestFailure(("cpu",)), # # Failed to find for loop/triton kernel: # diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index fcf77cae6e3a..20b0082eb1d9 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -892,7 +892,7 @@ def repeat(x, repeats): if zero_tensor: return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): - return expand(x, new_size) + return clone(expand(x, new_size)) x_loader: Callable[[Any], Any] From e02971fcfb3a972781c713cfa2f677d451f0306a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 30 May 2024 22:06:46 +0000 Subject: [PATCH 125/706] Revert "Enable UFMT on test_shape_ops.py test_show_pickle.py test_sort_and_select.py (#127165)" This reverts commit a288b95d4e5ceed327c5bdb9696331aa87688d60. Reverted https://github.com/pytorch/pytorch/pull/127165 on behalf of https://github.com/atalman due to lint is failing ([comment](https://github.com/pytorch/pytorch/pull/127165#issuecomment-2140930658)) --- .lintrunner.toml | 3 + test/test_shape_ops.py | 243 +++++-------- test/test_show_pickle.py | 13 +- test/test_sort_and_select.py | 646 +++++++++++++---------------------- 4 files changed, 318 insertions(+), 587 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index db202cfd0fa4..1e0a2f37fcf4 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1115,6 +1115,9 @@ exclude_patterns = [ 'test/test_segment_reductions.py', 'test/test_serialization.py', 'test/test_set_default_mobile_cpu_allocator.py', + 'test/test_shape_ops.py', + 'test/test_show_pickle.py', + 'test/test_sort_and_select.py', 'test/test_sparse.py', 'test/test_sparse_csr.py', 'test/test_sparse_semi_structured.py', diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py index 5ea4139888bc..47acfff9c6d4 100644 --- a/test/test_shape_ops.py +++ b/test/test_shape_ops.py @@ -1,40 +1,22 @@ # Owner(s): ["module: tests"] +import torch +import numpy as np + +from itertools import product, combinations, permutations, chain +from functools import partial import random -import unittest import warnings -from functools import partial - -from itertools import chain, combinations, permutations, product - -import numpy as np -import torch +import unittest from torch import nan from torch.testing import make_tensor -from torch.testing._internal.common_device_type import ( - dtypes, - dtypesIfCUDA, - instantiate_device_type_tests, - largeTensorTest, - onlyCPU, - onlyCUDA, - onlyNativeDeviceTypes, -) -from torch.testing._internal.common_dtype import ( - all_types, - all_types_and, - all_types_and_complex_and, -) from torch.testing._internal.common_utils import ( - IS_JETSON, - run_tests, - skipIfTorchDynamo, - TEST_PRIVATEUSE1_DEVICE_TYPE, - TestCase, - torch_to_numpy_dtype_dict, -) - + TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict, IS_JETSON, TEST_PRIVATEUSE1_DEVICE_TYPE) +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyNativeDeviceTypes, + dtypesIfCUDA, largeTensorTest) +from torch.testing._internal.common_dtype import all_types_and_complex_and, all_types, all_types_and # TODO: replace with make_tensor def _generate_input(shape, dtype, device, with_extremal): @@ -47,19 +29,17 @@ def _generate_input(shape, dtype, device, with_extremal): x = torch.randn(*shape, device=device) * random.randint(30, 100) x = x.to(torch.bfloat16) else: - x = torch.randn(*shape, dtype=dtype, device=device) * random.randint( - 30, 100 - ) + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) x[torch.randn(*shape) > 0.5] = 0 if with_extremal and dtype.is_floating_point: # Use extremal values - x[torch.randn(*shape) > 0.5] = float("nan") - x[torch.randn(*shape) > 0.5] = float("inf") - x[torch.randn(*shape) > 0.5] = float("-inf") + x[torch.randn(*shape) > 0.5] = float('nan') + x[torch.randn(*shape) > 0.5] = float('inf') + x[torch.randn(*shape) > 0.5] = float('-inf') elif with_extremal and dtype.is_complex: - x[torch.randn(*shape) > 0.5] = complex("nan") - x[torch.randn(*shape) > 0.5] = complex("inf") - x[torch.randn(*shape) > 0.5] = complex("-inf") + x[torch.randn(*shape) > 0.5] = complex('nan') + x[torch.randn(*shape) > 0.5] = complex('inf') + x[torch.randn(*shape) > 0.5] = complex('-inf') elif dtype == torch.bool: x = torch.zeros(shape, dtype=dtype, device=device) x[torch.randn(*shape) > 0.5] = True @@ -68,8 +48,8 @@ def _generate_input(shape, dtype, device, with_extremal): return x - class TestShapeOps(TestCase): + # TODO: update to work on CUDA, too @onlyCPU def test_unbind(self, device): @@ -91,7 +71,7 @@ def test_tolist(self, device): tensor0D = torch.tensor(list0D) self.assertEqual(tensor0D.tolist(), list0D) - table1D = [1.0, 2.0, 3.0] + table1D = [1., 2., 3.] tensor1D = torch.tensor(table1D) storage = torch.Storage(table1D) self.assertEqual(tensor1D.tolist(), table1D) @@ -122,29 +102,19 @@ def test_movedim_invalid(self, device, dtype): fn(x, 0, 5) # Mismatch in size of `source` and `destination` - with self.assertRaisesRegex( - RuntimeError, "movedim: Invalid source or destination dims:" - ): - fn(x, (1, 0), (0,)) - - with self.assertRaisesRegex( - RuntimeError, "movedim: repeated dim in `source`" - ): + with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"): + fn(x, (1, 0), (0, )) + + with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): fn(x, (0, 0), (0, 1)) - with self.assertRaisesRegex( - RuntimeError, "movedim: repeated dim in `source`" - ): + with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): fn(x, (0, 1, 0), (0, 1, 2)) - with self.assertRaisesRegex( - RuntimeError, "movedim: repeated dim in `destination`" - ): + with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): fn(x, (0, 1), (1, 1)) - with self.assertRaisesRegex( - RuntimeError, "movedim: repeated dim in `destination`" - ): + with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): fn(x, (0, 1, 2), (1, 0, 1)) @dtypes(torch.int64, torch.float, torch.complex128) @@ -167,12 +137,8 @@ def test_movedim(self, device, dtype): # Integer `source` and `destination` torch_fn = partial(fn, source=src_dim, destination=dst_dim) - np_fn = partial( - np.moveaxis, source=src_dim, destination=dst_dim - ) - self.compare_with_numpy( - torch_fn, np_fn, x, device=None, dtype=None - ) + np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim) + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) if nd == 0: continue @@ -182,13 +148,9 @@ def make_index_negative(sequence, idx): sequence[random_idx] = sequence[random_idx] - nd return tuple(src_sequence) - for src_sequence in permutations( - range(nd), r=random.randint(1, nd) - ): + for src_sequence in permutations(range(nd), r=random.randint(1, nd)): # Sequence `source` and `destination` - dst_sequence = tuple( - random.sample(range(nd), len(src_sequence)) - ) + dst_sequence = tuple(random.sample(range(nd), len(src_sequence))) # Randomly change a dim to a negative dim representation of itself. random_prob = random.random() @@ -204,15 +166,9 @@ def make_index_negative(sequence, idx): random_idx = random.randint(0, len(src_sequence) - 1) src_sequence = make_index_negative(src_sequence, random_idx) - torch_fn = partial( - fn, source=src_sequence, destination=dst_sequence - ) - np_fn = partial( - np.moveaxis, source=src_sequence, destination=dst_sequence - ) - self.compare_with_numpy( - torch_fn, np_fn, x, device=None, dtype=None - ) + torch_fn = partial(fn, source=src_sequence, destination=dst_sequence) + np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence) + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) # Move dim to same position x = torch.randn(2, 3, 5, 7, 11) @@ -257,7 +213,10 @@ def test_diagonal(self, device): def test_diagonal_multidim(self, device, dtype): x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device) xn = x.numpy() - for args in [(2, 2, 3), (2,), (-2, 1, 2), (0, -2, -1)]: + for args in [(2, 2, 3), + (2,), + (-2, 1, 2), + (0, -2, -1)]: result = torch.diagonal(x, *args) expected = xn.diagonal(*args) self.assertEqual(expected.shape, result.shape) @@ -311,22 +270,14 @@ def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nan max_vals = max_vals.cpu().numpy() # Use NumPy implementation as reference - X_clamped = torch.tensor( - np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device - ) + X_clamped = torch.tensor(np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device) return X, X_clamped # Tests clamp and its alias, clip @dtypes(torch.int64, torch.float32) def test_clamp(self, device, dtype): - op_list = ( - torch.clamp, - torch.Tensor.clamp, - torch.Tensor.clamp_, - torch.clip, - torch.Tensor.clip, - torch.Tensor.clip_, - ) + op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, + torch.clip, torch.Tensor.clip, torch.Tensor.clip_) # min/max argument product args = product((-10, None), (10, None)) @@ -336,9 +287,10 @@ def test_clamp(self, device, dtype): if min_val is None and max_val is None: continue - X, Y_expected = self.generate_clamp_baseline( - device, dtype, min_vals=min_val, max_vals=max_val, with_nans=False - ) + X, Y_expected = self.generate_clamp_baseline(device, dtype, + min_vals=min_val, + max_vals=max_val, + with_nans=False) # Test op X1 = X.clone() # So that the in-place ops do not change X @@ -352,14 +304,8 @@ def test_clamp(self, device, dtype): self.assertEqual(Y_expected, Y_out) def test_clamp_propagates_nans(self, device): - op_list = ( - torch.clamp, - torch.Tensor.clamp, - torch.Tensor.clamp_, - torch.clip, - torch.Tensor.clip, - torch.Tensor.clip_, - ) + op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, + torch.clip, torch.Tensor.clip, torch.Tensor.clip_) # min/max argument product args = product((-10, None), (10, None)) @@ -369,13 +315,10 @@ def test_clamp_propagates_nans(self, device): if min_val is None and max_val is None: continue - X, Y_expected = self.generate_clamp_baseline( - device, - torch.float, - min_vals=min_val, - max_vals=max_val, - with_nans=True, - ) + X, Y_expected = self.generate_clamp_baseline(device, torch.float, + min_vals=min_val, + max_vals=max_val, + with_nans=True) Y_expected = torch.isnan(Y_expected) # Test op @@ -391,7 +334,7 @@ def test_clamp_propagates_nans(self, device): def test_clamp_raises_arg_errors(self, device): X = torch.randn(100, dtype=torch.float, device=device) - error_msg = "At least one of 'min' or 'max' must not be None" + error_msg = 'At least one of \'min\' or \'max\' must not be None' with self.assertRaisesRegex(RuntimeError, error_msg): X.clamp() with self.assertRaisesRegex(RuntimeError, error_msg): @@ -426,22 +369,18 @@ def all_t(): self.assertEqual(in_t.flip(p_dims), out_t) if len(p_dims) > 0: # Wrap 1st dim - self.assertEqual( - in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t - ) + self.assertEqual(in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t) def gen_data(): # Basic tests data = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2) nonctg = make_from_size((2, 2, 2), noncontiguous=True).copy_(data) - dims_result = ( - (0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)), - (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)), - (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)), - ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)), - ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2)), - ) + dims_result = ((0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)), + (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)), + (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)), + ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)), + ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2))) for in_tensor, (dims, out_tensor) in product((data, nonctg), dims_result): yield in_tensor, dims, out_tensor @@ -454,9 +393,7 @@ def gen_data(): yield in_t, 1, in_t # Transposed - in_t = ( - make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1) - ) + in_t = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1) dims = (0, 1, 2) out_t = make_from_data([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2) yield in_t, dims, out_t @@ -474,9 +411,7 @@ def gen_data(): if device == "cpu" and dtype != torch.bfloat16: for mf in [torch.contiguous_format, torch.channels_last]: for c in [2, 3, 8, 16]: - in_t = make_from_size((2, c, 32, 32)).contiguous( - memory_format=mf - ) + in_t = make_from_size((2, c, 32, 32)).contiguous(memory_format=mf) np_in_t = in_t.numpy() np_out_t = np_in_t[:, :, :, ::-1].copy() @@ -529,9 +464,7 @@ def gen_data(): size = [2, 3, 4] data = make_from_size(size) possible_dims = range(len(size)) - test_dims = chain( - combinations(possible_dims, 1), combinations(possible_dims, 2) - ) + test_dims = chain(combinations(possible_dims, 1), combinations(possible_dims, 2)) for dims in test_dims: self.assertEqual(size, list(data.flip(dims).size())) @@ -550,6 +483,7 @@ def test_flip_errors(self, device, dtype): self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3)) self.assertRaises(IndexError, lambda: data.flip(3)) + def _rand_shape(self, dim, min_size, max_size): return tuple(torch.randint(min_size, max_size + 1, (dim,))) @@ -570,10 +504,8 @@ def test_flip_numpy(self, device, dtype): self.compare_with_numpy(torch_fn, np_fn, data) @onlyCUDA # CPU is too slow - @largeTensorTest("17GB") # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB - @largeTensorTest( - "81GB", "cpu" - ) # even for CUDA test, sufficient system memory is required + @largeTensorTest('17GB') # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB + @largeTensorTest("81GB", "cpu") # even for CUDA test, sufficient system memory is required @unittest.skipIf(IS_JETSON, "Too large for Jetson") def test_flip_large_tensor(self, device): t_in = torch.empty(2**32 + 1, dtype=torch.uint8).random_() @@ -637,9 +569,7 @@ def test_rot90(self, device): # test tensor with more than 2D data = torch.arange(1, 9, device=device).view(2, 2, 2) - self.assertEqual( - torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2]) - ) + self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) # test for errors @@ -671,6 +601,7 @@ def test_nonzero_no_warning(self, device): @dtypes(*all_types_and(torch.half, torch.bool, torch.bfloat16)) def test_nonzero(self, device, dtype): + shapes = [ torch.Size((12,)), torch.Size((12, 1)), @@ -685,9 +616,7 @@ def gen_nontrivial_input(shape, dtype, device): return torch.randint(2, shape, device=device, dtype=dtype) else: # windows does not work for bfloat16 randing - return torch.randint(2, shape, device=device, dtype=torch.float).to( - dtype - ) + return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype) for shape in shapes: tensor = gen_nontrivial_input(shape, dtype, device) @@ -695,31 +624,20 @@ def gen_nontrivial_input(shape, dtype, device): dst2 = tensor.nonzero(as_tuple=False) dst3 = torch.empty([], dtype=torch.long, device=device) torch.nonzero(tensor, out=dst3) - if self.device_type != "xla": + if self.device_type != 'xla': # xla does not raise runtime error self.assertRaisesRegex( RuntimeError, "scalar type Long", - lambda: torch.nonzero( - tensor, out=torch.empty([], dtype=torch.float, device=device) - ), + lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.float, device=device)) ) - if ( - self.device_type == "cuda" - or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE - ): + if self.device_type == 'cuda' or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE: self.assertRaisesRegex( RuntimeError, "on the same device", - lambda: torch.nonzero( - tensor, out=torch.empty([], dtype=torch.long) - ), + lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.long)) ) - np_array = ( - tensor.cpu().numpy() - if dtype != torch.bfloat16 - else tensor.float().cpu().numpy() - ) + np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy() np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) @@ -738,9 +656,7 @@ def test_nonzero_astuple_out(self, device): with self.assertRaises(RuntimeError): torch.nonzero(t, as_tuple=True, out=out) - self.assertEqual( - torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out) - ) + self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out)) # Verifies that JIT script cannot handle the as_tuple kwarg # See Issue https://github.com/pytorch/pytorch/issues/45499. @@ -768,9 +684,7 @@ def _foo(t): def test_nonzero_discontiguous(self, device): shape = (4, 4) tensor = torch.randint(2, shape, device=device) - tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_( - tensor - ) + tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor) dst1 = tensor.nonzero(as_tuple=False) dst2 = tensor_nc.nonzero(as_tuple=False) self.assertEqual(dst1, dst2, atol=0, rtol=0) @@ -781,9 +695,7 @@ def test_nonzero_discontiguous(self, device): self.assertEqual(data_ptr, dst3.data_ptr()) self.assertEqual(dst1, dst3, atol=0, rtol=0) # discontiguous out - dst4 = torch.empty( - dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device - )[:, ::2] + dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2] data_ptr = dst4.data_ptr() strides = dst4.stride() torch.nonzero(tensor, out=dst4) @@ -798,7 +710,7 @@ def test_nonzero_non_diff(self, device): @dtypes(torch.int64, torch.float, torch.complex128) def test_sparse_dense_dim(self, device, dtype): - for shape in [(), (2,), (2, 3)]: + for shape in [(), (2, ), (2, 3)]: if dtype.is_complex or dtype.is_floating_point: x = torch.rand(shape, device=device, dtype=dtype) else: @@ -806,8 +718,7 @@ def test_sparse_dense_dim(self, device, dtype): self.assertEqual(x.sparse_dim(), 0) self.assertEqual(x.dense_dim(), len(shape)) - instantiate_device_type_tests(TestShapeOps, globals()) -if __name__ == "__main__": +if __name__ == '__main__': run_tests() diff --git a/test/test_show_pickle.py b/test/test_show_pickle.py index 48b459e12eac..929584943007 100644 --- a/test/test_show_pickle.py +++ b/test/test_show_pickle.py @@ -1,16 +1,15 @@ # Owner(s): ["oncall: mobile"] +import unittest import io import tempfile -import unittest - import torch import torch.utils.show_pickle -from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase - +from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS class TestShowPickle(TestCase): + @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows") def test_scripted_model(self): class MyCoolModule(torch.nn.Module): @@ -27,13 +26,11 @@ def forward(self, x): torch.jit.save(m, tmp) tmp.flush() buf = io.StringIO() - torch.utils.show_pickle.main( - ["", tmp.name + "@*/data.pkl"], output_stream=buf - ) + torch.utils.show_pickle.main(["", tmp.name + "@*/data.pkl"], output_stream=buf) output = buf.getvalue() self.assertRegex(output, "MyCoolModule") self.assertRegex(output, "weight") -if __name__ == "__main__": +if __name__ == '__main__': run_tests() diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 73414bb28be3..7709131e6102 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -1,68 +1,44 @@ # Owner(s): ["module: tests"] -import random -from itertools import permutations, product - -import numpy as np import torch +import numpy as np + +import random from torch import nan +from itertools import permutations, product from torch.testing import make_tensor -from torch.testing._internal.common_device_type import ( - dtypes, - dtypesIfCPU, - dtypesIfCUDA, - instantiate_device_type_tests, - largeTensorTest, - onlyCPU, - onlyCUDA, - onlyNativeDeviceTypes, -) -from torch.testing._internal.common_dtype import ( - all_types, - all_types_and, - floating_types_and, - integral_types, -) -from torch.testing._internal.common_utils import ( - run_tests, - skipIfTorchDynamo, - slowTest, - TestCase, -) +from torch.testing._internal.common_dtype import all_types, all_types_and, floating_types_and, integral_types +from torch.testing._internal.common_utils import \ + (TestCase, run_tests, slowTest, skipIfTorchDynamo) +from torch.testing._internal.common_device_type import \ + (instantiate_device_type_tests, dtypes, onlyNativeDeviceTypes, + onlyCUDA, dtypesIfCUDA, dtypesIfCPU, onlyCPU, largeTensorTest) # TODO: remove this SIZE = 100 - class TestSortAndSelect(TestCase): + def assertIsOrdered(self, order, x, mxx, ixx, task): SIZE = x.size(1) - if order == "descending": - + if order == 'descending': def check_order(a, b): # `a != a` because we put NaNs # at the end of ascending sorted lists, # and the beginning of descending ones. return ((a != a) | (a >= b)).all().item() - - elif order == "ascending": - + elif order == 'ascending': def check_order(a, b): # see above return ((b != b) | (a <= b)).all().item() - else: - error( # noqa: F821 - f'unknown order "{order}", must be "ascending" or "descending"' - ) + error(f'unknown order "{order}", must be "ascending" or "descending"') # noqa: F821 are_ordered = True for k in range(1, SIZE): - self.assertTrue( - check_order(mxx[:, k - 1], mxx[:, k]), - f"torch.sort ({order}) values unordered for {task}", - ) + self.assertTrue(check_order(mxx[:, k - 1], mxx[:, k]), + f'torch.sort ({order}) values unordered for {task}') seen = set() indicesCorrect = True @@ -74,11 +50,8 @@ def check_order(a, b): for k in range(size0): seen.clear() for j in range(size): - self.assertEqual( - x[k][ixx[k][j]], - mxx[k][j], - msg=f"torch.sort ({order}) indices wrong for {task}", - ) + self.assertEqual(x[k][ixx[k][j]], mxx[k][j], + msg=f'torch.sort ({order}) indices wrong for {task}') seen.add(ixx[k][j]) self.assertEqual(len(seen), size) @@ -106,22 +79,19 @@ def test_sort(self, device): self.assertEqual(x.argsort(), res1ind) # Test sorting of random numbers - self.assertIsOrdered("ascending", x, res2val, res2ind, "random") + self.assertIsOrdered('ascending', x, res2val, res2ind, 'random') # Test simple sort self.assertEqual( torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0], torch.tensor((10, 20, 30, 40, 50), device=device), - atol=0, - rtol=0, + atol=0, rtol=0 ) # Test that we still have proper sorting with duplicate keys x = torch.floor(torch.rand(4, SIZE, device=device) * 10) torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered( - "ascending", x, res2val, res2ind, "random with duplicate keys" - ) + self.assertIsOrdered('ascending', x, res2val, res2ind, 'random with duplicate keys') # DESCENDING SORT x = torch.rand(4, SIZE, device=device) @@ -137,41 +107,35 @@ def test_sort(self, device): self.assertEqual(x.argsort(x.dim() - 1, True), res1ind) # Test sorting of random numbers - self.assertIsOrdered("descending", x, res2val, res2ind, "random") + self.assertIsOrdered('descending', x, res2val, res2ind, 'random') # Test simple sort task self.assertEqual( - torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[ - 0 - ], + torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[0], torch.tensor((50, 40, 30, 20, 10), device=device), - atol=0, - rtol=0, + atol=0, rtol=0 ) # Test that we still have proper sorting with duplicate keys - self.assertIsOrdered( - "descending", x, res2val, res2ind, "random with duplicate keys" - ) + self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys') # Test argument sorting with and without stable x = torch.tensor([1, 10, 2, 2, 3, 7, 7, 8, 9, 9] * 3) - self.assertEqual( - torch.argsort(x, stable=True), torch.sort(x, stable=True).indices - ) - self.assertEqual( - torch.argsort(x, stable=False), torch.sort(x, stable=False).indices - ) + self.assertEqual(torch.argsort(x, stable=True), torch.sort(x, stable=True).indices) + self.assertEqual(torch.argsort(x, stable=False), torch.sort(x, stable=False).indices) self.assertEqual(torch.argsort(x), torch.sort(x).indices) + # Test sorting with NaNs x = torch.rand(4, SIZE, device=device) - x[1][2] = float("NaN") - x[3][0] = float("NaN") + x[1][2] = float('NaN') + x[3][0] = float('NaN') torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered("ascending", x, res2val, res2ind, "random with NaNs") + self.assertIsOrdered('ascending', x, res2val, res2ind, + 'random with NaNs') torch.sort(x, out=(res2val, res2ind), descending=True) - self.assertIsOrdered("descending", x, res2val, res2ind, "random with NaNs") + self.assertIsOrdered('descending', x, res2val, res2ind, + 'random with NaNs') def test_sort_stable_none(self): # Called sort with stable=None used to trigger an assertion @@ -205,19 +169,19 @@ def test_stable_sort(self, device, dtype): _, idx = x.sort(stable=True) self.assertEqual( idx[:ncopies], - torch.arange(start=0, end=2 * ncopies, step=2, device=device), + torch.arange(start=0, end=2 * ncopies, step=2, device=device) ) self.assertEqual( idx[ncopies:], - torch.arange(start=1, end=2 * ncopies, step=2, device=device), + torch.arange(start=1, end=2 * ncopies, step=2, device=device) ) @onlyCUDA @dtypes(torch.uint8) - @largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough + @largeTensorTest('200GB') # Unfortunately 80GB A100 is not large enough def test_sort_large(self, device, dtype): t0 = torch.randperm(8192, device=device).to(dtype) - t = t0.view(1, 8192).expand(2**18 + 1, -1).contiguous() + t = t0.view(1, 8192).expand(2 ** 18 + 1, -1).contiguous() v, i = t.sort() del t iv, im = i.var_mean(dim=0) @@ -229,6 +193,7 @@ def test_sort_large(self, device, dtype): self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device)) self.assertEqual(im, t0.sort().indices) + @dtypes(torch.float32) def test_sort_restride(self, device, dtype): # Input: non-contiguous (stride: 5) 3-element array @@ -258,24 +223,14 @@ def _test_sort_discontiguous(self, device, dtype): n = t.size(dim) # assert ordered - self.assertTrue( - ( - r1.values.narrow(dim, 1, n - 1) - >= r1.values.narrow(dim, 0, n - 1) - ).all() - ) + self.assertTrue((r1.values.narrow(dim, 1, n - 1) >= r1.values.narrow(dim, 0, n - 1)).all()) # assert that different segments does not mix, which can easily happen # if the stride is not handled correctly - self.assertTrue( - (t.unsqueeze(-1).transpose(dim, -1) == r1.values.unsqueeze(-1)) - .any(dim=dim) - .any(dim=-1) - .all() - ) + self.assertTrue((t.unsqueeze(-1).transpose(dim, -1) == r1.values.unsqueeze(-1)).any(dim=dim).any(dim=-1).all()) # assert stride is preserved - if self.device_type == "cuda": + if self.device_type == 'cuda': # FIXME: this behavior should be true for all cases, not # just the one specified in if condition self.assertEqual(r1.values.stride(), t.stride()) @@ -307,9 +262,7 @@ def test_sort_1d_output_discontiguous(self, device, dtype): @dtypes(*integral_types()) def test_sort_1d_parallel(self, device, dtype): low = 0 if dtype == torch.uint8 else -128 - tensor = torch.randint( - low=low, high=127, size=(100000,), device=device, dtype=dtype - ) + tensor = torch.randint(low=low, high=127, size=(100000, ), device=device, dtype=dtype) vals, _ = torch.sort(tensor, stable=True) self.assertEqual(True, torch.all(vals[:-1] <= vals[1:])) @@ -330,9 +283,9 @@ def test_topk_1d_output_discontiguous(self, device, dtype): @dtypes(*all_types_and(torch.half, torch.bfloat16)) def test_stable_sort_against_numpy(self, device, dtype): if dtype in floating_types_and(torch.float16, torch.bfloat16): - inf = float("inf") - neg_inf = -float("inf") - nan = float("nan") + inf = float('inf') + neg_inf = -float('inf') + nan = float('nan') else: if dtype != torch.bool: # no torch.iinfo support for torch.bool @@ -352,7 +305,7 @@ def generate_samples(): # binary strings yield (torch.tensor([0, 1] * size, dtype=dtype, device=device), 0) - if self.device_type == "cuda": + if self.device_type == 'cuda': return yield (torch.tensor([0, 1] * 100, dtype=dtype, device=device), 0) @@ -373,21 +326,13 @@ def repeated_index_fill(t, dim, idxs, vals): # for each dimension. n_fill_vals = 3 # cardinality of (inf, neg_inf, nan) for dim in range(len(sizes)): - idxs = ( - torch.randint(high=size, size=(size // 10,)) - for i in range(n_fill_vals) - ) + idxs = (torch.randint(high=size, size=(size // 10,)) for i in range(n_fill_vals)) vals = (inf, neg_inf, nan) - subsets = chain.from_iterable( - combinations(list(zip(idxs, vals)), r) - for r in range(1, n_fill_vals + 1) - ) + subsets = chain.from_iterable(combinations(list(zip(idxs, vals)), r) + for r in range(1, n_fill_vals + 1)) for subset in subsets: idxs_subset, vals_subset = zip(*subset) - yield ( - repeated_index_fill(x, dim, idxs_subset, vals_subset), - dim, - ) + yield (repeated_index_fill(x, dim, idxs_subset, vals_subset), dim) for sample, dim in generate_samples(): _, idx_torch = sample.sort(dim=dim, stable=True) @@ -395,7 +340,7 @@ def repeated_index_fill(t, dim, idxs, vals): sample_numpy = sample.float().cpu().numpy() else: sample_numpy = sample.cpu().numpy() - idx_numpy = np.argsort(sample_numpy, axis=dim, kind="stable") + idx_numpy = np.argsort(sample_numpy, axis=dim, kind='stable') self.assertEqual(idx_torch, idx_numpy) @dtypes(*all_types_and(torch.half, torch.bfloat16)) @@ -404,9 +349,7 @@ def test(shape): tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9) if tensor.size() != torch.Size([]): if dtype is torch.bfloat16: - expected = torch.from_numpy( - np.msort(tensor.float().cpu().numpy()) - ).bfloat16() + expected = torch.from_numpy(np.msort(tensor.float().cpu().numpy())).bfloat16() else: expected = torch.from_numpy(np.msort(tensor.cpu().numpy())) else: @@ -421,15 +364,11 @@ def test(shape): shapes = ( [], - [ - 0, - ], - [ - 20, - ], + [0, ], + [20, ], [1, 20], [30, 30], - [10, 20, 30], + [10, 20, 30] ) for shape in shapes: test(shape) @@ -475,12 +414,9 @@ def compare(t, k, dim, dir): sortKVal, sortKInd = topKViaSort(t, k, dim, dir) compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim) - t = torch.rand( - random.randint(1, SIZE), - random.randint(1, SIZE), - random.randint(1, SIZE), - device=device, - ) + t = torch.rand(random.randint(1, SIZE), + random.randint(1, SIZE), + random.randint(1, SIZE), device=device) for _kTries in range(3): for _dimTries in range(3): @@ -521,94 +457,91 @@ def test_topk_arguments(self, device): self.assertRaises(TypeError, lambda: q.topk(4, True)) def test_unique_dim(self, device): - self.assertFalse(hasattr(torch, "unique_dim")) + self.assertFalse(hasattr(torch, 'unique_dim')) def run_test(device, dtype): - x = torch.tensor( - [ - [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], - [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], - ], - dtype=dtype, - device=device, - ) + x = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]], + [[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], + dtype=dtype, + device=device) x_empty = torch.empty(5, 0, dtype=dtype, device=device) x_ill_formed_empty = torch.empty(5, 0, 0, dtype=dtype, device=device) - x_ill_formed_empty_another = torch.empty( - 5, 0, 5, dtype=dtype, device=device - ) + x_ill_formed_empty_another = torch.empty(5, 0, 5, dtype=dtype, device=device) if dtype in floating_types_and(torch.float16, torch.bfloat16): - x_nan = torch.tensor( - [float("nan"), 0, 0, float("nan"), float("nan"), 1], - dtype=dtype, - device=device, - ) - expected_unique_dim0 = torch.tensor( - [[[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]]], - dtype=dtype, - device=device, - ) + x_nan = torch.tensor([float("nan"), 0, 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) + expected_unique_dim0 = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], + dtype=dtype, + device=device) expected_inverse_dim0 = torch.tensor([0, 0]) expected_counts_dim0 = torch.tensor([2]) - expected_unique_dim1 = torch.tensor( - [ - [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]], - [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]], - ], - dtype=dtype, - device=device, - ) - expected_unique_dim1_bool = torch.tensor( - [[[False, True], [True, True]], [[False, True], [True, True]]], - dtype=torch.bool, - device=device, - ) + expected_unique_dim1 = torch.tensor([[[0., 1.], + [1., 1.], + [2., 1.]], + [[0., 1.], + [1., 1.], + [2., 1.]]], + dtype=dtype, + device=device) + expected_unique_dim1_bool = torch.tensor([[[False, True], [True, True]], + [[False, True], [True, True]]], + dtype=torch.bool, + device=device) expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) expected_inverse_dim1_bool = torch.tensor([1, 0, 1, 0]) expected_counts_dim1 = torch.tensor([2, 1, 1]) expected_counts_dim1_bool = torch.tensor([2, 2]) - expected_unique_dim2 = torch.tensor( - [ - [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], - [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], - ], - dtype=dtype, - device=device, - ) + expected_unique_dim2 = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]], + [[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], + dtype=dtype, + device=device) expected_inverse_dim2 = torch.tensor([0, 1]) expected_counts_dim2 = torch.tensor([1, 1]) expected_unique_empty = torch.empty(5, 0, dtype=dtype, device=device) expected_inverse_empty = torch.tensor([], dtype=torch.long, device=device) expected_counts_empty = torch.tensor([], dtype=torch.long, device=device) if dtype in floating_types_and(torch.float16, torch.bfloat16): - expected_unique_nan = torch.tensor( - [float("nan"), 0, float("nan"), float("nan"), 1], - dtype=dtype, - device=device, - ) - expected_inverse_nan = torch.tensor( - [0, 1, 1, 2, 3, 4], dtype=torch.long, device=device - ) - expected_counts_nan = torch.tensor( - [1, 2, 1, 1, 1], dtype=torch.long, device=device - ) + expected_unique_nan = torch.tensor([float("nan"), 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) + expected_inverse_nan = torch.tensor([0, 1, 1, 2, 3, 4], dtype=torch.long, device=device) + expected_counts_nan = torch.tensor([1, 2, 1, 1, 1], dtype=torch.long, device=device) # dim0 x_unique = torch.unique(x, dim=0) self.assertEqual(expected_unique_dim0, x_unique) - x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=0) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) x_unique, x_counts = torch.unique( - x, return_inverse=False, return_counts=True, dim=0 - ) + x, + return_inverse=False, + return_counts=True, + dim=0) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_counts_dim0, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, return_inverse=True, return_counts=True, dim=0 - ) + x, + return_inverse=True, + return_counts=True, + dim=0) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) self.assertEqual(expected_counts_dim0, x_counts) @@ -620,7 +553,10 @@ def run_test(device, dtype): else: self.assertEqual(expected_unique_dim1, x_unique) - x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=1) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_inverse_dim1_bool, x_inverse) @@ -629,8 +565,10 @@ def run_test(device, dtype): self.assertEqual(expected_inverse_dim1, x_inverse) x_unique, x_counts = torch.unique( - x, return_inverse=False, return_counts=True, dim=1 - ) + x, + return_inverse=False, + return_counts=True, + dim=1) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_counts_dim1_bool, x_counts) @@ -639,8 +577,10 @@ def run_test(device, dtype): self.assertEqual(expected_counts_dim1, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, return_inverse=True, return_counts=True, dim=1 - ) + x, + return_inverse=True, + return_counts=True, + dim=1) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_inverse_dim1_bool, x_inverse) @@ -654,27 +594,36 @@ def run_test(device, dtype): x_unique = torch.unique(x, dim=2) self.assertEqual(expected_unique_dim2, x_unique) - x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=2) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) x_unique, x_counts = torch.unique( - x, return_inverse=False, return_counts=True, dim=2 - ) + x, + return_inverse=False, + return_counts=True, + dim=2) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_counts_dim2, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, return_inverse=True, return_counts=True, dim=2 - ) + x, + return_inverse=True, + return_counts=True, + dim=2) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) self.assertEqual(expected_counts_dim2, x_counts) # test empty tensor x_unique, x_inverse, x_counts = torch.unique( - x_empty, return_inverse=True, return_counts=True, dim=1 - ) + x_empty, + return_inverse=True, + return_counts=True, + dim=1) self.assertEqual(expected_unique_empty, x_unique) self.assertEqual(expected_inverse_empty, x_inverse) self.assertEqual(expected_counts_empty, x_counts) @@ -682,8 +631,10 @@ def run_test(device, dtype): # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): x_unique, x_inverse, x_counts = torch.unique( - x_nan, return_inverse=True, return_counts=True, dim=0 - ) + x_nan, + return_inverse=True, + return_counts=True, + dim=0) self.assertEqual(expected_unique_nan, x_unique) self.assertEqual(expected_inverse_nan, x_inverse) self.assertEqual(expected_counts_nan, x_counts) @@ -692,8 +643,10 @@ def run_test(device, dtype): # Checking for runtime error, as this is the expected behaviour with self.assertRaises(RuntimeError): torch.unique( - x_ill_formed_empty, return_inverse=True, return_counts=True, dim=1 - ) + x_ill_formed_empty, + return_inverse=True, + return_counts=True, + dim=1) # test along dim2 with self.assertRaises(RuntimeError): @@ -701,66 +654,46 @@ def run_test(device, dtype): x_ill_formed_empty_another, return_inverse=True, return_counts=True, - dim=2, - ) + dim=2) # test consecutive version y = torch.tensor( - [ - [0, 1], - [0, 1], - [0, 1], - [1, 2], - [1, 2], - [3, 4], - [0, 1], - [0, 1], - [3, 4], - [1, 2], - ], + [[0, 1], + [0, 1], + [0, 1], + [1, 2], + [1, 2], + [3, 4], + [0, 1], + [0, 1], + [3, 4], + [1, 2]], dtype=dtype, - device=device, + device=device ) # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): - y_nan = torch.tensor( - [float("nan"), 0, 0, float("nan"), float("nan"), 1], - dtype=dtype, - device=device, - ) + y_nan = torch.tensor([float("nan"), 0, 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) expected_y_unique = torch.tensor( - [[0, 1], [1, 2], [3, 4], [0, 1], [3, 4], [1, 2]], + [[0, 1], + [1, 2], + [3, 4], + [0, 1], + [3, 4], + [1, 2]], dtype=dtype, - device=device, - ) - expected_y_inverse = torch.tensor( - [0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device - ) - expected_y_counts = torch.tensor( - [3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device - ) - expected_y_inverse_bool = torch.tensor( - [0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device - ) - expected_y_counts_bool = torch.tensor( - [3, 3, 2, 2], dtype=torch.int64, device=device + device=device ) + expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device) + expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device) + expected_y_inverse_bool = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device) + expected_y_counts_bool = torch.tensor([3, 3, 2, 2], dtype=torch.int64, device=device) if dtype in floating_types_and(torch.float16, torch.bfloat16): - expected_y_unique_nan = torch.tensor( - [float("nan"), 0, float("nan"), float("nan"), 1], - dtype=dtype, - device=device, - ) - expected_y_inverse_nan = torch.tensor( - [0, 1, 1, 2, 3, 4], dtype=torch.long, device=device - ) - expected_y_counts_nan = torch.tensor( - [1, 2, 1, 1, 1], dtype=torch.long, device=device - ) - - y_unique, y_inverse, y_counts = torch.unique_consecutive( - y, return_inverse=True, return_counts=True, dim=0 - ) + expected_y_unique_nan = torch.tensor([float("nan"), 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) + expected_y_inverse_nan = torch.tensor([0, 1, 1, 2, 3, 4], dtype=torch.long, device=device) + expected_y_counts_nan = torch.tensor([1, 2, 1, 1, 1], dtype=torch.long, device=device) + + y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0) if x.dtype == torch.bool: self.assertEqual(expected_y_inverse_bool, y_inverse) self.assertEqual(expected_y_counts_bool, y_counts) @@ -771,27 +704,23 @@ def run_test(device, dtype): # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): y_unique, y_inverse, y_counts = torch.unique_consecutive( - y_nan, return_inverse=True, return_counts=True, dim=0 - ) + y_nan, + return_inverse=True, + return_counts=True, + dim=0) self.assertEqual(expected_y_unique_nan, y_unique) self.assertEqual(expected_y_inverse_nan, y_inverse) self.assertEqual(expected_y_counts_nan, y_counts) # Test dim is sorted same as NumPy with dims >= 3 - x = torch.tensor( - [ - [ - [[1, 0, 1, 0, 1, 1], [0, 1, 1, 0, 1, 1]], - [[0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0]], - ], - [ - [[0, 1, 0, 1, 1, 1], [0, 1, 1, 0, 1, 1]], - [[0, 0, 1, 1, 0, 1], [1, 1, 0, 0, 0, 0]], - ], - ], - dtype=dtype, - device=device, - ) + x = torch.tensor([[[[1, 0, 1, 0, 1, 1], + [0, 1, 1, 0, 1, 1]], + [[0, 1, 1, 0, 0, 1], + [0, 0, 0, 1, 0, 0]]], + [[[0, 1, 0, 1, 1, 1], + [0, 1, 1, 0, 1, 1]], + [[0, 0, 1, 1, 0, 1], + [1, 1, 0, 0, 0, 0]]]], dtype=dtype, device=device) xn = x.cpu().numpy() for d in range(x.dim()): t = torch.unique(x, dim=d) @@ -821,20 +750,15 @@ def test_topk_noncontiguous_gpu(self, device): def _test_topk_dtype(self, device, dtype, integral, size): if integral: - a = torch.randint( - torch.iinfo(dtype).min, - torch.iinfo(dtype).max, - size=(size,), - dtype=dtype, - device=device, - ) + a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, + size=(size,), dtype=dtype, device=device) else: a = torch.randn(size=(size,), dtype=dtype, device=device) - sort_topk = a.sort()[0][-(size // 2) :].flip(0) + sort_topk = a.sort()[0][-(size // 2):].flip(0) topk = a.topk(size // 2) - self.assertEqual(sort_topk, topk[0]) # check values - self.assertEqual(sort_topk, a[topk[1]]) # check indices + self.assertEqual(sort_topk, topk[0]) # check values + self.assertEqual(sort_topk, a[topk[1]]) # check indices @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64) def test_topk_integral(self, device, dtype): @@ -846,6 +770,7 @@ def test_topk_integral(self, device, dtype): @dtypes(torch.bfloat16, torch.half) def test_topk_lower_precision(self, device, dtype): + small = 10 large = 4096 verylarge = 8192 # multi_block topk on cuda @@ -855,20 +780,14 @@ def test_topk_lower_precision(self, device, dtype): @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) @dtypes(torch.float, torch.double, torch.bfloat16, torch.half) def test_topk_nonfinite(self, device, dtype): - x = torch.tensor( - [float("nan"), float("inf"), 1e4, 0, -1e4, -float("inf")], - device=device, - dtype=dtype, - ) + x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype) val, idx = x.topk(4) - expect = torch.tensor( - [float("nan"), float("inf"), 1e4, 0], device=device, dtype=dtype - ) + expect = torch.tensor([float('nan'), float('inf'), 1e4, 0], device=device, dtype=dtype) self.assertEqual(val, expect) self.assertEqual(idx, [0, 1, 2, 3]) val, idx = x.topk(4, largest=False) - expect = torch.tensor([-float("inf"), -1e4, 0, 1e4], device=device, dtype=dtype) + expect = torch.tensor([-float('inf'), -1e4, 0, 1e4], device=device, dtype=dtype) self.assertEqual(val, expect) self.assertEqual(idx, [5, 4, 3, 2]) @@ -877,13 +796,13 @@ def test_topk_4d(self, device): large = 8192 for size in (small, large): x = torch.ones(2, size, 2, 2, device=device) - x[:, 1, :, :] *= 2.0 + x[:, 1, :, :] *= 2. x[:, 10, :, :] *= 1.5 val, ind = torch.topk(x, k=2, dim=1) expected_ind = torch.ones(2, 2, 2, 2, dtype=torch.long, device=device) expected_ind[:, 1, :, :] = 10 expected_val = torch.ones(2, 2, 2, 2, device=device) - expected_val[:, 0, :, :] *= 2.0 + expected_val[:, 0, :, :] *= 2. expected_val[:, 1, :, :] *= 1.5 self.assertEqual(val, expected_val, atol=0, rtol=0) self.assertEqual(ind, expected_ind, atol=0, rtol=0) @@ -919,17 +838,7 @@ def _test_unique_scalar_empty(self, dtype, device, f): self.assertEqual(inverse, expected_inverse) self.assertEqual(counts, expected_counts) - def _test_unique_with_expects( - self, - device, - dtype, - f, - x, - expected_unique, - expected_inverse, - expected_counts, - additional_shape, - ): + def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape): def ensure_tuple(x): if isinstance(x, torch.Tensor): return (x,) @@ -938,9 +847,7 @@ def ensure_tuple(x): for return_inverse in [True, False]: for return_counts in [True, False]: # test with expected - ret = ensure_tuple( - f(x, return_inverse=return_inverse, return_counts=return_counts) - ) + ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) self.assertEqual(expected_unique, ret[0]) if return_inverse: @@ -951,9 +858,7 @@ def ensure_tuple(x): # tests per-element unique on a higher rank tensor. y = x.view(additional_shape) - y_unique, y_inverse, y_counts = f( - y, return_inverse=True, return_counts=True - ) + y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True) self.assertEqual(expected_unique, y_unique) self.assertEqual(expected_inverse.view(additional_shape), y_inverse) self.assertEqual(expected_counts, y_counts) @@ -967,17 +872,9 @@ def ensure_tuple(x): return x if dtype is torch.bool: - x = torch.tensor( - [True, False, False, False, True, False, True, False], - dtype=torch.bool, - device=device, - ) - expected_unique = torch.tensor( - [False, True], dtype=torch.bool, device=device - ) - expected_inverse = torch.tensor( - [1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device - ) + x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device) + expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device) + expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device) expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device) else: x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device) @@ -993,29 +890,18 @@ def ensure_tuple(x): x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x) xs = (x, x_sliced) for f, x in product(fs, xs): - self._test_unique_with_expects( - device, - dtype, - f, - x, - expected_unique, - expected_inverse, - expected_counts, - (2, 2, 2), - ) + self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2)) self._test_unique_scalar_empty(dtype, device, f) # test unsorted unique fs = ( lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs), - lambda x, **kwargs: x.unique(sorted=False, **kwargs), + lambda x, **kwargs: x.unique(sorted=False, **kwargs) ) for f, x in product(fs, xs): self._test_unique_scalar_empty(dtype, device, f) for return_inverse, return_counts in product((True, False), repeat=2): - ret = ensure_tuple( - f(x, return_inverse=return_inverse, return_counts=return_counts) - ) + ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) x_list = x.tolist() x_unique_list = ret[0].tolist() @@ -1038,40 +924,18 @@ def ensure_tuple(x): @dtypes(*all_types_and(torch.half, torch.bool)) def test_unique_consecutive(self, device, dtype): if dtype is torch.bool: - x = torch.tensor( - [True, False, False, False, True, True, False, False, False], - dtype=torch.bool, - device=device, - ) - expected_unique = torch.tensor( - [True, False, True, False], dtype=torch.bool, device=device - ) - expected_inverse = torch.tensor( - [0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device - ) - expected_counts = torch.tensor( - [1, 3, 2, 3], dtype=torch.long, device=device - ) + x = torch.tensor([True, False, False, False, True, True, False, False, False], dtype=torch.bool, device=device) + expected_unique = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) + expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device) + expected_counts = torch.tensor([1, 3, 2, 3], dtype=torch.long, device=device) else: x = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], dtype=dtype, device=device) expected_unique = torch.tensor([1, 2, 5, 2, 3], dtype=dtype, device=device) expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device) expected_counts = torch.tensor([1, 3, 2, 2, 1], device=device) - for f in [ - torch.unique_consecutive, - lambda x, **kwargs: x.unique_consecutive(**kwargs), - ]: - self._test_unique_with_expects( - device, - dtype, - f, - x, - expected_unique, - expected_inverse, - expected_counts, - (3, 3), - ) + for f in [torch.unique_consecutive, lambda x, **kwargs: x.unique_consecutive(**kwargs)]: + self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (3, 3)) self._test_unique_scalar_empty(dtype, device, f) @dtypes(torch.double) @@ -1127,7 +991,7 @@ def test_kthvalue(self, device, dtype): self.assertEqual(x, x0, atol=0, rtol=0) # simple test case (with repetitions) - y = torch.tensor((3.0, 5, 4, 1, 1, 5), dtype=dtype, device=device) + y = torch.tensor((3., 5, 4, 1, 1, 5), dtype=dtype, device=device) self.assertEqual(torch.kthvalue(y, 3)[0], 3, atol=0, rtol=0) self.assertEqual(torch.kthvalue(y, 2)[0], 1, atol=0, rtol=0) @@ -1143,7 +1007,7 @@ def test_kthvalue(self, device, dtype): self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) @dtypes(torch.float) - @onlyNativeDeviceTypes # Fails on XLA + @onlyNativeDeviceTypes # Fails on XLA def test_kthvalue_scalar(self, device, dtype): # Test scalar input (test case from https://github.com/pytorch/pytorch/issues/30818) # Tests that passing a scalar tensor or 1D tensor with 1 element work either way @@ -1165,9 +1029,7 @@ def assert_isin_equal(a, b): # multi-dim tensor, multi-dim tensor a = torch.arange(24, device=device, dtype=dtype).reshape([2, 3, 4]) - b = torch.tensor( - [[10, 20, 30], [0, 1, 3], [11, 22, 33]], device=device, dtype=dtype - ) + b = torch.tensor([[10, 20, 30], [0, 1, 3], [11, 22, 33]], device=device, dtype=dtype) assert_isin_equal(a, b) # zero-dim tensor @@ -1211,56 +1073,16 @@ def define_expected(lst, invert=False): c = torch.isin(a, b, assume_unique=True, invert=invert) self.assertEqual(c, ec) - a = torch.tensor( - [5, 4, 5, 3, 4, 4, 3, 4, 3, 5, 2, 1, 5, 5], - device=device, - dtype=dtype, - ) + a = torch.tensor([5, 4, 5, 3, 4, 4, 3, 4, 3, 5, 2, 1, 5, 5], device=device, dtype=dtype) b = torch.tensor([2, 3, 4] * mult, device=device, dtype=dtype) - ec = define_expected( - [ - False, - True, - False, - True, - True, - True, - True, - True, - True, - False, - True, - False, - False, - False, - ], - invert=invert, - ) + ec = define_expected([False, True, False, True, True, True, True, True, True, + False, True, False, False, False], invert=invert) c = torch.isin(a, b, invert=invert) self.assertEqual(c, ec) - b = torch.tensor( - [2, 3, 4] * mult + [5, 5, 4] * mult, device=device, dtype=dtype - ) - ec = define_expected( - [ - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - False, - True, - True, - ], - invert=invert, - ) + b = torch.tensor([2, 3, 4] * mult + [5, 5, 4] * mult, device=device, dtype=dtype) + ec = define_expected([True, True, True, True, True, True, True, True, True, True, + True, False, True, True], invert=invert) c = torch.isin(a, b, invert=invert) self.assertEqual(c, ec) @@ -1286,14 +1108,12 @@ def define_expected(lst, invert=False): for assume_unique in [False, True]: a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3]) b = torch.arange(3, 30, device=device, dtype=dtype) - ec = define_expected( - [[False, False, False], [True, True, True]], invert=invert - ) + ec = define_expected([[False, False, False], [True, True, True]], invert=invert) c = torch.isin(a, b, invert=invert, assume_unique=assume_unique) self.assertEqual(c, ec) def test_isin_different_dtypes(self, device): - supported_types = all_types() if device == "cpu" else all_types_and(torch.half) + supported_types = all_types() if device == 'cpu' else all_types_and(torch.half) for mult in [1, 10]: for assume_unique in [False, True]: for dtype1, dtype2 in product(supported_types, supported_types): @@ -1307,18 +1127,18 @@ def test_isin_different_dtypes(self, device): @dtypes(*all_types()) def test_isin_different_devices(self, device, dtype): a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3]) - b = torch.arange(3, 30, device="cpu", dtype=dtype) + b = torch.arange(3, 30, device='cpu', dtype=dtype) with self.assertRaises(RuntimeError): torch.isin(a, b) - c = torch.arange(6, device="cpu", dtype=dtype).reshape([2, 3]) + c = torch.arange(6, device='cpu', dtype=dtype).reshape([2, 3]) d = torch.arange(3, 30, device=device, dtype=dtype) with self.assertRaises(RuntimeError): torch.isin(c, d) @dtypes(*integral_types()) def test_sort_overflow(self, device, dtype): - "Regression test for https://github.com/pytorch/pytorch/issues/111189" + " Regression test for https://github.com/pytorch/pytorch/issues/111189 " prev_num_threads = torch.get_num_threads() try: low = 0 if dtype == torch.uint8 else -1 @@ -1333,5 +1153,5 @@ def test_sort_overflow(self, device, dtype): instantiate_device_type_tests(TestSortAndSelect, globals()) -if __name__ == "__main__": +if __name__ == '__main__': run_tests() From 3b88c27c46923e2beaeb7d1e75e20383c1926963 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 30 May 2024 22:40:48 +0000 Subject: [PATCH 126/706] Mark DynamicShapesExportTests::test_retracibility_dict_container_inp_out as slow (#127558) Same as https://github.com/pytorch/pytorch/pull/117896, another slowpoke `DynamicShapesExportTests::test_retracibility_dict_container_inp_out` shows up on recently on MacOS. For example, https://ossci-raw-job-status.s3.amazonaws.com/log/25585713394 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127558 Approved by: https://github.com/clee2000 --- test/dynamo/test_dynamic_shapes.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index df86179657ca..175ed573391b 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -100,6 +100,10 @@ def make_dynamic_cls(cls): DynamicShapesExportTests.test_retracibility_dynamic_shapes = slowTest( # noqa: F821 DynamicShapesExportTests.test_retracibility_dynamic_shapes # noqa: F821 ) +# Also take more than 30m as of 15cc9f2e7e7b2b175f24755925dc38d4d430905d +DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes = slowTest( # noqa: F821 + DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes # noqa: F821 +) if __name__ == "__main__": from torch._dynamo.test_case import run_tests From 029af29e6d2a91bc8a6e15d445fec42a49c0454e Mon Sep 17 00:00:00 2001 From: laithsakka Date: Wed, 29 May 2024 17:00:01 -0700 Subject: [PATCH 127/706] support operator.index function (#127440) Fix https://github.com/pytorch/pytorch/issues/127426 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127440 Approved by: https://github.com/mlazos ghstack dependencies: #126444, #127146, #127424 --- test/dynamo/test_functions.py | 18 ++++++++++++++++++ torch/_dynamo/variables/builtin.py | 8 ++++++++ 2 files changed, 26 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 20b9fadcf015..e2baebf60321 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2312,6 +2312,24 @@ def test(x, y): test(True, False) test(torch.ones(4, dtype=torch.float32), 1.1) + def test_index(self): + def fn(x, t): + v = operator.index(x) + torch.mul(t, v) + + def test(a, b): + self.assertEqual(opt_fn(a, b), fn(a, b)) + + for dynamic in [True, False]: + torch._dynamo.reset() + opt_fn = torch._dynamo.optimize(dynamic=dynamic)(fn) + t = torch.ones(1) + test(10, t) + test(-100, t) + test(10, t) + test(False, t) + test(True, t) + def test_truth(self): def fn(x, y): return operator.truth(x) and bool(y) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 306a17c018f9..a9f6272d0571 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1129,6 +1129,14 @@ def call_pos(self, tx, arg: "VariableTracker"): ) return pos_method.call_function(tx, [], {}) + def call_index(self, tx, arg: "VariableTracker"): + if isinstance(arg, variables.TensorVariable): + unimplemented("unsupported index(tensor)") + + arg = guard_if_dyn(arg) + constant_value = operator.index(arg) + return variables.ConstantVariable.create(constant_value) + def call_round(self, tx, arg, *args, **kwargs): # Call arg.__round__() round_method = BuiltinVariable(getattr).call_function( From f9a1bc2c65b6d199e19577d1c75677d2084ed086 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Mon, 22 Apr 2024 16:21:13 -0700 Subject: [PATCH 128/706] [FSDP] Remove _sync_module_states (#124678) Remove this unused API Differential Revision: [D56445639](https://our.internmc.facebook.com/intern/diff/D56445639/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124678 Approved by: https://github.com/awgu --- torch/distributed/fsdp/_init_utils.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 5b811f50a032..295400220b4f 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -1099,24 +1099,6 @@ def _sync_module_params_and_buffers( ) -def _sync_module_states( - params: List[nn.Parameter], - buffers: List[torch.Tensor], - process_group: dist.ProcessGroup, -) -> None: - # Assumes that each call to this method passes in disjoint `params` and - # and `buffers` across calls, so there is no chance of re-synchronizing - params_and_buffers = [param.detach() for param in params] + [ - buffer.detach() for buffer in buffers - ] - _check_module_states_for_sync_module_states(params_and_buffers) - _sync_params_and_buffers( - process_group, - params_and_buffers, - PARAM_BROADCAST_BUCKET_SIZE, - src=0, - ) - def _check_module_states_for_sync_module_states( module_states: List[torch.Tensor], From a3c00e43319f6d3d448bfe893cf2a79c4bd6ddda Mon Sep 17 00:00:00 2001 From: eellison Date: Wed, 29 May 2024 17:56:58 -0700 Subject: [PATCH 129/706] [Easy] Move V.fake_mode inside of replace_by_example (#127494) Was writing docs and saw that we always have this duplicated usage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127494 Approved by: https://github.com/shunting314, https://github.com/aorenste --- torch/_inductor/fx_passes/joint_graph.py | 7 ++---- torch/_inductor/fx_passes/post_grad.py | 12 ++++------ torch/_inductor/pattern_matcher.py | 30 +++++++++++++++--------- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 3302dfd63292..bf282ee72ba8 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -7,7 +7,6 @@ import torch import torch._guards from torch._inductor.constant_folding import ConstantFolder -from torch._inductor.virtualized import V from torch.fx.experimental.symbolic_shapes import statically_known_true from torch.multiprocessing.reductions import StorageWeakRef @@ -463,8 +462,7 @@ def repl(inp, other): max_ = torch.amax(inp, dim=dim, keepdim=keepdim) return (inp - max_) * (sign * other) - with V.fake_mode: - match.replace_by_example(repl, [inp, other]) + match.replace_by_example(repl, [inp, other]) for reverse, to_dtype in itertools.product((False, True), repeat=2): @@ -491,8 +489,7 @@ def repl(inp, other): max_ = torch.amax(inp, dim=dim, keepdim=keepdim) return (inp - max_) / (sign * other) - with V.fake_mode: - match.replace_by_example(repl, [inp, other]) + match.replace_by_example(repl, [inp, other]) for to_dtype in (False, True): diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 585d261787e4..b18577a02ffc 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -347,8 +347,7 @@ def repl(*shape): # only replace the output node, not all nodes match.nodes = [match.output_node()] - with V.fake_mode: - match.replace_by_example(repl, list(shape)) + match.replace_by_example(repl, list(shape)) def shape_of_mm(a, b): @@ -708,8 +707,7 @@ def decomp(*flat_args): args, kwargs = pytree.tree_unflatten(flat_args, spec) return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs) - with V.fake_mode: - match.replace_by_example(decomp, flat_args, run_dce=False) + match.replace_by_example(decomp, flat_args, run_dce=False) graph_pass.apply(graph) for node in graph.find_nodes( @@ -825,8 +823,7 @@ def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): def repl(inp, x1, x2): return x1 @ x2 + inp - with V.fake_mode: - match.replace_by_example(repl, [inp, mat1, mat2]) + match.replace_by_example(repl, [inp, mat1, mat2]) def is_valid_addmm_fusion(match): @@ -869,8 +866,7 @@ def addmm(match, mat1, mat2, *, inp): def repl(inp, mat1, mat2): return aten.addmm(inp, mat1, mat2) - with V.fake_mode: - match.replace_by_example(repl, [inp, mat1, mat2]) + match.replace_by_example(repl, [inp, mat1, mat2]) def check_shape_cuda_and_fused_int_mm_mul_enabled(match): diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index b9b66874aba3..e91873ea933a 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1,5 +1,7 @@ from __future__ import annotations +import contextlib + import dataclasses import functools import importlib @@ -133,17 +135,23 @@ def replace_with_graph(self, replacement_graph, args): def replace_by_example(self, replacement_fn, args, trace_fn=None, run_dce=True): assert self.ctx - if trace_fn is None: - trace_fn = functools.partial(fwd_only, run_dce=run_dce) - replacement = trace_fn( - replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) - ) - ReplacementPatternEntry.replace_with_graph( - self, - self.ctx.graph, - replacement, - args, - ) + + from torch._inductor.virtualized import V + + context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext + + with context: + if trace_fn is None: + trace_fn = functools.partial(fwd_only, run_dce=run_dce) + replacement = trace_fn( + replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + ) + ReplacementPatternEntry.replace_with_graph( + self, + self.ctx.graph, + replacement, + args, + ) class FailedMatch(RuntimeError): From 74b89b9283373b2137fcef74508dc9a38c8097c9 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 29 May 2024 17:11:07 -0700 Subject: [PATCH 130/706] Extract dot-product functions from fp16_gemv_trans gemv kernels (#127435) Summary: Refactoring step before we attempt to use these to implement a less bad fp16 GEMM. Test Plan: Existing tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127435 Approved by: https://github.com/malfet --- aten/src/ATen/native/BlasKernel.cpp | 102 +++++++++++++++------------- 1 file changed, 55 insertions(+), 47 deletions(-) diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index fb4289eb989a..df0512f95390 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -218,10 +218,9 @@ static inline float16_t reduce(float16x8_t x) { /* * NOTE [ GGML Copyright Notice ] - * The below reduce overload and - * fp16_gemv_trans_fp16_arith_by_dot_products function is adapted from - * llama.cpp's ggml_vec_dot_f16 and surrounding utility functions, so - * here is the required copyright notice: + * The below reduce overload and fp16_dot_with_fp16_arith function is + * adapted from llama.cpp's ggml_vec_dot_f16 and surrounding utility + * functions, so here is the required copyright notice: * * MIT License * @@ -279,29 +278,33 @@ static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) { #endif } +static float fp16_dot_with_fp16_arith(const float16_t* x, const float16_t* a, int len) { + float16x8_t sum[kF16RegistersPerIteration] = {vdupq_n_f16(0)}; + + const auto len_aligned = len & ~(kF16ElementsPerIteration - 1); + for (int j = 0; j < len_aligned ; j += kF16ElementsPerIteration) { + for (int k = 0; k < kF16RegistersPerIteration; ++k) { + const auto temp_x = vld1q_f16(x + j + k * kF16ElementsPerRegister); + const auto temp_a = vld1q_f16(a + j + k * kF16ElementsPerRegister); + sum[k] = f16_fma(sum[k], temp_x, temp_a); + } + } + auto reducedSum = reduce(sum); + + for (int j = len_aligned; j < len; ++j) { + reducedSum += x[j] * a[j]; + } + return reducedSum; +} + // Rather than unrolling to process multiple rows (transposed columns) // of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll // along an individual dot product. static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { parallel_for(0, n, 1, [&](int begin, int end) { - for (int i = begin; i < end; ++i) { - float16x8_t sum[kF16RegistersPerIteration] = {vdupq_n_f16(0)}; - - const auto m_aligned = m & ~(kF16ElementsPerIteration - 1); - for (int j = 0; j < m_aligned ; j += kF16ElementsPerIteration) { - for (int k = 0; k < kF16RegistersPerIteration; ++k) { - const auto temp_x = vld1q_f16(x + j + k * kF16ElementsPerRegister); - const auto temp_a = vld1q_f16(a + lda * i + j + k * kF16ElementsPerRegister); - sum[k] = f16_fma(sum[k], temp_x, temp_a); - } - } - auto reducedSum = reduce(sum); - - for (int j = m_aligned; j < m; ++j) { - reducedSum += x[j] * a[lda * i + j]; - } - y[i * incy] = reducedSum; - } + for (int i = begin; i < end; ++i) { + y[i * incy] = fp16_dot_with_fp16_arith(x, a + lda * i, m); + } }); } @@ -341,10 +344,10 @@ static inline float32x4_t f32_fma_high_f16(float32x4_t a, float16x8_t b, float16 #endif } -// The below reduce overload and -// fp16_gemv_trans_fp32_arith_by_dot_products are adapted from -// llama.cpp's ggml_vec_dot_f32 and surrounding utility functions. See -// NOTE [ GGML Copyright Notice ] above for the required notice. +// The below reduce overload and fp16_dot_with_fp32_arith are adapted +// from llama.cpp's ggml_vec_dot_f32 and surrounding utility +// functions. See NOTE [ GGML Copyright Notice ] above for the +// required notice. // We need the shift for reduce(), hence the extra constants. static constexpr auto kF32ElementsPerIterationShift = 5; @@ -372,32 +375,37 @@ static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { return vaddvq_f32(x[0]); } +static float fp16_dot_with_fp32_arith(const float16_t* x, const float16_t* a, int len) { + float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; + const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); + for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { + const auto* x_ = x + j; + const auto* a_ = a + j; + c10::ForcedUnroll{}([x_, a_, &sum](auto k) { + // Load a pair of f32 registers at a time. + const auto temp_x = vld1q_f16(x_ + k * 2 * kF32ElementsPerRegister); + const auto temp_a = vld1q_f16(a_ + k * 2 * kF32ElementsPerRegister); + + sum[2 * k] = f32_fma_low_f16(sum[2 * k], temp_x, temp_a); + sum[2 * k + 1] = f32_fma_high_f16(sum[2 * k + 1], temp_x, temp_a); + }); + } + auto reducedSum = reduce(sum); + + for (int j = len_aligned; j < len; ++j) { + reducedSum += x[j] * a[j]; + } + return reducedSum; +} + // On my Apple M1 Macbook (which is ARM v8.5 and thus has the // instructions f32_fma_{low,high}_f16 is targeting), this kernel has // equivalent performance to the fp16-native kernel. static void fp16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { parallel_for(0, n, 1, [&](int begin, int end) { - for (int i = begin; i < end; ++i) { - float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; - - const auto m_aligned = m & ~(kF32ElementsPerIteration - 1); - for (int j = 0; j < m_aligned ; j += kF32ElementsPerIteration) { - c10::ForcedUnroll{}([x, a, lda, i, j, &sum](auto k) { - // Load a pair of f32 registers at a time. - const auto temp_x = vld1q_f16(x + j + k * 2 * kF32ElementsPerRegister); - const auto temp_a = vld1q_f16(a + lda * i + j + k * 2 * kF32ElementsPerRegister); - - sum[2 * k] = f32_fma_low_f16(sum[2 * k], temp_x, temp_a); - sum[2 * k + 1] = f32_fma_high_f16(sum[2 * k + 1], temp_x, temp_a); - }); - } - auto reducedSum = reduce(sum); - - for (int j = m_aligned; j < m; ++j) { - reducedSum += x[j] * a[lda * i + j]; - } - y[i * incy] = reducedSum; - } + for (int i = begin; i < end; ++i) { + y[i * incy] = fp16_dot_with_fp32_arith(x, a + lda * i, m); + } }); } From 603bde1de376cd242bf365ef361a281809c8e6ab Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 29 May 2024 17:11:08 -0700 Subject: [PATCH 131/706] Use efficient ARM fp16 dot product for gemm_transa_ general case (#127451) Summary: This doesn't change the overall gemm algorithm away from repeated dot products, just uses our efficient fp16 dot product developed for the gemv case. It seems to improve performance for every prompt length I tested. Test Plan: Use https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py , edited to test only the trans_b (really gemm_transa_) case for the sizes outlined in the output. Before: ``` Matrix-vector: m=8, n=128, k=1 ==================== trans_b torch.float32 1.05 usec trans_b torch.float16 0.97 usec trans_b torch.bfloat16 1.06 usec m=128, n=8, k=1 ==================== trans_b torch.float32 0.80 usec trans_b torch.float16 0.97 usec trans_b torch.bfloat16 1.00 usec m=4096, n=4096, k=1 ==================== trans_b torch.float32 2160.75 usec trans_b torch.float16 659.77 usec trans_b torch.bfloat16 3800.13 usec m=11008, n=4096, k=1 ==================== trans_b torch.float32 6343.68 usec trans_b torch.float16 1789.42 usec trans_b torch.bfloat16 10098.34 usec m=4096, n=11008, k=1 ==================== trans_b torch.float32 6217.20 usec trans_b torch.float16 1874.47 usec trans_b torch.bfloat16 10490.30 usec m=32000, n=4096, k=1 ==================== trans_b torch.float32 17934.45 usec trans_b torch.float16 5323.81 usec trans_b torch.bfloat16 29320.80 usec Matrix-matrix (prompt len 4: m=8, n=128, k=4 ==================== trans_b torch.float32 2.40 usec trans_b torch.float16 1.22 usec trans_b torch.bfloat16 1.22 usec m=128, n=8, k=4 ==================== trans_b torch.float32 1.52 usec trans_b torch.float16 1.33 usec trans_b torch.bfloat16 1.77 usec m=4096, n=4096, k=4 ==================== trans_b torch.float32 4317.09 usec trans_b torch.float16 15541.04 usec trans_b torch.bfloat16 15032.29 usec m=11008, n=4096, k=4 ==================== trans_b torch.float32 6191.19 usec trans_b torch.float16 40436.29 usec trans_b torch.bfloat16 40626.93 usec m=4096, n=11008, k=4 ==================== trans_b torch.float32 6049.22 usec trans_b torch.float16 42367.16 usec trans_b torch.bfloat16 42482.43 usec m=32000, n=4096, k=4 ==================== trans_b torch.float32 17611.36 usec trans_b torch.float16 117368.54 usec trans_b torch.bfloat16 116958.85 usec Matrix-matrix (prompt len 8: m=8, n=128, k=8 ==================== trans_b torch.float32 1.04 usec trans_b torch.float16 1.71 usec trans_b torch.bfloat16 1.74 usec m=128, n=8, k=8 ==================== trans_b torch.float32 2.10 usec trans_b torch.float16 2.01 usec trans_b torch.bfloat16 2.91 usec m=4096, n=4096, k=8 ==================== trans_b torch.float32 2456.23 usec trans_b torch.float16 30112.76 usec trans_b torch.bfloat16 29941.58 usec m=11008, n=4096, k=8 ==================== trans_b torch.float32 6236.12 usec trans_b torch.float16 80361.22 usec trans_b torch.bfloat16 80466.64 usec m=4096, n=11008, k=8 ==================== trans_b torch.float32 6236.10 usec trans_b torch.float16 82990.74 usec trans_b torch.bfloat16 83899.80 usec m=32000, n=4096, k=8 ==================== trans_b torch.float32 17606.43 usec trans_b torch.float16 234397.38 usec trans_b torch.bfloat16 237057.29 usec Matrix-matrix (prompt len 16: m=8, n=128, k=16 ==================== trans_b torch.float32 1.31 usec trans_b torch.float16 2.67 usec trans_b torch.bfloat16 2.72 usec m=128, n=8, k=16 ==================== trans_b torch.float32 1.66 usec trans_b torch.float16 3.36 usec trans_b torch.bfloat16 5.18 usec m=4096, n=4096, k=16 ==================== trans_b torch.float32 2504.24 usec trans_b torch.float16 60896.53 usec trans_b torch.bfloat16 59852.49 usec m=11008, n=4096, k=16 ==================== trans_b torch.float32 6407.11 usec trans_b torch.float16 163294.92 usec trans_b torch.bfloat16 161199.10 usec m=4096, n=11008, k=16 ==================== trans_b torch.float32 6132.30 usec trans_b torch.float16 167244.77 usec trans_b torch.bfloat16 170064.35 usec m=32000, n=4096, k=16 ==================== trans_b torch.float32 17635.56 usec trans_b torch.float16 475020.00 usec trans_b torch.bfloat16 476332.29 usec Matrix-matrix (prompt len 32: m=8, n=128, k=32 ==================== trans_b torch.float32 1.40 usec trans_b torch.float16 4.67 usec trans_b torch.bfloat16 4.80 usec m=128, n=8, k=32 ==================== trans_b torch.float32 1.24 usec trans_b torch.float16 6.10 usec trans_b torch.bfloat16 10.03 usec m=4096, n=4096, k=32 ==================== trans_b torch.float32 2660.63 usec trans_b torch.float16 122436.04 usec trans_b torch.bfloat16 121687.96 usec m=11008, n=4096, k=32 ==================== trans_b torch.float32 6405.60 usec trans_b torch.float16 324708.42 usec trans_b torch.bfloat16 324866.67 usec m=4096, n=11008, k=32 ==================== trans_b torch.float32 6566.74 usec trans_b torch.float16 330801.04 usec trans_b torch.bfloat16 332561.79 usec m=32000, n=4096, k=32 ==================== trans_b torch.float32 18610.84 usec trans_b torch.float16 944578.75 usec trans_b torch.bfloat16 940674.33 usec Matrix-matrix (prompt len 128: m=8, n=128, k=128 ==================== trans_b torch.float32 2.48 usec trans_b torch.float16 16.43 usec trans_b torch.bfloat16 17.11 usec m=128, n=8, k=128 ==================== trans_b torch.float32 1.83 usec trans_b torch.float16 22.31 usec trans_b torch.bfloat16 37.00 usec m=4096, n=4096, k=128 ==================== trans_b torch.float32 4806.59 usec trans_b torch.float16 485338.83 usec trans_b torch.bfloat16 478835.08 usec m=11008, n=4096, k=128 ==================== trans_b torch.float32 12109.51 usec trans_b torch.float16 1300928.58 usec trans_b torch.bfloat16 1293181.63 usec m=4096, n=11008, k=128 ==================== trans_b torch.float32 11223.70 usec trans_b torch.float16 1326119.92 usec trans_b torch.bfloat16 1330395.12 usec m=32000, n=4096, k=128 ==================== trans_b torch.float32 33485.34 usec trans_b torch.float16 3869227.17 usec trans_b torch.bfloat16 3792905.00 usec ``` After: ``` Matrix-vector: m=8, n=128, k=1 ==================== trans_b torch.float32 0.75 usec trans_b torch.float16 0.71 usec trans_b torch.bfloat16 0.81 usec m=128, n=8, k=1 ==================== trans_b torch.float32 0.75 usec trans_b torch.float16 0.93 usec trans_b torch.bfloat16 0.98 usec m=4096, n=4096, k=1 ==================== trans_b torch.float32 2194.31 usec trans_b torch.float16 661.27 usec trans_b torch.bfloat16 3758.42 usec m=11008, n=4096, k=1 ==================== trans_b torch.float32 5792.04 usec trans_b torch.float16 1789.98 usec trans_b torch.bfloat16 10120.67 usec m=4096, n=11008, k=1 ==================== trans_b torch.float32 6101.22 usec trans_b torch.float16 1927.34 usec trans_b torch.bfloat16 10469.47 usec m=32000, n=4096, k=1 ==================== trans_b torch.float32 18353.20 usec trans_b torch.float16 5161.06 usec trans_b torch.bfloat16 29601.69 usec Matrix-matrix (prompt len 4: m=8, n=128, k=4 ==================== trans_b torch.float32 2.14 usec trans_b torch.float16 0.85 usec trans_b torch.bfloat16 1.19 usec m=128, n=8, k=4 ==================== trans_b torch.float32 1.47 usec trans_b torch.float16 1.85 usec trans_b torch.bfloat16 1.75 usec m=4096, n=4096, k=4 ==================== trans_b torch.float32 4416.40 usec trans_b torch.float16 2688.36 usec trans_b torch.bfloat16 14987.33 usec m=11008, n=4096, k=4 ==================== trans_b torch.float32 6140.24 usec trans_b torch.float16 7467.26 usec trans_b torch.bfloat16 40295.52 usec m=4096, n=11008, k=4 ==================== trans_b torch.float32 6143.10 usec trans_b torch.float16 7298.04 usec trans_b torch.bfloat16 41393.43 usec m=32000, n=4096, k=4 ==================== trans_b torch.float32 17650.72 usec trans_b torch.float16 21346.63 usec trans_b torch.bfloat16 116849.98 usec Matrix-matrix (prompt len 8: m=8, n=128, k=8 ==================== trans_b torch.float32 1.05 usec trans_b torch.float16 1.03 usec trans_b torch.bfloat16 1.69 usec m=128, n=8, k=8 ==================== trans_b torch.float32 2.05 usec trans_b torch.float16 3.08 usec trans_b torch.bfloat16 2.95 usec m=4096, n=4096, k=8 ==================== trans_b torch.float32 2323.99 usec trans_b torch.float16 5265.45 usec trans_b torch.bfloat16 29942.40 usec m=11008, n=4096, k=8 ==================== trans_b torch.float32 6202.01 usec trans_b torch.float16 14677.90 usec trans_b torch.bfloat16 80625.18 usec m=4096, n=11008, k=8 ==================== trans_b torch.float32 6112.05 usec trans_b torch.float16 14340.52 usec trans_b torch.bfloat16 82799.99 usec m=32000, n=4096, k=8 ==================== trans_b torch.float32 17650.65 usec trans_b torch.float16 42551.43 usec trans_b torch.bfloat16 236081.08 usec Matrix-matrix (prompt len 16: m=8, n=128, k=16 ==================== trans_b torch.float32 1.26 usec trans_b torch.float16 1.34 usec trans_b torch.bfloat16 2.69 usec m=128, n=8, k=16 ==================== trans_b torch.float32 1.60 usec trans_b torch.float16 5.81 usec trans_b torch.bfloat16 5.34 usec m=4096, n=4096, k=16 ==================== trans_b torch.float32 2328.05 usec trans_b torch.float16 10526.58 usec trans_b torch.bfloat16 60028.28 usec m=11008, n=4096, k=16 ==================== trans_b torch.float32 6243.35 usec trans_b torch.float16 28505.08 usec trans_b torch.bfloat16 163670.15 usec m=4096, n=11008, k=16 ==================== trans_b torch.float32 5870.11 usec trans_b torch.float16 28597.89 usec trans_b torch.bfloat16 165404.88 usec m=32000, n=4096, k=16 ==================== trans_b torch.float32 17746.27 usec trans_b torch.float16 83393.87 usec trans_b torch.bfloat16 472313.13 usec Matrix-matrix (prompt len 32: m=8, n=128, k=32 ==================== trans_b torch.float32 1.35 usec trans_b torch.float16 2.01 usec trans_b torch.bfloat16 4.68 usec m=128, n=8, k=32 ==================== trans_b torch.float32 1.19 usec trans_b torch.float16 10.98 usec trans_b torch.bfloat16 10.13 usec m=4096, n=4096, k=32 ==================== trans_b torch.float32 2525.29 usec trans_b torch.float16 23106.71 usec trans_b torch.bfloat16 122987.04 usec m=11008, n=4096, k=32 ==================== trans_b torch.float32 6131.34 usec trans_b torch.float16 57537.41 usec trans_b torch.bfloat16 327825.00 usec m=4096, n=11008, k=32 ==================== trans_b torch.float32 6395.01 usec trans_b torch.float16 57456.33 usec trans_b torch.bfloat16 331325.58 usec m=32000, n=4096, k=32 ==================== trans_b torch.float32 19078.68 usec trans_b torch.float16 167735.08 usec trans_b torch.bfloat16 975736.88 usec Matrix-matrix (prompt len 128: m=8, n=128, k=128 ==================== trans_b torch.float32 2.40 usec trans_b torch.float16 6.07 usec trans_b torch.bfloat16 16.83 usec m=128, n=8, k=128 ==================== trans_b torch.float32 1.78 usec trans_b torch.float16 40.35 usec trans_b torch.bfloat16 37.21 usec m=4096, n=4096, k=128 ==================== trans_b torch.float32 4827.60 usec trans_b torch.float16 84341.24 usec trans_b torch.bfloat16 478917.75 usec m=11008, n=4096, k=128 ==================== trans_b torch.float32 11879.96 usec trans_b torch.float16 226484.33 usec trans_b torch.bfloat16 1289465.50 usec m=4096, n=11008, k=128 ==================== trans_b torch.float32 10707.75 usec trans_b torch.float16 229200.58 usec trans_b torch.bfloat16 1327416.67 usec m=32000, n=4096, k=128 ==================== trans_b torch.float32 33306.32 usec trans_b torch.float16 662898.21 usec trans_b torch.bfloat16 3815866.63 usec ``` torch.float16 performance seems to be improved for all except the m=128, n=8, k=128 case, where it is roughly neutral. This case motivated the addition of the "first-tier tail fixup" in the dot kernel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127451 Approved by: https://github.com/malfet ghstack dependencies: #127435 --- aten/src/ATen/native/BlasKernel.cpp | 22 ++++++++++++++++++++-- aten/src/ATen/native/cpu/BlasKernel.cpp | 19 +++++++++++++------ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index df0512f95390..930b2aed7b79 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -344,6 +344,10 @@ static inline float32x4_t f32_fma_high_f16(float32x4_t a, float16x8_t b, float16 #endif } +static inline float32x4_t f32_fma_f16(float32x4_t a, float16x4_t b, float16x4_t c) { + return f32_fma_low_f16(a, vcombine_f16(b, vdup_n_f16(0)), vcombine_f16(c, vdup_n_f16(0))); +} + // The below reduce overload and fp16_dot_with_fp32_arith are adapted // from llama.cpp's ggml_vec_dot_f32 and surrounding utility // functions. See NOTE [ GGML Copyright Notice ] above for the @@ -375,7 +379,7 @@ static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { return vaddvq_f32(x[0]); } -static float fp16_dot_with_fp32_arith(const float16_t* x, const float16_t* a, int len) { +float fp16_dot_with_fp32_arith(const float16_t* x, const float16_t* a, int64_t len) { float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { @@ -392,7 +396,21 @@ static float fp16_dot_with_fp32_arith(const float16_t* x, const float16_t* a, in } auto reducedSum = reduce(sum); - for (int j = len_aligned; j < len; ++j) { + // First-tier tail fixup: make sure we handle workloads that can + // benefit from vectorization, but don't fit into our fully unrolled + // loop above. + float32x4_t tailSum = vdupq_n_f32(0); + const auto len_aligned_4 = len & ~3; + for (int j = len_aligned; j < len_aligned_4; j += 4) { + const auto temp_x = vld1_f16(x + j); + const auto temp_a = vld1_f16(a + j); + tailSum = f32_fma_f16(tailSum, temp_x, temp_a); + } + auto reducedTail = vpaddq_f32(tailSum, tailSum); + reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0); + + // Second-tier tail fixup: handle all workloads. + for (int j = len_aligned_4; j < len; ++j) { reducedSum += x[j] * a[j]; } return reducedSum; diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index 587809ea57c8..387d6840999a 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -33,6 +33,11 @@ void fp16_gemv_trans( const float beta, float16_t* y, const int incy); + +float fp16_dot_with_fp32_arith( + const float16_t* x, + const float16_t* a, + int64_t len); } #endif @@ -308,18 +313,20 @@ void gemm_notrans_( } -inline float32x4_t load_as_float32x4(const Half* ptr) { - return vcvt_f32_f16(vld1_f16(reinterpret_cast(ptr))); -} - inline float32x4_t load_as_float32x4(const BFloat16* ptr) { int32x4_t shift = vdupq_n_s32(16); uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast(ptr))); return vreinterpretq_f32_u32(vshlq_u32(as_int, shift)); } -template -static float compute_dot(const T* a, const T* b, int64_t l) { +static float compute_dot(const at::Half* a, const at::Half* b, int64_t len) { + return at::native::blas_impl::fp16_dot_with_fp32_arith( + reinterpret_cast(a), + reinterpret_cast(b), + len); +} + +static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t l) { if ((l&3) != 0) { return sum(l, [&](int64_t i) -> float { return float(a[i]) * float(b[i]); From 620ec081ec350173cb55d05755f9ce7af708c4ae Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 29 May 2024 17:11:09 -0700 Subject: [PATCH 132/706] Extract inner loops into separate function for ARM64 fp16_dot_with_fp32_arith (#127476) Summary: Preparing to generalize to bf16. (This should not be committed unless the following bf16 PR is committed!) Test Plan: Spot-checked llm_experiments benchmark result to make sure it didn't regress. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127476 Approved by: https://github.com/malfet ghstack dependencies: #127435, #127451 --- aten/src/ATen/native/BlasKernel.cpp | 44 ++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index 930b2aed7b79..79076544fc33 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -379,19 +379,37 @@ static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { return vaddvq_f32(x[0]); } -float fp16_dot_with_fp32_arith(const float16_t* x, const float16_t* a, int64_t len) { +static C10_ALWAYS_INLINE void fp16_dot_with_fp32_arith_main_inner_loop( + const float16_t* vec1, + const float16_t* vec2, + float32x4_t sum[kF32RegistersPerIteration], + int registerPairIndex) { + // Load a pair of f32 registers at a time. + const auto temp_vec1 = vld1q_f16(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]); + const auto temp_vec2 = vld1q_f16(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]); + + sum[2 * registerPairIndex] = f32_fma_low_f16(sum[2 * registerPairIndex], temp_vec1, temp_vec2); + sum[2 * registerPairIndex + 1] = f32_fma_high_f16(sum[2 * registerPairIndex + 1], temp_vec1, temp_vec2); +} + +static C10_ALWAYS_INLINE void fp16_dot_with_fp32_arith_vectorized_tail_inner_loop( + const float16_t* vec1, + const float16_t* vec2, + float32x4_t* tailSum, + int idx) { + const auto temp_vec1 = vld1_f16(&vec1[idx]); + const auto temp_vec2 = vld1_f16(&vec2[idx]); + *tailSum = f32_fma_f16(*tailSum, temp_vec1, temp_vec2); +} + +float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int64_t len) { float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { - const auto* x_ = x + j; - const auto* a_ = a + j; - c10::ForcedUnroll{}([x_, a_, &sum](auto k) { - // Load a pair of f32 registers at a time. - const auto temp_x = vld1q_f16(x_ + k * 2 * kF32ElementsPerRegister); - const auto temp_a = vld1q_f16(a_ + k * 2 * kF32ElementsPerRegister); - - sum[2 * k] = f32_fma_low_f16(sum[2 * k], temp_x, temp_a); - sum[2 * k + 1] = f32_fma_high_f16(sum[2 * k + 1], temp_x, temp_a); + const auto* vec1_ = vec1 + j; + const auto* vec2_ = vec2 + j; + c10::ForcedUnroll{}([vec1_, vec2_, &sum](auto k) { + fp16_dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k); }); } auto reducedSum = reduce(sum); @@ -402,16 +420,14 @@ float fp16_dot_with_fp32_arith(const float16_t* x, const float16_t* a, int64_t l float32x4_t tailSum = vdupq_n_f32(0); const auto len_aligned_4 = len & ~3; for (int j = len_aligned; j < len_aligned_4; j += 4) { - const auto temp_x = vld1_f16(x + j); - const auto temp_a = vld1_f16(a + j); - tailSum = f32_fma_f16(tailSum, temp_x, temp_a); + fp16_dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j); } auto reducedTail = vpaddq_f32(tailSum, tailSum); reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0); // Second-tier tail fixup: handle all workloads. for (int j = len_aligned_4; j < len; ++j) { - reducedSum += x[j] * a[j]; + reducedSum += vec1[j] * vec2[j]; } return reducedSum; } From 214dd44608a92802f9c13471451ae09cf6b25fd0 Mon Sep 17 00:00:00 2001 From: Shuqiang Zhang Date: Wed, 29 May 2024 15:42:43 -0700 Subject: [PATCH 133/706] [c10d] add Work's numel to logger for debugging purposes (#127468) Summary: We have seen some cases that all ranks call into a collective but it got stuck probably due to incorrect sizes of the tensors. Adding the size info into logging for debugging Also, taking this chance to consolidate all logger related status metrics in to one struct Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/127468 Approved by: https://github.com/wconstab --- .../distributed/c10d/ProcessGroupNCCL.cpp | 59 +++++++++++-------- .../distributed/c10d/ProcessGroupNCCL.hpp | 52 +++++++++------- 2 files changed, 65 insertions(+), 46 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 1ea278a44e3c..be2853efc113 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1248,9 +1248,9 @@ void ProcessGroupNCCL::heartbeatMonitor() { "Received a dump signal from this local rank and will ", "start to dump the debug info. ", "Last enqueued NCCL work: ", - lastEnqueuedSeq_, + pgStatus_.lastEnqueuedSeq, ", last completed NCCL work: ", - lastCompletedSeq_, + pgStatus_.lastCompletedSeq, "."); exitMsg = c10::str( "ProcessGroupNCCL's watchdog detected an exception from the local rank. ", @@ -1310,9 +1310,9 @@ void ProcessGroupNCCL::heartbeatMonitor() { timeOutRank, ", and will start to dump the debug info. ", "Last enqueued NCCL work: ", - lastEnqueuedSeq_, + pgStatus_.lastEnqueuedSeq, ", last completed NCCL work: ", - lastCompletedSeq_, + pgStatus_.lastCompletedSeq, "."); exitMsg = c10::str( "ProcessGroupNCCL's watchdog detected a dump signal from rank ", @@ -1578,9 +1578,9 @@ void ProcessGroupNCCL::watchdogHandler() { logPrefix(), "NCCL Work update periodically: ", "last enqueued NCCL work: ", - lastEnqueuedSeq_, + pgStatus_.lastEnqueuedSeq, ", last completed NCCL work: ", - lastCompletedSeq_, + pgStatus_.lastCompletedSeq, "."); #endif auto logger = ::c10d::C10dLogger::getLogger(); @@ -1593,13 +1593,19 @@ void ProcessGroupNCCL::watchdogHandler() { data.integers["pg_id"] = uid_; data.integers["rank"] = rank_; data.integers["global_rank"] = globalRank(); - data.integers["last_enqueued_work"] = lastEnqueuedSeq_; - data.integers["last_started_work"] = lastStartedSeq_; - data.integers["last_completed_work"] = lastCompletedSeq_; + data.integers["last_enqueued_work"] = pgStatus_.lastEnqueuedSeq; + data.integers["last_started_work"] = pgStatus_.lastStartedSeq; + data.integers["last_completed_work"] = pgStatus_.lastCompletedSeq; + data.integers["last_enqueued_numel_in"] = pgStatus_.lastEnqueuedNumelIn; + data.integers["last_enqueued_numel_out"] = pgStatus_.lastEnqueuedNumelOut; + data.integers["last_completed_numel_in"] = pgStatus_.lastCompletedNumelIn; + data.integers["last_completed_numel_out"] = + pgStatus_.lastCompletedNumelOut; // logging strings - data.strings["last_enqueued_work_name"] = lastEnqueuedWorkName_; - data.strings["last_started_work_name"] = lastStartedWorkName_; - data.strings["last_completed_work_name"] = lastCompletedWorkName_; + data.strings["last_enqueued_work_name"] = pgStatus_.lastEnqueuedWorkName; + data.strings["last_started_work_name"] = pgStatus_.lastStartedWorkName; + data.strings["last_completed_work_name"] = + pgStatus_.lastCompletedWorkName; data.strings["pg_name"] = pg_name_; data.strings["pg_desc"] = pg_desc_; logger->log(data); @@ -1626,9 +1632,9 @@ void ProcessGroupNCCL::watchdogHandler() { "Exception (either an error or timeout) detected by watchdog at work: ", work.seq_, ", last enqueued NCCL work: ", - lastEnqueuedSeq_, + pgStatus_.lastEnqueuedSeq, ", last completed NCCL work: ", - lastCompletedSeq_, + pgStatus_.lastCompletedSeq, "."); // try to dump flight records if exception happens. // Flight recorder behavior should be independent of desync Debug @@ -1671,9 +1677,9 @@ void ProcessGroupNCCL::watchdogHandler() { "Timeout at NCCL work: ", work.seq_, ", last enqueued NCCL work: ", - lastEnqueuedSeq_, + pgStatus_.lastEnqueuedSeq, ", last completed NCCL work: ", - lastCompletedSeq_, + pgStatus_.lastCompletedSeq, "."); if (desyncDebug_) { try { @@ -1708,18 +1714,20 @@ void ProcessGroupNCCL::watchdogHandler() { } // a work could be started but not completed, so we should not update - // lastStartedSeq_ and lastStartedOpName_ if the work state is checked + // lastStartedSeq and lastStartedOpName if the work state is checked // multiple times after the start - if (lastStartedSeq_ < static_cast(work.seq_) && + if (pgStatus_.lastStartedSeq < static_cast(work.seq_) && work.isStarted()) { - lastStartedSeq_ = work.seq_; - lastStartedWorkName_ = opTypeToString(work.opType_); + pgStatus_.lastStartedSeq = work.seq_; + pgStatus_.lastStartedWorkName = opTypeToString(work.opType_); } // Clean up completed work if (work.isCompleted()) { - lastCompletedSeq_ = work.seq_; - lastCompletedWorkName_ = opTypeToString(work.opType_); + pgStatus_.lastCompletedSeq = work.seq_; + pgStatus_.lastCompletedWorkName = opTypeToString(work.opType_); + pgStatus_.lastCompletedNumelIn = work.numelIn_; + pgStatus_.lastCompletedNumelOut = work.numelOut_; NCCLTraceBuffer::get()->retire_id(work.trace_id_, true); if (onCompletionHook_) { // Move Work object to completedWorkList_ to be consumed by the hook @@ -2348,8 +2356,11 @@ void ProcessGroupNCCL::workEnqueue( // needs to be destructed in user thread. Otherwise will // get deadlock. Here we enqueue work without outputs_. workMetaList_.emplace_back(*work); - lastEnqueuedSeq_ = work->seq_; - lastEnqueuedWorkName_ = opTypeToString(work->opType_); + // update the PG status related to the last enqueued work + pgStatus_.lastEnqueuedSeq = work->seq_; + pgStatus_.lastEnqueuedWorkName = opTypeToString(work->opType_); + pgStatus_.lastEnqueuedNumelIn = work->numelIn_; + pgStatus_.lastEnqueuedNumelOut = work->numelOut_; lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); } } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 1655de8a7848..117c24ebfb82 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -437,6 +437,34 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::string group_name; }; + // A struct to hold the latest status of the process group. + struct ProcessGroupStatus { + // the sequential number of the last collective enqueued into workMetaList_ + // This is useful for indentifying a rank that has not join a collective + // initialized to be -1 to indicate no collective has been enqueued + int64_t lastEnqueuedSeq{-1}; + // the sequential number of the last collective started as the kernel + int64_t lastStartedSeq{-1}; + // the sequential number of the last colletive completed marked by + // the watchdog thread + // initialized to be -1 to indicate no collective has been completed + int64_t lastCompletedSeq{-1}; + + // the name of the last collective enqueued into workMetaList_ + std::string lastEnqueuedWorkName; + // the name of the last collective started as the kernel + std::string lastStartedWorkName; + // the name of the last collective completed + std::string lastCompletedWorkName; + + // the sizes of the last work enqueued + size_t lastEnqueuedNumelIn; + size_t lastEnqueuedNumelOut; + // the sizes of the last work completed + size_t lastCompletedNumelIn; + size_t lastCompletedNumelOut; + }; + // If you wish to create multiple process groups, each with a potentially // different rank and size, you can do so by passing a new store instance // to each one. If you have only a single store object, you can @@ -1071,28 +1099,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // the ProcessGroup uint64_t op_id_{0}; - // the sequential number of the last collective enqueued into workMetaList_ - // This is useful for indentifying a rank that has not join a collective - // initialized to be -1 to indicate no collective has been enqueued - int64_t lastEnqueuedSeq_{-1}; - - // the name of the last collective enqueued into workMetaList_ - std::string lastEnqueuedWorkName_; - - // the sequential number of the last collective started as the kernel - int64_t lastStartedSeq_{-1}; - - // the name of the last collective started as the kernel - std::string lastStartedWorkName_; - - // the sequential number of the last colletive completed marked by - // the watchdog thread - // initialized to be -1 to indicate no collective has been completed - int64_t lastCompletedSeq_{-1}; - - // the name of the last collective completed - std::string lastCompletedWorkName_; - std::exception_ptr watchDogException_ = nullptr; size_t uid_; @@ -1103,6 +1109,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Number of devices on this node. int localDeviceCount_{0}; + + ProcessGroupStatus pgStatus_; }; TORCH_API std::string dump_nccl_trace(); From e72232f8f032b970b74da18200678b3a4617bf95 Mon Sep 17 00:00:00 2001 From: wz337 Date: Thu, 30 May 2024 23:55:18 +0000 Subject: [PATCH 134/706] [DeviceMesh] Adding nD slicing support back (#127465) Fixes #126530 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127465 Approved by: https://github.com/wconstab --- test/distributed/test_device_mesh.py | 46 ++++++++++-- torch/distributed/device_mesh.py | 101 ++++++++++++++++++--------- 2 files changed, 108 insertions(+), 39 deletions(-) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 8f70ee2f0b7d..03457de14b68 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -420,16 +420,16 @@ def world_size(self): @with_comms def test_raises_no_mesh_dim_found(self): - with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."): + with self.assertRaisesRegex( + RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!" + ): mesh = init_device_mesh(self.device_type, (2, 4)) child_mesh = mesh["DP"] @with_comms def test_raises_invalid_mesh_dim_name(self): - child_mesh_dim_name = "PP" - with self.assertRaisesRegex( - KeyError, f"Mesh dimension '{child_mesh_dim_name}' does not exist." - ): + child_mesh_dim_name = ("PP",) + with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): mesh_dim_names = ("DP", "TP") mesh = init_device_mesh( self.device_type, (2, 4), mesh_dim_names=mesh_dim_names @@ -437,7 +437,7 @@ def test_raises_invalid_mesh_dim_name(self): child_mesh = mesh[child_mesh_dim_name] @with_comms - def test_get_item(self): + def test_get_item_2d(self): mesh_shape = (2, 4) mesh_dim_names = ("DP", "TP") mesh_2d = init_device_mesh( @@ -467,9 +467,41 @@ def test_get_item_1d(self): dp_mesh = mesh["dp"] self.assertEqual(dp_mesh, mesh) - with self.assertRaisesRegex(RuntimeError, "Invalid mesh_dim_name"): + with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): dp_mesh = mesh["dim0"] + @with_comms + def test_get_item_3d(self): + mesh_shape = (2, 2, 2) + mesh_dim_names = ("Replicate", "Shard", "TP") + mesh_3d = init_device_mesh( + self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names + ) + + tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]] + tp_group_idx = int(self.rank / 2) + self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx]) + + shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]] + shard_group_idx = self.rank % 2 + self.rank // 4 * 2 + self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx]) + + replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]] + replicate_group_idx = self.rank % 4 + self.assertEqual( + mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx] + ) + + # We support both UX for nD slicing. + # mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"] + hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]] + hsdp_mesh_2 = mesh_3d["Replicate", "Shard"] + hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]] + hsdp_group_idx = self.rank % 2 + self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx]) + self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx]) + self.assertEqual(hsdp_mesh_1, hsdp_mesh_2) + @with_comms def test_cache_and_reuse_submesh_slice_result(self): mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp")) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 57b8fa1cf564..fe23fda513db 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -69,31 +69,47 @@ def get_current_mesh(self) -> "DeviceMesh": return self.mesh_stack[-1] def create_child_mesh( - self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str + self, parent_mesh: "DeviceMesh", submesh_dim_names: Tuple[str] ) -> "DeviceMesh": - # swap the current dim to the last dim then reshape to flatten out other - # dims, so we can just extract the list of ranks which contains cur_rank. - cur_rank = device_mesh.get_rank() - pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( - -1, device_mesh.mesh.size(mesh_dim) - ) + # submesh_dims are the mesh dimension of the submesh in the parent mesh. + submesh_dims = [ + not_none(parent_mesh.mesh_dim_names).index(mesh_dim_name) + for mesh_dim_name in submesh_dim_names + ] + submesh_dim_sizes = [ + parent_mesh.mesh.size(mesh_dim) for mesh_dim in submesh_dims + ] - for mesh_1d in pg_ranks_by_dim: - sub_mesh = DeviceMesh( - device_mesh.device_type, - mesh_1d, - mesh_dim_names=(mesh_dim_name,), + mesh_dims_remained = list(range(parent_mesh.mesh.ndim)) + for submesh_dim in submesh_dims: + mesh_dims_remained.remove(submesh_dim) + + # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *sub_mesh_dims] + # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with + # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank. + pg_ranks_by_dim = parent_mesh.mesh.permute( + *mesh_dims_remained, *submesh_dims + ).reshape(-1, *submesh_dim_sizes) + + cur_rank = parent_mesh.get_rank() + for mesh_nd in pg_ranks_by_dim: + # Every rank needs to participate in this DeviceMesh creation even if the cur_rank is not in mesh_nd + submesh = DeviceMesh( + parent_mesh.device_type, + mesh_nd, + mesh_dim_names=submesh_dim_names, _init_backend=False, ) - if cur_rank in mesh_1d: - res_sub_mesh = sub_mesh + if cur_rank in mesh_nd: + res_submesh = submesh + + res_submesh._parent_mesh = parent_mesh + res_submesh._dim_group_infos = [ + parent_mesh._dim_group_infos[mesh_dim] for mesh_dim in submesh_dims + ] + self.child_to_parent_mapping[res_submesh] = parent_mesh - res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined] - res_sub_mesh._parent_mesh = device_mesh - # Assign the current DeviceMesh as the parent of the child DeviceMesh. - # We need to update the mappings after the child mesh hash update. - self.child_to_parent_mapping[res_sub_mesh] = device_mesh - return res_sub_mesh + return res_submesh def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]: return self.child_to_parent_mapping.get(device_mesh, None) @@ -367,14 +383,14 @@ def __eq__(self, other: object) -> bool: and self._thread_id == other._thread_id ) - def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh": + def __getitem__(self, mesh_dim_names: Union[str, Tuple[str]]) -> "DeviceMesh": """ Slice the current DeviceMesh based on the mesh_dim_name given to create a child DeviceMesh. Args: - mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh - to create a child DeviceMesh for. + mesh_dim_name (Union[str, Tuple[str]]): the name or the tuple of names of the + mesh dimension of the parent DeviceMesh to create the child DeviceMesh for. Returns: A :class:`DeviceMesh` object @@ -395,16 +411,37 @@ def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh": >>> # of cross-host(dim 0), and within-host (dim 1). >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) """ - if self.mesh.ndim == 1: - if self.mesh_dim_names and mesh_dim_name == self.mesh_dim_names[0]: - return self - else: - raise RuntimeError( - f"Invalid mesh_dim_name {mesh_dim_name} specified." - ) + if not self.mesh_dim_names: + raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") + + mesh_dim_names = ( + (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names + ) + + error_msg = ( + f"Invalid mesh_dim_name {mesh_dim_names} specified. " + f"Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}." + ) + + if mesh_dim_names == self.mesh_dim_names: + return self + elif len(mesh_dim_names) > len(self.mesh_dim_names) or not all( + mesh_dim_name in self.mesh_dim_names for mesh_dim_name in mesh_dim_names + ): + raise KeyError(error_msg) + # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names + # of the current DeviceMesh. + else: + outermost_dim_name = mesh_dim_names[0] + outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name) + for i, j in zip( + mesh_dim_names, + self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)], + ): + if i != j: + raise KeyError(error_msg) - mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim_name) - submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name) + submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) return submesh def get_group( From a2bff4dc8cf0127c8ab2f78e661c156c40768efc Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Fri, 31 May 2024 00:00:11 +0000 Subject: [PATCH 135/706] Fix lint (#127584) Trivial fix after https://github.com/pytorch/pytorch/pull/124678 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127584 Approved by: https://github.com/huydhn --- torch/distributed/fsdp/_init_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 295400220b4f..2364b1871206 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -1099,7 +1099,6 @@ def _sync_module_params_and_buffers( ) - def _check_module_states_for_sync_module_states( module_states: List[torch.Tensor], ) -> None: From fc73d07e5e66eb7e0d599cc69f91686106c4a7d2 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Fri, 31 May 2024 00:17:25 +0000 Subject: [PATCH 136/706] [c10d] Decorate methods in `NCCLUtils.hpp` with `TORCH_API` (#127550) Summary: User-defined PyTorch modules that uses `C10D_NCCL_CHECK` run into undefined symbol errors when loaded by `torch.library.load()`, because they have not been exported. This change exports the symbols needed to resolve those runtime errors. Test Plan: PyTorch CI Differential Revision: D57977944 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127550 Approved by: https://github.com/Skylion007 --- torch/csrc/distributed/c10d/NCCLUtils.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 5690c0591a7a..165c514bbd27 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -173,14 +173,14 @@ namespace c10d { TORCH_API size_t hashTensors(const std::vector& tensors); -std::string getNcclVersion(); -std::string ncclGetErrorWithVersion(ncclResult_t error); +TORCH_API std::string getNcclVersion(); +TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error); bool nccl_use_nonblocking(); int nccl_nonblocking_timeout(); // Provides additional detail into NCCL error codes based on when these are // thrown in the NCCL codebase. -std::string getNcclErrorDetailStr( +TORCH_API std::string getNcclErrorDetailStr( ncclResult_t error, std::optional processGroupFailureReason = c10::nullopt); From af5ed054162695dee84475b58a069772050f07d1 Mon Sep 17 00:00:00 2001 From: atalman Date: Fri, 31 May 2024 00:30:10 +0000 Subject: [PATCH 137/706] Include triton in py3.12 binaries (#127547) Additional Builder PR: https://github.com/pytorch/builder/pull/1846/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/127547 Approved by: https://github.com/williamwen42 --- .circleci/scripts/binary_populate_env.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index 287423641d77..a45a2c9754ba 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -76,8 +76,8 @@ TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt) # Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then - # Only linux Python < 3.12 are supported wheels for triton - TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.12'" + # Only linux Python < 3.13 are supported wheels for triton + TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.13'" TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}" if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt) From f6e303fa47b6eb431db7a80a95f56574dfddc297 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 31 May 2024 00:43:13 +0000 Subject: [PATCH 138/706] Revert "[DeviceMesh] Adding nD slicing support back (#127465)" This reverts commit e72232f8f032b970b74da18200678b3a4617bf95. Reverted https://github.com/pytorch/pytorch/pull/127465 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing lint https://hud.pytorch.org/pytorch/pytorch/commit/e72232f8f032b970b74da18200678b3a4617bf95, the error does not like look trivial fix, so I revert the change for a forward fix ([comment](https://github.com/pytorch/pytorch/pull/127465#issuecomment-2141051630)) --- test/distributed/test_device_mesh.py | 46 ++---------- torch/distributed/device_mesh.py | 101 +++++++++------------------ 2 files changed, 39 insertions(+), 108 deletions(-) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 03457de14b68..8f70ee2f0b7d 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -420,16 +420,16 @@ def world_size(self): @with_comms def test_raises_no_mesh_dim_found(self): - with self.assertRaisesRegex( - RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!" - ): + with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."): mesh = init_device_mesh(self.device_type, (2, 4)) child_mesh = mesh["DP"] @with_comms def test_raises_invalid_mesh_dim_name(self): - child_mesh_dim_name = ("PP",) - with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): + child_mesh_dim_name = "PP" + with self.assertRaisesRegex( + KeyError, f"Mesh dimension '{child_mesh_dim_name}' does not exist." + ): mesh_dim_names = ("DP", "TP") mesh = init_device_mesh( self.device_type, (2, 4), mesh_dim_names=mesh_dim_names @@ -437,7 +437,7 @@ def test_raises_invalid_mesh_dim_name(self): child_mesh = mesh[child_mesh_dim_name] @with_comms - def test_get_item_2d(self): + def test_get_item(self): mesh_shape = (2, 4) mesh_dim_names = ("DP", "TP") mesh_2d = init_device_mesh( @@ -467,41 +467,9 @@ def test_get_item_1d(self): dp_mesh = mesh["dp"] self.assertEqual(dp_mesh, mesh) - with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): + with self.assertRaisesRegex(RuntimeError, "Invalid mesh_dim_name"): dp_mesh = mesh["dim0"] - @with_comms - def test_get_item_3d(self): - mesh_shape = (2, 2, 2) - mesh_dim_names = ("Replicate", "Shard", "TP") - mesh_3d = init_device_mesh( - self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names - ) - - tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]] - tp_group_idx = int(self.rank / 2) - self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx]) - - shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]] - shard_group_idx = self.rank % 2 + self.rank // 4 * 2 - self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx]) - - replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]] - replicate_group_idx = self.rank % 4 - self.assertEqual( - mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx] - ) - - # We support both UX for nD slicing. - # mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"] - hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]] - hsdp_mesh_2 = mesh_3d["Replicate", "Shard"] - hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]] - hsdp_group_idx = self.rank % 2 - self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx]) - self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx]) - self.assertEqual(hsdp_mesh_1, hsdp_mesh_2) - @with_comms def test_cache_and_reuse_submesh_slice_result(self): mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp")) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index fe23fda513db..57b8fa1cf564 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -69,47 +69,31 @@ def get_current_mesh(self) -> "DeviceMesh": return self.mesh_stack[-1] def create_child_mesh( - self, parent_mesh: "DeviceMesh", submesh_dim_names: Tuple[str] + self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str ) -> "DeviceMesh": - # submesh_dims are the mesh dimension of the submesh in the parent mesh. - submesh_dims = [ - not_none(parent_mesh.mesh_dim_names).index(mesh_dim_name) - for mesh_dim_name in submesh_dim_names - ] - submesh_dim_sizes = [ - parent_mesh.mesh.size(mesh_dim) for mesh_dim in submesh_dims - ] + # swap the current dim to the last dim then reshape to flatten out other + # dims, so we can just extract the list of ranks which contains cur_rank. + cur_rank = device_mesh.get_rank() + pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( + -1, device_mesh.mesh.size(mesh_dim) + ) - mesh_dims_remained = list(range(parent_mesh.mesh.ndim)) - for submesh_dim in submesh_dims: - mesh_dims_remained.remove(submesh_dim) - - # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *sub_mesh_dims] - # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with - # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank. - pg_ranks_by_dim = parent_mesh.mesh.permute( - *mesh_dims_remained, *submesh_dims - ).reshape(-1, *submesh_dim_sizes) - - cur_rank = parent_mesh.get_rank() - for mesh_nd in pg_ranks_by_dim: - # Every rank needs to participate in this DeviceMesh creation even if the cur_rank is not in mesh_nd - submesh = DeviceMesh( - parent_mesh.device_type, - mesh_nd, - mesh_dim_names=submesh_dim_names, + for mesh_1d in pg_ranks_by_dim: + sub_mesh = DeviceMesh( + device_mesh.device_type, + mesh_1d, + mesh_dim_names=(mesh_dim_name,), _init_backend=False, ) - if cur_rank in mesh_nd: - res_submesh = submesh - - res_submesh._parent_mesh = parent_mesh - res_submesh._dim_group_infos = [ - parent_mesh._dim_group_infos[mesh_dim] for mesh_dim in submesh_dims - ] - self.child_to_parent_mapping[res_submesh] = parent_mesh + if cur_rank in mesh_1d: + res_sub_mesh = sub_mesh - return res_submesh + res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined] + res_sub_mesh._parent_mesh = device_mesh + # Assign the current DeviceMesh as the parent of the child DeviceMesh. + # We need to update the mappings after the child mesh hash update. + self.child_to_parent_mapping[res_sub_mesh] = device_mesh + return res_sub_mesh def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]: return self.child_to_parent_mapping.get(device_mesh, None) @@ -383,14 +367,14 @@ def __eq__(self, other: object) -> bool: and self._thread_id == other._thread_id ) - def __getitem__(self, mesh_dim_names: Union[str, Tuple[str]]) -> "DeviceMesh": + def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh": """ Slice the current DeviceMesh based on the mesh_dim_name given to create a child DeviceMesh. Args: - mesh_dim_name (Union[str, Tuple[str]]): the name or the tuple of names of the - mesh dimension of the parent DeviceMesh to create the child DeviceMesh for. + mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh + to create a child DeviceMesh for. Returns: A :class:`DeviceMesh` object @@ -411,37 +395,16 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str]]) -> "DeviceMesh": >>> # of cross-host(dim 0), and within-host (dim 1). >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) """ - if not self.mesh_dim_names: - raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") - - mesh_dim_names = ( - (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names - ) - - error_msg = ( - f"Invalid mesh_dim_name {mesh_dim_names} specified. " - f"Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}." - ) - - if mesh_dim_names == self.mesh_dim_names: - return self - elif len(mesh_dim_names) > len(self.mesh_dim_names) or not all( - mesh_dim_name in self.mesh_dim_names for mesh_dim_name in mesh_dim_names - ): - raise KeyError(error_msg) - # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names - # of the current DeviceMesh. - else: - outermost_dim_name = mesh_dim_names[0] - outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name) - for i, j in zip( - mesh_dim_names, - self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)], - ): - if i != j: - raise KeyError(error_msg) + if self.mesh.ndim == 1: + if self.mesh_dim_names and mesh_dim_name == self.mesh_dim_names[0]: + return self + else: + raise RuntimeError( + f"Invalid mesh_dim_name {mesh_dim_name} specified." + ) - submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) + mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim_name) + submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name) return submesh def get_group( From da9fb670d2283a7dfffac565ec6065feec8325bc Mon Sep 17 00:00:00 2001 From: feifan <37650440+huihoaan@users.noreply.github.com> Date: Fri, 31 May 2024 01:11:13 +0000 Subject: [PATCH 139/706] Nadam support the flag for "maximize" (#127214) Fixes https://github.com/pytorch/pytorch/issues/126642 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127214 Approved by: https://github.com/janeyx99 --- torch/optim/nadam.py | 34 ++++++++++++++++---- torch/testing/_internal/common_optimizers.py | 6 +++- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 75a6f49be262..fd1f8ab0e718 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -12,6 +12,7 @@ _get_capturable_supported_devices, _get_scalar_dtype, _get_value, + _maximize_doc, _stack_if_compiling, _use_grad_for_differentiable, _view_as_real, @@ -34,6 +35,7 @@ def __init__( decoupled_weight_decay: bool = False, *, foreach: Optional[bool] = None, + maximize: bool = False, capturable: bool = False, differentiable: bool = False, ): @@ -56,6 +58,7 @@ def __init__( weight_decay=weight_decay, momentum_decay=momentum_decay, decoupled_weight_decay=decoupled_weight_decay, + maximize=maximize, foreach=foreach, capturable=capturable, differentiable=differentiable, @@ -65,6 +68,7 @@ def __init__( def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: + group.setdefault("maximize", False) group.setdefault("foreach", None) group.setdefault("capturable", False) group.setdefault("differentiable", False) @@ -188,6 +192,7 @@ def step(self, closure=None): weight_decay=group["weight_decay"], momentum_decay=group["momentum_decay"], eps=group["eps"], + maximize=group["maximize"], decoupled_weight_decay=group["decoupled_weight_decay"], foreach=group["foreach"], capturable=group["capturable"], @@ -207,12 +212,15 @@ def step(self, closure=None): &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\ - &\hspace{13mm} \: \textit{decoupled\_weight\_decay} \\ + &\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ - &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} \\ &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ &\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ @@ -249,6 +257,7 @@ def step(self, closure=None): decoupled_weight_decay (bool, optional): whether to use decoupled weight decay as in AdamW to obtain NAdamW (default: False) {_foreach_doc} + {_maximize_doc} {_capturable_doc} {_differentiable_doc} @@ -276,12 +285,13 @@ def _single_tensor_nadam( momentum_decay: float, eps: float, decoupled_weight_decay: bool, + maximize: bool, capturable: bool, differentiable: bool, has_complex: bool, ): for i, param in enumerate(params): - grad = grads[i] + grad = grads[i] if not maximize else -grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] mu_product = mu_products[i] @@ -369,6 +379,7 @@ def _multi_tensor_nadam( momentum_decay: float, eps: float, decoupled_weight_decay: bool, + maximize: bool, capturable: bool, differentiable: bool, has_complex: bool, @@ -406,6 +417,9 @@ def _multi_tensor_nadam( grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs ) + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + # Update steps # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just @@ -422,9 +436,15 @@ def _multi_tensor_nadam( # Perform stepweight decay torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) else: - grouped_grads = torch._foreach_add( # type: ignore[assignment] - grouped_grads, grouped_params, alpha=weight_decay - ) + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_( + grouped_grads, grouped_params, alpha=weight_decay + ) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) # Decay the first and second moment running average coefficient torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) @@ -560,6 +580,7 @@ def nadam( capturable: bool = False, differentiable: bool = False, has_complex: bool = False, + maximize: bool = False, *, beta1: float, beta2: float, @@ -608,6 +629,7 @@ def nadam( lr=lr, weight_decay=weight_decay, momentum_decay=momentum_decay, + maximize=maximize, decoupled_weight_decay=decoupled_weight_decay, eps=eps, capturable=capturable, diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index bca785dd9543..43d0124a6021 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -642,7 +642,6 @@ def optim_error_inputs_func_lbfgs(device, dtype): return error_inputs -# Weird story bro, NAdam and RAdam do not have maximize. def optim_inputs_func_nadam(device, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), @@ -695,6 +694,11 @@ def optim_inputs_func_nadam(device, dtype=None): }, desc="decoupled_weight_decay", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), ] + (cuda_supported_configs if "cuda" in str(device) else []) From d44daebdbcdd1435b770a0dc1c5b9fbc1979932b Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 31 May 2024 01:20:45 +0000 Subject: [PATCH 140/706] [Submodule] Remove deprecated USE_TBB option and TBB submodule (#127051) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127051 Approved by: https://github.com/cpuhrsch, https://github.com/malfet --- .ci/pytorch/build.sh | 5 +- .ci/pytorch/test.sh | 18 - .gitmodules | 4 - BUILD.bazel | 14 +- CMakeLists.txt | 6 - WORKSPACE | 10 - aten/src/ATen/CMakeLists.txt | 10 - aten/src/ATen/Config.h.in | 1 - aten/src/ATen/Parallel.h | 2 - aten/src/ATen/ParallelCommon.cpp | 2 - aten/src/ATen/ParallelNativeTBB.cpp | 115 ------ aten/src/ATen/ParallelNativeTBB.h | 52 --- aten/src/ATen/ParallelThreadPoolNative.cpp | 2 +- aten/src/ATen/cpu/tbb/CMakeLists.txt | 391 ------------------ .../ATen/cpu/tbb/extra/version_string.ver.in | 11 - buckbuild.bzl | 4 - build_variables.bzl | 1 - caffe2/CMakeLists.txt | 19 - cmake/Dependencies.cmake | 29 -- cmake/Modules/FindMKL.cmake | 4 +- cmake/Modules/FindMKLDNN.cmake | 2 +- cmake/Summary.cmake | 4 - cmake/public/utils.cmake | 3 - defs.bzl | 2 - setup.py | 8 - third_party/mkl-dnn.BUILD | 5 +- third_party/mkl.BUILD | 5 +- third_party/tbb | 1 - third_party/tbb.BUILD | 75 ---- third_party/tbb.patch | 34 -- torch/testing/_internal/common_modules.py | 4 +- torch/testing/_internal/common_optimizers.py | 3 +- torch/testing/_internal/common_utils.py | 33 -- torch/utils/cpp_extension.py | 3 - 34 files changed, 19 insertions(+), 863 deletions(-) delete mode 100644 aten/src/ATen/ParallelNativeTBB.cpp delete mode 100644 aten/src/ATen/ParallelNativeTBB.h delete mode 100644 aten/src/ATen/cpu/tbb/CMakeLists.txt delete mode 100644 aten/src/ATen/cpu/tbb/extra/version_string.ver.in delete mode 160000 third_party/tbb delete mode 100644 third_party/tbb.BUILD delete mode 100644 third_party/tbb.patch diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 130b770a2cc2..187e6d788bdd 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -44,10 +44,7 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda11* ]]; then fi fi -if [[ ${BUILD_ENVIRONMENT} == *"paralleltbb"* ]]; then - export ATEN_THREADING=TBB - export USE_TBB=1 -elif [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then +if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then export ATEN_THREADING=NATIVE fi diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 6a9c81fb79dc..190f99204e9c 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -693,7 +693,6 @@ test_aten() { ${SUDO} ln -sf "$TORCH_LIB_DIR"/libmkldnn* "$TEST_BASE_DIR" ${SUDO} ln -sf "$TORCH_LIB_DIR"/libnccl* "$TEST_BASE_DIR" ${SUDO} ln -sf "$TORCH_LIB_DIR"/libtorch* "$TEST_BASE_DIR" - ${SUDO} ln -sf "$TORCH_LIB_DIR"/libtbb* "$TEST_BASE_DIR" ls "$TEST_BASE_DIR" aten/tools/run_tests.sh "$TEST_BASE_DIR" @@ -718,21 +717,6 @@ test_without_numpy() { popd } -# pytorch extensions require including torch/extension.h which includes all.h -# which includes utils.h which includes Parallel.h. -# So you can call for instance parallel_for() from your extension, -# but the compilation will fail because of Parallel.h has only declarations -# and definitions are conditionally included Parallel.h(see last lines of Parallel.h). -# I tried to solve it #39612 and #39881 by including Config.h into Parallel.h -# But if Pytorch is built with TBB it provides Config.h -# that has AT_PARALLEL_NATIVE_TBB=1(see #3961 or #39881) and it means that if you include -# torch/extension.h which transitively includes Parallel.h -# which transitively includes tbb.h which is not available! -if [[ "${BUILD_ENVIRONMENT}" == *tbb* ]]; then - sudo mkdir -p /usr/include/tbb - sudo cp -r "$PWD"/third_party/tbb/include/tbb/* /usr/include/tbb -fi - test_libtorch() { local SHARD="$1" @@ -746,7 +730,6 @@ test_libtorch() { ln -sf "$TORCH_LIB_DIR"/libc10* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libshm* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_BIN_DIR" - ln -sf "$TORCH_LIB_DIR"/libtbb* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libnvfuser* "$TORCH_BIN_DIR" export CPP_TESTS_DIR="${TORCH_BIN_DIR}" @@ -883,7 +866,6 @@ test_rpc() { # test reporting process to function as expected. ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libc10* "$TORCH_BIN_DIR" - ln -sf "$TORCH_LIB_DIR"/libtbb* "$TORCH_BIN_DIR" CPP_TESTS_DIR="${TORCH_BIN_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_cpp_rpc } diff --git a/.gitmodules b/.gitmodules index fe30ac3d0e5b..476f11fd945c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -82,10 +82,6 @@ ignore = dirty path = third_party/foxi url = https://github.com/houseroad/foxi.git -[submodule "third_party/tbb"] - path = third_party/tbb - url = https://github.com/01org/tbb - branch = tbb_2018 [submodule "android/libs/fbjni"] ignore = dirty path = android/libs/fbjni diff --git a/BUILD.bazel b/BUILD.bazel index e61290ca2cab..9eff26e01ca9 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -125,10 +125,6 @@ filegroup( data = [":generate-code"], ) -exports_files( - srcs = ["aten/src/ATen/cpu/tbb/extra/version_string.ver.in"], -) - # ATen filegroup( name = "aten_base_cpp", @@ -275,7 +271,6 @@ header_template_rule( "@AT_BUILD_WITH_LAPACK@": "1", "@AT_PARALLEL_OPENMP@": "0", "@AT_PARALLEL_NATIVE@": "1", - "@AT_PARALLEL_NATIVE_TBB@": "0", "@AT_BLAS_F2C@": "0", "@AT_BLAS_USE_CBLAS_DOT@": "1", }, @@ -359,6 +354,9 @@ cc_library( ":aten_src_ATen_config", ] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"), copts = ATEN_COPTS, + linkopts = [ + "-ldl", + ], data = if_cuda( [":libcaffe2_nvrtc.so"], [], @@ -772,6 +770,9 @@ cc_library( ], )) + torch_sources, copts = TORCH_COPTS, + linkopts = [ + "-lrt", + ], defines = [ "CAFFE2_NIGHTLY_VERSION=20200115", ], @@ -792,6 +793,9 @@ cc_library( cc_library( name = "shm", srcs = glob(["torch/lib/libshm/*.cpp"]), + linkopts = [ + "-lrt", + ], deps = [ ":torch", ], diff --git a/CMakeLists.txt b/CMakeLists.txt index 10a92dcc7c2c..335f5750648c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -362,9 +362,6 @@ cmake_dependent_option( cmake_dependent_option( USE_TENSORPIPE "Use TensorPipe. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) -option(USE_TBB "Use TBB (Deprecated)" OFF) -cmake_dependent_option( - USE_SYSTEM_TBB "Use system-provided Intel TBB." OFF "USE_TBB" OFF) option(ONNX_ML "Enable traditional ONNX ML API." ON) option(HAVE_SOVERSION "Whether to add SOVERSION to the shared objects" OFF) option(BUILD_LIBTORCH_CPU_WITH_DEBUG @@ -483,9 +480,6 @@ if(USE_SYSTEM_LIBS) if(USE_NCCL) set(USE_SYSTEM_NCCL ON) endif() - if(USE_TBB) - set(USE_SYSTEM_TBB ON) - endif() endif() # Used when building Caffe2 through setup.py diff --git a/WORKSPACE b/WORKSPACE index 9f32aea703dc..4169e0dbce1d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -168,16 +168,6 @@ new_local_repository( path = "third_party/opentelemetry-cpp", ) -new_patched_local_repository( - name = "tbb", - build_file = "//third_party:tbb.BUILD", - patch_strip = 1, - patches = [ - "@//third_party:tbb.patch", - ], - path = "third_party/tbb", -) - new_local_repository( name = "cpp-httplib", build_file = "//third_party:cpp-httplib.BUILD", diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 9ec458fda45e..9fa7a1f2305b 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -349,16 +349,6 @@ endif() list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..) -if(USE_TBB) - if(USE_SYSTEM_TBB) - message("ATen is compiled with system-provided Intel TBB.") - else() - message("ATen is compiled with Intel TBB (${TBB_ROOT_DIR}).") - endif() - list(APPEND ATen_CPU_INCLUDE ${TBB_INCLUDE_DIR}) - list(APPEND ATen_CPU_DEPENDENCY_LIBS TBB::tbb) -endif() - if(BLAS_FOUND) if($ENV{TH_BINARY_BUILD}) message(STATUS "TH_BINARY_BUILD detected. Enabling special linkage.") diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in index 93b8e0434f1a..fdd2ac2bc5f7 100644 --- a/aten/src/ATen/Config.h.in +++ b/aten/src/ATen/Config.h.in @@ -17,6 +17,5 @@ #define AT_BUILD_WITH_LAPACK() @AT_BUILD_WITH_LAPACK@ #define AT_PARALLEL_OPENMP @AT_PARALLEL_OPENMP@ #define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@ -#define AT_PARALLEL_NATIVE_TBB @AT_PARALLEL_NATIVE_TBB@ #define AT_BLAS_F2C() @AT_BLAS_F2C@ #define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@ diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index ff14f568d22a..966e29c0289f 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -153,8 +153,6 @@ TORCH_API int intraop_default_num_threads(); #include // IWYU pragma: keep #elif AT_PARALLEL_NATIVE #include // IWYU pragma: keep -#elif AT_PARALLEL_NATIVE_TBB -#include // IWYU pragma: keep #endif #include // IWYU pragma: keep diff --git a/aten/src/ATen/ParallelCommon.cpp b/aten/src/ATen/ParallelCommon.cpp index 0504a066eef5..e5d9bb83c016 100644 --- a/aten/src/ATen/ParallelCommon.cpp +++ b/aten/src/ATen/ParallelCommon.cpp @@ -80,8 +80,6 @@ std::string get_parallel_info() { ss << "OpenMP"; #elif AT_PARALLEL_NATIVE ss << "native thread pool"; - #elif AT_PARALLEL_NATIVE_TBB - ss << "native thread pool and TBB"; #endif #ifdef C10_MOBILE ss << " [mobile]"; diff --git a/aten/src/ATen/ParallelNativeTBB.cpp b/aten/src/ATen/ParallelNativeTBB.cpp deleted file mode 100644 index 06be418f7d9c..000000000000 --- a/aten/src/ATen/ParallelNativeTBB.cpp +++ /dev/null @@ -1,115 +0,0 @@ -#include -#if AT_PARALLEL_NATIVE_TBB -#include -#include -#include - -#include -#include - -#include -#define TBB_PREVIEW_GLOBAL_CONTROL 1 -#include - -#ifdef _OPENMP -#include -#endif - -#if AT_MKL_ENABLED() -#include -#endif - -namespace at { - -namespace { -static thread_local tbb::task_group tg_; -thread_local int this_thread_id{0}; - -std::mutex global_thread_mutex_; -std::shared_ptr global_thread_limit_ = nullptr; -std::atomic num_intraop_threads_{-1}; - -void _internal_set_num_threads(int nthreads) { - TORCH_INTERNAL_ASSERT(nthreads > 0); - { - std::unique_lock lk(global_thread_mutex_); - // This is an antipattern and we shouldn't be constraining the number of - // threads in library code. - // TODO: Think of a smarter way to leverage tbb::thread_arena to limit the - // number of slots instead of the number of threads. - global_thread_limit_ = std::make_shared( - tbb::global_control::max_allowed_parallelism, nthreads); - num_intraop_threads_.store(nthreads); - } -} -} - -void init_num_threads() { - #ifdef _OPENMP - omp_set_num_threads(1); - #endif - - #if AT_MKL_ENABLED() - mkl_set_num_threads(1); - #endif - - int nthreads = num_intraop_threads_.load(); - if (nthreads < 0) { - nthreads = intraop_default_num_threads(); - } - _internal_set_num_threads(nthreads); -} - -void set_num_threads(int nthreads) { - TORCH_CHECK(nthreads > 0); - - _internal_set_num_threads(nthreads); -} - -int get_num_threads() { - at::internal::lazy_init_num_threads(); - return tbb::global_control::active_value( - tbb::global_control::max_allowed_parallelism); -} - -int get_thread_num() { - return this_thread_id; -} - -namespace internal { -void set_thread_num(int id) { - this_thread_id = id; -} -} - -bool in_parallel_region() { - return tbb::this_task_arena::current_thread_index() >= 0; -} - -void intraop_launch(std::function func) { - if (get_num_threads() > 1) { - tg_.run(func); - } else { - func(); - } -} - -c10::intrusive_ptr intraop_launch_future( - std::function func) { - auto future = c10::make_intrusive(NoneType::get()); - if (get_num_threads() > 1) { - tg_.run( - [func, future]() { - func(); - future->markCompleted(); - } - ); - } else { - func(); - future->markCompleted(); - } - return future; -} - -} // namespace at -#endif diff --git a/aten/src/ATen/ParallelNativeTBB.h b/aten/src/ATen/ParallelNativeTBB.h deleted file mode 100644 index 9193e06ed695..000000000000 --- a/aten/src/ATen/ParallelNativeTBB.h +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#ifdef _WIN32 -#ifndef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN -#endif -#endif -#include - -#define INTRA_OP_PARALLEL - -namespace at::internal { - -template -inline void invoke_parallel( - const int64_t begin, - const int64_t end, - const int64_t grain_size, - const F& f) { - // Choose number of tasks based on grain size and number of threads. - int64_t chunk_size = divup((end - begin), get_num_threads()); - // Make sure each task is at least grain_size size. - chunk_size = std::max(grain_size, chunk_size); - - std::atomic_flag err_flag = ATOMIC_FLAG_INIT; - std::exception_ptr eptr; - tbb::parallel_for( - tbb::blocked_range(begin, end, chunk_size), - [&eptr, &err_flag, f](const tbb::blocked_range& r) { - try { - internal::ThreadIdGuard tid_guard( - tbb::this_task_arena::current_thread_index()); - f(r.begin(), r.end()); - } catch (...) { - if (!err_flag.test_and_set()) { - eptr = std::current_exception(); - } - } - }, - tbb::static_partitioner{}); - if (eptr) { - std::rethrow_exception(eptr); - } -} - -} // namespace at::internal diff --git a/aten/src/ATen/ParallelThreadPoolNative.cpp b/aten/src/ATen/ParallelThreadPoolNative.cpp index a9d5095f32a9..3ea51bb5e683 100644 --- a/aten/src/ATen/ParallelThreadPoolNative.cpp +++ b/aten/src/ATen/ParallelThreadPoolNative.cpp @@ -1,5 +1,5 @@ #include -#if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE || AT_PARALLEL_NATIVE_TBB +#if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE #include #include #include diff --git a/aten/src/ATen/cpu/tbb/CMakeLists.txt b/aten/src/ATen/cpu/tbb/CMakeLists.txt deleted file mode 100644 index 6e946d5d13d5..000000000000 --- a/aten/src/ATen/cpu/tbb/CMakeLists.txt +++ /dev/null @@ -1,391 +0,0 @@ -# Based on https://github.com/wjakob/tbb/blob/master/CMakeLists.txt -# All credit goes to Wenzel Jakob! - -cmake_minimum_required(VERSION 2.8.12 FATAL_ERROR) -project(tbb CXX) - -include(CheckCXXCompilerFlag) -include(CheckCXXSourceRuns) - -if(POLICY CMP0058) - cmake_policy(SET CMP0058 NEW) -endif() - -if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) - message(STATUS "Setting build type to 'Release' as none was specified.") - set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE) - set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" - "MinSizeRel" "RelWithDebInfo") -endif() - -if(NOT TBB_ROOT_DIR) - set(TBB_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}") -endif() -if(NOT TBB_INSTALL_EXPORT_NAME) - set(TBB_INSTALL_EXPORT_NAME "Caffe2Targets") -endif() -if(NOT TBB_INSTALL_EXPORT_DESTINATION) - set(TBB_INSTALL_EXPORT_DESTINATION lib) -endif() -if(NOT TBB_INSTALL_RUNTIME_DIR) - set(TBB_INSTALL_RUNTIME_DIR bin) -endif() -if(NOT TBB_INSTALL_LIBRARY_DIR) - set(TBB_INSTALL_LIBRARY_DIR lib) -endif() -if(NOT TBB_INSTALL_ARCHIVE_DIR) - set(TBB_INSTALL_ARCHIVE_DIR lib) -endif() -if(NOT TBB_INSTALL_INCLUDE_DIR) - set(TBB_INSTALL_INCLUDE_DIR "${TBB_ROOT_DIR}/include") -endif() - -set(TBB_INCLUDES - "${TBB_ROOT_DIR}/include" - "${TBB_ROOT_DIR}/src" - "${TBB_ROOT_DIR}/src/rml/include" - ${CMAKE_CURRENT_BINARY_DIR}) - -option(TBB_BUILD_SHARED "Build TBB shared library" ON) -option(TBB_BUILD_STATIC "Build TBB static library" ON) -option(TBB_BUILD_TBBMALLOC "Build TBB malloc library" ON) -option(TBB_BUILD_TBBMALLOC_PROXY "Build TBB malloc proxy library" ON) -option(TBB_BUILD_TESTS "Build TBB tests and enable testing infrastructure" ON) -option(TBB_CI_BUILD "Is this a continuous integration build?" OFF) - -if(APPLE) - set(CMAKE_MACOSX_RPATH ON) -endif() - -file(GLOB tbb_src "${TBB_ROOT_DIR}/src/tbb/*.cpp" "${TBB_ROOT_DIR}/src/old/*.cpp") -list(APPEND tbb_src ${TBB_ROOT_DIR}/src/rml/client/rml_tbb.cpp) -file(GLOB to_remove "${TBB_ROOT_DIR}/src/old/test*.cpp") -if(NOT "${to_remove}" STREQUAL "") - list(REMOVE_ITEM tbb_src ${to_remove}) -endif() - -set(tbbmalloc_static_src - src/tbbmalloc/backend.cpp - src/tbbmalloc/large_objects.cpp - src/tbbmalloc/backref.cpp - src/tbbmalloc/tbbmalloc.cpp - src/tbbmalloc/frontend.cpp - src/tbb/itt_notify.cpp) - -set(tbbmalloc_src ${tbbmalloc_static_src}) - -set(tbbmalloc_proxy_src - src/tbbmalloc/proxy.cpp - src/tbbmalloc/tbb_function_replacement.cpp) - -if(CMAKE_SYSTEM_PROCESSOR MATCHES "(i386|x86_64)") - if(NOT APPLE AND NOT MINGW) - add_definitions(-DDO_ITT_NOTIFY) - endif() -endif() - -if(APPLE) - # Disable annoying "has no symbols" warnings - set(CMAKE_C_ARCHIVE_CREATE " Scr ") - set(CMAKE_CXX_ARCHIVE_CREATE " Scr ") - set(CMAKE_C_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") - set(CMAKE_CXX_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") -endif() - -macro(CHECK_CXX_COMPILER_AND_LINKER_FLAGS _RESULT _CXX_FLAGS _LINKER_FLAGS) - set(CMAKE_REQUIRED_FLAGS ${_CXX_FLAGS}) - set(CMAKE_REQUIRED_LIBRARIES ${_LINKER_FLAGS}) - set(CMAKE_REQUIRED_QUIET TRUE) - check_cxx_source_runs("#include \nint main(int argc, char **argv) { std::cout << \"test\"; return 0; }" ${_RESULT}) - set(CMAKE_REQUIRED_FLAGS "") - set(CMAKE_REQUIRED_LIBRARIES "") -endmacro() - -# Prefer libc++ in conjunction with Clang -if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - if(CMAKE_CXX_FLAGS MATCHES "-stdlib=libc\\+\\+") - message(STATUS "TBB: using libc++.") - else() - CHECK_CXX_COMPILER_AND_LINKER_FLAGS(HAS_LIBCPP "-stdlib=libc++" "-stdlib=libc++") - if(HAS_LIBCPP) - string(APPEND CMAKE_CXX_FLAGS " -stdlib=libc++ -D_LIBCPP_VERSION") - string(APPEND CMAKE_EXE_LINKER_FLAGS " -stdlib=libc++") - string(APPEND CMAKE_SHARED_LINKER_FLAGS " -stdlib=libc++") - message(STATUS "TBB: using libc++.") - else() - message(STATUS "TBB: NOT using libc++.") - endif() - endif() -endif() - -if(UNIX) - add_definitions(-DUSE_PTHREAD) - - check_cxx_compiler_flag("-std=c++17" SUPPORTS_STDCXX17) - if(SUPPORTS_STDCXX17) - set(CMAKE_CXX_FLAGS "-std=c++17 ${CMAKE_CXX_FLAGS}") - endif() - - check_cxx_compiler_flag("-mrtm -Werror" SUPPORTS_MRTM) - if(SUPPORTS_MRTM) - set(CMAKE_CXX_FLAGS "-mrtm ${CMAKE_CXX_FLAGS}") - endif() - -elseif(WIN32) - if(MSVC) - cmake_minimum_required(VERSION 3.1) - enable_language(ASM_MASM) - set(CMAKE_CXX_FLAGS "/GS- /Zc:wchar_t /Zc:forScope /DUSE_WINTHREAD ${CMAKE_CXX_FLAGS}") - set(CMAKE_CXX_FLAGS "/D_CRT_SECURE_NO_DEPRECATE /D_WIN32_WINNT=0x0600 ${CMAKE_CXX_FLAGS}") - check_cxx_compiler_flag("/volatile:iso" SUPPORTS_VOLATILE_FLAG) - if(SUPPORTS_VOLATILE_FLAG) - set(CMAKE_CXX_FLAGS "/volatile:iso ${CMAKE_CXX_FLAGS}") - endif() - set(CMAKE_CXX_FLAGS "/wd4267 /wd4800 /wd4146 /wd4244 /wd4577 /wd4018 ${CMAKE_CXX_FLAGS}") - if(NOT CMAKE_SIZEOF_VOID_P) - message(FATAL_ERROR "'CMAKE_SIZEOF_VOID_P' is undefined. Please delete your build directory and rerun CMake again!") - endif() - - if(CMAKE_SIZEOF_VOID_P EQUAL 8) - list(APPEND tbb_src "${TBB_ROOT_DIR}/src/tbb/intel64-masm/atomic_support.asm") - list(APPEND tbb_src "${TBB_ROOT_DIR}/src/tbb/intel64-masm/itsx.asm") - list(APPEND tbb_src "${TBB_ROOT_DIR}/src/tbb/intel64-masm/intel64_misc.asm") - list(APPEND tbbmalloc_src "${TBB_ROOT_DIR}/src/tbb/intel64-masm/atomic_support.asm") - set(CMAKE_ASM_MASM_FLAGS "/DEM64T=1 ${CMAKE_ASM_MASM_FLAGS}") - else() - list(APPEND tbb_src "${TBB_ROOT_DIR}/src/tbb/ia32-masm/atomic_support.asm" - "${TBB_ROOT_DIR}/src/tbb/ia32-masm/itsx.asm src/tbb/ia32-masm/lock_byte.asm") - # Enable SAFESEH feature for assembly (x86 builds only). - set(CMAKE_ASM_MASM_FLAGS "/safeseh ${CMAKE_ASM_MASM_FLAGS}") - endif() - elseif(MINGW) - add_definitions(-DUSE_WINTHREAD) - add_definitions(-D_WIN32_WINNT=0x0502) - set(CMAKE_CXX_FLAGS "-mthreads ${CMAKE_CXX_FLAGS}") - endif() -endif() - -if(MSVC) - set(ENABLE_RTTI "/EHsc /GR ") - set(DISABLE_RTTI "/EHs- /GR- ") -elseif(UNIX) - set(ENABLE_RTTI "-frtti -fexceptions ") - set(DISABLE_RTTI "-fno-rtti -fno-exceptions ") -endif() - -##-------- -# - Added TBB_USE_GLIBCXX_VERSION macro to specify the version of GNU -# libstdc++ when it cannot be properly recognized, e.g. when used -# with Clang on Linux* OS. Inspired by a contribution from David A. -if(NOT TBB_USE_GLIBCXX_VERSION AND UNIX AND NOT APPLE) - if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") - # using Clang - string(REPLACE "." "0" TBB_USE_GLIBCXX_VERSION ${CMAKE_CXX_COMPILER_VERSION}) - endif() -endif() - -if(TBB_USE_GLIBCXX_VERSION) - add_definitions(-DTBB_USE_GLIBCXX_VERSION=${TBB_USE_GLIBCXX_VERSION}) -endif() - -##------- - -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - check_cxx_compiler_flag("-flifetime-dse=1" SUPPORTS_FLIFETIME) - if(SUPPORTS_FLIFETIME) - add_definitions(-flifetime-dse=1) - endif() -endif() - -# Linker export definitions -if(APPLE) - set(ARCH_PREFIX "mac") -elseif(WIN32) - set(ARCH_PREFIX "win") -else() - set(ARCH_PREFIX "lin") -endif() - -if(CMAKE_SIZEOF_VOID_P EQUAL 8) - set(ARCH_PREFIX "${ARCH_PREFIX}64") -else() - set(ARCH_PREFIX "${ARCH_PREFIX}32") -endif() - -if(MINGW) - set(ARCH_PREFIX "${ARCH_PREFIX}-gcc") - # there's no win32-gcc-tbb-export.def, use lin32-tbb-export.def - execute_process(COMMAND ${CMAKE_COMMAND} -E copy ${TBB_ROOT_DIR}/src/tbb/lin32-tbb-export.def ${TBB_ROOT_DIR}/src/tbb/win32-gcc-tbb-export.def) -endif() - -if(MSVC) - add_custom_command(OUTPUT tbb.def - COMMAND ${CMAKE_CXX_COMPILER} /TC /EP ${TBB_ROOT_DIR}/src/tbb/${ARCH_PREFIX}-tbb-export.def -I ${TBB_ROOT_DIR}/include > tbb.def - MAIN_DEPENDENCY ${TBB_ROOT_DIR}/src/tbb/${ARCH_PREFIX}-tbb-export.def - COMMENT "Preprocessing tbb.def" - ) - - add_custom_command(OUTPUT tbbmalloc.def - COMMAND ${CMAKE_CXX_COMPILER} /TC /EP ${TBB_ROOT_DIR}/src/tbbmalloc/${ARCH_PREFIX}-tbbmalloc-export.def -I ${TBB_ROOT_DIR}/include > tbbmalloc.def - MAIN_DEPENDENCY ${TBB_ROOT_DIR}/src/tbbmalloc/${ARCH_PREFIX}-tbbmalloc-export.def - COMMENT "Preprocessing tbbmalloc.def" - ) -else() - add_custom_command(OUTPUT tbb.def - COMMAND ${CMAKE_CXX_COMPILER} -xc++ -E ${TBB_ROOT_DIR}/src/tbb/${ARCH_PREFIX}-tbb-export.def -I ${TBB_ROOT_DIR}/include -o tbb.def - MAIN_DEPENDENCY ${TBB_ROOT_DIR}/src/tbb/${ARCH_PREFIX}-tbb-export.def - COMMENT "Preprocessing tbb.def" - ) - - add_custom_command(OUTPUT tbbmalloc.def - COMMAND ${CMAKE_CXX_COMPILER} -xc++ -E ${TBB_ROOT_DIR}/src/tbbmalloc/${ARCH_PREFIX}-tbbmalloc-export.def -I ${TBB_ROOT_DIR}/include -o tbbmalloc.def - MAIN_DEPENDENCY ${TBB_ROOT_DIR}/src/tbbmalloc/${ARCH_PREFIX}-tbbmalloc-export.def - COMMENT "Preprocessing tbbmalloc.def" - ) -endif() - -add_custom_target(tbb_def_files DEPENDS tbb.def tbbmalloc.def) - -# TBB library -if(TBB_BUILD_STATIC) - add_library(tbb_static STATIC ${tbb_src}) - target_include_directories(tbb_static PRIVATE ${TBB_INCLUDES}) - set_property(TARGET tbb_static APPEND PROPERTY COMPILE_DEFINITIONS "__TBB_BUILD=1") - set_property(TARGET tbb_static APPEND_STRING PROPERTY COMPILE_FLAGS ${ENABLE_RTTI}) - install(TARGETS tbb_static - EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} - ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR}) - if(MSVC) - target_compile_definitions(tbb_static PUBLIC __TBB_NO_IMPLICIT_LINKAGE=1) - endif() - - if(UNIX AND NOT APPLE) - target_link_libraries(tbb_static PUBLIC pthread dl) - endif() -endif() - -if(TBB_BUILD_SHARED) - add_library(tbb SHARED ${tbb_src}) - target_include_directories(tbb PRIVATE ${TBB_INCLUDES}) - set_property(TARGET tbb APPEND PROPERTY COMPILE_DEFINITIONS "__TBB_BUILD=1") - set_property(TARGET tbb APPEND_STRING PROPERTY COMPILE_FLAGS ${ENABLE_RTTI}) - add_dependencies(tbb tbb_def_files) - - if(APPLE) - set_property(TARGET tbb APPEND PROPERTY LINK_FLAGS "-Wl,-exported_symbols_list,\"${CMAKE_CURRENT_BINARY_DIR}/tbb.def\"") - elseif(MSVC) - set_property(TARGET tbb APPEND PROPERTY LINK_FLAGS "/DEF:\"${CMAKE_CURRENT_BINARY_DIR}/tbb.def\"") - else() - set_property(TARGET tbb APPEND PROPERTY LINK_FLAGS "-Wl,-version-script,\"${CMAKE_CURRENT_BINARY_DIR}/tbb.def\"") - endif() - - install(TARGETS tbb - EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} - LIBRARY DESTINATION ${TBB_INSTALL_LIBRARY_DIR} - ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR} - RUNTIME DESTINATION ${TBB_INSTALL_RUNTIME_DIR}) - if(UNIX AND NOT APPLE) - target_link_libraries(tbb PUBLIC pthread dl) - endif() - if(MSVC) - target_compile_definitions(tbb PUBLIC __TBB_NO_IMPLICIT_LINKAGE=1) - endif() -endif() - - -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - # Quench a warning on GCC - set_source_files_properties(${TBB_ROOT_DIR}/src/tbb/governor.cpp COMPILE_FLAGS "-Wno-missing-field-initializers ") -elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") - # Quench a warning on Clang - set_source_files_properties(${TBB_ROOT_DIR}/src/tbb/itt_notify.cpp COMPILE_FLAGS "-Wno-varargs ") -elseif(MSVC) - # Quench a warning on MSVC - set_source_files_properties(${TBB_ROOT_DIR}/src/tbb/scheduler.cpp COMPILE_FLAGS "/wd4458 ") -endif() - -if(TBB_BUILD_TBBMALLOC) - # TBB malloc library - if(TBB_BUILD_STATIC) - add_library(tbbmalloc_static STATIC ${tbbmalloc_static_src}) - target_include_directories(tbbmalloc_static PRIVATE ${TBB_INCLUDES}) - set_property(TARGET tbbmalloc_static APPEND PROPERTY COMPILE_DEFINITIONS "__TBBMALLOC_BUILD=1") - set_property(TARGET tbbmalloc_static APPEND_STRING PROPERTY COMPILE_FLAGS ${DISABLE_RTTI}) - if(MSVC) - target_compile_definitions(tbbmalloc_static PUBLIC __TBB_NO_IMPLICIT_LINKAGE=1 __TBBMALLOC_NO_IMPLICIT_LINKAGE=1) - endif() - install(TARGETS tbbmalloc_static - EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} - ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR}) - endif() - - if(TBB_BUILD_SHARED) - add_library(tbbmalloc SHARED ${tbbmalloc_src}) - target_include_directories(tbbmalloc PRIVATE ${TBB_INCLUDES}) - set_property(TARGET tbbmalloc APPEND PROPERTY COMPILE_DEFINITIONS "__TBBMALLOC_BUILD=1") - set_property(TARGET tbbmalloc APPEND_STRING PROPERTY COMPILE_FLAGS ${DISABLE_RTTI}) - add_dependencies(tbbmalloc tbb_def_files) - if(APPLE) - set_property(TARGET tbbmalloc APPEND PROPERTY LINK_FLAGS "-Wl,-exported_symbols_list,\"${CMAKE_CURRENT_BINARY_DIR}/tbbmalloc.def\"") - elseif(MSVC) - set_property(TARGET tbbmalloc APPEND PROPERTY LINK_FLAGS "/DEF:\"${CMAKE_CURRENT_BINARY_DIR}/tbbmalloc.def\"") - else() - set_property(TARGET tbbmalloc APPEND PROPERTY LINK_FLAGS "-Wl,-version-script,\"${CMAKE_CURRENT_BINARY_DIR}/tbbmalloc.def\"") - endif() - if(MSVC) - target_compile_definitions(tbbmalloc PUBLIC __TBB_NO_IMPLICIT_LINKAGE=1 __TBBMALLOC_NO_IMPLICIT_LINKAGE=1) - endif() - install(TARGETS tbbmalloc - EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} - LIBRARY DESTINATION ${TBB_INSTALL_LIBRARY_DIR} - ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR} - RUNTIME DESTINATION ${TBB_INSTALL_RUNTIME_DIR}) - if(UNIX AND NOT APPLE) - target_link_libraries(tbbmalloc PUBLIC pthread dl) - endif() - endif() -endif() - -if(TBB_BUILD_TBBMALLOC_PROXY) - # TBB malloc proxy library - if(TBB_BUILD_STATIC) - add_library(tbbmalloc_proxy_static STATIC ${tbbmalloc_proxy_src}) - set_property(TARGET tbbmalloc_proxy_static APPEND PROPERTY COMPILE_DEFINITIONS "__TBBMALLOC_BUILD=1") - set_property(TARGET tbbmalloc_proxy_static APPEND_STRING PROPERTY COMPILE_FLAGS ${DISABLE_RTTI}) - install(TARGETS tbbmalloc_proxy_static - EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} - ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR}) - endif() - - if(TBB_BUILD_SHARED) - add_library(tbbmalloc_proxy SHARED ${tbbmalloc_proxy_src}) - set_property(TARGET tbbmalloc_proxy APPEND PROPERTY COMPILE_DEFINITIONS "__TBBMALLOC_BUILD=1") - set_property(TARGET tbbmalloc_proxy APPEND_STRING PROPERTY COMPILE_FLAGS ${DISABLE_RTTI}) - target_link_libraries(tbbmalloc_proxy PUBLIC tbbmalloc) - install(TARGETS tbbmalloc_proxy - EXPORT ${TBB_INSTALL_EXPORT_NAME} DESTINATION ${TBB_INSTALL_EXPORT_DESTINATION} - LIBRARY DESTINATION ${TBB_INSTALL_LIBRARY_DIR} - ARCHIVE DESTINATION ${TBB_INSTALL_ARCHIVE_DIR} - RUNTIME DESTINATION ${TBB_INSTALL_RUNTIME_DIR}) - if(UNIX AND NOT APPLE) - target_link_libraries(tbbmalloc_proxy PUBLIC pthread dl) - endif() - endif() -endif() - -install(DIRECTORY "${TBB_ROOT_DIR}/include/tbb" DESTINATION ${TBB_INSTALL_INCLUDE_DIR}) - -# version_string.ver -if(UNIX) - execute_process(COMMAND date "+%a, %d %b %Y %H:%M:%S %z" - OUTPUT_VARIABLE _configure_date - OUTPUT_STRIP_TRAILING_WHITESPACE) -elseif(WIN32) - execute_process(COMMAND cmd " /C date /T" - OUTPUT_VARIABLE _configure_date - OUTPUT_STRIP_TRAILING_WHITESPACE) -else() - set(_configure_date "Unknown") -endif() -include_directories(${CMAKE_BINARY_DIR}) -configure_file(extra/version_string.ver.in version_string.ver @ONLY) diff --git a/aten/src/ATen/cpu/tbb/extra/version_string.ver.in b/aten/src/ATen/cpu/tbb/extra/version_string.ver.in deleted file mode 100644 index bb9f96e8f295..000000000000 --- a/aten/src/ATen/cpu/tbb/extra/version_string.ver.in +++ /dev/null @@ -1,11 +0,0 @@ -#define __TBB_VERSION_STRINGS(N) \ -#N": BUILD_HOST @CMAKE_SYSTEM_NAME@" ENDL \ -#N": BUILD_OS @CMAKE_SYSTEM@" ENDL \ -#N": BUILD_KERNEL @CMAKE_SYSTEM_VERSION@" ENDL \ -#N": BUILD_GCC @CMAKE_CXX_COMPILER_ID@" ENDL \ -#N": BUILD_LIBC Unknown" ENDL \ -#N": BUILD_LD Unknown" ENDL \ -#N": BUILD_TARGET Unknown" ENDL \ -#N": BUILD_COMMAND Unknown" ENDL - -#define __TBB_DATETIME "@_configure_date@" diff --git a/buckbuild.bzl b/buckbuild.bzl index 4c4fc9a89a28..649ebe668365 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -261,7 +261,6 @@ def get_aten_preprocessor_flags(): "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", "-DAT_PARALLEL_OPENMP_FBXPLAT=0", "-DAT_PARALLEL_NATIVE_FBXPLAT=1", - "-DAT_PARALLEL_NATIVE_TBB_FBXPLAT=0", "-DUSE_LAPACK_FBXPLAT=0", "-DAT_BLAS_F2C_FBXPLAT=0", "-DAT_BLAS_USE_CBLAS_DOT_FBXPLAT=0", @@ -1112,9 +1111,6 @@ def define_buck_targets( "@AT_PARALLEL_NATIVE@", "AT_PARALLEL_NATIVE_FBXPLAT", "--replace", - "@AT_PARALLEL_NATIVE_TBB@", - "AT_PARALLEL_NATIVE_TBB_FBXPLAT", - "--replace", "@AT_BUILD_WITH_LAPACK@", "USE_LAPACK_FBXPLAT", "--replace", diff --git a/build_variables.bzl b/build_variables.bzl index 26986506ec8b..20822ba95cf2 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1001,7 +1001,6 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/NestedTensorImpl.cpp", "aten/src/ATen/ParallelCommon.cpp", "aten/src/ATen/ParallelNative.cpp", - "aten/src/ATen/ParallelNativeTBB.cpp", "aten/src/ATen/ParallelOpenMP.cpp", "aten/src/ATen/ParallelThreadPoolNative.cpp", "aten/src/ATen/PythonTorchFunctionTLS.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 54a31185c127..02823729d71a 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -16,14 +16,11 @@ endif() # ATen parallelism settings # OMP - OpenMP for intra-op, native thread pool for inter-op parallelism # NATIVE - using native thread pool for intra- and inter-op parallelism -# TBB - using TBB for intra- and native thread pool for inter-op parallelism if(INTERN_BUILD_MOBILE) set(ATEN_THREADING "NATIVE" CACHE STRING "ATen parallel backend") else() if(USE_OPENMP) set(ATEN_THREADING "OMP" CACHE STRING "ATen parallel backend") - elseif(USE_TBB) - set(ATEN_THREADING "TBB" CACHE STRING "ATen parallel backend") else() set(ATEN_THREADING "NATIVE" CACHE STRING "ATen parallel backend") endif() @@ -31,19 +28,12 @@ endif() set(AT_PARALLEL_OPENMP 0) set(AT_PARALLEL_NATIVE 0) -set(AT_PARALLEL_NATIVE_TBB 0) message(STATUS "Using ATen parallel backend: ${ATEN_THREADING}") if("${ATEN_THREADING}" STREQUAL "OMP") set(AT_PARALLEL_OPENMP 1) elseif("${ATEN_THREADING}" STREQUAL "NATIVE") set(AT_PARALLEL_NATIVE 1) -elseif("${ATEN_THREADING}" STREQUAL "TBB") - if(NOT USE_TBB) - message(FATAL_ERROR "Using TBB backend but USE_TBB is off") - endif() - message(WARNING "ATEN TBB Threading is deprectated.") - set(AT_PARALLEL_NATIVE_TBB 1) else() message(FATAL_ERROR "Unknown ATen parallel backend: ${ATEN_THREADING}") endif() @@ -1223,11 +1213,6 @@ if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU" set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/qlinear_unpack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) endif() -if(USE_TBB) - list(APPEND ATen_CPU_INCLUDE ${TBB_INCLUDE_DIR}) - target_link_libraries(torch_cpu PUBLIC TBB::tbb) -endif() - target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE}) target_include_directories(torch_cpu PRIVATE @@ -1708,10 +1693,6 @@ if(BUILD_SHARED_LIBS) target_link_libraries(torch_global_deps ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS}) target_link_libraries(torch_global_deps torch::cudart torch::nvtoolsext) endif() - if(USE_TBB) - target_link_libraries(torch_global_deps TBB::tbb) - endif() - install(TARGETS torch_global_deps DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index e9fd67018da7..f953146605ed 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -134,35 +134,6 @@ else() "Cannot find threading library. PyTorch requires Threads to compile.") endif() -if(USE_TBB) - if(USE_SYSTEM_TBB) - find_package(TBB 2018.0 REQUIRED CONFIG COMPONENTS tbb) - - get_target_property(TBB_INCLUDE_DIR TBB::tbb INTERFACE_INCLUDE_DIRECTORIES) - else() - message(STATUS "Compiling TBB from source") - # Unset our restrictive C++ flags here and reset them later. - # Remove this once we use proper target_compile_options. - set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - set(CMAKE_CXX_FLAGS) - - set(TBB_ROOT_DIR "${PROJECT_SOURCE_DIR}/third_party/tbb") - set(TBB_BUILD_STATIC OFF CACHE BOOL " " FORCE) - set(TBB_BUILD_SHARED ON CACHE BOOL " " FORCE) - set(TBB_BUILD_TBBMALLOC OFF CACHE BOOL " " FORCE) - set(TBB_BUILD_TBBMALLOC_PROXY OFF CACHE BOOL " " FORCE) - set(TBB_BUILD_TESTS OFF CACHE BOOL " " FORCE) - add_subdirectory(${PROJECT_SOURCE_DIR}/aten/src/ATen/cpu/tbb) - set_property(TARGET tbb tbb_def_files PROPERTY FOLDER "dependencies") - - set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) - - set(TBB_INCLUDE_DIR "${TBB_ROOT_DIR}/include") - - add_library(TBB::tbb ALIAS tbb) - endif() -endif() - # ---[ protobuf if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) if(USE_LITE_PROTO) diff --git a/cmake/Modules/FindMKL.cmake b/cmake/Modules/FindMKL.cmake index 01de7c7cec15..daaa5dd24f00 100644 --- a/cmake/Modules/FindMKL.cmake +++ b/cmake/Modules/FindMKL.cmake @@ -71,8 +71,8 @@ IF (NOT "${MKL_THREADING}" STREQUAL "SEQ" AND MESSAGE(FATAL_ERROR "Invalid MKL_THREADING (${MKL_THREADING}), should be one of: SEQ, TBB, OMP") ENDIF() -IF ("${MKL_THREADING}" STREQUAL "TBB" AND NOT USE_TBB) - MESSAGE(FATAL_ERROR "MKL_THREADING is TBB but USE_TBB is turned off") +IF ("${MKL_THREADING}" STREQUAL "TBB" AND NOT TARGET TBB::tbb) + MESSAGE(FATAL_ERROR "MKL_THREADING is TBB but TBB is not found") ENDIF() MESSAGE(STATUS "MKL_THREADING = ${MKL_THREADING}") diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 2e33654cc484..b93f9229fc23 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -109,7 +109,7 @@ IF(NOT MKLDNN_FOUND) IF(NOT MKLDNN_CPU_RUNTIME) SET(MKLDNN_CPU_RUNTIME "OMP" CACHE STRING "") ELSEIF(MKLDNN_CPU_RUNTIME STREQUAL "TBB") - IF(USE_TBB) + IF(TARGET TBB::tbb) MESSAGE(STATUS "MKL-DNN is using TBB") SET(TBB_cmake_included TRUE) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 329bdd19a6cd..99b6521328d6 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -151,10 +151,6 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_OBSERVERS : ${USE_OBSERVERS}") message(STATUS " USE_OPENCL : ${USE_OPENCL}") message(STATUS " USE_OPENMP : ${USE_OPENMP}") - message(STATUS " USE_TBB : ${USE_TBB}") - if(${USE_TBB}) - message(STATUS " USE_SYSTEM_TBB : ${USE_SYSTEM_TBB}") - endif() message(STATUS " USE_MIMALLOC : ${USE_MIMALLOC}") message(STATUS " USE_VULKAN : ${USE_VULKAN}") if(${USE_VULKAN}) diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 597c22c75dd0..0f5da8e6cae2 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -317,9 +317,6 @@ function(caffe2_binary_target target_name_or_src) if(DEFINED Caffe2_MODULES) target_link_libraries(${__target} ${Caffe2_MODULES}) endif() - if(USE_TBB AND NOT USE_SYSTEM_TBB) - target_include_directories(${__target} PUBLIC ${TBB_INCLUDE_DIR}) - endif() install(TARGETS ${__target} DESTINATION bin) endfunction() diff --git a/defs.bzl b/defs.bzl index 83aa9383d7c4..6ea4b1219325 100644 --- a/defs.bzl +++ b/defs.bzl @@ -64,8 +64,6 @@ def get_cpu_parallel_backend_flags(): defs = [] if parallel_backend == "openmp": defs.append("-DAT_PARALLEL_OPENMP_FBCODE=1") - elif parallel_backend == "tbb": - defs.append("-DAT_PARALLEL_NATIVE_TBB_FBCODE=1") elif parallel_backend == "native": defs.append("-DAT_PARALLEL_NATIVE_FBCODE=1") else: diff --git a/setup.py b/setup.py index 9826207de73b..e2529335bcc6 100644 --- a/setup.py +++ b/setup.py @@ -179,13 +179,6 @@ # possible values: # OMP - use OpenMP for intra-op and native backend for inter-op tasks # NATIVE - use native thread pool for both intra- and inter-op tasks -# TBB - using TBB for intra- and native thread pool for inter-op parallelism -# -# USE_TBB -# enable TBB support -# -# USE_SYSTEM_TBB -# Use system-provided Intel TBB. # # USE_SYSTEM_LIBS (work in progress) # Use system-provided libraries to satisfy the build dependencies. @@ -371,7 +364,6 @@ def get_submodule_folders(): for name in [ "gloo", "cpuinfo", - "tbb", "onnx", "foxi", "QNNPACK", diff --git a/third_party/mkl-dnn.BUILD b/third_party/mkl-dnn.BUILD index 9a688a52b1cf..bb7f34107892 100644 --- a/third_party/mkl-dnn.BUILD +++ b/third_party/mkl-dnn.BUILD @@ -130,10 +130,7 @@ cc_library( ], deps = [ "@mkl", - ] + select({ - "@pytorch//tools/config:thread_sanitizer": [], - "//conditions:default": ["@tbb"], - }), + ], defines = [ "DNNL_ENABLE_MAX_CPU_ISA", "DNNL_ENABLE_CONCURRENT_EXEC", diff --git a/third_party/mkl.BUILD b/third_party/mkl.BUILD index c3115f164a66..b7abb0e035ad 100644 --- a/third_party/mkl.BUILD +++ b/third_party/mkl.BUILD @@ -12,10 +12,7 @@ cc_library( "libmkl_vml_avx2.so", "libmkl_vml_avx512.so", "libmkl_vml_def.so", - ] + select({ - "@pytorch//tools/config:thread_sanitizer": [], - "//conditions:default": ["libmkl_tbb_thread.so"], - }), + ], visibility = ["//visibility:public"], deps = ["@mkl_headers"], ) diff --git a/third_party/tbb b/third_party/tbb deleted file mode 160000 index a51a90bc609b..000000000000 --- a/third_party/tbb +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a51a90bc609bb73db8ea13841b5cf7aa4344d4a9 diff --git a/third_party/tbb.BUILD b/third_party/tbb.BUILD deleted file mode 100644 index b11e65847331..000000000000 --- a/third_party/tbb.BUILD +++ /dev/null @@ -1,75 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") -load("@pytorch//third_party:substitution.bzl", "template_rule") - -licenses(["notice"]) # Apache 2.0 - -template_rule( - name = "version_string", - src = "@//:aten/src/ATen/cpu/tbb/extra/version_string.ver.in", - out = "version_string.h", - substitutions = { - "@CMAKE_SYSTEM_NAME@": "Unknown", - "@CMAKE_SYSTEM@": "Unknown", - "@CMAKE_SYSTEM_VERSION@": "Unknown", - "@CMAKE_CXX_COMPILER_ID@": "Unknown", - "@_configure_date@": "Unknown", - } -) - -cc_library( - name = "tbb", - srcs = [":version_string"] + glob( - [ - "src/old/*.h", - "src/rml/client/*.h", - "src/rml/include/*.h", - "src/rml/server/*.h", - "src/tbb/*.h", - "src/tbb/tools_api/*.h", - "src/tbb/tools_api/legacy/*.h", - "src/old/*.cpp", - "src/tbb/*.cpp", - ], - exclude = ["src/old/test_*.cpp"], - ) + ["src/rml/client/rml_tbb.cpp"], - hdrs = glob( - [ - "include/tbb/*", - "include/tbb/compat/*", - "include/tbb/internal/*", - "include/tbb/machine/*", - ], - exclude = ["include/tbb/scalable_allocator.h"], - ), - copts = [ - "-Iexternal/tbb/src/rml/include", - "-Iexternal/tbb/src", - "-pthread", - "-DDO_ITT_NOTIFY=1", - "-DUSE_PTHREAD=1", - "-D__TBB_BUILD=1", - "-D__TBB_DYNAMIC_LOAD_ENABLED=0", - "-D__TBB_SOURCE_DIRECTLY_INCLUDED=1", - "-fno-sanitize=vptr", - "-fno-sanitize=thread", - ], - defines = [ - # TBB Cannot detect the standard library version when using clang with libstdc++. - # See https://github.com/01org/tbb/issues/22 - "TBB_USE_GLIBCXX_VERSION=(_GLIBCXX_RELEASE*10000)", - "TBB_PREVIEW_GLOBAL_CONTROL=1", - "TBB_PREVIEW_LOCAL_OBSERVER=1", - "__TBB_ALLOW_MUTABLE_FUNCTORS=1", - ], - includes = [ - "include", - "src/tbb/tools_api", - ], - linkopts = [ - "-ldl", - "-lpthread", - "-lrt", - ], - textual_hdrs = ["src/tbb/tools_api/ittnotify_static.c"], - visibility = ["//visibility:public"], -) diff --git a/third_party/tbb.patch b/third_party/tbb.patch deleted file mode 100644 index 4a1f6845b774..000000000000 --- a/third_party/tbb.patch +++ /dev/null @@ -1,34 +0,0 @@ -diff --git a/src/rml/server/rml_server.cpp b/src/rml/server/rml_server.cpp -index 2508465..1e22ad2 100644 ---- a/src/rml/server/rml_server.cpp -+++ b/src/rml/server/rml_server.cpp -@@ -3279,10 +3279,10 @@ extern "C" void __KMP_call_with_my_server_info( ::rml::server_info_callback_t cb - /* - * RML server info - */ --#include "version_string.ver" -+#include "version_string.h" - - #ifndef __TBB_VERSION_STRINGS --#pragma message("Warning: version_string.ver isn't generated properly by version_info.sh script!") -+#pragma message("Warning: version_string.h isn't generated properly by version_info.sh script!") - #endif - - // We use the build time as the RML server info. TBB is required to build RML, so we make it the same as the TBB build time. -diff --git a/src/tbb/tbb_version.h b/src/tbb/tbb_version.h -index dcaa55b..4981a8a 100644 ---- a/src/tbb/tbb_version.h -+++ b/src/tbb/tbb_version.h -@@ -25,10 +25,10 @@ - #ifndef ENDL - #define ENDL "\n" - #endif --#include "version_string.ver" -+#include "version_string.h" - - #ifndef __TBB_VERSION_STRINGS --#pragma message("Warning: version_string.ver isn't generated properly by version_info.sh script!") -+#pragma message("Warning: version_string.h isn't generated properly by version_info.sh script!") - // here is an example of macros value: - #define __TBB_VERSION_STRINGS \ - "TBB: BUILD_HOST\tUnknown\n" \ diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index ffd0e6f95a87..5e7e3739695d 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -24,7 +24,7 @@ marginrankingloss_reference, multimarginloss_reference, multilabelmarginloss_reference, nllloss_reference, nlllossNd_reference, smoothl1loss_reference, softmarginloss_reference, get_reduction) from torch.testing._internal.common_utils import ( - freeze_rng_state, set_single_threaded_if_parallel_tbb, skipIfMps, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS, + freeze_rng_state, skipIfMps, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS, skipIfTorchDynamo) from types import ModuleType from typing import List, Tuple, Type, Set, Dict @@ -235,7 +235,7 @@ def __init__(self, self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin) def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): - result = [set_single_threaded_if_parallel_tbb] + result = [] for decorator in self.decorators: if isinstance(decorator, DecorateInfo): if decorator.is_active(test_class, test_name, device, dtype, param_kwargs): diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 43d0124a6021..bb8375e35cfd 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -39,7 +39,6 @@ from torch.testing._internal.common_methods_invocations import DecorateInfo from torch.testing._internal.common_utils import ( _TestParametrizer, - set_single_threaded_if_parallel_tbb, skipIfMps, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, @@ -161,7 +160,7 @@ def __init__( self.supports_fused_on = supports_fused_on def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): - result = [set_single_threaded_if_parallel_tbb] + result = [] for decorator in self.decorators: if isinstance(decorator, DecorateInfo): if decorator.is_active( diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 475ab977cdb4..2237ec67c500 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1497,8 +1497,6 @@ def wrapper(*args, **kwargs): # See: https://github.com/pytorch/pytorch/pull/59402#issuecomment-858811135 TestEnvironment.def_flag("TEST_CUDA_MEM_LEAK_CHECK", env_var="PYTORCH_TEST_CUDA_MEM_LEAK_CHECK") -# True if CI is running TBB-enabled Pytorch -IS_TBB = "tbb" in os.getenv("BUILD_ENVIRONMENT", "") # Dict of NumPy dtype -> torch dtype (when the correspondence exists) numpy_to_torch_dtype_dict = { @@ -1875,19 +1873,6 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper - -def skipIfTBB(message="This test makes TBB sad"): - def dec_fn(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - if IS_TBB: - raise unittest.SkipTest(message) - else: - fn(*args, **kwargs) - return wrapper - return dec_fn - - def skip_if_pytest(fn): @wraps(fn) def wrapped(*args, **kwargs): @@ -4723,24 +4708,6 @@ def dtype_name(dtype): } -def set_single_threaded_if_parallel_tbb(fn): - """Set test to be single threaded for parallel tbb. - - See https://github.com/pytorch/pytorch/issues/64571#issuecomment-914691883 - """ - if not IS_TBB: - return fn - - @wraps(fn) - def wrap_fn(*args, **kwargs): - num_threads = torch.get_num_threads() - torch.set_num_threads(1) - try: - return fn(*args, **kwargs) - finally: - torch.set_num_threads(num_threads) - return wrap_fn - @functools.lru_cache def get_cycles_per_ms() -> float: diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index b625d2dbd40f..913947ea84c7 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1880,9 +1880,6 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): if not is_standalone: extra_ldflags.append('-ltorch_python') - if is_standalone and "TBB" in torch.__config__.parallel_info(): - extra_ldflags.append('-ltbb') - if is_standalone: extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}") From 7646825c3eb687030c4f873b01312be0eed80174 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 31 May 2024 01:21:23 +0000 Subject: [PATCH 141/706] Revert "distributed debug handlers (#126601)" This reverts commit 3d541835d509910fceca00fc5a916e9718c391d8. Reverted https://github.com/pytorch/pytorch/pull/126601 on behalf of https://github.com/PaliC due to breaking internal typechecking tests ([comment](https://github.com/pytorch/pytorch/pull/126601#issuecomment-2141076987)) --- BUILD.bazel | 1 - WORKSPACE | 6 - build_variables.bzl | 2 - caffe2/CMakeLists.txt | 3 - cmake/Dependencies.cmake | 4 - docs/source/distributed.elastic.rst | 1 - docs/source/elastic/control_plane.rst | 10 - .../distributed/elastic/test_control_plane.py | 86 --------- third_party/cpp-httplib.BUILD | 10 - torch/CMakeLists.txt | 2 - torch/_C/_distributed_c10d.pyi | 4 - .../distributed/c10d/ProcessGroupNCCL.cpp | 8 - .../c10d/control_plane/Handlers.cpp | 75 -------- .../c10d/control_plane/Handlers.hpp | 67 ------- .../c10d/control_plane/WorkerServer.cpp | 178 ------------------ .../c10d/control_plane/WorkerServer.hpp | 28 --- torch/csrc/distributed/c10d/init.cpp | 12 -- torch/distributed/elastic/control_plane.py | 51 ----- 18 files changed, 548 deletions(-) delete mode 100644 docs/source/elastic/control_plane.rst delete mode 100644 test/distributed/elastic/test_control_plane.py delete mode 100644 third_party/cpp-httplib.BUILD delete mode 100644 torch/csrc/distributed/c10d/control_plane/Handlers.cpp delete mode 100644 torch/csrc/distributed/c10d/control_plane/Handlers.hpp delete mode 100644 torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp delete mode 100644 torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp delete mode 100644 torch/distributed/elastic/control_plane.py diff --git a/BUILD.bazel b/BUILD.bazel index 9eff26e01ca9..ecbeaab9bbf8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -781,7 +781,6 @@ cc_library( ":caffe2", ":torch_headers", "@kineto", - "@cpp-httplib", ] + if_cuda([ "@cuda//:nvToolsExt", "@cutlass", diff --git a/WORKSPACE b/WORKSPACE index 4169e0dbce1d..5b4f2f2e3375 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -168,12 +168,6 @@ new_local_repository( path = "third_party/opentelemetry-cpp", ) -new_local_repository( - name = "cpp-httplib", - build_file = "//third_party:cpp-httplib.BUILD", - path = "third_party/cpp-httplib", -) - new_local_repository( name = "tensorpipe", build_file = "//third_party:tensorpipe.BUILD", diff --git a/build_variables.bzl b/build_variables.bzl index 20822ba95cf2..8b5ac4f46d7c 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -515,8 +515,6 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/sequence_num.cpp", "torch/csrc/distributed/c10d/socket.cpp", "torch/csrc/distributed/c10d/Work.cpp", - "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", - "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", ] # These files are only supported on Linux (and others) but not on Windows. diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 02823729d71a..1e29044e19fd 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1229,9 +1229,6 @@ if(USE_KINETO) ${TORCH_ROOT}/third_party/kineto/libkineto/src) endif() -target_include_directories(torch_cpu PRIVATE - ${TORCH_ROOT}/third_party/cpp-httplib) - install(DIRECTORY "${TORCH_SRC_DIR}/csrc" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp") diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index f953146605ed..e15b55cd16ed 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1699,7 +1699,3 @@ endif() # Include google/FlatBuffers include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake) - -# Include cpp-httplib -add_library(httplib INTERFACE IMPORTED) -target_include_directories(httplib SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/cpp-httplib) diff --git a/docs/source/distributed.elastic.rst b/docs/source/distributed.elastic.rst index 0aabb560c9c8..24d33d1982df 100644 --- a/docs/source/distributed.elastic.rst +++ b/docs/source/distributed.elastic.rst @@ -29,7 +29,6 @@ Documentation elastic/metrics elastic/events elastic/subprocess_handler - elastic/control_plane .. toctree:: :maxdepth: 1 diff --git a/docs/source/elastic/control_plane.rst b/docs/source/elastic/control_plane.rst deleted file mode 100644 index c37454cf1b0a..000000000000 --- a/docs/source/elastic/control_plane.rst +++ /dev/null @@ -1,10 +0,0 @@ -Control Plane -============= - -.. automodule:: torch.distributed.elastic.control_plane -.. currentmodule:: torch.distributed.elastic.control_plane - -This module contains optional helpers that add extra debug and control handlers -into your application. - -.. autofunction:: torch.distributed.elastic.control_plane.worker_main diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py deleted file mode 100644 index c9ae512f2718..000000000000 --- a/test/distributed/elastic/test_control_plane.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python3 -# Owner(s): ["oncall: distributed"] - -import json -import os -import pickle -import socket -import tempfile -from contextlib import contextmanager - -from urllib3.connection import HTTPConnection -from urllib3.connectionpool import HTTPConnectionPool - -from torch.distributed.elastic.control_plane import ( - TORCH_WORKER_SERVER_SOCKET, - worker_main, -) -from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase - - -class UnixHTTPConnection(HTTPConnection): - def __init__(self, socket_path: str) -> None: - super().__init__("localhost") - - self.socket_path = socket_path - - def connect(self) -> None: - self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.sock.connect(self.socket_path) - - -class UnixHTTPConnectionPool(HTTPConnectionPool): - def __init__(self, socket_path: str) -> None: - super().__init__("localhost") - - self.socket_path = socket_path - - def _new_conn(self): - return UnixHTTPConnection(self.socket_path) - - -@contextmanager -def local_worker_server() -> None: - with tempfile.TemporaryDirectory() as tmpdir: - socket_path = os.path.join(tmpdir, "socket.sock") - os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path - - with worker_main(): - pool = UnixHTTPConnectionPool(socket_path) - yield pool - - -class WorkerServerTest(TestCase): - def test_worker_server(self) -> None: - with local_worker_server() as pool: - resp = pool.request("GET", "/") - self.assertEqual(resp.status, 200) - self.assertEqual( - resp.data, - b"""

torch.distributed.WorkerServer

-Handler names -""", - ) - - resp = pool.request("POST", "/handler/ping") - self.assertEqual(resp.status, 200) - self.assertEqual(resp.data, b"pong") - - resp = pool.request("GET", "/handler/") - self.assertEqual(resp.status, 200) - self.assertIn("ping", json.loads(resp.data)) - - resp = pool.request("POST", "/handler/nonexistant") - self.assertEqual(resp.status, 404) - self.assertIn(b"Handler nonexistant not found:", resp.data) - - @requires_cuda - def test_dump_nccl_trace_pickle(self) -> None: - with local_worker_server() as pool: - resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") - self.assertEqual(resp.status, 200) - out = pickle.loads(resp.data) - - -if __name__ == "__main__": - run_tests() diff --git a/third_party/cpp-httplib.BUILD b/third_party/cpp-httplib.BUILD deleted file mode 100644 index 3cd0c3dbe94b..000000000000 --- a/third_party/cpp-httplib.BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library( - name = "cpp-httplib", - hdrs = ["httplib.h"], - includes = [ - "/", - ], - visibility = ["//visibility:public"], -) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 1cf1fe2cf599..b4db57488f02 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -68,7 +68,6 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES ${TORCH_ROOT}/third_party/onnx ${TORCH_ROOT}/third_party/flatbuffers/include ${TORCH_ROOT}/third_party/kineto/libkineto/include - ${TORCH_ROOT}/third_party/cpp-httplib ${TORCH_SRC_DIR}/csrc ${TORCH_SRC_DIR}/csrc/api/include @@ -81,7 +80,6 @@ set(TORCH_PYTHON_LINK_LIBRARIES python::python pybind::pybind11 opentelemetry::api - httplib shm fmt::fmt-header-only ATEN_CPU_FILES_GEN_LIB) diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 5594d6153b07..1a3e4ea63342 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -632,7 +632,3 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... - -class _WorkerServer: - def __init__(self, socket_path: str) -> None: ... - def shutdown(self) -> None: ... diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index be2853efc113..ce2dc9d072b4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -28,7 +28,6 @@ #include #include #include -#include #include #include @@ -356,13 +355,6 @@ std::string dump_nccl_trace() { } #endif -// TODO(c-p-i-o): add a JSON endpoint. -control_plane::RegisterHandler dumpHandler{ - "dump_nccl_trace_pickle", - [](const control_plane::Request&, control_plane::Response& res) { - res.setContent(dump_nccl_trace(), "application/octet-stream"); - }}; - std::optional)>>& get_cpp_trace_dumper() { static std::optional< diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp deleted file mode 100644 index e29f1e3a2ac3..000000000000 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include - -#include -#include -#include -#include - -namespace c10d { -namespace control_plane { - -namespace { - -class HandlerRegistry { - public: - void registerHandler(const std::string& name, HandlerFunc f) { - std::unique_lock lock(handlersMutex_); - - if (handlers_.find(name) != handlers_.end()) { - throw std::runtime_error( - fmt::format("Handler {} already registered", name)); - } - - handlers_[name] = f; - } - - HandlerFunc getHandler(const std::string& name) { - std::shared_lock lock(handlersMutex_); - - auto it = handlers_.find(name); - if (it == handlers_.end()) { - throw std::runtime_error(fmt::format("Failed to find handler {}", name)); - } - return handlers_[name]; - } - - std::vector getHandlerNames() { - std::shared_lock lock(handlersMutex_); - - std::vector names; - for (const auto& [name, _] : handlers_) { - names.push_back(name); - } - return names; - } - - private: - std::shared_mutex handlersMutex_{}; - std::unordered_map handlers_{}; -}; - -HandlerRegistry& getHandlerRegistry() { - static HandlerRegistry registry; - return registry; -} - -RegisterHandler pingHandler{"ping", [](const Request&, Response& res) { - res.setContent("pong", "text/plain"); - }}; - -} // namespace - -void registerHandler(const std::string& name, HandlerFunc f) { - return getHandlerRegistry().registerHandler(name, f); -} - -HandlerFunc getHandler(const std::string& name) { - return getHandlerRegistry().getHandler(name); -} - -std::vector getHandlerNames() { - return getHandlerRegistry().getHandlerNames(); -} - -} // namespace control_plane -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp deleted file mode 100644 index 0c1063054931..000000000000 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp +++ /dev/null @@ -1,67 +0,0 @@ -#pragma once - -#include -#include - -#include - -namespace c10d { -namespace control_plane { - -// Request represents a request to the handler. This conceptually maps to an -// HTTP request but could be called via other transports. -class TORCH_API Request { - public: - virtual ~Request() = default; - - virtual const std::string& body() = 0; -}; - -// Response represents a response to the handler. This conceptually maps to an -// HTTP response but could be called via other transports. -class TORCH_API Response { - public: - virtual ~Response() = default; - - // Set the response body to the provided string. - // TODO: add support for chunked responses - virtual void setContent( - std::string&& content, - const std::string& content_type) = 0; - - // Set the response status code. - // These should match standard HTTP status codes. - virtual void setStatus(int status) = 0; -}; - -using HandlerFunc = std::function; - -// Registers a handler. The name needs to be unique and can be called by using -// getHandler directly or via WorkerServer for remote requests. -// These handlers are called from a background C++ thread concurrently with the -// main thread. These handlers need to be thread safe and not cause issues -// during Python training. -TORCH_API void registerHandler(const std::string& name, HandlerFunc f); - -// Fetches a handler by name. -TORCH_API HandlerFunc getHandler(const std::string& name); - -TORCH_API std::vector getHandlerNames(); - -// Registers a handler statically. -// See registerHandler for more details. -class TORCH_API RegisterHandler { - public: - RegisterHandler(const std::string& name, HandlerFunc f) { - registerHandler(name, f); - } - - // disable move, copy - RegisterHandler(const RegisterHandler&) = delete; - RegisterHandler(RegisterHandler&&) = delete; - RegisterHandler& operator=(const RegisterHandler&) = delete; - RegisterHandler& operator=(RegisterHandler&&) = delete; -}; - -} // namespace control_plane -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp deleted file mode 100644 index 14d287e9607f..000000000000 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ /dev/null @@ -1,178 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace c10d { -namespace control_plane { - -namespace { -class RequestImpl : public Request { - public: - RequestImpl(const httplib::Request& req) : req_(req) {} - - const std::string& body() override { - return req_.body; - } - - private: - const httplib::Request& req_; -}; - -class ResponseImpl : public Response { - public: - ResponseImpl(httplib::Response& res) : res_(res) {} - - void setStatus(int status) override { - res_.status = status; - } - - void setContent(std::string&& content, const std::string& content_type) - override { - res_.set_content(std::move(content), content_type); - } - - private: - httplib::Response& res_; -}; - -std::string jsonStrEscape(const std::string& str) { - std::ostringstream ostream; - for (char ch : str) { - if (ch == '"') { - ostream << "\\\""; - } else if (ch == '\\') { - ostream << "\\\\"; - } else if (ch == '\b') { - ostream << "\\b"; - } else if (ch == '\f') { - ostream << "\\f"; - } else if (ch == '\n') { - ostream << "\\n"; - } else if (ch == '\r') { - ostream << "\\r"; - } else if (ch == '\t') { - ostream << "\\t"; - } else if ('\x00' <= ch && ch <= '\x1f') { - ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0') - << static_cast(ch); - } else { - ostream << ch; - } - } - return ostream.str(); -} -} // namespace - -WorkerServer::WorkerServer(const std::string& socketFile) { - // using unix sockets - server_.set_address_family(AF_UNIX); - - // adjust keep alives as it stops the server from shutting down quickly - server_.set_keep_alive_timeout(1); // second, default is 5 - server_.set_keep_alive_max_count( - 30); // wait max 30 seconds before closing socket - - server_.Get("/", [](const httplib::Request& req, httplib::Response& res) { - res.set_content( - R"BODY(

torch.distributed.WorkerServer

-Handler names -)BODY", - "text/html"); - }); - server_.Get( - "/handler/", [](const httplib::Request& req, httplib::Response& res) { - std::ostringstream body; - body << "["; - bool first = true; - for (const auto& name : getHandlerNames()) { - if (!first) { - body << ","; - } - first = false; - - body << "\"" << jsonStrEscape(name) << "\""; - } - body << "]"; - - res.set_content(body.str(), "application/json"); - }); - server_.Post( - "/handler/:handler", - [](const httplib::Request& req, httplib::Response& res) { - auto handler_name = req.path_params.at("handler"); - HandlerFunc handler; - try { - handler = getHandler(handler_name); - } catch (const std::exception& e) { - res.status = 404; - res.set_content( - fmt::format("Handler {} not found: {}", handler_name, e.what()), - "text/plain"); - return; - } - RequestImpl torchReq{req}; - ResponseImpl torchRes{res}; - - try { - handler(torchReq, torchRes); - } catch (const std::exception& e) { - res.status = 500; - res.set_content( - fmt::format("Handler {} failed: {}", handler_name, e.what()), - "text/plain"); - return; - } catch (...) { - res.status = 500; - res.set_content( - fmt::format( - "Handler {} failed with unknown exception", handler_name), - "text/plain"); - return; - } - }); - - if (std::filesystem::exists(socketFile)) { - throw std::runtime_error(fmt::format("{} already exists", socketFile)); - } - - C10D_WARNING("Server listening to {}", socketFile); - if (!server_.bind_to_port(socketFile, 80)) { - throw std::runtime_error(fmt::format("Error binding to {}", socketFile)); - } - - serverThread_ = std::thread([this]() { - try { - if (!server_.listen_after_bind()) { - throw std::runtime_error("failed to listen"); - } - } catch (std::exception& e) { - C10D_ERROR("Error while running server: {}", e.what()); - throw; - } - C10D_WARNING("Server exited"); - }); -} - -void WorkerServer::shutdown() { - C10D_WARNING("Server shutting down"); - server_.stop(); - serverThread_.join(); -} - -WorkerServer::~WorkerServer() { - if (serverThread_.joinable()) { - C10D_WARNING("WorkerServer destructor called without shutdown"); - shutdown(); - } -} - -} // namespace control_plane -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp deleted file mode 100644 index 7d64038f0b01..000000000000 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include -#include - -namespace c10d { -namespace control_plane { - -class TORCH_API WorkerServer : public c10::intrusive_ptr_target { - public: - WorkerServer(const std::string& socketFile); - ~WorkerServer(); - - void shutdown(); - - private: - httplib::Server server_; - std::thread serverThread_; -}; - -} // namespace control_plane -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index c4b9a9823c84..6f6dae326065 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #ifndef _WIN32 #include @@ -3165,17 +3164,6 @@ such as `dist.all_reduce(tensor, async_op=True)`. return py::bytes(::c10d::dump_nccl_trace()); }); #endif - - intrusive_ptr_class_<::c10d::control_plane::WorkerServer>( - module, "_WorkerServer", R"( -)") - .def( - py::init([](const std::string& socketPath) { - return c10::make_intrusive<::c10d::control_plane::WorkerServer>( - socketPath); - }), - py::arg("socket_path")) - .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown); Py_RETURN_TRUE; } diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py deleted file mode 100644 index 160383637865..000000000000 --- a/torch/distributed/elastic/control_plane.py +++ /dev/null @@ -1,51 +0,0 @@ -import os -from contextlib import contextmanager, ExitStack -from typing import Generator - -from torch.distributed.elastic.multiprocessing.errors import record - -__all__ = [ - "worker_main", -] - -TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" - - -@contextmanager -def _worker_server(socket_path: str) -> Generator[None, None, None]: - from torch._C._distributed_c10d import _WorkerServer - - server = _WorkerServer(socket_path) - try: - yield - finally: - server.shutdown() - - -@contextmanager -@record -def worker_main() -> Generator[None, None, None]: - """ - This is a context manager that wraps your main entry function. This combines - the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that - exposes handlers via a unix socket specified by - ``Torch_WORKER_SERVER_SOCKET``. - - Example - - :: - - @worker_main() - def main(): - pass - - if __name__=="__main__": - main() - - """ - with ExitStack() as stack: - socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) - if socket_path is not None: - stack.enter_context(_worker_server(socket_path)) - - yield From d535de1747c397c45012fb5c728b46c3cc402f7a Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Thu, 30 May 2024 12:38:07 -0700 Subject: [PATCH 142/706] [inductor] remove reordering_reindex (#127367) This fixes the loop ordering issue for avg_pool2d here (https://github.com/pytorch/pytorch/issues/126255#issuecomment-2117931529). The reason we can not fuse the 2 kernels for avg_pool2d is due to ComputedBuffer.iter_reordering_reindex. Take a simpler example: ``` def f(x, y): """ Add a matmul since inductor may force layout for output. """ return (x.sum(dim=-1) + 1) @ y # Make the first 2 dimension not able to merge on purpose so that # ComputedBuffer.iter_reoredering_reindex will be updated. x = rand_strided([20, 20, 30], [30, 900, 1], device="cuda") y = torch.randn(20, 20) ``` Suppose x.sum is stored to x2. The computed buffer for x2 will remember that we have reordered it's first and second dimension (i.e. loop order [1, 0]). Later one when we decide the loop order for x2 when computing 'x2 + 1' , we decide to pick loop order [1, 0] according to the stride analysis. And then we use the saved ComputedBuffer.iter_reordering_reindex to further reorder the loop order. The net effect is that we use loop order [0, 1] which cause the pointwise kernel not able to fuse with the reduction kernel. I feel that we don't need ComputedBuffer.iter_reordering_reindex. And test result shows removing it has neutral impact on the dashboard [link](https://hud.pytorch.org/benchmark/compilers?startTime=Wed%2C%2022%20May%202024%2017%3A30%3A29%20GMT&stopTime=Wed%2C%2029%20May%202024%2017%3A30%3A29%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/153/head&lCommit=195f42cf1a414d2d1a0422b8a081a85ff52b7d20&rBranch=main&rCommit=d6e3e89804c4063827ea21ffcd3d865e5fe365d9) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127367 Approved by: https://github.com/jansel --- benchmarks/dynamo/timm_models.py | 1 + test/inductor/test_loop_ordering.py | 57 +++++++++++++++++++++++++++++ torch/_inductor/ir.py | 30 +++------------ 3 files changed, 64 insertions(+), 24 deletions(-) create mode 100644 test/inductor/test_loop_ordering.py diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index 75a12517698e..d5cdc533da43 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -74,6 +74,7 @@ def pip_install(package): "hrnet_w18", "inception_v3", "mixer_b16_224", + "mobilenetv3_large_100", "sebotnet33ts_256", "selecsls42b", } diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py new file mode 100644 index 000000000000..856d849b880f --- /dev/null +++ b/test/inductor/test_loop_ordering.py @@ -0,0 +1,57 @@ +# Owner(s): ["module: inductor"] + +import torch +from torch._dynamo.testing import rand_strided +from torch._dynamo.utils import same +from torch._inductor import config as inductor_config, metrics +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.inductor_utils import HAS_CUDA + + +@inductor_config.patch( + { + "benchmark_kernel": True, + "triton.unique_kernel_names": True, + } +) +class LoopOrderingTest(TestCase): + def do_acc_test(self, f, *args): + expect = f(*args) + actual = torch.compile(f)(*args) + self.assertTrue(same(expect, actual, tol=1e-3)) + + def test_for_reordering_reindex(self): + """ + ComputedBuffer.iter_reoredering_reindex can cause some fusion + opportunitiies being skipped. + + In this test case, Inductor generates 2 triton kernels before. + By removing ComputedBuffer.iter_reoredering_reindex, we can fuse those + two kernels into a single one. + """ + + def f(x, y): + """ + Add a matmul since inductor may force layout for output. + """ + return (x.sum(dim=-1) + 1) @ y + + A, B = 20, 30 + # Make the first 2 dimension not able to merge on purpose so that + # ComputedBuffer.iter_reoredering_reindex will be updated. + x = rand_strided([A, A, B], [B, B * A + 300, 1], device="cuda") + y = torch.randn(A, A) + + self.do_acc_test(f, x, y) + self.assertEqual(1, metrics.generated_kernel_count) + expected_num_bytes = 0 + expected_num_bytes += A * A * B + A * A # for the fused reduction + expected_num_bytes += A * A * 3 # for matmul + expected_num_bytes *= x.itemsize + self.assertEqual(expected_num_bytes, metrics.num_bytes_accessed) + + +if __name__ == "__main__": + if HAS_CUDA: + torch.set_default_device("cuda") + run_tests() diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 59b563c4e660..696641533500 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3415,17 +3415,9 @@ def simplify_and_reorder( *body.writes_name2expr.values(), ] - # the reordering_reindex in reads' simplify_reorder_and_tile - reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs) - for i, reads_buf in enumerate(reads_bufs): - if isinstance(reads_buf, ComputedBuffer) and hasattr( - reads_buf, "iter_reordering_reindex" - ): - reordering_reindex[i] = reads_buf.iter_reordering_reindex # type: ignore[has-type] - - def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None): + def simplify_and_reorder(x_vars, support_vars, sizes): sizes, reindex0, reindex1 = self._apply_loop_reordering( - x_vars, support_vars, sizes, memory_addrs, reordering_reindex + x_vars, support_vars, sizes, memory_addrs ) # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] x_vars = reindex0(x_vars) @@ -3442,16 +3434,15 @@ def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None): return sizes, reindex, reindex1 support_vars = index_vars + reduce_vars - iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder( - index_vars, support_vars, index_size, reordering_reindex + iter_ranges, iter_reindex, _ = simplify_and_reorder( + index_vars, + support_vars, + index_size, ) reduce_ranges, reduce_reindex, _ = simplify_and_reorder( reduce_vars, support_vars, reduce_size ) - # remember the reordering if not have loop collapse. - if len(iter_ranges) == len(index_vars): - self.iter_reordering_reindex = iter_reordering_reindex # retrace the loop body with simplification and reordering applied (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( iter_ranges, reduce_ranges, prefix="z" @@ -3467,7 +3458,6 @@ def _apply_loop_reordering( support_vars, sizes, memory_addrs, - reordering_reindex=None, priority_idx=None, ): """ @@ -3486,14 +3476,6 @@ def _apply_loop_reordering( assert len(strides) == len(memory_addrs) and len(strides[0]) == len( index_vars ) - # consider both layout(strides) and reordering(reordering_reindex) - if reordering_reindex is not None: - for i in range(len(memory_addrs)): - try: - strides[i] = reordering_reindex[i](strides[i]) - # if len(order) != len(strides), do not reorder - except AssertionError: - pass order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) except Exception: if config.debug: From b1792a622dee8f529c05e83541195b9c642a54b3 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Fri, 31 May 2024 01:52:57 +0000 Subject: [PATCH 143/706] [pipelining] handle param aliasing (#127471) Adds support for parameter aliasing in pipelining. Does this by reading the state_dict, and creating a map of id -> valid tensor FQNs (to be used in _sink_params). Assigns additional FQN attributes that may be used, runs _sink_params(), and then deletes unused attributes. Shares some similarity with how export's unflattener does it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127471 Approved by: https://github.com/kwen2501 --- test/distributed/pipelining/model_registry.py | 21 +++++ test/distributed/pipelining/test_pipe.py | 40 +++++--- torch/distributed/pipelining/_IR.py | 92 ++++++++++++++++--- 3 files changed, 128 insertions(+), 25 deletions(-) diff --git a/test/distributed/pipelining/model_registry.py b/test/distributed/pipelining/model_registry.py index babc1cfa1096..5f0c9baf3b1e 100644 --- a/test/distributed/pipelining/model_registry.py +++ b/test/distributed/pipelining/model_registry.py @@ -51,6 +51,27 @@ def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)): return x +class ModelWithParamAlias(torch.nn.Module): + default_dhid = 512 + default_batch_size = 256 + + def __init__(self, d_hid: int = default_dhid): + super().__init__() + self.mm_param1 = self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin1 = self.lin0 = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x, y): + x = torch.mm(x, self.mm_param0) + x = x + y + x = self.lin0(x) + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param1) + x = self.lin1(x) + x = torch.relu(x) + return x + + # MLP Layer class MLPModule(torch.nn.Module): def __init__(self, d_hid: int): diff --git a/test/distributed/pipelining/test_pipe.py b/test/distributed/pipelining/test_pipe.py index a9e283d3cedf..df053bd6c249 100644 --- a/test/distributed/pipelining/test_pipe.py +++ b/test/distributed/pipelining/test_pipe.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] -from model_registry import MLPModule +from model_registry import MLPModule, ModelWithParamAlias import torch from torch.distributed.pipelining import pipe_split, pipeline @@ -64,8 +64,21 @@ def forward(self, x, y): return x - y +EXPECTED_N_STAGES = { + ExampleCode: 4, + MultiMLP: 4, + ModelWithParamAlias: 2, +} + +# Currently, we don't enforce full set equality on the FQNs between the original +# and pipelined models, because in the multi-use param case, PP will deduplicate +# the FQNs from the state_dict. +# TODO +CHECK_FQN_SET_EQUALITY = False + + class PipeTests(TestCase): - @parametrize("ModelClass", [ExampleCode, MultiMLP]) + @parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias]) def test_model_split(self, ModelClass): mod = ModelClass() x = torch.randn(batch_size, d_hid) @@ -77,7 +90,9 @@ def test_model_split(self, ModelClass): example_args=(x, y), ) - assert pipe.num_stages == 4, f"nstages = {pipe.num_stages}, expect 4" + assert ( + pipe.num_stages == EXPECTED_N_STAGES[ModelClass] + ), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" ref_out = mod(x, y) out = pipe(x, y)[0] @@ -90,14 +105,17 @@ def test_model_split(self, ModelClass): new_names = set() for idx in range(pipe.num_stages): stage_mod = pipe.get_stage_module(idx) - new_names.update(stage_mod.state_dict().keys()) - - assert ( - old_names == new_names - ), f""" - old names {old_names} - new names {new_names} - """ + stage_fqns = set(stage_mod.state_dict().keys()) + assert stage_fqns.issubset(old_names) + new_names.update(stage_fqns) + + if CHECK_FQN_SET_EQUALITY: + assert ( + old_names == new_names + ), f""" + old names {old_names} + new names {new_names} + """ print("Qualname check passed") diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 7fe8eab83d97..c7ea787f98b5 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -2,16 +2,22 @@ import copy import logging import operator +from collections import defaultdict from dataclasses import dataclass from enum import Enum from inspect import Parameter, signature, Signature from types import MethodType -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.fx as fx from torch.export import ExportedProgram -from torch.export.unflatten import _assign_attr, _AttrKind, _sink_params +from torch.export.unflatten import ( + _assign_attr, + _AttrKind, + _sink_params, + InterpreterModule, +) from torch.fx.node import map_aggregate from torch.fx.passes.split_module import split_module @@ -751,6 +757,18 @@ def delete_user_reference(node, user): # To be accumulated in `move_param_to_callee`. to_delete = list() + def _recursive_getattr_with_parent(mod, fqn): + # Returns getattr call given a nested FQN, and the last parent + atoms = fqn.split(".") + for atom in atoms[:-1]: + if not hasattr(mod, atom): + return None, None + mod = getattr(mod, atom) + if not hasattr(mod, atoms[-1]): + return mod, None + attr = getattr(mod, atoms[-1]) + return mod, attr + def move_param_to_callee( root, callee_name, @@ -766,12 +784,7 @@ def move_param_to_callee( # `atoms` is a list of strings representing the path to the # parameter in the original model atoms = param_fqn.split(".") - # Recursively find the parent of the parameter - mod_itr = split - for atom in atoms[:-1]: - mod_itr = getattr(mod_itr, atom) - # Get the parameter (it is still under the root module) - param_val = getattr(mod_itr, atoms[-1]) + mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn) # Check whether the parameter is a buffer or a parameter is_buffer = atoms[-1] in mod_itr._buffers @@ -837,6 +850,37 @@ def move_param_to_callee( node.target, ) + # [aliasing] store tensor id -> list of FQNs, built from state dict + # Also assign non-persistent buffers + id_to_fqns: Dict[int, Set[str]] = defaultdict(set) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + id_to_fqns[id(tensor)].add(fqn) + for fqn, tensor in mod.named_buffers(): + id_to_fqns[id(tensor)].add(fqn) + + # After moving the params to their corresponding hierarchies, we also + # need to move the `get_attr` nodes from the root of the graph to those + # hierarchies. + # [aliasing] use id -> fqn mapping to list out all valid FQNs + inputs_to_state: Dict[str, List[str]] = {} + for attr in attr_nodes: + _, tensor = _recursive_getattr_with_parent(mod, attr.target) + inputs_to_state[attr.name] = list(id_to_fqns[id(tensor)]) + + # [aliasing] for each submodule split, assign attributes on FQNs that may be used. + # We determine this based on whether or not the FQN attribute parent exists. + # i.e. if the last submodule exists, assign the attribute. + added_attributes: Dict[str, List[str]] = defaultdict(list) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + for name, submod in split.named_children(): + if isinstance(submod, fx.GraphModule): + parent, child = _recursive_getattr_with_parent(submod, fqn) + if ( + parent and child is None + ): # parent exists, attribute doesn't -> assign + added_attributes[name].append(fqn) + setattr(parent, fqn.split(".")[-1], tensor) + # Deferral deletion: Remove the original attributes (to params) from the # root GraphModule for mod_itr, last_atom in to_delete: @@ -846,12 +890,6 @@ def move_param_to_callee( # This is expected if the parameter is used in multiple stages pass - # After moving the params to their corresponding hierarchies, we also - # need to move the `get_attr` nodes from the root of the graph to those - # hierarchies. - inputs_to_state: Dict[str, List[str]] = { - attr.name: [attr.target] for attr in attr_nodes - } # This is done by (1) `_sink_params` at each submodule; for name, submod in split.named_children(): if isinstance(submod, fx.GraphModule): @@ -859,6 +897,32 @@ def move_param_to_callee( submod.graph.lint() submod.recompile() + # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory. + # After _sink_params() routine has run, clean up unused attributes that we previously added. + # Determine this based on the get_attr nodes - if not used, remove it. + for name, attributes in added_attributes.items(): + submod = getattr(split, name) + unused_attributes = set(attributes) + # track used attributes in the submodule, running DFS on subgraph hierarchy + stack = [("", submod)] # (scope, submodule) + while stack: + scope, _mod = stack.pop() + if isinstance(_mod, (fx.GraphModule, InterpreterModule)): + for node in _mod.graph.nodes: + if node.op == "get_attr": + # get_attr might get access deeper level attribute + fqn = scope + "." + node.target if scope else node.target + if fqn in unused_attributes: # used, remove it + unused_attributes.remove(fqn) + for _name, _submod in _mod.named_children(): + stack.append((scope + "." + _name if scope else _name, _submod)) + # delete unused attributes + for attr in unused_attributes: + mod_itr, atoms = submod, attr.split(".") + for atom in atoms[:-1]: + mod_itr = getattr(mod_itr, atom) + delattr(mod_itr, atoms[-1]) + for node in attr_nodes: # And (2): remove `get_attr` node from submod's arg list for user in copy.copy(node.users): From 0aaac68c573e729aafb847b1a43cfbbc9171c787 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 29 May 2024 06:46:48 -0700 Subject: [PATCH 144/706] Add structured logging for tensor fakeification (#126879) This adds dumps of MetaTensorDesc and MetaStorageDesc to structured logs when they are triggered from Dynamo. The logs look like this: ``` V0522 08:13:25.267000 140224882566144 torch/_subclasses/meta_utils.py:195] {"describe_storage": {"id": 0, "describer_id": 0, "size": 32}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} V0522 08:13:25.267000 140224882566144 torch/_subclasses/meta_utils.py:220] {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [8], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "", "describer_id": 0}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} V0522 08:13:25.268000 140224882566144 torch/_subclasses/meta_utils.py:1594] {"describe_source": {"describer_id": 0, "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} ``` The `describer_id` is used to disambiguate ids. We expect it to be unique per frame id, but if there is a bug it possibly is not. Note you will get redundant dumps when evaluation restarts. tlparse can use this to give a visualization of input tensors to a model, you could also use this to generate example inputs to run graphs on. Some care is taken to avoid redumping the tensor metadata multiple times, which would happen ordinarily because AOTAutograd refakifies everything after Dynamo, to deal with metadata mutation. Partially fixes https://github.com/pytorch/pytorch/issues/126644 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126879 Approved by: https://github.com/jamesjwu --- test/dynamo/test_structured_trace.py | 59 ++++++++ torch/_functorch/aot_autograd.py | 5 + torch/_logging/_internal.py | 6 +- torch/_subclasses/fake_tensor.py | 4 + torch/_subclasses/meta_utils.py | 199 +++++++++++++++++++++------ 5 files changed, 233 insertions(+), 40 deletions(-) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index ea44a5e0771d..fb7a2c249a85 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -77,6 +77,14 @@ def format(self, record): metadata["stack"] = "STACK" if "compilation_metrics" in metadata: metadata["compilation_metrics"] = "METRICS" + if "describe_storage" in metadata: + metadata["describe_storage"]["describer_id"] = "ID" + if "describe_tensor" in metadata: + metadata["describe_tensor"]["describer_id"] = "ID" + if "view_func" in metadata["describe_tensor"]: + metadata["describe_tensor"]["view_func"] = "VIEW_FUNC" + if "describe_source" in metadata: + metadata["describe_source"]["describer_id"] = "ID" return json.dumps(metadata) @@ -136,6 +144,9 @@ def test_schedule(self): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -157,6 +168,9 @@ def test_cudagraphs(self): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -182,6 +196,12 @@ def fn(x, y): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "l_y_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -191,6 +211,9 @@ def fn(x, y): {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -211,6 +234,9 @@ def test_example_fn(self): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "ones_1": [1000, 1000], "output_1": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -234,6 +260,9 @@ def test_dynamo_error(self): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 ) @@ -263,6 +292,9 @@ def throw(x): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_joint_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -310,10 +342,16 @@ def forward(self, x): {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1024, 1024], "l__self___layers_0": [1024, 1024], "l__self___layers_1": [1024, 1024]}}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_child": {"name": "submod_0"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_child": {"name": "submod_1"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"aot_joint_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_backward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -350,6 +388,9 @@ def fn(x): {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1], "add": [1]}}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -379,11 +420,23 @@ def fn(a, b): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 800}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 20], "is_leaf": true, "stride": [20, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 2400}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [20, 30], "is_leaf": true, "stride": [30, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [10, 20], "l_b_": [20, 30], "matmul": [10, 30]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 600}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": ["s0", "s1"], "l_b_": ["s1", "s3"], "matmul": ["s0", "s3"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -414,11 +467,17 @@ def inner(x, ys, zs): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index f7724a6add60..1c4fff02220d 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -509,10 +509,14 @@ def convert(idx, x): # see note [Tensor Fakification and Symbol Caching] symbolic_context = None source = None + trace = True if tracing_context := torch._guards.TracingContext.try_get(): if x in tracing_context.tensor_to_context: symbolic_context = tracing_context.tensor_to_context[x] source = symbolic_context.tensor_source + # We already fakeified this tensor in Dynamo, don't + # dump the trace for it again + trace = False if ( idx < aot_config.num_params_buffers and config.static_weight_shapes @@ -527,6 +531,7 @@ def convert(idx, x): static_shapes=False, symbolic_context=symbolic_context, source=source, + trace=trace, ) return [convert(idx, x) for idx, x in enumerate(flat_args)] diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 28a57e39bf3b..798eeabc5d6b 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -795,7 +795,11 @@ def format(self, record): ) if self._is_trace: assert s == "" - r = f"{prefix} {json.dumps(record.metadata)}" + try: + r = f"{prefix} {json.dumps(record.metadata)}" + except TypeError: + log.warning("failing metadata: %r", record.metadata) + raise if record.payload is not None: r += "".join(f"\n\t{l}" for l in record.payload.split("\n")) return r diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 2c75847c92a1..47d4abcf77b9 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -291,6 +291,7 @@ def from_real_tensor( *, source=None, symbolic_context=None, + trace=True, ): # see note [Tensor Fakification and Symbol Caching] if not symbolic_context and not source and shape_env: @@ -333,6 +334,7 @@ def mk_fake_tensor(make_meta_t): callback=mk_fake_tensor, source=source, symbolic_context=symbolic_context, + trace=trace, ) if out is NotImplemented: raise UnsupportedFakeTensorException("meta converter nyi") @@ -1925,6 +1927,7 @@ def from_tensor( static_shapes=None, source: Optional[Source] = None, symbolic_context=None, + trace=True, ): shape_env: Optional[ShapeEnv] = self.shape_env if static_shapes is None: @@ -1940,6 +1943,7 @@ def from_tensor( shape_env=shape_env, source=source, symbolic_context=symbolic_context, + trace=trace, ) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 647c03861768..08c4c9ce4e3f 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1,13 +1,15 @@ from __future__ import annotations import contextlib + +import dataclasses import warnings import weakref - from dataclasses import dataclass from typing import ( Any, Callable, + ClassVar, ContextManager, Dict, List, @@ -20,6 +22,7 @@ from typing_extensions import TypeAlias import torch +from torch._C._autograd import CreationMeta from torch._C._functorch import ( _add_batch_dim, _unwrap_functional_tensor, @@ -33,13 +36,13 @@ maybe_get_level, peek_interpreter_stack, ) +from torch._logging import trace_structured from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils.weak import WeakIdKeyDictionary if TYPE_CHECKING: - from torch._C._autograd import CreationMeta from torch._C._functorch import CInterpreter from torch._guards import Source @@ -142,6 +145,9 @@ def is_sparse_any(t): MetaTensorId: TypeAlias = int +DESCRIBER_NEXT_ID = 0 + + class MetaTensorDescriber: """ Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc @@ -154,6 +160,9 @@ class MetaTensorDescriber: """ def __init__(self, *, copy_data=False): + global DESCRIBER_NEXT_ID + self.id = DESCRIBER_NEXT_ID + DESCRIBER_NEXT_ID += 1 self.next_tensor_id: MetaTensorId = 0 self.next_storage_id: MetaStorageId = 0 # Tensor -> int @@ -161,6 +170,8 @@ def __init__(self, *, copy_data=False): # Storage -> int self.lookup_storage = WeakIdKeyDictionary() self.copy_data = copy_data + self.traced_tensors = set() + self.traced_storages = set() def get_tensor_id(self, t: torch.Tensor): if t not in self.lookup_tensor: @@ -174,19 +185,25 @@ def get_storage_id(self, s: torch.UntypedStorage): self.next_storage_id += 1 return self.lookup_storage[s] - # NB: the describe functions NOT maintain a cache and will happily regen the - # description - - def describe_storage(self, s: torch.UntypedStorage): - return MetaStorageDesc( + def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False): + r = MetaStorageDesc( id=self.get_storage_id(s), size=s.size(), # NB: We don't do the copy yet; copy happens when we start # creating the new storages data=s if self.copy_data else None, ) + if trace and r.id not in self.traced_storages: + trace_structured( + "describe_storage", + metadata_fn=lambda: r.as_json(self.id), + ) + self.traced_storages.add(r.id) + return r - def describe_tensor(self, t: torch.Tensor, recurse: bool = True): + def describe_tensor( + self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False + ): is_leaf = safe_is_leaf(t) is_view = t._is_view() is_sparse = t.is_sparse @@ -218,7 +235,7 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): ): # NB: We actually don't use storage to do views, but might as well # put it in for accuracy - storage = self.describe_storage(t.untyped_storage()) + storage = self.describe_storage(t.untyped_storage(), trace=trace) storage_offset = t.storage_offset() stride = None @@ -239,7 +256,7 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): autograd_meta_from = None current_level = None if is_batchedtensor_v or is_gradtrackingtensor_v: - unwrapped = self.describe_tensor(get_unwrapped(t)) + unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace) # xla and lazy tensors present as functional tensors, but we want them # to be handled specially elif is_functional and t.device.type not in ("xla", "lazy"): @@ -249,13 +266,15 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): ) if not is_functorch_wrapped: torch._sync(t) - unwrapped = self.describe_tensor(torch._from_functional_tensor(t)) + unwrapped = self.describe_tensor( + torch._from_functional_tensor(t), trace=trace + ) autograd_meta_from = t else: reapply_views = torch._C._functionalization_reapply_views_tls() # NB: has side effects! unwrapped = self.describe_tensor( - _unwrap_functional_tensor(t, reapply_views) + _unwrap_functional_tensor(t, reapply_views), trace=trace ) # TODO: It's pretty suspicious that functional tensors don't have # valid level and thus we just grab whatever the current level @@ -273,12 +292,15 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): if is_traceable_wrapper_subclass_v: assert hasattr(t, "__tensor_flatten__") raw_attrs, ctx = t.__tensor_flatten__() - attrs = {attr: self.describe_tensor(getattr(t, attr)) for attr in raw_attrs} + attrs = { + attr: self.describe_tensor(getattr(t, attr), trace=trace) + for attr in raw_attrs + } type_v = type(t) # TODO: Is it important to enable torch.inference_mode before querying # these values? - return MetaTensorDesc( + r = MetaTensorDesc( id=self.get_tensor_id(t), storage=storage, is_inference=t.is_inference(), @@ -318,22 +340,30 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): # TODO: I actually think recursing here is correct, but we have at # least an infinite cycle from base -> values -> base # https://github.com/pytorch/pytorch/issues/122089 - crow_indices=self.describe_tensor(t.crow_indices(), recurse=False) + crow_indices=self.describe_tensor( + t.crow_indices(), recurse=False, trace=trace + ) if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} else None, - col_indices=self.describe_tensor(t.col_indices(), recurse=False) + col_indices=self.describe_tensor( + t.col_indices(), recurse=False, trace=trace + ) if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} else None, - ccol_indices=self.describe_tensor(t.ccol_indices(), recurse=False) + ccol_indices=self.describe_tensor( + t.ccol_indices(), recurse=False, trace=trace + ) if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} else None, - row_indices=self.describe_tensor(t.row_indices(), recurse=False) + row_indices=self.describe_tensor( + t.row_indices(), recurse=False, trace=trace + ) if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} else None, - values=self.describe_tensor(t.values(), recurse=False) + values=self.describe_tensor(t.values(), recurse=False, trace=trace) if recurse and is_sparse_compressed(t) else None, - grad=self.describe_tensor(safe_grad(t)) + grad=self.describe_tensor(safe_grad(t), trace=trace) if safe_grad(t) is not None else None, creation_meta=torch._C._autograd._get_creation_meta(t) @@ -344,7 +374,7 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): if is_batchedtensor_v or is_gradtrackingtensor_v else None, bdim=maybe_get_bdim(t) if is_batchedtensor_v else None, - base=self.describe_tensor(t._base) + base=self.describe_tensor(t._base, trace=trace) if recurse and t._is_view() and t._base is not None else None, fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t), @@ -360,6 +390,13 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): current_level=current_level, data=t if self.copy_data else None, ) + if trace and r.id not in self.traced_tensors: + trace_structured( + "describe_tensor", + metadata_fn=lambda: r.as_json(self.id), + ) + self.traced_tensors.add(r.id) + return r @dataclass(frozen=True) @@ -370,43 +407,58 @@ class MetaStorageDesc: # serializable in JSON, you want to do something special here anyway data: Optional[torch.UntypedStorage] + def as_json(self, describer_id): + return { + "id": self.id, + "describer_id": describer_id, + "size": self.size if isinstance(self.size, int) else repr(self.size), + } + @dataclass(frozen=True) class MetaTensorDesc: id: MetaTensorId - is_inference: bool - is_leaf: bool - requires_grad: bool ndim: int dtype: torch.dtype - is_sparse: bool - is_mkldnn: bool - is_functorch_wrapped: bool - is_batchedtensor: bool - is_legacy_batchedtensor: bool - is_gradtrackingtensor: bool - is_view: bool - is_nested: bool - is_traceable_wrapper_subclass: bool - is_functional: bool - is_conj: bool - is_neg: bool device: torch.device - layout: torch.layout + # NB: Sometimes, size, stride and storage_offset contain SymInt, in which # case this is NOT serializable. That only happens when you're # re-fakeifying a fake tensor with an existing ShapeEnv... maybe we # can get rid of this use case entirely. Notably, even if we are # fakeifying a real tensor into a fake tensor with symbolic shapes, the # size here is NOT dynamic + # NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic + # goes through this codepath. But it really should not LOL. # NB: size could potentially be None as you can override it and make it # throw an error, but we don't currently have any subclasses that do this # except C++ nested tensor but we're going to have nested int to make this # defined on NJT size: Tuple[int, ...] dynamo_dynamic_indices: List[int] + + layout: torch.layout = torch.strided + is_inference: bool = False + is_leaf: bool = False + requires_grad: bool = False + is_sparse: bool = False + is_mkldnn: bool = False + is_functorch_wrapped: bool = False + is_batchedtensor: bool = False + is_legacy_batchedtensor: bool = False + is_gradtrackingtensor: bool = False + is_view: bool = False + is_nested: bool = False + is_traceable_wrapper_subclass: bool = False + is_functional: bool = False + is_conj: bool = False + is_neg: bool = False stride: Optional[Tuple[int, ...]] = None storage_offset: int = 0 + # NB: We have a choice whether or not to store the id or a direct pointer + # to the data structure. For ease of use, we store the data structure, + # but this means that when we serialize, we have to swizzle these pointers + # back into ids (so we have accurate aliasing relationships) storage: Optional[MetaStorageDesc] = None sparse_dim: Optional[int] = None # is_sparse, is_sparse_compressed dense_dim: Optional[int] = None # is_sparse, is_sparse_compressed @@ -424,6 +476,19 @@ class MetaTensorDesc: grad: Optional[MetaTensorDesc] = None # Everything below is NOT serializable, need some more work + + _UNSERIALIZABLE: ClassVar[List[str]] = [ + "ctx", + "type", + "fake_mode", + "view_func", + "level", + "current_level", + "functorch_stack", + "autograd_meta_from", + "data", + ] + ctx: Optional[object] = None # is_traceable_wrapper_subclass type: Optional[Type] = None # is_traceable_wrapper_subclass fake_mode: Optional[FakeTensorMode] = None @@ -459,6 +524,44 @@ class MetaTensorDesc: # entirely clear how to make it all lexical again, so we haven't done # it for now. + # NB: This will reference numeric IDs, and it is assumed that you've + # already serialized everything this recursively references + def as_json(self, describer_id): + def json(k, v): + # Some best-effort debugging serialization for unserializable + # fields (feel free to add other special cases as appropriate) + if k in ["data", "autograd_meta_from"]: + return None # never repr these + if k in set(self._UNSERIALIZABLE): + return repr(v) + if isinstance(v, (torch.device, torch.dtype, torch.layout)): + return repr(v) + if isinstance(v, torch.SymInt): + return repr(v) + if isinstance(v, (tuple, list)): + return [json(k, v1) for v1 in v] + if isinstance(v, (MetaStorageDesc, MetaTensorDesc)): + return v.id + if isinstance(v, CreationMeta): + return str(v) + if k == "attrs" and isinstance(v, dict): + return {k1: v1.id for k1, v1 in v.items()} + return v + + r = { + field.name: json(field.name, getattr(self, field.name)) + for field in dataclasses.fields(self) + if not ( + getattr(self, field.name) is field.default + or ( + field.name == "dynamo_dynamic_indices" + and not getattr(self, field.name) + ) + ) + } + r.update({"describer_id": describer_id}) + return r + @property def shape(self): return self.size @@ -887,9 +990,10 @@ def symint_visitor_fn(s): def tensor_visitor_fn( visited_t: torch.Tensor, + # These arguments are never passed, we just use them to close + # over these relevant values shape_env=shape_env, callback=callback, - source=source, ): # It's possible to close over an undefined tensor (e.g. NJT's lengths). if visited_t is None: @@ -1443,6 +1547,10 @@ def __call__( callback=lambda t: t(), source=None, symbolic_context=None, + # Controls whether or not we should dump the tensor metadata to structured logs + # when source is not None. Because we refakify after Dynamo is done, + # we don't want to dump info again from AOTAutograd, it is redundant. + trace=True, ): # TODO: zero tensors? We appear to have eliminated them by # excluding complex for now @@ -1475,9 +1583,22 @@ def __call__( # non-Tensor types don't count as hit or miss return t + if source is None: + trace = False + # Describe the tensor. NB: do NOT disable ambient modes, we may need # to query them when figuring out what to put in here - t_desc = self.describer.describe_tensor(t) + t_desc = self.describer.describe_tensor(t, trace=trace) + + if trace: + trace_structured( + "describe_source", + metadata_fn=lambda: { + "describer_id": self.describer.id, + "id": t_desc.id, + "source": source.name(), + }, + ) # Do the meta-fication. Here, we disable all the ambient modes, to # better simulate what would be like to re-fakeify from a fresh From 8629f9b3f2d82025ad1652ba9852a1a05b5add7a Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 31 May 2024 03:39:45 +0000 Subject: [PATCH 145/706] Remove more unused variables in tests (#127510) Follows #127379 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127510 Approved by: https://github.com/Skylion007, https://github.com/r-barnes --- test/cpp/api/CMakeLists.txt | 3 --- test/cpp/api/serialize.cpp | 1 - test/cpp/tensorexpr/test_llvm.cpp | 1 - test/cpp/tensorexpr/test_reductions.cpp | 2 ++ torch/csrc/api/include/torch/data/dataloader/stateful.h | 6 ++---- torch/csrc/jit/api/module.h | 4 +--- 6 files changed, 5 insertions(+), 12 deletions(-) diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index b0e296ad2309..ceeb607d52a7 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -51,9 +51,6 @@ endif() add_executable(test_api ${TORCH_API_TEST_SOURCES}) target_include_directories(test_api PRIVATE ${ATen_CPU_INCLUDE}) target_link_libraries(test_api PRIVATE torch gtest) -if(NOT MSVC) - target_compile_options_if_supported(test_api -Wno-unused-variable) -endif() if(USE_CUDA) target_compile_definitions(test_api PRIVATE "USE_CUDA") diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index 1b61499c2a75..9d4d381742e1 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -806,7 +806,6 @@ TEST(SerializeTest, Optim_RMSprop) { for (const auto i : c10::irange(params1_2_.size())) { if (i != (params1_2_.size() - 1)) { auto key_ = params_[i].unsafeGetTensorImpl(); - auto key1_2_ = params1_2_[i].unsafeGetTensorImpl(); const RMSpropParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); RMSpropParamState& curr_state1_2_ = diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index d469a7dfa21b..aa578a4956c6 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -1474,7 +1474,6 @@ TEST(LLVM, RFactorReduction) { TEST(LLVM, RFactorVectorizedReduction) { int M = 128; int N = 64; - const int kTotalSize = M * N; BufHandle a("a", {1, M, N}, kFloat); diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index 6a6a94c82e59..d65b5c544f6c 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -1092,6 +1092,7 @@ TEST(Reductions, ReduceOverSplitRfactor) { // Check the IR to verify the rfactored reduce is eliminated. // TODO: The alloc free should be eliminated here since it is size 0. + /* const std::string& verification_pattern = R"IR( # CHECK: Allocate(tmp_buf); // dtype=float, dims=[0] @@ -1102,6 +1103,7 @@ TEST(Reductions, ReduceOverSplitRfactor) { # CHECK: } # CHECK: } # CHECK: Free(tmp_buf);)IR"; + */ // TODO: rfactor output is not consistent yet, will fix (@nickg). // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } diff --git a/torch/csrc/api/include/torch/data/dataloader/stateful.h b/torch/csrc/api/include/torch/data/dataloader/stateful.h index e8eb85861f77..22d584ce4a00 100644 --- a/torch/csrc/api/include/torch/data/dataloader/stateful.h +++ b/torch/csrc/api/include/torch/data/dataloader/stateful.h @@ -36,10 +36,8 @@ class StatefulDataLoader : public DataLoaderBase< /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`. StatefulDataLoader(Dataset dataset, DataLoaderOptions options) - : super( - std::move(options), - std::make_unique(std::move(dataset))) { - for (const auto w : c10::irange(this->options_.workers)) { + : super(options, std::make_unique(std::move(dataset))) { + for ([[maybe_unused]] const auto _ : c10::irange(this->options_.workers)) { // As opposed to the stateless case, here all worker threads access the // same underlying dataset. this->workers_.emplace_back( diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index e779542e315f..92b9c96c3a6e 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -541,9 +541,7 @@ struct slot_list_impl { size_t size() const { if (!size_) { size_ = size_t(0); - // NOLINTNEXTLINE(clang-diagnostic-unused-variable) - for (const value_type& s : *(this)) { - (void)s; // Suppress unused variable warning + for ([[maybe_unused]] const value_type& _ : *(this)) { ++*size_; } } From f264745ff1a87ddd7628e64f20dc144a3872eef7 Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Fri, 31 May 2024 03:54:43 +0000 Subject: [PATCH 146/706] [interformer] batch pointwise op + unbind stack pass in post grad (#126959) Summary: Tested on H100 with single GPU, and the bs is set to 64. Test Plan: # local script ``` buck2 run mode/opt scripts/jackiexu0313/pt2:uniarch_perf_benchmark -- single-module-benchmark --provider interformer --enable_pt2 True --batch_size 64 ``` baseline: P1370993922 | Metric | Value | |:-------------------|:-------------| | Latency | 120.84 ms | | Model size | 5.93 G bytes | | Flops/example | 62.22 GB | | TFLOPS | 32.95 | | MFU | 4.12% | | Activation/example | 128.17 MB | proposal: P1371676068 config ``` torch._inductor.config.pre_grad_fusion_options = {} torch._inductor.config.post_grad_fusion_options = { "batch_aten_mul": {"min_fuse_set_size": 50}, "batch_aten_sigmoid": {"min_fuse_set_size": 50}, "batch_aten_relu": {"min_fuse_set_size": 50}, "batch_linear_post_grad": {"min_fuse_set_size": 50}, "unbind_stack_aten_pass": {}, } ``` | Metric | Value | |:-------------------|:-------------| | Latency | 117.30 ms | | Model size | 5.93 G bytes | | Flops/example | 62.65 GB | | TFLOPS | 34.18 | | MFU | 4.27% | | Activation/example | 163.12 MB | Differential Revision: D57595173 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126959 Approved by: https://github.com/jackiexu1992 --- test/inductor/test_group_batch_fusion.py | 120 ++++++++++ .../_inductor/fx_passes/group_batch_fusion.py | 216 +++++++++++++++--- torch/_inductor/fx_passes/split_cat.py | 107 +++++++++ 3 files changed, 416 insertions(+), 27 deletions(-) diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index b203a0f63e8b..96255c54147e 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -2,6 +2,7 @@ import collections import unittest +from typing import List import torch import torch._inductor @@ -22,6 +23,37 @@ requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +class TestHighwaySelfGating(torch.nn.Module): + def __init__( + self, + d_model: int, + size: int, + device="cuda", + ) -> None: + super().__init__() + self.size = size + self.device = device + self.gating_proj = torch.nn.Linear(d_model, d_model).to(self.device) + self.transform_proj = torch.nn.Linear(d_model, d_model).to(self.device) + self.gating_func = torch.nn.Sigmoid().to(self.device) + + self.d_model = d_model + + def forward( + self, + inputs: List[torch.Tensor], + ) -> torch.Tensor: + results = [] + for i in range(self.size): + x = inputs[i] + gating_proj = self.gating_proj(x) + transform_proj = self.transform_proj(x) + x = gating_proj * self.gating_func(transform_proj) + results.append(x) + + return torch.cat(results, dim=-1) + + class MyModule(torch.nn.Module): def __init__(self, z: int, has_bias: bool, device="cuda") -> None: super().__init__() @@ -221,6 +253,25 @@ def forward(self, x): return torch.cat(div, dim=1) +class TestPoitwiseOpsPostGrad(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.device = device + + def forward(self, x): + inputs = torch.ops.aten.split(x.to(self.device), 500, dim=1) + x_split = torch.ops.aten.split(inputs[0].to(self.device), 50, dim=1) + y_split = torch.ops.aten.split(inputs[1].to(self.device), 50, dim=1) + tanh_1 = [torch.ops.aten.tanh(x_split[i]) for i in range(len(x_split))] + tanh_2 = [torch.ops.aten.tanh(y_split[i]) for i in range(len(y_split))] + sigmoid_1 = [torch.ops.aten.sigmoid(tanh_1[i]) for i in range(len(tanh_1))] + sigmoid_2 = [torch.ops.aten.sigmoid(tanh_2[i]) for i in range(len(tanh_2))] + relu_1 = [torch.ops.aten.relu(sigmoid_1[i]) for i in range(len(sigmoid_1))] + relu_2 = [torch.ops.aten.relu(sigmoid_2[i]) for i in range(len(sigmoid_2))] + add = [torch.ops.aten.add(relu_1[i], relu_2[i]) for i in range(len(relu_1))] + return torch.cat(add, dim=1) + + @requires_cuda @torch._inductor.config.patch( pre_grad_fusion_options={ @@ -400,6 +451,75 @@ def test_pointwise_op_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_cuda + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "batch_aten_relu": {}, + "batch_aten_sigmoid": {}, + "batch_aten_tanh": {}, + "unbind_stack_aten_pass": {}, + }, + ) + def test_pointwise_op_fusion_post_grad(self): + counters.clear() + module = TestPoitwiseOpsPostGrad("cuda") + input = [torch.randn(50, 1000, requires_grad=True, device="cuda")] + traced = torch.compile(module) + ref = module(*input) + res = traced(*input) + self.compare_pred(module, traced, input) + self.assertEqual(counters["inductor"]["batch_aten_tanh"], 1) + self.assertEqual(counters["inductor"]["batch_aten_relu"], 1) + self.assertEqual(counters["inductor"]["batch_aten_sigmoid"], 1) + self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 2) + ref.sum().backward() + res.sum().backward() + self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) + self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) + counters.clear() + + @requires_cuda + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "batch_linear_post_grad": { + "shape_broadcast_batch_linear": True, + "fuse_nodes_with_same_users": True, + }, + "batch_aten_mul": {"fuse_nodes_with_same_parent": False}, + "batch_aten_sigmoid": {"fuse_nodes_with_same_parent": True}, + "batch_aten_add": {"fuse_nodes_with_same_parent": True}, + "normalization_aten_pass": {}, + "unbind_stack_aten_pass": {}, + }, + ) + def test_gate_fusion_post_grad(self): + counters.clear() + size = 20 + module = TestHighwaySelfGating(d_model=10, size=size) + input = [ + [ + torch.randn(10, 10, requires_grad=True, device="cuda") + for i in range(size) + ] + ] + traced = torch.compile(module) + ref = module(*input) + res = traced(*input) + self.compare_pred(module, traced, input) + self.assertEqual(counters["inductor"]["batch_linear_post_grad"], 2) + self.assertEqual(counters["inductor"]["batch_aten_sigmoid"], 1) + self.assertEqual(counters["inductor"]["batch_aten_mul"], 1) + self.assertEqual(counters["inductor"]["batch_aten_add"], 2) + self.assertEqual(counters["inductor"]["normalization_aten_pass"], 1) + self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 5) + ref.sum().backward() + res.sum().backward() + self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) + self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) + counters.clear() + class TestBMMFusionModule(torch.nn.Module): def __init__(self): diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 90c59a06bab7..92449268bcec 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -44,6 +44,12 @@ MAX_FUSE_SEARCH_DEPTH = 5 # The maximum tensor size that can go into the fusion group MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096 +# Whether we only fuse nodes with same parent node +FUSE_NODES_WITH_SAME_PARENT = False +# Whether we enable the add broadcast in batch linear +SHAPE_BROADCAST_BATCH_LINEAR = False +# Whether we enable the fuse nodes with same users +Fuse_NODES_WITH_SAME_USERS = False # exclude these nodes from BFS # excluding get item improves optimizer compilation time by 60s @@ -55,6 +61,9 @@ "max_fuse_set_size": MAX_FUSE_SET_SIZE, "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH, "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR, + "fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT, + "shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR, + "fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS, } graph_search_options = default_graph_search_options @@ -125,14 +134,18 @@ def list_group_batch_fusions(pre_grad=True) -> List[str]: def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any: unsqueezed_inputs = [] + unsqueezed_inputs_meta = [] for input_tensor in input_tensors: unsqueezed_input = graph.call_function( aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0} ) unsqueezed_inputs.append(unsqueezed_input) + unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0) # type: ignore[assignment] + unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"]) stacked_inputs = graph.call_function( aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0} ) + stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0) # type: ignore[assignment] return stacked_inputs @@ -165,19 +178,22 @@ class PostGradBatchLinearFusion(BatchFusion): """ def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool: + # pyre-fixme[7]: Incompatible return type return ( node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value] ) def _is_input_2d(self, input: torch.fx.Node) -> bool: - input_shapes = input.meta["tensor_meta"].shape + input_shapes = input.meta["val"].shape return ( len(input_shapes) == 2 and isinstance(input_shapes[0], int) and isinstance(input_shapes[1], int) ) - def match(self, node: torch.fx.Node) -> Optional[Tuple[str, int, int, int, bool]]: + def match( + self, node: torch.fx.Node + ) -> Optional[Tuple[str, int, int, int, bool, str]]: if CallFunctionVarArgs(aten.mm).match(node): input_m, weight_m = node.args bias_m = None @@ -188,13 +204,17 @@ def match(self, node: torch.fx.Node) -> Optional[Tuple[str, int, int, int, bool] bias_m, input_m, weight_m = node.args else: return None - + # get the user of the node + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] # only handle the cases where inputs are 2D tensors if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type] return None - m, k = input_m.meta["tensor_meta"].shape # type: ignore[union-attr] - n = weight_m.meta["tensor_meta"].shape[1] # type: ignore[union-attr] - batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None) + m, k = input_m.meta["val"].shape # type: ignore[union-attr] + n = weight_m.meta["val"].shape[1] # type: ignore[union-attr] + batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users)) return batch_key def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): @@ -202,6 +222,9 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): batch_weights = [] batch_biases = [] batch_nodes = [] + batch_inputs_meta = [] + batch_weights_meta = [] + batch_biases_meta = [] for node in subset: if CallFunctionVarArgs(aten.addmm.default).match(node): @@ -213,24 +236,62 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): batch_inputs.append(input) # type: ignore[possibly-undefined] batch_weights.append(weight) # type: ignore[possibly-undefined] batch_biases.append(bias) # type: ignore[possibly-undefined] + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_weights_meta.append(weight.meta) # type: ignore[possibly-undefined, union-attr] + if bias is not None: # type: ignore[possibly-undefined] + batch_biases_meta.append(bias.meta) # type: ignore[possibly-undefined, union-attr] + else: + batch_biases_meta.append(None) with graph.inserting_before(subset[-1]): fused_inputs = decompose_stack(graph, batch_inputs) fused_weights = decompose_stack(graph, batch_weights) + fused_inputs_meta_val = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + fused_weights_meta_val = torch.stack( + [weight["val"] for weight in batch_weights_meta] + ) fused_bmm = graph.call_function( aten.bmm, args=(fused_inputs, fused_weights), ) - + fused_bmm.meta["val"] = aten.bmm( + fused_inputs_meta_val, fused_weights_meta_val + ) for i, original_mm in enumerate(batch_nodes): has_bias = False with graph.inserting_after(fused_bmm): new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i))) + new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i) if batch_biases[i]: has_bias = True - new_bias_add = graph.call_function( - aten.add, args=((batch_biases[i], new_mm)) - ) + # broadcast the bias to the same shape as the mm output + if self.graph_search_options.get( + "shape_broadcast_batch_linear", False + ): + broadcast_shape = torch.broadcast_shapes( + batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape + ) + broadcast_bias = graph.call_function( + aten.broadcast_to.default, + args=(batch_biases[i],), + kwargs={"size": broadcast_shape}, + ) + broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment] + new_bias_add = graph.call_function( + aten.add.Tensor, args=((broadcast_bias, new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + broadcast_bias.meta["val"], new_mm.meta["val"] + ) + else: + new_bias_add = graph.call_function( + aten.add, args=((batch_biases[i], new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + batch_biases_meta[i]["val"], new_mm.meta["val"] + ) new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined] original_mm.replace_all_uses_with(new_mm_cont) new_mm_cont.meta.update(original_mm.meta) @@ -241,8 +302,8 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): @register_fusion("group_linear", pre_grad=False) class GroupLinearFusion(GroupFusion): def _addmm_node_can_be_fused(self, node: torch.fx.Node): - input_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr] - weight_shape = node.args[2].meta["tensor_meta"].shape # type: ignore[union-attr] + input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr] return ( node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 @@ -256,8 +317,8 @@ def _addmm_node_can_be_fused(self, node: torch.fx.Node): ) def _mm_node_can_be_fused(self, node: torch.fx.Node): - input_shape = node.args[0].meta["tensor_meta"].shape # type: ignore[union-attr] - weight_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr] + input_shape = node.args[0].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] return ( len(input_shape) == 2 and len(weight_shape) == 2 @@ -319,9 +380,9 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): counters["inductor"]["group_linear"] += 1 -class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): +class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory): """ - Batch pointwise operator (e.g., add, mul) in post grad pass. + Batch pointwise math operator (e.g., add, mul) in post grad pass. """ def __init__(self, op, **kwargs): @@ -336,11 +397,11 @@ def _pointwise_node_can_be_fused(self, node: torch.fx.Node): # its inputs, and cause dtype not same error in mm or addmm input, other = node.args return ( - input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape # type: ignore[union-attr] + input.meta["val"].shape == other.meta["val"].shape # type: ignore[union-attr] if hasattr(input, "meta") and hasattr(other, "meta") - and "tensor_meta" in input.meta # type: ignore[union-attr] - and "tensor_meta" in other.meta # type: ignore[union-attr] + and "val" in input.meta # type: ignore[union-attr] + and "val" in other.meta # type: ignore[union-attr] else False ) @@ -351,14 +412,30 @@ def match(self, node: torch.fx.Node): alpha = node.kwargs.get("alpha", 1.0) rounding_mode = node.kwargs.get("rounding_mode", None) input, other = node.args - shape = list(input.meta["tensor_meta"].shape) # type: ignore[union-attr] + shape = list(input.meta["val"].shape) # type: ignore[union-attr] + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # only consider the linear case so far + # pyre-fixme[16] + if input.target == aten.select or other.target == aten.select: # type: ignore[union-attr] + parent = ( + # pyre-fixme[16] + input.args[0] # type: ignore[union-attr] + # pyre-fixme[16] + if input.target == aten.select # type: ignore[union-attr] + else other.args[0] # type: ignore[union-attr] + ) + else: + parent = "" + else: + parent = "" group_key = ( "batch_aten_" + self.op.__name__.lower().split(".")[0], str(shape), - str(input.meta["tensor_meta"].dtype), # type: ignore[union-attr] - str(other.meta["tensor_meta"].dtype), # type: ignore[union-attr] + str(input.meta["val"].dtype), # type: ignore[union-attr] + str(other.meta["val"].dtype), # type: ignore[union-attr] str(alpha), str(rounding_mode), + str(parent), ) else: group_key = None @@ -367,21 +444,31 @@ def match(self, node: torch.fx.Node): def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): batch_inputs, batch_others = [], [] alpha = subset[0].kwargs.get("alpha", 1.0) + batch_inputs_meta, batch_others_meta = [], [] for node in subset: input, other = node.args batch_inputs.append(input) batch_others.append(other) + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_others_meta.append(other.meta) # type: ignore[possibly-undefined, union-attr] with graph.inserting_before(subset[0]): stack_inputs = decompose_stack(graph, batch_inputs) stack_others = decompose_stack(graph, batch_others) + stack_inputs_meta = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + stack_others_meta = torch.stack( + [other["val"] for other in batch_others_meta] + ) batch_op = graph.call_function( self.op, args=(stack_inputs, stack_others), kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {}, ) + batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta) for i, original_add in enumerate(subset): with graph.inserting_after(batch_op): new_add = graph.call_function( @@ -475,7 +562,7 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): def is_node_meta_valid(node: Optional[torch.fx.Node]): if node is None: return True - if "example_value" not in node.meta: + if "example_value" not in node.meta and "val" not in node.meta: return False return True @@ -810,6 +897,63 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 +class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass. + The introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs): + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # for relu op, we also use the inplace to construct the key + # we batch the ops with same parent to enable followup split cat + parent = node.args[0] + parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr] + group_key = ( + "batch_aten_" + self.op.__name__.lower().split(".")[0], + str(input.meta["val"].shape), + str(node.kwargs.get("inplace", False)), + # pyre-fixme[16] + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["val"]) + + with graph.inserting_before(subset[0]): + stack_inputs = decompose_stack(graph, batch_inputs) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + batch_op = graph.call_function( + self.op, + args=(stack_inputs,), + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(batch_op): + getitem = graph.call_function(aten.select, args=(batch_op, 0, i)) + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"][ + "batch_aten_" + self.op.__name__.lower().split(".")[0] + ] += 1 + + @register_fusion("batch_tanh") class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion): def __init__(self, **kwargs): @@ -828,26 +972,44 @@ def __init__(self, **kwargs): super().__init__(torch.nn.functional.relu, **kwargs) +@register_fusion("batch_aten_tanh", pre_grad=False) +class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs): + super().__init__(aten.tanh.default, **kwargs) + + +@register_fusion("batch_aten_sigmoid", pre_grad=False) +class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs): + super().__init__(aten.sigmoid.default, **kwargs) + + +@register_fusion("batch_aten_relu", pre_grad=False) +class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs): + super().__init__(aten.relu.default, **kwargs) + + @register_fusion("batch_aten_add", pre_grad=False) -class BatchAddPostGradFusion(BatchPointwiseOpsPostGradFusion): +class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion): def __init__(self, **kwargs): super().__init__(aten.add.Tensor, **kwargs) @register_fusion("batch_aten_sub", pre_grad=False) -class BatchSubPostGradFusion(BatchPointwiseOpsPostGradFusion): +class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion): def __init__(self, **kwargs): super().__init__(aten.sub.Tensor, **kwargs) @register_fusion("batch_aten_div", pre_grad=False) -class BatchDivPostGradFusion(BatchPointwiseOpsPostGradFusion): +class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion): def __init__(self, **kwargs): super().__init__(aten.div.Tensor, **kwargs) @register_fusion("batch_aten_mul", pre_grad=False) -class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion): +class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion): def __init__(self, **kwargs): super().__init__(aten.mul.Tensor, **kwargs) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index ad6adf748dd2..563804f2471a 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -54,7 +54,9 @@ ] post_grad_pass_names = [ + "normalization_aten_pass", "decompose_mm_pass", + "unbind_stack_aten_pass", ] for pass_name in pre_grad_pass_names: @@ -1609,3 +1611,108 @@ def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int): split_sections = new_split_sections counters["inductor"]["merge_stack_tahn_unbind_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_cat_default_aten(match: Match, *args, **kwargs): + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.info("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if "val" not in tensor.meta: + log.warning("val absent for node: %s", tensor) + return + + ndim = cat_node.meta["val"].dim() + + def is_empty_tensor(x: torch.fx.Node) -> bool: + # special case where torch.ops.aten.cat.default supports cat'ing with an empty tensor + x_shape = x.meta["val"].shape + return len(x_shape) == 1 and x_shape[0] == 0 + + assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.ops.aten.cat.default, + args=(tensors,), + kwargs={"dim": cat_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat, + ListOf(CallFunctionVarArgs(torch.ops.aten.unsqueeze)), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_aten_pass"), +) +def merge_unbind_stack_aten(match: Match, *args, **kwargs): + node = match.nodes[-1] + graph = match.graph + # pyre-fixme[6] + unsqueeze_nodes = list(node.args[0]) # type: ignore[arg-type] + cat_dim = get_arg_value(node, 1, "dim") + # check the unsqueeze nodes come from the select nodes + if not all( + get_arg_value(unsqueeze_node, 0, "input").target == torch.ops.aten.select + for unsqueeze_node in unsqueeze_nodes + ): + return + select_nodes = [ + get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes + ] + parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") + # check the target of select_nodes are the same + if not all( + select_node.target == torch.ops.aten.select for select_node in select_nodes + ): + return + # check the select nodes come from the same parent node + if not all( + get_arg_value(select_node, 0, "input") == parent_of_select_node + for select_node in select_nodes + ): + return + if len(unsqueeze_nodes) != len(select_nodes): + return + # check the select nodes have the same dim + if not all( + get_arg_value(select_node, 1, "dim") == cat_dim for select_node in select_nodes + ): + return + # check the select nodes have consecutive indices starting from 0 + if get_arg_value(select_nodes[0], 2, "index") != 0 or not is_sorted_and_consecutive( + [get_arg_value(select_node, 2, "index") for select_node in select_nodes] + ): + return + node.replace_all_uses_with(parent_of_select_node) + graph.erase_node(node) + for unsqueeze_node in unsqueeze_nodes: + graph.erase_node(unsqueeze_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters["inductor"]["unbind_stack_aten_pass"] += 1 From bb6bfd9ad849945676e5fff030c5f3923e017036 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 30 May 2024 10:40:02 -0700 Subject: [PATCH 147/706] [dynamo][compile-time] Cache the child guard managers (#127377) Reduces compile time of MobileBertForMaskedLM model from 39 seconds to 26 seconds. This was a regression introduced by #125202. Before that PR, compile time was 24 seconds. The extra two seconds is just because we are going through enormous number of guards. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127377 Approved by: https://github.com/jansel --- torch/_dynamo/guards.py | 86 +++++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index a8fc77b92c11..ac46b4df0f38 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -512,6 +512,11 @@ def __init__( # limit the number of cache entries with same ID_MATCH'd object. self.id_matched_objs: Dict[str, ReferenceType[object]] = {} + # Save the guard managers to avoid repeatedly traversing sources. + self._cached_guard_managers: Dict[ + str, torch._C._dynamo.guards.GuardManager + ] = {} + def guard_on_dict_keys_and_ignore_order(self, example_value, guard): dict_mgr = self.get_guard_manager(guard) if isinstance(dict_mgr, DictGuardManager): @@ -758,6 +763,10 @@ def get_guard_manager_from_source(self, source): example_value = None source_name = source.name() + + if source_name != "" and source_name in self._cached_guard_managers: + return self._cached_guard_managers[source_name] + if source_name != "": example_value = self.get(source_name) @@ -781,7 +790,7 @@ def get_guard_manager_from_source(self, source): # RootGuardManager accepts a dict but still its not a # DictGuardManager because we will eventually move to # fastlocals. - return root_guard_manager.dict_getitem_manager( + out = root_guard_manager.dict_getitem_manager( key=source.local_name, source=source_name, example_value=example_value, @@ -791,14 +800,14 @@ def get_guard_manager_from_source(self, source): # Global manager accepts a dict but it is not a DictGuardManager # because globals dict is big and we typically guard on a very # selected items on globals. - return self.get_global_guard_manager().dict_getitem_manager( + out = self.get_global_guard_manager().dict_getitem_manager( key=source.global_name, source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, ) elif istype(source, GlobalWeakRefSource): - return self.get_global_guard_manager().global_weakref_manager( + out = self.get_global_guard_manager().global_weakref_manager( global_name=source.global_name, source=source_name, example_value=example_value, @@ -812,7 +821,7 @@ def get_guard_manager_from_source(self, source): return root_guard_manager elif istype(source, TypeSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.type_manager( + out = base_guard_manager.type_manager( source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, @@ -822,10 +831,10 @@ def get_guard_manager_from_source(self, source): (OptimizerSource, NNModuleSource, NotNNModuleSource, FSDPNNModuleSource), ): assert base_guard_manager # to make mypy happy - return base_guard_manager + out = base_guard_manager elif istype(source, GradSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.grad_manager( + out = base_guard_manager.grad_manager( source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, @@ -834,7 +843,7 @@ def get_guard_manager_from_source(self, source): assert base_guard_manager # to make mypy happy if isinstance(base_example_value, torch.nn.Module): - return self.getattr_on_nn_module( + out = self.getattr_on_nn_module( source, base_guard_manager, base_example_value, @@ -843,13 +852,13 @@ def get_guard_manager_from_source(self, source): source_name, guard_manager_enum, ) - - return base_guard_manager.getattr_manager( - attr=source.member, - source=source_name, - example_value=example_value, - guard_manager_enum=guard_manager_enum, - ) + else: + out = base_guard_manager.getattr_manager( + attr=source.member, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, GetItemSource): assert base_guard_manager # to make mypy happy if isinstance(base_example_value, (dict, collections.OrderedDict)): @@ -858,7 +867,7 @@ def get_guard_manager_from_source(self, source): # dicts) so that GetItemSource is only for non dict objects. if isinstance(base_guard_manager, DictGuardManager): assert self.manager_guards_on_keys(base_guard_manager_enum) - return getitem_on_dict_manager( + out = getitem_on_dict_manager( source, base_guard_manager, base_example_value, @@ -871,40 +880,40 @@ def get_guard_manager_from_source(self, source): "Expecting clean index here. Likely Dynamo forgot to mark" " a dict as guard_on_key_order" ) - return base_guard_manager.dict_getitem_manager( + out = base_guard_manager.dict_getitem_manager( key=source.index, source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, ) elif isinstance(base_example_value, list) and not source.index_is_slice: - return base_guard_manager.list_getitem_manager( + out = base_guard_manager.list_getitem_manager( key=source.index, source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, ) elif isinstance(base_example_value, tuple) and not source.index_is_slice: - return base_guard_manager.tuple_getitem_manager( + out = base_guard_manager.tuple_getitem_manager( key=source.index, source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, ) - - index = source.index - if source.index_is_slice: - index = source.unpack_slice() - return base_guard_manager.getitem_manager( - key=index, - source=source_name, - example_value=example_value, - guard_manager_enum=guard_manager_enum, - ) + else: + index = source.index + if source.index_is_slice: + index = source.unpack_slice() + out = base_guard_manager.getitem_manager( + key=index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, ODictGetItemSource): if isinstance(base_guard_manager, DictGuardManager): assert self.manager_guards_on_keys(base_guard_manager_enum) - return getitem_on_dict_manager( + out = getitem_on_dict_manager( source, base_guard_manager, base_example_value, @@ -913,7 +922,7 @@ def get_guard_manager_from_source(self, source): ) else: assert base_guard_manager # to make mypy happy - return base_guard_manager.dict_getitem_manager( + out = base_guard_manager.dict_getitem_manager( key=source.index, source=source_name, example_value=example_value, @@ -923,7 +932,7 @@ def get_guard_manager_from_source(self, source): assert base_guard_manager # to make mypy happy assert callable(base_example_value) if not source.is_kw: - return base_guard_manager.func_defaults_manager( + out = base_guard_manager.func_defaults_manager( source=base_source_name, example_value=base_example_value.__defaults__, guard_manager_enum=GuardManagerType.GUARD_MANAGER, @@ -947,7 +956,7 @@ def get_guard_manager_from_source(self, source): ) assert not isinstance(dict_mgr, DictGuardManager) - return dict_mgr.dict_getitem_manager( + out = dict_mgr.dict_getitem_manager( key=source.idx_key, source=source_name, example_value=example_value, @@ -955,7 +964,7 @@ def get_guard_manager_from_source(self, source): ) elif istype(source, NumpyTensorSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.lambda_manager( + out = base_guard_manager.lambda_manager( python_lambda=from_numpy, source=source_name, example_value=example_value, @@ -963,7 +972,7 @@ def get_guard_manager_from_source(self, source): ) elif istype(source, FlattenScriptObjectSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.lambda_manager( + out = base_guard_manager.lambda_manager( python_lambda=lambda x: x.__obj_flatten__(), source=source_name, example_value=example_value, @@ -971,7 +980,7 @@ def get_guard_manager_from_source(self, source): ) elif istype(source, ScriptObjectQualifiedNameSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.lambda_manager( + out = base_guard_manager.lambda_manager( python_lambda=lambda x: x._type().qualified_name(), source=source_name, example_value=example_value, @@ -979,7 +988,7 @@ def get_guard_manager_from_source(self, source): ) elif istype(source, TupleIteratorGetItemSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.tuple_iterator_getitem_manager( + out = base_guard_manager.tuple_iterator_getitem_manager( index=source.index, source=source_name, example_value=example_value, @@ -990,7 +999,7 @@ def get_guard_manager_from_source(self, source): raise AssertionError( "ConstDictKeySource can only work on DictGuardManager" ) - return base_guard_manager.get_key_manager( + out = base_guard_manager.get_key_manager( index=source.index, source=source_name, example_value=example_value, @@ -1001,6 +1010,9 @@ def get_guard_manager_from_source(self, source): f"missing guard manager builder {source} - {source.name()}" ) + self._cached_guard_managers[source.name()] = out + return out + def get_guard_manager(self, guard: Guard): return self.get_guard_manager_from_source(guard.originating_source) From 159632aecd2c189e5042d4654eddb04d70d14aab Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 30 May 2024 10:40:02 -0700 Subject: [PATCH 148/706] [dynamo] Support hasattr on BuiltinVariable (#127372) Fixes https://github.com/pytorch/pytorch/issues/127172 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127372 Approved by: https://github.com/williamwen42, https://github.com/yanboliang ghstack dependencies: #127377 --- test/dynamo/test_repros.py | 15 +++++++++++++++ torch/_dynamo/variables/builtin.py | 2 ++ 2 files changed, 17 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index ae317a78d96f..4b151d8b093e 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5026,6 +5026,21 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), opt_fn(x)) + def test_hasattr_builtin(self): + class MyClass: + foo: int = 1 + + def func(x, m): + if getattr(type(m), "foo", 0): + return x + MyClass.foo + return x + + opt_func = torch.compile(func, backend="eager", fullgraph=True) + m = MyClass() + x = torch.zeros(()) + self.assertEqual(func(x, m), opt_func(x, m)) + self.assertEqual(func(x, 0), opt_func(x, 0)) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index a9f6272d0571..5603496193e2 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1441,6 +1441,8 @@ def call_next(self, tx, arg: VariableTracker): def call_hasattr(self, tx, obj, attr): if attr.is_python_constant(): name = attr.as_python_constant() + if isinstance(obj, variables.BuiltinVariable): + return variables.ConstantVariable(hasattr(obj.fn, name)) return obj.call_hasattr(tx, name) def call_map(self, tx, fn, seq): From ee08cf57924a4230edad3101666890d8fe050c75 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 31 May 2024 04:27:20 +0000 Subject: [PATCH 149/706] Improve MAGMA conditional macro in BatchLinearAlgebra.cpp (#127495) Unnecessary TORCH_CHECK(false) are changed to macro coverage as mentioned in #127371 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127495 Approved by: https://github.com/ezyang --- .../native/cuda/linalg/BatchLinearAlgebra.cpp | 213 ++++++++---------- 1 file changed, 98 insertions(+), 115 deletions(-) diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index 2122d2af5f6a..18a1316fb567 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -40,8 +40,6 @@ #include #include -const bool use_magma_ = true; - namespace { struct MagmaInitializer { MagmaInitializer() { @@ -61,9 +59,6 @@ struct MagmaInitializer { #error "MAGMA release minor or micro version >= 10, please correct AT_MAGMA_VERSION" #endif -#else -const bool use_magma_ = false; - #endif namespace at::native { @@ -84,9 +79,9 @@ void magmaLdlHermitian( magma_int_t ldda, magma_int_t* ipiv, magma_int_t* info) { - TORCH_CHECK( - false, - "LDL decomposition is not available.", + static_assert( + false&&sizeof(scalar_t), + "LDL decomposition is not available." "Please rebuild with MAGMA 2.5.4+."); } @@ -1034,18 +1029,13 @@ magma_trans_t to_magma(TransposeType trans) { namespace { +#if AT_MAGMA_ENABLED() template void apply_ldl_factor_magma( const Tensor& A, const Tensor& pivots, const Tensor& info, bool upper) { -#if !AT_MAGMA_ENABLED() - TORCH_CHECK( - false, - "torch.linalg.ldl_factor: MAGMA library not found in " - "compilation. Please rebuild with MAGMA."); -#else auto batch_size = batchCount(A); magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)"); magma_int_t leading_dim = magma_int_cast(A.stride(-1), "A.stride(-1)"); @@ -1076,7 +1066,6 @@ void apply_ldl_factor_magma( } pivots.copy_(pivots_cpu); info.copy_(info_cpu); -#endif } void ldl_factor_magma( @@ -1098,6 +1087,7 @@ void ldl_factor_magma( apply_ldl_factor_magma(LD, pivots, info, upper); }); } +#endif void ldl_factor_kernel( const Tensor& LD, @@ -1110,8 +1100,10 @@ void ldl_factor_kernel( case at::LinalgBackend::Cusolver: return ldl_factor_cusolver( LD, pivots, info, upper, hermitian); +#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: return ldl_factor_magma(LD, pivots, info, upper, hermitian); +#endif default: // By default use cusolver if available and magma otherwise. // If cusolver and magma 2.5.4+ are both available and hermitian=true, @@ -1155,12 +1147,9 @@ REGISTER_CUDA_DISPATCH(ldl_solve_stub, &ldl_solve_kernel) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#if AT_MAGMA_ENABLED() template static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, int64_t& info) { -#if !AT_MAGMA_ENABLED() -AT_ERROR("cholesky_solve: MAGMA library not found in " - "compilation. Please rebuild with MAGMA."); -#else magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower; auto A_data = A.data_ptr(); @@ -1179,8 +1168,8 @@ AT_ERROR("cholesky_solve: MAGMA library not found in " auto b_mat_stride = matrixStride(b); magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount"); - scalar_t** A_array; - scalar_t** b_array; + scalar_t** A_array = nullptr; + scalar_t** b_array = nullptr; ALLOCATE_ARRAY(A_array, scalar_t*, batch_size); ALLOCATE_ARRAY(b_array, scalar_t*, batch_size); @@ -1197,7 +1186,7 @@ AT_ERROR("cholesky_solve: MAGMA library not found in " // Compute as many batches of 65535 possible // The number of "mini"-batches are floor(batch_size / batch_limit) // and these cover floor(batch_size / batch_limit) * batch_limit matrix solves - int64_t mini_batches = batch_size / batch_limit, mini_idx; + int64_t mini_batches = batch_size / batch_limit, mini_idx = 0; for (mini_idx = 0; mini_idx < mini_batches * batch_limit; mini_idx += batch_limit) { scalar_t** A_array_cur = &A_array[mini_idx]; scalar_t** b_array_cur = &b_array[mini_idx]; @@ -1221,7 +1210,6 @@ AT_ERROR("cholesky_solve: MAGMA library not found in " info = info_tmp; } -#endif } Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bool upper) { @@ -1234,6 +1222,7 @@ Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bo TORCH_CHECK(info == 0, "MAGMA cholesky_solve : invalid argument: ", -info); return self_working_copy; } +#endif // Todo: cusolverDnpotrsBatched only supports nrhs == 1 and does not have good performance. // Batched cholesky_solve is dispatched to magma. @@ -1243,14 +1232,20 @@ Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upp switch (preferred_backend) { case at::LinalgBackend::Cusolver: return _cholesky_solve_helper_cuda_cusolver(self, A, upper); +#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: return _cholesky_solve_helper_cuda_magma(self, A, upper); +#endif default: - if (batchCount(self) == 1 || !use_magma_) { +#if !AT_MAGMA_ENABLED() + return _cholesky_solve_helper_cuda_cusolver(self, A, upper); +#else + if (batchCount(self) == 1) { return _cholesky_solve_helper_cuda_cusolver(self, A, upper); } else { return _cholesky_solve_helper_cuda_magma(self, A, upper); } +#endif } #else return _cholesky_solve_helper_cuda_magma(self, A, upper); @@ -1259,14 +1254,9 @@ Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upp // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#if AT_MAGMA_ENABLED() template static void apply_cholesky(const Tensor& self, bool upper, const Tensor& info) { -#if !AT_MAGMA_ENABLED() - TORCH_CHECK( - false, - "Calling torch.linalg.cholesky on a CUDA tensor requires compiling ", - "PyTorch with MAGMA. Please use PyTorch built with MAGMA support."); -#else magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower; auto self_data = self.data_ptr(); @@ -1288,7 +1278,7 @@ static void apply_cholesky(const Tensor& self, bool upper, const Tensor& info) { auto self_mat_stride = matrixStride(self); magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount"); - scalar_t** self_array; + scalar_t** self_array = nullptr; ALLOCATE_ARRAY(self_array, scalar_t*, batch_size); @@ -1314,7 +1304,6 @@ static void apply_cholesky(const Tensor& self, bool upper, const Tensor& info) { uplo, n, self_array_cur, lda, info_array_cur, nbatches, magma_queue); } } -#endif } void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info) { @@ -1350,6 +1339,7 @@ void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info) } } } +#endif static void cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) { #if defined(USE_LINALG_SOLVER) && !defined(USE_ROCM) @@ -1358,15 +1348,21 @@ static void cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) case at::LinalgBackend::Cusolver: cholesky_helper_cusolver(input, upper, info); break; +#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: cholesky_helper_magma(input, upper, info); break; +#endif default: - if (batchCount(input) == 1 || !use_magma_ || use_cusolver_potrf_batched_) { +#if !AT_MAGMA_ENABLED() + cholesky_helper_cusolver(input, upper, info); +#else + if (batchCount(input) == 1 || use_cusolver_potrf_batched_) { cholesky_helper_cusolver(input, upper, info); } else { cholesky_helper_magma(input, upper, info); } +#endif } #else cholesky_helper_magma(input, upper, info); @@ -1384,11 +1380,9 @@ This is an in-place routine, content of 'input' is overwritten. MAGMA requires 'infos' to reside in CPU memory. For more information see MAGMA's documentation for POTRS routine. */ +#if AT_MAGMA_ENABLED() template static void apply_cholesky_inverse(Tensor& input, Tensor& infos, bool upper) { -#if !AT_MAGMA_ENABLED() - TORCH_CHECK(false, "cholesky_inverse: MAGMA library not found in compilation. Please rebuild with MAGMA."); -#else // magmaCholeskyInverse (magma_dpotri_gpu) is slow because internally // it transfers data several times between GPU and CPU and calls lapack routine on CPU // using magmaCholeskySolveBatched is a lot faster @@ -1418,7 +1412,6 @@ static void apply_cholesky_inverse(Tensor& input, Tensor& infos, bool upper) { int64_t info_tmp = 0; apply_cholesky_solve(result_u, input_u, upper, info_tmp); infos.fill_(info_tmp); -#endif } // This is a type dispatching helper function for 'apply_cholesky_inverse' @@ -1428,6 +1421,7 @@ Tensor& cholesky_inverse_kernel_impl_magma(Tensor &result, Tensor& infos, bool u }); return result; } +#endif Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper) { // This function calculates the inverse matrix in-place @@ -1438,20 +1432,25 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper) switch (preferred_backend) { case at::LinalgBackend::Cusolver: return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); +#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: return cholesky_inverse_kernel_impl_magma(result, infos, upper); +#endif default: - if (batchCount(result) == 1 || - !use_magma_) { +#if !AT_MAGMA_ENABLED() + return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); +#else + if (batchCount(result) == 1) { return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); } else { return cholesky_inverse_kernel_impl_magma(result, infos, upper); } + +#endif } #else return cholesky_inverse_kernel_impl_magma(result, infos, upper); #endif - } REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); @@ -1526,14 +1525,9 @@ static void apply_lu_factor_looped_magma(const Tensor& input, const Tensor& pivo For further details, please see the MAGMA documentation for magma_dgetrf_batched. */ +#if AT_MAGMA_ENABLED() template static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { -#if !AT_MAGMA_ENABLED() - TORCH_CHECK( - false, - "Calling linalg.lu_factor on a CUDA tensor requires compiling ", - "PyTorch with MAGMA. Please rebuild with MAGMA."); -#else // There is a bug in lu_factor_batched_magma in MAGMA < 2.5.2, see // https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on std::tuple version; @@ -1550,7 +1544,7 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv magma_int_t n = magma_int_cast(input.size(-1), "n"); auto leading_dimension = std::max(1, m); - scalar_t** input_array; + scalar_t** input_array = nullptr; ALLOCATE_ARRAY(input_array, scalar_t*, batch_size); // Set up array of pointers to matrices @@ -1570,7 +1564,7 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv // magmaLuBatched might not set the values for it // see https://github.com/pytorch/pytorch/pull/53064 pivots.fill_(1); - magma_int_t** pivots_array; + magma_int_t** pivots_array = nullptr; ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size); for (int64_t i = 0; i < batch_size; i++) { pivots_array[i] = &pivots_data[i * pivots_stride]; @@ -1583,7 +1577,6 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv // block CPU until all operations on the queue are finished // this explicit sync prevents garbage results from the subsequent magmaLuSolveBatched call from a different queue magma_queue_sync(magma_queue.get_queue()); -#endif } static void lu_factor_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { @@ -1597,6 +1590,7 @@ static void lu_factor_batched_magma(const Tensor& input, const Tensor& pivots, c apply_lu_factor_batched_magma(input, pivots, infos, compute_pivots); }); } +#endif static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { auto batch_size = batchCount(input); @@ -1604,6 +1598,7 @@ static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& i auto m = input.size(-2); auto n = input.size(-1); +#if AT_MAGMA_ENABLED() const auto lu_factor_magma = [batch_size](const Tensor& input, const Tensor& pivots, const Tensor& infos, const bool compute_pivots) { if (batch_size == 1) { lu_factor_looped_magma(input, pivots, infos, compute_pivots); @@ -1611,6 +1606,7 @@ static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& i lu_factor_batched_magma(input, pivots, infos, compute_pivots); } }; +#endif const auto preferred_backend = at::globalContext().linalgPreferredBackend(); #ifdef USE_LINALG_SOLVER @@ -1635,9 +1631,12 @@ static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& i lu_factor_cusolver(input, pivots, infos, compute_pivots); } else #endif // ifdef USE_LINALG_SOLVER +#if AT_MAGMA_ENABLED() if (preferred_backend == at::LinalgBackend::Magma) { lu_factor_magma(input, pivots, infos, compute_pivots); - } else { // preferred backend == default + } else +#endif + { // preferred backend == default #ifdef USE_LINALG_SOLVER #if AT_MAGMA_ENABLED() // If magma batched is buggy, we use cusolver @@ -1701,8 +1700,8 @@ AT_ERROR("triangular_solve: MAGMA library not found in " auto A_mat_stride = matrixStride(A); auto b_mat_stride = matrixStride(b); - scalar_t** A_array; - scalar_t** b_array; + scalar_t** A_array = nullptr; + scalar_t** b_array = nullptr; ALLOCATE_ARRAY(A_array, scalar_t*, batch_size); ALLOCATE_ARRAY(b_array, scalar_t*, batch_size); @@ -1720,7 +1719,7 @@ AT_ERROR("triangular_solve: MAGMA library not found in " // The number of "mini"-batches are floor(batch_size / batch_limit) // and these cover floor(batch_size / batch_limit) * batch_limit matrix solves int64_t mini_batches = batch_size / batch_limit; - int64_t mini_idx; // this is outside the loop because it is used for the case batch_size % batch_limit != 0 + int64_t mini_idx = 0; // this is outside the loop because it is used for the case batch_size % batch_limit != 0 for (mini_idx = 0; mini_idx < mini_batches * batch_limit; mini_idx += batch_limit) { scalar_t** A_array_cur = &A_array[mini_idx]; scalar_t** b_array_cur = &b_array[mini_idx]; @@ -1777,7 +1776,7 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) { #ifdef USE_LINALG_SOLVER return orgqr_helper_cusolver(result, tau); // cusolver #else - TORCH_CHECK(false, "Calling torch.orgqr on a CUDA tensor requires compiling ", + static_assert(false, "Calling torch.orgqr on a CUDA tensor requires compiling ", "PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support."); #endif } @@ -1788,8 +1787,8 @@ void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, b #ifdef USE_LINALG_SOLVER ormqr_cusolver(input, tau, other, left, transpose); #else - TORCH_CHECK(false, - "Calling torch.ormqr on a CUDA tensor requires compiling ", + static_assert(false, + "Calling torch.ormqr on a CUDA tensor requires compiling " "PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support."); #endif } @@ -1798,15 +1797,9 @@ REGISTER_CUDA_DISPATCH(ormqr_stub, &ormqr_kernel); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ qr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#if AT_MAGMA_ENABLED() template static void apply_geqrf(const Tensor& input, const Tensor& tau) { -#if !AT_MAGMA_ENABLED() - TORCH_CHECK( - false, - "Calling torch.geqrf on a CUDA tensor requires compiling ", - "PyTorch with MAGMA. Please use PyTorch built with MAGMA support."); -#else - magma_int_t m = magma_int_cast(input.size(-2), "m"); magma_int_t n = magma_int_cast(input.size(-1), "n"); @@ -1833,7 +1826,6 @@ static void apply_geqrf(const Tensor& input, const Tensor& tau) { checkMagmaInternalError(info, "geqrf"); } tau.copy_(tau_cpu, /*non_blocking=*/true); -#endif } // This is a type dispatching helper function for 'apply_geqrf' @@ -1842,6 +1834,7 @@ void geqrf_magma(const Tensor& input, const Tensor& tau) { apply_geqrf(input, tau); }); } +#endif void geqrf_kernel(const Tensor& input, const Tensor& tau) { #ifdef USE_LINALG_SOLVER @@ -1867,8 +1860,10 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) { // - ?geqrf2_gpu gives correct R, but doesn't allow computation of Q via ?orgqr_gpu // Refer to the below link for more details: // http://icl.cs.utk.edu/magma/forum/viewtopic.php?f=2&t=1015&p=2800&hilit=geqrf_gpu#p2800 +#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: return geqrf_magma(input, tau); +#endif case at::LinalgBackend::Cusolver: default: return geqrf_cusolver_backend(input, tau); @@ -1880,14 +1875,9 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) { REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel); +#if AT_MAGMA_ENABLED() template static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { -#if !AT_MAGMA_ENABLED() - TORCH_CHECK( - false, - "Calling torch.linalg.eigh/eigvalsh on a CUDA tensor requires compiling ", - "PyTorch with MAGMA. Please use PyTorch built with MAGMA support."); -#else TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == kCPU); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == kCPU); @@ -1907,7 +1897,7 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const auto values_data = values.data_ptr(); auto infos_data = infos.data_ptr(); - scalar_t* wA; + scalar_t* wA = nullptr; ALLOCATE_ARRAY(wA, scalar_t, lda * lda); // Run once, first to get the optimum work sizes. @@ -1917,14 +1907,14 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const magma_int_t lwork = -1; scalar_t wkopt; magma_int_t liwork = -1; - magma_int_t iwkopt; + magma_int_t iwkopt = -1; magma_int_t lrwork = -1; value_t rwkopt; magmaSyevd(jobz, uplo, n, vectors_data, lda, values_data, wA, lda, &wkopt, lwork, &rwkopt, lrwork, &iwkopt, liwork, infos_data); - scalar_t* work; - magma_int_t* iwork; + scalar_t* work = nullptr; + magma_int_t* iwork = nullptr; lwork = magma_int_cast(std::max(1, real_impl(wkopt)), "work_size"); liwork = magma_int_cast(std::max(1, iwkopt), "iwork_size"); ALLOCATE_ARRAY(work, scalar_t, lwork); @@ -1951,7 +1941,6 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const return; } } -#endif } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1990,14 +1979,17 @@ void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, co eigenvalues.copy_(eigenvalues_cpu); } } +#endif void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { #if defined(USE_LINALG_SOLVER) auto preferred_backend = at::globalContext().linalgPreferredBackend(); switch (preferred_backend) { +#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: linalg_eigh_magma(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); break; +#endif case at::LinalgBackend::Cusolver: default: linalg_eigh_cusolver(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); @@ -2017,12 +2009,9 @@ This is an in-place routine, content of 'input', 'values', 'vectors' is overwrit 'infos' is an int Tensor containing error codes for each matrix in the batched input. For more information see MAGMA's documentation for GEEV routine. */ +#if AT_MAGMA_ENABLED() template void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) { -#if !AT_MAGMA_ENABLED() -TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling PyTorch with MAGMA. " - "Either transfer the tensor to the CPU before calling torch.linalg.eig or recompile with MAGMA."); -#else TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == at::kCPU); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == at::kCPU); @@ -2072,7 +2061,6 @@ TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling magmaEig(jobvl, jobvr, n, input_working_ptr, lda, values_working_ptr, lvectors_data, ldvl, rvectors_working_ptr, ldvr, work_data, lwork, rwork_data, info_working_ptr); } -#endif } // This is a type dispatching helper function for 'apply_linalg_eig' @@ -2105,10 +2093,6 @@ static void apply_svd_magma(const Tensor& A, const Tensor& S, const Tensor& Vh, const Tensor& info) { -#if !AT_MAGMA_ENABLED() -AT_ERROR("linalg.svd: MAGMA library not found in " - "compilation. Please rebuild with MAGMA."); -#else using value_t = typename c10::scalar_value_type::type; const auto A_data = A.data_ptr(); const auto U_data = compute_uv ? U.data_ptr() : nullptr; @@ -2136,7 +2120,7 @@ AT_ERROR("linalg.svd: MAGMA library not found in " rwork = static_cast(storage_rwork.mutable_data()); } - magma_int_t* iwork; + magma_int_t* iwork = nullptr; ALLOCATE_ARRAY(iwork, magma_int_t, 8 * std::min(m, n)); // Query svd for the optimal lwork size @@ -2151,7 +2135,7 @@ AT_ERROR("linalg.svd: MAGMA library not found in " &wkopt, lwork, rwork, iwork, info_data); lwork = magma_int_cast(real_impl(wkopt), "work_size"); } - scalar_t* work; + scalar_t* work = nullptr; ALLOCATE_ARRAY(work, scalar_t, lwork); for (int64_t i = 0; i < batchsize; i++) { @@ -2164,7 +2148,6 @@ AT_ERROR("linalg.svd: MAGMA library not found in " work, lwork, rwork, iwork, info_data + i); } -#endif } void svd_magma(const Tensor& A, @@ -2206,6 +2189,7 @@ void svd_magma(const Tensor& A, S.copy_(S_, /*non_blocking*/true); info.copy_(info, /*non_blocking*/true); } +#endif void svd_kernel(const Tensor& A, const bool full_matrices, @@ -2217,10 +2201,13 @@ void svd_kernel(const Tensor& A, const Tensor& info) { #ifdef USE_LINALG_SOLVER // We always use cuSOLVER unless the user has specified they want to use MAGMA +#if AT_MAGMA_ENABLED() bool use_magma = at::globalContext().linalgPreferredBackend() == at::LinalgBackend::Magma; if (use_magma) { svd_magma(A, full_matrices, compute_uv, U, S, Vh, info); - } else { + } else +#endif + { // svd_cusolver computes V rather than Vh, so we pass a view of Vh.mT // and then conjugate Vh in-place svd_cusolver(A, full_matrices, compute_uv, driver, U, S, compute_uv ? Vh.mT() : Vh, info); @@ -2251,14 +2238,9 @@ REGISTER_CUDA_DISPATCH(svd_stub, &svd_kernel) For further details, please see the MAGMA documentation for magma_dgetrs_gpu. */ +#if AT_MAGMA_ENABLED() template static void apply_lu_solve_looped_magma(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) { -#if !AT_MAGMA_ENABLED() - TORCH_CHECK( - false, - "Calling linalg.lu_solve on a CUDA tensor requires compiling ", - "PyTorch with MAGMA. Please rebuild with MAGMA."); -#else auto trans = to_magma(transpose); auto b_data = B.data_ptr(); auto lu_data = LU.data_ptr(); @@ -2296,7 +2278,6 @@ static void apply_lu_solve_looped_magma(const Tensor& LU, const Tensor& pivots, // so we don't need to check it all the time TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0); } -#endif } /* @@ -2315,12 +2296,6 @@ static void apply_lu_solve_looped_magma(const Tensor& LU, const Tensor& pivots, */ template static void apply_lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) { -#if !AT_MAGMA_ENABLED() - TORCH_CHECK( - false, - "Calling linalg.lu_solve on a CUDA tensor requires compiling ", - "PyTorch with MAGMA. Please rebuild with MAGMA."); -#else TORCH_INTERNAL_ASSERT(batchCount(B) == batchCount(LU), "batch_size of LU and B must be the same"); TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same"); auto trans = to_magma(transpose); @@ -2338,9 +2313,9 @@ static void apply_lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, auto pivots_stride = pivots.size(-1); magma_int_t batch_size = magma_int_cast(batchCount(B), "batchCount"); - magma_int_t** pivots_array; - scalar_t** lu_array; - scalar_t** b_array; + magma_int_t** pivots_array = nullptr; + scalar_t** lu_array = nullptr; + scalar_t** b_array = nullptr; ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size); ALLOCATE_ARRAY(lu_array, scalar_t*, batch_size); @@ -2364,7 +2339,7 @@ static void apply_lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, scalar_t** b_array_cur = &b_array[mini_idx]; magma_int_t** pivots_array_cur = &pivots_array[mini_idx]; - int info; + int info = -1; magmaLuSolveBatched( n, nrhs, lu_array_cur, leading_dimension, pivots_array_cur, b_array_cur, leading_dimension, @@ -2374,7 +2349,6 @@ static void apply_lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, // so we don't need to check it all the time TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0); } -#endif } static void lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { @@ -2390,6 +2364,7 @@ static void lu_solve_looped_magma(const Tensor& LU, const Tensor& pivots, const apply_lu_solve_looped_magma(LU, pivots, B, trans); }); } +#endif c10::MaybeOwned maybe_expand_lu(const Tensor& B, const Tensor& LU) { // B and LU have the same number of dimensions @@ -2424,9 +2399,11 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor auto b = batchCount(B); auto n = LU.size(-2); auto k = B.size(-1); +#if AT_MAGMA_ENABLED() // magma implementation of LU solve cannot handle a b tensor with last dim > 1024 // See https://bitbucket.org/icl/magma/issues/19/dgesv_batched-dgetrs_batched-fails-for bool over_batched_magma_dim_limit = k > 1024; +#endif // heuristics determined from tests discussed in https://github.com/pytorch/pytorch/pull/72935 // Computes X = U^{-1}L^{-1}P^T B via triangular solves @@ -2441,7 +2418,7 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) - .declare_static_shape(pivots_->sizes(), /*squash_dim=*/pivots_->dim() - 1) + .declare_static_shape(pivots_->sizes(), /*squash_dims=*/pivots_->dim() - 1) .add_output(perm) .add_const_input(*pivots_) .build(); @@ -2457,7 +2434,7 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor // B1 = P^T @ B (must be done out-of-place as B is both source and target) auto B1 = B.scatter(-2, inv_perm.unsqueeze(-1).expand_as(B), B); // B = L^{-1} @ B1 - at::linalg_solve_triangular_out(const_cast(B), *LU_, std::move(B1), /*upper=*/false, /*left=*/true, /*unitriangular=*/true); + at::linalg_solve_triangular_out(const_cast(B), *LU_, B1, /*upper=*/false, /*left=*/true, /*unitriangular=*/true); // B = U^{-1} @ B at::linalg_solve_triangular_out(const_cast(B), *LU_, B, /*upper=*/true); } else { @@ -2479,11 +2456,13 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor }; #endif +#if AT_MAGMA_ENABLED() auto lu_solve_batched_magma_fn = [](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { auto LU_ = maybe_expand_lu(B, LU); auto pivots_ = maybe_expand_pivots(B, pivots); lu_solve_batched_magma(*LU_, *pivots_, B, trans); }; +#endif // Preferred Backend @@ -2498,6 +2477,7 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor return; } else #endif // ifdef USE_LINALG_SOLVER +#if AT_MAGMA_ENABLED() if (preferred_backend == at::LinalgBackend::Magma) { // Looped magma is very slow, but batched magma is buggy in these two cases if (!over_batched_magma_dim_limit && trans == TransposeType::NoTranspose) { @@ -2508,6 +2488,7 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor } return; } +#endif // Heuristic //if (n == k) { @@ -2548,9 +2529,12 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor } if (n <= 8) { - if (use_magma_ && !over_batched_magma_dim_limit && trans == TransposeType::NoTranspose && k >= 256) { +#if AT_MAGMA_ENABLED() + if (!over_batched_magma_dim_limit && trans == TransposeType::NoTranspose && k >= 256) { lu_solve_batched_magma_fn(LU, pivots, B, trans); - } else { + } else +#endif + { lu_solve_batched_cublas_fn(LU, pivots, B, trans); } } else if (n <= 64) { @@ -2583,12 +2567,9 @@ REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_kernel); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#if AT_MAGMA_ENABLED() template static void apply_gels(const Tensor& a, Tensor& b, Tensor& infos) { -#if !AT_MAGMA_ENABLED() - TORCH_CHECK(false, "torch.linalg.lstsq: MAGMA library not found in " - "compilation. Please rebuild with MAGMA."); -#else auto trans = MagmaNoTrans; auto m = magma_int_cast(a.size(-2), "m"); auto n = magma_int_cast(a.size(-1), "n"); @@ -2618,7 +2599,6 @@ static void apply_gels(const Tensor& a, Tensor& b, Tensor& infos) { hwork_ptr, lwork, infos_working_ptr); } ); -#endif } void gels_magma(const Tensor& a, Tensor& b, Tensor& infos) { @@ -2626,6 +2606,7 @@ void gels_magma(const Tensor& a, Tensor& b, Tensor& infos) { apply_gels(a, b, infos); }); } +#endif void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/) { // The steps for using the QR decomposition for solving least squares problems @@ -2714,8 +2695,10 @@ void gels_looped(const Tensor& a, Tensor& b, Tensor& infos) { #if defined(USE_LINALG_SOLVER) && !defined(USE_ROCM) auto preferred_backend = at::globalContext().linalgPreferredBackend(); switch (preferred_backend) { +#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: return gels_magma(a, b, infos); +#endif case at::LinalgBackend::Cusolver: default: // linalg_lstsq_gels is a generic function that is implemented using From ec098b88b692569ad7b086eb184e44b60d47174a Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 30 May 2024 14:16:44 -0700 Subject: [PATCH 150/706] [compiled autograd] torch.compile API (#125880) - enter existing compiled autograd ctx manager before entering torch.compile frames Pull Request resolved: https://github.com/pytorch/pytorch/pull/125880 Approved by: https://github.com/jansel --- test/inductor/test_compiled_autograd.py | 170 +++++++++++++++++++++++- torch/_dynamo/config.py | 4 + torch/_dynamo/eval_frame.py | 32 ++++- 3 files changed, 203 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 87299d796f6c..2daacc308071 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn from torch import _inductor as inductor -from torch._dynamo import compiled_autograd +from torch._dynamo import compiled_autograd, config from torch._dynamo.utils import counters from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA @@ -54,10 +54,14 @@ def hook3(gI, gO): class TestCompiledAutograd(TestCase): def setUp(self) -> None: super().setUp() + torch._logging.set_logs(compiled_autograd_verbose=False) + config.compiled_autograd = False compiled_autograd.reset() def tearDown(self) -> None: super().tearDown() + torch._logging.set_logs(compiled_autograd_verbose=False) + config.compiled_autograd = False compiled_autograd.reset() def check_output_and_recompiles( @@ -230,6 +234,170 @@ def fn(): self.check_output_and_recompiles(fn) + def test_torch_compile_api_inductor(self): + def fn(): + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + + res = [] + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + expected = fn() + with config.patch(compiled_autograd=True): + compiled_fn = torch.compile(fn) + actual = compiled_fn() + self.assertEqual(expected, actual) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_api_aot_eager(self): + def fn(): + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + + res = [] + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + expected = fn() + with config.patch(compiled_autograd=True): + compiled_fn = torch.compile(fn, backend="aot_eager") + actual = compiled_fn() + self.assertEqual(expected, actual) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_api_eager(self): + def fn(): + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + + res = [] + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + expected = fn() + with config.patch(compiled_autograd=True): + compiled_fn = torch.compile(fn, backend="eager") + actual = compiled_fn() + self.assertEqual(expected, actual) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_multiple_torch_compile(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + def fn(): + result = model(x).sum() + result.backward() + + model2 = torch.nn.Linear(4, 4) + x2 = torch.randn([1, 4]) + + def fn2(): + result = model2(x2).sum() + result.backward() + + no_ca1 = torch.compile(fn) + no_ca1() + self.assertEqual(counters["compiled_autograd"]["captures"], 0) + counters.clear() + + with config.patch(compiled_autograd=True): + with_ca = torch.compile(fn2) + with_ca() + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + counters.clear() + + no_ca2 = torch.compile(fn) + no_ca2() + self.assertEqual(counters["compiled_autograd"]["captures"], 0) + + def test_torch_compile_graph_break(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + @torch._dynamo.disable() + def fn(): + result = model(x).sum() + result.backward() + + with config.patch(compiled_autograd=True): + opt_fn = torch.compile(fn) + opt_fn() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_graph_break2(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + @torch._dynamo.disable() + def inner_fn(loss): + loss.backward() + + def fn(): + result = model(x).sum() + inner_fn(result) + + with config.patch(compiled_autograd=True): + opt_fn = torch.compile(fn) + opt_fn() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_only_backward_call(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + result = model(x).sum() + with config.patch(compiled_autograd=True): + opt_bwd = torch.compile(lambda: result.backward()) + opt_bwd() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + def test_dynamo_boxed(self): def get_placeholders(gm_): placeholders = [] diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 6f4219a03b18..212021859c46 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -454,6 +454,10 @@ def default_debug_dir_root(): # WARNING: this is an experimental flag and is subject to change. _experimental_support_context_fn_in_torch_utils_checkpoint = False +# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile(). +# Note: AOT Autograd will still trace joint graphs. +compiled_autograd = False + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index fa9311f2c18a..8a195664b403 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -493,6 +493,9 @@ def __init__( export=False, dynamic=None, compiler_config=None, + rebuild_ctx: Optional[ + Callable[[], Union[OptimizeContext, _NullDecorator]] + ] = None, ): def on_enter(): install_generation_tagging_init() @@ -508,6 +511,17 @@ def on_enter(): compiler_config=compiler_config, ) + if config.compiled_autograd: + + def call_compiled_autograd(): + assert rebuild_ctx is not None + compiler_fn = rebuild_ctx() + ctx = torch._dynamo.compiled_autograd.enable(compiler_fn) + ctx.__enter__() + return functools.partial(ctx.__exit__, None, None, None) + + self.enter_exit_hooks.append(call_compiled_autograd) + class RunOnlyContext(_TorchDynamoContext): def __init__(self): @@ -577,6 +591,7 @@ def _optimize_catch_errors( export=False, dynamic=None, compiler_config=None, + rebuild_ctx=None, ): return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), @@ -585,6 +600,7 @@ def _optimize_catch_errors( export=export, dynamic=dynamic, compiler_config=compiler_config, + rebuild_ctx=rebuild_ctx, ) @@ -635,7 +651,15 @@ def is_inductor_supported(): return False -def optimize( +def optimize(*args, **kwargs): + def rebuild_ctx(): + return optimize(*args, **kwargs) + + return _optimize(rebuild_ctx, *args, **kwargs) + + +def _optimize( + rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], backend="inductor", *, nopython=False, @@ -643,7 +667,7 @@ def optimize( guard_fail_fn=None, disable=False, dynamic=None, -): +) -> Union[OptimizeContext, _NullDecorator]: """ The main entrypoint of TorchDynamo. Do graph capture and call backend() to optimize extracted graphs. @@ -691,6 +715,7 @@ def toy_example(a, b): backend, dynamic=dynamic, hooks=hooks, + rebuild_ctx=rebuild_ctx, ) # The backend function is stashed in the callable returned by # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can @@ -703,6 +728,7 @@ def toy_example(a, b): compiler_config=backend.get_compiler_config() if hasattr(backend, "get_compiler_config") else None, + rebuild_ctx=rebuild_ctx, ) @@ -1466,6 +1492,7 @@ def optimize_assert( export=False, export_constraints=None, dynamic=None, + rebuild_ctx=None, ): """ The same as `torch._dynamo.optimize(backend, nopython=True)` @@ -1483,6 +1510,7 @@ def optimize_assert( backend_ctx_ctor, export=export, dynamic=dynamic, + rebuild_ctx=rebuild_ctx, ) From ae47152ca86838203b2b084e6ca87a416ea255c8 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 31 May 2024 05:40:16 +0000 Subject: [PATCH 151/706] Expand supported labels to most self-hosted linux pull.yml workflows (#127578) Initial set of runners added in https://github.com/pytorch/pytorch/pull/127566 seem to be working. Expanding to include more machine types, especially GPU machines Pull Request resolved: https://github.com/pytorch/pytorch/pull/127578 Approved by: https://github.com/huydhn --- .github/lf-canary-scale-config.yml | 18 ++++++++++++++++++ .github/lf-scale-config.yml | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/.github/lf-canary-scale-config.yml b/.github/lf-canary-scale-config.yml index 6aeca46f484e..628444237e52 100644 --- a/.github/lf-canary-scale-config.yml +++ b/.github/lf-canary-scale-config.yml @@ -6,3 +6,21 @@ runner_types: is_ephemeral: false max_available: 3120 os: linux + lf.c.linux.4xlarge: + disk_size: 150 + instance_type: c5.4xlarge + is_ephemeral: false + max_available: 1000 + os: linux + lf.c.linux.4xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.4xlarge + is_ephemeral: false + max_available: 520 + os: linux + lf.c.linux.8xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.8xlarge + is_ephemeral: false + max_available: 400 + os: linux diff --git a/.github/lf-scale-config.yml b/.github/lf-scale-config.yml index 758b7bd90314..cd4e6dc9f4f4 100644 --- a/.github/lf-scale-config.yml +++ b/.github/lf-scale-config.yml @@ -6,3 +6,21 @@ runner_types: is_ephemeral: false max_available: 3120 os: linux + lf.linux.4xlarge: + disk_size: 150 + instance_type: c5.4xlarge + is_ephemeral: false + max_available: 1000 + os: linux + lf.linux.4xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.4xlarge + is_ephemeral: false + max_available: 520 + os: linux + lf.linux.8xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.8xlarge + is_ephemeral: false + max_available: 400 + os: linux From b5e85b8eccfa1402287666f21fbe97aafd3fa1f6 Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Fri, 31 May 2024 05:45:28 +0000 Subject: [PATCH 152/706] Add deferred_runtime_assertion pass after run_decompositions (#127305) Summary: We also want to reinsert the deferred_runtime passes after run_decompositions as well Test Plan: CI Reviewed By: zhxchen17 Differential Revision: D57802237 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127305 Approved by: https://github.com/BoyuanFeng --- test/export/test_export.py | 9 +++++ torch/export/exported_program.py | 59 ++++++++++++++++++++++--------- torch/fx/passes/runtime_assert.py | 21 ++++++++--- 3 files changed, 68 insertions(+), 21 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index fcc0db08c701..859a46e80b93 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3312,6 +3312,15 @@ def forward(self, x): "torch.ops.aten._assert_scalar.default", 1, exactly=True ).run(ep.graph_module.code) + ep = ep.run_decompositions() + + FileCheck().check_count( + "torch.ops.aten.sym_constrain_range.default", 1, exactly=True + ).run(ep.graph_module.code) + FileCheck().check_count( + "torch.ops.aten._assert_scalar.default", 1, exactly=True + ).run(ep.graph_module.code) + def test_non_arg_name_dynamic_shapes_api(self): class Foo(torch.nn.Module): def forward(self, a, b): diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index cc6a9e65dd34..bfdeb5db8e0e 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -33,10 +33,13 @@ import torch.utils._pytree as pytree from torch.export._tree_utils import is_equivalent, reorder_kwargs from torch.fx._compatibility import compatibility + +from torch.fx._utils import first_call_function_nn_module_stack from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.infra.pass_manager import PassManager +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from .graph_signature import ( # noqa: F401 _sig_to_specs, @@ -660,6 +663,27 @@ def update_arg(old_arg, new_ph): _replace_sym_size_ops_pass(gm) + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + shape_env = _get_shape_env(gm) + if shape_env is not None: + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + exported_program = ExportedProgram( root=gm, graph=gm.graph, @@ -799,30 +823,31 @@ def _update( ) +def _get_shape_env(gm): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(vals) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + def _get_updated_range_constraints( gm: torch.fx.GraphModule, old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None, _is_executorch: bool = True, ) -> "Dict[sympy.Symbol, Any]": - def get_shape_env(gm): - vals = [ - node.meta["val"] - for node in gm.graph.nodes - if node.meta.get("val", None) is not None - ] - from torch._guards import detect_fake_mode - - fake_mode = detect_fake_mode(vals) - if fake_mode is not None: - return fake_mode.shape_env - for v in vals: - if isinstance(v, torch.SymInt): - return v.node.shape_env - # FIXME(tmanlaibaatar) Remove this whole branch once https://github.com/pytorch/pytorch/pull/123764 if _is_executorch: assert old_range_constraints is None - shape_env = get_shape_env(gm) + shape_env = _get_shape_env(gm) if shape_env is None: return {} range_constraints = { @@ -840,7 +865,7 @@ def get_shape_env(gm): assert old_range_constraints is not None - shape_env = get_shape_env(gm) + shape_env = _get_shape_env(gm) if shape_env is None: return {} diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 12dc62cc16e7..05e7f31ffb4e 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -51,6 +51,9 @@ def insert_deferred_runtime_asserts( # We hash (node_name, min_val, max_val) nodes_that_already_have_sym_constraint_range = set() + + # We hash only node name here because size don't take min/max + nodes_that_already_have_sym_constraint_size = set() # TODO this only works for top-level nodes today, also # we should potentially use it not create duplicate # assert_async nodes @@ -63,6 +66,12 @@ def insert_deferred_runtime_asserts( nodes_that_already_have_sym_constraint_range.add( (node.args[0], node.kwargs["min"], node.kwargs["max"]) ) + if ( + node.op == "call_function" + and node.target == torch.ops.aten.sym_constrain_range_for_size.default + ): + assert len(node.args) == 1 + nodes_that_already_have_sym_constraint_size.add(node.args[0]) # Import sympy locally import sympy @@ -337,10 +346,14 @@ def go(node, keypath): if i0 in shape_env.size_like: if export: - graph.call_function( - torch.ops.aten.sym_constrain_range_for_size.default, - (symbol_to_proxy[i0].node,), - ) + if ( + symbol_to_proxy[i0].node + not in nodes_that_already_have_sym_constraint_size + ): + graph.call_function( + torch.ops.aten.sym_constrain_range_for_size.default, + (symbol_to_proxy[i0].node,), + ) else: graph.call_function( torch._check_is_size, (symbol_to_proxy[i0].node,) From 0c5faee372f27379d8a02260dc650554823ffd15 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 31 May 2024 05:57:05 +0000 Subject: [PATCH 153/706] Replace python::python with Python::Module (#127485) Use found Python::Module target Pull Request resolved: https://github.com/pytorch/pytorch/pull/127485 Approved by: https://github.com/ezyang --- cmake/Dependencies.cmake | 6 ------ torch/CMakeLists.txt | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index e15b55cd16ed..6fb3d967301f 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -881,12 +881,6 @@ if(BUILD_PYTHON) endif() if(Python_Interpreter_FOUND) - add_library(python::python INTERFACE IMPORTED) - target_include_directories(python::python SYSTEM INTERFACE ${Python_INCLUDE_DIRS}) - if(WIN32) - target_link_libraries(python::python INTERFACE ${Python_LIBRARIES}) - endif() - if(USE_NUMPY) if(NOT Python_NumPy_FOUND) message(WARNING "NumPy could not be found. Not building with NumPy. Suppress this warning with -DUSE_NUMPY=OFF") diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index b4db57488f02..c854baf286e8 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -77,7 +77,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES list(APPEND TORCH_PYTHON_INCLUDE_DIRECTORIES ${LIBSHM_SRCDIR}) set(TORCH_PYTHON_LINK_LIBRARIES - python::python + Python::Module pybind::pybind11 opentelemetry::api shm From 4935a019e48c31d85c16698455719c5188c8b880 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Fri, 31 May 2024 06:35:47 +0000 Subject: [PATCH 154/706] [ONNX] Update decomposition table to core ATen ops (#127353) Fixes #125894 Previous to this PR, there are ATen core ops missing in the decomposition table because we thought they might be decomposed into prim ops, as they are under _refs. The PR picks them back according to https://github.com/pytorch/pytorch/blob/f6ef832e87a8ea01e6df93b27a2367cccb6b6171/torch/_decomp/__init__.py#L253 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127353 Approved by: https://github.com/justinchuby --- test/onnx/test_fx_op_consistency.py | 47 ++++++++++++++----- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 2 +- .../onnx/_internal/fx/decomposition_table.py | 8 ++++ 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index 760ede50bd6d..4a4171699e65 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -310,6 +310,11 @@ def skip_torchlib_forward_compatibility( "bincount", reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"), ), + xfail( + "block_diag", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Block_diag", "complex"), + ), xfail( "bmm", dtypes=( @@ -407,10 +412,6 @@ def skip_torchlib_forward_compatibility( "combinations", reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"), ), - xfail( - "cross", - reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"), - ), xfail( "diag", dtypes=onnx_test_common.BOOL_TYPES, @@ -544,6 +545,11 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.COMPLEX_TYPES, reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64") ), + xfail( + "index_fill", + dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES, + reason="fixme: Constant input list has None. ONNXScript does not support None in constant list." + ), xfail( "index_put", dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), @@ -586,6 +592,10 @@ def skip_torchlib_forward_compatibility( variant_name="grad_oriented", reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), ), + xfail( + "linalg.matrix_power", + reason="fixme: The values for attribute 'shape' do not match: torch.Size([2, 2]) != torch.Size([2, 2, 2])." + ), xfail( "linalg.norm", reason="fixme: Assertion error: result mismatch", @@ -963,6 +973,10 @@ def skip_torchlib_forward_compatibility( dtypes=(torch.int32, torch.int64) + onnx_test_common.BOOL_TYPES, reason="fixme: ONNX Runtime does not support int32/64 inputs", ), + xfail( + "nn.functional.pixel_unshuffle", + reason=onnx_test_common.reason_onnx_script_does_not_support("aten.pixel_unshuffle.default"), + ), xfail( "nn.functional.poisson_nll_loss", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, @@ -1107,6 +1121,11 @@ def skip_torchlib_forward_compatibility( variant_name="mean", reason="ONNX doesn't support reduce='mean' option", ), + xfail( + "sgn", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), + ), xfail( "sign", dtypes=onnx_test_common.BOOL_TYPES, @@ -1140,6 +1159,11 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.FLOAT_TYPES, reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"), ), + xfail( + "special.log_ndtr", + dtypes=onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES, + reason="fixme: Assertion error: result mismatch", + ), xfail( "special.ndtr", dtypes=(torch.float16,), @@ -1159,15 +1183,6 @@ def skip_torchlib_forward_compatibility( "svd_lowrank", reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), ), - xfail( - "std_mean", - reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple." - ), - xfail( - "std_mean", - variant_name="unbiased", - reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple." - ), xfail( "stft", reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"), @@ -1961,8 +1976,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): "addr": [3e-3, 4e-3], "baddbmm": [3e-2, 1e-3], "cumulative_trapezoid": [3e-2, 1e-3], + "cross": [3e-2, 2e-2], "diff": [1e-2, 5e-2], "gradient": [3e-3, 4e-3], + "linalg.cross": [1e-3, 2e-2], "linalg.multi_dot": [3e-2, 1e-3], "linalg.vecdot": [1e-2, 2e-2], "linspace": [2e-2, 2e-3], @@ -1977,6 +1994,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): "nn.functional.hardsigmoid": [1e-3, 5e-3], "nn.functional.hardswish": [1e-3, 5e-3], "nn.functional.hinge_embedding_loss": [4e-1, 3e-3], + "nn.functional.huber_loss": [1e-3, 1e-2], "nn.functional.instance_norm": [1e-2, 1e-3], "nn.functional.interpolate": [1e-2, 1e-3], "nn.functional.kl_div": [2e-3, 2e-4], @@ -1984,6 +2002,8 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): "nn.functional.local_response_norm": [1e-2, 5e-3], "nn.functional.poisson_nll_loss": [3e-2, 1e-3], "nn.functional.nll_loss": [3e-2, 1e-3], + "nn.functional.triplet_margin_loss": [2e-2, 1e-2], + "nn.functional.triplet_margin_with_distance_loss": [3e-2, 1e-2], "native_batch_norm": [3e-2, 1e-3], "norm": [1e-2, 1e-2], "dot": [3e-2, 1e-3], @@ -1993,6 +2013,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): "sub": [3e-2, 1e-3], "trapezoid": [1e-3, 7e-3], "trapz": [1e-3, 7e-3], + "vdot": [1e-3, 1e-2], } fp16_low_precision_variant_dict = { diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 5345e0219c14..b70bfbf9c4a7 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -1275,7 +1275,7 @@ def create_model(): model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, ) @pytorch_test_common.xfail_if_model_type_is_exportedprogram( - error_message="aot_autograd expected to have an entirely functional graph", + error_message="n=copy_, n.args[0]=zeros_like, placeholders={", reason="aot_autograd doesn't support it.", ) def test_fake_tensor_mode_huggingface_openai_whisper(self): diff --git a/torch/onnx/_internal/fx/decomposition_table.py b/torch/onnx/_internal/fx/decomposition_table.py index 4f3f705ca867..5cb9be6da79d 100644 --- a/torch/onnx/_internal/fx/decomposition_table.py +++ b/torch/onnx/_internal/fx/decomposition_table.py @@ -111,4 +111,12 @@ def create_onnx_friendly_decomposition_table( ): continue decomposition_table[op_overload] = decomp_fn + + # NOTE: There are ops in core ATen and under torch._refs, + # that are not decomposed to prim::ops. We need to pick them + # back + for op_overload, decomp_fn in torch._decomp.core_aten_decompositions().items(): + if op_overload in _ONNX_SUPPORT_OP_OVERLOADS: + continue + decomposition_table[op_overload] = decomp_fn return decomposition_table From 9eefc04069670e1335d8fbcd3ed474b9174b537b Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 31 May 2024 14:59:12 +0800 Subject: [PATCH 155/706] refine the docstring Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index b83388ad978b..d19aa7be251e 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -414,7 +414,19 @@ def wrapper( @dataclass -class _QuantizationMode: +class _CurrentQuantizationMode: + r"""Configuration defining the current quantization mode for the quantizer. + + All possible current quantization modes are listed below: + ---------------------------------------------------------------------------------------------------------- + | is_dynamic + is_qat |--------------------------------------------------------------------------------------------- + | None | True | False + ---------------------------------------------------------------------------------------------------------- + None | quantizer does not receive a non-None `quantization_config` | \ | \ + False | quantizer will not do QAT | dynamic | static + True | quantizer will do QAT | QAT + dynamic | QAT + static + """ is_qat: Optional[bool] is_dynamic: Optional[bool] @@ -453,7 +465,7 @@ def get_supported_operator_for_quantization_config( return ops return [] - def _get_current_quantization_mode(self) -> _QuantizationMode: + def _get_current_quantization_mode(self) -> _CurrentQuantizationMode: """Retrieves the current quantization mode based on all configurations.""" is_qat = None is_dynamic = None @@ -468,7 +480,7 @@ def _get_current_quantization_mode(self) -> _QuantizationMode: input_activation_spec = qconfig.input_activation if input_activation_spec is not None: is_dynamic = input_activation_spec.is_dynamic - return _QuantizationMode(is_qat=is_qat, is_dynamic=is_dynamic) + return _CurrentQuantizationMode(is_qat=is_qat, is_dynamic=is_dynamic) def _need_skip_config( self, quantization_config: Optional[QuantizationConfig] @@ -476,7 +488,8 @@ def _need_skip_config( """Check if the provided quantization config is valid for X86InductorQuantizer. Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. - If such a mix is detected, the configuration will be marked for skipping.. + To avoid such a mix, we compare the incoming configuration with current configuration status. + Refer the `_CurrentQuantizationMode` definition for all possible modes. """ if quantization_config is None: return False From 2a03bf5a14d13b1f683c18f573cd93a8b1d38212 Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Fri, 31 May 2024 08:01:31 +0000 Subject: [PATCH 156/706] [inductor] fix grid z bug for large grid (#127448) Fixes #123210 https://github.com/pytorch/pytorch/blob/2f3d3ddd70e553d4c5269df699489b82b3aa25ab/torch/_inductor/runtime/triton_heuristics.py#L1733-L1753 If a kernel's y_grid is larger than 65535, it will be split into multiple z grids. The above grad_fn does this split before the kernel launch; however, the computations for yoffset and the y_grid are incorrect. For example, if we have xy numel of `(1*XBLOCK, 65537*YBLOCK)`, this function will return an [xyz]_grid with (1, 32768, 2). XBLOCK and YBLOCK here are used for the following `get_grid_dim`. Let's use their default values (4, 1024). https://github.com/pytorch/pytorch/blob/2f3d3ddd70e553d4c5269df699489b82b3aa25ab/torch/_inductor/runtime/triton_heuristics.py#L1734 [xyz]_grid = (1, 32768, 2) means the workload are divided to two z grids. Because the triton kernel generation still follows xy dimension, one of the exampled generated kernel is shown below. ```python @triton.jit def triton_(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr): ynumel = 65537*1024 xnumel = 1*4 yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK yindex = yoffset + tl.arange(0, YBLOCK)[None, :] ymask = yindex < ynumel xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel x2 = xindex y0 = yindex % 128 y1 = (yindex // 128) y3 = yindex tmp0 = tl.load(in_ptr0 + (y0 + (128*x2) + (512*y1)), xmask, eviction_policy='evict_last') tl.store(out_ptr0 + (x2 + (4*y3)), tmp0, xmask) ``` For a trition block with xyz index (0, 0, 1), its yoffset and xoffset are both 0s based on the compuation `yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK` and `xoffset = tl.program_id(0) * XBLOCK`. So, this triton block will access the very first elements of the input. However, the correct yoffset should be `(y_index + z_index * y_grid ) * YBLOCK` which is the starting position of the 2nd z grid. At the same time, because we used `y_grid = y_grid // div` to compute the maximum number of element in y dimension, the y_grid is 32768. The total y grids is 32768*2 = 65536, which is less than the actual y grids 65537. So, we should use `y_grid = ceildiv(y_grid, div)` to compute the y grid to save the remaining grids. #123210 is not about AOTInductor, the root cause is the triton kernel generated by torchinductor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127448 Approved by: https://github.com/eellison --- test/inductor/test_aot_inductor.py | 27 ++++++-------------- test/inductor/test_torchinductor.py | 17 ++++++++++++ test/inductor/test_triton_heuristics.py | 7 +++-- torch/_inductor/codegen/triton.py | 5 +++- torch/_inductor/runtime/triton_heuristics.py | 2 +- 5 files changed, 33 insertions(+), 25 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index dc4fe6fcbf7d..03d18f7c3f3f 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -968,29 +968,19 @@ class Model(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, primals_1, primals_2, primals_5): - view = torch.ops.aten.reshape.default(primals_5, [-1, 4, 128]) + def forward(self, primals_5): + view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) primals_5 = None permute = torch.ops.aten.permute.default(view, [0, 2, 1]) clone = torch.ops.aten.clone.default( permute, memory_format=torch.contiguous_format ) - permute = None - view_1 = torch.ops.aten.reshape.default(clone, [-1, 4]) - clone = None - permute_1 = torch.ops.aten.permute.default(primals_1, [1, 0]) - primals_1 = None - addmm = torch.ops.aten.addmm.default(primals_2, view_1, permute_1) - primals_2 = None - return addmm - - s0 = 727828 - s1 = 512 - example_inputs = ( - torch.rand(2, 4, device=self.device), - torch.rand(2, device=self.device), - torch.rand(s0, s1, device=self.device), - ) + return clone + + # let y_grid = 65537 + s0 = 16777472 + s1 = 8 + example_inputs = (torch.rand(s0, s1, device=self.device),) self.check_model(Model(), example_inputs) def test_cond_simple(self): @@ -3065,7 +3055,6 @@ def fail_non_abi_compatible_cuda(is_skip=False): CUDA_TEST_FAILURES = { # test_failures, xfail by default, set is_skip=True to skip - "test_large_grid": fail_cuda(), "test_normal_functional": fail_abi_compatible_cuda(is_skip=True), # no runtime checks for non_abi_compatible mode "test_runtime_checks": fail_non_abi_compatible_cuda(is_skip=True), diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 5b5a56394372..12ace877df05 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10341,6 +10341,23 @@ def test_generate_rand_fp8(self): t = rand_strided((2, 3), (3, 1), device=self.device, dtype=torch.float8_e4m3fn) self.assertTrue(t.dtype is torch.float8_e4m3fn) + def test_large_grid(self): + # https://github.com/pytorch/pytorch/issues/123210 + def fn(primals_5): + view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) + primals_5 = None + permute = torch.ops.aten.permute.default(view, [0, 2, 1]) + clone = torch.ops.aten.clone.default( + permute, memory_format=torch.contiguous_format + ) + return clone + + s0 = 16777472 + s1 = 8 + compiled_fn = torch._dynamo.optimize()(fn) + actual = compiled_fn(torch.ones(s0, s1)) + self.assertTrue((actual == 1).all()) + @dataclasses.dataclass class TestFailure: diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index d8c74c0a3841..c0908251f85b 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -38,7 +38,7 @@ def test_triton_config(self): def _test_artificial_zgrid(self): def forward(primals_1, primals_2, primals_5): - view = torch.ops.aten.reshape.default(primals_5, [-1, 4, 128]) + view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) primals_5 = None permute = torch.ops.aten.permute.default(view, [0, 2, 1]) clone = torch.ops.aten.clone.default( @@ -53,8 +53,8 @@ def forward(primals_1, primals_2, primals_5): primals_2 = None return addmm - s0 = 727828 - s1 = 512 + s0 = 16777472 + s1 = 8 args = [ torch.rand([2, 4], device=GPU_TYPE), @@ -73,7 +73,6 @@ def forward(primals_1, primals_2, primals_5): ] self.assertEqual(forward(*args), foo_c(*args)) - @unittest.skip("https://github.com/pytorch/pytorch/issues/123210") @expectedFailureXPU def test_artificial_zgrid(self): self._test_artificial_zgrid() diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 785f79d91503..4b0ea92f3bf4 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2341,7 +2341,10 @@ def iteration_ranges_get_pid(self, entry): and not entry.has_zdim and not (isinstance(entry.numel, int) and entry.numel <= get_max_y_grid()) ): - key = f"{key} * (tl.program_id({entry.grid_dim + 1}) + 1)" + # For ynumel larger than max_ygrid, we need to use zdim. + # For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z). + # So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset. + key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))" pid = entry.pid_cache.get(key, key) if self.index_dtype != "tl.int32": return f"{pid}.to({self.index_dtype})" diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 75584a60c0ff..6629e0fe5e77 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1738,7 +1738,7 @@ def grid_fn(meta): max_y_grid = get_max_y_grid() if znumel is None: div = ceildiv(y_grid, max_y_grid) - y_grid = y_grid // div + y_grid = ceildiv(y_grid, div) z_grid = div else: z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None)) From f4d7cdc5e63c786b1f6588eafa53bbc6d33c3826 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 30 May 2024 15:11:02 -0700 Subject: [PATCH 157/706] [dynamo] Add current instruction to BlockStackEntry (#127482) Will be used by exception handling in later PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127482 Approved by: https://github.com/jansel --- torch/_dynamo/symbolic_convert.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index bacb8dff9e36..a0014d3339c4 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -199,6 +199,8 @@ def _step_logger(): @dataclasses.dataclass class BlockStackEntry: + # Current instruction that pushes something to block_stack + inst: Instruction target: Instruction stack_index: Optional[int] = None with_context: Optional[ContextWrappingVariable] = None @@ -1156,11 +1158,11 @@ def jump(self, inst): def SETUP_LOOP(self, inst): # only exists in python<=3.7 - self.block_stack.append(BlockStackEntry(inst.target)) + self.block_stack.append(BlockStackEntry(inst, inst.target)) def SETUP_EXCEPT(self, inst): # only exists in python<=3.7 - self.block_stack.append(BlockStackEntry(inst.target)) + self.block_stack.append(BlockStackEntry(inst, inst.target)) def POP_BLOCK(self, inst): self.block_stack.pop() @@ -1169,7 +1171,7 @@ def SETUP_WITH(self, inst): self.setup_or_before_with(inst) def SETUP_FINALLY(self, inst): - self.block_stack.append(BlockStackEntry(inst.target)) + self.block_stack.append(BlockStackEntry(inst, inst.target)) def BEGIN_FINALLY(self, inst): self.push(None) @@ -1905,9 +1907,11 @@ def setup_or_before_with(self, inst): if target: if isinstance(self, InstructionTranslator): - self.block_stack.append(BlockStackEntry(target, len(self.stack), ctx)) + self.block_stack.append( + BlockStackEntry(inst, target, len(self.stack), ctx) + ) else: - self.block_stack.append(BlockStackEntry(target)) + self.block_stack.append(BlockStackEntry(inst, target)) self.push(exit) self.push(ctx.enter(self)) From df0c69f32d269f8cdc136c9c65d791b6b86ef5e3 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Thu, 30 May 2024 13:01:48 -0700 Subject: [PATCH 158/706] [inductor] Add fallback for collectives size estimation for unbacked (#127562) Differential Revision: [D57982928](https://our.internmc.facebook.com/intern/diff/D57982928) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127562 Approved by: https://github.com/yifuwang --- torch/_inductor/comm_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 3d9233b34370..334ccf5b7e18 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -59,7 +59,7 @@ def get_collective_input_size_bytes(node: ir.IRNode) -> int: # For ease of testing numel = int(numel) else: - numel = V.graph.sizevars.size_hint(numel) + numel = V.graph.sizevars.size_hint(numel, fallback=0) sz_bytes += numel * get_dtype_size(inp.layout.dtype) return sz_bytes From a6bae1f6db3bb86c521dd3c2417f42b8f5e8d705 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 31 May 2024 11:26:24 +0000 Subject: [PATCH 159/706] Remove more caffe2 files (#127511) Remove more caffe2 files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127511 Approved by: https://github.com/r-barnes --- BUILD.bazel | 4 - caffe2/core/blob.h | 130 -- caffe2/core/blob_serialization_gpu.cc | 10 - caffe2/core/common_cudnn.cc | 26 - caffe2/core/common_cudnn.h | 314 ---- caffe2/core/common_gpu.cc | 253 --- caffe2/core/common_gpu.h | 475 ----- caffe2/core/context.h | 227 --- caffe2/core/context_base.h | 168 -- caffe2/core/context_gpu.cu | 669 ------- caffe2/core/context_gpu.h | 354 ---- caffe2/core/event_gpu.cc | 227 --- caffe2/core/flags.h | 4 - caffe2/core/hip/common_miopen.h | 178 -- caffe2/core/hip/common_miopen.hip | 42 - caffe2/core/hip/miopen_wrapper.h | 166 -- caffe2/core/init.h | 179 -- caffe2/core/net.h | 175 -- caffe2/core/numa.h | 3 - caffe2/core/observer.h | 164 -- caffe2/core/operator.h | 1600 ----------------- caffe2/core/operator_gradient.h | 337 ---- caffe2/core/operator_schema.h | 612 ------- caffe2/core/storage.h | 33 - caffe2/core/tensor.h | 674 ------- caffe2/core/tensor_int8.h | 21 - caffe2/core/workspace.h | 342 ---- caffe2/utils/GpuAtomics.cuh | 28 - caffe2/utils/GpuBitonicSort.cuh | 178 -- caffe2/utils/GpuDefs.cuh | 158 -- caffe2/utils/GpuScanUtils.cuh | 133 -- caffe2/utils/bench_utils.cc | 120 -- caffe2/utils/bench_utils.h | 30 - caffe2/utils/cast.h | 49 - caffe2/utils/cast_test.cc | 39 - caffe2/utils/cblas.h | 606 ------- caffe2/utils/cpu_neon.h | 53 - caffe2/utils/cpuid_test.cc | 10 - caffe2/utils/cub_namespace.cuh | 17 - caffe2/utils/eigen_utils.h | 205 --- caffe2/utils/fatal_signal_asan_no_sig_test.cc | 148 -- caffe2/utils/filler.h | 140 -- caffe2/utils/fixed_divisor_test.cc | 80 - caffe2/utils/knob_patcher.cc | 137 -- caffe2/utils/knob_patcher.h | 32 - caffe2/utils/knobs.cc | 76 - caffe2/utils/knobs.h | 26 - caffe2/utils/knobs_test.cc | 34 - caffe2/utils/map_utils.h | 19 - caffe2/utils/murmur_hash3.cc | 450 ----- caffe2/utils/murmur_hash3.h | 34 - caffe2/utils/proto_utils.cc | 715 -------- caffe2/utils/proto_utils.h | 383 ---- caffe2/utils/proto_utils_test.cc | 63 - caffe2/utils/signal_handler.h | 24 - caffe2/utils/simple_queue.h | 79 - caffe2/utils/simple_queue_test.cc | 76 - caffe2/utils/smart_tensor_printer.h | 50 - caffe2/utils/smart_tensor_printer_test.cc | 53 - caffe2/utils/zmq_helper.h | 137 -- 60 files changed, 11769 deletions(-) delete mode 100644 caffe2/core/blob.h delete mode 100644 caffe2/core/blob_serialization_gpu.cc delete mode 100644 caffe2/core/common_cudnn.cc delete mode 100644 caffe2/core/common_cudnn.h delete mode 100644 caffe2/core/common_gpu.cc delete mode 100644 caffe2/core/common_gpu.h delete mode 100644 caffe2/core/context.h delete mode 100644 caffe2/core/context_base.h delete mode 100644 caffe2/core/context_gpu.cu delete mode 100644 caffe2/core/context_gpu.h delete mode 100644 caffe2/core/event_gpu.cc delete mode 100644 caffe2/core/flags.h delete mode 100644 caffe2/core/hip/common_miopen.h delete mode 100644 caffe2/core/hip/common_miopen.hip delete mode 100644 caffe2/core/hip/miopen_wrapper.h delete mode 100644 caffe2/core/init.h delete mode 100644 caffe2/core/net.h delete mode 100644 caffe2/core/numa.h delete mode 100644 caffe2/core/observer.h delete mode 100644 caffe2/core/operator.h delete mode 100644 caffe2/core/operator_gradient.h delete mode 100644 caffe2/core/operator_schema.h delete mode 100644 caffe2/core/storage.h delete mode 100644 caffe2/core/tensor.h delete mode 100644 caffe2/core/tensor_int8.h delete mode 100644 caffe2/core/workspace.h delete mode 100644 caffe2/utils/GpuAtomics.cuh delete mode 100644 caffe2/utils/GpuBitonicSort.cuh delete mode 100644 caffe2/utils/GpuDefs.cuh delete mode 100644 caffe2/utils/GpuScanUtils.cuh delete mode 100644 caffe2/utils/bench_utils.cc delete mode 100644 caffe2/utils/bench_utils.h delete mode 100644 caffe2/utils/cast.h delete mode 100644 caffe2/utils/cast_test.cc delete mode 100644 caffe2/utils/cblas.h delete mode 100644 caffe2/utils/cpu_neon.h delete mode 100644 caffe2/utils/cpuid_test.cc delete mode 100644 caffe2/utils/cub_namespace.cuh delete mode 100644 caffe2/utils/eigen_utils.h delete mode 100644 caffe2/utils/fatal_signal_asan_no_sig_test.cc delete mode 100644 caffe2/utils/filler.h delete mode 100644 caffe2/utils/fixed_divisor_test.cc delete mode 100644 caffe2/utils/knob_patcher.cc delete mode 100644 caffe2/utils/knob_patcher.h delete mode 100644 caffe2/utils/knobs.cc delete mode 100644 caffe2/utils/knobs.h delete mode 100644 caffe2/utils/knobs_test.cc delete mode 100644 caffe2/utils/map_utils.h delete mode 100644 caffe2/utils/murmur_hash3.cc delete mode 100644 caffe2/utils/murmur_hash3.h delete mode 100644 caffe2/utils/proto_utils.cc delete mode 100644 caffe2/utils/proto_utils.h delete mode 100644 caffe2/utils/proto_utils_test.cc delete mode 100644 caffe2/utils/signal_handler.h delete mode 100644 caffe2/utils/simple_queue.h delete mode 100644 caffe2/utils/simple_queue_test.cc delete mode 100644 caffe2/utils/smart_tensor_printer.h delete mode 100644 caffe2/utils/smart_tensor_printer_test.cc delete mode 100644 caffe2/utils/zmq_helper.h diff --git a/BUILD.bazel b/BUILD.bazel index ecbeaab9bbf8..7a2c3a523dfc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -488,10 +488,7 @@ filegroup( filegroup( name = "caffe2_utils_srcs", srcs = [ - "caffe2/utils/bench_utils.cc", "caffe2/utils/cpuid.cc", - "caffe2/utils/murmur_hash3.cc", - "caffe2/utils/proto_utils.cc", "caffe2/utils/proto_wrap.cc", "caffe2/utils/string_utils.cc", "caffe2/utils/threadpool/ThreadPool.cc", @@ -544,7 +541,6 @@ cc_library( ], ) + if_cuda(glob([ "caffe2/**/*.cuh", - "caffe2/image/*.h", ])), copts = CAFFE2_COPTS, visibility = ["//visibility:public"], diff --git a/caffe2/core/blob.h b/caffe2/core/blob.h deleted file mode 100644 index 582328092b26..000000000000 --- a/caffe2/core/blob.h +++ /dev/null @@ -1,130 +0,0 @@ -#ifndef CAFFE2_CORE_BLOB_H_ -#define CAFFE2_CORE_BLOB_H_ - -#include -#include -#include -#include -#include -#include "caffe2/core/common.h" - -#include -#include -#include "caffe2/core/logging.h" -#include "caffe2/core/tensor.h" -#include "caffe2/core/tensor_int8.h" - -namespace caffe2 { - -inline bool BlobIsInt8TensorCPUType(const Blob& blob) { - return blob.meta().Match(); -} - -inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) { - bool is_match = blob.meta().Match(); - if (!is_match) { - return false; - } - const Tensor* tensor = &blob.Get(); - return tensor && *tensor && tensor->GetDeviceType() == device_type; -} - -inline Tensor* BlobSetTensor(Blob* blob, Tensor&& tensor) { - return blob->Reset(new Tensor(std::move(tensor))); -} - -inline Tensor GetSizedTensorWithOptions( - Tensor&& previous_tensor, - at::IntArrayRef dims, - at::TensorOptions options) { - Tensor tensor = std::move(previous_tensor); - if (!tensor.defined()) { - return caffe2::empty(dims, options); - } - if (tensor.GetDevice() == options.device() || - (!tensor.GetDevice().has_index() && - tensor.GetDeviceType() == options.device().type())) { - if (tensor.sizes() != dims) { - // Resize when the dims doesn't match - tensor.Resize(dims); - } - if (tensor.dtype() == options.dtype()) { - tensor.raw_mutable_data(); - } else { - // create a new Tensor when the data_type doesn't match - return caffe2::empty(dims, options); - } - return tensor; - } - return caffe2::empty(dims, options); -} - -// need to keep both functions that returns Tensor* and the one -// returns Tensor for clangr codemod -inline Tensor* -BlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) { - if (blob->IsType()) { - Tensor* tensor = blob->GetMutable(); - if (*tensor) { - // We only compare device_type if the index is not set since there are Tensors - // TODO: remove the extra check when all the Tensors are properly initialized - const auto tensorDevice = tensor->GetDevice(); - if (tensorDevice == options.device() || (!tensorDevice.has_index() && tensor->GetDeviceType() == options.device().type())) { - if (tensor->sizes() != dims) { - // Resize when the dims doesn't match - tensor->Resize(dims); - } - tensor->raw_mutable_data(options.dtype()); - return tensor; - } - // create a new Tensor when device doesn't match - } - } - - VLOG(1) << "Create new mutable object " << TypeMeta::TypeName() - << " dims: " << dims; - // << " options: " << options; (operator<< for Options is in at:: now) - return BlobSetTensor(blob, caffe2::empty(dims, options)); -} - -inline Tensor -XBlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) { - return BlobGetMutableTensor(blob, dims, options)->UnsafeSharedInstance(); -} - -inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) { - if (blob->IsType()) { - Tensor* tensor = blob->GetMutable(); - if (*tensor && tensor->GetDeviceType() == device_type) { - return tensor; - } - } - - // if we're here, then either Blob didn't hold a Tensor - // or that Tensor had the wrong DeviceType. - VLOG(1) << "Create new mutable object " << TypeMeta::TypeName() - << " DeviceType:" << device_type; - - return BlobSetTensor(blob, Tensor(device_type)); -} - -inline const Tensor& BlobGetTensor(const Blob& blob, DeviceType device_type) { - if (blob.IsType()) { - const auto& tensor = blob.Get(); - if (tensor.GetDeviceType() == device_type) { - return tensor; - } - } - CAFFE_THROW("Blob didn't contain a Tensor or the device_type doesn't match"); -} - -inline Tensor BlobGetTensorOrUndefined(const Blob& blob) { - if (blob.IsType()) { - return blob.Get().UnsafeSharedInstance(); - } else { - return Tensor(); - } -} - -} // namespace caffe2 -#endif // CAFFE2_CORE_BLOB_H_ diff --git a/caffe2/core/blob_serialization_gpu.cc b/caffe2/core/blob_serialization_gpu.cc deleted file mode 100644 index 4d675354531c..000000000000 --- a/caffe2/core/blob_serialization_gpu.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "caffe2/core/blob.h" -#include "caffe2/core/blob_serialization.h" -#include "caffe2/core/context_gpu.h" - -namespace caffe2 { - -namespace { -REGISTER_BLOB_DESERIALIZER(TensorCUDA, TensorDeserializer); -} -} // namespace caffe2 diff --git a/caffe2/core/common_cudnn.cc b/caffe2/core/common_cudnn.cc deleted file mode 100644 index f8186544054a..000000000000 --- a/caffe2/core/common_cudnn.cc +++ /dev/null @@ -1,26 +0,0 @@ -#include "caffe2/core/common_cudnn.h" -#include "caffe2/core/cudnn_wrappers.h" - -#include "caffe2/core/init.h" - -namespace caffe2 { - -CuDNNWrapper::PerGPUCuDNNStates& CuDNNWrapper::cudnn_states() { - // New it (never delete) to avoid calling the destructors on process - // exit and racing against the CUDA shutdown sequence. - static auto* p = new CuDNNWrapper::PerGPUCuDNNStates(); - TORCH_CHECK_NOTNULL(p); - return *p; -} - -namespace { -bool PrintCuDNNInfo(int*, char***) { - VLOG(1) << "Caffe2 is built with CuDNN version " << CUDNN_VERSION; - return true; -} - -REGISTER_CAFFE2_INIT_FUNCTION(PrintCuDNNInfo, &PrintCuDNNInfo, - "Print CuDNN Info."); - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/common_cudnn.h b/caffe2/core/common_cudnn.h deleted file mode 100644 index b130103fb5cb..000000000000 --- a/caffe2/core/common_cudnn.h +++ /dev/null @@ -1,314 +0,0 @@ -#ifndef CAFFE2_CORE_COMMON_CUDNN_H_ -#define CAFFE2_CORE_COMMON_CUDNN_H_ - -#include -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/context.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/types.h" - -#ifndef CAFFE2_USE_CUDNN -#error("This Caffe2 install is not built with cudnn, so you should not include this file."); -#endif - -#include - -static_assert( - CUDNN_VERSION >= 8200, - "Caffe2 requires cudnn version 8.2 or above."); - -#define CUDNN_VERSION_MIN(major, minor, patch) \ - (major >= 9 ? CUDNN_VERSION >= ((major) * 10000 + (minor) * 100 + (patch)) : \ - CUDNN_VERSION >= ((major) * 1000 + (minor) * 100 + (patch))) - -namespace caffe2 { - -namespace internal { -/** - * A helper function to obtain cudnn error strings. - */ -inline const char* cudnnGetErrorString(cudnnStatus_t status) { - switch (status) { - case CUDNN_STATUS_SUCCESS: - return "CUDNN_STATUS_SUCCESS"; - case CUDNN_STATUS_NOT_INITIALIZED: - return "CUDNN_STATUS_NOT_INITIALIZED"; - case CUDNN_STATUS_ALLOC_FAILED: - return "CUDNN_STATUS_ALLOC_FAILED"; - case CUDNN_STATUS_BAD_PARAM: - return "CUDNN_STATUS_BAD_PARAM"; - case CUDNN_STATUS_INTERNAL_ERROR: - return "CUDNN_STATUS_INTERNAL_ERROR"; - case CUDNN_STATUS_INVALID_VALUE: - return "CUDNN_STATUS_INVALID_VALUE"; - case CUDNN_STATUS_ARCH_MISMATCH: - return "CUDNN_STATUS_ARCH_MISMATCH"; - case CUDNN_STATUS_MAPPING_ERROR: - return "CUDNN_STATUS_MAPPING_ERROR"; - case CUDNN_STATUS_EXECUTION_FAILED: - return "CUDNN_STATUS_EXECUTION_FAILED"; - case CUDNN_STATUS_NOT_SUPPORTED: - return "CUDNN_STATUS_NOT_SUPPORTED"; - case CUDNN_STATUS_LICENSE_ERROR: - return "CUDNN_STATUS_LICENSE_ERROR"; - default: - return "Unknown cudnn error number"; - } -} -} // namespace internal - -// A macro that wraps around a cudnn statement so we can check if the cudnn -// execution finishes or not. -#define CUDNN_ENFORCE(condition) \ - do { \ - cudnnStatus_t status = condition; \ - CAFFE_ENFORCE_EQ( \ - status, \ - CUDNN_STATUS_SUCCESS, \ - ", Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - ::caffe2::internal::cudnnGetErrorString(status)); \ - } while (0) -#define CUDNN_CHECK(condition) \ - do { \ - cudnnStatus_t status = condition; \ - CHECK(status == CUDNN_STATUS_SUCCESS) \ - << ::caffe2::internal::cudnnGetErrorString(status); \ - } while (0) - -// report the version of cuDNN Caffe2 was compiled with -inline size_t cudnnCompiledVersion() { - return CUDNN_VERSION; -} -// report the runtime version of cuDNN -inline size_t cudnnRuntimeVersion() { - return cudnnGetVersion(); -} - -// Check compatibility of compiled and runtime cuDNN versions -inline void CheckCuDNNVersions() { - // Version format is major*1000 + minor*100 + patch - // If compiled with version < 7, major, minor and patch must all match - // If compiled with version >= 7, then either - // runtime_version > compiled_version - // major and minor match - bool version_match = cudnnCompiledVersion() == cudnnRuntimeVersion(); - bool compiled_with_7 = cudnnCompiledVersion() >= 7000; - bool backwards_compatible_7 = compiled_with_7 && cudnnRuntimeVersion() >= cudnnCompiledVersion(); - bool patch_compatible = compiled_with_7 && (cudnnRuntimeVersion() / 100) == (cudnnCompiledVersion() / 100); - CAFFE_ENFORCE(version_match || backwards_compatible_7 || patch_compatible, - "cuDNN compiled (", cudnnCompiledVersion(), ") and " - "runtime (", cudnnRuntimeVersion(), ") versions mismatch"); -} - -/** - * cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type - * in a template function. The class is specialized explicitly for different - * data types below. - */ -template -class cudnnTypeWrapper; - -template <> -class cudnnTypeWrapper { - public: - static const cudnnDataType_t type = CUDNN_DATA_FLOAT; - typedef const float ScalingParamType; - typedef float BNParamType; - static ScalingParamType* kOne() { - static ScalingParamType v = 1.0; - return &v; - } - static const ScalingParamType* kZero() { - static ScalingParamType v = 0.0; - return &v; - } -}; - -template <> -class cudnnTypeWrapper { - public: - static const cudnnDataType_t type = CUDNN_DATA_INT32; - typedef const int ScalingParamType; - typedef int BNParamType; - static ScalingParamType* kOne() { - static ScalingParamType v = 1; - return &v; - } - static const ScalingParamType* kZero() { - static ScalingParamType v = 0; - return &v; - } -}; - -template <> -class cudnnTypeWrapper { - public: - static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; - typedef const double ScalingParamType; - typedef double BNParamType; - static ScalingParamType* kOne() { - static ScalingParamType v = 1.0; - return &v; - } - static ScalingParamType* kZero() { - static ScalingParamType v = 0.0; - return &v; - } -}; - -template <> -class cudnnTypeWrapper { - public: - static const cudnnDataType_t type = CUDNN_DATA_HALF; - typedef const float ScalingParamType; - typedef float BNParamType; - static ScalingParamType* kOne() { - static ScalingParamType v = 1.0; - return &v; - } - static ScalingParamType* kZero() { - static ScalingParamType v = 0.0; - return &v; - } -}; - -/** - * A wrapper function to convert the Caffe storage order to cudnn storage order - * enum values. - */ -inline cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder& order) { - switch (order) { - case StorageOrder::NHWC: - return CUDNN_TENSOR_NHWC; - case StorageOrder::NCHW: - return CUDNN_TENSOR_NCHW; - default: - LOG(FATAL) << "Unknown cudnn equivalent for order: " << order; - } - // Just to suppress compiler warnings - return CUDNN_TENSOR_NCHW; -} - -/** - * cudnnTensorDescWrapper is the placeholder that wraps around a - * cudnnTensorDescriptor_t, allowing us to do descriptor change as-needed during - * runtime. - */ -class cudnnTensorDescWrapper { - public: - cudnnTensorDescWrapper() { - CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_)); - } - ~cudnnTensorDescWrapper() noexcept { - CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_)); - } - - inline cudnnTensorDescriptor_t Descriptor( - const cudnnTensorFormat_t format, - const cudnnDataType_t type, - const vector& dims, - bool* changed) { - if (type_ == type && format_ == format && dims_ == dims) { - // if not changed, simply return the current descriptor. - if (changed) - *changed = false; - return desc_; - } - CAFFE_ENFORCE_EQ( - dims.size(), 4U, "Currently only 4-dimensional descriptor supported."); - format_ = format; - type_ = type; - dims_ = dims; - CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( - desc_, - format, - type, - dims_[0], - (format == CUDNN_TENSOR_NCHW ? dims_[1] : dims_[3]), - (format == CUDNN_TENSOR_NCHW ? dims_[2] : dims_[1]), - (format == CUDNN_TENSOR_NCHW ? dims_[3] : dims_[2]))); - if (changed) - *changed = true; - return desc_; - } - - template - inline cudnnTensorDescriptor_t Descriptor( - const StorageOrder& order, - const vector& dims) { - return Descriptor( - GetCudnnTensorFormat(order), cudnnTypeWrapper::type, dims, nullptr); - } - - private: - cudnnTensorDescriptor_t desc_; - cudnnTensorFormat_t format_; - cudnnDataType_t type_; - vector dims_; - C10_DISABLE_COPY_AND_ASSIGN(cudnnTensorDescWrapper); -}; - -class cudnnFilterDescWrapper { - public: - cudnnFilterDescWrapper() { - CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&desc_)); - } - ~cudnnFilterDescWrapper() noexcept { - CUDNN_CHECK(cudnnDestroyFilterDescriptor(desc_)); - } - - inline cudnnFilterDescriptor_t Descriptor( - const StorageOrder& order, - const cudnnDataType_t type, - const vector& dims, - bool* changed) { - if (type_ == type && order_ == order && dims_ == dims) { - // if not changed, simply return the current descriptor. - if (changed) - *changed = false; - return desc_; - } - CAFFE_ENFORCE_EQ( - dims.size(), 4U, "Currently only 4-dimensional descriptor supported."); - order_ = order; - type_ = type; - dims_ = dims; - CUDNN_ENFORCE(cudnnSetFilter4dDescriptor( - desc_, - type, - GetCudnnTensorFormat(order), - dims_[0], - // TODO - confirm that this is correct for NHWC - (order == StorageOrder::NCHW ? dims_[1] : dims_[3]), - (order == StorageOrder::NCHW ? dims_[2] : dims_[1]), - (order == StorageOrder::NCHW ? dims_[3] : dims_[2]))); - if (changed) - *changed = true; - return desc_; - } - - template - inline cudnnFilterDescriptor_t Descriptor( - const StorageOrder& order, - const vector& dims) { - return Descriptor(order, cudnnTypeWrapper::type, dims, nullptr); - } - - private: - cudnnFilterDescriptor_t desc_; - StorageOrder order_; - cudnnDataType_t type_; - vector dims_; - C10_DISABLE_COPY_AND_ASSIGN(cudnnFilterDescWrapper); -}; - - -} // namespace caffe2 - -#endif // CAFFE2_CORE_COMMON_CUDNN_H_ diff --git a/caffe2/core/common_gpu.cc b/caffe2/core/common_gpu.cc deleted file mode 100644 index e5a26359d3f2..000000000000 --- a/caffe2/core/common_gpu.cc +++ /dev/null @@ -1,253 +0,0 @@ -#include "caffe2/core/common_gpu.h" - -#include -#include -#include -#include - -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/init.h" -#include "caffe2/core/logging.h" - -namespace caffe2 { - -int NumCudaDevices() { - if (getenv("CAFFE2_DEBUG_CUDA_INIT_ORDER")) { - static bool first = true; - if (first) { - first = false; - std::cerr << "DEBUG: caffe2::NumCudaDevices() invoked for the first time" - << std::endl; - } - } - // It logs warnings on first run - return c10::cuda::device_count(); -} - -namespace { -int gDefaultGPUID = 0; -} // namespace - -void SetDefaultGPUID(const int deviceid) { - CAFFE_ENFORCE_LT( - deviceid, - NumCudaDevices(), - "The default gpu id should be smaller than the number of gpus " - "on this machine: ", - deviceid, - " vs ", - NumCudaDevices()); - gDefaultGPUID = deviceid; -} - -int GetDefaultGPUID() { return gDefaultGPUID; } - -int CaffeCudaGetDevice() { - int gpu_id = 0; - CUDA_ENFORCE(cudaGetDevice(&gpu_id)); - return gpu_id; -} - -void CaffeCudaSetDevice(const int id) { - CUDA_ENFORCE(cudaSetDevice(id)); -} - -int GetGPUIDForPointer(const void* ptr) { - cudaPointerAttributes attr; - cudaError_t err = cudaPointerGetAttributes(&attr, ptr); - - if (err == cudaErrorInvalidValue) { - // Occurs when the pointer is in the CPU address space that is - // unmanaged by CUDA; make sure the last error state is cleared, - // since it is persistent - err = cudaGetLastError(); - CHECK(err == cudaErrorInvalidValue); - return -1; - } - - // Otherwise, there must be no error - CUDA_ENFORCE(err); - - if (attr.type == cudaMemoryTypeHost) { - return -1; - } - - return attr.device; -} - -struct CudaDevicePropWrapper { - CudaDevicePropWrapper() : props(NumCudaDevices()) { - for (int i = 0; i < NumCudaDevices(); ++i) { - CUDA_ENFORCE(cudaGetDeviceProperties(&props[i], i)); - } - } - - vector props; -}; - -const cudaDeviceProp& GetDeviceProperty(const int deviceid) { - // According to C++11 standard section 6.7, static local variable init is - // thread safe. See - // https://stackoverflow.com/questions/8102125/is-local-static-variable-initialization-thread-safe-in-c11 - // for details. - static CudaDevicePropWrapper props; - CAFFE_ENFORCE_LT( - deviceid, - NumCudaDevices(), - "The gpu id should be smaller than the number of gpus ", - "on this machine: ", - deviceid, - " vs ", - NumCudaDevices()); - return props.props[deviceid]; -} - -void DeviceQuery(const int device) { - const cudaDeviceProp& prop = GetDeviceProperty(device); - std::stringstream ss; - ss << std::endl; - ss << "Device id: " << device << std::endl; - ss << "Major revision number: " << prop.major << std::endl; - ss << "Minor revision number: " << prop.minor << std::endl; - ss << "Name: " << prop.name << std::endl; - ss << "Total global memory: " << prop.totalGlobalMem << std::endl; - ss << "Total shared memory per block: " << prop.sharedMemPerBlock - << std::endl; - ss << "Total registers per block: " << prop.regsPerBlock << std::endl; - ss << "Warp size: " << prop.warpSize << std::endl; -#if !defined(USE_ROCM) - ss << "Maximum memory pitch: " << prop.memPitch << std::endl; -#endif - ss << "Maximum threads per block: " << prop.maxThreadsPerBlock - << std::endl; - ss << "Maximum dimension of block: " - << prop.maxThreadsDim[0] << ", " << prop.maxThreadsDim[1] << ", " - << prop.maxThreadsDim[2] << std::endl; - ss << "Maximum dimension of grid: " - << prop.maxGridSize[0] << ", " << prop.maxGridSize[1] << ", " - << prop.maxGridSize[2] << std::endl; - ss << "Clock rate: " << prop.clockRate << std::endl; - ss << "Total constant memory: " << prop.totalConstMem << std::endl; -#if !defined(USE_ROCM) - ss << "Texture alignment: " << prop.textureAlignment << std::endl; - ss << "Concurrent copy and execution: " - << (prop.deviceOverlap ? "Yes" : "No") << std::endl; -#endif - ss << "Number of multiprocessors: " << prop.multiProcessorCount - << std::endl; -#if !defined(USE_ROCM) - ss << "Kernel execution timeout: " - << (prop.kernelExecTimeoutEnabled ? "Yes" : "No") << std::endl; -#endif - LOG(INFO) << ss.str(); - return; -} - -bool GetCudaPeerAccessPattern(vector >* pattern) { - int gpu_count; - if (cudaGetDeviceCount(&gpu_count) != cudaSuccess) return false; - pattern->clear(); - pattern->resize(gpu_count, vector(gpu_count, false)); - for (int i = 0; i < gpu_count; ++i) { - for (int j = 0; j < gpu_count; ++j) { - int can_access = true; - if (i != j) { - if (cudaDeviceCanAccessPeer(&can_access, i, j) - != cudaSuccess) { - return false; - } - } - (*pattern)[i][j] = static_cast(can_access); - } - } - return true; -} - -bool TensorCoreAvailable() { - int device = CaffeCudaGetDevice(); - auto& prop = GetDeviceProperty(device); - - return prop.major >= 7; -} - -const char* cublasGetErrorString(cublasStatus_t error) { - switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; -#if !defined(USE_ROCM) - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; -#endif - } - // To suppress compiler warning. - return "Unrecognized cublas error string"; -} - -const char* curandGetErrorString(curandStatus_t error) { - switch (error) { - case CURAND_STATUS_SUCCESS: - return "CURAND_STATUS_SUCCESS"; - case CURAND_STATUS_VERSION_MISMATCH: - return "CURAND_STATUS_VERSION_MISMATCH"; - case CURAND_STATUS_NOT_INITIALIZED: - return "CURAND_STATUS_NOT_INITIALIZED"; - case CURAND_STATUS_ALLOCATION_FAILED: - return "CURAND_STATUS_ALLOCATION_FAILED"; - case CURAND_STATUS_TYPE_ERROR: - return "CURAND_STATUS_TYPE_ERROR"; - case CURAND_STATUS_OUT_OF_RANGE: - return "CURAND_STATUS_OUT_OF_RANGE"; - case CURAND_STATUS_LENGTH_NOT_MULTIPLE: - return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; - case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: - return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; - case CURAND_STATUS_LAUNCH_FAILURE: - return "CURAND_STATUS_LAUNCH_FAILURE"; - case CURAND_STATUS_PREEXISTING_FAILURE: - return "CURAND_STATUS_PREEXISTING_FAILURE"; - case CURAND_STATUS_INITIALIZATION_FAILED: - return "CURAND_STATUS_INITIALIZATION_FAILED"; - case CURAND_STATUS_ARCH_MISMATCH: - return "CURAND_STATUS_ARCH_MISMATCH"; - case CURAND_STATUS_INTERNAL_ERROR: - return "CURAND_STATUS_INTERNAL_ERROR"; -#if defined(USE_ROCM) - case HIPRAND_STATUS_NOT_IMPLEMENTED: - return "HIPRAND_STATUS_NOT_IMPLEMENTED"; -#endif - } - // To suppress compiler warning. - return "Unrecognized curand error string"; -} - -// Turn on the flag g_caffe2_has_cuda_linked to true for HasCudaRuntime() -// function. -namespace { -class CudaRuntimeFlagFlipper { - public: - CudaRuntimeFlagFlipper() { - internal::SetCudaRuntimeFlag(); - } -}; -static CudaRuntimeFlagFlipper g_flipper; -} // namespace - -} // namespace caffe2 diff --git a/caffe2/core/common_gpu.h b/caffe2/core/common_gpu.h deleted file mode 100644 index 011f46264b19..000000000000 --- a/caffe2/core/common_gpu.h +++ /dev/null @@ -1,475 +0,0 @@ -#ifndef CAFFE2_CORE_COMMON_GPU_H_ -#define CAFFE2_CORE_COMMON_GPU_H_ - -#include -#include -#include - -#if !defined(USE_ROCM) -#ifdef __GNUC__ -#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6) -#pragma GCC diagnostic push -#endif -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // __GNUC__ -#endif // USE_ROCM - -#include -#include -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" - -#include "c10/cuda/CUDAMacros.h" -#include "c10/cuda/CUDAMathCompat.h" -#include - -#define CAFFE2_CUDA_EXPORT C10_EXPORT - -// CAFFE2_CUDA_API gets translated to CAFFE2_HIP_API in hipify script, which -// causes a marco redefinition issue with the later definition of -// CAFFE2_HIP_API, so we exclude this definition when HIP is specified -#if !defined(USE_ROCM) -#define CAFFE2_CUDA_API TORCH_CUDA_CPP_API -#endif // USE_ROCM - -//TODO: [ROCm] Need to remove this after CUDA->HIP mapping is updated. -#define CAFFE2_HIP_EXPORT C10_EXPORT -#define CAFFE2_HIP_API TORCH_HIP_API - -// This is a macro defined for cuda fp16 support. In default, cuda fp16 is -// supported by NVCC 7.5, but it is also included in the Tegra X1 platform with -// a (custom?) NVCC 7.0. As a result, we would normally just check the cuda -// version here, but would also allow a use to pass in the flag -// CAFFE_HAS_CUDA_FP16 manually. - -#ifndef CAFFE_HAS_CUDA_FP16 -#define CAFFE_HAS_CUDA_FP16 -#endif // CAFFE_HAS_CUDA_FP16 - -#ifdef CAFFE_HAS_CUDA_FP16 -#include -#endif - -// cuda major revision number below which fp16 compute is not supoorted -#if !defined(USE_ROCM) -constexpr int kFp16CUDADevicePropMajor = 6; -#else -constexpr int kFp16CUDADevicePropMajor = 3; -#endif - -// Re-enable strict aliasing diagnostic if it was disabled. -#if !defined(USE_ROCM) -#ifdef __GNUC__ -#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6) -#pragma GCC diagnostic pop -#endif -#endif // __GNUC__ -#endif // USE_ROCM - -/** - * The maximum number of peers that each gpu can have when doing p2p setup. - * Currently, according to NVidia documentation, each device can support a - * system-wide maximum of eight peer connections. - * When Caffe2 sets up peer access resources, if we have more than 8 gpus, - * we will enable peer access in groups of 8. - */ -#define CAFFE2_CUDA_MAX_PEER_SIZE 8 - -namespace caffe2 { - -#if !defined(USE_ROCM) -/** - * Empty class to identify TensorCore-based math - */ -class TensorCoreEngine {}; -#endif // USE_ROCM - -/** - * A runtime function to report the cuda version that Caffe2 is built with. - */ -inline int CudaVersion() { -#if defined(USE_ROCM) - return ROCM_VERSION; -#else - return CUDA_VERSION; -#endif -} - -/** - * Returns the number of devices. - */ -CAFFE2_CUDA_API int NumCudaDevices(); - -/** - * Check if the current running session has a cuda gpu present. - * - * Note that this is different from having caffe2 built with cuda. Building - * Caffe2 with cuda only guarantees that this function exists. If there are no - * cuda gpus present in the machine, or there are hardware configuration - * problems like an insufficient driver, this function will still return false, - * meaning that there is no usable GPU present. - * - * In the open source build, it is possible that Caffe2's GPU code is - * dynamically loaded, and as a result a library could be only linked to the - * CPU code, but want to test if cuda is later available or not. In this case, - * one should use HasCudaRuntime() from common.h. - */ -inline bool HasCudaGPU() { - return NumCudaDevices() > 0; -} - -/** - * Gets the current GPU id. This is a simple wrapper around cudaGetDevice(). - */ -CAFFE2_CUDA_API int CaffeCudaGetDevice(); - -/** - * Gets the current GPU id. This is a simple wrapper around cudaGetDevice(). - */ -CAFFE2_CUDA_API void CaffeCudaSetDevice(const int id); - -/** - * Gets the GPU id that the current pointer is located at. - */ -CAFFE2_CUDA_API int GetGPUIDForPointer(const void* ptr); - -/** - * Gets the device property for the given device. This function is thread safe. - * The initial run on this function is ~1ms/device; however, the results are - * cached so subsequent runs should be much faster. - */ -CAFFE2_CUDA_API const cudaDeviceProp& GetDeviceProperty(const int device); - -/** - * Runs a device query function and prints out the results to LOG(INFO). - */ -CAFFE2_CUDA_API void DeviceQuery(const int deviceid); - -/** - * Return a peer access pattern by returning a matrix (in the format of a - * nested vector) of boolean values specifying whether peer access is possible. - * - * This function returns false if anything wrong happens during the query of - * the GPU access pattern. - */ -CAFFE2_CUDA_API bool GetCudaPeerAccessPattern(vector>* pattern); - -/** - * Return the availability of TensorCores for math - */ -CAFFE2_CUDA_API bool TensorCoreAvailable(); - -/** - * Return a human readable cublas error string. - */ -CAFFE2_CUDA_API const char* cublasGetErrorString(cublasStatus_t error); - -/** - * Return a human readable curand error string. - */ -CAFFE2_CUDA_API const char* curandGetErrorString(curandStatus_t error); - -// CUDA: various checks for different function calls. -#define CUDA_ENFORCE(condition, ...) \ - do { \ - cudaError_t error = condition; \ - CAFFE_ENFORCE_EQ( \ - error, \ - cudaSuccess, \ - "Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - cudaGetErrorString(error), \ - ##__VA_ARGS__); \ - } while (0) -#define CUDA_CHECK(condition) \ - do { \ - cudaError_t error = condition; \ - CHECK(error == cudaSuccess) << cudaGetErrorString(error); \ - } while (0) - -#define CUDA_DRIVERAPI_ENFORCE(condition) \ - do { \ - CUresult result = condition; \ - if (result != CUDA_SUCCESS) { \ - const char* msg; \ - cuGetErrorName(result, &msg); \ - CAFFE_THROW("Error at: ", __FILE__, ":", __LINE__, ": ", msg); \ - } \ - } while (0) -#define CUDA_DRIVERAPI_CHECK(condition) \ - do { \ - CUresult result = condition; \ - if (result != CUDA_SUCCESS) { \ - const char* msg; \ - cuGetErrorName(result, &msg); \ - LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \ - << msg; \ - } \ - } while (0) - -#define CUBLAS_ENFORCE(condition) \ - do { \ - cublasStatus_t status = condition; \ - CAFFE_ENFORCE_EQ( \ - status, \ - CUBLAS_STATUS_SUCCESS, \ - "Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - ::caffe2::cublasGetErrorString(status)); \ - } while (0) -#define CUBLAS_CHECK(condition) \ - do { \ - cublasStatus_t status = condition; \ - CHECK(status == CUBLAS_STATUS_SUCCESS) \ - << ::caffe2::cublasGetErrorString(status); \ - } while (0) - -#define CURAND_ENFORCE(condition) \ - do { \ - curandStatus_t status = condition; \ - CAFFE_ENFORCE_EQ( \ - status, \ - CURAND_STATUS_SUCCESS, \ - "Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - ::caffe2::curandGetErrorString(status)); \ - } while (0) -#define CURAND_CHECK(condition) \ - do { \ - curandStatus_t status = condition; \ - CHECK(status == CURAND_STATUS_SUCCESS) \ - << ::caffe2::curandGetErrorString(status); \ - } while (0) - -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - -#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) \ - for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \ - j += blockDim.y * gridDim.y) - -// The following helper functions are here so that you can write a kernel call -// when you are not particularly interested in maxing out the kernels' -// performance. Usually, this will give you a reasonable speed, but if you -// really want to find the best performance, it is advised that you tune the -// size of the blocks and grids more reasonably. -// A legacy note: this is derived from the old good Caffe days, when I simply -// hard-coded the number of threads and wanted to keep backward compatibility -// for different computation capabilities. -// For more info on CUDA compute capabilities, visit the NVidia website at: -// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities - -// The number of cuda threads to use. Since work is assigned to SMs at the -// granularity of a block, 128 is chosen to allow utilizing more SMs for -// smaller input sizes. -// 1D grid -constexpr int CAFFE_CUDA_NUM_THREADS = 128; -// 2D grid -constexpr int CAFFE_CUDA_NUM_THREADS_2D_DIMX = 16; -constexpr int CAFFE_CUDA_NUM_THREADS_2D_DIMY = 16; - -// The maximum number of blocks to use in the default kernel call. We set it to -// 4096 which would work for compute capability 2.x (where 65536 is the limit). -// This number is very carelessly chosen. Ideally, one would like to look at -// the hardware at runtime, and pick the number of blocks that makes most -// sense for the specific runtime environment. This is a todo item. -// 1D grid -constexpr int CAFFE_MAXIMUM_NUM_BLOCKS = 4096; -// 2D grid -constexpr int CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMX = 128; -constexpr int CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMY = 128; - -constexpr int kCUDAGridDimMaxX = 2147483647; -constexpr int kCUDAGridDimMaxY = 65535; -constexpr int kCUDAGridDimMaxZ = 65535; - -/** - * @brief Compute the number of blocks needed to run N threads. - */ -inline int CAFFE_GET_BLOCKS(const int N) { - return std::max( - std::min( - (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS, - CAFFE_MAXIMUM_NUM_BLOCKS), - // Use at least 1 block, since CUDA does not allow empty block - 1); -} - -/** - * @brief Compute the number of blocks needed to run N threads for a 2D grid - */ -inline dim3 CAFFE_GET_BLOCKS_2D(const int N, const int /* M */) { - dim3 grid; - // Not calling the 1D version for each dim to keep all constants as literals - - grid.x = std::max( - std::min( - (N + CAFFE_CUDA_NUM_THREADS_2D_DIMX - 1) / - CAFFE_CUDA_NUM_THREADS_2D_DIMX, - CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMX), - // Use at least 1 block, since CUDA does not allow empty block - 1); - - grid.y = std::max( - std::min( - (N + CAFFE_CUDA_NUM_THREADS_2D_DIMY - 1) / - CAFFE_CUDA_NUM_THREADS_2D_DIMY, - CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMY), - // Use at least 1 block, since CUDA does not allow empty block - 1); - - return grid; -} - -using CUDAGuard = c10::cuda::CUDAGuard; - -template -struct SimpleArray { - T data[N]; -}; - -constexpr int kCUDATensorMaxDims = 8; - -#define DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1(val, Func, T, ...) \ - do { \ - CAFFE_ENFORCE_LE(val, kCUDATensorMaxDims); \ - switch (val) { \ - case 1: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 2: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 3: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 4: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 5: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 6: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 7: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 8: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - default: { \ - break; \ - } \ - } \ - } while (false) - -#define DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2(val, Func, T1, T2, ...) \ - do { \ - CAFFE_ENFORCE_LE(val, kCUDATensorMaxDims); \ - switch (val) { \ - case 1: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 2: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 3: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 4: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 5: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 6: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 7: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 8: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - default: { \ - break; \ - } \ - } \ - } while (false) - -#define DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_3(val, Func, T1, T2, T3, ...) \ - do { \ - CAFFE_ENFORCE_LE(val, kCUDATensorMaxDims); \ - switch (val) { \ - case 1: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 2: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 3: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 4: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 5: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 6: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 7: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 8: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - default: { \ - break; \ - } \ - } \ - } while (false) - -} // namespace caffe2 - -#endif // CAFFE2_CORE_COMMON_GPU_H_ diff --git a/caffe2/core/context.h b/caffe2/core/context.h deleted file mode 100644 index eb46f78f8b0d..000000000000 --- a/caffe2/core/context.h +++ /dev/null @@ -1,227 +0,0 @@ -#ifndef CAFFE2_CORE_CONTEXT_H_ -#define CAFFE2_CORE_CONTEXT_H_ - -#include -#include -#include -#include - -#include -#include "caffe2/core/allocator.h" -#include "caffe2/core/context_base.h" -#include "caffe2/core/event.h" -#include "caffe2/core/logging.h" -#include "caffe2/proto/caffe2_pb.h" - -#include - -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -#include -#include -#include -#include -#else -#include "caffe2/core/distributions_stubs.h" -#endif - -C10_DECLARE_bool(caffe2_report_cpu_memory_usage); - -namespace caffe2 { - -/** - * A function to generate a random number seed that is unique in a best-effort - * basis, using an ever-incrementing seed and the current time. - */ -TORCH_API uint32_t RandomNumberSeed(); - -/** - * The CPU Context, representing the bare minimum of what a Context class in - * Caffe2 should implement. - * - * // TODO modify docs - * See operator.h, especially Operator, for how Context are used in - * actual operator implementations that are associated with specific devices. - * In general, the Context class is passed in as a template argument, and - * the operator can use the functions defined in the context to execute whatever - * computation it has. - * - */ -class TORCH_API CPUContext final : public BaseContext { - public: -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - class rand_gen_type { - public: - explicit rand_gen_type(uint64_t seed_in = default_rng_seed_val) - : engine_{seed_in} {} - - uint32_t random() { - return engine_(); - } - uint64_t random64() { - uint32_t random1 = engine_(); - uint32_t random2 = engine_(); - return (static_cast(random1) << 32) | random2; - } - - std::optional next_float_normal_sample() { - return next_float_normal_sample_; - } - std::optional next_double_normal_sample() { - return next_double_normal_sample_; - } - void set_next_float_normal_sample(std::optional randn) { - next_float_normal_sample_ = randn; - } - void set_next_double_normal_sample(std::optional randn) { - next_double_normal_sample_ = randn; - } - - private: - at::mt19937 engine_; - std::optional next_float_normal_sample_; - std::optional next_double_normal_sample_; - }; -#else - typedef std::mt19937 rand_gen_type; -#endif - - CPUContext() {} - explicit CPUContext(const DeviceOption& option) - : random_seed_(option.has_random_seed() ? option.random_seed() : 1701), - random_seed_set_(option.has_random_seed() ? true : false) { - CAFFE_ENFORCE_EQ(option.device_type(), PROTO_CPU); - } - explicit CPUContext(const at::Device& device) - : CPUContext(DeviceToOption(device)) {} - - ~CPUContext() noexcept override {} - - inline void SwitchToDevice(int64_t /*stream_id*/) override {} - - using BaseContext::SwitchToDevice; - - inline void WaitEvent(const Event& ev) override { - ev.Wait(CPU, this); - } - - inline void Record(Event* ev, const char* err_msg = nullptr) const override { - CAFFE_ENFORCE(ev, "Event must not be null."); - ev->Record(CPU, this, err_msg); - } - - inline void FinishDeviceComputation() override {} - - inline rand_gen_type* RandGenerator() { - if (!random_generator_.get()) { - random_generator_.reset(new rand_gen_type(RandSeed())); - } - return random_generator_.get(); - } - - inline uint32_t RandSeed() { - if (!random_seed_set_) { - random_seed_ = RandomNumberSeed(); - random_seed_set_ = true; - } - return static_cast(random_seed_); - } - - inline static at::DataPtr New(size_t nbytes) { - return GetCPUAllocator()->allocate(nbytes); - } - - void CopyBytesSameDevice(size_t nbytes, const void* src, void* dst) override; - - void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) override { - CopyBytesSameDevice(nbytes, src, dst); - } - - void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) override { - CopyBytesSameDevice(nbytes, src, dst); - } - - bool SupportsNonFundamentalTypes() const override { - // CPU non fumdamental type copy OK - return true; - } - - template - inline void CopyBytes(size_t nbytes, const void* src, void* dst); - - template - inline void Copy(size_t n, const T* src, T* dst) { - if (c10::guts::is_fundamental::value) { - CopyBytes( - n * sizeof(T), - static_cast(src), - static_cast(dst)); - } else { - for (const auto i : c10::irange(n)) { - dst[i] = src[i]; - } - } - } - - template - inline void - CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { - if (meta.copy()) { - meta.copy()(src, dst, n); - } else { - CopyBytes(n * meta.itemsize(), src, dst); - } - } - - // By default CPU operators don't have async device parts - static bool HasAsyncPartDefault() { - return false; - } - - static bool SupportsAsyncScheduling() { - return false; - } - - // CPU streams are not implemented and are silently ignored by CPU ops, - // return true to signal executor to schedule a CPU op - static bool IsStreamFree( - const DeviceOption& /* option */, - int /* stream_id */) { - return true; - } - - at::Device device() const override { - // TODO: numa? - return at::Device(CPU); - } - - DeviceType device_type() const override { - return CPU; - } - - static constexpr DeviceType GetDeviceType() { - return CPU; - } - - protected: - // TODO(jiayq): instead of hard-coding a generator, make it more flexible. - int random_seed_{1701}; - bool random_seed_set_{false}; - std::unique_ptr random_generator_; -}; - -template <> -inline void CPUContext::CopyBytes( - size_t nbytes, - const void* src, - void* dst) { - if (nbytes == 0) { - return; - } - CAFFE_ENFORCE(src); - CAFFE_ENFORCE(dst); - memcpy(dst, src, nbytes); -} - -} // namespace caffe2 - -#endif // CAFFE2_CORE_CONTEXT_H_ diff --git a/caffe2/core/context_base.h b/caffe2/core/context_base.h deleted file mode 100644 index cc8cc4c5bb60..000000000000 --- a/caffe2/core/context_base.h +++ /dev/null @@ -1,168 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" -#include "caffe2/proto/caffe2_pb.h" - -namespace caffe2 { -class Event; - -} // namespace caffe2 -namespace at { - -class BaseContext; - -/** - * Virtual interface for the Context class in Caffe2. - * - * A Context defines all the necessities to run an operator on a specific - * device. Specific Context classes needs to implement all the pure virtual - * functions in the BaseContext class. - * TODO: add docs after this is finalized. - */ -class TORCH_API BaseContext { - public: - virtual ~BaseContext() noexcept {} - - virtual Device device() const = 0; - - /* Sorry for the naming, will get rid of this in future diff */ - virtual DeviceType device_type() const = 0; - - virtual void SwitchToDevice(int64_t /*stream_id*/) = 0; - - inline void SwitchToDevice() { - SwitchToDevice(0); - } - - virtual void WaitEvent(const caffe2::Event& ev) = 0; - - virtual void Record(caffe2::Event* ev, const char* err_msg = nullptr) - const = 0; - - virtual void FinishDeviceComputation() = 0; - - // This used to be arbitrary cross-device copy, but it turns out everyone - // did direct CPU-X copy, so we just make three functions for it (to avoid - // double dispatch). This will get obsoleted by C10. where copies - // will be proper operators (and get to rely on multiple dispatch there.) - virtual void CopyBytesSameDevice( - size_t nbytes, - const void* src, - void* dst) = 0; - - virtual void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) = 0; - - virtual void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) = 0; - - template - inline void CopySameDevice(size_t n, const T* src, T* dst) { - static_assert( - c10::guts::is_fundamental::value, - "CopySameDevice requires fundamental types"); - CopyBytesSameDevice( - n * sizeof(T), static_cast(src), static_cast(dst)); - } - - template - inline void CopyFromCPU(size_t n, const T* src, T* dst) { - static_assert( - c10::guts::is_fundamental::value, - "CopyFromCPU requires fundamental types"); - CopyBytesFromCPU( - n * sizeof(T), static_cast(src), static_cast(dst)); - } - - template - inline void CopyToCPU(size_t n, const T* src, T* dst) { - static_assert( - c10::guts::is_fundamental::value, "CopyToCPU requires fundamental types"); - CopyBytesToCPU( - n * sizeof(T), static_cast(src), static_cast(dst)); - } - - virtual bool SupportsNonFundamentalTypes() const { - return false; - } - - inline void EnforceMetaCopyOK() { - AT_ASSERTM( - SupportsNonFundamentalTypes(), "Context requires fundamental types"); - } - - void CopyItemsSameDevice( - const caffe2::TypeMeta meta, - size_t n, - const void* src, - void* dst) { - if (meta.copy()) { - EnforceMetaCopyOK(); - meta.copy()(src, dst, n); - } else { - CopyBytesSameDevice(n * meta.itemsize(), src, dst); - } - } - - void CopyItemsFromCPU( - const caffe2::TypeMeta meta, - size_t n, - const void* src, - void* dst) { - if (meta.copy()) { - EnforceMetaCopyOK(); - meta.copy()(src, dst, n); - } else { - CopyBytesFromCPU(n * meta.itemsize(), src, dst); - } - } - - void CopyItemsToCPU( - const caffe2::TypeMeta meta, - size_t n, - const void* src, - void* dst) { - if (meta.copy()) { - EnforceMetaCopyOK(); - meta.copy()(src, dst, n); - } else { - CopyBytesToCPU(n * meta.itemsize(), src, dst); - } - } -}; - -// Context constructor registry -C10_DECLARE_TYPED_REGISTRY( - ContextRegistry, - at::DeviceType, - at::BaseContext, - std::unique_ptr, - at::Device); - -#define REGISTER_CONTEXT(type, ...) \ - C10_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__) - -inline std::unique_ptr CreateContext( - const at::Device& device) { - return at::ContextRegistry()->Create(device.type(), device); -} - -} // namespace at - -namespace caffe2 { - -using at::BaseContext; -using at::CreateContext; -} // namespace caffe2 diff --git a/caffe2/core/context_gpu.cu b/caffe2/core/context_gpu.cu deleted file mode 100644 index ecc933ac7fad..000000000000 --- a/caffe2/core/context_gpu.cu +++ /dev/null @@ -1,669 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include -#include -#include "cub/util_allocator.cuh" - -// Needed to be included first to check the CAFFE2_USE_CUDNN macros. -#include "caffe2/core/macros.h" - -#include "caffe2/core/blob_stats.h" -#ifdef CAFFE2_USE_CUDNN -#include "caffe2/core/common_cudnn.h" -#endif // CAFFE2_USE_CUDNN -#include "caffe2/core/context_gpu.h" -#include "caffe2/core/init.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/tensor.h" -#include "caffe2/utils/string_utils.h" -#include "caffe2/utils/cub_namespace.cuh" - -C10_DEFINE_string( - caffe2_cuda_memory_pool, - "", - "Sets the memory pool used by caffe2. Possible values are " - "none, cnmem, thc and cub."); - -// For description of CUB caching allocator configuration, see -// https://nvlabs.github.io/cub/structcub_1_1_caching_device_allocator.html -C10_DEFINE_int( - caffe2_cub_bin_growth, - 8, - "If using cub as the memory allocator, sets the growth of bins " - "used by the cub pool."); -C10_DEFINE_int( - caffe2_cub_min_bin, - 3, - "If using cub as the memory allocator, sets the min number of " - "bins."); -C10_DEFINE_int( - caffe2_cub_max_bin, - 10, - "If using cub as the memory allocator, sets the max number of " - "bins."); -C10_DEFINE_int( - caffe2_cub_max_managed_mb, - 10 * 1024, - "If using cub as the memory allocators, sets the maximum amount " - "of memory managed in gigabytes"); - -C10_DEFINE_bool( - caffe2_cub_print_allocation_events, - false, - "If true CachingDeviceAllocator will print allocation and deallocation " - "events to stdout."); - -C10_DEFINE_bool( - caffe2_gpu_memory_tracking, - false, - "If set, logs changes in GPU memory allocations"); -C10_DEFINE_int( - caffe2_gpu_memory_report_interval_mb, - 128, - "The threshold in MB on how frequently to report memory changes"); - -namespace at { - -REGISTER_CONTEXT(DeviceType::CUDA, caffe2::CUDAContext); -} // namespace at - -namespace caffe2 { - -// Generic implementation - CUDA will handle the right function to call for us -void CUDAContext::CopyBytesAsync( - size_t nbytes, - const void* src, - Device src_device, - void* dst, - Device dst_device) { - // TODO: verify that the CUDA handles copy from device to device correctly - // even without SetDevice() - // TODO: verify whether source or dest device should be a priority in picking - // the stream - // NB: right now the cross-device copy logic is invoked only in the contexts - // when surrounding code explicitly manages data dependencies and sets up - // events, so it's fine. In order to make it a standalone function proper - // synchronization between stream is required - int gpu_id = 0; - if (dst_device.is_cuda()) { - gpu_id = dst_device.index(); - } else if (src_device.is_cuda()) { - gpu_id = src_device.index(); - } else { - LOG(FATAL) << "shouldn't be called with non-cuda device"; - } - CUDA_ENFORCE(cudaMemcpyAsync( - dst, - src, - nbytes, - cudaMemcpyDefault, - CUDAContext::getCudaObjects().GetStream(gpu_id))); -} - -void CUDAContext::CopyBytesSync( - size_t nbytes, - const void* src, - Device src_device, - void* dst, - Device dst_device) { - // This emulates Caffe2 original behavior where sync copy doesn't change the - // device. It's probably better for clarity to switch to the target device - // explicitly here, but in the worst case CUDA would sync for us. - // TODO: change it to CUDAGuard - CUDAContext context(-1); // take current device - CUDA_ENFORCE(cudaMemcpyAsync( - dst, src, nbytes, cudaMemcpyDefault, context.cuda_stream())); - // destructor of context synchronizes -} - -// For the CPU context, we also allow a (probably expensive) function -// to copy the data from a cuda context. Inside the function, we create -// a temporary CUDAContext object to carry out the copy. From the caller's -// side, these functions are synchronous with respect to the host, similar -// to a normal CPUContext::CopyBytes call. -template <> -inline void CPUContext::CopyBytes( - size_t nbytes, - const void* src, - void* dst) { - CUDAContext context(GetGPUIDForPointer(src)); - context.CopyBytes(nbytes, src, dst); -} -template <> -inline void CPUContext::CopyBytes( - size_t nbytes, - const void* src, - void* dst) { - CUDAContext context(GetGPUIDForPointer(dst)); - context.CopyBytes(nbytes, src, dst); -} - -} // namespace caffe2 - -namespace caffe2 { - -ThreadLocalCUDAObjects& CUDAContext::getCudaObjects() { - static thread_local ThreadLocalCUDAObjects cuda_objects_; - return cuda_objects_; -} - -// TODO(jiayq): these variables shouldn't be currently accessed during static -// initialization. We should consider moving them to a Mayer's singleton to -// be totally safe against SIOF. - -// Static global variables for setting up the memory pool. -CudaMemoryPoolType g_cuda_memory_pool_type; - -std::unique_ptr g_cub_allocator; - -// an unordered map that holds the map from the cuda memory pointer to the -// device id that it is allocated from. This is used in the cuda memory pool -// cases, where we need the device id to carry out the deletion. -// Note(jiayq): an alternate approach is to use cudaGetPointerAttributes, but -// that is usually quite slow. We might want to benchmark the speed difference -// though. -// Note(jiayq): another alternate approach is to augment the Tensor class that -// would allow one to record the device id. However, this does not address any -// non-tensor allocation and deallocation. -// Ideally, a memory pool should already have the device id information, as -// long as we are using UVA (as of CUDA 5 and later) so the addresses are -// unique. -static std::unordered_map g_cuda_device_affiliation; - -// Data structures for optional memory tracking. Access to these structures -// is guarded by the CUDAContext::mutex. -static std::unordered_map g_size_map; -static std::vector g_total_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0); -static std::vector g_max_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0); - -static long g_total_mem = 0; -static long g_last_rep = 0; - -CudaMemoryPoolType GetCudaMemoryPoolType() { - return g_cuda_memory_pool_type; -} - -/////////////////////////////////////////////////////////////////////////////// -// A wrapper to allow us to lazily initialize all cuda environments that Caffe -// uses. This gets done the first time a caffe2::CUDAContext::New() gets called -// which is probably the decisive indication that this caffe2 run is going to -// use GPUs. We avoid cuda initialization with core/init.h functionalities so -// that we have minimal resource impact in case we will need to run multiple -// caffe2 instances on a GPU machine. -/////////////////////////////////////////////////////////////////////////////// - -static void Caffe2InitializeCuda() { - // If the current run does not have any cuda devices, do nothing. - if (!HasCudaGPU()) { - VLOG(1) << "No cuda gpu present. Skipping."; - return; - } - C10_LOG_API_USAGE_ONCE("caffe2.init.cuda"); - // Check if the number of GPUs matches the expected compile-time max number - // of GPUs. - CAFFE_ENFORCE_LE( - NumCudaDevices(), - C10_COMPILE_TIME_MAX_GPUS, - "Number of CUDA devices on the machine is larger than the compiled " - "max number of gpus expected (", - C10_COMPILE_TIME_MAX_GPUS, - "). Increase that and recompile."); - - for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) { - CUDAGuard g(i); - // Enable peer access. - const int peer_group = i / CAFFE2_CUDA_MAX_PEER_SIZE; - const int peer_start = peer_group * CAFFE2_CUDA_MAX_PEER_SIZE; - const int peer_end = std::min( - NumCudaDevices(), (peer_group + 1) * CAFFE2_CUDA_MAX_PEER_SIZE); - VLOG(1) << "Enabling peer access within group #" << peer_group - << ", from gpuid " << peer_start << " to " << peer_end - 1 - << ", for gpuid " << i << "."; - - for (int j = peer_start; j < peer_end; ++j) { - if (i == j) continue; - int can_access; - CUDA_ENFORCE(cudaDeviceCanAccessPeer(&can_access, i, j)); - if (can_access) { - VLOG(1) << "Enabling peer access from " << i << " to " << j; - // Note: just for future reference, the 0 here is not a gpu id, it is - // a reserved flag for cudaDeviceEnablePeerAccess that should always be - // zero currently. - // It is ok if peer access is already enabled... - cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaDeviceEnablePeerAccess(j, 0)); - if ((err != cudaErrorPeerAccessAlreadyEnabled) && - (err != cudaSuccess)) { - CAFFE_THROW(cudaGetErrorString(err)); - } - cudaGetLastError(); // reset cuda error code - } - } - } - -#ifdef CAFFE2_USE_CUDNN - // Check the versions of cuDNN that were compiled and linked with are compatible - CheckCuDNNVersions(); -#endif // CAFFE2_USE_CUDNN -} - -static void SetUpCub() { - VLOG(1) << "Setting up cub memory pool."; - // Sets up the cub memory pool - try { - g_cub_allocator.reset(new cub::CachingDeviceAllocator( - FLAGS_caffe2_cub_bin_growth, - FLAGS_caffe2_cub_min_bin, - FLAGS_caffe2_cub_max_bin, - size_t(FLAGS_caffe2_cub_max_managed_mb) * 1024L * 1024L, - false, - FLAGS_caffe2_cub_print_allocation_events)); - } catch (...) { - CAFFE_THROW("Some error happened at cub initialization."); - } - VLOG(1) << "Done setting up cub memory pool."; -} - -static void Caffe2SetCUDAMemoryPool() { - if (FLAGS_caffe2_cuda_memory_pool == "" || - FLAGS_caffe2_cuda_memory_pool == "none") { - g_cuda_memory_pool_type = CudaMemoryPoolType::NONE; - } else if (FLAGS_caffe2_cuda_memory_pool == "cnmem") { - CAFFE_THROW("CNMEM is no longer used by Caffe2. Use cub instead. " - "This error message may go away in the future."); - } else if (FLAGS_caffe2_cuda_memory_pool == "cub") { - // Sets up cub. - g_cuda_memory_pool_type = CudaMemoryPoolType::CUB; - SetUpCub(); - } else if (FLAGS_caffe2_cuda_memory_pool == "thc") { - g_cuda_memory_pool_type = CudaMemoryPoolType::THC; - // Initialize caching allocator - at::globalContext().lazyInitCUDA(); - } else { - CAFFE_THROW( - "Unrecognized cuda memory pool type: ", FLAGS_caffe2_cuda_memory_pool); - } -} - -/** - * An allocator that does the CPU memory allocation with pinned memory. - * - * This is needed because if we want to do any asynchronous cuda memcpy, - * the underlying CPU memory also needs to be allocated into pinned memory - * space. As a result, whenever Caffe2 is built with GPU and there is - * GPU present during runtime, at global initialization time we will set - * the CPU memory allocator to allocate pinned memory. - * - * NB: This behavior is probably too aggressive. We should consider asking users - * to do on-demand memory pinning (like exposed in PyTorch APIs) instead. - */ -struct CAFFE2_CUDA_API PinnedCPUAllocator final : public at::Allocator { - PinnedCPUAllocator() { - baseAllocator_ = GetDefaultCPUAllocator(); - } - ~PinnedCPUAllocator() override {} - at::DataPtr allocate(size_t nbytes) override { - if (nbytes == 0) { - // replicate c10::alloc_cpu behavior - return nullptr - return {nullptr, nullptr, &Delete, at::Device(CPU)}; - } - void* data; - at::DataPtr data_ptr; - std::lock_guard lock(CUDAContext::mutex()); - if (IsNUMAEnabled()) { - at::DeleterFnPtr expected_deleter = baseAllocator_->raw_deleter(); - data_ptr = baseAllocator_->allocate(nbytes); - data = data_ptr.get(); - CAFFE_ENFORCE(data); - CUDA_ENFORCE(cudaHostRegister(data, nbytes, cudaHostRegisterDefault)); - CAFFE_ENFORCE( - data_ptr.compare_exchange_deleter(expected_deleter, &Delete), - "Failed to swap deleter (already swapped?)"); - } else { - CUDA_ENFORCE(cudaMallocHost(&data, nbytes)); - profiledCPUMemoryReporter().New(data, nbytes); - data_ptr = {data, data, &Delete, at::Device(CPU)}; - } - memset(data, 0, nbytes); - return data_ptr; - } - - at::DeleterFnPtr raw_deleter() const override { - return &Delete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for PinnedCPUAllocator"); - } - - private: - static void Delete(void* data) { - if (!data) { - return; - } - // Caffe2 uses a lazy way to figure out if one is actually going to use GPUs - // or not. If a CUDAContext::New() call is made, inside the CUDAContext - // function we will switch the cpu side allocator to a PinnedCPUAllocator. - // But, if one calls CPUContext::New() before any cuda allocations, - // PinnedCPUAllocator can still delete the corresponding memory. - std::lock_guard lock(CUDAContext::mutex()); - if (IsNUMAEnabled()) { - CUDA_ENFORCE(cudaHostUnregister(data)); - GetDefaultCPUAllocator()->raw_deleter()(data); - } else { - cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaFreeHost(data)); - profiledCPUMemoryReporter().Delete(data); - if (err == cudaErrorInvalidValue) { - free(data); - // Calling cudaGetLastError will reset the cuda error. - cudaError_t _err = cudaGetLastError(); - } else { - // For all other errors, still do a cuda check. - CUDA_ENFORCE(err); - } - } - } - - at::Allocator* baseAllocator_; -}; - -static PinnedCPUAllocator g_pinned_cpu_alloc; - -// An initialization function that sets the CPU side to use pinned cpu -// allocator. -void Caffe2UsePinnedCPUAllocator() { -#if C10_ASAN_ENABLED - // Note(jiayq): for more details, see - // https://github.com/google/sanitizers/issues/629 - LOG(WARNING) << "There are known issues between address sanitizer and " - "cudaMallocHost. As a result, caffe2 will not enable pinned " - "memory allocation in asan mode. If you are expecting any " - "behavior that depends on asan, be advised that it is not " - "turned on."; -#else - if (!HasCudaGPU()) { - VLOG(1) << "No GPU present. I won't use pinned allocator then."; - return; - } - VLOG(1) << "Caffe2 gpu: setting CPUAllocator to PinnedCPUAllocator."; - - // If CUDA is enabled, using CPU allocators other than PinnedCPUAllocator - // will cause memory corruptions. Therefore, we need to set the priority - // to highest to avoid being overwritten. - SetCPUAllocator( - &g_pinned_cpu_alloc, - std::numeric_limits::max() /* priority */); -#endif -} - -// Caffe2CudaInitializerHelper is a minimal struct whose sole purpose is to -// detect the first hint that this Caffe2 run is going to use GPU: either -// CUDAContext is initialized or CUDAContext::New is called. It then runs -// all the related cuda initialization functions. -namespace { -struct Caffe2CudaInitializerHelper { - Caffe2CudaInitializerHelper() { - // We cannot use bool because nvcc changes bool to __nv_bool which does - // not have a std::atomic instantiation. - static std::atomic first_call(1); - if (first_call.fetch_and((char)0)) { - Caffe2InitializeCuda(); - Caffe2SetCUDAMemoryPool(); - Caffe2UsePinnedCPUAllocator(); - } - } -}; -} // namespace - -/** - * A utility function to rectify the gpu id. If the context specifies the - * gpu id to be -1, it means that we will just use the current gpu id when - * the function is being called. - */ -static inline DeviceIndex RectifyGPUID(DeviceIndex gpu_id) { - return gpu_id == -1 ? CaffeCudaGetDevice() : gpu_id; -} - -CUDAContext::CUDAContext(DeviceIndex gpu_id) - : gpu_id_(RectifyGPUID(gpu_id)), random_seed_(RandomNumberSeed()) { - static Caffe2CudaInitializerHelper g_cuda_initializer_; -} - -CUDAContext::CUDAContext(const DeviceOption& option) - : gpu_id_( - option.has_device_id() ? RectifyGPUID(option.device_id()) - : CaffeCudaGetDevice()), - random_seed_( - option.has_random_seed() ? option.random_seed() - : RandomNumberSeed()) { - static Caffe2CudaInitializerHelper g_cuda_initializer_; - TORCH_DCHECK_EQ(option.device_type(), PROTO_CUDA); -} - -CUDAContext::~CUDAContext() { - try { - if (curand_generator_) { - CURAND_CHECK(curandDestroyGenerator(curand_generator_)); - } - // CUDAContext is used in 2 cases now: - // - long-lived instance inside OperatorBase in which case what happens in - // destructor doesn't really matter - // - short-lived on-the-fly instances that are utilized as CUDAGuard - in - // this case there's only one stream id (passed to SwitchToDevice) and - // it's preferrable to synchronize in the destructor - FinishDeviceComputation(); - } catch (const std::exception& e) { - LOG(ERROR) << "Encountered following in " << __FUNCTION__ << ": " << e.what(); - } -} - -// shared mutex to lock out alloc / free during NCCL launches -std::mutex& CUDAContext::mutex() { - static std::mutex m; - return m; -} - -std::vector CUDAContext::TotalMemoryByGpu() { - std::lock_guard lock(CUDAContext::mutex()); - CAFFE_ENFORCE( - FLAGS_caffe2_gpu_memory_tracking, - "Pass --caffe2_gpu_memory_tracking to enable memory stats"); - return g_total_by_gpu_map; -} - -std::vector CUDAContext::MaxMemoryByGpu() { - std::lock_guard lock(CUDAContext::mutex()); - CAFFE_ENFORCE( - FLAGS_caffe2_gpu_memory_tracking, - "Pass --caffe2_gpu_memory_tracking to enable memory stats"); - return g_max_by_gpu_map; -} - -namespace { -void TrackMemoryAlloc(size_t nbytes) { - int this_gpu = CaffeCudaGetDevice(); - g_total_by_gpu_map[this_gpu] += nbytes; - g_max_by_gpu_map[this_gpu] = - std::max(g_max_by_gpu_map[this_gpu], g_total_by_gpu_map[this_gpu]); - g_total_mem += nbytes; - if (g_total_mem - g_last_rep > - FLAGS_caffe2_gpu_memory_report_interval_mb * 1024 * 1024) { - for (int gpu = 0; gpu < g_total_by_gpu_map.size(); gpu++) { - long t = g_total_by_gpu_map[gpu]; - long max_t = g_max_by_gpu_map[gpu]; - if (max_t > 0) { - if (max_t != t) { - VLOG(1) << "GPU " << gpu << ": " << t / 1024 / 1024 << " MB" - << " (max: " << max_t / 1024 / 1024 << " MB)"; - } else { - VLOG(1) << "GPU " << gpu << ": " << t / 1024 / 1024 << " MB"; - } - } - } - VLOG(1) << "Total: " << g_total_mem / 1024 / 1024 << " MB"; - g_last_rep = g_total_mem; - } -} -} - -struct DefaultCUDAAllocator final : public at::Allocator { - DefaultCUDAAllocator() {} - ~DefaultCUDAAllocator() override {} - at::DataPtr allocate(size_t nbytes) override { - // Lock the mutex - std::lock_guard lock(CUDAContext::mutex()); - // A one-time caffe2 cuda initializer. - static Caffe2CudaInitializerHelper g_cuda_initializer_; - void* ptr = nullptr; - - if (FLAGS_caffe2_gpu_memory_tracking) { - TrackMemoryAlloc(nbytes); - } - switch (g_cuda_memory_pool_type) { - case CudaMemoryPoolType::NONE: - if (nbytes != 0) { - CUDA_ENFORCE(cudaMalloc(&ptr, nbytes)); - } - if (FLAGS_caffe2_gpu_memory_tracking) { - g_size_map[ptr] = nbytes; - g_cuda_device_affiliation[ptr] = CaffeCudaGetDevice(); - } - return {ptr, ptr, &Delete, at::Device(CUDA, CaffeCudaGetDevice())}; - case CudaMemoryPoolType::CUB: - if (nbytes != 0) { - CUDA_ENFORCE(g_cub_allocator->DeviceAllocate(&ptr, nbytes)); - } - g_cuda_device_affiliation[ptr] = CaffeCudaGetDevice(); - VLOG(2) << "CUB allocating pointer " << ptr << " on device " - << CaffeCudaGetDevice(); - if (FLAGS_caffe2_gpu_memory_tracking) { - g_size_map[ptr] = nbytes; - } - return {ptr, ptr, &Delete, at::Device(CUDA, CaffeCudaGetDevice())}; - case CudaMemoryPoolType::THC: - { - // The reason we have this stream guard here is to preserve - // the historical behavior of the 'thc' allocator in Caffe2, - // which is to put all allocations on the same (default) - // stream. This behavior is morally wrong (since passing - // allocations between streams allows for the possibility - // of you handing out some memory that an old stream - // is still working on), but it doesn't seem to cause issues - // in Caffe2 today. Our hypothesis for why this is the case - // is that Caffe2 doesn't really do very many allocations - // on the fly; instead they allocate once and then reuse - // the allocations for the whole program. In this case, - // the hazard is avoided. - // - // We intend to remove this stream guard, but the benefit - // to putting all allocations on the same stream is it - // reduces per-stream fragmentation, and this helps - // some models that are currently running with the thc - // allocator fit in memory. We will need to find some - // way of resolving this problem. - c10::cuda::CUDAStreamGuard g( - Stream( - Stream::DEFAULT, - Device(kCUDA, CaffeCudaGetDevice()) - )); - ptr = c10::cuda::CUDACachingAllocator::raw_alloc(nbytes); - } - if (FLAGS_caffe2_gpu_memory_tracking) { - g_size_map[ptr] = nbytes; - g_cuda_device_affiliation[ptr] = CaffeCudaGetDevice(); - } - return {ptr, ptr, &Delete, at::Device(CUDA, CaffeCudaGetDevice())}; - } - return {nullptr, nullptr, &Delete, at::Device(CUDA, CaffeCudaGetDevice())}; - } - - at::DeleterFnPtr raw_deleter() const override { - return &Delete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for DefaultCUDAAllocator"); - } - - private: - static void Delete(void* ptr) { - // lock the mutex - std::lock_guard lock(CUDAContext::mutex()); - if (FLAGS_caffe2_gpu_memory_tracking) { - auto sz_it = g_size_map.find(ptr); - DCHECK(sz_it != g_size_map.end()); - auto aff_it = g_cuda_device_affiliation.find(ptr); - DCHECK(aff_it != g_cuda_device_affiliation.end()); - g_total_mem -= sz_it->second; - g_total_by_gpu_map[aff_it->second] -= sz_it->second; - g_size_map.erase(sz_it); - } - - switch (g_cuda_memory_pool_type) { - case CudaMemoryPoolType::NONE: { - // If memory pool is not set up, use simple cudaFree. - cudaError_t error = C10_CUDA_ERROR_HANDLED(cudaFree(ptr)); - // For some reason, in Python runtime we sometimes delete a data pointer - // after the cuda runtime exits - this is odd but is probably caused by - // a static workspace that pycaffe2 uses, and the destruction got - // entangled in some race condition. Anyway, since cuda runtime is - // exiting anyway, we will not need to worry about memory leak, so we - // basically ignore it. This is definitely not ideal but works for now. - if (error != cudaSuccess && error != cudaErrorCudartUnloading) { - LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " - << cudaGetErrorString(error); - } - - if (FLAGS_caffe2_gpu_memory_tracking) { - g_cuda_device_affiliation.erase(g_cuda_device_affiliation.find(ptr)); - } - - break; - } - case CudaMemoryPoolType::CUB: { - auto it = g_cuda_device_affiliation.find(ptr); - DCHECK(it != g_cuda_device_affiliation.end()); - VLOG(2) << "CUB freeing pointer " << ptr << " on device " << it->second; - CUDA_ENFORCE(g_cub_allocator->DeviceFree(it->second, ptr)); - g_cuda_device_affiliation.erase(it); - break; - } - case CudaMemoryPoolType::THC: { - c10::cuda::CUDACachingAllocator::raw_delete(ptr); - if (FLAGS_caffe2_gpu_memory_tracking) { - g_cuda_device_affiliation.erase(g_cuda_device_affiliation.find(ptr)); - } - break; - } - } - } -}; - -static DefaultCUDAAllocator g_cuda_alloc; -REGISTER_ALLOCATOR(CUDA, &g_cuda_alloc); - -} // namespace caffe2 - -namespace at { -REGISTER_COPY_BYTES_FUNCTION( - DeviceType::CUDA, - DeviceType::CUDA, - caffe2::CUDAContext::CopyBytesSync, - caffe2::CUDAContext::CopyBytesAsync); - -REGISTER_COPY_BYTES_FUNCTION( - DeviceType::CUDA, - DeviceType::CPU, - caffe2::CUDAContext::CopyBytesSync, - caffe2::CUDAContext::CopyBytesAsync); - -REGISTER_COPY_BYTES_FUNCTION( - DeviceType::CPU, - DeviceType::CUDA, - caffe2::CUDAContext::CopyBytesSync, - caffe2::CUDAContext::CopyBytesAsync); -} // namespace at diff --git a/caffe2/core/context_gpu.h b/caffe2/core/context_gpu.h deleted file mode 100644 index 8490a5002e5f..000000000000 --- a/caffe2/core/context_gpu.h +++ /dev/null @@ -1,354 +0,0 @@ -#ifndef CAFFE2_CORE_CONTEXT_GPU_H_ -#define CAFFE2_CORE_CONTEXT_GPU_H_ - -#include -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/context.h" -#include "caffe2/core/context_base.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/numa.h" -#include "caffe2/core/tensor.h" -#include "caffe2/core/types.h" -#include "caffe2/proto/caffe2_pb.h" - -// Since we are using the macro CAFFE2_USE_CUDNN, we will need to include this -// file after common.h is included. -#ifdef CAFFE2_USE_CUDNN -#include "caffe2/core/common_cudnn.h" -#endif // CAFFE2_USE_CUDNN - -#include -#include -#include -#include - -namespace caffe2 { - -enum class CudaMemoryPoolType { - NONE = 0, - CUB = 1, - THC = 2, -}; - -/** - * Gets the current memory pool type used by Caffe2. - * - * The memory pool is set up during caffe2's global initialization time. - */ -CAFFE2_CUDA_API CudaMemoryPoolType GetCudaMemoryPoolType(); - -/** - * A struct to host thread-local cuda objects. - * - * In Caffe2, each thread has its own non-default cuda stream as well as - * related objects such as cublas and curand handles. This is achieved by - * having the ThreadLocalCUDAObjects wrapper that takes care of allocating - * and deallocating these objects at the thread scope. This class is solely - * used inside CUDAContext and should not be used externally. - * - * This class manages the mapping from logical stream ID (int stream_id - * passed around in Caffe2) and CUDAStream objects. We intend to eventually - * deprecate the logical stream ID interface, but not for now. - */ -class CAFFE2_CUDA_API ThreadLocalCUDAObjects { - friend class CUDAContext; - - private: - ThreadLocalCUDAObjects() { - for (DeviceIndex i = 0; i < C10_COMPILE_TIME_MAX_GPUS; ++i) { - cuda_streams_[i] = vector(); - } - } - - // Record current stream id for the current thread. - // This is the new API we're trying to migrate use cases to and get rid of - // explicit stream id passing. For now it's invoked in - // CUDAContext::SwitchToDevice - void SetCurrentStreamId(DeviceIndex gpu, StreamId stream_id) { - // TODO: use current device id from thread local instead of passing gpu in - if (stream_id != -1) { - c10::cuda::setCurrentCUDAStream(GetCUDAStream(gpu, stream_id)); - } - } - - // Retrieves the CUDAStream corresponding to a logical stream ID, ensuring - // that it exists in cuda_streams_ if it has not been allocated yet. - c10::cuda::CUDAStream GetCUDAStream(DeviceIndex gpu, StreamId stream_id) { - vector& gpu_streams = cuda_streams_[gpu]; - while (gpu_streams.size() <= static_cast(stream_id)) { - // NB: This streams are not guaranteed to be unique; we'll - // wrap around once we run out of streams in the pool. - gpu_streams.emplace_back(c10::cuda::getStreamFromPool(/* high priority */ false, gpu)); - } - return gpu_streams[stream_id]; - } - - // Uses the logical stream id from the thread local to pick the stream - // We're going to migrate all usages to this case API instead of passing the - // stream id directly - cudaStream_t GetStream(DeviceIndex gpu) { - return c10::cuda::getCurrentCUDAStream(gpu).stream(); - } - - cudaStream_t GetStream(DeviceIndex gpu, StreamId stream_id) { - return GetCUDAStream(gpu, stream_id).stream(); - } - - // Uses the logical stream id from the thread local to pick the stream - // We're going to migrate all usages to this case API instead of passing the - // stream id directly - cublasHandle_t GetHandle(DeviceIndex gpu) { - return GetHandle(c10::cuda::getCurrentCUDAStream(gpu)); - } - - cublasHandle_t GetHandle(c10::cuda::CUDAStream cuda_stream) { - CUDAGuard guard(cuda_stream.device_index()); - // Default construct in the map if it doesn't exist, and return a mutable - // reference to it. - auto& r = cublas_handles_[cuda_stream]; - if (r == nullptr) { - CUBLAS_ENFORCE(cublasCreate(&r)); - // The default is CUBLAS_POINTER_MODE_HOST. You can override - // it after obtaining the cublas handle, but do that with - // caution. - CUBLAS_ENFORCE(cublasSetPointerMode(r, CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSetStream(r, cuda_stream)); - } - return r; - } - -#ifdef CAFFE2_USE_CUDNN - // Uses the logical stream id from the thread local to pick the stream - // We're going to migrate all usages to this case API instead of passing the - // stream id directly - cudnnHandle_t GetCudnnHandle(DeviceIndex gpu) { - return GetCudnnHandle(c10::cuda::getCurrentCUDAStream(gpu)); - } - - cudnnHandle_t GetCudnnHandle(c10::cuda::CUDAStream cuda_stream) { - CUDAGuard guard(cuda_stream.device_index()); - auto& r = cudnn_handles_[cuda_stream]; - if (r == nullptr) { - CUDNN_ENFORCE(cudnnCreate(&r)); - CUDNN_ENFORCE(cudnnSetStream(r, cuda_stream)); - } - return r; - } -#endif // CAFFE2_USE_CUDNN - - ~ThreadLocalCUDAObjects() noexcept { - for (auto element : cublas_handles_) { - if (element.second) { - CUBLAS_CHECK(cublasDestroy(element.second)); - } - } -#ifdef CAFFE2_USE_CUDNN - for (auto element : cudnn_handles_) { - if (element.second) { -#ifdef _WIN32 - // this is because of something dumb in the ordering of - // destruction. Sometimes at exit, the cuda context would already - // be destroyed by the time this gets destroyed. This happens on - // windows with cuda 11 and cuda 12. - cudnnDestroy(element.second); -#else - CUDNN_CHECK(cudnnDestroy(element.second)); -#endif // _WIN32 - } - } -#endif // CAFFE2_USE_CUDNN - } - // WARNING: mapping from logical stream ID to c10::cuda::CUDAStream - // is NOT bijective; multiple logical stream IDs may map to the - // same underlying stream ID. - vector cuda_streams_[C10_COMPILE_TIME_MAX_GPUS]; - std::unordered_map cublas_handles_; -#ifdef CAFFE2_USE_CUDNN - std::unordered_map cudnn_handles_; -#endif // CAFFE2_USE_CUDNN -}; - -class CAFFE2_CUDA_API CUDAContext final : public BaseContext { - public: - // The default cuda context constructor. - explicit CUDAContext(DeviceIndex gpu_id = -1); - explicit CUDAContext(const DeviceOption& option); - explicit CUDAContext(Device device) - : CUDAContext(DeviceToOption(device)) {} - - ~CUDAContext() override; - - inline void SwitchToDevice(StreamId stream_id) override { - getCudaObjects().SetCurrentStreamId(gpu_id_, stream_id); - CaffeCudaSetDevice(gpu_id_); - } - - // void SwitchToDevice() - using BaseContext::SwitchToDevice; - - inline void WaitEvent(const Event& ev) override { - ev.Wait(CUDA, this); - } - - inline void Record(Event* ev, const char* err_msg = nullptr) const override { - CAFFE_ENFORCE(ev, "Event must not be null."); - ev->Record(CUDA, this, err_msg); - } - - // Note on current use cases: - // FinishDeviceComputation must be called on the same cpu thread as - // SwitchToDevice() - void FinishDeviceComputation() override { - CUDA_ENFORCE(cudaStreamSynchronize(getCudaObjects().GetStream(gpu_id_))); - } - - inline int device_id() const { - return gpu_id_; - } - - inline c10::cuda::CUDAStream stream() const { - return at::cuda::getStreamFromExternal(getCudaObjects().GetStream(gpu_id_), gpu_id_); - } - - inline cudaStream_t cuda_stream() const { - return getCudaObjects().GetStream(gpu_id_); - } - - static cudaStream_t cuda_stream(DeviceIndex gpu_id, StreamId stream_id) { - return getCudaObjects().GetStream(gpu_id, stream_id); - } - - cublasHandle_t cublas_handle() { - return getCudaObjects().GetHandle(gpu_id_); - } - -#ifdef CAFFE2_USE_CUDNN - cudnnHandle_t cudnn_handle() { - return getCudaObjects().GetCudnnHandle(gpu_id_); - } -#endif // CAFFE2_USE_CUDNN - - curandGenerator_t& curand_generator() { - if (!curand_generator_) { - CUDAGuard guard(gpu_id_); - CURAND_ENFORCE( - curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); - CURAND_ENFORCE( - curandSetPseudoRandomGeneratorSeed(curand_generator_, random_seed_)); - TORCH_CHECK_NOTNULL(curand_generator_); - } - CURAND_ENFORCE(curandSetStream(curand_generator_, cuda_stream())); - return curand_generator_; - } - - inline static at::DataPtr New(size_t nbytes) { - return GetAllocator(CUDA)->allocate(nbytes); - } - - // Get a mutex to lock out cudaMalloc / cudaFree calls when - // NCCL kernels are being launched. Should remove threat of - // deadlocks - static std::mutex& mutex(); - - // Functions to query memory stats. Only available if flag - // --caffe2_gpu_memory_tracking is enabled. - static std::vector TotalMemoryByGpu(); - static std::vector MaxMemoryByGpu(); - - template - inline void CopyBytes(size_t nbytes, const void* src, void* dst) { - CUDA_ENFORCE(cudaMemcpyAsync( - dst, - src, - nbytes, - cudaMemcpyDefault, - getCudaObjects().GetStream(gpu_id_))); - } - - void CopyBytesSameDevice(size_t nbytes, const void* src, void* dst) override { - CopyBytes(nbytes, src, dst); - } - - void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) override { - CopyBytes(nbytes, src, dst); - } - - void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) override { - CopyBytes(nbytes, src, dst); - } - - template - inline void Copy(int n, const T* src, T* dst) { - CopyBytes(n * sizeof(T), - static_cast(src), - static_cast(dst)); - } - - template - inline void - CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { - CAFFE_ENFORCE(!meta.copy(), "CUDAContext requires fundamental types."); - CopyBytes(n * meta.itemsize(), src, dst); - } - - static void CopyBytesAsync( - size_t nbytes, - const void* src, - Device src_device, - void* dst, - Device dst_device); - static void CopyBytesSync( - size_t nbytes, - const void* src, - Device src_device, - void* dst, - Device dst_device); - - // By default CUDA operators have async device parts - static bool HasAsyncPartDefault() { - return true; - } - - static bool SupportsAsyncScheduling() { - return true; - } - - static bool IsStreamFree(const DeviceOption& option, StreamId stream_id) { - const auto stream = CUDAContext::cuda_stream(option.device_id(), stream_id); - const auto status = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream)); - if (status == cudaErrorNotReady) { - // ignore and clear the error if not ready - C10_CUDA_CLEAR_ERROR(); - } else { - C10_CUDA_CHECK(status); // Reraise error - } - return status == cudaSuccess; - } - - at::Device device() const override { - return at::Device(CUDA, gpu_id_); - } - - DeviceType device_type() const override { - return CUDA; - } - - static constexpr DeviceType GetDeviceType() { - return CUDA; - } - - protected: - int gpu_id_; - int random_seed_; - curandGenerator_t curand_generator_{nullptr}; - static ThreadLocalCUDAObjects& getCudaObjects(); -}; - -using TensorCUDA = Tensor; - -} // namespace caffe2 - -#endif // CAFFE2_CORE_CONTEXT_GPU_H_ diff --git a/caffe2/core/event_gpu.cc b/caffe2/core/event_gpu.cc deleted file mode 100644 index 82000de79011..000000000000 --- a/caffe2/core/event_gpu.cc +++ /dev/null @@ -1,227 +0,0 @@ -#include "caffe2/core/context_gpu.h" -#include "caffe2/core/event_cpu.h" -#include "caffe2/core/operator.h" - -#include -#include - -namespace caffe2 { - -struct CudaEventWrapper { - explicit CudaEventWrapper(const DeviceOption& option) - : cuda_stream_(nullptr), - device_id_(option.device_id()), - status_(EventStatus::EVENT_INITIALIZED) { - CAFFE_ENFORCE(option.device_type(), PROTO_CUDA); - CUDAGuard g(device_id_); - try { - CUDA_ENFORCE(cudaEventCreateWithFlags( - &cuda_event_, cudaEventDefault | cudaEventDisableTiming)); - } catch (const Error&) { - std::cerr << "ERROR: Failed to load CUDA.\n" - << "HINT: Check that this binary contains GPU code." - << std::endl; - throw; - } - } - ~CudaEventWrapper() { - CUDAGuard g(device_id_); - CUDA_CHECK(cudaEventDestroy(cuda_event_)); - } - - cudaEvent_t cuda_event_; - cudaStream_t cuda_stream_; - int device_id_; - - std::atomic status_; - std::mutex mutex_recorded_; - std::condition_variable cv_recorded_; - std::string err_msg_; -}; - -namespace { -const std::string kNoError = "No error"; -} - -void EventCreateCUDA(const DeviceOption& option, Event* event) { - event->event_ = std::make_shared(option); -} - -void EventRecordCUDA(Event* event, const void* context, const char* err_msg) { - auto* wrapper = static_cast(event->event_.get()); - { - std::unique_lock lock(wrapper->mutex_recorded_); - - // Possible state changes: - // INITIALIZED -> SCHEDULED/FAILED - // SCHEDULED -> SUCCESS/FAILED - // SUCCESS/FAILED - terminal - // - // No further changes to cuda_event_ and cuda_stream_ after transitioning - // from INITIALIZED - // No further changes to err_msg_ after transitioning into FAILED - - CAFFE_ENFORCE_EQ( - wrapper->status_, - EventStatus::EVENT_INITIALIZED, - "Calling Record multiple times"); - - if (!err_msg) { - // When recording, one needs to make sure that the current gpu id is - // correct. - // TODO(jiayq): move the enforce logic to the caller? - const auto& current_device = CaffeCudaGetDevice(); - CAFFE_ENFORCE_EQ( - current_device, - wrapper->device_id_, - "When you call EventRecordCUDA, your current device should be the same " - "as the device specified by the event."); - CAFFE_ENFORCE_EQ( - current_device, - static_cast(context)->device_id()); - CUDA_ENFORCE(cudaEventRecord( - wrapper->cuda_event_, - static_cast(context)->cuda_stream())); - wrapper->cuda_stream_ = - static_cast(context)->cuda_stream(); - wrapper->status_ = EventStatus::EVENT_SCHEDULED; - } else { - wrapper->err_msg_ = err_msg; - wrapper->status_ = EventStatus::EVENT_FAILED; - } - } - wrapper->cv_recorded_.notify_all(); -} - -void EventFinishCUDA(const Event* event) { - auto* wrapper = static_cast(event->event_.get()); - { - std::unique_lock lock(wrapper->mutex_recorded_); - while (wrapper->status_ == EventStatus::EVENT_INITIALIZED) { - wrapper->cv_recorded_.wait(lock); - } - } - - if (wrapper->status_ == EventStatus::EVENT_SCHEDULED) { - // ok, even if event is already completed and status was not yet updated - CUDAGuard g(wrapper->device_id_); - auto cudaResult = cudaEventSynchronize(wrapper->cuda_event_); - if (cudaResult == cudaSuccess) { - wrapper->status_ = EventStatus::EVENT_SUCCESS; - } else { - const auto& err_msg = cudaGetErrorString(cudaResult); - - std::unique_lock lock(wrapper->mutex_recorded_); - wrapper->err_msg_ = err_msg; - wrapper->status_ = EventStatus::EVENT_FAILED; - } - } -} - -// Both waiter and event are CUDA. Non-blocking -void EventWaitCUDACUDA(const Event* event, void* context) { - auto* wrapper = static_cast(event->event_.get()); - { - std::unique_lock lock(wrapper->mutex_recorded_); - while (wrapper->status_ == EventStatus::EVENT_INITIALIZED) { - wrapper->cv_recorded_.wait(lock); - } - } - - if (wrapper->status_ == EventStatus::EVENT_SCHEDULED) { - // ok, even if event is already completed and status was not yet updated - auto context_stream = static_cast(context)->cuda_stream(); - auto event_stream = wrapper->cuda_stream_; - if (context_stream != event_stream) { - // CAFFE_ENFORCE_EQ( - // CaffeCudaGetDevice(), - // static_cast(context)->device_id()); - CUDA_CHECK(cudaStreamWaitEvent(context_stream, wrapper->cuda_event_, 0)); - } - } -} - -// Waiter is CPU, event is CUDA -void EventWaitCPUCUDA(const Event* event, void* context) { - EventFinishCUDA(event); -} - -// Waiter is CUDA, event is CPU -void EventWaitCUDACPU(const Event* event, void* context) { - event->Finish(); // calls EventFinishCPU -} - -EventStatus EventQueryCUDA(const Event* event) { - auto* wrapper = static_cast(event->event_.get()); - if (wrapper->status_ == EventStatus::EVENT_SCHEDULED) { - auto cudaResult = cudaEventQuery(wrapper->cuda_event_); - if (cudaResult == cudaSuccess) { - wrapper->status_ = EventStatus::EVENT_SUCCESS; - } else if (cudaResult != cudaErrorNotReady) { - const auto& err_msg = cudaGetErrorString(cudaResult); - - std::unique_lock lock(wrapper->mutex_recorded_); - wrapper->err_msg_ = err_msg; - wrapper->status_ = EventStatus::EVENT_FAILED; - } else { - // ignore and clear the error if not ready - (void)cudaGetLastError(); - } - } - return static_cast(wrapper->status_.load()); -} - -const std::string& EventErrorMessageCUDA(const Event* event) { - auto* wrapper = static_cast(event->event_.get()); - // supposed to be called after EventQueryCUDA to update status first - if (wrapper->status_ == EventStatus::EVENT_FAILED) { - return wrapper->err_msg_; - } else { - return kNoError; - } -} - -void EventSetFinishedCUDA(const Event* event, const char* err_msg) { - auto* wrapper = static_cast(event->event_.get()); - { - std::unique_lock lock(wrapper->mutex_recorded_); - - CAFFE_ENFORCE_EQ( - wrapper->status_, - EventStatus::EVENT_INITIALIZED, - "Calling SetFinished on recorded CUDA event"); - - if (!err_msg) { - wrapper->status_ = EventStatus::EVENT_SUCCESS; - } else { - wrapper->err_msg_ = err_msg; - wrapper->status_ = EventStatus::EVENT_FAILED; - } - } - wrapper->cv_recorded_.notify_all(); -} - -void EventResetCUDA(Event* event) { - auto* wrapper = static_cast(event->event_.get()); - std::unique_lock lock(wrapper->mutex_recorded_); - wrapper->status_ = EventStatus::EVENT_INITIALIZED; - wrapper->err_msg_ = ""; - wrapper->cuda_stream_ = nullptr; -} - -REGISTER_EVENT_CREATE_FUNCTION(CUDA, EventCreateCUDA); -REGISTER_EVENT_RECORD_FUNCTION(CUDA, EventRecordCUDA); -REGISTER_EVENT_WAIT_FUNCTION(CUDA, CUDA, EventWaitCUDACUDA); -REGISTER_EVENT_WAIT_FUNCTION(CPU, CUDA, EventWaitCPUCUDA); -REGISTER_EVENT_WAIT_FUNCTION(CUDA, CPU, EventWaitCUDACPU); -REGISTER_EVENT_FINISH_FUNCTION(CUDA, EventFinishCUDA); - -REGISTER_EVENT_QUERY_FUNCTION(CUDA, EventQueryCUDA); -REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(CUDA, EventErrorMessageCUDA); -REGISTER_EVENT_SET_FINISHED_FUNCTION(CUDA, EventSetFinishedCUDA); -REGISTER_EVENT_RESET_FUNCTION(CUDA, EventResetCUDA); - -REGISTER_EVENT_WAIT_FUNCTION(MKLDNN, CUDA, EventWaitCPUCUDA); -REGISTER_EVENT_WAIT_FUNCTION(CUDA, MKLDNN, EventWaitCUDACPU); - -} // namespace caffe2 diff --git a/caffe2/core/flags.h b/caffe2/core/flags.h deleted file mode 100644 index 54f1f41f2fb3..000000000000 --- a/caffe2/core/flags.h +++ /dev/null @@ -1,4 +0,0 @@ -#pragma once - -#include "c10/util/Flags.h" -#include "caffe2/core/common.h" diff --git a/caffe2/core/hip/common_miopen.h b/caffe2/core/hip/common_miopen.h deleted file mode 100644 index 6901055813cb..000000000000 --- a/caffe2/core/hip/common_miopen.h +++ /dev/null @@ -1,178 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef CAFFE2_CORE_COMMON_MIOPEN_H_ -#define CAFFE2_CORE_COMMON_MIOPEN_H_ - -#include -#include -#include "miopen/miopen.h" -#include "caffe2/core/common.h" -#include "caffe2/core/context.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/types.h" -#include "caffe2/proto/caffe2_pb.h" - -#define MIOPEN_VERSION 1399 - -namespace caffe2 { - -namespace internal { -/** - * A helper function to obtain miopen error strings. - */ -inline const char* miopenGetErrorString(miopenStatus_t status) -{ - switch(status) - { - case miopenStatusSuccess: return "MIOPEN_STATUS_SUCCESS"; - case miopenStatusNotInitialized: return "MIOPEN_STATUS_NOT_INITIALIZED"; - case miopenStatusAllocFailed: return "MIOPEN_STATUS_ALLOC_FAILED"; - case miopenStatusBadParm: return "MIOPEN_STATUS_BAD_PARAM"; - case miopenStatusInternalError: return "MIOPEN_STATUS_INTERNAL_ERROR"; - case miopenStatusInvalidValue: return "MIOPEN_STATUS_INVALID_VALUE"; - case miopenStatusNotImplemented: return "MIOPEN_STATUS_NOT_SUPPORTED"; - case miopenStatusUnknownError: return "MIOPEN_STATUS_UNKNOWN_ERROR"; - default: return "MIOPEN_STATUS_UNKNOWN_ERROR"; - } -} -} // namespace internal - -// A macro that wraps around a miopen statement so we can check if the miopen -// execution finishes or not. -#define MIOPEN_ENFORCE(condition) \ - do \ - { \ - miopenStatus_t status = condition; \ - CAFFE_ENFORCE_EQ(status, \ - miopenStatusSuccess, \ - ", Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - ::caffe2::internal::miopenGetErrorString(status)); \ - } while(0) -#define MIOPEN_CHECK(condition) \ - do \ - { \ - miopenStatus_t status = condition; \ - CHECK(status == miopenStatusSuccess) << ::caffe2::internal::miopenGetErrorString(status); \ - } while(0) - -// report the version of miopen Caffe2 was compiled with -inline size_t miopenCompiledVersion() { return MIOPEN_VERSION; } - -// report the runtime version of miopen -inline size_t miopenRuntimeVersion() { return MIOPEN_VERSION; } - -// Check compatibility of compiled and runtime miopen versions -inline void CheckMIOPENVersions() {} - -/** - * miopenTypeWrapper is a wrapper class that allows us to refer to the miopen type - * in a template function. The class is specialized explicitly for different - * data types below. - */ -template -class miopenTypeWrapper; - -template <> -class miopenTypeWrapper -{ - public: - static const miopenDataType_t type = miopenFloat; - typedef const float ScalingParamType; - typedef float BNParamType; - static ScalingParamType* kOne() - { - static ScalingParamType v = 1.0; - return &v; - } - static const ScalingParamType* kZero() - { - static ScalingParamType v = 0.0; - return &v; - } -}; - -template <> -class miopenTypeWrapper -{ - public: - static const miopenDataType_t type = miopenHalf; - typedef const float ScalingParamType; - typedef float BNParamType; - static ScalingParamType* kOne() - { - static ScalingParamType v = 1.0; - return &v; - } - static ScalingParamType* kZero() - { - static ScalingParamType v = 0.0; - return &v; - } -}; - -/** - * miopenTensorDescWrapper is the placeholder that wraps around a - * miopenTensorDescriptor_t, allowing us to do descriptor change as-needed during - * runtime. - */ -class miopenTensorDescWrapper -{ - public: - miopenTensorDescWrapper() { MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&desc_)); } - ~miopenTensorDescWrapper() noexcept { MIOPEN_CHECK(miopenDestroyTensorDescriptor(desc_)); } - - inline miopenTensorDescriptor_t - Descriptor(const miopenDataType_t type, const vector& dims, bool* changed) - { - if(type_ == type && dims_ == dims) - { - // if not changed, simply return the current descriptor. - if(changed) - *changed = false; - return desc_; - } - CAFFE_ENFORCE_EQ( - dims.size(), 4, "MIOPEN currently only support 4-dimensional tensor descriptor"); - - type_ = type; - dims_ = dims; - MIOPEN_ENFORCE( - miopenSet4dTensorDescriptor(desc_, type, dims_[0], dims_[1], dims_[2], dims_[3])); - if(changed) - *changed = true; - return desc_; - } - - template - inline miopenTensorDescriptor_t Descriptor(const StorageOrder& order, const vector& dims) - { - return Descriptor(miopenTypeWrapper::type, dims, nullptr); - } - - private: - miopenTensorDescriptor_t desc_; - miopenDataType_t type_; - vector dims_; - C10_DISABLE_COPY_AND_ASSIGN(miopenTensorDescWrapper); -}; - -} // namespace caffe2 - -#endif // CAFFE2_CORE_COMMON_MIOPEN_H_ diff --git a/caffe2/core/hip/common_miopen.hip b/caffe2/core/hip/common_miopen.hip deleted file mode 100644 index a617bad29a3d..000000000000 --- a/caffe2/core/hip/common_miopen.hip +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "caffe2/core/hip/common_miopen.h" -#include "caffe2/core/hip/miopen_wrapper.h" - -#include "caffe2/core/init.h" - -namespace caffe2 { - -MIOPENWrapper::PerGPUMIOPENStates& MIOPENWrapper::miopen_states() -{ - // New it (never delete) to avoid calling the destructors on process - // exit and racing against the CUDA shutdown sequence. - static auto* p = new MIOPENWrapper::PerGPUMIOPENStates(); - TORCH_CHECK_NOTNULL(p); - return *p; -} - -namespace { -bool PrintMIOPENInfo(int*, char***) -{ - VLOG(1) << "Caffe2 is built with MIOPEN version " << MIOPEN_VERSION; - return true; -} - -REGISTER_CAFFE2_INIT_FUNCTION(PrintMIOPENInfo, &PrintMIOPENInfo, "Print MIOPEN Info."); - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/hip/miopen_wrapper.h b/caffe2/core/hip/miopen_wrapper.h deleted file mode 100644 index f60bed6c277d..000000000000 --- a/caffe2/core/hip/miopen_wrapper.h +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. -#ifndef CAFFE2_CORE_MIOPEN_WRAPPERS_H_ -#define CAFFE2_CORE_MIOPEN_WRAPPERS_H_ - -#include "caffe2/core/hip/common_miopen.h" -#include "caffe2/core/hip/context_gpu.h" - -#include - -namespace caffe2 { - -class MIOPENWrapper; - -/** - * MIOPENWorkspace is a wrapper around a raw cuda pointer that holds the miopen - * scratch space. This struct is meant to be only used in MIOPENWrapper to - * provide a program-wide scratch space for MIOPEN. The reason behind it is that - * miopen function calls are usually very efficient, hence one probably does not - * want to run multiple miopen calls at the same time. As a result, one should - * not need more than one miopen workspace per device. - */ -struct MIOPENWorkspace -{ - ~MIOPENWorkspace() noexcept {} - - void* get(size_t nbytes) - { - if(nbytes_ < nbytes) - { - reset(); - data_ = HIPContext::New(nbytes); - nbytes_ = nbytes; - } - CAFFE_ENFORCE_GE(nbytes_, nbytes); - return data_.get(); - } - - void reset() - { - data_.clear(); - nbytes_ = 0; - } - - private: - at::DataPtr data_; - size_t nbytes_{0}; -}; - -// MIOPENState is the owner of the MIOPENWorkspace, and serializes all -// executions of operations that use the state onto it's own stream -// (so multiple Net workers can reuse the same workspace from -// different threads and HIP streams). -class MIOPENState -{ - public: - explicit MIOPENState(size_t gpu_id) : gpu_id_(gpu_id) - { - HIPGuard g(gpu_id_); - MIOPEN_ENFORCE(miopenCreate(&miopen_handle_)); - HIP_ENFORCE(hipEventCreate(&before_)); - HIP_ENFORCE(hipEventCreate(&after_)); - HIP_ENFORCE(hipStreamCreate(&stream_)); - MIOPEN_ENFORCE(miopenSetStream(miopen_handle_, stream_)); - } - - ~MIOPENState() noexcept - { - HIPGuard g(gpu_id_); - MIOPEN_CHECK(miopenDestroy(miopen_handle_)); - HIP_CHECK(hipStreamDestroy(stream_)); - HIP_CHECK(hipEventDestroy(after_)); - HIP_CHECK(hipEventDestroy(before_)); - } - - miopenHandle_t& miopen_handle() { return miopen_handle_; } - - MIOPENWorkspace& workspace() { return workspace_; } - - template - void execute(hipStream_t stream, F&& f) - { - HIP_ENFORCE(hipEventRecord(before_, stream)); - HIP_ENFORCE(hipStreamWaitEvent(stream_, before_, 0)); - f(this); - HIP_ENFORCE(hipEventRecord(after_, stream_)); - HIP_ENFORCE(hipStreamWaitEvent(stream, after_, 0)); - } - - private: - miopenHandle_t miopen_handle_{nullptr}; - hipEvent_t before_{nullptr}; - hipEvent_t after_{nullptr}; - hipStream_t stream_{nullptr}; - MIOPENWorkspace workspace_; - size_t gpu_id_{0}; - C10_DISABLE_COPY_AND_ASSIGN(MIOPENState); -}; - -/** - * MIOPENWrapper is a class that wraps the miopen handles and miopen workspaces. - * - * The wrapper ensures that for each thread and each gpu, there is one - * identical miopen handle, which is also associated with the thread-local - * per-device hip stream. The wrapper also hosts the device-specific miopen - * workspace (scratch space for some miopen functions). - * - */ -class MIOPENWrapper -{ - public: - /** - * Creates a miopen wrapper associated with a HIPContext object. Note that - * the HIPContext object should outlive the MIOPENWrapper. - */ - explicit MIOPENWrapper(HIPContext* context) : context_(context) {} - - /** - * Returns the inline miopen handle that executes on the current - * thread's hip_stream. - */ - miopenHandle_t inline_miopen_handle() { return context_->miopen_handle(); } - - // Executes the closure F on the MIOPENState associated with state_idx - template - void with_miopen_state(size_t state_idx, F&& f) - { - CAFFE_ENFORCE(state_idx < CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES, "Invalid state_idx"); - auto& sync_state = miopen_states()[context_->device_id()][state_idx]; - - HIPGuard dg(context_->device_id()); - - // We need to serialize execution on the MIOPENState as we can't - // allow multiple threads to race through the cudaEventRecord - // calls (so a worker thread might wait on another worker thread's - // execution) - std::lock_guard g(sync_state.mutex); - if(!sync_state.state.get()) - { - sync_state.state.reset(new MIOPENState(context_->device_id())); - } - TORCH_CHECK_NOTNULL(sync_state.state.get())->execute(context_->hip_stream(), f); - } - - protected: - // Pointer to an external cuda context that the miopen wrapper will use. - HIPContext* context_; - - static constexpr size_t CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES = 4; - - struct SyncedMIOPENState - { - std::mutex mutex; - std::unique_ptr state; - }; - - using PerGPUMIOPENStates = std::array< - std::array, - C10_COMPILE_TIME_MAX_GPUS>; - static PerGPUMIOPENStates& miopen_states(); - - C10_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper); -}; - -}; // namespace caffe2 - -#endif diff --git a/caffe2/core/init.h b/caffe2/core/init.h deleted file mode 100644 index 8d0fbd3f1557..000000000000 --- a/caffe2/core/init.h +++ /dev/null @@ -1,179 +0,0 @@ -#ifndef CAFFE2_CORE_INIT_H_ -#define CAFFE2_CORE_INIT_H_ - -#include "caffe2/core/common.h" -#include "caffe2/core/flags.h" -#include "caffe2/core/logging.h" - -namespace caffe2 { - -namespace internal { -class TORCH_API Caffe2InitializeRegistry { - public: - typedef bool (*InitFunction)(int*, char***); - // Registry() is defined in .cpp file to make registration work across - // multiple shared libraries loaded with RTLD_LOCAL - static Caffe2InitializeRegistry* Registry(); - - void Register( - InitFunction function, - bool run_early, - const char* description, - const char* name = nullptr) { - if (name) { - named_functions_[name] = function; - } - if (run_early) { - // Disallow registration after GlobalInit of early init functions - CAFFE_ENFORCE(!early_init_functions_run_yet_); - early_init_functions_.emplace_back(function, description); - } else { - if (init_functions_run_yet_) { - // Run immediately, since GlobalInit already ran. This should be - // rare but we want to allow it in some cases. - LOG(WARNING) << "Running init function after GlobalInit: " - << description; - // TODO(orionr): Consider removing argc and argv for non-early - // registration. Unfortunately that would require a new InitFunction - // typedef, so not making the change right now. - // - // Note that init doesn't receive argc and argv, so the function - // might fail and we want to raise an error in that case. - int argc = 0; - char** argv = nullptr; - bool success = (function)(&argc, &argv); - CAFFE_ENFORCE(success); - } else { - // Wait until GlobalInit to run - init_functions_.emplace_back(function, description); - } - } - } - - bool RunRegisteredEarlyInitFunctions(int* pargc, char*** pargv) { - CAFFE_ENFORCE(!early_init_functions_run_yet_); - early_init_functions_run_yet_ = true; - return RunRegisteredInitFunctionsInternal( - early_init_functions_, pargc, pargv); - } - - bool RunRegisteredInitFunctions(int* pargc, char*** pargv) { - CAFFE_ENFORCE(!init_functions_run_yet_); - init_functions_run_yet_ = true; - return RunRegisteredInitFunctionsInternal(init_functions_, pargc, pargv); - } - - bool RunNamedFunction(const char* name, int* pargc, char*** pargv) { - if (named_functions_.count(name)) { - return named_functions_[name](pargc, pargv); - } - return false; - } - - private: - // Run all registered initialization functions. This has to be called AFTER - // all static initialization are finished and main() has started, since we are - // using logging. - bool RunRegisteredInitFunctionsInternal( - vector>& functions, - int* pargc, char*** pargv) { - for (const auto& init_pair : functions) { - VLOG(1) << "Running init function: " << init_pair.second; - if (!(*init_pair.first)(pargc, pargv)) { - LOG(ERROR) << "Initialization function failed."; - return false; - } - } - return true; - } - - Caffe2InitializeRegistry() {} - vector > early_init_functions_; - vector > init_functions_; - std::unordered_map named_functions_; - bool early_init_functions_run_yet_ = false; - bool init_functions_run_yet_ = false; -}; -} // namespace internal - -TORCH_API bool unsafeRunCaffe2InitFunction( - const char* name, - int* pargc = nullptr, - char*** pargv = nullptr); - -class TORCH_API InitRegisterer { - public: - InitRegisterer( - internal::Caffe2InitializeRegistry::InitFunction function, - bool run_early, - const char* description, - const char* name = nullptr) { - internal::Caffe2InitializeRegistry::Registry()->Register( - function, run_early, description, name); - } -}; - -#define REGISTER_CAFFE2_INIT_FUNCTION(name, function, description) \ - namespace { \ - ::caffe2::InitRegisterer \ - g_caffe2_initregisterer_##name(function, false, description, #name); \ - } // namespace - -#define REGISTER_CAFFE2_EARLY_INIT_FUNCTION(name, function, description) \ - namespace { \ - ::caffe2::InitRegisterer \ - g_caffe2_initregisterer_##name(function, true, description, #name); \ - } // namespace - -/** - * @brief Determine whether GlobalInit has already been run - */ -TORCH_API bool GlobalInitAlreadyRun(); - -class TORCH_API GlobalInitIsCalledGuard { - public: - GlobalInitIsCalledGuard() { - if (!GlobalInitAlreadyRun()) { - LOG(WARNING) - << "Caffe2 GlobalInit should be run before any other API calls."; - } - } -}; - -/** - * @brief Initialize the global environment of caffe2. - * - * Caffe2 uses a registration pattern for initialization functions. Custom - * initialization functions should take the signature - * bool (*func)(int*, char***) - * where the pointers to argc and argv are passed in. Caffe2 then runs the - * initialization in three phases: - * (1) Functions registered with REGISTER_CAFFE2_EARLY_INIT_FUNCTION. Note that - * since it is possible the logger is not initialized yet, any logging in - * such early init functions may not be printed correctly. - * (2) Parses Caffe-specific commandline flags, and initializes caffe logging. - * (3) Functions registered with REGISTER_CAFFE2_INIT_FUNCTION. - * If there is something wrong at each stage, the function returns false. If - * the global initialization has already been run, the function returns false - * as well. - * - * GlobalInit is re-entrant safe; a re-entrant call will no-op and exit. - * - * GlobalInit is safe to call multiple times but not idempotent; - * successive calls will parse flags and re-set caffe2 logging levels from - * flags as needed, but NOT re-run early init and init functions. - * - * GlobalInit is also thread-safe and can be called concurrently. - */ -TORCH_API bool GlobalInit(int* pargc, char*** argv); - -/** - * @brief Initialize the global environment without command line arguments - * - * This is a version of the GlobalInit where no argument is passed in. - * On mobile devices, use this global init, since we cannot pass the - * command line options to caffe2, no arguments are passed. - */ -TORCH_API bool GlobalInit(); -} // namespace caffe2 -#endif // CAFFE2_CORE_INIT_H_ diff --git a/caffe2/core/net.h b/caffe2/core/net.h deleted file mode 100644 index 0726d8e8c6c9..000000000000 --- a/caffe2/core/net.h +++ /dev/null @@ -1,175 +0,0 @@ -#ifndef CAFFE2_CORE_NET_H_ -#define CAFFE2_CORE_NET_H_ - -#include -#include -#include -#include // NOLINT -#include -#include -#include - -#include "c10/core/thread_pool.h" -#include "c10/util/Registry.h" -#include "caffe2/core/blob.h" -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/observer.h" -#include "caffe2/core/operator_schema.h" -#include "caffe2/core/tensor.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/simple_queue.h" - -C10_DECLARE_string(caffe2_override_executor); - -namespace caffe2 { - -class NetBase; -typedef ObserverBase NetObserver; -typedef std::function(NetBase*)> - NetObserverCreator; - -class OperatorBase; -class Workspace; - -// Net is a thin struct that owns all the operators together with the operator -// contexts. -class TORCH_API NetBase : public Observable { - public: - NetBase(const std::shared_ptr& net_def, Workspace* ws); - virtual ~NetBase() noexcept {} - - virtual bool SupportsAsync() = 0; - inline const vector& events() const { - return events_; - } - - virtual void Wait() { - // by default just wait till all events are finished - for (const auto& event : events_) { - event->Finish(); - } - } - - virtual bool Run() { - if (!RunAsync()) { - LOG(ERROR) << "Failed to execute async run"; - return false; - } - Wait(); - return handleRunError(); - } - - virtual bool RunAsync(); - - virtual void Cancel(); - - /* Benchmarks a network for one individual run so that we can feed new - * inputs on additional calls. - * This function returns the number of microseconds spent - * during the benchmark - */ - virtual float TEST_Benchmark_One_Run(); - - /** - * Benchmarks a network. - * - * This function returns a vector of float recording the number of milli- - * seconds spent during the benchmark. The 0-th item is the time spent per - * each network run, and if a net instantiation supports run_individual, - * the remainder of the vector returns the number of milliseconds spent per - * operator. - */ - virtual vector TEST_Benchmark( - const int /*warmup_runs*/, - const int /*main_runs*/, - const bool /*run_individual*/); - - inline const vector& external_output() const { - return external_output_; - } - - inline const vector& external_input() const { - return external_input_; - } - - /* Used to attach Observers to operators of a Net - * - * Returns pointers to objects owned with unique_ptrs. - * Use with caution. - */ - virtual vector GetOperators() const = 0; - - const string& Name() const { - return name_; - } - - inline const NetDef& debug_def() const { - CAFFE_ENFORCE(has_debug_def(), "net_def was null!"); - return *net_def_; - } - - inline bool has_debug_def() const { - return net_def_ != nullptr; - } - - protected: - virtual bool DoRunAsync() { - CAFFE_THROW("Not implemented"); - }; - - virtual bool handleRunError() { - for (const Event* event : events_) { - if (event->Query() != EventStatus::EVENT_SUCCESS) { - CAFFE_THROW(event->ErrorMessage()); - } - } - return true; - } - - vector external_input_; - vector external_output_; - string name_; - vector events_; - std::shared_ptr net_def_; - C10_DISABLE_COPY_AND_ASSIGN(NetBase); -}; - -class TORCH_API ExecutorHelper { - public: - ExecutorHelper() {} - virtual TaskThreadPoolBase* GetPool(const DeviceOption& option) const; - virtual std::vector GetOperators() const; - virtual int GetNumWorkers() const; - virtual ~ExecutorHelper() {} -}; - -C10_DECLARE_REGISTRY( - NetRegistry, - NetBase, - const std::shared_ptr&, - Workspace*); -#define REGISTER_NET_CREATOR(key, ...) \ - C10_REGISTER_CREATOR(NetRegistry, key, __VA_ARGS__) -#define REGISTER_NET(name, ...) \ - C10_REGISTER_CLASS(NetRegistry, name, __VA_ARGS__) - -/** - * @brief Creates a network, accessing / creating blobs in the given workspace. - * - * Note that this is different from Workspace::CreateNet. The latter adds the - * created net object to the workspace's net map, while this function returns - * a standalone net object. - */ -TORCH_API unique_ptr CreateNet(const NetDef& net_def, Workspace* ws); -TORCH_API unique_ptr CreateNet( - const std::shared_ptr& net_def, - Workspace* ws); - -TORCH_API void AddGlobalNetObserverCreator(NetObserverCreator creator); - -TORCH_API void ClearGlobalNetObservers(); - -} // namespace caffe2 - -#endif // CAFFE2_CORE_NET_H_ diff --git a/caffe2/core/numa.h b/caffe2/core/numa.h deleted file mode 100644 index 8424d544fa38..000000000000 --- a/caffe2/core/numa.h +++ /dev/null @@ -1,3 +0,0 @@ -#pragma once -#include "c10/util/numa.h" -#include "caffe2/core/common.h" diff --git a/caffe2/core/observer.h b/caffe2/core/observer.h deleted file mode 100644 index 3897bb76b52a..000000000000 --- a/caffe2/core/observer.h +++ /dev/null @@ -1,164 +0,0 @@ -#pragma once - -#include -#include - -#include "caffe2/core/logging.h" - -namespace caffe2 { - -/** - * Use this to implement a Observer using the Observer Pattern template. - */ - -template -class ObserverBase { - public: - explicit ObserverBase(T* subject) : subject_(subject) {} - - virtual void Start() {} - virtual void Stop() {} - - virtual std::string debugInfo() { - return "Not implemented."; - } - - virtual ~ObserverBase() noexcept {} - - T* subject() const { - return subject_; - } - - virtual std::unique_ptr> rnnCopy(T* subject, int rnn_order) - const { - return nullptr; - } - - protected: - T* subject_; -}; - -/** - * Inherit to make your class observable. - */ -template -class Observable { - public: - Observable() = default; - - Observable(Observable&&) = default; - Observable& operator =(Observable&&) = default; - - virtual ~Observable() = default; - - C10_DISABLE_COPY_AND_ASSIGN(Observable); - - using Observer = ObserverBase; - - /* Returns a reference to the observer after addition. */ - const Observer* AttachObserver(std::unique_ptr observer) { - CAFFE_ENFORCE(observer, "Couldn't attach a null observer."); - std::unordered_set observers; - for (auto& ob : observers_list_) { - observers.insert(ob.get()); - } - - const auto* observer_ptr = observer.get(); - if (observers.count(observer_ptr)) { - return observer_ptr; - } - observers_list_.push_back(std::move(observer)); - UpdateCache(); - - return observer_ptr; - } - - /** - * Returns a unique_ptr to the removed observer. If not found, return a - * nullptr - */ - std::unique_ptr DetachObserver(const Observer* observer_ptr) { - for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) { - if (it->get() == observer_ptr) { - auto res = std::move(*it); - observers_list_.erase(it); - UpdateCache(); - return res; - } - } - return nullptr; - } - - virtual size_t NumObservers() { - return num_observers_; - } - - private: - inline static void StartObserver(Observer* observer) { - try { - observer->Start(); - } catch (const std::exception& e) { - LOG(ERROR) << "Exception from observer: " << e.what(); - } catch (...) { - LOG(ERROR) << "Exception from observer: unknown"; - } - } - - inline static void StopObserver(Observer* observer) { - try { - observer->Stop(); - } catch (const std::exception& e) { - LOG(ERROR) << "Exception from observer: " << e.what(); - } catch (...) { - LOG(ERROR) << "Exception from observer: unknown"; - } - } - - void UpdateCache() { - num_observers_ = observers_list_.size(); - if (num_observers_ != 1) { - // we cannot take advantage of the cache - return; - } - observer_cache_ = observers_list_[0].get(); - } - - public: - void StartAllObservers() { - // do not access observers_list_ unless necessary - if (num_observers_ == 0) { - return; - } else if (num_observers_ == 1) { - StartObserver(observer_cache_); - } else { - for (auto& observer : observers_list_) { - StartObserver(observer.get()); - } - } - } - - void StopAllObservers() { - // do not access observers_list_ unless necessary - if (num_observers_ == 0) { - return; - } else if (num_observers_ == 1) { - StopObserver(observer_cache_); - } else { - for (auto& observer : observers_list_) { - StopObserver(observer.get()); - } - } - } - - private: - // an on-stack cache for fast iteration; - // ideally, inside StartAllObservers and StopAllObservers, - // we should never access observers_list_ - Observer* observer_cache_; - size_t num_observers_ = 0; - - protected: - std::vector> observers_list_; -}; - -} // namespace caffe2 diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h deleted file mode 100644 index 3277357b4f34..000000000000 --- a/caffe2/core/operator.h +++ /dev/null @@ -1,1600 +0,0 @@ -#ifndef CAFFE2_CORE_OPERATOR_H_ -#define CAFFE2_CORE_OPERATOR_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include "caffe2/core/blob.h" -#include "caffe2/core/common.h" -#include "caffe2/core/net.h" -#include "caffe2/core/observer.h" -#include "caffe2/core/operator_gradient.h" -#include "caffe2/core/operator_schema.h" -#include "caffe2/core/tensor.h" -#include "caffe2/core/tensor_int8.h" -#include "caffe2/core/types.h" -#include "caffe2/core/workspace.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/proto_utils.h" - -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -#include -#include -#include -#endif - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") -#endif - -C10_DECLARE_bool(caffe2_operator_throw_if_fp_exceptions); -C10_DECLARE_bool(caffe2_operator_throw_if_fp_overflow_exceptions); -#ifdef __GNU_LIBRARY__ -C10_DECLARE_bool(caffe2_operator_throw_on_first_occurrence_if_fp_exceptions); -#endif - -namespace c10 { -struct FunctionSchema; -} - -namespace caffe2 { - -class TORCH_API OperatorBase; -typedef ObserverBase OperatorObserver; - -class TORCH_API OperatorBase : public Observable { - public: - explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws); - - /* - * Notes: All outputs ivalues must be tensors. Input ivalue list must start - * with all tensors ("inputs" in caffe2 terminology), - * followed by non-tensors ("arguments" in caffe2 terminology). - * Alternatively, inputs can be one tensor list ivalue followed by non-tensors - * to represent operators with a variable number of inputs. - */ -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - explicit OperatorBase( - const c10::FunctionSchema& schema, - std::vector inputs, - std::vector outputs); -#endif - - virtual ~OperatorBase() noexcept; - - /** @brief Return true if the operator was instantiated with OperatorDef - * New operators should be instantiated with FunctionSchema - */ - bool isLegacyOperator() const { -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - return !fn_schema_; -#else - return true; -#endif - } - - const c10::FunctionSchema& getFunctionSchema() const { - CAFFE_ENFORCE(!isLegacyOperator()); -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - return *fn_schema_.get(); -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - - /** @brief Checks if the operator has an argument of the given name. - */ - inline bool HasArgument(c10::string_view name) const { - if (isLegacyOperator()) { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::HasArgument(*operator_def_, name); - } - return argumentIndexWithName(name).has_value(); - } - - // Functions that deal with arguments. Basically, this allows us to map an - // argument name to a specific type of argument that we are trying to access. - template - inline T GetSingleArgument(c10::string_view name, const T& default_value) const { - if (isLegacyOperator()) { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::GetSingleArgument( - *operator_def_, name, default_value); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - auto index = argumentIndexWithName(name); - CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name); - const auto& value = newstyle_inputs_[index.value()]; - return value.template to(); -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - - template - inline bool HasSingleArgumentOfType(c10::string_view name) const { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::HasSingleArgumentOfType( - *operator_def_, name); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - template - inline vector GetVectorFromIValueList(const c10::IValue& value) const { - return value.template to>().vec(); - } -#endif - - template - inline vector GetRepeatedArgument( - c10::string_view name, - const vector& default_value = {}) const; - - // Get the inputs and outputs as specific types. - template - inline const T& Input(int idx) { - static_assert( - !std::is_same::value, - "You should use Input(int, DeviceType) for " - "Tensor."); - TORCH_DCHECK_LT((size_t)idx, inputs_.size()); - try { - return inputs_.at(idx)->template Get(); - } catch (::caffe2::EnforceNotMet& enf) { - if (has_debug_def()) { - TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), "."); - } - throw enf; - } - } - - // TODO(jerryzh): Remove template - // and the type argument? - // This is to keep the API changes minimal and make refactoring - // a bit easier - template - inline const T& Input(int idx, DeviceType type) { - if (isLegacyOperator()) { - static_assert( - std::is_same::value, - "Input(int, DeviceType) is only available for Tensor"); - TORCH_DCHECK_LT((size_t)idx, inputs_.size()); - try { - // TODO(jerryzh): We'll need to check device type in Get() later - // Get() -> Get(type) - const auto& tensor = inputs_.at(idx)->template Get(); - return tensor; - } catch (::caffe2::EnforceNotMet& enf) { - if (has_debug_def()) { - TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), "."); - } - throw enf; - } - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - TORCH_DCHECK_LT(0U, newstyle_inputs_.size()); - IValue ival; - if (newstyle_inputs_[0].isTensorList()) { - // if the first input is a tensor list, we get input tensors by indexing - // into that list. currently, this means that only tensors from that list - // are accessible as inputs. any hypothetical input tensors that come - // after the list are not accessible. - auto tensorList = newstyle_inputs_[0].toTensorVector(); - TORCH_DCHECK_LT((size_t)idx, tensorList.size()); - ival = tensorList[idx]; - } else { - // if the first input is not a tensor list, we get input tensors by - // indexing into the inputs. - TORCH_DCHECK_LT((size_t)idx, newstyle_inputs_.size()); - ival = newstyle_inputs_[idx]; - } - CAFFE_ENFORCE( - ival.isTensor(), - "Input(int, DeviceType) is only available for IValues that store Tensors"); - auto t = ival.toTensor(); - if (!t.is_contiguous()) { - t = t.contiguous(); - } - Tensor tensor = caffe2::Tensor(std::move(t)); - CAFFE_ENFORCE_EQ(tensor.GetDeviceType(), type); - input_tensors_[idx] = std::move(tensor); - return input_tensors_[idx]; -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - - template - inline T* Output(int idx) { - CAFFE_ENFORCE( - isLegacyOperator(), - "Output(idx) not supported for operators exported to c10. Please use XOutput instead."); - - static_assert( - !std::is_same::value, - "You should use Output(int, DeviceType) for " - "Tensor."); - return outputs_.at(idx)->template GetMutable(); - } - - // TODO(jerryzh): Remove this template - template - inline T* Output(int idx, DeviceType type) { - if (isLegacyOperator()) { - static_assert( - std::is_same::value, - "Output(int, DeviceType) is only available for Tensor"); - // When you get a Tensor here it is not fully initialized - return BlobGetMutableTensor(outputs_.at(idx), type); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - auto &output = output_tensors_[idx]; - if (!output.defined() || output.GetDeviceType() != type) { - // Fix tensor type - output = Tensor(type); - } - return &output; -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - - inline Tensor - XOutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) { - CAFFE_ENFORCE_WITH_CALLER( - options.device_opt() != c10::nullopt, - "device must be provided in option."); - if (isLegacyOperator()) { - return XBlobGetMutableTensor(outputs_.at(idx), dims, options); - } - - return OutputTensor(idx, dims, options)->UnsafeSharedInstance(); - } - - void SetOutputTensor(int idx, Tensor tensor) { - if (!isLegacyOperator()) { -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - output_tensors_[idx] = std::move(tensor); -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } else { - // update the tensor in the workspace - BlobSetTensor(outputs_.at(idx), std::move(tensor)); - } - } - - Tensor OutputTensorOrUndefined(int idx) { - if (isLegacyOperator()) { - return BlobGetTensorOrUndefined(*outputs_.at(idx)); - } - return output_tensors_[idx].UnsafeSharedInstance(); - } - - inline Tensor* - OutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) { - if (isLegacyOperator()) { - CAFFE_ENFORCE_WITH_CALLER( - options.device_opt() != c10::nullopt, - "device must be provided in options."); - return BlobGetMutableTensor(outputs_.at(idx), dims, options); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - auto &output = output_tensors_[idx]; - output = output.defined() - ? GetSizedTensorWithOptions(std::move(output), dims, options) - : caffe2::empty(dims, options); - - return &output; -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - - // Get output Tensor of the operator and CopyFrom the given Tensor - Tensor* OutputTensorCopyFrom( - int idx, - at::TensorOptions options, - const Tensor& src, - bool async = false) { - CAFFE_ENFORCE_WITH_CALLER( - options.device_opt() != c10::nullopt, - "device must be provided in options."); - // Ouptut Tensor will always have the same data type as `src` - if (!options.has_dtype()) { - options = options.dtype(src.dtype()); - } - CAFFE_ENFORCE_WITH_CALLER( - options.dtype() == src.dtype(), - "We don't allow change of src data type in OutputTensorCopyFrom"); - Tensor* t = OutputTensor(idx, src.sizes(), options); - t->CopyFrom(src, async); - return t; - } - - Tensor* OutputTensorAlias(int idx, const Tensor& src) { - CAFFE_ENFORCE( - isLegacyOperator(), - "OutputTensorAlias(idx, src) not (yet) supported for operators exported to c10."); - return BlobSetTensor(OutputBlob(idx), src.Alias()); - } - - template - inline T* Output(int idx, T* allocated) { - CAFFE_ENFORCE( - isLegacyOperator(), - "Output(idx, allocated) not supported for operators exported to c10. Please use XOutput."); - outputs_.at(idx)->Reset(allocated); - return allocated; - } - - inline const Blob& InputBlob(int idx) { - CAFFE_ENFORCE( - isLegacyOperator(), - "InputBlob(idx) not (yet) supported for operators exported to c10."); - return *inputs_.at(idx); - } - - inline Blob* OutputBlob(int idx) { - CAFFE_ENFORCE( - isLegacyOperator(), - "OutputBlob(idx) not (yet) supported for operators exported to c10."); - return outputs_.at(idx); - } - - // Check whether output j is an alias of input i by comparing Blob pointers, - // note this does not check if the two Blobs points to the same Tensor, or if - // the Tensor pointers point to the same TensorImpl, or if the Storages alias - inline bool IsInputOutputAlias(int i, int j) { - CAFFE_ENFORCE( - isLegacyOperator(), - "IsInputOutputAlias(i, j) not (yet) supported for operators exported to c10."); - return inputs_.at(i) == outputs_.at(j); - } - - template - inline bool InputIsType(int idx) { - CAFFE_ENFORCE( - isLegacyOperator(), - "InputIsType(idx) not (yet) supported for operators exported to c10."); - static_assert( - !std::is_same::value, - "You should use InputIsTensorType(int, DeviceType) for " - "Tensor."); - return inputs_.at(idx)->template IsType(); - } - - inline bool InputIsTensorType(int idx, DeviceType device_type) { - CAFFE_ENFORCE( - isLegacyOperator(), - "InputIsTensorType(idx, device_type) not (yet) supported for operators exported to c10."); - return BlobIsTensorType(*inputs_.at(idx), device_type); - } - - template - inline bool OutputIsType(int idx) { - CAFFE_ENFORCE( - isLegacyOperator(), - "OutputIsType(idx) not (yet) supported for operators exported to c10."); - static_assert( - !std::is_same::value, - "You should use OutputIsTensorType(int, DeviceType) for " - "Tensor."); - return outputs_.at(idx)->template IsType(); - } - - inline bool OutputIsTensorType(int idx, DeviceType type) { - CAFFE_ENFORCE( - isLegacyOperator(), - "OutputIsTensorType(idx, type) not (yet) supported for operators exported to c10."); - return BlobIsTensorType(*outputs_.at(idx), type); - } - - inline int InputSize() const { - return input_size_; - } - - inline int OutputSize() const { - if (isLegacyOperator()) { - return outputs_.size(); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - return output_tensors_.size(); -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - inline const vector& Inputs() const { - CAFFE_ENFORCE( - isLegacyOperator(), - "Inputs() not supported for operators exported to c10."); - return inputs_; - } - inline const vector& Outputs() { - CAFFE_ENFORCE( - isLegacyOperator(), - "Outputs() not supported for operators exported to c10."); - return outputs_; - } - vector InputTensorShapes() const; - - virtual void WaitEvent(const Event& ev, int /*stream_id */ = -1) { - ev.Finish(); - } - - inline void Wait(const OperatorBase& other, int stream_id = -1) { - if (!other.IsEventDisabled()) { - WaitEvent(other.event(), stream_id); - } - } - - virtual void WaitEvents( - const std::vector& events, - int /*stream_id*/ = -1) { - for (const auto& ev : events) { - ev->Finish(); - } - } - - virtual void Finish() { - if (event_) { - event_->Finish(); - } - } - - virtual bool Run(int /* unused */ /*stream_id*/ = 0) { - CAFFE_NOT_IMPLEMENTED; - } - - virtual bool HasAsyncPart() const { - return false; - } - - virtual bool SupportsAsyncScheduling() const { - return false; - } - - virtual void CancelAsyncCallback() {} - - virtual void Cancel() {} - - // RunAsync, if implemented by the specific operators, will schedule the - // computation on the corresponding context and record the event in its - // event_ member object. If the specific operator does not support RunAsync, - // it will simply be synchronous as a fallback. - virtual bool RunAsync(int stream_id = 0); - - virtual void AddRelatedBlobInfo(EnforceNotMet* err); - - virtual std::string debug_info_string() const { - return ""; - } - - inline const OperatorDef& debug_def() const { - CAFFE_ENFORCE(has_debug_def(), "operator_def was null!"); - return *operator_def_; - } - - inline void set_debug_def( - const std::shared_ptr& operator_def) { - operator_def_ = operator_def; - } - - inline bool has_debug_def() const { - return operator_def_ != nullptr; - } - - public: - void RecordLastFailedOpNetPosition() { - if (net_position_ != kNoNetPositionSet) { - VLOG(1) << "Operator with id " << net_position_ << " failed"; - operator_ws_->last_failed_op_net_position = net_position_; - } else { - VLOG(1) << "Failed operator doesn't have id set"; - } - } - - int net_position() const { - return net_position_; - } - - void set_net_position(int idx) { - net_position_ = idx; - } - - const DeviceOption& device_option() const { - return device_option_; - } - - const Event& event() const { - CAFFE_ENFORCE(event_, "Event is disabled"); - return *event_; - } - - Event& event() { - CAFFE_ENFORCE(event_, "Event is disabled"); - return *event_; - } - - void ResetEvent() { - if (event_) { - event_->Reset(); - } - } - - void DisableEvent() { - event_ = nullptr; - } - - bool IsEventDisabled() const { - return !event_; - } - - // Internal API invoked by observers. Normal callers shouldn't invoke it. - virtual void SyncDeviceBarrierForObservers() { - CAFFE_NOT_IMPLEMENTED; - } - - // Checks whether stream is ready to execute new computation, - // used in stream allocation optimization to skip stream that is currently - // busy. Depends on context and operator's device, returns true by default - virtual bool IsStreamFree(int /* unused */) const { - return true; - } - - const std::string& type() const { - return type_; - } - - void annotate_engine(const std::string& engine) { - engine_ = engine; - } - - const std::string& engine() const { - return engine_; - } - - void SetExecutorHelper(ExecutorHelper* helper) { - helper_ = helper; - } - - ExecutorHelper* GetExecutorHelper() const { - return helper_; - } - -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - std::vector move_output_tensors() && { - return std::move(output_tensors_); - } -#endif - - public: - static const int kNoNetPositionSet = -1; - - private: - Workspace* operator_ws_; - std::shared_ptr operator_def_; - DeviceOption device_option_; - std::string engine_; - std::string type_; - vector inputs_; - vector outputs_; - // Preferably use std::optional, but nvcc doesn't work -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - std::unique_ptr fn_schema_; - vector newstyle_inputs_; -#endif - // HACK - // We preserve the fact that Output() returns Tensor* - // by storing Tensor in a vector owned by the - // operator. - vector input_tensors_; - vector output_tensors_; - - int input_size_; - - int net_position_{kNoNetPositionSet}; - - ExecutorHelper* helper_ = nullptr; - - protected: - virtual void RecordEvent(const char* /*err_msg*/ = nullptr) { - CAFFE_NOT_IMPLEMENTED; - } - - void SetEventFinished(const char* err_msg = nullptr) { - if (event_) { - event_->SetFinished(err_msg); - } - } - - void SetEventFinishedWithException(const char* err_msg = nullptr) { - if (event_) { - event_->SetFinishedWithException(err_msg); - } - } - - std::string getErrorMsg() { - if (has_debug_def()) { - return "Error from operator: " + ProtoDebugString(debug_def()); - } else { - return "Error from operator: no op def"; - } - } - - std::optional argumentIndexWithName(c10::string_view name) const; - - // An event used by asynchronous execution. - std::unique_ptr event_; - - C10_DISABLE_COPY_AND_ASSIGN(OperatorBase); -}; - -template <> -inline NetDef OperatorBase::GetSingleArgument( - c10::string_view name, - const NetDef& default_value) const { - if (isLegacyOperator()) { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::GetSingleArgument( - *operator_def_, name, default_value); - } - CAFFE_THROW("Cannot get NetDefs from IValue"); - return NetDef(); -} - -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -template <> -inline vector OperatorBase::GetVectorFromIValueList( - const c10::IValue& value) const { - auto vs = value.toIntVector(); - vector out; - out.reserve(vs.size()); - for (int64_t v : vs) { - out.emplace_back(v); - } - return out; -} - -template <> -inline vector OperatorBase::GetVectorFromIValueList( - const c10::IValue& value) const { - const auto& vs = value.toDoubleVector(); - vector out; - out.reserve(vs.size()); - for (double v : vs) { - out.emplace_back(v); - } - return out; -} - -template <> -inline vector OperatorBase::GetVectorFromIValueList( - const c10::IValue& value) const { - auto vs = value.template to>(); - vector out; - out.reserve(vs.size()); - for (string v : vs) { - out.emplace_back(v); - } - return out; -} - -// We need this specialisation because IValue based lists don't support -// int16_t. We need to load it as List and transform to int16_t. -template <> -inline vector OperatorBase::GetVectorFromIValueList( - const c10::IValue& value) const { - auto list = value.template to>(); - std::vector result; - result.reserve(list.size()); - for (int64_t elem : list) { - result.push_back(static_cast(elem)); - } - return result; -} -#endif - -// OP_SINGLE_ARG provides a shorter initialization choice for initialization of -// member variables for the class constructors. -#define OP_SINGLE_ARG(type, name, variable, default) \ - variable(OperatorBase::GetSingleArgument(name, (default))) - -// INPUT_TAGS and OUTPUT_TAGS are optional features to name the indices of the -// operator's inputs and outputs, in order to avoid confusion. For example, for -// a fully convolution layer that has input, weight and bias, you can define its -// input tags as: -// INPUT_TAGS(INPUT, WEIGHT, BIAS); -// And in the code, instead of doing -// auto& weight = Input(1); -// you can now do -// auto& weight = Input(WEIGHT); -// to make it more clear. -#define INPUT_TAGS(first_input, ...) \ - enum _InputTags { first_input = 0, __VA_ARGS__ } -#define OUTPUT_TAGS(first_input, ...) \ - enum _OutputTags { first_input = 0, __VA_ARGS__ } - -template -inline vector OperatorBase::GetRepeatedArgument( - c10::string_view name, - const vector& default_value) const { - if (isLegacyOperator()) { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::GetRepeatedArgument( - *operator_def_, name, default_value); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - auto index = argumentIndexWithName(name); - CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name); - const auto& value = newstyle_inputs_[index.value()]; - return GetVectorFromIValueList(value); -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif -} - -// We need this specialisation because IValue based lists don't support -// int16_t. We need to load it as List and transform to int16_t. -template <> -inline vector OperatorBase::GetRepeatedArgument( - c10::string_view name, - const vector& default_value) const { - if (isLegacyOperator()) { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::GetRepeatedArgument( - *operator_def_, name, default_value); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - auto index = argumentIndexWithName(name); - CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name); - const auto& value = newstyle_inputs_[index.value()]; - auto vec = GetVectorFromIValueList(value); - std::vector result; - result.reserve(vec.size()); - for (int64_t elem : vec) { - result.push_back(static_cast(elem)); - } - return result; -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif -} - -// Operator is the class that you usually want to derive, if your operator will -// run on different devices. You should then implement the RunOnDevice() -// function. -template -class Operator : public OperatorBase { - public: - explicit Operator(const OperatorDef& operator_def, Workspace* ws, StreamId stream = 0) - : OperatorBase(operator_def, ws), context_(operator_def.device_option()) { - // In the constructor, we switch to the device so that the child class - // constructors will run on that device. - context_.SwitchToDevice(stream); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - explicit Operator( - const c10::FunctionSchema& fn_schema, - std::vector inputs, - std::vector outputs, - StreamId stream = 0) - : OperatorBase(fn_schema, std::move(inputs), std::move(outputs)) { - // In the constructor, we switch to the device so that the child class - // constructors will run on that device. - context_.SwitchToDevice(stream); - } -#endif - ~Operator() noexcept override {} - - /// Retrieve a non-owning reference to the input at position 'idx' for this - /// operator. The returned reference is valid for the duration of the - /// RunOnDevice call. The optional 'type' parameter can be used to assert a - /// required device type for the input (by default, we assert that the tensor - /// is consistent with the device type implied by the Context parameter of an - /// Operator.) - inline const Tensor& Input( - int idx, - DeviceType type = Context::GetDeviceType()) { - return OperatorBase::template Input(idx, type); - } - - /// XOutput is a modernized version of Output which returns a Tensor - /// rather than a Tensor* (the raw pointer in the latter case is - /// useless, as Tensor is a pointer type.) - Tensor XOutput(int idx, at::IntArrayRef dims, at::TensorOptions options) { - // We'll default device to the device of the current Operator Context - if (options.device_opt() == c10::nullopt) { - return OperatorBase::XOutputTensor( - idx, dims, options.device(context_.device())); - } - return OperatorBase::XOutputTensor(idx, dims, options); - } - - /// Retrieve a non-owning pointer to the output at position 'idx', - /// initializing it to have size 'dims' and properties 'options' if - /// there is no pre-existing output or the pre-existing output does - /// not have the correct options. The returned pointer is valid for - /// the duration of the RunOnDevice call. If device is not explicitly - /// specified in options, we default to allocating output on the - /// current device of the device type implied by the Context parameter - /// of this Operator. - /// - /// Note [Operator::Output what?] - /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - /// The contract of Operator::Output is somewhat complex; it is perhaps better - /// understood in terms of what was historically an idiomatic Caffe2 operator - /// implementation: - /// - /// void RunOnDevice() override { - /// auto* output = Output(0, output_size, dtype()); - /// float* output_ptr = output->data(); - /// // write into output_ptr - /// } - /// - /// In the simple case, this code does the following things: - /// - /// 1. Allocates a new tensor with size 'output_size' and dtype 'float' - /// (and device type whatever the Operator's device type is) - /// 2. "Registers" this tensor as the 0th output tensor of this operator - /// (Caffe2 operators don't "return" outputs; instead, outputs - /// are shoved into an output vector which the executor reads out.) - /// 3. Returns the tensor, so the operator implementation can write - /// the actual output data into the tensor. - /// - /// So what's this business with "pre-existing" outputs? Caffe2 - /// commonly applies an optimization whereby it reuses tensors on - /// subsequent runs of operators in a graph. It doesn't know ahead - /// of time what intermediate tensors it will need, so the first - /// time it runs a graph it has all of the operators create the outputs - /// necessary (as described above). However, the second time around, - /// it will reuse all of the tensors created from the first time. - /// If they are lucky, this time the Output() call is a no-op and - /// just returns the old tensor. - /// - /// However, we cannot /guarantee/ that the output size will be the - /// same the next time the Operator is called; for example, output - /// size may be data dependent and vary between runs. In this case, - /// we have to resize it to the correct size. Resizing is still - /// helpful, as we may be able to fit the output in the same - /// space that was previously used. - /// - Tensor* Output(int idx, at::IntArrayRef dims, at::TensorOptions options) { - // We'll default device to the device of the current Operator Context - if (options.device_opt() == c10::nullopt) { - return OperatorBase::OutputTensor( - idx, dims, options.device(context_.device())); - } - return OperatorBase::OutputTensor(idx, dims, options); - } - - /// Legacy: please consider using the version of Output() which also takes - /// dtype and size as arguments. - inline Tensor* Output(int idx, DeviceType type = Context::GetDeviceType()) { - return OperatorBase::template Output(idx, type); - } - - /// Get the output Tensor of an operator (allocating it if it is not - /// already initialized), and copy the contents of src into it. - /// You probably don't actually want to use this function (the fact - /// that you have a Tensor to copy from is probably a mistake: - /// you should have written the output into the output tensor, - /// from Output, directly in the first place), but this method - /// is situationally useful. - Tensor* OutputTensorCopyFrom( - int idx, - at::TensorOptions options, - const Tensor& src, - bool async = false) { - if (options.device_opt() == c10::nullopt) { - return OperatorBase::OutputTensorCopyFrom( - idx, options.device(context_.device()), src, async); - } - return OperatorBase::OutputTensorCopyFrom(idx, options, src, async); - } - - void WaitEvent(const Event& ev, int stream_id = -1) final { - if (stream_id >= 0) { - context_.SwitchToDevice(stream_id); - } - context_.WaitEvent(ev); - } - - void WaitEvents(const std::vector& events, int stream_id = -1) - final { - if (stream_id >= 0) { - context_.SwitchToDevice(stream_id); - } - for (const auto& ev : events) { - context_.WaitEvent(*ev); - } - } - - // The run function of Operator switches to the device, and then carries out - // the actual computation with RunOnDevice(). You should implement RunOnDevice - // instead of Run(). - // Note: Run does not update operator's event and can be used only with - // non-async executors that do not rely on events - bool Run(int stream_id = 0) final { - try { - StartAllObservers(); - - context_.SwitchToDevice(stream_id); - - // Clear floating point exception flags before RunOnDevice. We will test - // exception flags afterwards, and raise an error if an exception has - // happened. - if (FLAGS_caffe2_operator_throw_if_fp_exceptions || - FLAGS_caffe2_operator_throw_if_fp_overflow_exceptions) { - std::feclearexcept(FE_ALL_EXCEPT); - } - -#ifdef __GNU_LIBRARY__ - // If glibc is available, use feenableexcept that will raise exception - // right away. - int old_enabled_exceptions = 0; - if (FLAGS_caffe2_operator_throw_on_first_occurrence_if_fp_exceptions) { - if (FLAGS_caffe2_operator_throw_if_fp_exceptions || - FLAGS_caffe2_operator_throw_if_fp_overflow_exceptions) { - int flag = 0; - if (FLAGS_caffe2_operator_throw_if_fp_exceptions) { - flag |= FE_DIVBYZERO | FE_INVALID; - } - if (FLAGS_caffe2_operator_throw_if_fp_overflow_exceptions) { - flag |= FE_OVERFLOW; - } - old_enabled_exceptions = feenableexcept(flag); - } - } -#endif - bool result = RunOnDevice(); -#ifdef __GNU_LIBRARY__ - if (FLAGS_caffe2_operator_throw_on_first_occurrence_if_fp_exceptions) { - if (FLAGS_caffe2_operator_throw_if_fp_exceptions || - FLAGS_caffe2_operator_throw_if_fp_overflow_exceptions) { - fedisableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW); - std::feclearexcept(FE_ALL_EXCEPT); - feenableexcept(old_enabled_exceptions); - } - } -#endif - if (FLAGS_caffe2_operator_throw_if_fp_exceptions) { - CAFFE_ENFORCE( - !std::fetestexcept(FE_DIVBYZERO), - "Division by zero floating point exception (FE_DIVBYZERO) reported."); - CAFFE_ENFORCE( - !std::fetestexcept(FE_INVALID), - "Invalid floating point exception (FE_INVALID) reported."); - } - if (FLAGS_caffe2_operator_throw_if_fp_overflow_exceptions) { - CAFFE_ENFORCE( - !std::fetestexcept(FE_OVERFLOW), - "Overflow floating point exception (FE_OVERFLOW) reported."); - } - if (!result) { - this->RecordLastFailedOpNetPosition(); - } - context_.FinishDeviceComputation(); // throws on error - - StopAllObservers(); - - return result; - } catch (EnforceNotMet& err) { - if (has_debug_def()) { - err.add_context( - "Error from operator: \n" + ProtoDebugString(debug_def())); - AddRelatedBlobInfo(&err); - } - this->RecordLastFailedOpNetPosition(); - StopAllObservers(); - throw; - } catch (...) { - this->RecordLastFailedOpNetPosition(); - StopAllObservers(); - throw; - } - } - - bool RunAsync(int stream_id = 0) final { - try { - StartAllObservers(); - - context_.SwitchToDevice(stream_id); - auto result = RunOnDevice(); - if (result) { - if (HasAsyncPart()) { - RecordEvent(); - } else { - // Manually set CPU operator's event status to finished, - // unless this is an async CPU operator - SetEventFinished(); - } - } else { - SetEventFinished(getErrorMsg().c_str()); - this->RecordLastFailedOpNetPosition(); - } - - StopAllObservers(); - - return result; - } catch (EnforceNotMet& err) { - if (has_debug_def()) { - err.add_context( - "Error from operator: \n" + ProtoDebugString(debug_def())); - AddRelatedBlobInfo(&err); - } - SetEventFinishedWithException(err.what()); - this->RecordLastFailedOpNetPosition(); - StopAllObservers(); - throw; - } catch (const std::exception& err) { - SetEventFinishedWithException(err.what()); - this->RecordLastFailedOpNetPosition(); - StopAllObservers(); - throw; - } catch (...) { - SetEventFinishedWithException(getErrorMsg().c_str()); - this->RecordLastFailedOpNetPosition(); - StopAllObservers(); - throw; - } - } - - bool IsStreamFree(int stream_id) const override { - return context_.IsStreamFree(device_option(), stream_id); - } - - virtual bool RunOnDevice() = 0; - - // Returns whether operator has async on device part. - // CUDA operators by default have async parts, CPU operators by default - // don't have async parts and are finished after RunOnDevice call. - // Events of operators that don't have async parts are automatically set - // to finished state by RunAsync. - // Defaulting to the value from context (true for CUDA, false for CPU). - // Override in case of async CPU operators - // Async CPU operators are expected to catch all exceptions in async parts - // and set Event to finished/failed state with Event::SetFinished or - // SetFinishedWithException call. - bool HasAsyncPart() const override { - return context_.HasAsyncPartDefault(); - } - - // Returns whether operator's RunOnDevice schedules async on device part and - // can be run without waiting for parent operator's async part to be finished - // on the same device. - // Note: when true, RunOnDevice must not access the content of the input blobs - // as they might not be computed yet - // Note: when true, operator's device needs to support async scheduling: - // - supports concept of streams: async ops scheduled on the same stream are - // guaranteed to be executed in the same order they were scheduled - // - provides non-blocking cross device/cross stream synchronization - // primitives - // - // By default, assuming an op with an async part can be scheduled - // asynchronously if device supports async scheduling - bool SupportsAsyncScheduling() const override { - return HasAsyncPart() && context_.SupportsAsyncScheduling(); - } - - void SyncDeviceBarrierForObservers() override { - context_.FinishDeviceComputation(); - } - - const Context* getContext() const { - return &context_; - } - Context* getContext() { - return &context_; - } - - protected: - void RecordEvent(const char* err_msg = nullptr) final { - if (event_) { - context_.Record(event_.get(), err_msg); - } - } - - Context context_; -}; - -#define USE_OPERATOR_BASE_FUNCTIONS \ - /* using override */ using OperatorBase::HasArgument; \ - /* using override */ using OperatorBase::GetSingleArgument; \ - /* using override */ using OperatorBase::HasSingleArgumentOfType; \ - /* using override */ using OperatorBase::GetRepeatedArgument; \ - /* using override */ using OperatorBase::InputIsType; \ - /* using override */ using OperatorBase::InputSize; \ - /* using override */ using OperatorBase::Output; \ - /* using override */ using OperatorBase::Input; \ - /* using override */ using OperatorBase::OutputSize; \ - /* using override */ using OperatorBase::IsInputOutputAlias; \ - /* using override */ using OperatorBase::OutputTensorAlias - -#define USE_OPERATOR_FUNCTIONS(context) \ - USE_OPERATOR_BASE_FUNCTIONS; \ - /* using override */ using Operator::context_; \ - /* using override */ using Operator::Input; \ - /* using override */ using Operator::InputBlob; \ - /* using override */ using Operator::Output; \ - /* using override */ using Operator::OutputBlob; \ - /* using override */ using Operator::OutputTensorCopyFrom - -#define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context) - -#define USE_SIMPLE_CTOR_DTOR(name) \ - template \ - explicit name(Args&&... args) \ - : Operator(std::forward(args)...) {} \ - virtual ~name() noexcept override {} - -// Helpers to implement runtime op polymorphism. Often it's convenient to make -// an op work on different input types (e.g. i32 vs i64 indices) or special-case -// it for particular input size (e.g. ScatterWeightedSum for block size of 1 -// doesn't need to call Eigen). -// -// DispatchHelper provides compile-time generation of nested "if" statements, -// e.g. `DispatchHelper>::call(this, block_size);` -// unrolls into: -// if (block_size == 1) { -// return DoRunWithValue<1>(); -// } else if (block_size = 4) { -// return DoRunWithValue<4>(); -// } else { -// return DoRunWithValue<-1>(); -// }` -// -// DoRunWithValue implementation can use template arguments to do "if" -// statements -// or proxy to functions in math.h which often provide fixed size -// implementation. -// -// Similarly `TensorTypes(this, Input(0))` provides branching -// based on type of the first input and calls DoRunWithType. -// -// Note, that the same instance of Op class is used as the method, not class is -// templated. We might consider adding static class-level polymorphism later. -// -// Convenient macro USE_DISPATCH_HELPER is provided for declaring friendship in -// case DoRunWithValue or DoRunWithType are declared non-public. - -#define USE_DISPATCH_HELPER \ - template \ - friend struct DispatchHelper - -template -struct FixedValues {}; - -template -struct TensorTypes {}; - -// Special tag that can be listed in TensorTypes to denote that a special -// implementation in 'RunWithOtherType' needs to be called instead of failing -// Obviously this needs to be the last item in lists, e.g. -// TensorTypes -struct GenericTensorImplementation {}; - -// Same as TensorTypes but call DoRunWithType2 -template -struct TensorTypes2 {}; - -template -struct DispatchHelper; - -template -struct DispatchHelper, ExtraArgs...> { - template - static bool call(Op* op, int value) { - if (FirstVal == value) { - return op->template DoRunWithValue(); - } - return DispatchHelper, ExtraArgs...>::template call< - Op>(op, value); - } -}; - -template -struct DispatchHelper, ExtraArgs...> { - template - static bool call(Op* op, int64_t /*size*/) { - return op->template DoRunWithValue(); - } -}; - -#define C10_DEFINE_TENSOR_TYPES_DISPATCHER( \ - TensorTypes, DoRunWithType, DoRunWithOtherType) \ - template \ - struct DispatchHelper, ExtraArgs...> { \ - template \ - static bool call(Op* op, const TypeMeta meta) { \ - static_assert( \ - !std::is_same::value, \ - "GenericTensorImplementation must be the last in TensorTypes list"); \ - if (meta.Match()) { \ - return op->template DoRunWithType(); \ - } \ - return DispatchHelper, ExtraArgs...>:: \ - template call(op, meta); \ - } \ - template \ - static bool call(Op* op, const Tensor& tensor) { \ - return call(op, tensor.dtype()); \ - } \ - template \ - static bool call(Op* op, const Blob& blob) { \ - return call(op, blob.meta()); \ - } \ - }; \ - \ - template \ - struct DispatchHelper, ExtraArgs...> { \ - template \ - static bool call(Op* /* unused */, const TypeMeta meta) { \ - CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \ - } \ - template \ - static bool call(Op* op, const Tensor& tensor) { \ - return call(op, tensor.dtype()); \ - } \ - template \ - static bool call(Op* op, const Blob& blob) { \ - return call(op, blob.meta()); \ - } \ - }; \ - \ - template \ - struct DispatchHelper< \ - TensorTypes, \ - ExtraArgs...> { \ - template \ - static bool call(Op* op, const TypeMeta) { \ - return op->template DoRunWithOtherType(); \ - } \ - template \ - static bool call(Op* op, const Tensor& tensor) { \ - return call(op, tensor.dtype()); \ - } \ - template \ - static bool call(Op* op, const Blob& blob) { \ - return call(op, blob.meta()); \ - } \ - }; -C10_DEFINE_TENSOR_TYPES_DISPATCHER( - TensorTypes, - DoRunWithType, - DoRunWithOtherType) -C10_DEFINE_TENSOR_TYPES_DISPATCHER( - TensorTypes2, - DoRunWithType2, - DoRunWithOtherType2) -#undef C10_DEFINE_TENSOR_TYPES_DISPATCHER - -// The device type registry. This works in two phases: -// (1) gDeviceTypeRegistry() maps the device types values to the actual operator -// registry function. -// (2) Then, one can call the operator registry function to further create the -// operators. -typedef c10::Registry< - std::string, - std::unique_ptr, - const OperatorDef&, - Workspace*> - OperatorRegistry; -typedef c10::Registry< - std::string, - std::unique_ptr, - const OperatorDef&, - Workspace*>* (*RegistryFunction)(); -TORCH_API std::map* gDeviceTypeRegistry(); - -struct TORCH_API DeviceTypeRegisterer { - explicit DeviceTypeRegisterer(DeviceType type, RegistryFunction func); -}; - -#if defined(_MSC_VER) -#define IMPORT_IF_NOT_MSVC -#else -#define IMPORT_IF_NOT_MSVC C10_IMPORT -#endif - -#define CAFFE_REGISTER_DEVICE_TYPE(type, registry_function) \ - namespace { \ - static DeviceTypeRegisterer C10_ANONYMOUS_VARIABLE( \ - DeviceType)(type, ®istry_function); \ - } - -// The operator registry. Since we are not expecting a great number of devices, -// we will simply have an if-then type command and allocate the actual -// generation to device-specific registerers. -// Note that although we have CUDA and CUDNN here, the registerers themselves do -// not depend on specific cuda or cudnn libraries. This means that we will be -// able to compile it even when there is no cuda available - we simply do not -// link any cuda or cudnn operators. -C10_DECLARE_REGISTRY( - CPUOperatorRegistry, - OperatorBase, - const OperatorDef&, - Workspace*); -#define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \ - C10_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__) -#define REGISTER_CPU_OPERATOR(name, ...) \ - IMPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CPU##name() { \ - CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - } \ - C10_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__) -#define REGISTER_CPU_OPERATOR_STR(str_name, ...) \ - C10_REGISTER_TYPED_CLASS(CPUOperatorRegistry, str_name, __VA_ARGS__) - -#define REGISTER_CPU_OPERATOR_WITH_ENGINE(name, engine, ...) \ - C10_REGISTER_CLASS(CPUOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) - -// Use these macros to register gradient operators. They can be automatically -// excluded from builds that don't need them (e.g., mobile). -#ifdef CAFFE2_NO_GRADIENT_OPS -#define REGISTER_CPU_GRADIENT_OPERATOR(...) /* No gradients. */ -#else -#define REGISTER_CPU_GRADIENT_OPERATOR(...) \ - C10_MACRO_EXPAND(REGISTER_CPU_OPERATOR(__VA_ARGS__)) -#endif - -#ifdef CAFFE2_NO_GRADIENT_OPS -#define REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE(...) /* No gradients. */ -#else -#define REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE(...) \ - C10_MACRO_EXPAND(REGISTER_CPU_OPERATOR_WITH_ENGINE(__VA_ARGS__)) -#endif - -C10_DECLARE_REGISTRY( - CUDAOperatorRegistry, - OperatorBase, - const OperatorDef&, - Workspace*); -#define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \ - C10_REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__) -#define REGISTER_CUDA_OPERATOR(name, ...) \ - IMPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CUDA##name() { \ - CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - } \ - C10_REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__) -#define REGISTER_CUDA_OPERATOR_STR(str_name, ...) \ - C10_REGISTER_TYPED_CLASS(CUDAOperatorRegistry, str_name, __VA_ARGS__) - -#define REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, engine, ...) \ - C10_REGISTER_CLASS(CUDAOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) - -// Macros for cudnn since we use it often -#define REGISTER_CUDNN_OPERATOR(name, ...) \ - REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__) - -// Macros for HIP operators -C10_DECLARE_REGISTRY( - HIPOperatorRegistry, - OperatorBase, - const OperatorDef&, - Workspace*); -#define REGISTER_HIP_OPERATOR_CREATOR(key, ...) \ - C10_REGISTER_CREATOR(HIPOperatorRegistry, key, __VA_ARGS__) -#define REGISTER_HIP_OPERATOR(name, ...) \ - IMPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_HIP##name() { \ - CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - } \ - C10_REGISTER_CLASS(HIPOperatorRegistry, name, __VA_ARGS__) -#define REGISTER_HIP_OPERATOR_STR(str_name, ...) \ - C10_REGISTER_TYPED_CLASS(HIPOperatorRegistry, str_name, __VA_ARGS__) - -#define REGISTER_HIP_OPERATOR_WITH_ENGINE(name, engine, ...) \ - C10_REGISTER_CLASS(HIPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) - -#define REGISTER_MIOPEN_OPERATOR(name, ...) \ - REGISTER_HIP_OPERATOR_WITH_ENGINE(name, MIOPEN, __VA_ARGS__) \ - REGISTER_HIP_OPERATOR_WITH_ENGINE( \ - name, CUDNN, __VA_ARGS__) // Make CUDNN an alias of MIOPEN for HIP ops - -// StaticLinkingProtector is a helper class that ensures that the Caffe2 -// library is linked correctly with whole archives (in the case of static -// linking). What happens is that when CreateOperator is called for the first -// time, it instantiates an OperatorLinkingProtector object to check if the -// operator registry is empty. If it is empty, this means that we are not -// properly linking the library. -// -// You should not need to use this class. -struct StaticLinkingProtector { - StaticLinkingProtector() { - const auto registered_ops = CPUOperatorRegistry()->Keys().size(); - // Note: this is a check failure instead of an exception, because if - // the linking is wrong, Caffe2 won't be able to run properly anyway, - // so it's better to fail loud. - // If Caffe2 is properly linked with whole archive, there should be more - // than zero registered ops. - if (registered_ops == 0) { - LOG(FATAL) - << "You might have made a build error: the Caffe2 library does not seem " - "to be linked with whole-static library option. To do so, use " - "-Wl,-force_load (clang) or -Wl,--whole-archive (gcc) to link the " - "Caffe2 library."; - } - } -}; - -// An exception that can be thrown by an operator constructor that notifies -// that it does not support the given setting. This can be usually used for -// specific engines that only implement a subset of the features required by -// the original operator schema. -// TODO(jiayq): make more feature-complete exception message. -class TORCH_API UnsupportedOperatorFeature : public std::exception { - public: - UnsupportedOperatorFeature(const string& msg) : msg_(msg) {} - const char* what() const noexcept override { - return msg_.c_str(); - } - - private: - string msg_; -}; - -// A helper macro that should ONLY be used in the operator constructor to check -// if needed features are met. If not, throws the UnsupportedOperatorFeature -// exception with the given message. -#define OPERATOR_NEEDS_FEATURE(condition, ...) \ - if (!(condition)) { \ - throw UnsupportedOperatorFeature(::c10::str(__VA_ARGS__)); \ - } - -// Creates an operator with the given operator definition. -// Throws on error and never returns nullptr -TORCH_API unique_ptr CreateOperator( - const OperatorDef& operator_def, - Workspace* ws, - int net_position = OperatorBase::kNoNetPositionSet); - -TORCH_API const std::string OpRegistryKey( - const std::string& op_type, - const std::string& engine = ""); - -// User can set the preferred engines as a list of engine names, in -// descending order of preference. -using EnginePrefType = std::vector; -// {device_type -> {operator_name -> EnginePrefType}} -using PerOpEnginePrefType = - CaffeMap>; -// {device_type -> EnginePrefType} -using GlobalEnginePrefType = CaffeMap; -TORCH_API void SetPerOpEnginePref( - const PerOpEnginePrefType& per_op_engine_pref); -TORCH_API void SetGlobalEnginePref( - const GlobalEnginePrefType& global_engine_pref); -TORCH_API void SetEnginePref( - const PerOpEnginePrefType& per_op_engine_pref, - const GlobalEnginePrefType& global_engine_pref); -TORCH_API void SetOpEnginePref( - const std::string& op_type, - const CaffeMap& op_pref); - -TORCH_API void LoadInt8TensorInfoOfBlob( - std::vector* scale, - std::vector* offset, - uint32_t* axis, - const Blob* b); - -TORCH_API TensorShape GetTensorShapeOfBlob(const Blob* b); - -TORCH_API TensorShapes InferBlobShapesAndTypes( - CaffeMap& blob_desc, - const vector& nets); - -TORCH_API TensorShapes InferBlobShapesAndTypesFromWorkspace( - Workspace* ws, - const vector& nets); - -TORCH_API TensorShapes InferBlobShapesAndTypesFromMap( - const CaffeMap>& blob_dimensions, - const vector& nets); - -TORCH_API TensorShapes InferBlobShapesAndTypesFromMap( - const CaffeMap>& blob_dimensions, - const CaffeMap& blob_types, - const vector& nets); - -TORCH_API std::map> -ValidateTensorDevices(OperatorBase& op, const OperatorDef& op_def); - -// Get a set of registered operator names -TORCH_API std::set GetRegisteredOperators(); - -// Operator logging capabilities -TORCH_API void SetOperatorLogger( - std::function tracer); -std::function GetOperatorLogger(); - -#ifndef C10_MOBILE -// This is for transferring tensor data between C2 and backends. -struct ExternalTensorDescriptor { - uint64_t dataType; - uint32_t dimensions; - const uint64_t* shape; - uint8_t isOffline = 0; - uint32_t quantizationAxis; - uint64_t quantizationParams; - const float* scales; - const int32_t* biases; - uint64_t buffer; -}; - -class ExternalTensorFunctionsBase { - public: - explicit ExternalTensorFunctionsBase() {} - virtual ~ExternalTensorFunctionsBase() {} - virtual bool isQuantized() const = 0; - virtual bool IsSameMetaType(TypeIdentifier id) = 0; - virtual void SetupExternalTensorDescriptor( - const Blob* blob, - std::vector>* shapes, - std::vector>* all_scales, - std::vector>* all_offsets, - ExternalTensorDescriptor* desc) = 0; - virtual void LoadInfoOfBlob( - const Blob* blob, - std::vector* scale, - std::vector* offset, - uint32_t* axis) = 0; - virtual TypeIdentifier GetTypeMetaId() = 0; - virtual TypeMeta GetExternalTensorType(const void* c) = 0; - virtual vector GetExternalTensorInfo( - const void* c, - size_t* capacity, - DeviceOption* device) = 0; -}; - -C10_DECLARE_TYPED_REGISTRY( - ExternalTensorFunctionsBaseRegistry, - TypeIdentifier, - ExternalTensorFunctionsBase, - std::unique_ptr); - -#define REGISTER_EXTERNAL_TENSOR_FUNCTIONS(id, ...) \ - C10_REGISTER_TYPED_CLASS(ExternalTensorFunctionsBaseRegistry, id, __VA_ARGS__) -inline unique_ptr CreateExternalTensorFunctions( - TypeIdentifier id) { - return ExternalTensorFunctionsBaseRegistry()->Create(id); -} -#endif // C10_MOBILE - -} // namespace caffe2 - -C10_CLANG_DIAGNOSTIC_POP() - -#endif // CAFFE2_CORE_OPERATOR_H_ diff --git a/caffe2/core/operator_gradient.h b/caffe2/core/operator_gradient.h deleted file mode 100644 index 5c8d97a38fd2..000000000000 --- a/caffe2/core/operator_gradient.h +++ /dev/null @@ -1,337 +0,0 @@ -#ifndef CAFFE2_CORE_OPERATOR_GRADIENT_H_ -#define CAFFE2_CORE_OPERATOR_GRADIENT_H_ - -#include "c10/util/Registry.h" -#include "caffe2/core/operator_schema.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/proto_utils.h" - -namespace caffe2 { - -/* @brief A struct that abstracts on top of dense and sparse blobs. - * - * For a dense blob, its gradient name should be written into dense_, and for - * a sparse blob, its gradient name should be written into indice_ for - * the sparse indices and value_ for the values. - */ -struct TORCH_API GradientWrapper { - string dense_; - string indices_; - string values_; - - inline bool IsDense() const { - return (dense_.size() != 0); - } - inline bool IsSparse() const { - return (indices_.size() != 0 || values_.size() != 0); - } - inline bool IsEmpty() const { - return (!IsDense() && !IsSparse()); - } -}; - -/** - * A struct that holds the gradient operators and related gradient maps. - */ -struct TORCH_API GradientOpsMeta { - vector ops_; - vector g_input_; - - GradientOpsMeta() {} - GradientOpsMeta( - const vector& ops, - const vector& v) - : ops_(ops), g_input_(v) {} -}; - -class TORCH_API GradientMakerBase { - public: - GradientMakerBase( - const OperatorDef& def, - const vector& g_output) - : def_(def), g_output_(g_output), g_input_(def.input_size()){}; - virtual ~GradientMakerBase() {} - virtual bool CopyDeviceOption() const { - return true; - } - virtual bool CopyEngine() const { - return true; - } - virtual bool CopyArguments() const { - return true; - } - - virtual void VerifyOp() const { - auto* schema = OpSchemaRegistry::Schema(def_.type()); - if (schema) { - CAFFE_ENFORCE( - schema->Verify(def_), - "(GradientMaker) Operator def did not pass schema checking: ", - ProtoDebugString(def_)); - } - } - - /** - * @brief Returns the gradient ops meta. - * - * If your gradient op generator only use standard input and output - * manipulations, you can simply implement GetGradientDefs() that - * returns vector. In that, you can call GI, GI_V and GI_I - * that will automatically create the gradient registration for you. - * - * If you need to do custom gradient name registration, overload this - * function directly. - */ - virtual GradientOpsMeta Get() { - VerifyOp(); - vector new_defs = GetGradientDefs(); - for (auto& opdef : new_defs) { - opdef.set_is_gradient_op(true); - } - return GradientOpsMeta(new_defs, g_input_); - }; - - const OperatorDef& Def() const { - return def_; - } - - protected: - virtual vector GetGradientDefs() { - CAFFE_NOT_IMPLEMENTED; - } - - // Helper functions to return names for the gradient computation. - // I(idx), O(idx): return the input and output names. - // GO(idx): return the name of the gradient for output idx. - // GI(idx), GI_I(idx), GI_V(idx): return the name of the gradient for - // input idx, and also registers that name into the gradient - // registry to be returned. - string I(const int i) { - CAFFE_ENFORCE((i >= 0) && (i < def_.input().size())); - return def_.input(i); - } - string O(const int i) { - CAFFE_ENFORCE((i >= 0) && (i < def_.output().size())); - return def_.output(i); - } - string GI(const int i) { - CAFFE_ENFORCE( - !g_input_.at(i).IsSparse(), - "Input ", - def_.input(i), - " already set to sparse."); - g_input_.at(i).dense_ = GradientName(def_.input(i)); - return GradientName(def_.input(i)); - } - string GI_I(const int i) { - CAFFE_ENFORCE( - !g_input_.at(i).IsDense(), - "Input ", - def_.input(i), - " already set to dense."); - g_input_.at(i).indices_ = GradientSliceIndices(def_.input(i)); - return GradientSliceIndices(def_.input(i)); - } - string GI_V(const int i) { - CAFFE_ENFORCE( - !g_input_.at(i).IsDense(), - "Input ", - def_.input(i), - " already set to dense."); - g_input_.at(i).values_ = GradientSliceValues(def_.input(i)); - return GradientSliceValues(def_.input(i)); - } - string GO(const int i) { - CAFFE_ENFORCE( - g_output_.at(i).IsDense(), - "Gradient of output ", - def_.output(i), - (g_output_.at(i).IsSparse() ? " is sparse (expected dense)." - : " is not provided!")); - return g_output_.at(i).dense_; - } - string GO_I(const int i) { - CAFFE_ENFORCE( - g_output_.at(i).IsSparse(), - "Gradient of output ", - def_.output(i), - (g_output_.at(i).IsDense() ? " is dense (expected sparse)." - : " is not provided!")); - return g_output_.at(i).indices_; - } - string GO_V(const int i) { - CAFFE_ENFORCE( - g_output_.at(i).IsSparse(), - "Gradient of output ", - def_.output(i), - (g_output_.at(i).IsDense() ? " is dense (expected sparse)." - : " is not provided!")); - return g_output_.at(i).values_; - } - const GradientWrapper& GradOut(int i) { - return g_output_.at(i); - } - - // Function to add a gradient pair to map. - void SetDense(const int i, const string& name) { - CAFFE_ENFORCE( - !g_input_.at(i).IsSparse(), - "Input ", - def_.input(i), - " already set to sparse."); - g_input_.at(i).dense_ = name; - } - void SetSparse(const int i, const string& indices, const string& values) { - CAFFE_ENFORCE( - !g_input_.at(i).IsDense(), - "Input ", - def_.input(i), - " already set to dense."); - g_input_.at(i).indices_ = indices; - g_input_.at(i).values_ = values; - } - - /** - * @brief a helper function to allow one to create one single operator - * def, which is usually the case for many simple operators. - */ - template - inline static vector SingleGradientDef(const Args&... args) { - return vector{CreateOperatorDef(args...)}; - } - - public: - /** - * Returns map that returns the parameters that the gradients are for. - */ - static CaffeMap MatchGradsToParams(const OperatorDef& op) { - // NOTE: how to go beyond string-matching? - CaffeMap m; - for (auto& out : op.output()) { - if (IsGradientBlob(out)) { - m[out] = out.substr(0, out.length() - 5); - } - } - return m; - } - - private: - // Utility functions for gradient name computation. We don't expose them - // in order to discourage the use of such names explicitly. - static string GradientName(const string& name) { - return name + "_grad"; - } - - static bool IsGradientBlob(const string& name) { - return name.length() > 5 && name.find("_grad") == name.length() - 5; - } - - static string GradientNameToParam(const string& name) { - CHECK(IsGradientBlob(name)); - return name.substr(0, name.length() - 5); - } - - static string GradientSliceIndices(const string& name) { - return name + "_grad_indices"; - } - - static string GradientSliceValues(const string& name) { - return name + "_grad_values"; - } - - protected: - // We make the member variables protected in case someone wants to write - // a fully custom Get() function. - const OperatorDef& def_; - const vector& g_output_; - vector g_input_; -}; - -/** - * @brief A helper class to indicate that the operator does not need gradient - * computation. - * - * Use the macro NO_GRADIENT to register operators that do not have gradients. - * Note that this is different fron SHOULD_NOT_DO_GRADIENT: the latter means - * that the gradient computation should not flow through it at all, and throws - * an error if it is called. - */ -class TORCH_API NoGradient : public GradientMakerBase { - using GradientMakerBase::GradientMakerBase; - vector GetGradientDefs() override { - return vector(); - } -}; - -/** - * @brief A helper class to indicate that the operator should have no gradient. - * - * This is used when the operator definition is designed to not have a gradient. - * Calling a gradient on this operator def will cause Caffe2 to quit. - */ -struct ThrowInTheTowelIfGradientIsCalled : public GradientMakerBase { - using GradientMakerBase::GradientMakerBase; - GradientOpsMeta Get() override { - CAFFE_THROW("One should not call gradient for operator ", def_.type(), "."); - } -}; - -/** - * @brief A helper class to indicate that the gradient mechanism is not ready. - * - * This should only be used sparsely when the gradient does exist, but we have - * not implemented it yet and are using this as a lazy excuse. Eventually, a - * gradient operator should be implemented. - */ -struct GradientNotImplementedYet : public GradientMakerBase { - using GradientMakerBase::GradientMakerBase; - GradientOpsMeta Get() override { - CAFFE_THROW( - "Operator ", - def_.type(), - " should have a gradient but is not implemented yet."); - } -}; - -C10_DECLARE_REGISTRY( - GradientRegistry, - GradientMakerBase, - const OperatorDef&, - const vector&); - -#ifdef CAFFE2_NO_GRADIENT_OPS - -#define REGISTER_GRADIENT(name, ...) /* No gradients. */ -#define REGISTER_GRADIENT_STR(str_name, ...) /* No gradients. */ - -#else - -#define REGISTER_GRADIENT(name, ...) \ - C10_REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__) -#define REGISTER_GRADIENT_STR(str_name, ...) \ - C10_REGISTER_TYPED_CLASS(GradientRegistry, str_name, __VA_ARGS__) - -#endif - -// NO_GRADIENT means that the operator does not need any gradient computation. -#define NO_GRADIENT(name) REGISTER_GRADIENT(name, NoGradient) - -// SHOULD_NOT_DO_GRADIENT means that the operator is not designed to have -// gradient operators. If you attempt to call the gradient, a log fatal will -// occur. -#define SHOULD_NOT_DO_GRADIENT(name) \ - REGISTER_GRADIENT(name, ThrowInTheTowelIfGradientIsCalled) - -#define GRADIENT_NOT_IMPLEMENTED_YET(name) \ - REGISTER_GRADIENT(name, GradientNotImplementedYet) - -/** - * @brief Gets the GradientOpsMeta for the given operator def. - */ -TORCH_API GradientOpsMeta GetGradientForOp( - const OperatorDef& def, - const vector& g_output); - -} // namespace caffe2 - -#endif // CAFFE2_CORE_OPERATOR_GRADIENT_H_ diff --git a/caffe2/core/operator_schema.h b/caffe2/core/operator_schema.h deleted file mode 100644 index f5b9d0dc09a2..000000000000 --- a/caffe2/core/operator_schema.h +++ /dev/null @@ -1,612 +0,0 @@ -#ifndef CAFFE2_CORE_OPERATOR_SCHEMA_H_ -#define CAFFE2_CORE_OPERATOR_SCHEMA_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace caffe2 { - -// A const value returned by OpSchema::CalculateOutput() if the number of -// output cannot be determined. -constexpr int kCannotComputeNumOutputs = -1; - -/** - * @brief A class to record the schema of an op. - * - * OpSchema records the common interface of an op specified by its name. This - * is optional for each operator implemented in Caffe2 but is strongly - * recommended. - * - * To register an OpSchema, one can use the macro OPERATOR_SCHEMA(name) and - * then append the various functions in the class. For example, for an op - * that takes in two inputs, one output, and the first input and output - * could be in-place, can be written as - * - * OPERATOR_SCHEMA(name) - * .NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}}); - */ -class TORCH_API OpSchema { - public: - OpSchema() : OpSchema("unknown", "unknown", 0) {} - OpSchema(const string& type, const string& file, const int line); - - /** - * @brief Returns the file that the op schema is registered from. - */ - inline const string& file() const { - return file_; - } - - /** - * @brief Returns the line in file that the op schema is registered from. - */ - inline int line() const { - return line_; - } - - /** - * @brief Returns the docstring of the op schema. - */ - inline const char* doc() const { - return doc_.empty() ? nullptr : doc_.c_str(); - } - - /** - * @brief Verifies if an operator definition protobuf matches the pattern - * specified in the schema. - */ - bool Verify(const OperatorDef& def) const; - - // Functions to set the property of the operator schemas. - // Sets the number of inputs, either a fixed number or a min and a max. - - /** - * @brief A single input. - */ - OpSchema& NumInputs(int n); - /** - * @brief Input could be in range [min, max], inclusive. - */ - OpSchema& NumInputs(int min, int max); - /** - * @brief Input could be one of the values specified in allowed_input_nums. - */ - OpSchema& NumInputs(set allowed_input_nums); - /** - * @brief Input is checked with a specified function. - */ - OpSchema& NumInputs(std::function func); - - // Sets the number of outputs, either a fixed number, a min and a max, - // or a function that takes in the input number and produces an output - // number. Use only one function in the set below. - /** - * @brief A single output. - */ - OpSchema& NumOutputs(int n); - /** - * @brief Output could be in range [min, max], inclusive. - */ - OpSchema& NumOutputs(int min, int max); - /** - * @brief Output could be one of the values specified in allowed_output_nums. - */ - OpSchema& NumOutputs(set allowed_output_nums); - /** - * @brief Output is checked with a specified function. - */ - OpSchema& NumOutputs(std::function func); - - /** - * @brief Relationship between inputs and outputs is checked with a specified - * function. - */ - OpSchema& NumInputsOutputs(std::function func); - - // Set the function that can calculate the number of output based on the - // number of input. Use only one function in the set below. - /** - * @brief Set the output calculator to a user-defined function. - */ - OpSchema& OutputCalculator(std::function calc); - /** - * @brief Set the number of outputs to be the same as the number of inputs. - */ - OpSchema& SameNumberOfOutput(); - - // Sets the rule to allow optional in-place operation. - OpSchema& AllowInplace(std::function inplace); - OpSchema& AllowInplace(set> inplace); - OpSchema& AllowOneToOneInplace(); - // Sets the rule to enforce in-place operation. - OpSchema& EnforceInplace(std::function inplace); - OpSchema& EnforceInplace(set> inplace); - OpSchema& EnforceOneToOneInplace(); - - // Functions to deal with type and shape inference. Basically, this registers - // a function that takes in an OperatorDef and a series of input type and - // shape specified by TensorProto objects (whose data fields are empty), and - // produces a series of output type and shape. - typedef std::function< - vector(const OperatorDef&, const vector&)> - TensorInferenceFunctionType; - - /** - * @brief Sets the tensor inference function, which is a std::function object - * defined in operator_schema.h. - */ - OpSchema& TensorInferenceFunction(TensorInferenceFunctionType function); - - /** - * A wrapper that makes an infer tensor function to return unknown - * shape for all outputs if any one of the inputs has unknown shape - */ - - static TensorInferenceFunctionType NeedsAllInputShapes( - TensorInferenceFunctionType f); - - /** - * @brief Sets the corresponding onnx schema name - */ - OpSchema& InheritOnnxSchema(const std::string& onnx_schema_name); - - /** - * @brief Shortcut to InheritOnnxSchema(type_) - */ - OpSchema& InheritOnnxSchema() { - return InheritOnnxSchema(type_); - } - - /** - * @brief Sets the tensor inference function to produce the same output as - * the input. - */ - OpSchema& IdenticalTypeAndShape(); - OpSchema& IdenticalTypeAndShapeOfInput(int idx); - OpSchema& IdenticalTypeAndShapeOfInputDim(int idx, int dim); - OpSchema& IdenticalTypeAndShapeOfMultipleInputs(const vector& indices); - OpSchema& ScalarType(::caffe2::TensorProto_DataType dt); - - /** - * @brief A function to allow one to infer the type and shape from the op - * schema. - */ - inline vector InferTensor( - const OperatorDef& def, - const vector& input_type_shape) const { - CAFFE_ENFORCE( - Verify(def), - "(InferTensor) Operator def did not pass schema checking: ", - ProtoDebugString(def)); - return tensor_inference_function_(def, input_type_shape); - } - - /* - * @brief A struct to store various cost information about - * an operator such as FLOPs, total memory use and parameters. - */ - struct Cost { - uint64_t flops{0}; // Floating point operations. - uint64_t bytes_read{0}; // Total memory read. - uint64_t bytes_written{0}; // Total memory written. - uint64_t params_bytes{0}; // Memory read for parameters. - }; - /** - * @brief Registers a function that takes in an OperatorDef - * and a series of input shapes and returns the total "cost" - * required to run the operator via struct by value. - */ - typedef std::function< - struct Cost(const OperatorDef&, const vector&)> - CostInferenceFunctionType; - - /** - * @brief Register the Cost inference function. - */ - OpSchema& CostInferenceFunction(CostInferenceFunctionType function); - -#if 0 // def _MSC_VER - /** - * @brief Register the Cost inference function via a pointer. - */ - template :value - >:type> - inline OpSchema& CostInferenceFunction(T func) { - // Note: This is here in order to resolve an MSVC compiler issue: it - // does not automatically convert a function pointer to a std::function, - // and needs an explicit conversion. - return CostInferenceFunction(CostInferenceFunctionType(func)); - } -#endif // _MSC_VER - - bool HasCostInferenceFunction() const { - return !!cost_inference_function_; - } - - inline struct Cost InferCost( - const OperatorDef& def, - const vector& input_tensor_shape) const { - CAFFE_ENFORCE( - cost_inference_function_, "Cost inference function not defined."); - return (*cost_inference_function_)(def, input_tensor_shape); - } - - // Functions to do documentation for the operator schema. - OpSchema& SetDoc(const string& doc); - - struct Argument { - Argument(const char* name, const char* description, bool required) - : name_{name}, description_{description}, required_{required} {} - - const char* name() const { - return name_; - } - - const char* description() const { - return description_; - } - - bool is_required() const { - return required_; - } - - private: - const char* name_; - const char* description_; - const bool required_; - }; - - OpSchema& - Arg(const char* name, const char* description, bool required = false); - -#define DECLARE_STANDARD_ARG(name, str) \ - static const char* Arg_##name; \ - OpSchema& Arg##name(const char* description); - - DECLARE_STANDARD_ARG(IsTest, is_test) - -#undef DECLARE_STANDARD_ARG - - OpSchema& Input(const int n, const char* name, const char* description); - OpSchema& Output(const int n, const char* name, const char* description); - // Calls the passed function with `this` as an argument. Useful for - // adding docs for templated/macro ops. - OpSchema& FillUsing(std::function populator); - - // Remove from documentation - OpSchema& Private(); - - // This op can pass data across devices - OpSchema& InputsCanCrossDevices(); - - /** - * @brief A function to allow one to get the number of outputs based on the - * number of inputs, if this schema supports it. - */ - int CalculateOutput(int num_input) const; - - const std::string& onnx_schema() const { - return onnx_schema_; - } - - int min_input() const { - return min_input_; - } - - int max_input() const { - return max_input_; - } - - int min_output() const { - return min_output_; - } - - int max_output() const { - return max_output_; - } - - bool num_inputs_allowed(int x) const { - return num_inputs_allowed_(x); - } - - bool num_outputs_allowed(int x) const { - return num_outputs_allowed_(x); - } - - bool num_inputs_outputs_allowed(int x, int y) const { - return num_inputs_outputs_allowed_(x, y); - } - - int inf() const { - return std::numeric_limits::max(); - } - - bool inplace_enforced(int x, int y) const { - return inplace_enforced_(x, y); - } - - TORCH_API friend std::ostream& operator<<( - std::ostream& out, - const OpSchema& schema); - - const std::vector& args() const { - return args_; - } - - const std::vector>& input_desc() const { - return input_desc_; - } - const std::vector>& output_desc() const { - return output_desc_; - } - bool private_op() { - return private_; - } - bool inputs_can_cross_devices() const { - return inputs_can_cross_devices_; - } - - /** - * @brief Returns the required device location of inputs and outputs. - */ - using DeviceInferenceFunctionType = std::function< - std::pair, std::vector>( - const OperatorDef& def)>; - - OpSchema& DeviceInferenceFunction(DeviceInferenceFunctionType function); - - /** - * @brief Infer required device location of an op's inputs and outputs - */ - inline std::pair, std::vector> - InferDevice(const OperatorDef& def) const { - return device_inference_function_(def); - } - - // The helper is build sparse input with values, keys, weights and lengths; - // e.g.: - // values = [1, 2, 3, 2, 4, 6, 7, 3, 6] - // keys = [0, 1, 4, 0, 1, 2, 5, 1, 2] - // weights = [1, 2, 3, 4, 5, 6, 7, 8, 9] - // \_____/ \________/ \__/ - // lengths = [3, 4, 2] - OpSchema& WeightedValueKeyLengthInputFillers( - size_t value_index, - size_t key_index, - size_t length_index, - size_t weight_index); - - // The helper is build sparse input with values, keys, weights and lengths; - // e.g.: - // values = [1, 2, 3, 2, 4, 6, 7, 3, 6] - // keys = [0, 1, 4, 0, 1, 2, 5, 1, 2] - // \_____/ \________/ \__/ - // lengths = [3, 4, 2] - OpSchema& ValueKeyLengthInputFillers( - size_t value_index, - size_t key_index, - size_t length_index); - - // The helper is build sparse input with values and lengths; e.g.: - // values = [1, 2, 3, 2, 4, 6, 7, 3, 6] - // \_____/ \________/ \__/ - // lengths = [3, 4, 2] - OpSchema& ValueLengthInputFillers(size_t value_index, size_t length_index); - - OpSchema& DisallowInputFillers(); - - std::vector InputFillers( - const std::vector>& shapes) const; - - private: - std::vector SupplyDenseFillers( - const std::vector>& shapes); - - private: - string type_; - string file_; - string doc_; - string onnx_schema_; - std::vector args_{}; - std::vector> input_desc_{}; - std::vector> output_desc_{}; - int line_ = 0; - int min_input_ = 0; - int max_input_ = std::numeric_limits::max(); - int min_output_ = 0; - int max_output_ = std::numeric_limits::max(); - bool private_ = false; - bool inputs_can_cross_devices_ = false; - std::function num_inputs_allowed_ = [](int) { return true; }; - std::function num_outputs_allowed_ = [](int) { return true; }; - std::function num_inputs_outputs_allowed_ = [](int, int) { - return true; - }; - std::function calculate_output_; - // In default, any in-place operation is neither allowed nor enforced. - std::function inplace_allowed_ = [](int, int) { - return false; - }; - std::function inplace_enforced_ = [](int, int) { - return false; - }; - TensorInferenceFunctionType tensor_inference_function_; - std::unique_ptr cost_inference_function_ = nullptr; - DeviceInferenceFunctionType device_inference_function_; - - std::function( - const std::vector>&)> - filler_supplier_ = - [this](const std::vector>& shapes) { - return SupplyDenseFillers(shapes); - }; -}; - -/** - * @brief A registry to hold all the operator schemas. - */ -class TORCH_API OpSchemaRegistry { - public: - static OpSchema& - NewSchema(const string& key, const string& file, const int line); - - static const OpSchema* Schema(const string& key) { - auto& m = map(); - auto it = m.find(key); - if (it != m.end()) { - return &it->second; - } else { - return nullptr; - } - } - - private: - // OpSchemaRegistry should not need to be instantiated. - OpSchemaRegistry() = delete; - - /** - * @brief Returns the underlying string to OpSchema map. - * - * You should not manually manipulate the map object returned. Instead, use - * the macros defined such as OPERATOR_SCHEMA to register your operator - * schema. - * - * We wrap it inside a function to avoid the static initialization order - * fiasco. - */ - static CaffeMap& map(); -}; - -// Helper function for creating simple tensorproto with dimension and type -template -inline TensorShape CreateTensorShape( - vector dims, - ::caffe2::TensorProto_DataType dt) { - TensorShape ts; - for (T_I d : dims) { - ts.add_dims(d); - } - ts.set_data_type(dt); - return ts; -} - -// Helper function -inline vector GetDimsVector(const TensorShape& shape) { - vector dims; - for (auto d : shape.dims()) { - dims.push_back(d); - } - return dims; -} - -// Helper function -inline uint64_t nElemFromDim(const TensorShape& X, int dim = 0) { - CAFFE_ENFORCE_GE(dim, 0, "Invalid maximum index specified"); - - uint64_t nElem = 1; - for (const auto i : c10::irange(dim, X.dims_size())) { - nElem *= X.dims(i); - } - return nElem; -} - -// Helper function -inline uint64_t nElemBetweenDim(const TensorShape& X, int start, int stop) { - CAFFE_ENFORCE_GE(start, 0, "Invalid maximum index specified"); - CAFFE_ENFORCE_LE(stop, X.dims_size(), "Invalid maximum index specified"); - - uint64_t nElem = 1; - for (const auto i : c10::irange(start, stop)) { - nElem *= X.dims(i); - } - return nElem; -} - -// Helper function for infer op inputs and outputs device information. -inline std::pair, std::vector> -InferOpInputOutputDevice(const OperatorDef& op) { - auto op_schema = OpSchemaRegistry::Schema(op.type()); - if (op_schema) { - // op_schema found - return op_schema->InferDevice(op); - - } else { - // No schema for op.type registered - auto temp_schema = OpSchema(); - return temp_schema.InferDevice(op); - } -} - -template -OpSchema::Cost PointwiseCostInference( - const OperatorDef& /* unused */, - const vector& inputs) { - struct OpSchema::Cost c; - const TensorShape X = inputs[0]; - uint64_t nElemX = nElemFromDim(X); - uint64_t nElemRead = 0; - for (const auto i : c10::irange(inputs.size())) { - nElemRead += nElemFromDim(inputs[i]); - } - - c.flops = nElemX * OpsPerPoint; - auto const& X_element_size_byte = - DataTypeToTypeMeta(X.data_type()).itemsize(); - c.bytes_read = nElemRead * X_element_size_byte; - c.bytes_written = nElemX * X_element_size_byte; - return c; -} - -} // namespace caffe2 - -#if defined(_MSC_VER) -#define EXPORT_IF_NOT_MSVC -#else -#define EXPORT_IF_NOT_MSVC C10_EXPORT -#endif - -#ifndef CAFFE2_NO_OPERATOR_SCHEMA - -#define OPERATOR_SCHEMA(name) \ - EXPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ - static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \ - &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) - -#else // CAFFE2_NO_OPERATOR_SCHEMA - -#define OPERATOR_SCHEMA(name) \ - EXPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ - static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \ - 1 ? nullptr : &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) - -#endif // CAFFE2_NO_OPERATOR_SCHEMA - -#ifdef CAFFE2_NO_GRADIENT_OPS - -#define GRADIENT_OPERATOR_SCHEMA(name) \ - EXPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ - static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \ - 1 ? nullptr : &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) - -#else - -#define GRADIENT_OPERATOR_SCHEMA(name) OPERATOR_SCHEMA(name) - -#endif -#endif // CAFFE2_CORE_OPERATOR_SCHEMA_H_ diff --git a/caffe2/core/storage.h b/caffe2/core/storage.h deleted file mode 100644 index e9bd6ed60c0b..000000000000 --- a/caffe2/core/storage.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef CAFFE2_CORE_STORAGE_H_ -#define CAFFE2_CORE_STORAGE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "caffe2/core/allocator.h" -#include "caffe2/core/common.h" -#include "caffe2/core/context.h" -#include "caffe2/core/flags.h" -#include "caffe2/core/logging.h" -#include - -#include -#include -#include -#include -#include -#include - -namespace caffe2 { - -using StorageImpl = at::StorageImpl; -using Storage = at::Storage; - -} // namespace caffe2 - -#endif // CAFFE2_CORE_STORAGE_H_ diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h deleted file mode 100644 index 1171605b9f77..000000000000 --- a/caffe2/core/tensor.h +++ /dev/null @@ -1,674 +0,0 @@ -#ifndef CAFFE2_CORE_TENSOR_H_ -#define CAFFE2_CORE_TENSOR_H_ - -#include -#include "caffe2/core/storage.h" - -#include -#include -#include -#include -#include -#include - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") -#endif - -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -namespace at { -class Tensor; -}; -#endif -namespace caffe2 { - -using at::UndefinedTensorImpl; - -/** - * @brief Tensor class holds a shared pointer to the implementation TensorImpl, - * redirects API calls to TensorImpl; - * Copying of Tensor results in sharing the same underlying implementation - * object - * - * NB: See TensorImpl for documentation on these methods. - */ -class TORCH_API Tensor final { - private: - enum Unsafe { IDoWantAliasing }; - Tensor(const Tensor& other, Unsafe _) : impl_(other.getIntrusivePtr()) {} - - protected: - using TensorImplPtr = c10::intrusive_ptr; - TensorImplPtr impl_; - - void enforce_invariants(); - - public: - Tensor() : impl_() {} - - Tensor(const Tensor& t) : impl_(t.impl_) {} - Tensor& operator=(const Tensor& t) { - impl_ = t.impl_; - return *this; - } - - Tensor(Tensor&&) = default; - Tensor& operator=(Tensor&&) = default; - - operator bool() const { - return impl_.defined(); - } - - TensorImpl* unsafeGetTensorImpl() const { - return impl_.get(); - } - - TensorImpl* unsafeReleaseTensorImpl() { - return impl_.release(); - } - - Tensor UnsafeSharedInstance() const { - return Tensor(*this, IDoWantAliasing); - } - - /** - * @brief Creates a tensor of the given device type. - * - * Note that the actual data allocation is not going to be carried out until - * you resize the tensor and then call mutable_data(). - */ - explicit Tensor(at::Device device) - : impl_(c10::make_intrusive( - Storage::create_legacy(device), - c10::computeDispatchKey(c10::nullopt, at::kStrided, device), - TypeMeta())) {} - - /** - * @brief Creates a tensor of the given dimension. - * - * Note that the actual data allocation is not going to be carried out until - * the first time mutable_data() is called. - */ - explicit Tensor(at::IntArrayRef dims, DeviceType type) : Tensor(type) { - // TODO: here, we create a Storage - // and immediately discard it in Resize() since - // reset_tensor will be true and FreeMemory will be called, - // we might want to avoid creating Storage twice? - Resize(dims); - } - - // we want to preserve index information - explicit Tensor(at::IntArrayRef dims, at::Device device) : Tensor(device) { - Resize(dims); - } - - // TODO: remove? - explicit Tensor(const vector& dims, DeviceType type) : Tensor(type) { - Resize(dims); - } - - /** - * @brief: Create a Tensor of at::DeviceType `type` and initialize it with - * src Tensor - */ - Tensor(const Tensor& src, DeviceType type) : Tensor(type) { - CopyFrom(src); - } - - /** - * @brief Mutual conversion with at::Tensor - * - * The tensor will share the same instance (data, strides, sizes, etc) but - * a different subset of APIs would be available - */ -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - explicit Tensor(at::Tensor tensor); - - explicit operator at::Tensor() const&; - - explicit operator at::Tensor() &&; -#endif - - bool is_same(const Tensor& other) const noexcept { - return impl_ == other.impl_; - } - - Tensor Clone() const { - Tensor x(GetDevice()); - x.CopyFrom(*this); - return x; - } - - /** - * Clone self as a Tensor that share the same Storage, - * that is, both Tensors are views on the same Storage. - * If we change the sizes or strides of one Tensor, it - * does not affect the other Tensor that it shares Storage - * with. - * A similar yet different usage is `Tensor x = y;`, this - * will make x and y pointing to the same Tensor and resizing - * one of them will resize the other as well. - * - * TODO: Deduplicate this with THTensor_(newWithTensor) - * (exposed in ATen as at::alias but not otherwise available) - */ - Tensor Alias() const { - Tensor x(sizes(), GetDevice()); - if (!dtype_initialized()) { - C10_LOG_EVERY_MS(WARNING, 1000) - << "Cloning a tensor that don't have a data type (did you call mutable_data on the tensor?)"; - } - AT_ASSERTM( - storage_initialized(), - "Cloning a tensor that has no content and has size > 0"); - // set_storage already sets data_type_ of TensorImpl - x.impl_->set_storage_and_dtype(storage(), impl_->dtype()); - x.impl_->set_storage_offset(impl_->storage_offset()); - x.impl_->set_sizes_and_strides(sizes(), strides()); - return x; - } - - DeviceType GetDeviceType() const { - return impl_->device_type(); - } - - at::Device GetDevice() const { - return impl_.get()->device(); - } - - /** - * @brief Copies the data from a source tensor, with a context provided to - * carry out the underlying memcpy operation. This method respects - * caffe2_keep_on_shrink. - * - * After CopyFrom, this function guarantees that the destination tensor will - * have the same initialization state and dtype as src. This function - * preserves the DeviceType of the source tensor (so, e.g., if you allocate - * a tensor on CPU and then CopyFrom a CUDA tensor, that will to a - * CUDA-to-CPU transfer). - * - * 'async' parameter triggers async copy for CUDA tensors - */ - void CopyFrom(const Tensor& src, bool async = false); - - /** - * @brief Extend the outer-most dimension of this tensor - * to dimension of `num`. - */ - void ExtendTo(int64_t num, float growthPct) const { - CAFFE_ENFORCE_GE_WITH_CALLER(impl_->dim(), 1); - CAFFE_ENFORCE_GE_WITH_CALLER(growthPct, 0); - Extend(num - impl_->size(0), growthPct); - } - - void Extend(int64_t num, float growthPct) const { - impl_.get()->Extend(num, growthPct); - } - - /** - * @brief Shrinks the outer-most dimension to given size, keeping the data. - * - * This method guarantees that no re-allocations are carried out, which means - * that the extra capacity after the end of the shrunk tensor is maintained. - * Notably, this function does NOT respect caffe2_keep_on_shrink. - */ - void ShrinkTo(int64_t outer_dim) const { - CAFFE_ENFORCE_WITH_CALLER( - impl_->is_contiguous(), - "Right now ShrinkTo is only supported on contiguous Tensor."); - CAFFE_ENFORCE_WITH_CALLER(impl_->dim() >= 1, "Tensor must be at least 1D"); - CAFFE_ENFORCE_WITH_CALLER( - outer_dim <= impl_->size(0), - "New outer dimension must be smaller than current."); - CAFFE_ENFORCE( - impl_->storage().unique(), - "Can't call ShrinkTo on shared storage, please call Resize instead."); - impl_.get()->set_size(0, outer_dim); - } - - template - void ReserveSpace(const T& outer_dim) const { - impl_.get()->ReserveSpace(outer_dim); - } - - template - void Resize(Ts... dim_source) const { - impl_.get()->Resize(dim_source...); - } - - template - void Resize(const std::vector& dim_source) const { - impl_.get()->Resize(ArrayRef(dim_source)); - } - - /** - * Resize the tensor like the source tensor. Note that this is just a - * sugar wrapper that essentially calls Resize(src_tensor.dims()). - * This method respects caffe2_keep_on_shrink. - */ - inline void ResizeLike(const Tensor& src_tensor) const { - CAFFE_ENFORCE_WITH_CALLER( - src_tensor.is_contiguous(), - "Right now ResizeLike is only supported for contiguous Tensor."); - if (impl_ != src_tensor.impl_) { - impl_.get()->Resize(src_tensor.sizes()); - } - } - - inline void Reshape(const vector& dims) const { - impl_.get()->Reshape(dims); - } - - inline void Reshape(const vector& dims) const { - impl_.get()->Reshape(ToVectorint64_t(dims)); - } - - inline void FreeMemory() const { - impl_.get()->FreeMemory(); - } - - /** - * A utility function to print the debug string for the tensor. Note that this - * is very slow since it involves quite some string operations, so do not use - * it in your performance-critical code. - */ - string DebugString() const { - std::stringstream ss; - ss << "A Tensor of item size " << impl_->dtype().itemsize() << " and type " - << impl_->dtype().name() << " and dimension ("; - for (int d : impl_->sizes()) { - ss << d << ","; - } - ss << ")."; - return ss.str(); - } - - // To be deprecated - void ShareData(const Tensor& src) const { - impl_.get()->ShareData(*src.impl_.get()); - } - - /** - * @brief Shares the data with an externally managed pointer. - * - * This is similar to ShareData() but the source is a pointer with an advanced - * deleter option. In default, no deletion takes place, and one needs to make - * sure that the external memory is deallocated only after the tensor finishes - * using it. If a Deleter object is passed in, when this tensor is reallocated - * or freed, the deleter function is going to be called. - */ - template - void ShareExternalPointer( - T* src, - size_t nbytes = 0, - MemoryDeleter d = nullptr) const { - ShareExternalPointer((void*)src, caffe2::TypeMeta::Make(), nbytes, d); - } - - template - void ShareExternalPointer(at::DataPtr&& data_ptr, size_t nbytes = 0) const { - ShareExternalPointer( - std::move(data_ptr), caffe2::TypeMeta::Make(), nbytes); - } - - void ShareExternalPointer( - void* src, - const TypeMeta data_type, - size_t nbytes = 0, - MemoryDeleter d = nullptr) const { - CAFFE_ENFORCE_WITH_CALLER( - impl_->is_contiguous(), - "Right now ShareExternalPointer is only supported for contiguous Tensor."); - CAFFE_ENFORCE_WITH_CALLER( - data_type != ScalarType::Undefined, - "To share with a raw external pointer you need to pass in an " - "initialized data_type(TypeMeta)."); - impl_.get()->ShareExternalPointer( - at::DataPtr(src, src, d, impl_->device_type()), data_type, nbytes); - } - - void ShareExternalPointer( - at::DataPtr&& data_ptr, - const TypeMeta data_type, - size_t nbytes) { - impl_.get()->ShareExternalPointer(std::move(data_ptr), data_type, nbytes); - } - - const c10::intrusive_ptr& getIntrusivePtr() - const { - return impl_; - } - - bool defined() const { - return impl_; - } - - /** - * Returns a raw void* pointer of the underlying storage. mutable_data() - * or raw_mutable_data() must have been called prior to this function call. - */ - inline void* raw_data() const { - return impl_->mutable_data(); - } - - template - inline T* data() const { - return impl_.get()->mutable_data_dtype_initialized(); - } - - inline void* raw_mutable_data(const TypeMeta meta) const { - return impl_.get()->raw_mutable_data(meta); - } - - /** - * Returns a mutable raw pointer of the underlying storage. This can only be - * used when you know for sure that the underlying storage of the tensor is - * already created via an earlier raw_mutable_data(meta) call or a - * mutable_data() call. - * - * If the existing data does not match the desired type, it will be deleted - * and a new storage will be created. - */ - inline void* raw_mutable_data() const { - const auto& data_type = impl_->dtype(); - CAFFE_ENFORCE_WITH_CALLER( - data_type != ScalarType::Undefined, - "Calling raw_mutable_data() without meta, but the current meta is " - "of unknown type."); - return raw_mutable_data(data_type); - } - - template - inline T* mutable_data() const { - return impl_.get()->mutable_data(); - } - - /** - * Returns the number of dimensions of the data. - */ - inline int dim() const { - return impl_->dim(); - } - - /** - * (To be deprecated) Returns the number of dimensions of the data. - */ - inline int ndim() const { - return impl_->dim(); - } - - /** - * (To be deprecated) Returns the size (i.e. the number of items) of the - * tensor. - */ - inline int64_t size() const { - return impl_->numel(); - } - - /** - * Returns the number of items of the tensor. - */ - inline int64_t numel() const { - return impl_->numel(); - } - - /** - * Return the number of bytes each item takes in the tensor. - */ - inline size_t itemsize() const { - return impl_->dtype().itemsize(); - } - - /** - * Returns the total number of bytes of the storage. - * - * This is equivalent to calling size() * itemsize(). - */ - inline size_t nbytes() const { - return impl_->numel() * itemsize(); - } - - inline at::IntArrayRef sizes() const { - return impl_.get()->sizes(); - } - - inline c10::SymIntArrayRef sym_sizes() const { - return impl_->sym_sizes(); - } - - inline c10::SymInt sym_numel() const { - return impl_->sym_numel(); - } - - inline c10::SymIntArrayRef sym_strides() const { - return impl_->sym_strides(); - } - - inline int64_t size_from_dim(int k) const { - return size_from_dim_(k, impl_->sizes()); - } - - inline int64_t size_to_dim(int k) const { - return size_to_dim_(k, impl_->sizes()); - } - - inline int64_t size_between_dim(int k, int l) const { - return size_between_dim_(k, l, impl_->sizes()); - } - - /** - * Returns the 'canonical' version of a (usually) user-specified axis, - * allowing for negative indexing (e.g., -1 for the last axis). - * - * @param axis_index the axis index. - * If 0 <= index < dim(), return index. - * If -ndim <= index <= -1, return (dim() - (-index)), - * e.g., the last axis index (dim() - 1) if index == -1, - * the second to last if index == -2, etc. - * Dies on out of range index. - */ - inline int canonical_axis_index(int axis_index) const { - return canonical_axis_index_(axis_index, impl_->dim()); - } - - inline int64_t stride(int64_t dim) const { - return impl_.get()->stride(dim); - } - - inline at::IntArrayRef strides() const { - return impl_.get()->strides(); - } - - inline bool is_contiguous( - at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { - return impl_.get()->is_contiguous(memory_format); - } - - /** - * Checks if the tensor content is of the given data type. - */ - template - inline bool IsType() const { - return impl_->dtype().Match(); - } - - /** - * Returns the TypeMeta object associated with the current data type. - */ - inline const TypeMeta dtype() const { - return impl_->dtype(); - } - - /** - * (To be deprecated) Returns the TypeMeta object associated with the current - * data type. - */ - inline const TypeMeta meta() const { - return impl_->dtype(); - } - - /** - * Returns the i-th dimension of the tensor in int. - * - * This function returns an int value instead of int64_t, which depending on - * the typedef could be int64. If you want int64 dim values, make sure you - * call dim() instead. - */ - inline int dim32(const int i) const { -#ifndef NDEBUG - CAFFE_ENFORCE_LT_WITH_CALLER( - i, static_cast(impl_->dim()), "Exceeding ndim limit"); - CAFFE_ENFORCE_GE_WITH_CALLER(i, 0, "Cannot have negative dimension index"); -#endif - // Avoid TensorImpl::size() because it is a virtual call that - // supports out-of-range indexing like Python. - auto s = impl_->sizes()[i]; - CAFFE_ENFORCE_LT_WITH_CALLER(s, std::numeric_limits::max()); - return static_cast(s); - } - - inline int64_t size(const int i) const { - return impl_->size(i); - } - - // To be deprecated - inline int64_t dim(const int i) const { - return impl_->size(i); - } - - const Storage& storage() { - return impl_->storage(); - } - - const Storage& storage() const { - return impl_->storage(); - } - - bool storage_initialized() const { - return impl_->storage_initialized(); - } - - bool dtype_initialized() const { - return impl_->dtype_initialized(); - } -}; - -/** - * Reinitialize a Tensor to given dims and options if necessary, note that - * this will not do anything if the - * Tensor already has correct size and data type - */ -TORCH_API void -ReinitializeTensor(Tensor* t, at::IntArrayRef dims, at::TensorOptions options); - -TORCH_API void ReinitializeAndCopyFrom( - Tensor* t, - at::TensorOptions options, - const Tensor& src, - bool async = false); - -using TensorCPU = Tensor; - -constexpr int k_limit_default_ = 1000; - -// TODO: the following logic can be merged into regular Tensor class methods -// after MKLMemory starts to implement Tensor interface - -// Type call registry -typedef TypeMeta (*TypeCall)(const void*); -TypeCall GetTypeCallFunction(TypeIdentifier id); -void RegisterTypeCallFunction(TypeIdentifier id, TypeCall c); - -// Shape call registry -typedef vector ( - *TensorInfoCall)(const void*, size_t* capacity, DeviceOption* device); -TensorInfoCall GetTensorInfoFunction(TypeIdentifier id); -void RegisterTensorInfoFunction(TypeIdentifier id, TensorInfoCall c); - -// resize helper function -void TensorVectorResize( - std::vector& tensors, - int size, - DeviceType type); - -// Tensor factory function -TORCH_API Tensor empty(at::IntArrayRef dims, at::TensorOptions options); - -/** - * @brief Creates a CPU tensor, and fills its contents with the given values. - * Values are copied in - */ -// TODO: can be unified with at::from_blob when Tensor is merged and string -// types are supported -template -Tensor TensorCPUFromValues(at::IntArrayRef dims, at::ArrayRef values) { - Tensor r = empty(dims, at::device(CPU).dtype()); - CAFFE_ENFORCE_EQ(values.size(), r.numel()); - CPUContext context; - context.CopyItemsFromCPU( - r.dtype(), values.size(), values.data(), r.mutable_data()); - return r; -} - -vector -GetTensorInfo(const void* c, size_t* capacity, DeviceOption* device); - -class TORCH_API TensorPrinter { - public: - explicit TensorPrinter( - const std::string& tensor_name = "", - const std::string& file_name = "", - int limit = k_limit_default_); - ~TensorPrinter(); - - template - void Print(const Tensor& tensor); - - void PrintMeta(const Tensor& tensor); - - string MetaStr(const Tensor& tensor); - - private: - bool to_file_; - int limit_; - std::unique_ptr log_file_; - std::string tensor_name_; -}; - -template -void TensorPrinter::Print(const Tensor& tensor) { - std::stringstream values_stream; - // One most likely doesn't want to print int64-number of items for visual - // inspection, so we cast down to int here. - int total_count = static_cast(std::min(tensor.numel(), int64_t(limit_))); - - const T* tensor_data = tensor.template data(); - for (int i = 0; i < total_count - 1; ++i) { - values_stream << tensor_data[i] << ","; - } - if (total_count) { - // We do not add a comma after the last item. - values_stream << tensor_data[total_count - 1]; - } - - if (to_file_) { - (*log_file_) << MetaStr(tensor) << values_stream.str() << std::endl; - } else { - // Log to console. - LOG(INFO) << MetaStr(tensor) << values_stream.str(); - } -} - -CAFFE_DECLARE_KNOWN_TYPE(Tensor, Caffe2Tensor) -} // namespace caffe2 - -C10_CLANG_DIAGNOSTIC_POP() - -namespace c10 { -template <> -struct ExclusivelyOwnedTraits : public c10::ExclusivelyOwnedTensorTraits {}; -} // namespace c10 -#endif // CAFFE2_CORE_TENSOR_H_ diff --git a/caffe2/core/tensor_int8.h b/caffe2/core/tensor_int8.h deleted file mode 100644 index b95b7b8d10e5..000000000000 --- a/caffe2/core/tensor_int8.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef CAFFE2_TENSOR_INT8_H_ -#define CAFFE2_TENSOR_INT8_H_ - -#include "caffe2/core/context.h" -#include "caffe2/core/tensor.h" -#include "caffe2/proto/caffe2_pb.h" - -namespace caffe2 { -namespace int8 { - -struct Int8TensorCPU { - float scale{1.0}; - int32_t zero_point{0}; - // Generally stores uint8_t data, but sometimes int32_t (e.g. bias - // parameters). - Tensor t{CPU}; -}; -} // namespace int8 -} // namespace caffe2 - -#endif // CAFFE2_TENSOR_INT8_H_ diff --git a/caffe2/core/workspace.h b/caffe2/core/workspace.h deleted file mode 100644 index 04fa86fe2527..000000000000 --- a/caffe2/core/workspace.h +++ /dev/null @@ -1,342 +0,0 @@ -#ifndef CAFFE2_CORE_WORKSPACE_H_ -#define CAFFE2_CORE_WORKSPACE_H_ - -#include "caffe2/core/common.h" -#include "caffe2/core/observer.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "c10/util/Registry.h" -#include "caffe2/core/blob.h" -#include "caffe2/core/net.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/signal_handler.h" -#include "caffe2/utils/threadpool/ThreadPool.h" - -C10_DECLARE_bool(caffe2_print_blob_sizes_at_exit); - -namespace caffe2 { - -class NetBase; - -struct TORCH_API StopOnSignal { - StopOnSignal() - : handler_(std::make_shared( - SignalHandler::Action::STOP, - SignalHandler::Action::STOP)) {} - - StopOnSignal(const StopOnSignal& other) : handler_(other.handler_) {} - - bool operator()(int /*iter*/) { - return handler_->CheckForSignals() != SignalHandler::Action::STOP; - } - - std::shared_ptr handler_; -}; - -/** - * Workspace is a class that holds all the related objects created during - * runtime: (1) all blobs, and (2) all instantiated networks. It is the owner of - * all these objects and deals with the scaffolding logistics. - */ -class TORCH_API Workspace { - public: - typedef std::function ShouldContinue; - /** - * Initializes an empty workspace. - */ - Workspace() : Workspace(".", nullptr) {} - - /** - * Initializes an empty workspace with the given root folder. - * - * For any operators that are going to interface with the file system, such - * as load operators, they will write things under this root folder given - * by the workspace. - */ - explicit Workspace(const string& root_folder) - : Workspace(root_folder, nullptr) {} - - /** - * Initializes a workspace with a shared workspace. - * - * When we access a Blob, we will first try to access the blob that exists - * in the local workspace, and if not, access the blob that exists in the - * shared workspace. The caller keeps the ownership of the shared workspace - * and is responsible for making sure that its lifetime is longer than the - * created workspace. - */ - explicit Workspace(const Workspace* shared) : Workspace(".", shared) {} - - /** - * Initializes workspace with parent workspace, blob name remapping - * (new name -> parent blob name), no other blobs are inherited from - * parent workspace - */ - Workspace( - const Workspace* shared, - const std::unordered_map& forwarded_blobs) - : Workspace(".", nullptr) { - CAFFE_ENFORCE(shared, "Parent workspace must be specified"); - for (const auto& forwarded : forwarded_blobs) { - CAFFE_ENFORCE( - shared->HasBlob(forwarded.second), - "Invalid parent workspace blob: ", - forwarded.second); - forwarded_blobs_[forwarded.first] = - std::make_pair(shared, forwarded.second); - } - } - - /** - * Initializes a workspace with a root folder and a shared workspace. - */ - Workspace(const string& root_folder, const Workspace* shared) - : root_folder_(root_folder), shared_(shared), bookkeeper_(bookkeeper()) { - std::lock_guard guard(bookkeeper_->wsmutex); - bookkeeper_->workspaces.insert(this); - } - - ~Workspace() { - if (FLAGS_caffe2_print_blob_sizes_at_exit) { - PrintBlobSizes(); - } - // This is why we have a bookkeeper_ shared_ptr instead of a naked static! A - // naked static makes us vulnerable to out-of-order static destructor bugs. - std::lock_guard guard(bookkeeper_->wsmutex); - bookkeeper_->workspaces.erase(this); - } - - /** - * Adds blob mappings from workspace to the blobs from parent workspace. - * Creates blobs under possibly new names that redirect read/write operations - * to the blobs in the parent workspace. - * Arguments: - * parent - pointer to parent workspace - * forwarded_blobs - map from new blob name to blob name in parent's - * workspace skip_defined_blob - if set skips blobs with names that already - * exist in the workspace, otherwise throws exception - */ - void AddBlobMapping( - const Workspace* parent, - const std::unordered_map& forwarded_blobs, - bool skip_defined_blobs = false); - - /** - * Converts previously mapped tensor blobs to local blobs, copies values from - * parent workspace blobs into new local blobs. Ignores undefined blobs. - */ - template - void CopyForwardedTensors(const std::unordered_set& blobs) { - for (const auto& blob : blobs) { - auto it = forwarded_blobs_.find(blob); - if (it == forwarded_blobs_.end()) { - continue; - } - const auto& ws_blob = it->second; - const auto* parent_ws = ws_blob.first; - auto* from_blob = parent_ws->GetBlob(ws_blob.second); - CAFFE_ENFORCE(from_blob); - CAFFE_ENFORCE( - from_blob->template IsType(), - "Expected blob with tensor value", - ws_blob.second); - forwarded_blobs_.erase(blob); - auto* to_blob = CreateBlob(blob); - CAFFE_ENFORCE(to_blob); - const auto& from_tensor = from_blob->template Get(); - auto* to_tensor = BlobGetMutableTensor(to_blob, Context::GetDeviceType()); - to_tensor->CopyFrom(from_tensor); - } - } - - /** - * Return list of blobs owned by this Workspace, not including blobs - * shared from parent workspace. - */ - vector LocalBlobs() const; - - /** - * Return a list of blob names. This may be a bit slow since it will involve - * creation of multiple temp variables. For best performance, simply use - * HasBlob() and GetBlob(). - */ - vector Blobs() const; - - /** - * Return the root folder of the workspace. - */ - const string& RootFolder() { return root_folder_; } - /** - * Checks if a blob with the given name is present in the current workspace. - */ - inline bool HasBlob(const string& name) const { - // First, check the local workspace, - // Then, check the forwarding map, then the parent workspace - if (blob_map_.count(name)) { - return true; - } - - auto it = forwarded_blobs_.find(name); - if (it != forwarded_blobs_.end()) { - const auto parent_ws = it->second.first; - const auto& parent_name = it->second.second; - return parent_ws->HasBlob(parent_name); - } - - if (shared_) { - return shared_->HasBlob(name); - } - - return false; - } - - void PrintBlobSizes(); - - /** - * Creates a blob of the given name. The pointer to the blob is returned, but - * the workspace keeps ownership of the pointer. If a blob of the given name - * already exists, the creation is skipped and the existing blob is returned. - */ - Blob* CreateBlob(const string& name); - /** - * Similar to CreateBlob(), but it creates a blob in the local workspace even - * if another blob with the same name already exists in the parent workspace - * -- in such case the new blob hides the blob in parent workspace. If a blob - * of the given name already exists in the local workspace, the creation is - * skipped and the existing blob is returned. - */ - Blob* CreateLocalBlob(const string& name); - /** - * Remove the blob of the given name. Return true if removed and false if - * not exist. - * Will NOT remove from the shared workspace. - */ - bool RemoveBlob(const string& name); - /** - * Gets the blob with the given name as a const pointer. If the blob does not - * exist, a nullptr is returned. - */ - const Blob* GetBlob(const string& name) const; - /** - * Gets the blob with the given name as a mutable pointer. If the blob does - * not exist, a nullptr is returned. - */ - Blob* GetBlob(const string& name); - - /** - * Renames a local workspace blob. If blob is not found in the local blob list - * or if the target name is already present in local or any parent blob list - * the function will throw. - */ - Blob* RenameBlob(const string& old_name, const string& new_name); - - /** - * Creates a network with the given NetDef, and returns the pointer to the - * network. If there is anything wrong during the creation of the network, a - * nullptr is returned. The Workspace keeps ownership of the pointer. - * - * If there is already a net created in the workspace with the given name, - * CreateNet will overwrite it if overwrite=true is specified. Otherwise, an - * exception is thrown. - */ - NetBase* CreateNet(const NetDef& net_def, bool overwrite = false); - NetBase* CreateNet( - const std::shared_ptr& net_def, - bool overwrite = false); - /** - * Gets the pointer to a created net. The workspace keeps ownership of the - * network. - */ - NetBase* GetNet(const string& net_name); - /** - * Deletes the instantiated network with the given name. - */ - void DeleteNet(const string& net_name); - /** - * Finds and runs the instantiated network with the given name. If the network - * does not exist or there are errors running the network, the function - * returns false. - */ - bool RunNet(const string& net_name); - - /** - * Returns a list of names of the currently instantiated networks. - */ - vector Nets() const { - vector names; - for (auto& entry : net_map_) { - names.push_back(entry.first); - } - return names; - } - - /** - * Runs a plan that has multiple nets and execution steps. - */ - bool RunPlan(const PlanDef& plan_def, - ShouldContinue should_continue = StopOnSignal{}); - - /* - * Returns a CPU threadpool instance for parallel execution of - * work. The threadpool is created lazily; if no operators use it, - * then no threadpool will be created. - */ - ThreadPool* GetThreadPool(); - - // RunOperatorOnce and RunNetOnce runs an operator or net once. The difference - // between RunNet and RunNetOnce lies in the fact that RunNet allows you to - // have a persistent net object, while RunNetOnce creates a net and discards - // it on the fly - this may make things like database read and random number - // generators repeat the same thing over multiple calls. - bool RunOperatorOnce(const OperatorDef& op_def); - bool RunNetOnce(const NetDef& net_def); - - /** - * Applies a function f on each workspace that currently exists. - * - * This function is thread safe and there is no race condition between - * workspaces being passed to f in this thread and destroyed in another. - */ - template - static void ForEach(F f) { - auto bk = bookkeeper(); - std::lock_guard guard(bk->wsmutex); - for (Workspace* ws : bk->workspaces) { - f(ws); - } - } - - public: - std::atomic last_failed_op_net_position{}; - - private: - struct Bookkeeper { - std::mutex wsmutex; - std::unordered_set workspaces; - }; - - static std::shared_ptr bookkeeper(); - - std::unordered_map> blob_map_; - const string root_folder_; - const Workspace* shared_; - std::unordered_map> - forwarded_blobs_; - std::unique_ptr thread_pool_; - std::mutex thread_pool_creation_mutex_; - std::shared_ptr bookkeeper_; - std::unordered_map> net_map_; - - C10_DISABLE_COPY_AND_ASSIGN(Workspace); -}; - -} // namespace caffe2 - -#endif // CAFFE2_CORE_WORKSPACE_H_ diff --git a/caffe2/utils/GpuAtomics.cuh b/caffe2/utils/GpuAtomics.cuh deleted file mode 100644 index 2bbcc14fa7da..000000000000 --- a/caffe2/utils/GpuAtomics.cuh +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef CAFFE2_UTILS_GPU_ATOMICS_H_ -#define CAFFE2_UTILS_GPU_ATOMICS_H_ - -#include - -namespace caffe2 { - -namespace { - -template -inline __device__ void gpu_atomic_add(T* address, const T val) { - atomicAdd(address, val); -} - -template <> -inline __device__ void gpu_atomic_add(float* address, const float val) { -#if defined(USE_ROCM) && defined(__gfx908__) - atomicAddNoRet(address, val); -#else - atomicAdd(address, val); -#endif -} - -} // namespace - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_GPU_ATOMICS_H_ diff --git a/caffe2/utils/GpuBitonicSort.cuh b/caffe2/utils/GpuBitonicSort.cuh deleted file mode 100644 index 45cb298733a8..000000000000 --- a/caffe2/utils/GpuBitonicSort.cuh +++ /dev/null @@ -1,178 +0,0 @@ -#ifndef CAFFE2_UTILS_GPU_BITONIC_SORT_H_ -#define CAFFE2_UTILS_GPU_BITONIC_SORT_H_ - -#include "caffe2/utils/math.h" -#include "caffe2/utils/GpuDefs.cuh" - -namespace caffe2 { - -// Returns true if the given integer type is a power-of-2 (positive only) -// Note(jiayq): windows reported an error per -// https://github.com/caffe2/caffe2/issues/997 -// and as a result will make it a macro. -#ifdef _MSC_VER -#define integerIsPowerOf2(v) ((v) && !((v) & ((v) - 1))) -#else // _MSC_VER -template -constexpr bool integerIsPowerOf2(T v) { - return (v && !(v & (v - 1))); -} -#endif // _MSC_VER - -/// The maximum in-block bitonic sort we support -constexpr int kMaxBitonicSortSize = 4096; - -template -__device__ inline void swapVars(T& t1, T& t2) { - T tmp = t1; - t1 = t2; - t2 = tmp; -} - -template -__device__ inline void bitonicSwap(K& kA, V& vA, - K& kB, V& vB, - bool dir, - const Comparator& comp) { - bool swap = comp(kA, vA, kB, vB); - if (swap == dir) { - swapVars(kA, kB); - swapVars(vA, vB); - } -}; - -template -__device__ inline void bitonicSort(K* keys, - V* values, - const Comparator& comp) { - static_assert(Power2SortSize <= kMaxBitonicSortSize, - "sort size <= 4096 only supported"); - // Assume the sort is taking place in shared memory - // static_assert(Power2SortSize * (sizeof(K) + sizeof(V)) < 32768, - // "sort data too large (>32768 bytes)"); - static_assert(integerIsPowerOf2(Power2SortSize), - "sort size must be power of 2"); - static_assert(integerIsPowerOf2(ThreadsPerBlock), - "threads in block must be power of 2"); - - // If what we are sorting is too small, then not all threads - // participate - constexpr int numThreadsForSort = Power2SortSize / 2; - constexpr bool allThreads = numThreadsForSort >= ThreadsPerBlock; - - // If what we are sorting is too large, then threads must loop more - // than once - constexpr int loopPerThread = - allThreads ? numThreadsForSort / ThreadsPerBlock : 1; - -#pragma unroll - for (int size = 2; size < Power2SortSize; size *= 2) { - -#pragma unroll - for (int stride = size / 2; stride > 0; stride /= 2) { - -#pragma unroll - for (int loop = 0; loop < loopPerThread; ++loop) { - int threadId = loop * ThreadsPerBlock + threadIdx.x; - bool flag = ((threadId & (size / 2)) != 0); - - int pos = 2 * threadId - (threadId & (stride - 1)); - - if (allThreads || (threadId < numThreadsForSort)) { - bitonicSwap( - keys[pos], values[pos], - keys[pos + stride], values[pos + stride], - flag, comp); - } - - __syncthreads(); - } - } - } - -#pragma unroll - for (int stride = Power2SortSize / 2; stride > 0; stride /= 2) { - -#pragma unroll - for (int loop = 0; loop < loopPerThread; ++loop) { - int threadId = loop * ThreadsPerBlock + threadIdx.x; - - int pos = 2 * threadId - (threadId & (stride - 1)); - - if (allThreads || (threadId < numThreadsForSort)) { - bitonicSwap( - keys[pos], values[pos], - keys[pos + stride], values[pos + stride], - false, comp); - } - - __syncthreads(); - } - } -} - -template -__device__ inline void warpBitonicSort(K* keys, - V* values, - const Comparator& comp) { - // Smaller sorts should use a warp shuffle sort - static_assert(Power2SortSize > kWarpSize, - "sort not large enough"); - static_assert(integerIsPowerOf2(Power2SortSize), - "sort size must be power of 2"); - static_assert(Power2SortSize <= kMaxBitonicSortSize, - "sort size <= 4096 only supported"); - - // If what we are sorting is too large, then lanes must loop more - // than once - constexpr int loopPerThread = (Power2SortSize / 2) / kWarpSize; - int laneId = getLaneId(); - -#pragma unroll - for (int size = 2; size < Power2SortSize; size *= 2) { - -#pragma unroll - for (int stride = size / 2; stride > 0; stride /= 2) { - -#pragma unroll - for (int loop = 0; loop < loopPerThread; ++loop) { - int threadId = loop * kWarpSize + laneId; - bool flag = ((threadId & (size / 2)) != 0); - - int pos = 2 * threadId - (threadId & (stride - 1)); - - bitonicSwap( - keys[pos], values[pos], - keys[pos + stride], values[pos + stride], - flag, comp); - - __threadfence_block(); - } - } - } - -#pragma unroll - for (int stride = Power2SortSize / 2; stride > 0; stride /= 2) { - -#pragma unroll - for (int loop = 0; loop < loopPerThread; ++loop) { - int threadId = loop * kWarpSize + laneId; - - int pos = 2 * threadId - (threadId & (stride - 1)); - - bitonicSwap( - keys[pos], values[pos], - keys[pos + stride], values[pos + stride], - false, comp); - - __threadfence_block(); - } - } -} - - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_GPU_BITONIC_SORT_H_ diff --git a/caffe2/utils/GpuDefs.cuh b/caffe2/utils/GpuDefs.cuh deleted file mode 100644 index fcf2c64ddcb1..000000000000 --- a/caffe2/utils/GpuDefs.cuh +++ /dev/null @@ -1,158 +0,0 @@ -#ifndef CAFFE2_UTILS_GPU_DEFS_H_ -#define CAFFE2_UTILS_GPU_DEFS_H_ - -#include - -namespace caffe2 { - -// Static definition of GPU warp size for unrolling and code generation - -#if defined(USE_ROCM) -constexpr int kWarpSize = warpSize; // = 64 (Defined in hip_runtime.h) -#else -constexpr int kWarpSize = 32; -#endif // __CUDA_ARCH__ - -// -// Interfaces to PTX instructions for which there appears to be no -// intrinsic -// - -template -struct Bitfield {}; - -template <> -struct Bitfield { - static __device__ __forceinline__ - unsigned int getBitfield(unsigned int val, int pos, int len) { -#if defined(USE_ROCM) - pos &= 0xff; - len &= 0xff; - - unsigned int m = (1u << len) - 1u; - return (val >> pos) & m; -#else - unsigned int ret; - asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); - return ret; -#endif // USE_ROCM - } - - static __device__ __forceinline__ - unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) { -#if defined(USE_ROCM) - pos &= 0xff; - len &= 0xff; - - unsigned int m = (1u << len) - 1u; - toInsert &= m; - toInsert <<= pos; - m <<= pos; - - return (val & ~m) | toInsert; -#else - unsigned int ret; - asm("bfi.b32 %0, %1, %2, %3, %4;" : - "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len)); - return ret; -#endif // USE_ROCM - } -}; - -template <> -struct Bitfield { - static __device__ __forceinline__ - unsigned long long int getBitfield(unsigned long long int val, int pos, int len) { -#if defined(USE_ROCM) - pos &= 0xff; - len &= 0xff; - - unsigned long long int m = (1u << len) - 1u; - return (val >> pos) & m; -#else - unsigned long long int ret; - asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); - return ret; -#endif // USE_ROCM - } - - static __device__ __forceinline__ - unsigned long long int setBitfield(unsigned long long int val, unsigned long long int toInsert, int pos, int len) { -#if defined(USE_ROCM) - pos &= 0xff; - len &= 0xff; - - unsigned long long int m = (1u << len) - 1u; - toInsert &= m; - toInsert <<= pos; - m <<= pos; - - return (val & ~m) | toInsert; -#else - unsigned long long int ret; - asm("bfi.b64 %0, %1, %2, %3, %4;" : - "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len)); - return ret; -#endif // USE_ROCM - } -}; - -__device__ __forceinline__ int getLaneId() { -#if defined(USE_ROCM) - return __lane_id(); -#else - int laneId; - asm("mov.s32 %0, %%laneid;" : "=r"(laneId) ); - return laneId; -#endif // USE_ROCM -} - -#if defined(USE_ROCM) -__device__ __forceinline__ unsigned long long int getLaneMaskLt() { - unsigned long long int m = (1ull << getLaneId()) - 1ull; - return m; -} - -__device__ __forceinline__ unsigned long long int getLaneMaskLe() { - unsigned long long int m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1)); - return m; -} - -__device__ __forceinline__ unsigned long long int getLaneMaskGt() { - unsigned long long int m = getLaneMaskLe(); - return m ? ~m : m; -} - -__device__ __forceinline__ unsigned long long int getLaneMaskGe() { - unsigned long long int m = getLaneMaskLt(); - return ~m; -} -#else -__device__ __forceinline__ unsigned getLaneMaskLt() { - unsigned mask; - asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask)); - return mask; -} - -__device__ __forceinline__ unsigned getLaneMaskLe() { - unsigned mask; - asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); - return mask; -} - -__device__ __forceinline__ unsigned getLaneMaskGt() { - unsigned mask; - asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask)); - return mask; -} - -__device__ __forceinline__ unsigned getLaneMaskGe() { - unsigned mask; - asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask)); - return mask; -} -#endif // USE_ROCM - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_GPU_DEFS_H_ diff --git a/caffe2/utils/GpuScanUtils.cuh b/caffe2/utils/GpuScanUtils.cuh deleted file mode 100644 index 0f6823d8e85e..000000000000 --- a/caffe2/utils/GpuScanUtils.cuh +++ /dev/null @@ -1,133 +0,0 @@ -#ifndef CAFFE2_UTILS_GPU_SCAN_UTILS_H_ -#define CAFFE2_UTILS_GPU_SCAN_UTILS_H_ - -#include "caffe2/utils/GpuDefs.cuh" - -namespace caffe2 { - -// from the cutorch library; can probably be replaced with their CUB -// equivalents -// Collection of in-kernel scan / prefix sum utilities - -// Inclusive prefix sum using shared memory -template -__device__ void inclusivePrefixScan(T* smem, T in, T* out, BinaryFunction binop) { - // FIXME: this is a slow, simple implementation; need up/down sweep, - // prevent smem conflicts - smem[threadIdx.x] = in; - - __syncthreads(); - - for (int offset = 1; offset < blockDim.x; offset *= 2) { - T val = 0; - - if (threadIdx.x >= offset) { - val = binop(smem[threadIdx.x - offset], smem[threadIdx.x]); - } - - __syncthreads(); - if (threadIdx.x >= offset) { - smem[threadIdx.x] = val; - } - - __syncthreads(); - } - - *out = smem[threadIdx.x]; - - // Prevent write-after-read dependencies on smem usage above if necessary - if (KillWARDependency) { - __syncthreads(); - } -} - -// Exclusive prefix sum using shared memory -template -__device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunction binop) { - // FIXME: crappy implementation - // We kill write-after-read dependencies separately below, hence the `false` - inclusivePrefixScan(smem, in, out, binop); - - *out -= in; - *carry = smem[blockDim.x - 1]; - - // Prevent write-after-read dependencies on smem usage above if necessary - if (KillWARDependency) { - __syncthreads(); - } -} - -// Inclusive prefix sum for binary vars using intra-warp voting + -// shared memory -template -__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) { - // Within-warp, we use warp voting. -#if defined(USE_ROCM) - unsigned long long int vote = __ballot(in); - - T index = __popcll(getLaneMaskLe() & vote); - T carry = __popcll(vote); -#else - T vote = __ballot_sync(__activemask(), in); - T index = __popc(getLaneMaskLe() & vote); - T carry = __popc(vote); -#endif // USE_ROCM - - int warp = threadIdx.x / kWarpSize; - - // Per each warp, write out a value - if (getLaneId() == 0) { - smem[warp] = carry; - } - - __syncthreads(); - - // Sum across warps in one thread. This appears to be faster than a - // warp shuffle scan for CC 3.0+ - if (threadIdx.x == 0) { - int current = 0; - for (int i = 0; i < blockDim.x / kWarpSize; ++i) { - T v = smem[i]; - smem[i] = binop(smem[i], current); - current = binop(current, v); - } - } - - __syncthreads(); - - // load the carry from the preceding warp - if (warp >= 1) { - index = binop(index, smem[warp - 1]); - } - - *out = index; - - if (KillWARDependency) { - __syncthreads(); - } -} - -// Exclusive prefix sum for binary vars using intra-warp voting + -// shared memory -template -__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) { - inclusiveBinaryPrefixScan(smem, in, out, binop); - - // Inclusive to exclusive - *out -= (T) in; - - // The outgoing carry for all threads is the last warp's sum -#if defined(USE_ROCM) - *carry = smem[math::DivUp(blockDim.x, kWarpSize) - 1]; -#else - *carry = smem[(blockDim.x / kWarpSize) - 1]; -#endif // USE_ROCM - - if (KillWARDependency) { - __syncthreads(); - } -} - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_GPU_SCAN_UTILS_H_ diff --git a/caffe2/utils/bench_utils.cc b/caffe2/utils/bench_utils.cc deleted file mode 100644 index baa8d34fd146..000000000000 --- a/caffe2/utils/bench_utils.cc +++ /dev/null @@ -1,120 +0,0 @@ -#if !defined(__s390x__) && !defined(__powerpc__) -#include -#else -#include -#endif -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include - -#include "caffe2/core/logging.h" -#include "caffe2/utils/bench_utils.h" - -namespace caffe2 { - -uint32_t wipe_cache() { - static uint32_t* wipe_buffer = nullptr; - static size_t wipe_size = 0; - - if (wipe_buffer == nullptr) { -#if !defined(__s390x__) && !defined(__powerpc__) - CAFFE_ENFORCE(cpuinfo_initialize(), "failed to initialize cpuinfo"); - const cpuinfo_processor* processor = cpuinfo_get_processor(0); - if (processor->cache.l4 != nullptr) { - wipe_size = processor->cache.l4->size; - } else if (processor->cache.l3 != nullptr) { - wipe_size = processor->cache.l3->size; - } else if (processor->cache.l2 != nullptr) { - wipe_size = processor->cache.l2->size; - } else { - wipe_size = processor->cache.l1d->size; - } -#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 - /* - * On ARM precise cache size is not available, and cpuinfo may - * underestimate. Use max for uArch (see src/arm/cache.c) - */ - switch (processor->core->uarch) { - case cpuinfo_uarch_cortex_a5: - wipe_size = 512 * 1024; /* Max observed */ - break; - case cpuinfo_uarch_cortex_a7: - wipe_size = 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a8: - wipe_size = 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a9: - wipe_size = 1024 * 1024; /* Max observed */ - break; - case cpuinfo_uarch_cortex_a12: - case cpuinfo_uarch_cortex_a17: - wipe_size = 8 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a15: - wipe_size = 4 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a35: - wipe_size = 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a53: - wipe_size = 2 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a57: - wipe_size = 2 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a72: - wipe_size = 4 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a73: - wipe_size = 8 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a55: - case cpuinfo_uarch_cortex_a75: - case cpuinfo_uarch_meerkat_m3: - wipe_size = 4 * 1024 * 1024; /* DynamIQ max */ - break; - default: - wipe_size = 60 * 1024 * 1024; - break; - } -#endif -#elif defined (__s390x__) - wipe_size = sysconf(_SC_LEVEL4_CACHE_SIZE); - if (wipe_size <= 0) - { - /* - * Take current max L4 cache size for s390x - */ - wipe_size = 1024 * 1024 * 1024; - } -#else - /* ppc64le */ - wipe_size = sysconf(_SC_LEVEL4_CACHE_SIZE); - if (wipe_size <= 0) { - wipe_size = sysconf(_SC_LEVEL3_CACHE_SIZE); - if (wipe_size <= 0) { - wipe_size = sysconf(_SC_LEVEL2_CACHE_SIZE); - if(wipe_size <= 0) { - wipe_size = sysconf(_SC_LEVEL1D_CACHE_SIZE); - } - } - } -#endif - LOG(INFO) << "Allocating cache wipe buffer of size " << wipe_size; - // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) - wipe_buffer = static_cast(malloc(wipe_size)); - CAFFE_ENFORCE(wipe_buffer != nullptr); - } - uint32_t hash = 0; - for (uint32_t i = 0; i * sizeof(uint32_t) < wipe_size; i += 8) { - // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) - hash ^= wipe_buffer[i]; - wipe_buffer[i] = hash; - } - /* Make sure compiler doesn't optimize the loop away */ - return hash; -} - -} /* namespace caffe2 */ diff --git a/caffe2/utils/bench_utils.h b/caffe2/utils/bench_utils.h deleted file mode 100644 index 59997edad58d..000000000000 --- a/caffe2/utils/bench_utils.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef CAFFE2_UTILS_BENCH_UTILS_H_ -#define CAFFE2_UTILS_BENCH_UTILS_H_ - -#include - -#include - -namespace caffe2 { - -TORCH_API uint32_t wipe_cache(); - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_BENCH_UTILS_H_ diff --git a/caffe2/utils/cast.h b/caffe2/utils/cast.h deleted file mode 100644 index 6f9db0837946..000000000000 --- a/caffe2/utils/cast.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -#include - -namespace caffe2 { - -namespace cast { - -inline TensorProto_DataType GetCastDataType(const ArgumentHelper& helper, std::string arg) { - TensorProto_DataType to; - if (helper.HasSingleArgumentOfType(arg)) { - string s = helper.GetSingleArgument(arg, "float"); - std::transform(s.begin(), s.end(), s.begin(), ::toupper); -#ifndef CAFFE2_USE_LITE_PROTO - CAFFE_ENFORCE(TensorProto_DataType_Parse(s, &to), "Unknown 'to' argument: ", s); -#else - -// Manually implement in the lite proto case. -#define X(t) \ - if (s == #t) { \ - return TensorProto_DataType_##t; \ - } - - X(FLOAT); - X(INT32); - X(BYTE); - X(STRING); - X(BOOL); - X(UINT8); - X(INT8); - X(UINT16); - X(INT16); - X(INT64); - X(FLOAT16); - X(DOUBLE); -#undef X - CAFFE_THROW("Unhandled type argument: ", s); - -#endif - } else { - to = static_cast( - helper.GetSingleArgument(arg, TensorProto_DataType_FLOAT)); - } - return to; -} - -}; // namespace cast - -}; // namespace caffe2 diff --git a/caffe2/utils/cast_test.cc b/caffe2/utils/cast_test.cc deleted file mode 100644 index 680e87b3aecc..000000000000 --- a/caffe2/utils/cast_test.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include -#include - -#include - -#include "caffe2/utils/cast.h" - -namespace caffe2 { - -TEST(CastTest, GetCastDataType) { - auto castOp = [](std::string t) { - // Ensure lowercase. - std::transform(t.begin(), t.end(), t.begin(), ::tolower); - auto op = CreateOperatorDef("Cast", "", {}, {}); - AddArgument("to", t, &op); - return op; - }; - -#define X(t) \ - EXPECT_EQ( \ - TensorProto_DataType_##t, \ - cast::GetCastDataType(ArgumentHelper(castOp(#t)), "to")); - - X(FLOAT); - X(INT32); - X(BYTE); - X(STRING); - X(BOOL); - X(UINT8); - X(INT8); - X(UINT16); - X(INT16); - X(INT64); - X(FLOAT16); - X(DOUBLE); -#undef X -} - -} // namespace caffe2 diff --git a/caffe2/utils/cblas.h b/caffe2/utils/cblas.h deleted file mode 100644 index c91b8bf8c530..000000000000 --- a/caffe2/utils/cblas.h +++ /dev/null @@ -1,606 +0,0 @@ -// This is the exact cblas.h header file, placed here purely in order to get -// the enums. - -#include "caffe2/core/macros.h" - -#ifndef CBLAS_H -#ifdef CAFFE2_USE_MKL -#include -#else // CAFFE2_USE_MKL - -#ifndef CBLAS_ENUM_DEFINED_H - #define CBLAS_ENUM_DEFINED_H - enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 }; - enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113, - AtlasConj=114}; - enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; - enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; - enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; -#endif - -#ifndef CBLAS_ENUM_ONLY -#define CBLAS_H -#define CBLAS_INDEX int - -int cblas_errprn(int ierr, int info, char *form, ...); -void cblas_xerbla(int p, const char *rout, const char *form, ...); - -/* - * =========================================================================== - * Prototypes for level 1 BLAS functions (complex are recast as routines) - * =========================================================================== - */ -float cblas_sdsdot(const int N, const float alpha, const float *X, - const int incX, const float *Y, const int incY); -double cblas_dsdot(const int N, const float *X, const int incX, const float *Y, - const int incY); -float cblas_sdot(const int N, const float *X, const int incX, - const float *Y, const int incY); -double cblas_ddot(const int N, const double *X, const int incX, - const double *Y, const int incY); -/* - * Functions having prefixes Z and C only - */ -void cblas_cdotu_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotu); -void cblas_cdotc_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotc); - -void cblas_zdotu_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotu); -void cblas_zdotc_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotc); - - -/* - * Functions having prefixes S D SC DZ - */ -float cblas_snrm2(const int N, const float *X, const int incX); -float cblas_sasum(const int N, const float *X, const int incX); - -double cblas_dnrm2(const int N, const double *X, const int incX); -double cblas_dasum(const int N, const double *X, const int incX); - -float cblas_scnrm2(const int N, const void *X, const int incX); -float cblas_scasum(const int N, const void *X, const int incX); - -double cblas_dznrm2(const int N, const void *X, const int incX); -double cblas_dzasum(const int N, const void *X, const int incX); - - -/* - * Functions having standard 4 prefixes (S D C Z) - */ -CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX); -CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX); -CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX); -CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX); - -/* - * =========================================================================== - * Prototypes for level 1 BLAS routines - * =========================================================================== - */ - -/* - * Routines with standard 4 prefixes (s, d, c, z) - */ -void cblas_sswap(const int N, float *X, const int incX, - float *Y, const int incY); -void cblas_scopy(const int N, const float *X, const int incX, - float *Y, const int incY); -void cblas_saxpy(const int N, const float alpha, const float *X, - const int incX, float *Y, const int incY); -void catlas_saxpby(const int N, const float alpha, const float *X, - const int incX, const float beta, float *Y, const int incY); -void catlas_sset - (const int N, const float alpha, float *X, const int incX); - -void cblas_dswap(const int N, double *X, const int incX, - double *Y, const int incY); -void cblas_dcopy(const int N, const double *X, const int incX, - double *Y, const int incY); -void cblas_daxpy(const int N, const double alpha, const double *X, - const int incX, double *Y, const int incY); -void catlas_daxpby(const int N, const double alpha, const double *X, - const int incX, const double beta, double *Y, const int incY); -void catlas_dset - (const int N, const double alpha, double *X, const int incX); - -void cblas_cswap(const int N, void *X, const int incX, - void *Y, const int incY); -void cblas_ccopy(const int N, const void *X, const int incX, - void *Y, const int incY); -void cblas_caxpy(const int N, const void *alpha, const void *X, - const int incX, void *Y, const int incY); -void catlas_caxpby(const int N, const void *alpha, const void *X, - const int incX, const void *beta, void *Y, const int incY); -void catlas_cset - (const int N, const void *alpha, void *X, const int incX); - -void cblas_zswap(const int N, void *X, const int incX, - void *Y, const int incY); -void cblas_zcopy(const int N, const void *X, const int incX, - void *Y, const int incY); -void cblas_zaxpy(const int N, const void *alpha, const void *X, - const int incX, void *Y, const int incY); -void catlas_zaxpby(const int N, const void *alpha, const void *X, - const int incX, const void *beta, void *Y, const int incY); -void catlas_zset - (const int N, const void *alpha, void *X, const int incX); - - -/* - * Routines with S and D prefix only - */ -void cblas_srotg(float *a, float *b, float *c, float *s); -void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); -void cblas_srot(const int N, float *X, const int incX, - float *Y, const int incY, const float c, const float s); -void cblas_srotm(const int N, float *X, const int incX, - float *Y, const int incY, const float *P); - -void cblas_drotg(double *a, double *b, double *c, double *s); -void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P); -void cblas_drot(const int N, double *X, const int incX, - double *Y, const int incY, const double c, const double s); -void cblas_drotm(const int N, double *X, const int incX, - double *Y, const int incY, const double *P); - - -/* - * Routines with S D C Z CS and ZD prefixes - */ -void cblas_sscal(const int N, const float alpha, float *X, const int incX); -void cblas_dscal(const int N, const double alpha, double *X, const int incX); -void cblas_cscal(const int N, const void *alpha, void *X, const int incX); -void cblas_zscal(const int N, const void *alpha, void *X, const int incX); -void cblas_csscal(const int N, const float alpha, void *X, const int incX); -void cblas_zdscal(const int N, const double alpha, void *X, const int incX); - -/* - * Extra reference routines provided by ATLAS, but not mandated by the standard - */ -void cblas_crotg(void *a, void *b, void *c, void *s); -void cblas_zrotg(void *a, void *b, void *c, void *s); -void cblas_csrot(const int N, void *X, const int incX, void *Y, const int incY, - const float c, const float s); -void cblas_zdrot(const int N, void *X, const int incX, void *Y, const int incY, - const double c, const double s); - -/* - * =========================================================================== - * Prototypes for level 2 BLAS - * =========================================================================== - */ - -/* - * Routines with standard 4 prefixes (S, D, C, Z) - */ -void cblas_sgemv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const float alpha, const float *A, const int lda, - const float *X, const int incX, const float beta, - float *Y, const int incY); -void cblas_sgbmv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const float alpha, - const float *A, const int lda, const float *X, - const int incX, const float beta, float *Y, const int incY); -void cblas_strmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *A, const int lda, - float *X, const int incX); -void cblas_stbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const float *A, const int lda, - float *X, const int incX); -void cblas_stpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *Ap, float *X, const int incX); -void cblas_strsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *A, const int lda, float *X, - const int incX); -void cblas_stbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const float *A, const int lda, - float *X, const int incX); -void cblas_stpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *Ap, float *X, const int incX); - -void cblas_dgemv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const double alpha, const double *A, const int lda, - const double *X, const int incX, const double beta, - double *Y, const int incY); -void cblas_dgbmv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const double alpha, - const double *A, const int lda, const double *X, - const int incX, const double beta, double *Y, const int incY); -void cblas_dtrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *A, const int lda, - double *X, const int incX); -void cblas_dtbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const double *A, const int lda, - double *X, const int incX); -void cblas_dtpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *Ap, double *X, const int incX); -void cblas_dtrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *A, const int lda, double *X, - const int incX); -void cblas_dtbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const double *A, const int lda, - double *X, const int incX); -void cblas_dtpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *Ap, double *X, const int incX); - -void cblas_cgemv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *X, const int incX, const void *beta, - void *Y, const int incY); -void cblas_cgbmv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const void *alpha, - const void *A, const int lda, const void *X, - const int incX, const void *beta, void *Y, const int incY); -void cblas_ctrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, - void *X, const int incX); -void cblas_ctbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); -void cblas_ctpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); -void cblas_ctrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, void *X, - const int incX); -void cblas_ctbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); -void cblas_ctpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); - -void cblas_zgemv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *X, const int incX, const void *beta, - void *Y, const int incY); -void cblas_zgbmv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const void *alpha, - const void *A, const int lda, const void *X, - const int incX, const void *beta, void *Y, const int incY); -void cblas_ztrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, - void *X, const int incX); -void cblas_ztbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); -void cblas_ztpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); -void cblas_ztrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, void *X, - const int incX); -void cblas_ztbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); -void cblas_ztpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); - - -/* - * Routines with S and D prefixes only - */ -void cblas_ssymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *A, - const int lda, const float *X, const int incX, - const float beta, float *Y, const int incY); -void cblas_ssbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const float alpha, const float *A, - const int lda, const float *X, const int incX, - const float beta, float *Y, const int incY); -void cblas_sspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *Ap, - const float *X, const int incX, - const float beta, float *Y, const int incY); -void cblas_sger(const enum CBLAS_ORDER Order, const int M, const int N, - const float alpha, const float *X, const int incX, - const float *Y, const int incY, float *A, const int lda); -void cblas_ssyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, float *A, const int lda); -void cblas_sspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, float *Ap); -void cblas_ssyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, const float *Y, const int incY, float *A, - const int lda); -void cblas_sspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, const float *Y, const int incY, float *A); - -void cblas_dsymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *A, - const int lda, const double *X, const int incX, - const double beta, double *Y, const int incY); -void cblas_dsbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const double alpha, const double *A, - const int lda, const double *X, const int incX, - const double beta, double *Y, const int incY); -void cblas_dspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *Ap, - const double *X, const int incX, - const double beta, double *Y, const int incY); -void cblas_dger(const enum CBLAS_ORDER Order, const int M, const int N, - const double alpha, const double *X, const int incX, - const double *Y, const int incY, double *A, const int lda); -void cblas_dsyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, double *A, const int lda); -void cblas_dspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, double *Ap); -void cblas_dsyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, const double *Y, const int incY, double *A, - const int lda); -void cblas_dspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, const double *Y, const int incY, double *A); - - -/* - * Routines with C and Z prefixes only - */ -void cblas_chemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_chbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_chpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *Ap, - const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_cgeru(const enum CBLAS_ORDER Order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_cgerc(const enum CBLAS_ORDER Order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_cher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const void *X, const int incX, - void *A, const int lda); -void cblas_chpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const void *X, - const int incX, void *A); -void cblas_cher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_chpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *Ap); - -void cblas_zhemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_zhbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_zhpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *Ap, - const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_zgeru(const enum CBLAS_ORDER Order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_zgerc(const enum CBLAS_ORDER Order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_zher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const void *X, const int incX, - void *A, const int lda); -void cblas_zhpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const void *X, - const int incX, void *A); -void cblas_zher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_zhpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *Ap); - -/* - * =========================================================================== - * Prototypes for level 3 BLAS - * =========================================================================== - */ - -/* - * Routines with standard 4 prefixes (S, D, C, Z) - */ -void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const float alpha, const float *A, - const int lda, const float *B, const int ldb, - const float beta, float *C, const int ldc); -void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const float alpha, const float *A, const int lda, - const float *B, const int ldb, const float beta, - float *C, const int ldc); -void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const float alpha, const float *A, const int lda, - const float beta, float *C, const int ldc); -void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const float alpha, const float *A, const int lda, - const float *B, const int ldb, const float beta, - float *C, const int ldc); -void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const float alpha, const float *A, const int lda, - float *B, const int ldb); -void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const float alpha, const float *A, const int lda, - float *B, const int ldb); - -void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const double alpha, const double *A, - const int lda, const double *B, const int ldb, - const double beta, double *C, const int ldc); -void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const double alpha, const double *A, const int lda, - const double *B, const int ldb, const double beta, - double *C, const int ldc); -void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const double alpha, const double *A, const int lda, - const double beta, double *C, const int ldc); -void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const double alpha, const double *A, const int lda, - const double *B, const int ldb, const double beta, - double *C, const int ldc); -void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const double alpha, const double *A, const int lda, - double *B, const int ldb); -void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const double alpha, const double *A, const int lda, - double *B, const int ldb); - -void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const void *alpha, const void *A, - const int lda, const void *B, const int ldb, - const void *beta, void *C, const int ldc); -void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *beta, void *C, const int ldc); -void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); -void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); - -void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const void *alpha, const void *A, - const int lda, const void *B, const int ldb, - const void *beta, void *C, const int ldc); -void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *beta, void *C, const int ldc); -void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); -void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); - - -/* - * Routines with prefixes C and Z only - */ -void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const float alpha, const void *A, const int lda, - const float beta, void *C, const int ldc); -void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const float beta, - void *C, const int ldc); -void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const double alpha, const void *A, const int lda, - const double beta, void *C, const int ldc); -void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const double beta, - void *C, const int ldc); - -int cblas_errprn(int ierr, int info, char *form, ...); - -#endif /* end #ifdef CBLAS_ENUM_ONLY */ -#endif // CAFFE2_USE_MKL -#endif diff --git a/caffe2/utils/cpu_neon.h b/caffe2/utils/cpu_neon.h deleted file mode 100644 index 7e68d73c1bef..000000000000 --- a/caffe2/utils/cpu_neon.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef CAFFE2_UTILS_CPU_NEON_H_ -#define CAFFE2_UTILS_CPU_NEON_H_ - -// Provides a variety of ARM NEON-specific utility functions -#if defined(__ARM_NEON__) || defined(__ARM_NEON) -#include - -namespace caffe2 { - -template -inline bool isPointerAligned(T* p, size_t align) { - return (reinterpret_cast(p) % align == 0); -} - -inline float32x4_t vert_sum_f32(float32x4_t v0, - float32x4_t v1, - float32x4_t v2, - float32x4_t v3) { - v0 = vaddq_f32(v0, v1); - v2 = vaddq_f32(v2, v3); - return vaddq_f32(v0, v2); -} - -inline float horizontal_sum_f32(float32x4_t v0, - float32x4_t v1, - float32x4_t v2, - float32x4_t v3) { - v0 = vert_sum_f32(v0, v1, v2, v3); - float32x2_t v = vadd_f32(vget_high_f32(v0), vget_low_f32(v0)); - return vget_lane_f32(vpadd_f32(v, v), 0); -} - -// Load/store functions that assume alignment - -inline float32x4_t vld1q_f32_aligned(const float* p) { - return vld1q_f32((const float*) - __builtin_assume_aligned(p, sizeof(float32x4_t))); -} - -inline void vst1q_f32_aligned(float* p, float32x4_t v) { - vst1q_f32((float*) __builtin_assume_aligned(p, sizeof(float32x4_t)), v); -} - -inline void vst4_u8_aligned(uint8_t* p, uint8x8x4_t v) { - vst4_u8((uint8_t*) - __builtin_assume_aligned(p, sizeof(uint8x8x4_t)), v); -} - -} // namespace caffe2 - -#endif // defined(__ARM_NEON__) || defined(__ARM_NEON) - -#endif // CAFFE2_UTILS_CPU_NEON_H_ diff --git a/caffe2/utils/cpuid_test.cc b/caffe2/utils/cpuid_test.cc deleted file mode 100644 index f3694f5d0bac..000000000000 --- a/caffe2/utils/cpuid_test.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include -#include "caffe2/utils/cpuid.h" - -namespace caffe2 { - -TEST(CpuIdTest, ShouldAlwaysHaveMMX) { - EXPECT_TRUE(GetCpuId().mmx()); -} - -} // namespace caffe2 diff --git a/caffe2/utils/cub_namespace.cuh b/caffe2/utils/cub_namespace.cuh deleted file mode 100644 index 188a9936f9c6..000000000000 --- a/caffe2/utils/cub_namespace.cuh +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -// cub sort support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: -// https://github.com/NVIDIA/cub/pull/326 -// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake -// starting from CUDA 11.5 -#if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE) -#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true -#else -#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false -#endif - -#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE() -namespace caffe2 { -namespace cub = ::CUB_WRAPPED_NAMESPACE::cub; -} -#endif diff --git a/caffe2/utils/eigen_utils.h b/caffe2/utils/eigen_utils.h deleted file mode 100644 index c6c34dba9b5a..000000000000 --- a/caffe2/utils/eigen_utils.h +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#ifndef CAFFE2_OPERATORS_UTILS_EIGEN_H_ -#define CAFFE2_OPERATORS_UTILS_EIGEN_H_ - -#include "Eigen/Core" -#include "Eigen/Dense" - -#include -#include - -namespace caffe2 { - -// Common Eigen types that we will often use -template -using EigenMatrixMap = - Eigen::Map>; -template -using EigenArrayMap = - Eigen::Map>; -template -using EigenVectorMap = Eigen::Map>; -template -using EigenVectorArrayMap = Eigen::Map>; -template -using ConstEigenMatrixMap = - Eigen::Map>; -template -using ConstEigenArrayMap = - Eigen::Map>; -template -using ConstEigenVectorMap = - Eigen::Map>; -template -using ConstEigenVectorArrayMap = - Eigen::Map>; - -using EigenOuterStride = Eigen::OuterStride; -using EigenInnerStride = Eigen::InnerStride; -using EigenStride = Eigen::Stride; -template -using EigenOuterStridedMatrixMap = Eigen:: - Map, 0, EigenOuterStride>; -template -using EigenOuterStridedArrayMap = Eigen:: - Map, 0, EigenOuterStride>; -template -using ConstEigenOuterStridedMatrixMap = Eigen::Map< - const Eigen::Matrix, - 0, - EigenOuterStride>; -template -using ConstEigenOuterStridedArrayMap = Eigen::Map< - const Eigen::Array, - 0, - EigenOuterStride>; -template -using EigenStridedMatrixMap = Eigen:: - Map, 0, EigenStride>; -template -using EigenStridedArrayMap = - Eigen::Map, 0, EigenStride>; -template -using ConstEigenStridedMatrixMap = Eigen:: - Map, 0, EigenStride>; -template -using ConstEigenStridedArrayMap = Eigen:: - Map, 0, EigenStride>; - -// 1-d array -template -using EArrXt = Eigen::Array; -using EArrXf = Eigen::ArrayXf; -using EArrXd = Eigen::ArrayXd; -using EArrXi = Eigen::ArrayXi; -using EArrXb = EArrXt; -using EArrXI32 = EArrXt; -using EArrXU16 = EArrXt; -using EArrXU8 = EArrXt; -using EArr3U8 = Eigen::Array; - -// 2-d array, column major -template -using EArrXXt = Eigen::Array; -using EArrXXf = Eigen::ArrayXXf; -using EArrXXI32 = EArrXXt; -using EArrXXU16 = EArrXXt; -using EArrXXU8 = EArrXXt; -using EArrXXi = EArrXXt; - -// 2-d array, row major -template -using ERArrXXt = - Eigen::Array; -using ERArrXXf = ERArrXXt; -using ERArrXXI32t = ERArrXXt; -using ERArrXXU16t = ERArrXXt; -using ERArrXXU8t = ERArrXXt; -using ERArrXXi = ERArrXXt; -using ERArrXXi64t = ERArrXXt; -using ERArrXXi32t = ERArrXXt; - -// 1-d vector -template -using EVecXt = Eigen::Matrix; -using EVecXd = Eigen::VectorXd; -using EVecXf = Eigen::VectorXf; - -// 1-d row vector -using ERVecXd = Eigen::RowVectorXd; -using ERVecXf = Eigen::RowVectorXf; - -// 2-d matrix, column major -template -using EMatXt = Eigen::Matrix; -using EMatXd = Eigen::MatrixXd; -using EMatXf = Eigen::MatrixXf; -using EMatXU8 = EMatXt; -using EMatXU16 = EMatXt; - -// 2-d matrix, row major -template -using ERMatXt = - Eigen::Matrix; -using ERMatXd = ERMatXt; -using ERMatXf = ERMatXt; -using ERMatXU8 = ERMatXt; - -namespace utils { - -template -Eigen::Map> AsEArrXt(const std::vector& arr) { - return {arr.data(), static_cast(arr.size())}; -} -template -Eigen::Map> AsEArrXt(std::vector& arr) { - return {arr.data(), static_cast(arr.size())}; -} - -// return a sub array of 'array' based on indices 'indices' -template -void GetSubArray( - const Eigen::ArrayBase& array, - const Eigen::ArrayBase& indices, - Eigen::ArrayBase* out_array) { - CAFFE_ENFORCE_EQ(array.cols(), 1); - // using T = typename Derived::Scalar; - - out_array->derived().resize(indices.size()); - for (const auto i : c10::irange(indices.size())) { - TORCH_DCHECK_LT(indices[i], array.size()); - (*out_array)[i] = array[indices[i]]; - } -} - -// return a sub array of 'array' based on indices 'indices' -template -EArrXt GetSubArray( - const Eigen::ArrayBase& array, - const Eigen::ArrayBase& indices) { - using T = typename Derived::Scalar; - EArrXt ret(indices.size()); - GetSubArray(array, indices, &ret); - return ret; -} - -// return a sub array of 'array' based on indices 'indices' -template -EArrXt GetSubArray( - const Eigen::ArrayBase& array, - const std::vector& indices) { - return GetSubArray(array, AsEArrXt(indices)); -} - -// return 2d sub array of 'array' based on row indices 'row_indices' -template -void GetSubArrayRows( - const Eigen::ArrayBase& array2d, - const Eigen::ArrayBase& row_indices, - Eigen::ArrayBase* out_array) { - out_array->derived().resize(row_indices.size(), array2d.cols()); - - for (const auto i : c10::irange(row_indices.size())) { - TORCH_DCHECK_LT(row_indices[i], array2d.size()); - out_array->row(i) = - array2d.row(row_indices[i]).template cast(); - } -} - -// return indices of 1d array for elements evaluated to true -template -std::vector GetArrayIndices(const Eigen::ArrayBase& array) { - std::vector ret; - for (const auto i : c10::irange(array.size())) { - if (array[i]) { - ret.push_back(i); - } - } - return ret; -} - -} // namespace utils -} // namespace caffe2 - -#endif diff --git a/caffe2/utils/fatal_signal_asan_no_sig_test.cc b/caffe2/utils/fatal_signal_asan_no_sig_test.cc deleted file mode 100644 index 9c64102981c3..000000000000 --- a/caffe2/utils/fatal_signal_asan_no_sig_test.cc +++ /dev/null @@ -1,148 +0,0 @@ -#include "caffe2/utils/signal_handler.h" -#if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS) -#include -#include -#include - -#include -#include -#include - -#include "caffe2/core/common.h" - -namespace { -void* dummy_thread(void*) { - while (1) { - } - return nullptr; -} - -bool forkAndPipe( - std::string& stderrBuffer, - std::function callback) { - std::array stderrPipe; - if (pipe(stderrPipe.data()) != 0) { - perror("STDERR pipe"); - return false; - } - pid_t child = fork(); - if (child == 0) { - // Replace this process' stderr so we can read it. - if (dup2(stderrPipe[1], STDERR_FILENO) < 0) { - close(stderrPipe[0]); - close(stderrPipe[1]); - perror("dup2 STDERR"); - exit(5); - } - - // This is for the parent to work with. - close(stderrPipe[0]); - close(stderrPipe[1]); - - callback(); - exit(7); - } else if (child > 0) { - const int bufferSize = 128; - std::array buffer; - - // We want to close the writing end of the pipe right away so our - // read actually gets an EOF. - close(stderrPipe[1]); - - // wait for child to finish crashing. - int statloc; - if (wait(&statloc) < 0) { - close(stderrPipe[0]); - perror("wait"); - return false; - } - - ssize_t bytesRead; - while ((bytesRead = read(stderrPipe[0], buffer.data(), bufferSize)) > 0) { - const std::string tmp(buffer.data(), bytesRead); - std::cout << tmp; - stderrBuffer += tmp; - } - - // The child should have exited due to signal. - if (!WIFSIGNALED(statloc)) { - fprintf(stderr, "Child didn't exit because it received a signal\n"); - if (WIFEXITED(statloc)) { - fprintf(stderr, "Exited with code: %d\n", WEXITSTATUS(statloc) & 0xff); - } - return false; - } - - if (bytesRead < 0) { - perror("read"); - return false; - } - - close(stderrPipe[0]); - return true; - } else { - perror("fork"); - return false; - } -} -} // namespace - -#define _TEST_FATAL_SIGNAL(signum, name, threadCount, print, expected) \ - do { \ - std::string stderrBuffer; \ - ASSERT_TRUE(forkAndPipe(stderrBuffer, [=]() { \ - caffe2::setPrintStackTracesOnFatalSignal(print); \ - pthread_t pt; \ - for (int i = 0; i < threadCount; i++) { \ - if (pthread_create(&pt, nullptr, ::dummy_thread, nullptr)) { \ - perror("pthread_create"); \ - } \ - } \ - raise(signum); \ - })); \ - int keyPhraseCount = 0; \ - std::string keyPhrase = \ - std::string(name) + "(" + c10::to_string(signum) + ")"; \ - size_t loc = 0; \ - while ((loc = stderrBuffer.find(keyPhrase, loc)) != std::string::npos) { \ - keyPhraseCount += 1; \ - loc += 1; \ - } \ - EXPECT_GE(keyPhraseCount, expected); \ - } while (0) - -#define TEST_FATAL_SIGNAL(signum, name, threadCount) \ - _TEST_FATAL_SIGNAL(signum, name, threadCount, true, threadCount + 1) - -#define TEST_FATAL_SIGNAL_NO_PRINT(signum, name, threadCount) \ - _TEST_FATAL_SIGNAL(signum, name, threadCount, false, 0) - -TEST(fatalSignalTest, SIGABRT8) { - TEST_FATAL_SIGNAL(SIGABRT, "SIGABRT", 8); -} - -TEST(fatalSignalTest, SIGINT8) { - TEST_FATAL_SIGNAL(SIGINT, "SIGINT", 8); -} - -TEST(fatalSignalTest, SIGILL8) { - TEST_FATAL_SIGNAL(SIGILL, "SIGILL", 8); -} - -TEST(fatalSignalTest, SIGFPE8) { - TEST_FATAL_SIGNAL(SIGFPE, "SIGFPE", 8); -} - -TEST(fatalSignalTest, SIGBUS8) { - TEST_FATAL_SIGNAL(SIGBUS, "SIGBUS", 8); -} - -TEST(fatalSignalTest, SIGSEGV8) { - TEST_FATAL_SIGNAL(SIGSEGV, "SIGSEGV", 8); -} - -// Test that if we don't enable printing stack traces then we don't get any. -TEST(fatalSignalTest, SIGABRT8_NOPRINT) { - TEST_FATAL_SIGNAL_NO_PRINT(SIGABRT, "SIGABRT", 8); -} -#endif // defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS) diff --git a/caffe2/utils/filler.h b/caffe2/utils/filler.h deleted file mode 100644 index 3d0e399ba73b..000000000000 --- a/caffe2/utils/filler.h +++ /dev/null @@ -1,140 +0,0 @@ -#ifndef CAFFE2_FILLER_H_ -#define CAFFE2_FILLER_H_ - -#include - -#include "caffe2/core/logging.h" -#include "caffe2/core/tensor.h" -#include "caffe2/utils/math.h" - -namespace caffe2 { - -// TODO: replace filler distribution enum with a better abstraction -enum FillerDistribution { FD_UNIFORM, FD_FIXEDSUM, FD_SYNTHETIC }; - -class TensorFiller { - public: - template - void Fill(Tensor* tensor, Context* context) const { - CAFFE_ENFORCE(context, "context is null"); - CAFFE_ENFORCE(tensor, "tensor is null"); - auto min = (min_ < (double)std::numeric_limits::min()) - ? std::numeric_limits::min() - : static_cast(min_); - auto max = (max_ > (double)std::numeric_limits::max()) - ? std::numeric_limits::max() - : static_cast(max_); - CAFFE_ENFORCE_LE(min, max); - - Tensor temp_tensor(shape_, Context::GetDeviceType()); - std::swap(*tensor, temp_tensor); - Type* data = tensor->template mutable_data(); - - // select distribution - switch (dist_) { - case FD_UNIFORM: { - math::RandUniform( - tensor->numel(), min, max, data, context); - break; - } - case FD_FIXEDSUM: { - auto fixed_sum = static_cast(fixed_sum_); - CAFFE_ENFORCE_LE(min * tensor->numel(), fixed_sum); - CAFFE_ENFORCE_GE(max * tensor->numel(), fixed_sum); - math::RandFixedSum( - tensor->numel(), min, max, fixed_sum_, data, context); - break; - } - case FD_SYNTHETIC: { - math::RandSyntheticData( - tensor->numel(), min, max, data, context); - break; - } - } - } - - TensorFiller& Dist(FillerDistribution dist) { - dist_ = dist; - return *this; - } - - template - TensorFiller& Min(Type min) { - min_ = (double)min; - return *this; - } - - template - TensorFiller& Max(Type max) { - max_ = (double)max; - return *this; - } - - template - TensorFiller& FixedSum(Type fixed_sum) { - dist_ = FD_FIXEDSUM; - fixed_sum_ = (double)fixed_sum; - return *this; - } - - // A helper function to construct the lengths vector for sparse features - // We try to pad least one index per batch unless the total_length is 0 - template - TensorFiller& SparseLengths(Type total_length) { - return FixedSum(total_length) - .Min(std::min(static_cast(1), total_length)) - .Max(total_length); - } - - // a helper function to construct the segments vector for sparse features - template - TensorFiller& SparseSegments(Type max_segment) { - CAFFE_ENFORCE(dist_ != FD_FIXEDSUM); - return Min(0).Max(max_segment).Dist(FD_SYNTHETIC); - } - - TensorFiller& Shape(const std::vector& shape) { - shape_ = shape; - return *this; - } - - template - TensorFiller(const std::vector& shape, Type fixed_sum) - : shape_(shape), dist_(FD_FIXEDSUM), fixed_sum_((double)fixed_sum) {} - - TensorFiller(const std::vector& shape) - : shape_(shape), dist_(FD_UNIFORM), fixed_sum_(0) {} - - TensorFiller() : TensorFiller(std::vector()) {} - - std::string DebugString() const { - std::stringstream stream; - stream << "shape = [" << shape_ << "]; min = " << min_ - << "; max = " << max_; - switch (dist_) { - case FD_FIXEDSUM: - stream << "; dist = FD_FIXEDSUM"; - break; - case FD_SYNTHETIC: - stream << "; dist = FD_SYNTHETIC"; - break; - default: - stream << "; dist = FD_UNIFORM"; - break; - } - return stream.str(); - } - - private: - std::vector shape_; - // TODO: type is unknown until a user starts to fill data; - // cast everything to double for now. - double min_ = 0.0; - double max_ = 1.0; - FillerDistribution dist_; - double fixed_sum_; -}; - -} // namespace caffe2 - -#endif // CAFFE2_FILLER_H_ diff --git a/caffe2/utils/fixed_divisor_test.cc b/caffe2/utils/fixed_divisor_test.cc deleted file mode 100644 index 6093bc764c39..000000000000 --- a/caffe2/utils/fixed_divisor_test.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "caffe2/utils/fixed_divisor.h" - -#include - -#include - -namespace caffe2 { - -namespace { - -void CompareDivMod(int32_t v, int32_t divisor) { - auto fixed = FixedDivisor(divisor); - - int native_q = v / divisor; - int native_r = v % divisor; - - int fixed_q = fixed.Div(v); - int fixed_r = fixed.Mod(v); - -#if !defined(USE_ROCM) - EXPECT_EQ(native_q, fixed_q) - << v << " / " << divisor << " magic " << fixed.magic() << " shift " - << fixed.shift() << " quot " << fixed_q << " " << native_q; - - EXPECT_EQ(native_r, fixed_r) - << v << " / " << divisor << " magic " << fixed.magic() << " shift " - << fixed.shift() << " rem " << fixed_r << " " << native_r; -#endif -} - -} // namespace - -TEST(FixedDivisorTest, FixedDivisorInt32Test) { - constexpr int32_t kMax = std::numeric_limits::max(); - - // divide by 1 - CompareDivMod(kMax, 1); - CompareDivMod(0, 1); - CompareDivMod(1, 1); - - // divide by max - CompareDivMod(kMax, kMax); - CompareDivMod(0, kMax); - CompareDivMod(1, kMax); - - // divide by random positive values - std::random_device rd; - std::uniform_int_distribution v_dist(0, kMax); - std::uniform_int_distribution q_dist(1, kMax); - - std::uniform_int_distribution v_small_dist(0, 1000); - std::uniform_int_distribution q_small_dist(1, 1000); - for (int i = 0; i < 10000; ++i) { - auto q = q_dist(rd); - auto v = v_dist(rd); - auto q_small = q_small_dist(rd); - auto v_small = v_small_dist(rd); - - // random value - CompareDivMod(v_small, q_small); - CompareDivMod(v_small, q); - CompareDivMod(v, q_small); - CompareDivMod(v, q); - - // special values - CompareDivMod(kMax, q_small); - CompareDivMod(0, q_small); - CompareDivMod(1, q_small); - CompareDivMod(kMax, q); - CompareDivMod(0, q); - CompareDivMod(1, q); - - CompareDivMod(v_small, 1); - CompareDivMod(v_small, kMax); - CompareDivMod(v, 1); - CompareDivMod(v, kMax); - } -} - -} // namespace caffe2 diff --git a/caffe2/utils/knob_patcher.cc b/caffe2/utils/knob_patcher.cc deleted file mode 100644 index e099ea61dd87..000000000000 --- a/caffe2/utils/knob_patcher.cc +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and its affiliates. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include -#include -#include - -#include "caffe2/utils/knobs.h" -#include "caffe2/utils/knob_patcher.h" - -namespace caffe2 { -namespace detail { -std::map& getRegisteredKnobs(); -} // namespace detail - -namespace { -class PatchNode { - public: - PatchNode(c10::string_view name, bool value); - ~PatchNode(); - - std::string name; - bool oldValue{false}; - // Nodes to form a linked list of existing PatchState objects for this knob. - // This allows us to restore state correctly even if KnobPatcher objects - // are destroyed in any arbitrary order. - PatchNode* prev{nullptr}; - PatchNode* next{nullptr}; -}; -} // namespace - -class KnobPatcher::PatchState : public PatchNode { - using PatchNode::PatchNode; -}; - -KnobPatcher::KnobPatcher(c10::string_view name, bool value) - : state_{std::make_unique(name, value)} {} - -KnobPatcher::~KnobPatcher() = default; -KnobPatcher::KnobPatcher(KnobPatcher&&) noexcept = default; -KnobPatcher& KnobPatcher::operator=(KnobPatcher&&) noexcept = default; - -namespace { - -class Patcher { - public: - void patch(PatchNode* node, bool value) { - std::lock_guard lock{mutex_}; - - node->oldValue = setKnobValue(node->name, value); - auto ret = patches_.emplace(node->name, node); - if (!ret.second) { - // There was already another patcher for this knob - // Append the new node to the linked list. - node->prev = ret.first->second; - CHECK(!node->prev->next); - node->prev->next = node; - ret.first->second = node; - } - } - - void unpatch(PatchNode* node) { - std::lock_guard lock{mutex_}; - - // Remove this PatchNode from the linked list - if (node->prev) { - node->prev->next = node->next; - } - if (node->next) { - // There was another patch applied after this one. - node->next->prev = node->prev; - node->next->oldValue = node->oldValue; - } else { - // This was the most recently applied patch for this knob, - // so restore the knob value. - setKnobValue(node->name, node->oldValue); - - // The patches_ map should point to this node. - // Update it to point to the previous patch, if there is one. - auto iter = patches_.find(node->name); - if (iter == patches_.end()) { - LOG(FATAL) << "patch node not found when unpatching knob value"; - } - TORCH_CHECK_EQ(iter->second, node); - if (node->prev) { - iter->second = node->prev; - } else { - patches_.erase(iter); - } - } - } - - private: - bool setKnobValue(c10::string_view name, bool value) { - auto& knobs = caffe2::detail::getRegisteredKnobs(); - auto iter = knobs.find(name); - if (iter == knobs.end()) { - throw std::invalid_argument( - "attempted to patch unknown knob \"" + std::string(name) + "\""); - } - bool oldValue = *(iter->second); - *iter->second = value; - return oldValue; - } - - std::mutex mutex_; - std::map patches_; -}; - -Patcher& getPatcher() { - static Patcher patcher; - return patcher; -} - -PatchNode::PatchNode(c10::string_view knobName, bool value) - : name{knobName} { - getPatcher().patch(this, value); -} - -PatchNode::~PatchNode() { - try { - getPatcher().unpatch(this); - } catch (const std::exception& ex) { - // This shouldn't ever happen unless we have a programming bug, but it keeps - // clang-tidy happy if we put a catch block here to handle the theoretical - // error if unpatch() calls setKnobValue() and it throws due to not finding - // the knob by name. - LOG(FATAL) << "error removing knob patch: " << ex.what(); - } -} - -} // namespace -} // namespace caffe2 diff --git a/caffe2/utils/knob_patcher.h b/caffe2/utils/knob_patcher.h deleted file mode 100644 index ec2b6277760d..000000000000 --- a/caffe2/utils/knob_patcher.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -#include - -#include - -namespace caffe2 { - -/** - * Patch the value of a knob during a unit test. - * - * This forces the knob to the specified value for as long as the KnobPatcher - * object exists. When the KnobPatcher object is destroyed the knob will revert - * to its previous value. - */ -class KnobPatcher { - public: - KnobPatcher(c10::string_view name, bool value); - ~KnobPatcher(); - - KnobPatcher(KnobPatcher&&) noexcept; - KnobPatcher& operator=(KnobPatcher&&) noexcept; - KnobPatcher(const KnobPatcher&) = delete; - KnobPatcher& operator=(const KnobPatcher&) = delete; - - private: - class PatchState; - - std::unique_ptr state_; -}; - -} // namespace caffe2 diff --git a/caffe2/utils/knobs.cc b/caffe2/utils/knobs.cc deleted file mode 100644 index 63941a573edf..000000000000 --- a/caffe2/utils/knobs.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// This is a very basic knob implementation that purely uses command line flags. -// This can be replaced with a more sophisticated implementation for use in -// other production environments. - -#include - -#include -#include - -#include "caffe2/utils/knobs.h" - -namespace caffe2 { - -namespace detail { -// Get the map of knob names to pointers to their command-line controlled -// boolean value. -std::map& getRegisteredKnobs() { - // It's safe to store the keys as string_view, since DEFINE_KNOB() ensures - // that these views always point to string literals. - static std::map registeredKnobs; - return registeredKnobs; -} -} // namespace detail - -bool CheckKnob(c10::string_view name) { - const auto& knobs = detail::getRegisteredKnobs(); - auto iter = knobs.find(name); - if (iter == knobs.end()) { - throw std::invalid_argument( - "attempted to check unknown knob \"" + std::string(name) + "\""); - } - return *iter->second; -} - -namespace { -class RegisterKnob { - public: - RegisterKnob(c10::string_view name, bool* cmdlineFlag) { - auto ret = caffe2::detail::getRegisteredKnobs().emplace(name, cmdlineFlag); - if (!ret.second) { - throw std::runtime_error("duplicate knob name: " + std::string(name)); - } - } -}; -} // namespace -} // namespace caffe2 - -/** - * Define a knob. - * - * This will define a --caffe2_knob_ command line flag to control the - * knob. - * - * The knob can be checked in code by calling CheckKnob(name) - * or CheckKnob() - */ -#define DEFINE_KNOB(name, check_fn_name, default_value, docstring) \ - C10_DEFINE_bool(caffe2_knob_##name, default_value, docstring); \ - namespace caffe2 { \ - bool CheckKnob##check_fn_name() { \ - return FLAGS_caffe2_knob_##name; \ - } \ - } \ - static caffe2::RegisterKnob _knob_##name(#name, &FLAGS_caffe2_knob_##name) - -/* - * Definitions of well-known knobs. - */ - -DEFINE_KNOB( - example_knob, - ExampleKnob, - false, - "An example knob, mainly intended for use in unit tests"); diff --git a/caffe2/utils/knobs.h b/caffe2/utils/knobs.h deleted file mode 100644 index fbebd90cf741..000000000000 --- a/caffe2/utils/knobs.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -// This file contains functions for checking rollout knobs to enable staged -// roll out of specific code functionality. - -#include - -#include - -namespace caffe2 { - -/** - * Check an arbitrary knob by name. - */ -bool CheckKnob(c10::string_view name); - -/* - * The following are functions for checking specific known knob values. - * - * These APIs are more efficient than checking by name. - */ - -// An example knob, just for use in unit tests. -bool CheckKnobExampleKnob(); - -} // namespace caffe2 diff --git a/caffe2/utils/knobs_test.cc b/caffe2/utils/knobs_test.cc deleted file mode 100644 index 95f29cff2471..000000000000 --- a/caffe2/utils/knobs_test.cc +++ /dev/null @@ -1,34 +0,0 @@ -#include - -#include "caffe2/utils/knobs.h" -#include "caffe2/utils/knob_patcher.h" - -using namespace caffe2; - -TEST(KnobsTest, TestKnob) { - auto p = KnobPatcher("example_knob", false); - EXPECT_FALSE(CheckKnobExampleKnob()); - EXPECT_FALSE(CheckKnob("example_knob")); - - p = KnobPatcher("example_knob", true); - EXPECT_TRUE(CheckKnobExampleKnob()); - EXPECT_TRUE(CheckKnob("example_knob")); - - // Test nested patchers - { - auto p2 = KnobPatcher("example_knob", false); - EXPECT_FALSE(CheckKnobExampleKnob()); - EXPECT_FALSE(CheckKnob("example_knob")); - - auto p3 = KnobPatcher("example_knob", true); - EXPECT_TRUE(CheckKnobExampleKnob()); - EXPECT_TRUE(CheckKnob("example_knob")); - } - EXPECT_TRUE(CheckKnobExampleKnob()); - EXPECT_TRUE(CheckKnob("example_knob")); -} - -TEST(KnobsTest, TestUnknownKnob) { - // Unknown knob names should throw an exception - EXPECT_THROW(CheckKnob("this_knob_does_not_exist"), std::exception); -} diff --git a/caffe2/utils/map_utils.h b/caffe2/utils/map_utils.h deleted file mode 100644 index ef8ff0cab707..000000000000 --- a/caffe2/utils/map_utils.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -namespace caffe2 { - -// Get value from map given key. Return supplied default value if not found -// This is a stripped down version from folly: -// https://github.com/facebook/folly/blob/5a07e203d79324b68d69f294fa38e43b9671e9b1/folly/MapUtil.h#L35-L45 -template < - class Map, - typename Key = typename Map::key_type, - typename Value = typename Map::mapped_type> -typename Map::mapped_type -get_default(const Map& map, const Key& key, Value&& dflt) { - using M = typename Map::mapped_type; - auto pos = map.find(key); - return (pos != map.end()) ? (pos->second) : M(std::forward(dflt)); -} - -} // namespace caffe2 diff --git a/caffe2/utils/murmur_hash3.cc b/caffe2/utils/murmur_hash3.cc deleted file mode 100644 index 68cce1fdd34e..000000000000 --- a/caffe2/utils/murmur_hash3.cc +++ /dev/null @@ -1,450 +0,0 @@ -//----------------------------------------------------------------------------- -// MurmurHash3 was written by Austin Appleby, and is placed in the public -// domain. The author hereby disclaims copyright to this source code. - -// Note - The x86 and x64 versions do _not_ produce the same results, as the -// algorithms are optimized for their respective platforms. You can still -// compile and run any of them on any platform, but your performance with the -// non-native version will be less than optimal. - -#include "caffe2/utils/murmur_hash3.h" - -//----------------------------------------------------------------------------- -// Platform-specific functions and macros - -// Microsoft Visual Studio - -#if defined(_MSC_VER) - -#define FORCE_INLINE __forceinline - -#include - -#define ROTL32(x, y) _rotl(x, y) -#define ROTL64(x, y) _rotl64(x, y) - -#define BIG_CONSTANT(x) (x) - -// Other compilers - -#else // defined(_MSC_VER) - -#define FORCE_INLINE inline __attribute__((__always_inline__)) - -inline uint32_t rotl32(uint32_t x, int8_t r) { - return (x << r) | (x >> (32 - r)); -} - -inline uint64_t rotl64(uint64_t x, int8_t r) { - return (x << r) | (x >> (64 - r)); -} - -#define ROTL32(x, y) rotl32(x, y) -#define ROTL64(x, y) rotl64(x, y) - -#define BIG_CONSTANT(x) (x##LLU) - -#endif // !defined(_MSC_VER) - -//----------------------------------------------------------------------------- -// Block read - if your platform needs to do endian-swapping or can only -// handle aligned reads, do the conversion here - -FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) { - return p[i]; -} - -FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) { - return p[i]; -} - -//----------------------------------------------------------------------------- -// Finalization mix - force all bits of a hash block to avalanche - -FORCE_INLINE uint32_t fmix32(uint32_t h) { - h ^= h >> 16; - h *= 0x85ebca6b; - h ^= h >> 13; - h *= 0xc2b2ae35; - h ^= h >> 16; - - return h; -} - -//---------- - -FORCE_INLINE uint64_t fmix64(uint64_t k) { - k ^= k >> 33; - k *= BIG_CONSTANT(0xff51afd7ed558ccd); - k ^= k >> 33; - k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); - k ^= k >> 33; - - return k; -} - -namespace caffe2 { - -void MurmurHash3_x86_32(const void* key, int len, uint32_t seed, void* out) { - const uint8_t* data = (const uint8_t*)key; - const int nblocks = len / 4; - - uint32_t h1 = seed; - - const uint32_t c1 = 0xcc9e2d51; - const uint32_t c2 = 0x1b873593; - - //---------- - // body - - const uint32_t* blocks = (const uint32_t*)(data + nblocks * 4); - - for (int i = -nblocks; i; i++) { - uint32_t k1 = getblock32(blocks, i); - - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - - h1 ^= k1; - h1 = ROTL32(h1, 13); - h1 = h1 * 5 + 0xe6546b64; - } - - //---------- - // tail - - const uint8_t* tail = (const uint8_t*)(data + nblocks * 4); - - uint32_t k1 = 0; - - switch (len & 3) { - case 3: - k1 ^= tail[2] << 16; - [[fallthrough]]; - case 2: - k1 ^= tail[1] << 8; - [[fallthrough]]; - case 1: - k1 ^= tail[0]; - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - h1 ^= k1; - }; - - //---------- - // finalization - - h1 ^= len; - - h1 = fmix32(h1); - - *(uint32_t*)out = h1; -} - -//----------------------------------------------------------------------------- - -void MurmurHash3_x86_128( - const void* key, - const int len, - uint32_t seed, - void* out) { - const uint8_t* data = (const uint8_t*)key; - const int nblocks = len / 16; - - uint32_t h1 = seed; - uint32_t h2 = seed; - uint32_t h3 = seed; - uint32_t h4 = seed; - - const uint32_t c1 = 0x239b961b; - const uint32_t c2 = 0xab0e9789; - const uint32_t c3 = 0x38b34ae5; - const uint32_t c4 = 0xa1e38b93; - - //---------- - // body - - const uint32_t* blocks = (const uint32_t*)(data + nblocks * 16); - - for (int i = -nblocks; i; i++) { - uint32_t k1 = getblock32(blocks, i * 4 + 0); - uint32_t k2 = getblock32(blocks, i * 4 + 1); - uint32_t k3 = getblock32(blocks, i * 4 + 2); - uint32_t k4 = getblock32(blocks, i * 4 + 3); - - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - h1 ^= k1; - - h1 = ROTL32(h1, 19); - h1 += h2; - h1 = h1 * 5 + 0x561ccd1b; - - k2 *= c2; - k2 = ROTL32(k2, 16); - k2 *= c3; - h2 ^= k2; - - h2 = ROTL32(h2, 17); - h2 += h3; - h2 = h2 * 5 + 0x0bcaa747; - - k3 *= c3; - k3 = ROTL32(k3, 17); - k3 *= c4; - h3 ^= k3; - - h3 = ROTL32(h3, 15); - h3 += h4; - h3 = h3 * 5 + 0x96cd1c35; - - k4 *= c4; - k4 = ROTL32(k4, 18); - k4 *= c1; - h4 ^= k4; - - h4 = ROTL32(h4, 13); - h4 += h1; - h4 = h4 * 5 + 0x32ac3b17; - } - - //---------- - // tail - - const uint8_t* tail = (const uint8_t*)(data + nblocks * 16); - - uint32_t k1 = 0; - uint32_t k2 = 0; - uint32_t k3 = 0; - uint32_t k4 = 0; - - switch (len & 15) { - case 15: - k4 ^= tail[14] << 16; - [[fallthrough]]; - case 14: - k4 ^= tail[13] << 8; - [[fallthrough]]; - case 13: - k4 ^= tail[12] << 0; - k4 *= c4; - k4 = ROTL32(k4, 18); - k4 *= c1; - h4 ^= k4; - [[fallthrough]]; - - case 12: - k3 ^= tail[11] << 24; - [[fallthrough]]; - case 11: - k3 ^= tail[10] << 16; - [[fallthrough]]; - case 10: - k3 ^= tail[9] << 8; - [[fallthrough]]; - case 9: - k3 ^= tail[8] << 0; - k3 *= c3; - k3 = ROTL32(k3, 17); - k3 *= c4; - h3 ^= k3; - [[fallthrough]]; - - case 8: - k2 ^= tail[7] << 24; - [[fallthrough]]; - case 7: - k2 ^= tail[6] << 16; - [[fallthrough]]; - case 6: - k2 ^= tail[5] << 8; - [[fallthrough]]; - case 5: - k2 ^= tail[4] << 0; - k2 *= c2; - k2 = ROTL32(k2, 16); - k2 *= c3; - h2 ^= k2; - [[fallthrough]]; - - case 4: - k1 ^= tail[3] << 24; - [[fallthrough]]; - case 3: - k1 ^= tail[2] << 16; - [[fallthrough]]; - case 2: - k1 ^= tail[1] << 8; - [[fallthrough]]; - case 1: - k1 ^= tail[0] << 0; - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - h1 ^= k1; - }; - - //---------- - // finalization - - h1 ^= len; - h2 ^= len; - h3 ^= len; - h4 ^= len; - - h1 += h2; - h1 += h3; - h1 += h4; - h2 += h1; - h3 += h1; - h4 += h1; - - h1 = fmix32(h1); - h2 = fmix32(h2); - h3 = fmix32(h3); - h4 = fmix32(h4); - - h1 += h2; - h1 += h3; - h1 += h4; - h2 += h1; - h3 += h1; - h4 += h1; - - ((uint32_t*)out)[0] = h1; - ((uint32_t*)out)[1] = h2; - ((uint32_t*)out)[2] = h3; - ((uint32_t*)out)[3] = h4; -} - -//----------------------------------------------------------------------------- - -void MurmurHash3_x64_128( - const void* key, - const int len, - const uint32_t seed, - void* out) { - const uint8_t* data = (const uint8_t*)key; - const int nblocks = len / 16; - - uint64_t h1 = seed; - uint64_t h2 = seed; - - const uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); - const uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); - - //---------- - // body - - const uint64_t* blocks = (const uint64_t*)(data); - - for (int i = 0; i < nblocks; i++) { - uint64_t k1 = getblock64(blocks, i * 2 + 0); - uint64_t k2 = getblock64(blocks, i * 2 + 1); - - k1 *= c1; - k1 = ROTL64(k1, 31); - k1 *= c2; - h1 ^= k1; - - h1 = ROTL64(h1, 27); - h1 += h2; - h1 = h1 * 5 + 0x52dce729; - - k2 *= c2; - k2 = ROTL64(k2, 33); - k2 *= c1; - h2 ^= k2; - - h2 = ROTL64(h2, 31); - h2 += h1; - h2 = h2 * 5 + 0x38495ab5; - } - - //---------- - // tail - - const uint8_t* tail = (const uint8_t*)(data + nblocks * 16); - - uint64_t k1 = 0; - uint64_t k2 = 0; - - switch (len & 15) { - case 15: - k2 ^= ((uint64_t)tail[14]) << 48; - [[fallthrough]]; - case 14: - k2 ^= ((uint64_t)tail[13]) << 40; - [[fallthrough]]; - case 13: - k2 ^= ((uint64_t)tail[12]) << 32; - [[fallthrough]]; - case 12: - k2 ^= ((uint64_t)tail[11]) << 24; - [[fallthrough]]; - case 11: - k2 ^= ((uint64_t)tail[10]) << 16; - [[fallthrough]]; - case 10: - k2 ^= ((uint64_t)tail[9]) << 8; - [[fallthrough]]; - case 9: - k2 ^= ((uint64_t)tail[8]) << 0; - k2 *= c2; - k2 = ROTL64(k2, 33); - k2 *= c1; - h2 ^= k2; - [[fallthrough]]; - - case 8: - k1 ^= ((uint64_t)tail[7]) << 56; - [[fallthrough]]; - case 7: - k1 ^= ((uint64_t)tail[6]) << 48; - [[fallthrough]]; - case 6: - k1 ^= ((uint64_t)tail[5]) << 40; - [[fallthrough]]; - case 5: - k1 ^= ((uint64_t)tail[4]) << 32; - [[fallthrough]]; - case 4: - k1 ^= ((uint64_t)tail[3]) << 24; - [[fallthrough]]; - case 3: - k1 ^= ((uint64_t)tail[2]) << 16; - [[fallthrough]]; - case 2: - k1 ^= ((uint64_t)tail[1]) << 8; - [[fallthrough]]; - case 1: - k1 ^= ((uint64_t)tail[0]) << 0; - k1 *= c1; - k1 = ROTL64(k1, 31); - k1 *= c2; - h1 ^= k1; - }; - - //---------- - // finalization - - h1 ^= len; - h2 ^= len; - - h1 += h2; - h2 += h1; - - h1 = fmix64(h1); - h2 = fmix64(h2); - - h1 += h2; - h2 += h1; - - ((uint64_t*)out)[0] = h1; - ((uint64_t*)out)[1] = h2; -} - -} // namespace caffe2 diff --git a/caffe2/utils/murmur_hash3.h b/caffe2/utils/murmur_hash3.h deleted file mode 100644 index ea67e7151c0b..000000000000 --- a/caffe2/utils/murmur_hash3.h +++ /dev/null @@ -1,34 +0,0 @@ -//----------------------------------------------------------------------------- -// MurmurHash3 was written by Austin Appleby, and is placed in the public -// domain. The author hereby disclaims copyright to this source code. - -#pragma once - -//----------------------------------------------------------------------------- -// Platform-specific functions and macros - -// Microsoft Visual Studio - -#if defined(_MSC_VER) && (_MSC_VER < 1600) - -typedef unsigned char uint8_t; -typedef unsigned int uint32_t; -typedef unsigned __int64 uint64_t; - -// Other compilers - -#else // defined(_MSC_VER) - -#include - -#endif // !defined(_MSC_VER) - -namespace caffe2 { - -void MurmurHash3_x86_32(const void* key, int len, uint32_t seed, void* out); - -void MurmurHash3_x86_128(const void* key, int len, uint32_t seed, void* out); - -void MurmurHash3_x64_128(const void* key, int len, uint32_t seed, void* out); - -} // namespace caffe2 diff --git a/caffe2/utils/proto_utils.cc b/caffe2/utils/proto_utils.cc deleted file mode 100644 index 8fc81586f3ca..000000000000 --- a/caffe2/utils/proto_utils.cc +++ /dev/null @@ -1,715 +0,0 @@ -#include "caffe2/utils/proto_utils.h" - -#include - -#include -#include -#include -#include - -#if defined(_MSC_VER) -#include -#else -#include -#endif - -#include - -#ifndef CAFFE2_USE_LITE_PROTO -#include -#include -#else -#include -#endif // !CAFFE2_USE_LITE_PROTO - -#include - -using ::google::protobuf::MessageLite; - -namespace caffe2 { - -C10_EXPORT std::string DeviceTypeName(const int32_t& d) { - return at::DeviceTypeName(static_cast(d)); -} - -void setTotalBytesLimit(::google::protobuf::io::CodedInputStream& stream, int bytes_limit, int warning_threshold) { - #if GOOGLE_PROTOBUF_VERSION >= 3011000 - // Only take one parameter since protobuf 3.11 - stream.SetTotalBytesLimit(bytes_limit); - #else - stream.SetTotalBytesLimit(bytes_limit, warning_threshold); - #endif -} - -C10_EXPORT int DeviceId(const DeviceOption& option) { - switch (option.device_type()) { - case PROTO_CPU: - return option.numa_node_id(); - case PROTO_CUDA: - case PROTO_HIP: - return option.device_id(); - case PROTO_MKLDNN: - return option.numa_node_id(); - default: - CAFFE_THROW("Unknown device id for device type: ", option.device_type()); - } -} - -C10_EXPORT bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs) { - return ( - lhs.device_type() == rhs.device_type() && - lhs.device_id() == rhs.device_id() && - lhs.node_name() == rhs.node_name() && - lhs.numa_node_id() == rhs.numa_node_id()); -} - -C10_EXPORT bool IsCPUDeviceType(int device_type) { - static const std::unordered_set cpu_types{ - PROTO_CPU, - PROTO_MKLDNN, - PROTO_IDEEP, - }; - return cpu_types.count(device_type); -} - -C10_EXPORT bool IsGPUDeviceType(int device_type) { - static const std::unordered_set gpu_types{ - PROTO_CUDA, - PROTO_HIP, - }; - return gpu_types.count(device_type); -} - -C10_EXPORT bool ReadStringFromFile(const char* filename, string* str) { - std::ifstream ifs(filename, std::ios::in); - if (!ifs) { - VLOG(1) << "File cannot be opened: " << filename - << " error: " << ifs.rdstate(); - return false; - } - ifs.seekg(0, std::ios::end); - size_t n = ifs.tellg(); - str->resize(n); - ifs.seekg(0); - ifs.read(&(*str)[0], n); - return true; -} - -C10_EXPORT bool WriteStringToFile(const string& str, const char* filename) { - std::ofstream ofs(filename, std::ios::out | std::ios::trunc); - if (!ofs.is_open()) { - VLOG(1) << "File cannot be created: " << filename - << " error: " << ofs.rdstate(); - return false; - } - ofs << str; - return true; -} - -// IO-specific proto functions: we will deal with the protocol buffer lite and -// full versions differently. - -#ifdef CAFFE2_USE_LITE_PROTO - -// Lite runtime. - -namespace { -class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const string& filename) - : ifs_(filename.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { - ifs_.close(); - } - - int Read(void* buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast(buffer), size); - return ifs_.gcount(); - } - - private: - std::ifstream ifs_; -}; -} // namespace - -C10_EXPORT string ProtoDebugString(const MessageLite& proto) { - string serialized = proto.SerializeAsString(); - for (char& c : serialized) { - if (c < 0x20 || c >= 0x7f) { - c = '?'; - } - } - return serialized; -} - -C10_EXPORT bool ParseProtoFromLargeString( - const string& str, - MessageLite* proto) { - ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size()); - ::google::protobuf::io::CodedInputStream coded_stream(&input_stream); - // Set PlanDef message size limit to 2G. - setTotalBytesLimit(coded_stream, 2147483647, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -C10_EXPORT bool ReadProtoFromBinaryFile( - const char* filename, - MessageLite* proto) { - ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(filename)); - stream.SetOwnsCopyingStream(true); - // Total bytes hard limit / warning limit are set to 2GB and 512MB - // respectively. - ::google::protobuf::io::CodedInputStream coded_stream(&stream); - setTotalBytesLimit(coded_stream, 2147483647, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -C10_EXPORT void WriteProtoToBinaryFile( - const MessageLite& /*proto*/, - const char* /*filename*/) { - LOG(FATAL) << "Not implemented yet."; -} - -#else // CAFFE2_USE_LITE_PROTO - -// Full protocol buffer. - -using ::google::protobuf::Message; -using ::google::protobuf::io::CodedInputStream; -using ::google::protobuf::io::CodedOutputStream; -using ::google::protobuf::io::FileInputStream; -using ::google::protobuf::io::FileOutputStream; -using ::google::protobuf::io::ZeroCopyInputStream; -using ::google::protobuf::io::ZeroCopyOutputStream; - -namespace TextFormat { -C10_EXPORT bool ParseFromString(const string& spec, Message* proto) { - string bc_spec = spec; - - { - auto num_replaced = c10::ReplaceAll(bc_spec, "cuda_gpu_id", "device_id"); - if (num_replaced) { - LOG(ERROR) << "Your model was serialized in Protobuf TextFormat and " - << "it has " << num_replaced - << " places using the deprecated field name 'cuda_gpu_id'!\n" - << spec - << "\nPlease re-export your model in Protobuf binary format " - << "to make it backward compatible for field renaming."; - } - } - - return ::google::protobuf::TextFormat::ParseFromString( - // NOLINTNEXTLINE(performance-move-const-arg) - std::move(bc_spec), proto); -} -} // namespace TextFormat - -C10_EXPORT string ProtoDebugString(const Message& proto) { - return proto.ShortDebugString(); -} - -C10_EXPORT bool ParseProtoFromLargeString(const string& str, Message* proto) { - ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size()); - ::google::protobuf::io::CodedInputStream coded_stream(&input_stream); - // Set PlanDef message size limit to 2G. - setTotalBytesLimit(coded_stream, 2147483647, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -C10_EXPORT bool ReadProtoFromTextFile(const char* filename, Message* proto) { - int fd = open(filename, O_RDONLY); - CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename); - FileInputStream* input = new FileInputStream(fd); - bool success = google::protobuf::TextFormat::Parse(input, proto); - delete input; - close(fd); - return success; -} - -C10_EXPORT void WriteProtoToTextFile( - const Message& proto, - const char* filename, - bool throwIfError) { - int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); - FileOutputStream* output = new FileOutputStream(fd); - if(!google::protobuf::TextFormat::Print(proto, output)) { - if (throwIfError) { - CAFFE_THROW("Cannot write proto to text file: ", filename); - } else { - LOG(ERROR) << "Cannot write proto to text file: " << filename; - } - } - delete output; - close(fd); -} - -C10_EXPORT bool ReadProtoFromBinaryFile( - const char* filename, - MessageLite* proto) { -#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified - int fd = open(filename, O_RDONLY | O_BINARY); -#else - int fd = open(filename, O_RDONLY); -#endif - CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename); - std::unique_ptr raw_input(new FileInputStream(fd)); - std::unique_ptr coded_input( - new CodedInputStream(raw_input.get())); - // A hack to manually allow using very large protocol buffers. - #if GOOGLE_PROTOBUF_VERSION >= 3011000 - // Only take one parameter since protobuf 3.11 - coded_input->SetTotalBytesLimit(2147483647); - #else - // Total bytes hard limit / warning limit are set to 2GB and 512MB respectively. - coded_input->SetTotalBytesLimit(2147483647, 536870912); - #endif - bool success = proto->ParseFromCodedStream(coded_input.get()); - coded_input.reset(); - raw_input.reset(); - close(fd); - return success; -} - -C10_EXPORT void WriteProtoToBinaryFile( - const MessageLite& proto, - const char* filename) { - int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); - CAFFE_ENFORCE_NE( - fd, -1, "File cannot be created: ", filename, " error number: ", errno); - std::unique_ptr raw_output(new FileOutputStream(fd)); - std::unique_ptr coded_output( - new CodedOutputStream(raw_output.get())); - CAFFE_ENFORCE(proto.SerializeToCodedStream(coded_output.get())); - coded_output.reset(); - raw_output.reset(); - close(fd); -} - -#endif // CAFFE2_USE_LITE_PROTO - -C10_EXPORT ArgumentHelper::ArgumentHelper(const OperatorDef& def) { - for (auto& arg : def.arg()) { - if (arg_map_.count(arg.name())) { - if (arg.SerializeAsString() != arg_map_[arg.name()].SerializeAsString()) { - // If there are two arguments of the same name but different contents, - // we will throw an error. - CAFFE_THROW( - "Found argument of the same name ", - arg.name(), - "but with different contents.", - ProtoDebugString(def)); - } else { - LOG(WARNING) << "Duplicated argument name [" << arg.name() - << "] found in operator def: " << ProtoDebugString(def); - } - } - arg_map_[arg.name()] = arg; - } -} - -C10_EXPORT ArgumentHelper::ArgumentHelper(const NetDef& netdef) { - for (auto& arg : netdef.arg()) { - CAFFE_ENFORCE( - arg_map_.count(arg.name()) == 0, - "Duplicated argument name [", - arg.name(), - "] found in net def: ", - ProtoDebugString(netdef)); - arg_map_[arg.name()] = arg; - } -} - -C10_EXPORT bool ArgumentHelper::HasArgument(c10::string_view name) const { -#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP - return arg_map_.count(name); -#else - return arg_map_.count(std::string(name)); -#endif -} - -namespace { -// Helper function to verify that conversion between types won't loose any -// significant bit. -template -bool SupportsLosslessConversion(const InputType& value) { - return static_cast(static_cast(value)) == value; -} -} // namespace -bool operator==(const TensorProto& l, const TensorProto& r) { - return l.SerializeAsString() == r.SerializeAsString(); -} - -std::ostream& operator<<(std::ostream& output, const TensorProto& n) { - output << n.SerializeAsString(); - return output; -} -bool operator==(const QTensorProto& l, const QTensorProto& r) { - return l.SerializeAsString() == r.SerializeAsString(); -} - -std::ostream& operator<<(std::ostream& output, const QTensorProto& n) { - output << n.SerializeAsString(); - return output; -} -bool operator==(const NetDef& l, const NetDef& r) { - return l.SerializeAsString() == r.SerializeAsString(); -} - -std::ostream& operator<<(std::ostream& output, const NetDef& n) { - output << n.SerializeAsString(); - return output; -} - -#define INSTANTIATE_GET_SINGLE_ARGUMENT( \ - T, fieldname, enforce_lossless_conversion) \ - template <> \ - C10_EXPORT T ArgumentHelper::GetSingleArgument( \ - c10::string_view name, const T& default_value) const { \ - auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \ - if (it == arg_map_.end()) { \ - VLOG(1) << "Using default parameter value " << default_value \ - << " for parameter " << name; \ - return default_value; \ - } \ - CAFFE_ENFORCE( \ - it->second.has_##fieldname(), \ - "Argument ", \ - name, \ - " does not have the right field: expected field " #fieldname); \ - const auto& value = it->second.fieldname(); \ - if (enforce_lossless_conversion) { \ - auto supportsConversion = \ - SupportsLosslessConversion(value); \ - CAFFE_ENFORCE( \ - supportsConversion, \ - "Value", \ - value, \ - " of argument ", \ - name, \ - "cannot be represented correctly in a target type"); \ - } \ - return static_cast(value); \ - } \ - template <> \ - C10_EXPORT bool ArgumentHelper::HasSingleArgumentOfType( \ - c10::string_view name) const { \ - auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \ - if (it == arg_map_.end()) { \ - return false; \ - } \ - return it->second.has_##fieldname(); \ - } - -INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false) -#undef INSTANTIATE_GET_SINGLE_ARGUMENT - -#define INSTANTIATE_GET_REPEATED_ARGUMENT( \ - T, fieldname, enforce_lossless_conversion) \ - template <> \ - C10_EXPORT std::vector ArgumentHelper::GetRepeatedArgument( \ - c10::string_view name, const std::vector& default_value) const { \ - auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \ - if (it == arg_map_.end()) { \ - return default_value; \ - } \ - std::vector values; \ - for (const auto& v : it->second.fieldname()) { \ - if (enforce_lossless_conversion) { \ - auto supportsConversion = \ - SupportsLosslessConversion(v); \ - CAFFE_ENFORCE( \ - supportsConversion, \ - "Value", \ - v, \ - " of argument ", \ - name, \ - "cannot be represented correctly in a target type"); \ - } \ - values.push_back(static_cast(v)); \ - } \ - return values; \ - } - -INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(NetDef, nets, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(TensorProto, tensors, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(QTensorProto, qtensors, false) -#undef INSTANTIATE_GET_REPEATED_ARGUMENT - -#define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \ - template <> \ - C10_EXPORT Argument MakeArgument(const string& name, const T& value) { \ - Argument arg; \ - arg.set_name(name); \ - arg.set_##fieldname(value); \ - return arg; \ - } - -CAFFE2_MAKE_SINGULAR_ARGUMENT(bool, i) -CAFFE2_MAKE_SINGULAR_ARGUMENT(float, f) -CAFFE2_MAKE_SINGULAR_ARGUMENT(int, i) -CAFFE2_MAKE_SINGULAR_ARGUMENT(int16_t, i) -CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i) -CAFFE2_MAKE_SINGULAR_ARGUMENT(string, s) -#undef CAFFE2_MAKE_SINGULAR_ARGUMENT - -template <> -C10_EXPORT Argument MakeArgument(const string& name, const NetDef& value) { - Argument arg; - arg.set_name(name); - *arg.mutable_n() = value; - return arg; -} - -template <> -C10_EXPORT bool ArgumentHelper::RemoveArgument(OperatorDef& def, int index); -template <> -bool ArgumentHelper::RemoveArgument(NetDef& def, int index); - -template <> -C10_EXPORT Argument MakeArgument(const string& name, const MessageLite& value) { - Argument arg; - arg.set_name(name); - arg.set_s(value.SerializeAsString()); - return arg; -} - -#define CAFFE2_MAKE_REPEATED_ARGUMENT(T, fieldname) \ - template <> \ - C10_EXPORT Argument MakeArgument( \ - const string& name, const std::vector& value) { \ - Argument arg; \ - arg.set_name(name); \ - for (const auto& v : value) { \ - arg.add_##fieldname(v); \ - } \ - return arg; \ - } - -CAFFE2_MAKE_REPEATED_ARGUMENT(float, floats) -CAFFE2_MAKE_REPEATED_ARGUMENT(int, ints) -CAFFE2_MAKE_REPEATED_ARGUMENT(int64_t, ints) -CAFFE2_MAKE_REPEATED_ARGUMENT(string, strings) -#undef CAFFE2_MAKE_REPEATED_ARGUMENT - -C10_EXPORT bool HasOutput(const OperatorDef& op, const std::string& output) { - for (const auto& outp : op.output()) { - if (outp == output) { - return true; - } - } - return false; -} - -C10_EXPORT bool HasInput(const OperatorDef& op, const std::string& input) { - for (const auto& inp : op.input()) { - if (inp == input) { - return true; - } - } - return false; -} - -// Return the argument index or -1 if it does not exist. -C10_EXPORT int GetArgumentIndex( - const google::protobuf::RepeatedPtrField& args, - c10::string_view name) { - int index = 0; - for (const Argument& arg : args) { - if (arg.name() == name) { - return index; - } - index++; - } - return -1; -} - -C10_EXPORT const Argument& GetArgument( - const OperatorDef& def, - c10::string_view name) { - int index = GetArgumentIndex(def.arg(), name); - if (index != -1) { - return def.arg(index); - } else { - CAFFE_THROW( - "Argument named ", - name, - " does not exist in operator ", - ProtoDebugString(def)); - } -} - -C10_EXPORT const Argument& GetArgument(const NetDef& def, c10::string_view name) { - int index = GetArgumentIndex(def.arg(), name); - if (index != -1) { - return def.arg(index); - } else { - CAFFE_THROW( - "Argument named ", - name, - " does not exist in net ", - ProtoDebugString(def)); - } -} - -C10_EXPORT const Argument* GetArgumentPtr( - const OperatorDef& def, - c10::string_view name) { - int index = GetArgumentIndex(def.arg(), name); - if (index != -1) { - return &def.arg(index); - } else { - return nullptr; - } -} - -C10_EXPORT const Argument* GetArgumentPtr( - const NetDef& def, - c10::string_view name) { - int index = GetArgumentIndex(def.arg(), name); - if (index != -1) { - return &def.arg(index); - } else { - return nullptr; - } -} - -C10_EXPORT bool GetFlagArgument( - const google::protobuf::RepeatedPtrField& args, - c10::string_view name, - bool default_value) { - int index = GetArgumentIndex(args, name); - if (index != -1) { - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - auto arg = args.Get(index); - CAFFE_ENFORCE( - arg.has_i(), "Can't parse argument as bool: ", ProtoDebugString(arg)); - return arg.i(); - } - return default_value; -} - -C10_EXPORT bool GetFlagArgument( - const OperatorDef& def, - c10::string_view name, - bool default_value) { - return GetFlagArgument(def.arg(), name, default_value); -} - -C10_EXPORT bool -GetFlagArgument(const NetDef& def, c10::string_view name, bool default_value) { - return GetFlagArgument(def.arg(), name, default_value); -} - -template -Argument* GetMutableArgumentImpl( - const string& name, - const bool create_if_missing, - Def* def) { - for (int i = 0; i < def->arg_size(); ++i) { - if (def->arg(i).name() == name) { - return def->mutable_arg(i); - } - } - // If no argument of the right name is found... - if (create_if_missing) { - Argument* arg = def->add_arg(); - arg->set_name(name); - return arg; - } else { - return nullptr; - } -} - -C10_EXPORT Argument* GetMutableArgument( - const string& name, - const bool create_if_missing, - OperatorDef* def) { - return GetMutableArgumentImpl(name, create_if_missing, def); -} - -C10_EXPORT Argument* GetMutableArgument( - const string& name, - const bool create_if_missing, - NetDef* def) { - return GetMutableArgumentImpl(name, create_if_missing, def); -} - -C10_EXPORT void cleanupExternalInputsAndOutputs(NetDef* net) { - std::vector oldExternalInputs; - for (const auto& input : net->external_input()) { - oldExternalInputs.emplace_back(input); - } - std::vector oldExternalOutputs; - for (const auto& output : net->external_output()) { - oldExternalOutputs.emplace_back(output); - } - - net->clear_external_input(); - net->clear_external_output(); - - std::set inputSet; - for (const auto& input : oldExternalInputs) { - if (inputSet.count(input)) { - // Prevent duplicate external inputs. - continue; - } - inputSet.insert(input); - net->add_external_input(input); - } - - // Set of blobs that are external inputs or outputs of some operators. - std::set allOutputs(inputSet.begin(), inputSet.end()); - for (const auto& op : net->op()) { - for (const auto& input : op.input()) { - if (inputSet.count(input) || allOutputs.count(input)) { - continue; - } - // Add missing external inputs. - inputSet.insert(input); - net->add_external_input(input); - } - for (const auto& output : op.output()) { - allOutputs.insert(output); - } - } - - std::set outputSet; - for (const auto& output : oldExternalOutputs) { - if (!allOutputs.count(output)) { - continue; - } - if (outputSet.count(output)) { - continue; - } - outputSet.insert(output); - net->add_external_output(output); - } -} - -} // namespace caffe2 diff --git a/caffe2/utils/proto_utils.h b/caffe2/utils/proto_utils.h deleted file mode 100644 index a6903425ab4e..000000000000 --- a/caffe2/utils/proto_utils.h +++ /dev/null @@ -1,383 +0,0 @@ -#ifndef CAFFE2_UTILS_PROTO_UTILS_H_ -#define CAFFE2_UTILS_PROTO_UTILS_H_ - -#ifdef CAFFE2_USE_LITE_PROTO -#include -#else // CAFFE2_USE_LITE_PROTO -#include -#endif // !CAFFE2_USE_LITE_PROTO - -#include -#include -#include - -#include "caffe2/utils/proto_wrap.h" -#include "caffe2/proto/caffe2_pb.h" - -#ifndef C10_ANDROID -#define CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP -#define CAFFE2_ARG_MAP_FIND(map, key) map.find(key) -#else -#define CAFFE2_ARG_MAP_FIND(map, key) map.find(std::string(key)) -#endif - -namespace caffe2 { - -using std::string; -using ::google::protobuf::MessageLite; - -// A wrapper function to return device name string for use in blob serialization -// / deserialization. This should have one to one correspondence with -// caffe2/proto/caffe2.proto: enum DeviceType. -// -// Note that we can't use DeviceType_Name, because that is only available in -// protobuf-full, and some platforms (like mobile) may want to use -// protobuf-lite instead. -TORCH_API std::string DeviceTypeName(const int32_t& d); - -TORCH_API int DeviceId(const DeviceOption& option); - -// Returns if the two DeviceOptions are pointing to the same device. -TORCH_API bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs); - -TORCH_API bool IsCPUDeviceType(int device_type); -TORCH_API bool IsGPUDeviceType(int device_type); - -// Common interfaces that reads file contents into a string. -TORCH_API bool ReadStringFromFile(const char* filename, string* str); -TORCH_API bool WriteStringToFile(const string& str, const char* filename); - -// Common interfaces that are supported by both lite and full protobuf. -TORCH_API bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto); -inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) { - return ReadProtoFromBinaryFile(filename.c_str(), proto); -} - -TORCH_API void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename); -inline void WriteProtoToBinaryFile(const MessageLite& proto, - const string& filename) { - return WriteProtoToBinaryFile(proto, filename.c_str()); -} - -#ifdef CAFFE2_USE_LITE_PROTO - -namespace TextFormat { -inline bool ParseFromString(const string& spec, MessageLite* proto) { - LOG(FATAL) << "If you are running lite version, you should not be " - << "calling any text-format protobuffers."; - return false; -} -} // namespace TextFormat - - -TORCH_API string ProtoDebugString(const MessageLite& proto); - -TORCH_API bool ParseProtoFromLargeString(const string& str, MessageLite* proto); - -// Text format MessageLite wrappers: these functions do nothing but just -// allowing things to compile. It will produce a runtime error if you are using -// MessageLite but still want text support. -inline bool ReadProtoFromTextFile( - const char* /*filename*/, - MessageLite* /*proto*/) { - LOG(FATAL) << "If you are running lite version, you should not be " - << "calling any text-format protobuffers."; - return false; // Just to suppress compiler warning. -} -inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) { - return ReadProtoFromTextFile(filename.c_str(), proto); -} - -inline void WriteProtoToTextFile( - const MessageLite& /*proto*/, - const char* /*filename*/, - bool throwIfError = true) { - LOG(FATAL) << "If you are running lite version, you should not be " - << "calling any text-format protobuffers."; -} -inline void WriteProtoToTextFile(const MessageLite& proto, - const string& filename, - bool throwIfError = true) { - return WriteProtoToTextFile(proto, filename.c_str(), throwIfError); -} - -inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) { - return (ReadProtoFromBinaryFile(filename, proto) || - ReadProtoFromTextFile(filename, proto)); -} - -inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) { - return ReadProtoFromFile(filename.c_str(), proto); -} - -#else // CAFFE2_USE_LITE_PROTO - -using ::google::protobuf::Message; - -namespace TextFormat { -TORCH_API bool ParseFromString(const string& spec, Message* proto); -} // namespace TextFormat - -TORCH_API string ProtoDebugString(const Message& proto); - -TORCH_API bool ParseProtoFromLargeString(const string& str, Message* proto); - -TORCH_API bool ReadProtoFromTextFile(const char* filename, Message* proto); -inline bool ReadProtoFromTextFile(const string filename, Message* proto) { - return ReadProtoFromTextFile(filename.c_str(), proto); -} - -TORCH_API void WriteProtoToTextFile(const Message& proto, const char* filename, bool throwIfError = true); -inline void WriteProtoToTextFile(const Message& proto, const string& filename, bool throwIfError = true) { - return WriteProtoToTextFile(proto, filename.c_str(), throwIfError); -} - -// Read Proto from a file, letting the code figure out if it is text or binary. -inline bool ReadProtoFromFile(const char* filename, Message* proto) { - return (ReadProtoFromBinaryFile(filename, proto) || - ReadProtoFromTextFile(filename, proto)); -} - -inline bool ReadProtoFromFile(const string& filename, Message* proto) { - return ReadProtoFromFile(filename.c_str(), proto); -} - -#endif // CAFFE2_USE_LITE_PROTO - -template < - class IterableInputs = std::initializer_list, - class IterableOutputs = std::initializer_list, - class IterableArgs = std::initializer_list> -OperatorDef CreateOperatorDef( - const string& type, - const string& name, - const IterableInputs& inputs, - const IterableOutputs& outputs, - const IterableArgs& args, - const DeviceOption& device_option = DeviceOption(), - const string& engine = "") { - OperatorDef def; - def.set_type(type); - def.set_name(name); - for (const string& in : inputs) { - def.add_input(in); - } - for (const string& out : outputs) { - def.add_output(out); - } - for (const Argument& arg : args) { - def.add_arg()->CopyFrom(arg); - } - if (device_option.has_device_type()) { - def.mutable_device_option()->CopyFrom(device_option); - } - if (engine.size()) { - def.set_engine(engine); - } - return def; -} - -// A simplified version compared to the full CreateOperator, if you do not need -// to specify args. -template < - class IterableInputs = std::initializer_list, - class IterableOutputs = std::initializer_list> -inline OperatorDef CreateOperatorDef( - const string& type, - const string& name, - const IterableInputs& inputs, - const IterableOutputs& outputs, - const DeviceOption& device_option = DeviceOption(), - const string& engine = "") { - return CreateOperatorDef( - type, - name, - inputs, - outputs, - std::vector(), - device_option, - engine); -} - -TORCH_API bool HasOutput(const OperatorDef& op, const std::string& output); -TORCH_API bool HasInput(const OperatorDef& op, const std::string& input); - -/** - * @brief A helper class to index into arguments. - * - * This helper helps us to more easily index into a set of arguments - * that are present in the operator. To save memory, the argument helper - * does not copy the operator def, so one would need to make sure that the - * lifetime of the OperatorDef object outlives that of the ArgumentHelper. - */ -class C10_EXPORT ArgumentHelper { - public: - template - static bool HasArgument(const Def& def, c10::string_view name) { - return ArgumentHelper(def).HasArgument(name); - } - - template - static T GetSingleArgument( - const Def& def, - c10::string_view name, - const T& default_value) { - return ArgumentHelper(def).GetSingleArgument(name, default_value); - } - - template - static bool HasSingleArgumentOfType(const Def& def, c10::string_view name) { - return ArgumentHelper(def).HasSingleArgumentOfType(name); - } - - template - static std::vector GetRepeatedArgument( - const Def& def, - c10::string_view name, - const std::vector& default_value = std::vector()) { - return ArgumentHelper(def).GetRepeatedArgument(name, default_value); - } - - template - static MessageType GetMessageArgument(const Def& def, c10::string_view name) { - return ArgumentHelper(def).GetMessageArgument(name); - } - - template - static std::vector GetRepeatedMessageArgument( - const Def& def, - c10::string_view name) { - return ArgumentHelper(def).GetRepeatedMessageArgument(name); - } - - template - static bool RemoveArgument(Def& def, int index) { - if (index >= def.arg_size()) { - return false; - } - if (index < def.arg_size() - 1) { - def.mutable_arg()->SwapElements(index, def.arg_size() - 1); - } - def.mutable_arg()->RemoveLast(); - return true; - } - - explicit ArgumentHelper(const OperatorDef& def); - explicit ArgumentHelper(const NetDef& netdef); - bool HasArgument(c10::string_view name) const; - - template - T GetSingleArgument(c10::string_view name, const T& default_value) const; - template - bool HasSingleArgumentOfType(c10::string_view name) const; - template - std::vector GetRepeatedArgument( - c10::string_view name, - const std::vector& default_value = std::vector()) const; - - template - MessageType GetMessageArgument(c10::string_view name) const { - auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); - CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name); - MessageType message; - if (it->second.has_s()) { - CAFFE_ENFORCE( - message.ParseFromString(it->second.s()), - "Failed to parse content from the string"); - } else { - VLOG(1) << "Return empty message for parameter " << name; - } - return message; - } - - template - std::vector GetRepeatedMessageArgument(c10::string_view name) const { - auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); - CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name); - std::vector messages(it->second.strings_size()); - for (int i = 0; i < messages.size(); ++i) { - CAFFE_ENFORCE( - messages[i].ParseFromString(it->second.strings(i)), - "Failed to parse content from the string"); - } - return messages; - } - - private: - std::map -#endif - > arg_map_; -}; - -// **** Arguments Utils ***** - -// Helper methods to get an argument from OperatorDef or NetDef given argument -// name. Throws if argument does not exist. -TORCH_API const Argument& GetArgument(const OperatorDef& def, c10::string_view name); -TORCH_API const Argument& GetArgument(const NetDef& def, c10::string_view name); -// Helper methods to get an argument from OperatorDef or NetDef given argument -// name. Returns nullptr if argument does not exist. -TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, c10::string_view name); -TORCH_API const Argument* GetArgumentPtr(const NetDef& def, c10::string_view name); - -// Helper methods to query a boolean argument flag from OperatorDef or NetDef -// given argument name. If argument does not exist, return default value. -// Throws if argument exists but the type is not boolean. -TORCH_API bool GetFlagArgument( - const OperatorDef& def, - c10::string_view name, - bool default_value = false); -TORCH_API bool GetFlagArgument( - const NetDef& def, - c10::string_view name, - bool default_value = false); - -TORCH_API Argument* GetMutableArgument( - const string& name, - const bool create_if_missing, - OperatorDef* def); -TORCH_API Argument* GetMutableArgument( - const string& name, - const bool create_if_missing, - NetDef* def); - -template -TORCH_API Argument MakeArgument(const string& name, const T& value); - -template -inline void AddArgument(const string& name, const T& value, Def* def) { - GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value)); -} -// **** End Arguments Utils ***** - -bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) { - return IsSameDevice(dl, dr); -} - -// Given a net, modify the external inputs/outputs if necessary so that -// the following conditions are met -// - No duplicate external inputs -// - No duplicate external outputs -// - Going through list of ops in order, all op inputs must be outputs -// from other ops, or registered as external inputs. -// - All external outputs must be outputs of some operators. -TORCH_API void cleanupExternalInputsAndOutputs(NetDef* net); - -} // namespace caffe2 - -namespace std { -template <> -struct hash { - typedef caffe2::DeviceOption argument_type; - typedef std::size_t result_type; - result_type operator()(argument_type const& device_option) const { - std::string serialized; - CAFFE_ENFORCE(device_option.SerializeToString(&serialized)); - return std::hash{}(serialized); - } -}; -} // namespace std - -#endif // CAFFE2_UTILS_PROTO_UTILS_H_ diff --git a/caffe2/utils/proto_utils_test.cc b/caffe2/utils/proto_utils_test.cc deleted file mode 100644 index 1a687690c69f..000000000000 --- a/caffe2/utils/proto_utils_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -#include - -#include "caffe2/core/test_utils.h" -#include "caffe2/utils/proto_utils.h" - -namespace caffe2 { - -TEST(ProtoUtilsTest, IsSameDevice) { - DeviceOption a; - DeviceOption b; - EXPECT_TRUE(IsSameDevice(a, b)); - a.set_node_name("my_node"); - EXPECT_FALSE(IsSameDevice(a, b)); - b.set_node_name("my_node"); - EXPECT_TRUE(IsSameDevice(a, b)); - b.set_device_id(2); - EXPECT_FALSE(IsSameDevice(a, b)); - a.set_device_id(2); - EXPECT_TRUE(IsSameDevice(a, b)); - a.set_device_type(DeviceTypeProto::PROTO_CUDA); - b.set_device_type(DeviceTypeProto::PROTO_CPU); - EXPECT_FALSE(IsSameDevice(a, b)); -} - -TEST(ProtoUtilsTest, SimpleReadWrite) { - string content("The quick brown fox jumps over the lazy dog."); - string name = std::tmpnam(nullptr); - EXPECT_TRUE(WriteStringToFile(content, name.c_str())); - string read_back; - EXPECT_TRUE(ReadStringFromFile(name.c_str(), &read_back)); - EXPECT_EQ(content, read_back); -} - -TEST(ProtoUtilsTest, CleanupExternalInputsAndOutputs) { - caffe2::NetDef net; - caffe2::testing::NetMutator(&net) - .newOp("op1", {"X1", "X2"}, {"Y"}) - .newOp("op2", {"W", "Y"}, {"Z1", "Z2"}) - .newOp("op3", {"Z2", "W"}, {"O"}) - .externalInputs({"X1", "X3", "X1", "W"}) - .externalOutputs({"O", "Z2", "Z3", "O", "X3"}); - cleanupExternalInputsAndOutputs(&net); - - std::vector externalInputs; - for (const auto& inputName : net.external_input()) { - externalInputs.emplace_back(inputName); - } - // The 2nd X1 is removed because of duplication. - // X2 is added because it should be a missing external input. - std::vector expectedExternalInputs{"X1", "X3", "W", "X2"}; - EXPECT_EQ(externalInputs, expectedExternalInputs); - - std::vector externalOutputs; - for (const auto& outputName : net.external_output()) { - externalOutputs.emplace_back(outputName); - } - // Z3 is removed because it's not an output of any operator in the net. - // The 2nd O is removed because of duplication. - std::vector expectedexternalOutputs{"O", "Z2", "X3"}; - EXPECT_EQ(externalOutputs, expectedexternalOutputs); -} - -} // namespace caffe2 diff --git a/caffe2/utils/signal_handler.h b/caffe2/utils/signal_handler.h deleted file mode 100644 index 14d93a0df670..000000000000 --- a/caffe2/utils/signal_handler.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include - -namespace caffe2 { - -#if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS) -class TORCH_API C2FatalSignalHandler : public c10::FatalSignalHandler { - public: - void fatalSignalHandlerPostProcess() override; - static C2FatalSignalHandler& getInstance(); - - private: - explicit C2FatalSignalHandler(); -}; - -// This works by setting up certain fatal signal handlers. Previous fatal -// signal handlers will still be called when the signal is raised. Defaults -// to being off. -TORCH_API void setPrintStackTracesOnFatalSignal(bool print); -TORCH_API bool printStackTracesOnFatalSignal(); -#endif // defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLER) - -} // namespace caffe2 diff --git a/caffe2/utils/simple_queue.h b/caffe2/utils/simple_queue.h deleted file mode 100644 index c16f55223eed..000000000000 --- a/caffe2/utils/simple_queue.h +++ /dev/null @@ -1,79 +0,0 @@ -#ifndef CAFFE2_UTILS_SIMPLE_QUEUE_H_ -#define CAFFE2_UTILS_SIMPLE_QUEUE_H_ - -#include // NOLINT -#include // NOLINT -#include - -#include - -namespace caffe2 { - -// This is a very simple queue that Yangqing wrote when bottlefeeding the baby, -// so don't take it seriously. What it does is a minimal thread-safe queue that -// allows me to run network as a DAG. -// -// A usual work pattern looks like this: one or multiple producers push jobs -// into this queue, and one or multiple workers pops jobs from this queue. If -// nothing is in the queue but NoMoreJobs() is not called yet, the pop calls -// will wait. If NoMoreJobs() has been called, pop calls will return false, -// which serves as a message to the workers that they should exit. -template -class SimpleQueue { - public: - SimpleQueue() : no_more_jobs_(false) {} - - // Pops a value and writes it to the value pointer. If there is nothing in the - // queue, this will wait till a value is inserted to the queue. If there are - // no more jobs to pop, the function returns false. Otherwise, it returns - // true. - bool Pop(T* value) { - std::unique_lock mutex_lock(mutex_); - while (queue_.size() == 0 && !no_more_jobs_) cv_.wait(mutex_lock); - if (queue_.size() == 0 && no_more_jobs_) return false; - *value = queue_.front(); - queue_.pop(); - return true; - } - - int size() { - std::unique_lock mutex_lock(mutex_); - return queue_.size(); - } - - // Push pushes a value to the queue. - void Push(const T& value) { - { - std::lock_guard mutex_lock(mutex_); - CAFFE_ENFORCE(!no_more_jobs_, "Cannot push to a closed queue."); - queue_.push(value); - } - cv_.notify_one(); - } - - // NoMoreJobs() marks the close of this queue. It also notifies all waiting - // Pop() calls so that they either check out remaining jobs, or return false. - // After NoMoreJobs() is called, this queue is considered closed - no more - // Push() functions are allowed, and once existing items are all checked out - // by the Pop() functions, any more Pop() function will immediately return - // false with nothing set to the value. - void NoMoreJobs() { - { - std::lock_guard mutex_lock(mutex_); - no_more_jobs_ = true; - } - cv_.notify_all(); - } - - private: - std::mutex mutex_; - std::condition_variable cv_; - std::queue queue_; - bool no_more_jobs_{}; - // We do not allow copy constructors. - SimpleQueue(const SimpleQueue& /*src*/) {} -}; - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_SIMPLE_QUEUE_H_ diff --git a/caffe2/utils/simple_queue_test.cc b/caffe2/utils/simple_queue_test.cc deleted file mode 100644 index e59f699cd15a..000000000000 --- a/caffe2/utils/simple_queue_test.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include // NOLINT - -#include "caffe2/utils/simple_queue.h" -#include - -namespace caffe2 { - -static std::unique_ptr > gQueue; - -static void ConsumerFunction(int thread_idx) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int value; - while (true) { - if (!gQueue->Pop(&value)) return; - VLOG(1) << "Emitting " << value << " from thread " << thread_idx; - } -} - -static void ProducerFunction(int thread_idx, int start, int count) { - for (int i = 0; i < count; ++i) { - VLOG(1) << "Pushing " << i + start << " from thread " << thread_idx; - gQueue->Push(i + start); - } -} - - -TEST(SimpleQueueTest, SingleProducerSingleConsumer) { - // NOLINTNEXTLINE(modernize-make-unique) - gQueue.reset(new SimpleQueue()); - std::thread consumer(ConsumerFunction, 0); - for (int i = 0; i < 10; ++i) { - gQueue->Push(i); - } - gQueue->NoMoreJobs(); - consumer.join(); -} - -TEST(SimpleQueueTest, SingleProducerDoubleConsumer) { - // NOLINTNEXTLINE(modernize-make-unique) - gQueue.reset(new SimpleQueue()); - std::thread consumer0(ConsumerFunction, 0); - std::thread consumer1(ConsumerFunction, 1); - for (int i = 0; i < 10; ++i) { - gQueue->Push(i); - } - gQueue->NoMoreJobs(); - consumer0.join(); - consumer1.join(); -} - - -TEST(SimpleQueueTest, DoubleProducerDoubleConsumer) { - // NOLINTNEXTLINE(modernize-make-unique) - gQueue.reset(new SimpleQueue()); - std::thread producer0(ProducerFunction, 0, 0, 10); - std::thread producer1(ProducerFunction, 0, 10, 10); - std::thread consumer0(ConsumerFunction, 2); - std::thread consumer1(ConsumerFunction, 3); - producer0.join(); - producer1.join(); - gQueue->NoMoreJobs(); - consumer0.join(); - consumer1.join(); -} - -TEST(SimpleQueueDeathTest, CannotAddAfterQueueFinished) { - // NOLINTNEXTLINE(modernize-make-unique) - gQueue.reset(new SimpleQueue()); - gQueue->Push(0); - gQueue->NoMoreJobs(); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(gQueue->Push(0), EnforceNotMet); -} - - -} // namespace caffe2 diff --git a/caffe2/utils/smart_tensor_printer.h b/caffe2/utils/smart_tensor_printer.h deleted file mode 100644 index e6d96ef37ae0..000000000000 --- a/caffe2/utils/smart_tensor_printer.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include "caffe2/core/tensor.h" - -namespace caffe2 { - -// This is a wrapper around the TensorPrinter that doesn't require the user to -// explicit specify the type of the tensor while calling the Print() method. -// It also supports a convenience function with a default constructed printer as -// a static method. -class TORCH_API SmartTensorPrinter { - public: - // The proliferation of constructors is to give the feature parity with - // TensorPrinter - // yet not repeat the default arguments explicitly in case they change in the - // future. - SmartTensorPrinter() = default; - - explicit SmartTensorPrinter(const std::string& tensor_name); - - SmartTensorPrinter( - const std::string& tensor_name, - const std::string& file_name); - - SmartTensorPrinter( - const std::string& tensor_name, - const std::string& file_name, - int limit); - - void Print(const Tensor& tensor); - - void PrintMeta(const Tensor& tensor) { - tensorPrinter_.PrintMeta(tensor); - } - - // Uses a default constructed SmartTensorPrinter - static void PrintTensor(const Tensor& tensor); - - // Uses a default constructed SmartTensorPrinter - void PrintTensorMeta(const Tensor& tensor) { - DefaultTensorPrinter().PrintMeta(tensor); - } - - private: - // Returns a thread local default constructed TensorPrinter - static SmartTensorPrinter& DefaultTensorPrinter(); - - TensorPrinter tensorPrinter_; -}; -} diff --git a/caffe2/utils/smart_tensor_printer_test.cc b/caffe2/utils/smart_tensor_printer_test.cc deleted file mode 100644 index a45573001c6e..000000000000 --- a/caffe2/utils/smart_tensor_printer_test.cc +++ /dev/null @@ -1,53 +0,0 @@ -#include "caffe2/utils/smart_tensor_printer.h" - -#include "caffe2/core/common.h" - -#include - -namespace caffe2 { - -template -std::string my_to_string(const T& value) { - return to_string(value); -} - -template <> -std::string my_to_string(const std::string& value) { - return value; -} - -template -void expect_stderr_contains(const std::vector& values) { - std::string captured_stderr = testing::internal::GetCapturedStderr(); - for (const auto& value : values) { - std::string stringValue = my_to_string(value); - EXPECT_TRUE(captured_stderr.find(stringValue) != std::string::npos); - } -} - -template -void printTensorAndCheck(const std::vector& values) { - testing::internal::CaptureStderr(); - - Tensor tensor = - TensorCPUFromValues({static_cast(values.size())}, values); - - SmartTensorPrinter::PrintTensor(tensor); - expect_stderr_contains(values); -} - -// We need real glog for this test to pass -#ifdef CAFFE2_USE_GOOGLE_GLOG - -#if !(__APPLE__) // TODO(janusz): thread_local does not work under mac. - -TEST(SmartTensorPrinterTest, SimpleTest) { - printTensorAndCheck(std::vector{1, 2, 3, 4, 5}); - printTensorAndCheck(std::vector{"bob", "alice", "facebook"}); -} - -#endif // !(__APPLE__) - -#endif // CAFFE2_USE_GOOGLE_GLOG - -} // namespace caffe2 diff --git a/caffe2/utils/zmq_helper.h b/caffe2/utils/zmq_helper.h deleted file mode 100644 index 05bc22a73c4e..000000000000 --- a/caffe2/utils/zmq_helper.h +++ /dev/null @@ -1,137 +0,0 @@ -#ifndef CAFFE2_UTILS_ZMQ_HELPER_H_ -#define CAFFE2_UTILS_ZMQ_HELPER_H_ - -#include - -#include "caffe2/core/logging.h" - -namespace caffe2 { - -class ZmqContext { - public: - explicit ZmqContext(int io_threads) : ptr_(zmq_ctx_new()) { - CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq context."); - int rc = zmq_ctx_set(ptr_, ZMQ_IO_THREADS, io_threads); - CAFFE_ENFORCE_EQ(rc, 0); - rc = zmq_ctx_set(ptr_, ZMQ_MAX_SOCKETS, ZMQ_MAX_SOCKETS_DFLT); - CAFFE_ENFORCE_EQ(rc, 0); - } - ~ZmqContext() { - int rc = zmq_ctx_destroy(ptr_); - CAFFE_ENFORCE_EQ(rc, 0); - } - - void* ptr() { return ptr_; } - - private: - void* ptr_; - - C10_DISABLE_COPY_AND_ASSIGN(ZmqContext); -}; - -class ZmqMessage { - public: - ZmqMessage() { - int rc = zmq_msg_init(&msg_); - CAFFE_ENFORCE_EQ(rc, 0); - } - - ~ZmqMessage() { - int rc = zmq_msg_close(&msg_); - CAFFE_ENFORCE_EQ(rc, 0); - } - - zmq_msg_t* msg() { return &msg_; } - - void* data() { return zmq_msg_data(&msg_); } - size_t size() { return zmq_msg_size(&msg_); } - - private: - zmq_msg_t msg_; - C10_DISABLE_COPY_AND_ASSIGN(ZmqMessage); -}; - -class ZmqSocket { - public: - explicit ZmqSocket(int type) - : context_(1), ptr_(zmq_socket(context_.ptr(), type)) { - CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq socket."); - } - - ~ZmqSocket() { - int rc = zmq_close(ptr_); - CAFFE_ENFORCE_EQ(rc, 0); - } - - void Bind(const string& addr) { - int rc = zmq_bind(ptr_, addr.c_str()); - CAFFE_ENFORCE_EQ(rc, 0); - } - - void Unbind(const string& addr) { - int rc = zmq_unbind(ptr_, addr.c_str()); - CAFFE_ENFORCE_EQ(rc, 0); - } - - void Connect(const string& addr) { - int rc = zmq_connect(ptr_, addr.c_str()); - CAFFE_ENFORCE_EQ(rc, 0); - } - - void Disconnect(const string& addr) { - int rc = zmq_disconnect(ptr_, addr.c_str()); - CAFFE_ENFORCE_EQ(rc, 0); - } - - int Send(const string& msg, int flags) { - int nbytes = zmq_send(ptr_, msg.c_str(), msg.size(), flags); - if (nbytes) { - return nbytes; - } else if (zmq_errno() == EAGAIN) { - return 0; - } else { - LOG(FATAL) << "Cannot send zmq message. Error number: " - << zmq_errno(); - return 0; - } - } - - int SendTillSuccess(const string& msg, int flags) { - CAFFE_ENFORCE(msg.size(), "You cannot send an empty message."); - int nbytes = 0; - do { - nbytes = Send(msg, flags); - } while (nbytes == 0); - return nbytes; - } - - int Recv(ZmqMessage* msg) { - int nbytes = zmq_msg_recv(msg->msg(), ptr_, 0); - if (nbytes >= 0) { - return nbytes; - } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { - return 0; - } else { - LOG(FATAL) << "Cannot receive zmq message. Error number: " - << zmq_errno(); - return 0; - } - } - - int RecvTillSuccess(ZmqMessage* msg) { - int nbytes = 0; - do { - nbytes = Recv(msg); - } while (nbytes == 0); - return nbytes; - } - - private: - ZmqContext context_; - void* ptr_; -}; - -} // namespace caffe2 - - -#endif // CAFFE2_UTILS_ZMQ_HELPER_H_ From 029b3ec7754e13b260ec7931d98404d1023f0ed1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 31 May 2024 12:33:25 +0000 Subject: [PATCH 160/706] Revert "[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)" This reverts commit dae33a4961addb5847dbb362e7bb907bbfc64929. Reverted https://github.com/pytorch/pytorch/pull/126068 on behalf of https://github.com/PaliC due to failing internal tests ([comment](https://github.com/pytorch/pytorch/pull/126068#issuecomment-2141992307)) --- test/inductor/test_cpu_select_algorithm.py | 15 +-- torch/_inductor/codegen/cpp.py | 3 +- torch/_inductor/codegen/cpp_gemm_template.py | 69 +++-------- torch/_inductor/codegen/cpp_micro_gemm.py | 90 ++++---------- .../_inductor/codegen/cpp_template_kernel.py | 113 ++++++++---------- torch/_inductor/codegen/cpp_utils.py | 62 +--------- torch/_inductor/ir.py | 8 +- torch/_inductor/mkldnn_lowerings.py | 102 ++-------------- torch/_inductor/utils.py | 10 +- 9 files changed, 112 insertions(+), 360 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 0147ca8e24cd..aabd5bd08b15 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -77,11 +77,11 @@ class TestSelectAlgorithm(TestCase): @torch.no_grad @unittest.skipIf(not TEST_MKL, "Test requires MKL") @parametrize("batch_size", (1, 2, 1000)) - @parametrize("in_features", (1, 1000)) - @parametrize("out_features", (1, 1024)) + @parametrize("in_features", (1, 2, 1000)) + @parametrize("out_features", (1, 32, 1024)) @parametrize("bias", (True, False)) @parametrize("input_3d", (True, False)) - @dtypes(torch.float, torch.bfloat16, torch.half) + @dtypes(torch.float) def test_linear_static_shapes( self, batch_size, in_features, out_features, bias, input_3d, dtype ): @@ -97,14 +97,7 @@ def forward(self, x): mod = M(bias=bias).to(dtype=dtype).eval() B = (2, batch_size) if input_3d else (batch_size,) v = torch.randn(*B, in_features).to(dtype=dtype) - # For bfloat16 and half, we have to relax the tolerance - # due to the difference associave orders in different - # kernel implementations - atol, rtol = 1e-4, 1e-4 - if dtype == torch.half or dtype == torch.bfloat16: - atol, rtol = 1e-2, 1e-2 - with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)): - self.common(mod, (v,), atol=atol, rtol=rtol) + self.common(mod, (v,)) if ( counters["inductor"]["decompose_mm"] > 0 or counters["inductor"]["decompose_addmm"] > 0 diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index a14f93e14e6b..eabb5bbef470 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2849,8 +2849,9 @@ def store_reduction(self, name, index, value): return self.simd_vec def __exit__(self, exc_type, exc_val, exc_tb): + assert self._orig_wrapper_code is not None # Restore the wrapper_code - V.graph.wrapper_code = self._orig_wrapper_code # type: ignore[assignment] + V.graph.wrapper_code = self._orig_wrapper_code self.exit_stack.__exit__(exc_type, exc_val, exc_tb) def __enter__(self): diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 18d6301d57a6..e0a4c0993549 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -147,7 +147,6 @@ def __init__( beta=1, alpha=1, ): - assert layout.dtype in [torch.float, torch.bfloat16, torch.half] super().__init__("packed_gemm", input_nodes, layout) self.beta = beta self.alpha = alpha @@ -213,13 +212,7 @@ def cache_blocking(self) -> GemmBlocking: @staticmethod def add_choices( - choices, - layout, - input_nodes, - beta=1, - alpha=1, - trans_w=False, - input_indices=None, + choices, layout, input_nodes, beta=1, alpha=1, trans_w=False, input_indices=None ): if input_indices is None: input_indices = list(range(len(input_nodes))) @@ -239,58 +232,28 @@ def reorder_and_filter(inputs, layout_or_out): w_idx = input_indices[2] return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out - def maybe_to_dense(inputs, layout_or_out): - new_inputs = list(inputs) - if isinstance(inputs[1], torch.Tensor): - W = inputs[1] - new_inputs[1] = W.to_dense() if W.is_mkldnn else W - return new_inputs, layout_or_out - - def normalize_shapes(inputs, layout_or_out): + def transpose_weight(inputs, layout_or_out): if not trans_w: return inputs, layout_or_out new_inputs = list(inputs) - X = inputs[0] W = inputs[1] - B = inputs[2] if len(inputs) > 2 else None if isinstance(W, ir.IRNode): - if trans_w: - if not isinstance(W, ir.TensorBox): - W = ir.TensorBox(W) - W = L.permute(W, [1, 0]) + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + new_inputs[1] = L.permute(W, [1, 0]) + return new_inputs, layout_or_out else: - if trans_w: - assert isinstance(W, torch.Tensor) - W = W.transpose(0, 1) - if B is not None: - if isinstance(B, ir.IRNode): - if not isinstance(B, ir.TensorBox): - B = ir.TensorBox(B) - B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) - else: - assert isinstance(B, torch.Tensor) - B = B.expand(X.shape[0], B.shape[-1]) - new_inputs[1] = W - if B is not None: - new_inputs[2] = B + assert isinstance(W, torch.Tensor) + new_inputs[1] = W.transpose(0, 1) return new_inputs, layout_or_out # TODO(jgong5): decide proper number of threads per problem size num_threads = parallel_num_threads() - new_inputs, _ = normalize_shapes( - *maybe_to_dense(*reorder_and_filter(input_nodes, layout)) - ) + new_inputs, _ = transpose_weight(*reorder_and_filter(input_nodes, layout)) m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) micro_gemm = create_micro_gemm( - "micro_gemm", - m, - n, - k, - input_dtype=layout.dtype, - output_dtype=torch.float, - alpha=alpha, - num_threads=num_threads, + "micro_gemm", m, n, k, layout.dtype, alpha=alpha, num_threads=num_threads ) assert micro_gemm is not None _, block_n, _ = micro_gemm.register_blocking @@ -337,9 +300,7 @@ def pack_weight(inputs, layout_or_out): return new_inputs, layout_or_out def preprocessor(inputs, layout): - return pack_weight( - *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) - ) + return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout))) def postprocessor(output): if isinstance(output, ir.TensorBox): @@ -354,7 +315,7 @@ def postprocessor(output): W = V.graph.constants[W_node.get_name()] new_input_nodes[1] = W new_input_nodes, _ = pack_weight( - *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) + *transpose_weight(new_input_nodes, layout) ) W_packed = new_input_nodes[1] W_packed_constant = V.graph.add_tensor_constant(W_packed) @@ -397,7 +358,8 @@ def render( # type: ignore[override] template_buffer = Y Y_is_transposed = False - use_local_acc = self.layout.dtype != torch.float + # TODO(jgong5): support local accumulation + use_local_acc = False if epilogue_nodes: Y = cast(ir.Buffer, epilogue_nodes[-1]) assert Y.get_name() in V.kernel.inplace_update_buffers @@ -411,8 +373,7 @@ def render( # type: ignore[override] self.m, self.n, self.k, - input_dtype=self.layout.dtype, - output_dtype=torch.float, + self.layout.dtype, alpha=self.alpha, num_threads=self.num_threads, ) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index c5e989eb2eed..649782ff158d 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -59,11 +59,7 @@ def __init__( def get_common_options(self): return { - "torch": torch, "kernel_name": self.name, - "input_dtype": self.input_dtype, - "output_dtype": self.output_dtype, - "compute_dtype": self.compute_dtype, "input_t": DTYPE_TO_CPP[self.input_dtype], "output_t": DTYPE_TO_CPP[self.output_dtype], "compute_t": DTYPE_TO_CPP[self.compute_dtype], @@ -140,29 +136,6 @@ def inner(cls): return inner -def generate_gemm_config( - vec_isa_cls, - register_blockings, - input_dtype=torch.float, - output_dtype=None, - compute_dtype=None, -): - if output_dtype is None: - output_dtype = input_dtype - if compute_dtype is None: - compute_dtype = output_dtype - return [ - CppMicroGemmConfig( - input_dtype, - output_dtype, - compute_dtype, - vec_isa_cls, - GemmBlocking(*blocking), - ) - for blocking in register_blockings - ] - - class CppMicroGemmRef(CppMicroGemm): """ A reference implementation of the CppMicroGemm class with naive C++ code. @@ -197,41 +170,28 @@ def codegen_define(self, kernel: CppTemplateKernel) -> str: @register_micro_gemm( - *generate_gemm_config( - VecAVX512, [(8, 48, 1), (8, 32, 1), (16, 16, 1)], input_dtype=torch.float + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 48, 1) ), - *generate_gemm_config( - VecAVX512, - [(8, 48, 1), (8, 32, 1), (16, 16, 1)], - input_dtype=torch.bfloat16, - output_dtype=torch.float, + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 32, 1) ), - *generate_gemm_config( - VecAVX512, - [(8, 48, 1), (8, 32, 1), (16, 16, 1)], - input_dtype=torch.half, - output_dtype=torch.float, + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(16, 16, 1) ), - *generate_gemm_config( - VecAVX2, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.float + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 24, 1) ), - *generate_gemm_config( - VecAVX2, - [(4, 24, 1), (4, 16, 1), (8, 8, 1)], - input_dtype=torch.bfloat16, - output_dtype=torch.float, + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 16, 1) ), - *generate_gemm_config( - VecAVX2, - [(4, 24, 1), (4, 16, 1), (8, 8, 1)], - input_dtype=torch.half, - output_dtype=torch.float, + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(8, 8, 1) ), ) class CppMicroGemmFP32Vec(CppMicroGemm): """ - This class generates the code for micro gemm using fp32 vec instructions for compute. - It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output. + This class generates the code for fp32 micro gemm using vec instructions. """ TEMPLATE_ENTRY = r""" @@ -279,23 +239,22 @@ class CppMicroGemmFP32Vec(CppMicroGemm): TEMPLATE_KERNEL = r""" template inline void {{kernel_name}}_kernel( - const {{input_t}}* __restrict__ A, - const {{input_t}}* __restrict__ B, - {{output_t}}* __restrict__ C, + const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc ) { - using Vectorized = at::vec::Vectorized<{{compute_t}}>; - using VectorizedIn = at::vec::Vectorized<{{input_t}}>; + using Vectorized = at::vec::Vectorized; constexpr auto VLEN = Vectorized::size(); constexpr auto ROWS = BLOCK_M; constexpr auto COLS = BLOCK_N / VLEN; Vectorized va; - at::vec::VectorizedN<{{compute_t}}, COLS> vb; - at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; + at::vec::VectorizedN vb; + at::vec::VectorizedN vc; auto loadc = [&](auto i) { if constexpr (accum) { @@ -314,19 +273,14 @@ class CppMicroGemmFP32Vec(CppMicroGemm): if constexpr (col == 0) { {%- if alpha != 1 %} - va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}}); + va = Vectorized(A[row * lda + k] * {{alpha}}); {%- else %} - va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k])); + va = Vectorized(A[row * lda + k]); {%- endif %} } if constexpr (row == 0) { - {%- if input_dtype == torch.bfloat16 or input_dtype == torch.float16 %} - auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); - vb[col] = at::vec::convert<{{compute_t}}>(b); - {%- else %} vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); - {%- endif %} } constexpr int idx = row * COLS + col; @@ -395,7 +349,7 @@ def create_from_config(cls, config: CppMicroGemmConfig): if output_dtype is None: output_dtype = input_dtype if compute_dtype is None: - compute_dtype = output_dtype + compute_dtype = input_dtype if num_threads < 0: num_threads = parallel_num_threads() vec_isa = pick_vec_isa() diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index f1d4fbaaac33..5a6c6969b20c 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -13,7 +13,7 @@ from ..virtualized import V from .common import Kernel, OpOverrides from .cpp import CppKernelProxy, KernelGroup -from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferScope +from .cpp_utils import cexpr_index, DTYPE_TO_CPP def parse_expr_with_index_symbols(expr): @@ -110,13 +110,7 @@ def index(self, node: ir.Buffer, indices: List[Any]) -> str: indexer = node.layout.as_fixed().make_indexer() index = indexer(parse_expr_with_index_symbols(indices)) index = self.rename_indexing(index) - outer_name = node.get_name() - inner_name = ( - outer_name - if outer_name in self.local_buffers - else self.args.input(node.get_name()) - ) - return f"{inner_name}[{cexpr_index(index)}]" + return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" def slice_nd(self, node, ranges: List[Tuple[Any, Any]]) -> ir.ReinterpretView: """ @@ -175,50 +169,6 @@ def define_buffer(self, name, sizes: List[Any], dtype=torch.float) -> str: numel = f"{cexpr_index(buf.get_numel())}" return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();" - def store_pointwise_nodes( - self, - dst: ir.Buffer, - nodes: List[ir.IRNode], - offsets: Optional[List[sympy.Expr]] = None, - reindexer: Optional[Callable[[List[Any]], List[Any]]] = None, - ) -> str: - var_sizes = (tuple(dst.get_size()), ()) - var_ranges = {sympy.Symbol(f"z{i}"): sz for i, sz in enumerate(var_sizes[0])} - if not offsets: - offsets = [sympy.Integer(0)] * len(var_sizes[0]) - assert len(offsets) == len(var_sizes[0]) - output_index = dst.get_layout().make_indexer()(var_ranges.keys()) - kernel_group = KernelGroup() - kernel_group.args = self.args - cpp_kernel_proxy = CppKernelProxy(kernel_group) - bodies = [] - var_sizes_list = [] - for i, node in enumerate(nodes): - output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name() - node = node.data if isinstance(node, ir.ComputedBuffer) else node - assert isinstance(node, ir.Pointwise), node - - def fn(*args): - assert len(args) == 2 - assert len(args[0]) == len(var_sizes[0]) - assert len(args[1]) == 0 - new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] - if reindexer is not None: - new_args = reindexer(new_args) - V.ops.store( - output_name, - output_index, - node.make_loader()(new_args).value, - ) - - body = ir.LoopBody(fn, (list(var_ranges.keys()), ()), var_ranges) - bodies.append(body) - var_sizes_list.append(var_sizes) - - cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - kernel_group.finalize_kernel(cpp_kernel_proxy, []) - return kernel_group.loops_code.getvalue() - def store_output( self, dst: ir.Buffer, @@ -246,20 +196,55 @@ def store_output( needed on the indices to `epilogue_nodes` to match the indexing of `dst`. """ assert dst.get_size() == src.get_size() - if offsets: - offsets = parse_expr_with_index_symbols(offsets) if epilogue_nodes: - return self.store_pointwise_nodes(dst, epilogue_nodes, offsets, reindexer) + var_sizes = (tuple(dst.get_size()), ()) + var_ranges = { + sympy.Symbol(f"z{i}"): sz for i, sz in enumerate(var_sizes[0]) + } + + # epilogues are all pointwises, hence all indexed the same way as dst + output_index = dst.get_layout().make_indexer()(var_ranges.keys()) + + if not offsets: + offsets = [0] * len(var_sizes[0]) + assert len(offsets) == len(var_sizes[0]) + offsets = parse_expr_with_index_symbols(offsets) + + kernel_group = KernelGroup() + kernel_group.args = self.args + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + for i, node in enumerate(epilogue_nodes): + assert isinstance(node, ir.ComputedBuffer) + output_name = ( + node.get_name() if i < len(epilogue_nodes) - 1 else dst.get_name() + ) + + def fn(*args): + assert len(args) == 2 + assert len(args[0]) == len(var_sizes[0]) + assert len(args[1]) == 0 + new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] + if reindexer is not None: + new_args = reindexer(new_args) + V.ops.store( + output_name, + output_index, + node.data.make_loader()(new_args).value, + ) + + body = ir.LoopBody(fn, (list(var_ranges.keys()), ()), var_ranges) + bodies.append(body) + var_sizes_list.append(var_sizes) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + return kernel_group.loops_code.getvalue() else: - if dst.get_name() != src.get_name(): - # src is local - copy = L.copy(dst, src).data.data - with LocalBufferScope(self) as scope: - scope.add_local_buffer(src) - return self.store_pointwise_nodes(dst, [copy]) - else: - assert dst.layout == src.layout - return "" + # TODO(jgong5): support local acc buffer to avoid assertion below + assert dst.get_name() == src.get_name() and dst.layout == src.layout + return "" class CppTemplateCaller(ir.ChoiceCaller): diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index dbe3daf1c45c..4ab33a5e26dc 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -1,15 +1,10 @@ -import contextlib import math from collections import namedtuple -from typing import Dict -from unittest.mock import patch import torch -from .. import ir -from ..virtualized import V -from .common import ExprPrinter, Kernel +from .common import ExprPrinter DTYPE_TO_CPP = { torch.float32: "float", @@ -246,58 +241,3 @@ def value_to_cpp(value, cpp_type): return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" else: return f"static_cast<{cpp_type}>({repr(value)})" - - -class LocalBufferScope: - """ - This class creates a context that helps to generate code involving Inductor IR with - function local buffers. These buffers are constructed during the codegen process and - are used to store intermediate results such as local accumulators. We do not want to - add them to `V.graph` since they are not global and we do not want to add them as - function arguments either. So we patch the codegen processes under this scope to support - these buffers without exposure to the outside world. - """ - - def __init__(self, kernel: Kernel): - self.kernel = kernel - self.exit_stack = contextlib.ExitStack() - self.local_buffers: Dict[str, ir.Buffer] = {} - - def __enter__(self): - self.exit_stack.__enter__() - original_get_dtype = V.graph.get_dtype - - def get_dtype(name): - if name in self.local_buffers: - return self.local_buffers[name].get_dtype() - return original_get_dtype(name) - - self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) - - original_input = self.kernel.args.input - - def input(name): - if name in self.local_buffers: - return name - return original_input(name) - - self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input)) - - original_output = self.kernel.args.output - - def output(name): - if name in self.local_buffers: - return name - return original_output(name) - - self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output)) - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.local_buffers.clear() - self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - - def add_local_buffer(self, buffer: ir.Buffer): - assert buffer.get_name() not in self.local_buffers - self.local_buffers[buffer.get_name()] = buffer diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 696641533500..e2209cd3472d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6318,7 +6318,7 @@ def codegen(self, wrapper): ) @classmethod - def create(cls, x, w, B, attr, scalars, algorithm): + def create(cls, x, w, b, attr, scalars, algorithm): x = cls.require_contiguous(cls.realize_input(x)) w = cls.require_contiguous(cls.realize_input(w)) @@ -6326,9 +6326,9 @@ def create(cls, x, w, B, attr, scalars, algorithm): oc, ic = w.get_size() inputs = [x, w] constant_args = [attr, scalars if scalars else [-1], algorithm] - if B is not None: - B = cls.require_contiguous(cls.realize_input(B)) - inputs.append(B) + if b is not None: + b = cls.require_contiguous(cls.realize_input(b)) + inputs.append(b) else: constant_args.insert(0, None) diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index f9f2e66ab68c..1f64574d589b 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -13,25 +13,14 @@ permute, register_lowering, to_dtype, - view, -) -from .select_algorithm import ( - autotune_select_algorithm, - ChoiceCaller, - ExternKernelChoice, ) +from .select_algorithm import autotune_select_algorithm, ExternKernelChoice from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune from .virtualized import V def register_onednn_fusion_ops(): if torch._C._has_mkldnn: - aten_mkldnn_linear_unary = ExternKernelChoice( - torch.ops.mkldnn._linear_pointwise, - "mkldnn::_linear_pointwise", - has_out_variant=False, - kernel_creator=ir.LinearUnary.create, - ) cpu_needs_realized_inputs = [ torch.ops.mkldnn._convolution_pointwise, torch.ops.mkldnn._convolution_pointwise_, @@ -139,77 +128,11 @@ def convolution_binary_inplace( @register_lowering(torch.ops.mkldnn._linear_pointwise) def linear_unary( - x: TensorBox, - w: TensorBox, - b: TensorBox, - attr, - scalars, - algorithm, - layout=None, + x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm ): - x_size = x.get_size() - if len(x_size) > 2: - # GEMM template needs 2D input, normalize input shape here - x = view(x, [-1, x_size[-1]]) - choices: List[ChoiceCaller] = [] - if len(choices) == 0 or use_aten_gemm_kernels(): - choices.append( - aten_mkldnn_linear_unary.bind( - (x, w), - layout, - B=None, - attr=attr, - scalars=scalars, - algorithm=algorithm, - ) - if b is None - else aten_mkldnn_linear_unary.bind( - (x, w, b), - layout, - attr=attr, - scalars=scalars, - algorithm=algorithm, - ) - ) - if use_max_autotune(): - transposed_w = permute(w, [1, 0]) - *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) - if b is not None: - b = ir.ExternKernel.realize_input(b) - # TODO(jgong5): support epilogue fusion - if ( - use_cpp_packed_gemm_template(layout, x, transposed_w) - and attr == "none" - ): - if b is None: - CppPackedGemmTemplate.add_choices( - choices, - layout, - [x, w], - trans_w=True, - ) - else: - CppPackedGemmTemplate.add_choices( - choices, - layout, - [x, w, b], - trans_w=True, - input_indices=[2, 0, 1], - ) - assert w.get_name() in V.graph.constants - input_gen_fns = { - 1: lambda x: V.graph.constants[x.get_name()], - } - result = autotune_select_algorithm( - "linear_unary", - choices, - [x, w] if b is None else [x, w, b], - layout, - input_gen_fns=input_gen_fns, + return TensorBox.create( + ir.LinearUnary.create(x, w, b, attr, scalars, algorithm) ) - if len(x_size) > 2: - result = view(result, (*x_size[:-1], result.get_size()[-1])) - return result @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr): @@ -510,7 +433,15 @@ def mkl_packed_linear( *, layout=None, ): - choices: List[ChoiceCaller] = [] + choices = ( + [ + aten_mkl_linear.bind( + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size + ) + ] + if use_aten_gemm_kernels() + else [] + ) if use_max_autotune(): transposed_w = permute(orig_w, [1, 0]) *_, layout, x, transposed_w = mm_args( @@ -525,13 +456,6 @@ def mkl_packed_linear( input_indices=[0, 2], ) - if len(choices) == 0 or use_aten_gemm_kernels(): - choices.append( - aten_mkl_linear.bind( - (x, packed_w, orig_w), layout, B=None, batch_size=batch_size - ) - ) - assert packed_w.get_name() in V.graph.constants assert orig_w.get_name() in V.graph.constants # packed_w is a mkldnn tensor which we can't generate directly diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 8b66b496fd43..7a96630ef213 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1025,7 +1025,7 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): if not config.cpp.weight_prepack: return False - layout_dtypes = [torch.float32, torch.bfloat16, torch.half] + layout_dtypes = [torch.float32] m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2) # TODO(jgong5): support dynamic shapes for n or k if has_free_symbols((n, k)): @@ -1033,13 +1033,7 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): if isinstance(mat2, ir.BaseView): mat2 = mat2.unwrap_view() micro_gemm = create_micro_gemm( - "micro_gemm", - m, - n, - k, - input_dtype=layout.dtype, - output_dtype=torch.float, - num_threads=parallel_num_threads(), + "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads() ) # TODO(jgong5): support n % n_block_size != 0 return ( From aaef7b29e9cc4a71463ab2049267f512ffa4d929 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 30 May 2024 21:43:46 -0400 Subject: [PATCH 161/706] Only register _inductor_test ops if not running with deploy (#127557) Internal xref: https://fb.workplace.com/groups/1405155842844877/posts/8498194410207616 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/127557 Approved by: https://github.com/zou3519 --- torch/_inductor/test_operators.py | 33 ++++++++++++++++--------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/torch/_inductor/test_operators.py b/torch/_inductor/test_operators.py index e8421722568c..8e85f8bebbdb 100644 --- a/torch/_inductor/test_operators.py +++ b/torch/_inductor/test_operators.py @@ -2,23 +2,24 @@ from torch import Tensor from torch.autograd import Function -_test_lib_def = torch.library.Library("_inductor_test", "DEF") -_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag) +if not torch._running_with_deploy(): + _test_lib_def = torch.library.Library("_inductor_test", "DEF") + _test_lib_def.define( + "realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag + ) -_test_lib_impl = torch.library.Library("_inductor_test", "IMPL") -for dispatch_key in ("CPU", "CUDA", "Meta"): - _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + _test_lib_impl = torch.library.Library("_inductor_test", "IMPL") + for dispatch_key in ("CPU", "CUDA", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + class Realize(Function): + @staticmethod + def forward(ctx, x): + return torch.ops._inductor_test.realize(x) -class Realize(Function): - @staticmethod - def forward(ctx, x): - return torch.ops._inductor_test.realize(x) + @staticmethod + def backward(ctx, grad_output): + return grad_output - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -def realize(x: Tensor) -> Tensor: - return Realize.apply(x) + def realize(x: Tensor) -> Tensor: + return Realize.apply(x) From 413b81789f05a33a349452fda5f8841ad362a110 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 30 May 2024 12:45:26 -0700 Subject: [PATCH 162/706] [AOTI][refactor] Unify val_to_arg_str and val_to_cpp_arg_str (#126916) Summary: Now fallback argument type information has been passed, so time to unify val_to_arg_str and val_to_cpp_arg_str Differential Revision: [D57907751](https://our.internmc.facebook.com/intern/diff/D57907751) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126916 Approved by: https://github.com/chenyang78 --- test/inductor/test_cpu_cpp_wrapper.py | 1 + torch/_inductor/codegen/cpp_wrapper_cpu.py | 245 ++++++++++----------- torch/_inductor/codegen/wrapper.py | 3 - torch/_inductor/ir.py | 53 +++-- 4 files changed, 156 insertions(+), 146 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 10744a675fbf..0f7430ad2696 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -94,6 +94,7 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): "test_qconv2d_maxpool2d_linear_dynamic_cpu", "test_qconv2d_relu_cpu", "test_qlinear_cpu", + "test_qlinear_add_cpu", "test_qlinear_dequant_promotion_cpu", "test_qlinear_relu_cpu", ] diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 90a0702d4b8a..dff0844371c2 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1541,8 +1541,8 @@ def codegen_int_array_var( writer=None, known_statically=False, graph=None, # for per-graph caching - is_bool=False, ): + # This is used for size/stride declaration # Because the memory planning is done in two passes (see the implementation # of self.generate), the writeline behavior is different in the two passes. # As a result, the emitted int array declarations may appear in a later @@ -1553,7 +1553,7 @@ def codegen_int_array_var( writer = self var = f"int_array_{next(self.int_array_id)}" - ctype = "int32_t" if is_bool else "int64_t" + ctype = "int64_t" if var not in self.declared_int_array_vars: self.declared_int_array_vars.add(var) if known_statically: @@ -1562,43 +1562,6 @@ def codegen_int_array_var( writer.writeline(f"const {ctype} {var}[] = {int_array};") return var - @functools.lru_cache(None) - def codegen_var_array( - self, - var_array: str, - writer=None, - known_statically=False, - graph=None, # for per-graph caching - type_hint=None, # ['int64_t', 'tensor', 'bool'] - ): - # Because the memory planning is done in two passes (see the implementation - # of self.generate), the writeline behavior is different in the two passes. - # As a result, the emitted int array declarations may appear in a later - # position of the generated code, so the second pass codegen should not - # reuse int array declarations generated in the first pass - if writer is None: - # The first pass codegen uses `self` as the writer - writer = self - if not type_hint or type_hint in ["bool", "int64_t"]: - return self.codegen_int_array_var( - var_array, - writer, - known_statically, - graph, - is_bool=type_hint == "bool", - ) - - var = f"var_array_{next(self.var_array_id)}" - assert type_hint == "tensor" - ctype = "AtenTensorHandle*" - if var not in self.declared_var_array_vars: - self.declared_var_array_vars.add(var) - if known_statically: - writer.writeline(f"static constexpr {ctype} {var}[] = {var_array};") - else: - writer.writeline(f"const {ctype} {var}[] = {var_array};") - return var - def make_buffer_allocation(self, buffer): return self.make_allocation( buffer.get_name(), @@ -2335,65 +2298,29 @@ def generate_reset_kernel_saved_flags(self): def generate_save_uncompiled_kernels(self): pass - def val_to_cpp_arg_str(self, val, type_) -> str: - if config.abi_compatible and isinstance(type_, torch.OptionalType): - if val is None: - return "0" # nullptr is not available in C - if not isinstance(type_.getElementType(), torch.TensorType): - var_name = f"var_{next(self.arg_var_id)}" - if isinstance( - type_.getElementType(), - (torch.ListType, torch.TupleType, torch.DeviceObjType), - ): - arg_str = self.val_to_arg_str(val) - if val is None: - return "{arg_str}, 0" - else: - # For datatypes with auxiliary info, we need to hoist out the extra arguments. - # NOTE: This only works if there is one additional argument, though it can easily be generalized. - main_value, aux = arg_str.rsplit(", ") - self.writeline(f"auto {var_name} = {main_value};") - return f"&{var_name}, {aux}" - else: - self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};") - return f"&{var_name}" - elif config.c_shim_version == "2": - # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim - base_handle = self.val_to_arg_str(val) - if "wrap_with_raii_handle_if_needed" in base_handle: - # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to - # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. - tmp_var_name = f"var_{next(self.arg_var_id)}" - self.writeline( - f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};" - ) - base_handle = tmp_var_name - var_name = f"var_{next(self.arg_var_id)}" - self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();") - return f"&{var_name}" - - return self.val_to_arg_str(val, type_) + def c_type_for_prim_type(self, type_) -> str: + assert ( + config.abi_compatible + ), "c_type_for_prim_type is only used in ABI compatible mode" + if isinstance(type_, torch.OptionalType): + return f"{self.c_type_for_prim_type(type_.getElementType())}*" + elif isinstance(type_, torch.TensorType): + return "AtenTensorHandle" + elif isinstance(type_, (torch.IntType, torch.SymIntType)): + return "int64_t" + elif ( + isinstance(type_, (torch.BoolType, torch.SymBoolType, torch.EnumType)) + or repr(type_) == "ScalarType" + ): + return "int32_t" + elif isinstance(type_, torch.FloatType): + return "double" + else: + raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}") - def val_to_arg_str(self, val, type_=None) -> str: - if val is None: - # When None is passed as an argument, it represents an optional that does not contain a value. - if config.abi_compatible: - if type_ is None or isinstance(type_, torch.OptionalType): - return "0" # nullptr is not available in C - elif isinstance(type_, torch.TensorType): - var_name = f"var_{next(self.arg_var_id)}" - self.writeline(f"AtenTensorHandle {var_name}_handle;") - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" - ) - self.writeline( - f"RAIIAtenTensorHandle {var_name}({var_name}_handle);" - ) - return var_name - else: - raise AssertionError("Can not map None to a known data type") - return "c10::nullopt" - elif isinstance(val, bool): + def val_to_arg_str_for_prim_type(self, val, type_) -> str: + # TODO: not using type_ as the first step of refactoring. Will update this later. + if isinstance(val, bool): if config.abi_compatible: return "1" if val else "0" else: @@ -2417,34 +2344,104 @@ def val_to_arg_str(self, val, type_=None) -> str: else: return "-std::numeric_limits::infinity()" elif isinstance(val, (list, tuple)): - # FIXME handle embedded optional types? - result = f"{{{', '.join(self.val_to_arg_str(x) for x in val)}}}" + # FIXME: This happens because type_ is not always properly set to torch.ListType + return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}" + else: + return repr(val) + + def val_to_arg_str(self, val, type_=None) -> str: + if val is None: + # None needs special care. It either represent nullopt or an empty tensor if config.abi_compatible: - assert len(val) > 0, "Empty array is not supported in C" - static = self.is_statically_known_list_of_ints(val) - type_hint = "bool" if isinstance(val[0], bool) else "int64_t" - if ( - type_ is not None - and isinstance(type_, torch._C.ListType) - and isinstance(type_.getElementType(), torch._C.OptionalType) - and isinstance( - type_.getElementType().getElementType(), torch._C.TensorType + if type_ is None or isinstance(type_, torch.OptionalType): + if type_ is not None and isinstance( + type_.getElementType(), + ( + torch.ListType, + torch.TupleType, + torch.DeviceObjType, + ), + ): + return "0, 0" + else: + return "0" # nullptr is not available in C + elif isinstance(type_, torch.TensorType): + # create an empty tensor, the equivalent of at::Tensor() + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" ) - ): - type_hint = "tensor" - tmp_arg_list = "" - for x in val: - tmp_arg_list += f"&{x}_handle, " - result = f"{{{tmp_arg_list}}}" - # Need to pass the array length because we can't use std::vector - var_array = self.codegen_var_array( - result, - known_statically=static, - graph=self.get_codegened_graph(), - type_hint=type_hint, + self.writeline( + f"RAIIAtenTensorHandle {var_name}({var_name}_handle);" + ) + return var_name + else: + raise AssertionError("Can not map None to a known data type") + else: + if isinstance(type_, torch.TensorType): + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"at::Tensor {var_name} = at::Tensor();") + return var_name + else: + return "std::nullopt" + + if isinstance(type_, torch.OptionalType): + element_type = type_.getElementType() + if config.abi_compatible: + if not isinstance(element_type, torch.TensorType): + var_name = f"var_{next(self.arg_var_id)}" + if isinstance( + element_type, + (torch.ListType, torch.TupleType, torch.DeviceObjType), + ): + # type_ is something like Optional[List] or Optional[Device] + arg_str = self.val_to_arg_str(val, element_type) + # For datatypes with auxiliary info, we need to hoist out the extra arguments. + # NOTE: This only works if there is one additional argument, though it can easily be generalized. + main_value, aux = arg_str.rsplit(", ") + self.writeline(f"auto {var_name} = {main_value};") + return f"&{var_name}, {aux}" + else: + self.writeline( + f"{self.c_type_for_prim_type(element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" + ) + return f"&{var_name}" + elif config.c_shim_version == "2": + # type_ is Optional[Tensor] + # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim + base_handle = self.val_to_arg_str(val, element_type) + if "wrap_with_raii_handle_if_needed" in base_handle: + # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to + # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. + tmp_var_name = f"var_{next(self.arg_var_id)}" + self.writeline( + f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};" + ) + base_handle = tmp_var_name + var_name = f"var_{next(self.arg_var_id)}" + self.writeline( + f"AtenTensorHandle {var_name} = {base_handle}.get();" + ) + return f"&{var_name}" + else: + return self.val_to_arg_str(val, element_type) + + elif isinstance(type_, torch.ListType): + assert isinstance( + val, (list, tuple) + ), f"{val} does not match with arg type {type_}" + element_type = type_.getElementType() + if config.abi_compatible: + assert len(val) > 0, "Empty array is not supported in C" + var_name = f"var_array_{next(self.var_array_id)}" + result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" + self.writeline( + f"const {self.c_type_for_prim_type(element_type)} {var_name}[] = {result};" ) - return f"{var_array}, {len(val)}" + # Need to pass the array length because we can't use std::vector + return f"{var_name}, {len(val)}" else: - return result - else: - return repr(val) + return f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" + + return self.val_to_arg_str_for_prim_type(val, type_) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 030c73833a0e..19bb7bf3c25e 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1411,9 +1411,6 @@ def writelines(self, lines): def enter_context(self, ctx): self.lines.append(LineContext(ctx)) - def val_to_cpp_arg_str(self, val, type_) -> str: - raise NotImplementedError - def val_to_arg_str(self, s, type_=None): from torch.utils._triton import dtype_to_string, has_triton_package diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index e2209cd3472d..90127f35363f 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3992,13 +3992,6 @@ def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: def collect_arg_kwarg_properties(self): # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen - if ( - isinstance(self.op_overload, torch._ops.OpOverload) - and not self.ordered_kwargs_for_cpp_kernel - ): - self.ordered_kwargs_for_cpp_kernel = [ - x.name for x in self.op_overload._schema.arguments if x.kwarg_only - ] self.arg_properties = ( [ { @@ -4012,15 +4005,23 @@ def collect_arg_kwarg_properties(self): if isinstance(self.op_overload, torch._ops.OpOverload) else [{} for i in range(len(self.inputs))] ) - self.kwarg_properties = ( + self.allarg_properties = ( { x.name: {"type": x.real_type, "default_value": x.default_value} for x in self.op_overload._schema.arguments - if x.kwarg_only } if isinstance(self.op_overload, torch._ops.OpOverload) else {} ) + # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes + # ordered_kwargs_for_cpp_kernel is explicilty passed in. + if ( + isinstance(self.op_overload, torch._ops.OpOverload) + and not self.ordered_kwargs_for_cpp_kernel + ): + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in self.op_overload._schema.arguments if x.kwarg_only + ] def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): # Previously, we want to maintain forward-compatibility by skipping @@ -4382,7 +4383,21 @@ def apply_constraint(self): pass def codegen_const_args(self): - return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) + if V.graph.cpp_wrapper: + result = [] + for i, x in enumerate(self.constant_args): + idx = len(self.inputs) + i + type_ = ( + self.arg_properties[i].get("type") + if self.arg_properties and idx < len(self.arg_properties) + else None + ) + result.append( + V.graph.wrapper_code.val_to_arg_str(x, type_) # type: ignore[arg-type] + ) + return result + else: + return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) def codegen_args(self): args = [] @@ -4395,10 +4410,10 @@ def codegen_args(self): if V.graph.cpp_wrapper: assert self.arg_properties and i < len( self.arg_properties - ), "Invalid arg_properties accessing" + ), "Invalid access to ExternKernel.arg_properties" type_ = self.arg_properties[i].get("type") args.append( - V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] + V.graph.wrapper_code.val_to_arg_str( # type: ignore[arg-type] x, type_ ) ) @@ -4410,10 +4425,10 @@ def codegen_args(self): def get_kwargs_value(self, arg_name): if arg_name in self.kwargs: return self.kwargs.get(arg_name) - if self.kwarg_properties and self.kwarg_properties.get(arg_name): - return self.kwarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] + if self.allarg_properties and self.allarg_properties.get(arg_name): + return self.allarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] else: - raise AssertionError(f"{arg_name} not in self.kwarg_properties") + raise AssertionError(f"{arg_name} not in self.allarg_properties") def codegen_kwargs(self, skip_out=False): if V.graph.cpp_wrapper: @@ -4428,12 +4443,12 @@ def codegen_kwargs(self, skip_out=False): kwargs.append(v) else: type_ = ( - self.kwarg_properties.get(arg_name).get("type") # type: ignore[union-attr] - if self.kwarg_properties and arg_name in self.kwarg_properties + self.allarg_properties.get(arg_name).get("type") # type: ignore[union-attr] + if self.allarg_properties and arg_name in self.allarg_properties else None ) kwargs.append( - V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] + V.graph.wrapper_code.val_to_arg_str( # type: ignore[arg-type] v, type_ ) ) @@ -5416,7 +5431,7 @@ def __repr__(self): if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): args = self.fill_non_provided_args(args, kwargs) args = [ - V.graph.wrapper_code.val_to_cpp_arg_str(x, param.real_type) + V.graph.wrapper_code.val_to_arg_str(x, param.real_type) for param, x in zip(self.op_overload._schema.arguments, args) ] else: From 17c5b6508ba5d9fc380c56e2482bb9ad5c99ac5f Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 30 May 2024 12:45:27 -0700 Subject: [PATCH 163/706] [AOTI] Support _CollectiveKernel in the cpp wrapper mode (#127037) Summary: _CollectiveKernel appears in TorchBench moco training. It's a special Fallback op that requires extra care. Differential Revision: [D57911441](https://our.internmc.facebook.com/intern/diff/D57911441) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127037 Approved by: https://github.com/malfet ghstack dependencies: #126916 --- .ci/pytorch/test.sh | 2 ++ torch/_inductor/codegen/cpp_wrapper_cpu.py | 17 +++++++++-------- torch/_inductor/ir.py | 3 +++ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 190f99204e9c..28be463fb587 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -572,6 +572,8 @@ test_inductor_torchbench_smoketest_perf() { --bfloat16 --inference --inductor --only hf_T5 --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ --bfloat16 --inference --inductor --only llama --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" + TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ + --bfloat16 --inference --inductor --only moco --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" python benchmarks/dynamo/check_accuracy.py \ --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" \ --expected "benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv" diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index dff0844371c2..bc157075f39b 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -2073,7 +2073,7 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed( def extract_output_name(out): assert out is not None, "None, i.e. optional output is not supported" - if isinstance(out, ir.MultiOutput): + if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): return out.get_name() elif isinstance(out, (list, tuple)): return type(out)(extract_output_name(o) for o in out) @@ -2154,17 +2154,19 @@ def generate_py_arg_inner(raw_arg, arg_type): return f"PyCapsule_New(reinterpret_cast({raw_arg.codegen_reference()}.get()), NULL, NULL)" elif isinstance(arg_type, torch.IntType): # int - return f"PyInt_FromLong({raw_arg})" + return f"PyLong_FromLongLong({raw_arg})" elif isinstance(arg_type, torch.SymIntType): # SymInt expr = ( raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg ) - return f"PyInt_FromLong({self.expr_printer(expr)})" + return f"PyLong_FromLongLong({self.expr_printer(expr)})" elif isinstance(arg_type, torch.FloatType): return f"PyFloat_FromDouble({raw_arg})" elif isinstance(arg_type, torch.BoolType): - return f"PyBool_FromBool({raw_arg})" + return f"PyBool_FromLong({raw_arg})" + elif isinstance(arg_type, torch.StringType): + return f'PyUnicode_FromString("{raw_arg}")' else: raise NotImplementedError( f"arg type {arg_type} is not yet supported by custom_op_wrapper" @@ -2308,10 +2310,9 @@ def c_type_for_prim_type(self, type_) -> str: return "AtenTensorHandle" elif isinstance(type_, (torch.IntType, torch.SymIntType)): return "int64_t" - elif ( - isinstance(type_, (torch.BoolType, torch.SymBoolType, torch.EnumType)) - or repr(type_) == "ScalarType" - ): + elif isinstance( + type_, (torch.BoolType, torch.SymBoolType, torch.EnumType) + ) or repr(type_) in ("ScalarType", "Layout"): return "int32_t" elif isinstance(type_, torch.FloatType): return "double" diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 90127f35363f..c46cad5e41e2 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5485,6 +5485,9 @@ def export_extern_kernel_node(self): ordered_kwargs = [ kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel ] + if not V.graph.aot_mode: + # No need to serialize in the cpp wrapper JIT mode + return [*args, *ordered_kwargs] serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # type: ignore[arg-type] From 10a92b5f84f4d2d894fa3dc6fca9be45128830b0 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 30 May 2024 12:45:27 -0700 Subject: [PATCH 164/706] [AOTI] Fix a bool value codegen issue when calling custom ops (#127398) Summary: fixes https://github.com/pytorch/pytorch/issues/127392 Differential Revision: [D57911527](https://our.internmc.facebook.com/intern/diff/D57911527) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127398 Approved by: https://github.com/angelayi, https://github.com/chenyang78 ghstack dependencies: #126916, #127037 --- test/inductor/test_cuda_cpp_wrapper.py | 1 + torch/_inductor/codegen/cpp_wrapper_cpu.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index bab01927fac6..495a6362497d 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -211,6 +211,7 @@ class BaseTest(NamedTuple): BaseTest("test_reduction1"), # Reduction BaseTest("test_relu"), # multiple inputs BaseTest("test_repeat_interleave_2"), + BaseTest("test_roi_align"), BaseTest("test_scalar_input"), BaseTest("test_scaled_dot_product_attention"), BaseTest("test_scaled_dot_product_efficient_attention"), diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index bc157075f39b..c24fa33676d4 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -2164,7 +2164,7 @@ def generate_py_arg_inner(raw_arg, arg_type): elif isinstance(arg_type, torch.FloatType): return f"PyFloat_FromDouble({raw_arg})" elif isinstance(arg_type, torch.BoolType): - return f"PyBool_FromLong({raw_arg})" + return f"PyBool_FromLong({1 if raw_arg else 0})" elif isinstance(arg_type, torch.StringType): return f'PyUnicode_FromString("{raw_arg}")' else: From cddb8dbebe717a13550a583546bd49a54899b094 Mon Sep 17 00:00:00 2001 From: Feny Patel Date: Fri, 31 May 2024 14:25:44 +0000 Subject: [PATCH 165/706] add workloadd events to pytorch (#127415) Summary: add workloadd events to pytorch Test Plan: CIs Differential Revision: D57914472 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127415 Approved by: https://github.com/sraikund16 --- torch/csrc/profiler/kineto_shim.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index c2a43ce95b1d..d808555da8e4 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -47,6 +47,7 @@ const std::set kXpuTypes = { const std::set kMtiaTypes = { libkineto::ActivityType::MTIA_CCP_EVENTS, libkineto::ActivityType::MTIA_RUNTIME, + libkineto::ActivityType::MTIA_WORKLOADD, }; const std::set kPrivateUse1Types = { libkineto::ActivityType::GPU_MEMCPY, @@ -344,9 +345,7 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { case libkineto::ActivityType::CONCURRENT_KERNEL: case libkineto::ActivityType::CUDA_SYNC: case libkineto::ActivityType::GPU_USER_ANNOTATION: - case libkineto::ActivityType::CUDA_PROFILER_RANGE: - // TODO: T151322015 - case libkineto::ActivityType::MTIA_CCP_EVENTS: { + case libkineto::ActivityType::CUDA_PROFILER_RANGE: { // PrivateUse1 kineto backend reuse above ActivityTypes, // If PrivateUse1 backend enabled, this should return // c10::DeviceType::PrivateUse1. @@ -358,6 +357,20 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { }(); return device_type; } + // TODO: T151322015 + case libkineto::ActivityType::MTIA_CCP_EVENTS: + case libkineto::ActivityType::MTIA_WORKLOADD: { + // PrivateUse1 kineto backend reuse above ActivityTypes, + // If PrivateUse1 backend enabled, this should return + // c10::DeviceType::PrivateUse1. + c10::DeviceType device_type = []() { + if (c10::get_privateuse1_backend() != "privateuseone") { + return c10::DeviceType::PrivateUse1; + } + return c10::DeviceType::MTIA; + }(); + return device_type; + } case libkineto::ActivityType::CPU_OP: case libkineto::ActivityType::USER_ANNOTATION: case libkineto::ActivityType::EXTERNAL_CORRELATION: From 4644def4344c1e67a8364ff872bdfeac7e363098 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 30 May 2024 15:08:29 -0700 Subject: [PATCH 166/706] Update docstring for weights_only (#127575) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127575 Approved by: https://github.com/malfet --- torch/serialization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/serialization.py b/torch/serialization.py index e4ad1f7e9c6e..d47a49ddf0fd 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -921,7 +921,8 @@ def load( pickle_module: module used for unpickling metadata and objects (has to match the :attr:`pickle_module` used to serialize file) weights_only: Indicates whether unpickler should be restricted to - loading only tensors, primitive types and dictionaries + loading only tensors, tensor subclasses, primitive types, dictionaries + and any types added via :func:`torch.serialization.add_safe_globals`. mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory. Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they are moved to the location that they were tagged with when saving, or specified by ``map_location``. This From 75e7588f47d01c3bd66aec099b537fff5d4a6d73 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Thu, 30 May 2024 20:17:02 -0700 Subject: [PATCH 167/706] [Inductor UT] Fix expected failure but pass for test case on Intel GPU. (#127595) The XPU expected failure test case `TritonCodeGenTests.test_codegen_config_option_dont_assume_alignment` should have been expected passed after the PR #126261 merged, but due to test flaky, this case was skiped when landing the PR. The expected failure but passed error then exposed in periodic test: https://github.com/pytorch/pytorch/actions/runs/9302864965/job/25605549183#step:14:2082. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127595 Approved by: https://github.com/EikanWang, https://github.com/chuanqi129, https://github.com/peterbell10, https://github.com/atalman --- test/inductor/test_torchinductor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 12ace877df05..111d0e1ef959 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10520,7 +10520,6 @@ def fn(a: torch.Tensor) -> torch.Tensor: self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1)) torch._dynamo.reset() - @expectedFailureXPU @config.patch(assume_aligned_inputs=False) def test_codegen_config_option_dont_assume_alignment(self): def fn(x: torch.Tensor) -> torch.Tensor: From a010fa9e243a2dd3215c4bd8a25eba3189163438 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Fri, 31 May 2024 15:32:06 +0000 Subject: [PATCH 168/706] [DCP] Fix variable spelling (#127565) Summary: tsia Test Plan: sandcastle Differential Revision: D57983752 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127565 Approved by: https://github.com/wz337, https://github.com/fegin --- torch/distributed/checkpoint/state_dict_saver.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index b715fcdd9ae5..451603288d12 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -274,7 +274,7 @@ def _save_state_dict( planner = DefaultSavePlanner() assert planner is not None - global_metatadata = None + global_metadata = None ckpt_kwargs = {} if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: @@ -305,10 +305,10 @@ def local_step(): @_dcp_method_logger(**ckpt_kwargs) def global_step(all_local_plans): - nonlocal global_metatadata + nonlocal global_metadata assert planner is not None - all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans) + all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) all_local_plans = storage_writer.prepare_global_plan(all_local_plans) return all_local_plans @@ -325,8 +325,8 @@ def write_data(): @_dcp_method_logger(**ckpt_kwargs) def finish_checkpoint(all_results): - assert global_metatadata is not None - storage_writer.finish(metadata=global_metatadata, results=all_results) - return global_metatadata + assert global_metadata is not None + storage_writer.finish(metadata=global_metadata, results=all_results) + return global_metadata return distW.all_reduce("write", write_data, finish_checkpoint) From 6b1b8d0193515d418229dac609648534c5cbc927 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 30 May 2024 23:57:41 -0700 Subject: [PATCH 169/706] [DSD] Remove the support of Dict[nn.Module, Dict[str, Any]] state_dict (#127070) Summary: This is a very complicated signature that is hard for users to reason. Remove the support of this feature. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127070 Approved by: https://github.com/wz337 --- .../distributed/checkpoint/test_state_dict.py | 46 ------------------- torch/distributed/checkpoint/state_dict.py | 44 +----------------- 2 files changed, 2 insertions(+), 88 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 94ec8602bd4f..28bd137f7b55 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -424,52 +424,6 @@ def test_strict(self) -> None: with self.assertRaisesRegex(RuntimeError, "Missing key"): set_model_state_dict(model, model_state_dict=model_state_dict) - @with_comms - @skip_if_lt_x_gpu(1) - def test_partial(self) -> None: - model = CompositeParamModel(device=torch.device("cuda")) - - model_state_dict1 = get_model_state_dict(model) - model_state_dict1 = copy.deepcopy(model_state_dict1) - model_state_dict2 = get_model_state_dict(model, submodules={model.l}) - model_state_dict2 = copy.deepcopy(model_state_dict2) - model_state_dict3 = get_model_state_dict( - model, - submodules={model.l}, - options=StateDictOptions(keep_submodule_prefixes=False), - ) - model_state_dict3 = copy.deepcopy(model_state_dict3) - self.assertEqual(len(model_state_dict2), 2) - self.assertEqual(len(model_state_dict3), 2) - for key in model_state_dict3.keys(): - full_fqn = f"l.{key}" - value1 = model_state_dict1[full_fqn] - value2 = model_state_dict2[full_fqn] - value3 = model_state_dict3[key] - self.assertEqual(value1, value2) - self.assertEqual(value2, value3) - - zeros_state_dict = { - k: torch.zeros_like(v) for k, v in model_state_dict1.items() - } - model.load_state_dict(zeros_state_dict) - set_model_state_dict( - model, - model_state_dict=model_state_dict2, - options=StateDictOptions(strict=False), - ) - self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) - self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) - - model.load_state_dict(zeros_state_dict) - set_model_state_dict( - model, - model_state_dict={model.l: model_state_dict3}, - options=StateDictOptions(strict=False), - ) - self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) - self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) - def _test_cpu_offload_full_state_dict( self, optimizer_class: Type[Optimizer] ) -> None: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index e7072d623012..ed0c6ab34837 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -928,32 +928,6 @@ def get_state_dict( return model_state_dict, optim_state_dict -def _unflatten_model_state_dict( - model: nn.Module, - state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]], -) -> Dict[str, ValueType]: - if not state_dict: - return {} - - if isinstance(next(iter(state_dict.keys())), nn.Module): - cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict) - new_state_dict: Dict[str, ValueType] = {} - for submodule, sub_state_dict in cast_state_dict.items(): - for name, m in model.named_modules(): - if m != submodule: - continue - - fqns = _get_fqns(model, name) - assert len(fqns) == 1, "FQNs for a submodule should only have 1 element" - prefix = f"{next(iter(fqns))}." - new_state_dict.update( - {prefix + subfqn: value for subfqn, value in sub_state_dict.items()} - ) - return new_state_dict - else: - return cast(Dict[str, ValueType], state_dict) - - def set_model_state_dict( model: nn.Module, model_state_dict: Dict[str, ValueType], @@ -967,11 +941,7 @@ def set_model_state_dict( Args: model (nn.Module): the nn.Module to the model. - model_state_dict: (Dict[str, ValueType]): - the model state_dict to load. If the key of the ``model_state_dict`` - is nn.Module, the key is a submodule of ``model`` and the value should - be the state_dict of the submodule. When loading the state_dict, - the prefix of the submodule will be append to the state_dict. + model_state_dict: (Dict[str, ValueType]): the model state_dict to load. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be loaded. See `StateDictOptions` for the details. @@ -983,9 +953,6 @@ def set_model_state_dict( :type model_state_dict: typing.Dict[str, ValueType] """ - model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( - model, model_state_dict - ) with gc_context(): info = _verify_options(model, tuple(), optim_only=False, options=options) @@ -1054,11 +1021,7 @@ def set_state_dict( model (nn.Module): the nn.Module to the model. optimizers (Union[Optimizer, Iterable[Optimizer]]): The optimizers that are used to optimize ``model``. - model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): - the model state_dict to load. If the key of the ``model_state_dict`` - is nn.Module, the key is a submodule of ``model`` and the value should - be the state_dict of the submodule. When loading the state_dict, - the prefix of the submodule will be append to the state_dict. + model_state_dict: (Dict[str, ValueType]]): the model state_dict to load. optim_state_dict: OptimizerStateType: the optimizer state_dict to load. options (StateDictOptions): the options to control how @@ -1074,9 +1037,6 @@ def set_state_dict( :type optim_state_dict: typing.OptimizerStateType """ - model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( - model, model_state_dict - ) with gc_context(): optimizers = ( (optimizers,) From bd868eeb281b602f3ee4ed773b31f89648c12650 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 30 May 2024 23:57:41 -0700 Subject: [PATCH 170/706] [DSD] Support flattening the optimizer state_dict when saving and unflattening when loading (#127071) Fixes https://github.com/pytorch/pytorch/issues/126595 **What does this PR do?** This PR unflattens the optimizer state_dict, similar to what TorchRec does. The current `get_optimizer_state_dict()` converts the parameter IDs to FQNs in order to avoid any conflict with different optimizers on different ranks. The current returned optimizer state_dict looks like the following one: ``` { "state": { "layer1.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor}, "layer2.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor}, }, "param_group": [ {"lr": 0.0, "betas": (0.9, 0.95), ..., "params": ["layer1.weight", "layer2.weight"]} ] } ``` While this can avoid the conflict and can support merging multiple optimizers use case (e.g., optimizer in backward), the current optimizer state_dict still cannot support MPMD (e.g., pipeline parallelism). The root cause is `param_group`. `param_group` cannot generate unique keys during saving -- DCP will flatten the dict but for `param_group`, DCP will get the keys like, `param_group.lr` or `param_group.params`. These keys will conflict when using pipeline parallelism. This PR flatten the optimizer state_dict to the one as the following one: ``` { "state.layer1.weight.step": 10, "state.layer2.weight.step": 10, "state.layer1.weight.exp_avg": SomeTensor, "state.layer2.weight.exp_avg": SomeTensor, "state.layer1.weight.exp_avg_sq": SomeTensor, "state.layer2.weight.exp_avg_sq": SomeTensor, "param_group.layer1.weight.lr" : 0.1, "param_group.layer2.weight.lr" : 0.1, "param_group.layer1.weight.betas" : (0.9, 0.95), "param_group.layer2.weight.betas" : (0.9, 0.95), } ``` This allows distributed state_dict (DSD) to support MPMD (e.g., pipeline parallelism). **Pros and Cons** *Pros* 1. Can support optimizer resharding (e.g., changing the parallelisms from 3D to 2D or changing the number of workers). 2. User don't need to manually add prefix to different optimizer. 3. Allow users to merge the optimizer states easily. One use case is loop-based pipeline parallelism. *Cons* 1. The implementation has a strong assumption of the structure of `param_groups` and its value. If the assumption changes or some customized optimizers do not meet the assumption, the implementations will be broken. 2. There will be extra values saved in the checkpoints. The assumption here is `param_group` generally contains scalars which are cheap to save. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127071 Approved by: https://github.com/wconstab, https://github.com/wz337 ghstack dependencies: #127070 --- .../distributed/checkpoint/test_state_dict.py | 27 ++++ torch/distributed/checkpoint/state_dict.py | 126 +++++++++++++++++- 2 files changed, 151 insertions(+), 2 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 28bd137f7b55..6d7cdf6d2d18 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -650,6 +650,33 @@ def test_fsdp_root_not_initialized(self) -> None: get_model_state_dict(fsdp_model) get_optimizer_state_dict(fsdp_model, fsdp_optim) + @with_comms + @skip_if_lt_x_gpu(2) + def test_flattened_osd(self) -> None: + device_mesh = init_device_mesh("cuda", (self.world_size,)) + model = CompositeParamModel(device=torch.device("cuda")) + fsdp_model = FSDP2(copy.deepcopy(model), mesh=device_mesh) + fsdp_optim = torch.optim.AdamW(fsdp_model.parameters()) + batch = torch.rand(8, 100, device="cuda") + fsdp_model(batch).sum().backward() + fsdp_optim.step() + fsdp_optim.zero_grad() + osd1 = get_optimizer_state_dict(fsdp_model, fsdp_optim) + osd2 = get_optimizer_state_dict( + fsdp_model, + fsdp_optim, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + fsdp_optim2 = torch.optim.AdamW(fsdp_model.parameters()) + set_optimizer_state_dict( + fsdp_model, optimizers=fsdp_optim2, optim_state_dict=osd2 + ) + self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) + set_optimizer_state_dict( + fsdp_model, optimizers=fsdp_optim2, optim_state_dict=osd1 + ) + self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index ed0c6ab34837..137938e774d7 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -150,6 +150,7 @@ class StateDictOptions: keep_submodule_prefixes: bool = True strict: bool = True broadcast_from_rank0: bool = False + flatten_optimizer_state_dict: bool = False @dataclass @@ -382,7 +383,7 @@ def _verify_state_dict( if info.handle_optim: if ( - not (optim_state_dict and optim_state_dict[STATE]) + not optim_state_dict and not (info.cpu_offload and info.full_state_dict) and (not info.broadcast_from_rank0) ): @@ -563,6 +564,115 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None: optim.zero_grad(set_to_none=True) +def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]: + """ + This API flattens the optimizer state_dict to support optimizer resharding for + MPMD, e.g., pipeline parallelism. + + Without the API, the original optimizer state_dict looks like: + { + "state": { + "layer1.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + "layer2.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + }, + "param_group": [ + { + "lr": 0.0, + "betas": (0.9, 0.95), ..., + "params": ["layer1.weight", "layer2.weight"] + } + ] + } + + With this API, the optimizer state_dict looks like: + { + "state.layer1.weight.step": 10, + "state.layer2.weight.step": 10, + "state.layer1.weight.exp_avg": SomeTensor, + "state.layer2.weight.exp_avg": SomeTensor, + "state.layer1.weight.exp_avg_sq": SomeTensor, + "state.layer2.weight.exp_avg_sq": SomeTensor, + "param_group.layer1.weight.lr" : 0.1, + "param_group.layer2.weight.lr" : 0.1, + "param_group.layer1.weight.betas" : (0.9, 0.95), + "param_group.layer2.weight.betas" : (0.9, 0.95), + } + + Note that if any of the value is a container, like the betas in the example, + this API won't flattent it. + """ + + def _raise_if_type_not_supported(v): + if not isinstance(v, (torch.Tensor, int, float)): + raise NotImplementedError( + "Flattening optimizer state_dict only supports " + "tensor, int, float states now. " + f"Type is {type(v)}." + ) + + ret: Dict[str, ValueType] = {} + for fqn, state in cast(DictValueType, state_dict[STATE]).items(): + for k, v in cast(DictValueType, state).items(): + _raise_if_type_not_supported(v) + ret[f"{STATE}.{fqn}.{k}"] = v + + for param_group in cast(ListDictValueType, state_dict[PG]): + fqns = param_group.pop(PARAMS) + for fqn in cast(List[str], fqns): + for k, v in param_group.items(): + ret[f"{PG}.{fqn}.{k}"] = v + return ret + + +def _unflatten_optim_state_dict( + optim: torch.optim.Optimizer, + state_dict: Dict[str, ValueType], + info: _StateDictInfo, +) -> OptimizerStateType: + """ + This API unflattens the state_dict generated by _flatten_optim_state_dict(). + See the docstring of _flatten_optim_state_dict() for more detail. + """ + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {STATE: state, PG: pg_state} + + for param_group in optim.param_groups: + pg_state.append({PARAMS: []}) + for param in param_group[PARAMS]: + for fqn in info.fqn_param_mapping[param]: + params = pg_state[-1][PARAMS] + assert isinstance(params, list) # typing + params.append(fqn) + if not param.requires_grad: + continue + state[fqn] = {} + for state_name in optim.state[param].keys(): + cast(DictValueType, state[fqn])[state_name] = state_dict[ + f"{STATE}.{fqn}.{state_name}" + ] + + first_param_fqn = cast(List[str], pg_state[-1][PARAMS])[0] + for k in param_group.keys(): + if k == PARAMS: + continue + value = state_dict[f"{PG}.{first_param_fqn}.{k}"] + if k not in pg_state[-1]: + pg_state[-1][k] = value + elif pg_state[-1][k] != value: + raise RuntimeError( + "All the parameters in the same parameter group should have " + f"the same saved param_group value. But {first_param_fqn}.{k} " + f"is {value} while other(s) is {pg_state[-1][k]}." + ) + + return return_osd + + def _get_optim_state_dict( model: nn.Module, optimizers: Tuple[torch.optim.Optimizer, ...], @@ -618,6 +728,11 @@ def _get_optim_state_dict( cast(DictValueType, optim_state_dict[STATE]).update(osd[STATE]) cast(ListDictValueType, optim_state_dict[PG]).extend(osd[PG]) + if info.flatten_optimizer_state_dict: + optim_state_dict = cast( + OptimizerStateType, _flatten_optim_state_dict(optim_state_dict) + ) + if info.full_state_dict: ranks_only = tuple() if not info.cpu_offload else (0,) return _gather_state_dict( @@ -695,7 +810,14 @@ def _load_optim_state_dict( for optim in optimizers: _init_optim_state(optim) if state_dict: - optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info) + if STATE in state_dict: + optim_state_dict = _split_optim_state_dict( + model, optim, state_dict, info + ) + else: + optim_state_dict = _unflatten_optim_state_dict( + optim, cast(Dict[str, ValueType], state_dict), info + ) else: optim_state_dict = {} if info.fsdp_modules: From 8b4ad3a8d9f5cd943160feb4c478d4603603ffd6 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 30 May 2024 23:57:41 -0700 Subject: [PATCH 171/706] [DSD] Unify the API signatures of set_model_state_dict and set_optimizer_state_dict (#127384) Summary: Allow the optim_state_dict argument to be a positional argument. This make sense since this is a required argument and this will make the function signature the consistent as set_model_state_dict without causing BC issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127384 Approved by: https://github.com/wz337 ghstack dependencies: #127070, #127071 --- torch/distributed/checkpoint/state_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 137938e774d7..8818ace78158 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -1085,8 +1085,8 @@ def set_model_state_dict( def set_optimizer_state_dict( model: nn.Module, optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], - *, optim_state_dict: OptimizerStateType, + *, options: Optional[StateDictOptions] = None, ) -> None: """Load the optimizers state_dict. From 64c581a1d4b85b8e62675aaf6afb9567a1ccbbc7 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 30 May 2024 23:57:42 -0700 Subject: [PATCH 172/706] [DSD] Make distributed state_dict support torch.distributed is not initialized case (#127385) Fixes https://github.com/pytorch/pytorch/issues/124942 Summary: Allow DSD to support loading the regular optimizer state_dict and can be used when torch.distributed.is_initialized() is False. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127385 Approved by: https://github.com/wz337 ghstack dependencies: #127070, #127071, #127384 --- .../distributed/checkpoint/test_state_dict.py | 28 ++++++++++++ torch/distributed/checkpoint/state_dict.py | 43 +++++++++++-------- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 6d7cdf6d2d18..8039c487962f 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -42,6 +42,7 @@ from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, + MultiProcessTestCase, with_comms, ) from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin @@ -678,5 +679,32 @@ def test_flattened_osd(self) -> None: self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) +class TestNoComm(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @skip_if_lt_x_gpu(1) + def test_no_dist(self) -> None: + model = CompositeParamModel(device=torch.device("cuda")) + optim = torch.optim.AdamW(model.parameters(), lr=1e-3) + + self.assertFalse(dist.is_initialized()) + msd = get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + for v in msd.values(): + self.assertFalse(v.is_cuda) + self.assertEqual(model.state_dict(), msd) + set_model_state_dict(model, model.state_dict()) + osd = get_optimizer_state_dict( + model, + optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + set_optimizer_state_dict(model, optim, osd) + set_optimizer_state_dict(model, optim, optim.state_dict()) + + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 8818ace78158..714210910072 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -407,6 +407,24 @@ def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Ca return call +def _maybe_full_or_cpu_state_dict( + state_dict: Dict[str, Any], info: _StateDictInfo +) -> Dict[str, Any]: + if info.full_state_dict: + ranks_only = ( + tuple() + if (not info.cpu_offload or not torch.distributed.is_initialized()) + else (0,) + ) + return _gather_state_dict( + state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only + ) + elif info.cpu_offload: + return _offload_state_dict_to_cpu(state_dict) + else: + return state_dict + + def _get_model_state_dict( model: nn.Module, info: _StateDictInfo ) -> Dict[str, ValueType]: @@ -471,15 +489,7 @@ def verify(key, fqn) -> bool: if torch.is_tensor(p) and p.is_meta: state_dict.pop(key) - if info.full_state_dict: - ranks_only = tuple() if not info.cpu_offload else (0,) - return _gather_state_dict( - state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only - ) - elif info.cpu_offload: - return _offload_state_dict_to_cpu(state_dict) - else: - return state_dict + return _maybe_full_or_cpu_state_dict(state_dict, info) def _load_model_state_dict( @@ -733,15 +743,7 @@ def _get_optim_state_dict( OptimizerStateType, _flatten_optim_state_dict(optim_state_dict) ) - if info.full_state_dict: - ranks_only = tuple() if not info.cpu_offload else (0,) - return _gather_state_dict( - optim_state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only - ) - elif info.cpu_offload: - return _offload_state_dict_to_cpu(optim_state_dict) - else: - return optim_state_dict + return _maybe_full_or_cpu_state_dict(optim_state_dict, info) def _split_optim_state_dict( @@ -770,6 +772,11 @@ def _split_optim_state_dict( return_osd: OptimizerStateType = {STATE: state, PG: pg_state} pg_mapping: Dict[int, int] = {} + if all( + isinstance(k, int) for k in cast(DictValueType, optim_state_dict[STATE]).keys() + ): + return optim_state_dict + for param_group in optim.param_groups: pg_state.append({PARAMS: []}) for param in param_group[PARAMS]: From 67f080704228be106129cbd2a2fab61aad783e52 Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Fri, 31 May 2024 16:35:57 +0000 Subject: [PATCH 173/706] [Inductor] [CI] [CUDA] Skip the failed models and tests the better way (#127150) Address subtasks in https://github.com/pytorch/pytorch/issues/126692 After enabling the disabled shards, the following two models regressed (for cu124 configuration): dynamic_inductor_timm_training.csv cspdarknet53,pass,7 (expected) | cspdarknet53,fail_accuracy,7 (actual) eca_botnext26ts_256,pass,7 (expected) | eca_botnext26ts_256,fail_accuracy,7 (actual) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127150 Approved by: https://github.com/huydhn, https://github.com/eqy, https://github.com/atalman --- .ci/pytorch/test.sh | 30 +- .github/workflows/inductor.yml | 16 + .../cu124/aot_eager_huggingface_inference.csv | 185 +++++++++ .../cu124/aot_eager_huggingface_training.csv | 185 +++++++++ .../cu124/aot_eager_timm_inference.csv | 245 +++++++++++ .../cu124/aot_eager_timm_training.csv | 245 +++++++++++ .../cu124/aot_eager_torchbench_inference.csv | 381 ++++++++++++++++++ .../cu124/aot_eager_torchbench_training.csv | 289 +++++++++++++ .../aot_inductor_huggingface_inference.csv | 185 +++++++++ .../cu124/aot_inductor_timm_inference.csv | 245 +++++++++++ .../aot_inductor_torchbench_inference.csv | 353 ++++++++++++++++ .../cpu_inductor_huggingface_inference.csv | 185 +++++++++ .../cu124/cpu_inductor_timm_inference.csv | 245 +++++++++++ .../cpu_inductor_torchbench_inference.csv | 341 ++++++++++++++++ ...ynamic_aot_eager_huggingface_inference.csv | 185 +++++++++ ...dynamic_aot_eager_huggingface_training.csv | 185 +++++++++ .../dynamic_aot_eager_timm_inference.csv | 245 +++++++++++ .../cu124/dynamic_aot_eager_timm_training.csv | 245 +++++++++++ ...dynamic_aot_eager_torchbench_inference.csv | 377 +++++++++++++++++ .../dynamic_aot_eager_torchbench_training.csv | 285 +++++++++++++ ...mic_cpu_inductor_huggingface_inference.csv | 185 +++++++++ .../dynamic_cpu_inductor_timm_inference.csv | 245 +++++++++++ ...amic_cpu_inductor_torchbench_inference.csv | 301 ++++++++++++++ ...dynamic_inductor_huggingface_inference.csv | 185 +++++++++ .../dynamic_inductor_huggingface_training.csv | 185 +++++++++ .../cu124/dynamic_inductor_timm_inference.csv | 245 +++++++++++ .../cu124/dynamic_inductor_timm_training.csv | 245 +++++++++++ .../dynamic_inductor_torchbench_inference.csv | 377 +++++++++++++++++ .../dynamic_inductor_torchbench_training.csv | 285 +++++++++++++ .../dynamo_eager_huggingface_inference.csv | 185 +++++++++ .../dynamo_eager_huggingface_training.csv | 185 +++++++++ .../cu124/dynamo_eager_timm_inference.csv | 245 +++++++++++ .../cu124/dynamo_eager_timm_training.csv | 245 +++++++++++ .../dynamo_eager_torchbench_inference.csv | 381 ++++++++++++++++++ .../dynamo_eager_torchbench_training.csv | 289 +++++++++++++ .../cu124/inductor_huggingface_inference.csv | 185 +++++++++ .../cu124/inductor_huggingface_training.csv | 185 +++++++++ .../cu124/inductor_timm_inference.csv | 245 +++++++++++ .../cu124/inductor_timm_training.csv | 245 +++++++++++ .../cu124/inductor_torchbench_inference.csv | 381 ++++++++++++++++++ .../cu124/inductor_torchbench_training.csv | 289 +++++++++++++ .../cu124/update_expected.py | 172 ++++++++ 42 files changed, 10131 insertions(+), 6 deletions(-) create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_huggingface_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_timm_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_huggingface_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_timm_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_torchbench_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_huggingface_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_timm_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_torchbench_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 28be463fb587..1d185747abf8 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -264,6 +264,18 @@ elif [[ $TEST_CONFIG == 'nogpu_AVX512' ]]; then export ATEN_CPU_CAPABILITY=avx2 fi +# temp workarounds for https://github.com/pytorch/pytorch/issues/126692, remove when fixed +if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then + pushd test + CUDA_VERSION=$(python -c "import torch; print(torch.version.cuda)") + if [ "$CUDA_VERSION" == "12.4" ]; then + ISCUDA124="cu124" + else + ISCUDA124="" + fi + popd +fi + test_python_legacy_jit() { time python test/run_test.py --include test_jit_legacy test_jit_fuser_legacy --verbose assert_git_not_dirty @@ -364,7 +376,7 @@ test_inductor_cpp_wrapper_abi_compatible() { --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" python benchmarks/dynamo/check_accuracy.py \ --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv" + --expected "benchmarks/dynamo/ci_expected_accuracy/${ISCUDA124}/inductor_timm_training.csv" } # "Global" flags for inductor benchmarking controlled by TEST_CONFIG @@ -526,10 +538,10 @@ test_single_dynamo_benchmark() { --output "$TEST_REPORTS_DIR/${name}_${suite}.csv" python benchmarks/dynamo/check_accuracy.py \ --actual "$TEST_REPORTS_DIR/${name}_$suite.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/${TEST_CONFIG}_${name}.csv" + --expected "benchmarks/dynamo/ci_expected_accuracy/${ISCUDA124}/${TEST_CONFIG}_${name}.csv" python benchmarks/dynamo/check_graph_breaks.py \ --actual "$TEST_REPORTS_DIR/${name}_$suite.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/${TEST_CONFIG}_${name}.csv" + --expected "benchmarks/dynamo/ci_expected_accuracy/${ISCUDA124}/${TEST_CONFIG}_${name}.csv" fi } @@ -576,7 +588,7 @@ test_inductor_torchbench_smoketest_perf() { --bfloat16 --inference --inductor --only moco --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" python benchmarks/dynamo/check_accuracy.py \ --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv" + --expected "benchmarks/dynamo/ci_expected_accuracy/${ISCUDA124}/inductor_torchbench_inference.csv" python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --float16 --training \ --batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only hf_Bert \ @@ -591,7 +603,13 @@ test_inductor_torchbench_smoketest_perf() { # https://github.com/pytorch/pytorch/actions/runs/7158691360/job/19491437314, # and thus we lower its threshold to reduce flakiness. If this continues to be a problem, # we switch to use some other model. - python benchmarks/dynamo/check_perf_csv.py -f "$TEST_REPORTS_DIR/inductor_inference_smoketest.csv" -t 4.9 + # Use 4.7 for cuda 12.4, change back to 4.9 after fixing https://github.com/pytorch/pytorch/issues/126692 + if [ "$CUDA_VERSION" == "12.4" ]; then + THRESHOLD=4.7 + else + THRESHOLD=4.9 + fi + python benchmarks/dynamo/check_perf_csv.py -f "$TEST_REPORTS_DIR/inductor_inference_smoketest.csv" -t $THRESHOLD # Check memory compression ratio for a few models for test in hf_Albert timm_vision_transformer; do @@ -610,7 +628,7 @@ test_inductor_torchbench_smoketest_perf() { --only $test --output "$TEST_REPORTS_DIR/inductor_warm_start_smoketest_$test.csv" python benchmarks/dynamo/check_accuracy.py \ --actual "$TEST_REPORTS_DIR/inductor_warm_start_smoketest_$test.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv" + --expected "benchmarks/dynamo/ci_expected_accuracy/${ISCUDA124}/inductor_huggingface_training.csv" done } diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index cb5122e631bb..4afd87f056f3 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -140,11 +140,15 @@ jobs: { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, @@ -192,6 +196,18 @@ jobs: { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} + linux-focal-cuda12_4-py3_10-gcc9-inductor-test-gcp: + name: cuda12.4-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.test-matrix }} + use-gha: anything-non-empty-to-use-gha + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + linux-focal-cuda12_4-py3_12-gcc9-inductor-test: name: cuda12.4-py3.12-gcc9-sm86 uses: ./.github/workflows/_linux-test.yml diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_training.csv new file mode 100644 index 000000000000..a5e00513153d --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_training.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,4 + + + +AlbertForQuestionAnswering,pass,5 + + + +AllenaiLongformerBase,pass,9 + + + +BartForCausalLM,pass,12 + + + +BartForConditionalGeneration,pass,24 + + + +BertForMaskedLM,pass,5 + + + +BertForQuestionAnswering,pass,5 + + + +BlenderbotForCausalLM,eager_fail_to_run,0 + + + +BlenderbotSmallForCausalLM,pass,12 + + + +BlenderbotSmallForConditionalGeneration,pass,24 + + + +CamemBert,pass,5 + + + +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0 + + + +DistilBertForMaskedLM,pass,5 + + + +DistilBertForQuestionAnswering,pass,5 + + + +DistillGPT2,pass,5 + + + +ElectraForCausalLM,pass,4 + + + +ElectraForQuestionAnswering,pass,5 + + + +GPT2ForSequenceClassification,pass,7 + + + +GoogleFnet,pass,5 + + + +LayoutLMForMaskedLM,pass,5 + + + +LayoutLMForSequenceClassification,pass,7 + + + +M2M100ForConditionalGeneration,pass,4 + + + +MBartForCausalLM,pass,12 + + + +MBartForConditionalGeneration,pass,24 + + + +MT5ForConditionalGeneration,pass,5 + + + +MegatronBertForCausalLM,pass,5 + + + +MegatronBertForQuestionAnswering,pass,5 + + + +MobileBertForMaskedLM,pass,3 + + + +MobileBertForQuestionAnswering,pass,3 + + + +OPTForCausalLM,pass,12 + + + +PLBartForCausalLM,pass,12 + + + +PLBartForConditionalGeneration,pass,29 + + + +PegasusForCausalLM,pass,12 + + + +PegasusForConditionalGeneration,pass,23 + + + +RobertaForCausalLM,pass,5 + + + +RobertaForQuestionAnswering,pass,5 + + + +Speech2Text2ForCausalLM,pass,12 + + + +T5ForConditionalGeneration,pass,5 + + + +T5Small,pass,5 + + + +TrOCRForCausalLM,pass,12 + + + +XGLMForCausalLM,pass,12 + + + +XLNetLMHeadModel,pass,5 + + + +YituTechConvBert,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_training.csv new file mode 100644 index 000000000000..1def1d99bd53 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_training.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,6 + + + +beit_base_patch16_224,pass,7 + + + +botnet26t_256,pass,6 + + + +cait_m36_384,eager_fail_to_run,0 + + + +coat_lite_mini,pass,6 + + + +convit_base,pass,7 + + + +convmixer_768_32,pass,5 + + + +convnext_base,pass,7 + + + +crossvit_9_240,pass,7 + + + +cspdarknet53,pass,7 + + + +deit_base_distilled_patch16_224,pass,7 + + + +dla102,pass,7 + + + +dm_nfnet_f0,pass,6 + + + +dpn107,pass,6 + + + +eca_botnext26ts_256,pass,7 + + + +eca_halonext26ts,pass,7 + + + +ese_vovnet19b_dw,pass,7 + + + +fbnetc_100,pass,7 + + + +fbnetv3_b,pass,6 + + + +gernet_l,pass,6 + + + +ghostnet_100,pass,6 + + + +gluon_inception_v3,pass,7 + + + +gmixer_24_224,pass,6 + + + +gmlp_s16_224,pass,7 + + + +hrnet_w18,pass,5 + + + +inception_v3,pass,6 + + + +jx_nest_base,pass,7 + + + +lcnet_050,fail_accuracy,6 + + + +levit_128,pass,7 + + + +mixer_b16_224,pass,7 + + + +mixnet_l,pass,6 + + + +mnasnet_100,pass,7 + + + +mobilenetv2_100,pass,7 + + + +mobilenetv3_large_100,pass,7 + + + +mobilevit_s,pass,6 + + + +nfnet_l0,pass,7 + + + +pit_b_224,pass,6 + + + +pnasnet5large,pass,5 + + + +poolformer_m36,pass,6 + + + +regnety_002,pass,6 + + + +repvgg_a2,pass,7 + + + +res2net101_26w_4s,pass,6 + + + +res2net50_14w_8s,pass,6 + + + +res2next50,pass,6 + + + +resmlp_12_224,pass,6 + + + +resnest101e,pass,6 + + + +rexnet_100,pass,7 + + + +sebotnet33ts_256,pass,6 + + + +selecsls42b,pass,6 + + + +spnasnet_100,pass,7 + + + +swin_base_patch4_window7_224,pass,7 + + + +swsl_resnext101_32x16d,pass,6 + + + +tf_efficientnet_b0,pass,6 + + + +tf_mixnet_l,pass,6 + + + +tinynet_a,pass,6 + + + +tnt_s_patch16_224,pass,7 + + + +twins_pcpvt_base,pass,7 + + + +visformer_small,pass,7 + + + +vit_base_patch16_224,pass,7 + + + +volo_d1_224,pass,7 + + + +xcit_large_24_p8_224,pass_due_to_skip,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_inference.csv new file mode 100644 index 000000000000..20fb340690ac --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_inference.csv @@ -0,0 +1,381 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,pass,12 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +cm3leon_generate,pass,4 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_fcos_r_50_fpn,pass,21 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,pass,46 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_generate,pass,5 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,pass,5 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,pass,0 + + + +sam_fast,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,17 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv new file mode 100644 index 000000000000..5131c2e9ade4 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv @@ -0,0 +1,289 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,pass,6 + + + +BERT_pytorch,pass,6 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,eager_fail_to_run,0 + + + +LearningToPaint,pass,6 + + + +Super_SloMo,pass,7 + + + +alexnet,pass,6 + + + +basic_gnn_edgecnn,pass,22 + + + +basic_gnn_gcn,pass,13 + + + +basic_gnn_gin,pass,7 + + + +basic_gnn_sage,pass,7 + + + +dcgan,pass,6 + + + +demucs,pass,9 + + + +densenet121,pass,6 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +dlrm,pass,6 + + + +drq,pass,6 + + + +fastNLP_Bert,pass,10 + + + +functorch_dp_cifar10,pass,7 + + + +functorch_maml_omniglot,pass,7 + + + +hf_Albert,pass,6 + + + +hf_Bart,pass,6 + + + +hf_Bert,pass,6 + + + +hf_Bert_large,pass,6 + + + +hf_BigBird,pass, 52 + + + +hf_DistilBert,pass,6 + + + +hf_GPT2,pass,6 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,26 + + + +hf_T5_base,eager_2nd_run_OOM,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,6 + + + +hf_distil_whisper,model_fail_to_load,0 + + + +lennard_jones,pass,7 + + + +llava,model_fail_to_load,0 + + + +maml_omniglot,pass,7 + + + +mnasnet1_0,pass,7 + + + +mobilenet_v2,pass,6 + + + +mobilenet_v2_quantized_qat,eager_fail_to_run,0 + + + +mobilenet_v3_large,pass,7 + + + +moco,pass,11 + + + +nanogpt,pass,7 + + + +nvidia_deeprecommender,pass,7 + + + +opacus_cifar10,eager_fail_to_run,0 + + + +phlippe_densenet,pass,6 + + + +phlippe_resnet,pass,6 + + + +pytorch_CycleGAN_and_pix2pix,pass,6 + + + +pytorch_stargan,pass,6 + + + +pytorch_unet,pass_due_to_skip,7 + + + +resnet152,pass,7 + + + +resnet18,pass,6 + + + +resnet50,pass,6 + + + +resnet50_quantized_qat,eager_fail_to_run,0 + + + +resnext50_32x4d,pass,7 + + + +sam,eager_fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,6 + + + +soft_actor_critic,pass,6 + + + +speech_transformer,pass,16 + + + +squeezenet1_1,pass,6 + + + +stable_diffusion_text_encoder,pass,5 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,7 + + + +timm_regnet,pass,6 + + + +timm_resnest,pass,7 + + + +timm_vision_transformer,pass,6 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,6 + + + +torch_multimodal_clip,pass,7 + + + +tts_angular,pass,9 + + + +vgg16,pass,6 + + + +vision_maskrcnn,pass,34 + + + +yolov3,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_huggingface_inference.csv new file mode 100644 index 000000000000..784d3788e335 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,fail_to_run,0 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,0 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,0 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_timm_inference.csv new file mode 100644 index 000000000000..c7e86a6d317e --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,fail_to_run,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv new file mode 100644 index 000000000000..40382a4f277c --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv @@ -0,0 +1,353 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,fail_to_run,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,fail_to_run,0 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,fail_to_run,0 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +dcgan,pass,0 + + + +demucs,pass,0 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,fail_to_run,0 + + + +doctr_reco_predictor,fail_to_run,0 + + + +drq,fail_to_run,0 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,fail_to_run,0 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,fail_to_run,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,fail_to_run,0 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,fail_to_run,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,fail_to_run,0 + + + +sam_fast,fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,fail_to_run,0 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,fail_to_run,0 + + + +vgg16,pass,0 + + + +vision_maskrcnn,fail_to_run,0 + + + +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_torchbench_inference.csv new file mode 100644 index 000000000000..fcd87f4d2454 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_torchbench_inference.csv @@ -0,0 +1,341 @@ +name,accuracy,graph_breaks + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,model_fail_to_load,0 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,pass,42 + + + +detectron2_fasterrcnn_r_101_dc5,pass,42 + + + +detectron2_fasterrcnn_r_101_fpn,pass,46 + + + +detectron2_fasterrcnn_r_50_c4,pass,42 + + + +detectron2_fasterrcnn_r_50_dc5,pass,42 + + + +detectron2_fasterrcnn_r_50_fpn,pass,46 + + + +detectron2_fcos_r_50_fpn,pass,23 + + + +detectron2_maskrcnn_r_101_c4,fail_accuracy,57 + + + +detectron2_maskrcnn_r_101_fpn,pass,64 + + + +detectron2_maskrcnn_r_50_c4,pass,57 + + + +detectron2_maskrcnn_r_50_fpn,pass,64 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5_base,pass,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,pass,2 + + + +mobilenet_v3_large,pass,0 + + + +moco,model_fail_to_load,0 + + + +moondream,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,pass,2 + + + +resnext50_32x4d,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientdet,model_fail_to_load,0 + + + +timm_efficientnet,pass,0 + + + +timm_nfnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,28 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_training.csv new file mode 100644 index 000000000000..a5e00513153d --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_training.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,4 + + + +AlbertForQuestionAnswering,pass,5 + + + +AllenaiLongformerBase,pass,9 + + + +BartForCausalLM,pass,12 + + + +BartForConditionalGeneration,pass,24 + + + +BertForMaskedLM,pass,5 + + + +BertForQuestionAnswering,pass,5 + + + +BlenderbotForCausalLM,eager_fail_to_run,0 + + + +BlenderbotSmallForCausalLM,pass,12 + + + +BlenderbotSmallForConditionalGeneration,pass,24 + + + +CamemBert,pass,5 + + + +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0 + + + +DistilBertForMaskedLM,pass,5 + + + +DistilBertForQuestionAnswering,pass,5 + + + +DistillGPT2,pass,5 + + + +ElectraForCausalLM,pass,4 + + + +ElectraForQuestionAnswering,pass,5 + + + +GPT2ForSequenceClassification,pass,7 + + + +GoogleFnet,pass,5 + + + +LayoutLMForMaskedLM,pass,5 + + + +LayoutLMForSequenceClassification,pass,7 + + + +M2M100ForConditionalGeneration,pass,4 + + + +MBartForCausalLM,pass,12 + + + +MBartForConditionalGeneration,pass,24 + + + +MT5ForConditionalGeneration,pass,5 + + + +MegatronBertForCausalLM,pass,5 + + + +MegatronBertForQuestionAnswering,pass,5 + + + +MobileBertForMaskedLM,pass,3 + + + +MobileBertForQuestionAnswering,pass,3 + + + +OPTForCausalLM,pass,12 + + + +PLBartForCausalLM,pass,12 + + + +PLBartForConditionalGeneration,pass,29 + + + +PegasusForCausalLM,pass,12 + + + +PegasusForConditionalGeneration,pass,23 + + + +RobertaForCausalLM,pass,5 + + + +RobertaForQuestionAnswering,pass,5 + + + +Speech2Text2ForCausalLM,pass,12 + + + +T5ForConditionalGeneration,pass,5 + + + +T5Small,pass,5 + + + +TrOCRForCausalLM,pass,12 + + + +XGLMForCausalLM,pass,12 + + + +XLNetLMHeadModel,pass,5 + + + +YituTechConvBert,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_training.csv new file mode 100644 index 000000000000..1def1d99bd53 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_training.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,6 + + + +beit_base_patch16_224,pass,7 + + + +botnet26t_256,pass,6 + + + +cait_m36_384,eager_fail_to_run,0 + + + +coat_lite_mini,pass,6 + + + +convit_base,pass,7 + + + +convmixer_768_32,pass,5 + + + +convnext_base,pass,7 + + + +crossvit_9_240,pass,7 + + + +cspdarknet53,pass,7 + + + +deit_base_distilled_patch16_224,pass,7 + + + +dla102,pass,7 + + + +dm_nfnet_f0,pass,6 + + + +dpn107,pass,6 + + + +eca_botnext26ts_256,pass,7 + + + +eca_halonext26ts,pass,7 + + + +ese_vovnet19b_dw,pass,7 + + + +fbnetc_100,pass,7 + + + +fbnetv3_b,pass,6 + + + +gernet_l,pass,6 + + + +ghostnet_100,pass,6 + + + +gluon_inception_v3,pass,7 + + + +gmixer_24_224,pass,6 + + + +gmlp_s16_224,pass,7 + + + +hrnet_w18,pass,5 + + + +inception_v3,pass,6 + + + +jx_nest_base,pass,7 + + + +lcnet_050,fail_accuracy,6 + + + +levit_128,pass,7 + + + +mixer_b16_224,pass,7 + + + +mixnet_l,pass,6 + + + +mnasnet_100,pass,7 + + + +mobilenetv2_100,pass,7 + + + +mobilenetv3_large_100,pass,7 + + + +mobilevit_s,pass,6 + + + +nfnet_l0,pass,7 + + + +pit_b_224,pass,6 + + + +pnasnet5large,pass,5 + + + +poolformer_m36,pass,6 + + + +regnety_002,pass,6 + + + +repvgg_a2,pass,7 + + + +res2net101_26w_4s,pass,6 + + + +res2net50_14w_8s,pass,6 + + + +res2next50,pass,6 + + + +resmlp_12_224,pass,6 + + + +resnest101e,pass,6 + + + +rexnet_100,pass,7 + + + +sebotnet33ts_256,pass,6 + + + +selecsls42b,pass,6 + + + +spnasnet_100,pass,7 + + + +swin_base_patch4_window7_224,pass,7 + + + +swsl_resnext101_32x16d,pass,6 + + + +tf_efficientnet_b0,pass,6 + + + +tf_mixnet_l,pass,6 + + + +tinynet_a,pass,6 + + + +tnt_s_patch16_224,pass,7 + + + +twins_pcpvt_base,pass,7 + + + +visformer_small,pass,7 + + + +vit_base_patch16_224,pass,7 + + + +volo_d1_224,pass,7 + + + +xcit_large_24_p8_224,pass_due_to_skip,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv new file mode 100644 index 000000000000..431a91d10669 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv @@ -0,0 +1,377 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,pass,12 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +cm3leon_generate,pass,4 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_fcos_r_50_fpn,pass,21 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,pass,46 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_generate,fail_to_run,5 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,pass,5 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,17 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv new file mode 100644 index 000000000000..1e1a4be4149e --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv @@ -0,0 +1,285 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,fail_to_run,3 + + + +BERT_pytorch,pass,6 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,eager_fail_to_run,0 + + + +LearningToPaint,pass,6 + + + +Super_SloMo,pass,7 + + + +alexnet,pass,6 + + + +basic_gnn_edgecnn,pass,22 + + + +basic_gnn_gcn,pass,13 + + + +basic_gnn_gin,pass,7 + + + +basic_gnn_sage,pass,7 + + + +dcgan,pass,6 + + + +demucs,pass,9 + + + +densenet121,pass,6 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +dlrm,pass,6 + + + +drq,pass,6 + + + +fastNLP_Bert,pass,10 + + + +functorch_dp_cifar10,pass,7 + + + +functorch_maml_omniglot,pass,7 + + + +hf_Albert,pass,6 + + + +hf_Bart,pass,6 + + + +hf_Bert,pass,6 + + + +hf_Bert_large,pass,6 + + + +hf_BigBird,pass,52 + + + +hf_DistilBert,pass,6 + + + +hf_GPT2,pass,6 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,26 + + + +hf_T5_base,eager_2nd_run_OOM,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,6 + + + +hf_distil_whisper,model_fail_to_load,0 + + + +lennard_jones,pass,7 + + + +llava,model_fail_to_load,0 + + + +maml_omniglot,pass,7 + + + +mnasnet1_0,pass,7 + + + +mobilenet_v2,pass,6 + + + +mobilenet_v2_quantized_qat,eager_fail_to_run,0 + + + +mobilenet_v3_large,pass,7 + + + +moco,pass,11 + + + +nanogpt,pass,7 + + + +nvidia_deeprecommender,pass,7 + + + +opacus_cifar10,eager_fail_to_run,0 + + + +phlippe_densenet,pass,6 + + + +phlippe_resnet,pass,6 + + + +pytorch_CycleGAN_and_pix2pix,pass,6 + + + +pytorch_stargan,pass,6 + + + +pytorch_unet,pass_due_to_skip,7 + + + +resnet152,pass,7 + + + +resnet18,pass,6 + + + +resnet50,pass,6 + + + +resnet50_quantized_qat,eager_fail_to_run,0 + + + +resnext50_32x4d,pass,7 + + + +sam,eager_fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,6 + + + +soft_actor_critic,pass,6 + + + +squeezenet1_1,pass,6 + + + +stable_diffusion_text_encoder,pass,5 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,7 + + + +timm_regnet,pass,6 + + + +timm_resnest,pass,7 + + + +timm_vision_transformer,pass,6 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,6 + + + +torch_multimodal_clip,pass,7 + + + +tts_angular,pass,9 + + + +vgg16,pass,6 + + + +vision_maskrcnn,pass,34 + + + +yolov3,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_torchbench_inference.csv new file mode 100644 index 000000000000..ce271939b18c --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_torchbench_inference.csv @@ -0,0 +1,301 @@ +name,accuracy,graph_breaks + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,model_fail_to_load,0 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fcos_r_50_fpn,pass,23 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5_base,pass,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,pass,2 + + + +mobilenet_v3_large,pass,0 + + + +moco,model_fail_to_load,0 + + + +moondream,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,pass,2 + + + +resnext50_32x4d,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientdet,model_fail_to_load,0 + + + +timm_efficientnet,pass,0 + + + +timm_nfnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,3 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,28 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv new file mode 100644 index 000000000000..a5e00513153d --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,4 + + + +AlbertForQuestionAnswering,pass,5 + + + +AllenaiLongformerBase,pass,9 + + + +BartForCausalLM,pass,12 + + + +BartForConditionalGeneration,pass,24 + + + +BertForMaskedLM,pass,5 + + + +BertForQuestionAnswering,pass,5 + + + +BlenderbotForCausalLM,eager_fail_to_run,0 + + + +BlenderbotSmallForCausalLM,pass,12 + + + +BlenderbotSmallForConditionalGeneration,pass,24 + + + +CamemBert,pass,5 + + + +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0 + + + +DistilBertForMaskedLM,pass,5 + + + +DistilBertForQuestionAnswering,pass,5 + + + +DistillGPT2,pass,5 + + + +ElectraForCausalLM,pass,4 + + + +ElectraForQuestionAnswering,pass,5 + + + +GPT2ForSequenceClassification,pass,7 + + + +GoogleFnet,pass,5 + + + +LayoutLMForMaskedLM,pass,5 + + + +LayoutLMForSequenceClassification,pass,7 + + + +M2M100ForConditionalGeneration,pass,4 + + + +MBartForCausalLM,pass,12 + + + +MBartForConditionalGeneration,pass,24 + + + +MT5ForConditionalGeneration,pass,5 + + + +MegatronBertForCausalLM,pass,5 + + + +MegatronBertForQuestionAnswering,pass,5 + + + +MobileBertForMaskedLM,pass,3 + + + +MobileBertForQuestionAnswering,pass,3 + + + +OPTForCausalLM,pass,12 + + + +PLBartForCausalLM,pass,12 + + + +PLBartForConditionalGeneration,pass,29 + + + +PegasusForCausalLM,pass,12 + + + +PegasusForConditionalGeneration,pass,23 + + + +RobertaForCausalLM,pass,5 + + + +RobertaForQuestionAnswering,pass,5 + + + +Speech2Text2ForCausalLM,pass,12 + + + +T5ForConditionalGeneration,pass,5 + + + +T5Small,pass,5 + + + +TrOCRForCausalLM,pass,12 + + + +XGLMForCausalLM,pass,12 + + + +XLNetLMHeadModel,pass,5 + + + +YituTechConvBert,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv new file mode 100644 index 000000000000..b1a70e91cbae --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,6 + + + +beit_base_patch16_224,fail_accuracy,7 + + + +botnet26t_256,pass,6 + + + +cait_m36_384,eager_fail_to_run,0 + + + +coat_lite_mini,pass,6 + + + +convit_base,pass,7 + + + +convmixer_768_32,pass,5 + + + +convnext_base,pass,7 + + + +crossvit_9_240,pass,7 + + + +cspdarknet53,fail_accuracy,7 + + + +deit_base_distilled_patch16_224,pass,7 + + + +dla102,pass,7 + + + +dm_nfnet_f0,pass,6 + + + +dpn107,pass,6 + + + +eca_botnext26ts_256,pass,7 + + + +eca_halonext26ts,pass,7 + + + +ese_vovnet19b_dw,pass,7 + + + +fbnetc_100,pass,7 + + + +fbnetv3_b,pass,6 + + + +gernet_l,pass,6 + + + +ghostnet_100,pass,6 + + + +gluon_inception_v3,fail_accuracy,7 + + + +gmixer_24_224,pass,6 + + + +gmlp_s16_224,pass,7 + + + +hrnet_w18,pass,5 + + + +inception_v3,pass,6 + + + +jx_nest_base,pass,7 + + + +lcnet_050,pass,6 + + + +levit_128,pass,7 + + + +mixer_b16_224,pass,7 + + + +mixnet_l,pass,6 + + + +mnasnet_100,pass,7 + + + +mobilenetv2_100,pass,7 + + + +mobilenetv3_large_100,pass,7 + + + +mobilevit_s,pass,6 + + + +nfnet_l0,pass,7 + + + +pit_b_224,pass,6 + + + +pnasnet5large,pass,5 + + + +poolformer_m36,pass,6 + + + +regnety_002,pass,6 + + + +repvgg_a2,pass,7 + + + +res2net101_26w_4s,pass,6 + + + +res2net50_14w_8s,pass,6 + + + +res2next50,pass,6 + + + +resmlp_12_224,pass,6 + + + +resnest101e,pass,6 + + + +rexnet_100,pass,7 + + + +sebotnet33ts_256,pass,6 + + + +selecsls42b,pass,6 + + + +spnasnet_100,pass,7 + + + +swin_base_patch4_window7_224,pass,7 + + + +swsl_resnext101_32x16d,pass,6 + + + +tf_efficientnet_b0,pass,6 + + + +tf_mixnet_l,pass,6 + + + +tinynet_a,pass,6 + + + +tnt_s_patch16_224,pass,7 + + + +twins_pcpvt_base,pass,7 + + + +visformer_small,pass,7 + + + +vit_base_patch16_224,pass,7 + + + +volo_d1_224,pass,7 + + + +xcit_large_24_p8_224,pass_due_to_skip,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv new file mode 100644 index 000000000000..f652e5ffa91a --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv @@ -0,0 +1,377 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,pass,12 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +cm3leon_generate,pass,4 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_fcos_r_50_fpn,pass,21 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,fail_accuracy,46 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_generate,fail_to_run,5 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,pass,5 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,17 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv new file mode 100644 index 000000000000..a3c9c3915fc5 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv @@ -0,0 +1,285 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,fail_to_run,3 + + + +BERT_pytorch,pass,6 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,eager_fail_to_run,0 + + + +LearningToPaint,pass,6 + + + +Super_SloMo,pass,7 + + + +alexnet,pass,6 + + + +basic_gnn_edgecnn,pass,22 + + + +basic_gnn_gcn,pass,13 + + + +basic_gnn_gin,pass,7 + + + +basic_gnn_sage,pass,7 + + + +dcgan,pass,6 + + + +demucs,fail_to_run,4 + + + +densenet121,pass,6 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +dlrm,pass,6 + + + +drq,pass,6 + + + +fastNLP_Bert,pass,10 + + + +functorch_dp_cifar10,pass,7 + + + +functorch_maml_omniglot,pass,7 + + + +hf_Albert,pass,6 + + + +hf_Bart,pass,6 + + + +hf_Bert,pass,6 + + + +hf_Bert_large,pass,6 + + + +hf_BigBird,pass,52 + + + +hf_DistilBert,pass,6 + + + +hf_GPT2,pass,6 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,26 + + + +hf_T5_base,eager_2nd_run_OOM,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,6 + + + +hf_distil_whisper,model_fail_to_load,0 + + + +lennard_jones,pass,7 + + + +llava,model_fail_to_load,0 + + + +maml_omniglot,pass,7 + + + +mnasnet1_0,pass,7 + + + +mobilenet_v2,pass,6 + + + +mobilenet_v2_quantized_qat,eager_fail_to_run,0 + + + +mobilenet_v3_large,pass,7 + + + +moco,pass,11 + + + +nanogpt,pass,7 + + + +nvidia_deeprecommender,pass,7 + + + +opacus_cifar10,eager_fail_to_run,0 + + + +phlippe_densenet,pass,6 + + + +phlippe_resnet,fail_accuracy,6 + + + +pytorch_CycleGAN_and_pix2pix,pass,6 + + + +pytorch_stargan,pass,6 + + + +pytorch_unet,pass_due_to_skip,7 + + + +resnet152,pass,7 + + + +resnet18,pass,6 + + + +resnet50,pass,6 + + + +resnet50_quantized_qat,eager_fail_to_run,0 + + + +resnext50_32x4d,pass,7 + + + +sam,eager_fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,6 + + + +soft_actor_critic,pass,6 + + + +squeezenet1_1,pass,6 + + + +stable_diffusion_text_encoder,pass,5 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,7 + + + +timm_regnet,pass,6 + + + +timm_resnest,pass,7 + + + +timm_vision_transformer,pass,6 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,6 + + + +torch_multimodal_clip,pass,7 + + + +tts_angular,pass,9 + + + +vgg16,pass,6 + + + +vision_maskrcnn,pass,34 + + + +yolov3,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_training.csv new file mode 100644 index 000000000000..a5e00513153d --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_training.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,4 + + + +AlbertForQuestionAnswering,pass,5 + + + +AllenaiLongformerBase,pass,9 + + + +BartForCausalLM,pass,12 + + + +BartForConditionalGeneration,pass,24 + + + +BertForMaskedLM,pass,5 + + + +BertForQuestionAnswering,pass,5 + + + +BlenderbotForCausalLM,eager_fail_to_run,0 + + + +BlenderbotSmallForCausalLM,pass,12 + + + +BlenderbotSmallForConditionalGeneration,pass,24 + + + +CamemBert,pass,5 + + + +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0 + + + +DistilBertForMaskedLM,pass,5 + + + +DistilBertForQuestionAnswering,pass,5 + + + +DistillGPT2,pass,5 + + + +ElectraForCausalLM,pass,4 + + + +ElectraForQuestionAnswering,pass,5 + + + +GPT2ForSequenceClassification,pass,7 + + + +GoogleFnet,pass,5 + + + +LayoutLMForMaskedLM,pass,5 + + + +LayoutLMForSequenceClassification,pass,7 + + + +M2M100ForConditionalGeneration,pass,4 + + + +MBartForCausalLM,pass,12 + + + +MBartForConditionalGeneration,pass,24 + + + +MT5ForConditionalGeneration,pass,5 + + + +MegatronBertForCausalLM,pass,5 + + + +MegatronBertForQuestionAnswering,pass,5 + + + +MobileBertForMaskedLM,pass,3 + + + +MobileBertForQuestionAnswering,pass,3 + + + +OPTForCausalLM,pass,12 + + + +PLBartForCausalLM,pass,12 + + + +PLBartForConditionalGeneration,pass,29 + + + +PegasusForCausalLM,pass,12 + + + +PegasusForConditionalGeneration,pass,23 + + + +RobertaForCausalLM,pass,5 + + + +RobertaForQuestionAnswering,pass,5 + + + +Speech2Text2ForCausalLM,pass,12 + + + +T5ForConditionalGeneration,pass,5 + + + +T5Small,pass,5 + + + +TrOCRForCausalLM,pass,12 + + + +XGLMForCausalLM,pass,12 + + + +XLNetLMHeadModel,pass,5 + + + +YituTechConvBert,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_training.csv new file mode 100644 index 000000000000..e5464160d32f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_training.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,6 + + + +beit_base_patch16_224,pass,7 + + + +botnet26t_256,pass,6 + + + +cait_m36_384,eager_fail_to_run,0 + + + +coat_lite_mini,pass,6 + + + +convit_base,pass,7 + + + +convmixer_768_32,pass,5 + + + +convnext_base,pass,7 + + + +crossvit_9_240,pass,7 + + + +cspdarknet53,pass,7 + + + +deit_base_distilled_patch16_224,pass,7 + + + +dla102,pass,7 + + + +dm_nfnet_f0,pass,6 + + + +dpn107,pass,6 + + + +eca_botnext26ts_256,pass,7 + + + +eca_halonext26ts,pass,7 + + + +ese_vovnet19b_dw,pass,7 + + + +fbnetc_100,pass,7 + + + +fbnetv3_b,pass,6 + + + +gernet_l,pass,6 + + + +ghostnet_100,pass,6 + + + +gluon_inception_v3,pass,7 + + + +gmixer_24_224,pass,6 + + + +gmlp_s16_224,pass,7 + + + +hrnet_w18,pass,5 + + + +inception_v3,pass,6 + + + +jx_nest_base,pass,7 + + + +lcnet_050,pass,6 + + + +levit_128,pass,7 + + + +mixer_b16_224,pass,7 + + + +mixnet_l,pass,6 + + + +mnasnet_100,pass,7 + + + +mobilenetv2_100,pass,7 + + + +mobilenetv3_large_100,pass,7 + + + +mobilevit_s,pass,6 + + + +nfnet_l0,pass,7 + + + +pit_b_224,pass,6 + + + +pnasnet5large,pass,5 + + + +poolformer_m36,pass,6 + + + +regnety_002,pass,6 + + + +repvgg_a2,pass,7 + + + +res2net101_26w_4s,pass,6 + + + +res2net50_14w_8s,pass,6 + + + +res2next50,pass,6 + + + +resmlp_12_224,pass,6 + + + +resnest101e,pass,6 + + + +rexnet_100,pass,7 + + + +sebotnet33ts_256,pass,6 + + + +selecsls42b,pass,6 + + + +spnasnet_100,pass,7 + + + +swin_base_patch4_window7_224,pass,7 + + + +swsl_resnext101_32x16d,pass,6 + + + +tf_efficientnet_b0,pass,6 + + + +tf_mixnet_l,pass,6 + + + +tinynet_a,pass,6 + + + +tnt_s_patch16_224,pass,7 + + + +twins_pcpvt_base,pass,7 + + + +visformer_small,pass,7 + + + +vit_base_patch16_224,pass,7 + + + +volo_d1_224,pass,7 + + + +xcit_large_24_p8_224,pass_due_to_skip,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_inference.csv new file mode 100644 index 000000000000..20fb340690ac --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_inference.csv @@ -0,0 +1,381 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,pass,12 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +cm3leon_generate,pass,4 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_fcos_r_50_fpn,pass,21 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,pass,46 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_generate,pass,5 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,pass,5 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,pass,0 + + + +sam_fast,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,17 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv new file mode 100644 index 000000000000..cfc524426644 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv @@ -0,0 +1,289 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,pass,6 + + + +BERT_pytorch,pass,6 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,eager_fail_to_run,0 + + + +LearningToPaint,pass,6 + + + +Super_SloMo,pass,7 + + + +alexnet,pass,6 + + + +basic_gnn_edgecnn,pass,22 + + + +basic_gnn_gcn,pass,13 + + + +basic_gnn_gin,pass,7 + + + +basic_gnn_sage,pass,7 + + + +dcgan,pass,6 + + + +demucs,pass,9 + + + +densenet121,pass,6 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +dlrm,pass,6 + + + +drq,pass,6 + + + +fastNLP_Bert,pass,10 + + + +functorch_dp_cifar10,pass,7 + + + +functorch_maml_omniglot,pass,7 + + + +hf_Albert,pass,6 + + + +hf_Bart,pass,6 + + + +hf_Bert,pass,6 + + + +hf_Bert_large,pass,6 + + + +hf_BigBird,pass,52 + + + +hf_DistilBert,pass,6 + + + +hf_GPT2,pass,6 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,26 + + + +hf_T5_base,eager_2nd_run_OOM,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,6 + + + +hf_distil_whisper,model_fail_to_load,0 + + + +lennard_jones,pass,7 + + + +llava,model_fail_to_load,0 + + + +maml_omniglot,pass,7 + + + +mnasnet1_0,pass,7 + + + +mobilenet_v2,pass,6 + + + +mobilenet_v2_quantized_qat,eager_fail_to_run,0 + + + +mobilenet_v3_large,pass,7 + + + +moco,pass,11 + + + +nanogpt,pass,7 + + + +nvidia_deeprecommender,pass,7 + + + +opacus_cifar10,eager_fail_to_run,0 + + + +phlippe_densenet,pass,6 + + + +phlippe_resnet,pass,6 + + + +pytorch_CycleGAN_and_pix2pix,pass,6 + + + +pytorch_stargan,pass,6 + + + +pytorch_unet,pass_due_to_skip,7 + + + +resnet152,pass,7 + + + +resnet18,pass,6 + + + +resnet50,pass,6 + + + +resnet50_quantized_qat,eager_fail_to_run,0 + + + +resnext50_32x4d,pass,7 + + + +sam,eager_fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,6 + + + +soft_actor_critic,pass,6 + + + +speech_transformer,pass,16 + + + +squeezenet1_1,pass,6 + + + +stable_diffusion_text_encoder,pass,5 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,7 + + + +timm_regnet,pass,6 + + + +timm_resnest,pass,7 + + + +timm_vision_transformer,pass,6 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,6 + + + +torch_multimodal_clip,pass,7 + + + +tts_angular,pass,9 + + + +vgg16,pass,6 + + + +vision_maskrcnn,pass,34 + + + +yolov3,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv new file mode 100644 index 000000000000..a5e00513153d --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,4 + + + +AlbertForQuestionAnswering,pass,5 + + + +AllenaiLongformerBase,pass,9 + + + +BartForCausalLM,pass,12 + + + +BartForConditionalGeneration,pass,24 + + + +BertForMaskedLM,pass,5 + + + +BertForQuestionAnswering,pass,5 + + + +BlenderbotForCausalLM,eager_fail_to_run,0 + + + +BlenderbotSmallForCausalLM,pass,12 + + + +BlenderbotSmallForConditionalGeneration,pass,24 + + + +CamemBert,pass,5 + + + +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0 + + + +DistilBertForMaskedLM,pass,5 + + + +DistilBertForQuestionAnswering,pass,5 + + + +DistillGPT2,pass,5 + + + +ElectraForCausalLM,pass,4 + + + +ElectraForQuestionAnswering,pass,5 + + + +GPT2ForSequenceClassification,pass,7 + + + +GoogleFnet,pass,5 + + + +LayoutLMForMaskedLM,pass,5 + + + +LayoutLMForSequenceClassification,pass,7 + + + +M2M100ForConditionalGeneration,pass,4 + + + +MBartForCausalLM,pass,12 + + + +MBartForConditionalGeneration,pass,24 + + + +MT5ForConditionalGeneration,pass,5 + + + +MegatronBertForCausalLM,pass,5 + + + +MegatronBertForQuestionAnswering,pass,5 + + + +MobileBertForMaskedLM,pass,3 + + + +MobileBertForQuestionAnswering,pass,3 + + + +OPTForCausalLM,pass,12 + + + +PLBartForCausalLM,pass,12 + + + +PLBartForConditionalGeneration,pass,29 + + + +PegasusForCausalLM,pass,12 + + + +PegasusForConditionalGeneration,pass,23 + + + +RobertaForCausalLM,pass,5 + + + +RobertaForQuestionAnswering,pass,5 + + + +Speech2Text2ForCausalLM,pass,12 + + + +T5ForConditionalGeneration,pass,5 + + + +T5Small,pass,5 + + + +TrOCRForCausalLM,pass,12 + + + +XGLMForCausalLM,pass,12 + + + +XLNetLMHeadModel,pass,5 + + + +YituTechConvBert,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_training.csv new file mode 100644 index 000000000000..ae860db793c9 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_training.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,6 + + + +beit_base_patch16_224,fail_accuracy,7 + + + +botnet26t_256,pass,6 + + + +cait_m36_384,eager_fail_to_run,0 + + + +coat_lite_mini,pass,6 + + + +convit_base,pass,7 + + + +convmixer_768_32,pass,5 + + + +convnext_base,pass,7 + + + +crossvit_9_240,pass,7 + + + +cspdarknet53,pass,7 + + + +deit_base_distilled_patch16_224,pass,7 + + + +dla102,pass,7 + + + +dm_nfnet_f0,pass,6 + + + +dpn107,pass,6 + + + +eca_botnext26ts_256,pass,7 + + + +eca_halonext26ts,pass,7 + + + +ese_vovnet19b_dw,pass,7 + + + +fbnetc_100,pass,7 + + + +fbnetv3_b,pass,6 + + + +gernet_l,pass,6 + + + +ghostnet_100,pass,6 + + + +gluon_inception_v3,pass,7 + + + +gmixer_24_224,pass,6 + + + +gmlp_s16_224,pass,7 + + + +hrnet_w18,pass,5 + + + +inception_v3,pass,6 + + + +jx_nest_base,pass,7 + + + +lcnet_050,pass,6 + + + +levit_128,pass,7 + + + +mixer_b16_224,pass,7 + + + +mixnet_l,pass,6 + + + +mnasnet_100,pass,7 + + + +mobilenetv2_100,pass,7 + + + +mobilenetv3_large_100,pass,7 + + + +mobilevit_s,pass,6 + + + +nfnet_l0,pass,7 + + + +pit_b_224,pass,6 + + + +pnasnet5large,pass,5 + + + +poolformer_m36,pass,6 + + + +regnety_002,pass,6 + + + +repvgg_a2,pass,7 + + + +res2net101_26w_4s,pass,6 + + + +res2net50_14w_8s,pass,6 + + + +res2next50,pass,6 + + + +resmlp_12_224,pass,6 + + + +resnest101e,pass,6 + + + +rexnet_100,pass,7 + + + +sebotnet33ts_256,pass,6 + + + +selecsls42b,pass,6 + + + +spnasnet_100,pass,7 + + + +swin_base_patch4_window7_224,pass,7 + + + +swsl_resnext101_32x16d,pass,6 + + + +tf_efficientnet_b0,pass,6 + + + +tf_mixnet_l,pass,6 + + + +tinynet_a,pass,6 + + + +tnt_s_patch16_224,pass,7 + + + +twins_pcpvt_base,pass,7 + + + +visformer_small,pass,7 + + + +vit_base_patch16_224,pass,7 + + + +volo_d1_224,pass,7 + + + +xcit_large_24_p8_224,pass_due_to_skip,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_inference.csv new file mode 100644 index 000000000000..108bc6543aa9 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_inference.csv @@ -0,0 +1,381 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,pass,12 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +cm3leon_generate,pass,4 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_fcos_r_50_fpn,pass,21 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,fail_accuracy,46 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_generate,pass,5 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,pass,5 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,pass,0 + + + +sam_fast,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,17 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv new file mode 100644 index 000000000000..02411bef6cc5 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv @@ -0,0 +1,289 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,pass,6 + + + +BERT_pytorch,pass,6 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,eager_fail_to_run,0 + + + +LearningToPaint,pass,6 + + + +Super_SloMo,pass,7 + + + +alexnet,pass,6 + + + +basic_gnn_edgecnn,pass,22 + + + +basic_gnn_gcn,pass,13 + + + +basic_gnn_gin,pass,7 + + + +basic_gnn_sage,pass,7 + + + +dcgan,pass,6 + + + +demucs,pass,9 + + + +densenet121,pass,6 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +dlrm,pass,6 + + + +drq,pass,6 + + + +fastNLP_Bert,pass,10 + + + +functorch_dp_cifar10,pass,7 + + + +functorch_maml_omniglot,pass,7 + + + +hf_Albert,pass,6 + + + +hf_Bart,pass,6 + + + +hf_Bert,pass,6 + + + +hf_Bert_large,pass,6 + + + +hf_BigBird,pass,52 + + + +hf_DistilBert,pass,6 + + + +hf_GPT2,pass,6 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,26 + + + +hf_T5_base,eager_2nd_run_OOM,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,6 + + + +hf_distil_whisper,model_fail_to_load,0 + + + +lennard_jones,pass,7 + + + +llava,model_fail_to_load,0 + + + +maml_omniglot,pass,7 + + + +mnasnet1_0,pass,7 + + + +mobilenet_v2,pass,6 + + + +mobilenet_v2_quantized_qat,eager_fail_to_run,0 + + + +mobilenet_v3_large,pass,7 + + + +moco,pass,11 + + + +nanogpt,pass,7 + + + +nvidia_deeprecommender,pass,7 + + + +opacus_cifar10,eager_fail_to_run,0 + + + +phlippe_densenet,pass,6 + + + +phlippe_resnet,fail_accuracy,6 + + + +pytorch_CycleGAN_and_pix2pix,pass,6 + + + +pytorch_stargan,pass,6 + + + +pytorch_unet,pass_due_to_skip,7 + + + +resnet152,pass,7 + + + +resnet18,pass,6 + + + +resnet50,pass,6 + + + +resnet50_quantized_qat,eager_fail_to_run,0 + + + +resnext50_32x4d,pass,7 + + + +sam,eager_fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,6 + + + +soft_actor_critic,pass,6 + + + +speech_transformer,pass,16 + + + +squeezenet1_1,pass,6 + + + +stable_diffusion_text_encoder,pass,5 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,7 + + + +timm_regnet,pass,6 + + + +timm_resnest,pass,7 + + + +timm_vision_transformer,pass,6 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,6 + + + +torch_multimodal_clip,pass,7 + + + +tts_angular,pass,9 + + + +vgg16,pass,6 + + + +vision_maskrcnn,pass,34 + + + +yolov3,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py b/benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py new file mode 100644 index 000000000000..5d73cf658c17 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py @@ -0,0 +1,172 @@ +""" +Update commited CSV files used as reference points by dynamo/inductor CI. + +Currently only cares about graph breaks, so only saves those columns. + +Hardcodes a list of job names and artifacts per job, but builds the lookup +by querying github sha and finding associated github actions workflow ID and CI jobs, +downloading artifact zips, extracting CSVs and filtering them. + +Usage: + +python benchmarks/dynamo/ci_expected_accuracy.py + +Known limitations: +- doesn't handle 'retry' jobs in CI, if the same hash has more than one set of artifacts, gets the first one +""" + +import argparse +import json +import os +import pathlib +import subprocess +import sys +import urllib +from io import BytesIO +from itertools import product +from urllib.request import urlopen +from zipfile import ZipFile + +import pandas as pd +import requests + +# Note: the public query url targets this rockset lambda: +# https://console.rockset.com/lambdas/details/commons.artifacts +ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/4ca0033e-0117-41f5-b043-59cde19eff35" +CSV_LINTER = str( + pathlib.Path(__file__).absolute().parent.parent.parent.parent + / "tools/linter/adapters/no_merge_conflict_csv_linter.py" +) + + +def query_job_sha(repo, sha): + params = { + "parameters": [ + {"name": "sha", "type": "string", "value": sha}, + {"name": "repo", "type": "string", "value": repo}, + ] + } + + r = requests.post(url=ARTIFACTS_QUERY_URL, json=params) + data = r.json() + return data["results"] + + +def parse_job_name(job_str): + return (part.strip() for part in job_str.split("/")) + + +def parse_test_str(test_str): + return (part.strip() for part in test_str[6:].strip(")").split(",")) + + +S3_BASE_URL = "https://gha-artifacts.s3.amazonaws.com" + + +def get_artifacts_urls(results, suites): + urls = {} + for r in results: + if ( + r["workflowName"] in ("inductor", "inductor-periodic") + and "test" in r["jobName"] + ): + config_str, test_str = parse_job_name(r["jobName"]) + suite, shard_id, num_shards, machine, *_ = parse_test_str(test_str) + workflowId = r["workflowId"] + id = r["id"] + runAttempt = r["runAttempt"] + + if suite in suites: + artifact_filename = f"test-reports-test-{suite}-{shard_id}-{num_shards}-{machine}_{id}.zip" + s3_url = f"{S3_BASE_URL}/{repo}/{workflowId}/{runAttempt}/artifact/{artifact_filename}" + urls[(suite, int(shard_id))] = s3_url + print(f"{suite} {shard_id}, {num_shards}: {s3_url}") + return urls + + +def normalize_suite_filename(suite_name): + strs = suite_name.split("_") + subsuite = strs[-1] + if "timm" in subsuite: + subsuite = subsuite.replace("timm", "timm_models") + + return subsuite + + +def download_artifacts_and_extract_csvs(urls): + dataframes = {} + for (suite, shard), url in urls.items(): + try: + resp = urlopen(url) + subsuite = normalize_suite_filename(suite) + artifact = ZipFile(BytesIO(resp.read())) + for phase in ("training", "inference"): + name = f"test/test-reports/{phase}_{subsuite}.csv" + try: + df = pd.read_csv(artifact.open(name)) + df["graph_breaks"] = df["graph_breaks"].fillna(0).astype(int) + prev_df = dataframes.get((suite, phase), None) + dataframes[(suite, phase)] = ( + pd.concat([prev_df, df]) if prev_df is not None else df + ) + except KeyError: + print( + f"Warning: Unable to find {name} in artifacts file from {url}, continuing" + ) + except urllib.error.HTTPError: + print(f"Unable to download {url}, perhaps the CI job isn't finished?") + + return dataframes + + +def write_filtered_csvs(root_path, dataframes): + for (suite, phase), df in dataframes.items(): + out_fn = os.path.join(root_path, f"{suite}_{phase}.csv") + df.to_csv(out_fn, index=False, columns=["name", "accuracy", "graph_breaks"]) + apply_lints(out_fn) + + +def apply_lints(filename): + patch = json.loads(subprocess.check_output([sys.executable, CSV_LINTER, filename])) + if patch.get("replacement"): + with open(filename) as fd: + data = fd.read().replace(patch["original"], patch["replacement"]) + with open(filename, "w") as fd: + fd.write(data) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + + parser.add_argument("sha") + args = parser.parse_args() + + repo = "pytorch/pytorch" + + suites = { + f"{a}_{b}" + for a, b in product( + [ + "aot_eager", + "aot_inductor", + "cpu_inductor", + "dynamic_aot_eager", + "dynamic_cpu_inductor", + "dynamic_inductor", + "dynamo_eager", + "inductor", + ], + ["huggingface", "timm", "torchbench"], + ) + } + + root_path = "benchmarks/dynamo/ci_expected_accuracy/" + assert os.path.exists(root_path), f"cd and ensure {root_path} exists" + + results = query_job_sha(repo, args.sha) + urls = get_artifacts_urls(results, suites) + dataframes = download_artifacts_and_extract_csvs(urls) + write_filtered_csvs(root_path, dataframes) + print("Success. Now, confirm the changes to .csvs and `git add` them if satisfied.") From 225ec08e35c42ef3108dd7c6810f82305841a41e Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 29 May 2024 20:53:46 -0700 Subject: [PATCH 174/706] Fix typo in .ci/docker/ubuntu-cuda/Dockerfile (#127503) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127503 Approved by: https://github.com/nWEIdia, https://github.com/Skylion007 --- .ci/docker/ubuntu-cuda/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ubuntu-cuda/Dockerfile b/.ci/docker/ubuntu-cuda/Dockerfile index f96ee5e3b107..cb3ea502d231 100644 --- a/.ci/docker/ubuntu-cuda/Dockerfile +++ b/.ci/docker/ubuntu-cuda/Dockerfile @@ -152,7 +152,7 @@ RUN rm install_cusparselt.sh RUN if [ -h /usr/local/cuda-11.6/cuda-11.6 ]; then rm /usr/local/cuda-11.6/cuda-11.6; fi RUN if [ -h /usr/local/cuda-11.7/cuda-11.7 ]; then rm /usr/local/cuda-11.7/cuda-11.7; fi RUN if [ -h /usr/local/cuda-12.1/cuda-12.1 ]; then rm /usr/local/cuda-12.1/cuda-12.1; fi -RUN if [ -h /usr/local/cuda-12.1/cuda-12.4 ]; then rm /usr/local/cuda-12.1/cuda-12.4; fi +RUN if [ -h /usr/local/cuda-12.4/cuda-12.4 ]; then rm /usr/local/cuda-12.4/cuda-12.4; fi USER jenkins CMD ["bash"] From 58b461d57adef45949418e0f594dd4f2892f8ece Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 31 May 2024 16:51:36 +0000 Subject: [PATCH 175/706] Revert "[ROCm] Update triton pin to fix libtanh issue (#125396)" This reverts commit 19333d1eb9b8965edd6c8a52fd59b5c67b4fb523. Reverted https://github.com/pytorch/pytorch/pull/125396 on behalf of https://github.com/atalman due to Broke nightly builds ([comment](https://github.com/pytorch/pytorch/pull/125396#issuecomment-2142638237)) --- .ci/docker/ci_commit_pins/triton-rocm.txt | 2 +- test/inductor/test_cpu_cpp_wrapper.py | 14 ++------------ test/inductor/test_triton_kernels.py | 1 - 3 files changed, 3 insertions(+), 14 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton-rocm.txt b/.ci/docker/ci_commit_pins/triton-rocm.txt index 15f681977a12..2df035af1fdd 100644 --- a/.ci/docker/ci_commit_pins/triton-rocm.txt +++ b/.ci/docker/ci_commit_pins/triton-rocm.txt @@ -1 +1 @@ -01cbe5045a6898c9a925f01435c8277b2fe6afcc +bbe6246e37d8aa791c67daaf9d9d61b26c9ccfdc diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 0f7430ad2696..477193664431 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -9,7 +9,7 @@ from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) -from torch.testing._internal.common_utils import IS_MACOS, slowTest, TEST_WITH_ROCM +from torch.testing._internal.common_utils import IS_MACOS, slowTest from torch.testing._internal.inductor_utils import HAS_CPU @@ -68,17 +68,7 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): ("cpp_wrapper",), is_skip=True ), } -if TEST_WITH_ROCM: - test_failures_cpp_wrapper.update( - { - "test_linear_packed": test_torchinductor.TestFailure( - ("cpp_wrapper"), is_skip=True - ), - "test_linear_packed_dynamic_shapes": test_torchinductor.TestFailure( - ("cpp_wrapper"), is_skip=True - ), - } - ) + if config.abi_compatible: xfail_list = [ "test_conv2d_binary_inplace_fusion_failed_cpu", diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 41b20188f635..accab8beae6b 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -586,7 +586,6 @@ def call_triton( self.assertEqual(int_result, resulti) @requires_cuda - @skipIfRocm def test_triton_kernel_constants(self): @triton.jit def mulC_kernel( From 8bf2c0a2030f56ab32b537fec79d4ac1a4f5f3a9 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Fri, 31 May 2024 17:01:47 +0000 Subject: [PATCH 176/706] [BE][Ez]: Update ruff to 0.4.6 (#127614) Update ruff linter to 0.4.6. Uneventful PR that fixes bugs and reduces false positives. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127614 Approved by: https://github.com/albanD --- .lintrunner.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 1e0a2f37fcf4..eca2af96b761 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -2120,7 +2120,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'ruff==0.4.5', + 'ruff==0.4.6', ] is_formatter = true From 1699edaabb4b6b94502c76c8240b68b7bb392b33 Mon Sep 17 00:00:00 2001 From: Iris Z <31293777+wz337@users.noreply.github.com> Date: Fri, 31 May 2024 17:06:36 +0000 Subject: [PATCH 177/706] [DeviceMesh] Adding nD slicing support back (#127465) Fixes #126530 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127465 Approved by: https://github.com/wconstab, https://github.com/wanchaol --- test/distributed/test_device_mesh.py | 46 ++++++++++-- torch/distributed/device_mesh.py | 100 ++++++++++++++++++--------- 2 files changed, 107 insertions(+), 39 deletions(-) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 8f70ee2f0b7d..03457de14b68 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -420,16 +420,16 @@ def world_size(self): @with_comms def test_raises_no_mesh_dim_found(self): - with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."): + with self.assertRaisesRegex( + RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!" + ): mesh = init_device_mesh(self.device_type, (2, 4)) child_mesh = mesh["DP"] @with_comms def test_raises_invalid_mesh_dim_name(self): - child_mesh_dim_name = "PP" - with self.assertRaisesRegex( - KeyError, f"Mesh dimension '{child_mesh_dim_name}' does not exist." - ): + child_mesh_dim_name = ("PP",) + with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): mesh_dim_names = ("DP", "TP") mesh = init_device_mesh( self.device_type, (2, 4), mesh_dim_names=mesh_dim_names @@ -437,7 +437,7 @@ def test_raises_invalid_mesh_dim_name(self): child_mesh = mesh[child_mesh_dim_name] @with_comms - def test_get_item(self): + def test_get_item_2d(self): mesh_shape = (2, 4) mesh_dim_names = ("DP", "TP") mesh_2d = init_device_mesh( @@ -467,9 +467,41 @@ def test_get_item_1d(self): dp_mesh = mesh["dp"] self.assertEqual(dp_mesh, mesh) - with self.assertRaisesRegex(RuntimeError, "Invalid mesh_dim_name"): + with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): dp_mesh = mesh["dim0"] + @with_comms + def test_get_item_3d(self): + mesh_shape = (2, 2, 2) + mesh_dim_names = ("Replicate", "Shard", "TP") + mesh_3d = init_device_mesh( + self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names + ) + + tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]] + tp_group_idx = int(self.rank / 2) + self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx]) + + shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]] + shard_group_idx = self.rank % 2 + self.rank // 4 * 2 + self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx]) + + replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]] + replicate_group_idx = self.rank % 4 + self.assertEqual( + mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx] + ) + + # We support both UX for nD slicing. + # mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"] + hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]] + hsdp_mesh_2 = mesh_3d["Replicate", "Shard"] + hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]] + hsdp_group_idx = self.rank % 2 + self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx]) + self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx]) + self.assertEqual(hsdp_mesh_1, hsdp_mesh_2) + @with_comms def test_cache_and_reuse_submesh_slice_result(self): mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp")) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 57b8fa1cf564..a0e7b7acddeb 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -69,31 +69,46 @@ def get_current_mesh(self) -> "DeviceMesh": return self.mesh_stack[-1] def create_child_mesh( - self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str + self, parent_mesh: "DeviceMesh", submesh_dim_names: Tuple[str] ) -> "DeviceMesh": - # swap the current dim to the last dim then reshape to flatten out other - # dims, so we can just extract the list of ranks which contains cur_rank. - cur_rank = device_mesh.get_rank() - pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( - -1, device_mesh.mesh.size(mesh_dim) - ) + # submesh_dims are the mesh dimension of the submesh in the parent mesh. + submesh_dims = [ + not_none(parent_mesh.mesh_dim_names).index(mesh_dim_name) + for mesh_dim_name in submesh_dim_names + ] + submesh_dim_sizes = [ + parent_mesh.mesh.size(mesh_dim) for mesh_dim in submesh_dims + ] - for mesh_1d in pg_ranks_by_dim: - sub_mesh = DeviceMesh( - device_mesh.device_type, - mesh_1d, - mesh_dim_names=(mesh_dim_name,), + mesh_dims_remained = list(range(parent_mesh.mesh.ndim)) + for submesh_dim in submesh_dims: + mesh_dims_remained.remove(submesh_dim) + + # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *sub_mesh_dims] + # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with + # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank. + pg_ranks_by_dim = parent_mesh.mesh.permute( + *mesh_dims_remained, *submesh_dims + ).reshape(-1, *submesh_dim_sizes) + + cur_rank = parent_mesh.get_rank() + for mesh_nd in pg_ranks_by_dim: + submesh = DeviceMesh( + parent_mesh.device_type, + mesh_nd, + mesh_dim_names=submesh_dim_names, _init_backend=False, ) - if cur_rank in mesh_1d: - res_sub_mesh = sub_mesh + if cur_rank in mesh_nd: + res_submesh = submesh + + res_submesh._parent_mesh = parent_mesh # type: ignore[possibly-undefined] + res_submesh._dim_group_infos = [ + parent_mesh._dim_group_infos[mesh_dim] for mesh_dim in submesh_dims # type: ignore[possibly-undefined] + ] + self.child_to_parent_mapping[res_submesh] = parent_mesh - res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined] - res_sub_mesh._parent_mesh = device_mesh - # Assign the current DeviceMesh as the parent of the child DeviceMesh. - # We need to update the mappings after the child mesh hash update. - self.child_to_parent_mapping[res_sub_mesh] = device_mesh - return res_sub_mesh + return res_submesh def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]: return self.child_to_parent_mapping.get(device_mesh, None) @@ -367,14 +382,14 @@ def __eq__(self, other: object) -> bool: and self._thread_id == other._thread_id ) - def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh": + def __getitem__(self, mesh_dim_names: Union[str, Tuple[str]]) -> "DeviceMesh": """ Slice the current DeviceMesh based on the mesh_dim_name given to create a child DeviceMesh. Args: - mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh - to create a child DeviceMesh for. + mesh_dim_name (Union[str, Tuple[str]]): the name or the tuple of names of the + mesh dimension of the parent DeviceMesh to create the child DeviceMesh for. Returns: A :class:`DeviceMesh` object @@ -395,16 +410,37 @@ def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh": >>> # of cross-host(dim 0), and within-host (dim 1). >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) """ - if self.mesh.ndim == 1: - if self.mesh_dim_names and mesh_dim_name == self.mesh_dim_names[0]: - return self - else: - raise RuntimeError( - f"Invalid mesh_dim_name {mesh_dim_name} specified." - ) + if not self.mesh_dim_names: + raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") + + mesh_dim_names = ( + (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names + ) + + error_msg = ( + f"Invalid mesh_dim_name {mesh_dim_names} specified. " + f"Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}." + ) + + if mesh_dim_names == self.mesh_dim_names: + return self + elif len(mesh_dim_names) > len(self.mesh_dim_names) or not all( + mesh_dim_name in self.mesh_dim_names for mesh_dim_name in mesh_dim_names + ): + raise KeyError(error_msg) + # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names + # of the current DeviceMesh. + else: + outermost_dim_name = mesh_dim_names[0] + outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name) + for i, j in zip( + mesh_dim_names, + self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)], + ): + if i != j: + raise KeyError(error_msg) - mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim_name) - submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name) + submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) return submesh def get_group( From 8d7393cb5e0c2c1928b546631135b282ee263e20 Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Mon, 27 May 2024 01:50:52 +0000 Subject: [PATCH 178/706] Update triton-xpu commit pin merge rules for XPU (#127203) Add the ".ci/docker/ci_commit_pins/triton-xpu.txt" to the XPU merge rules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127203 Approved by: https://github.com/atalman --- .github/merge_rules.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index db0ec3c51aa7..d69fff16f305 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -245,6 +245,7 @@ - torch/xpu/** - test/xpu/** - third_party/xpu.txt + - .ci/docker/ci_commit_pins/triton-xpu.txt approved_by: - EikanWang - jgong5 From 3e66052e168b6e0558821e404df663083edc3b1e Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 31 May 2024 17:29:06 +0000 Subject: [PATCH 179/706] Improve python3 discovery code in CMake (#127600) The improvement is based on my comments in #124613 and it also fixes the current linux-s390x-binary-manywheel CI failures. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127600 Approved by: https://github.com/Skylion007 --- cmake/Dependencies.cmake | 65 +++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 6fb3d967301f..1e4c4262fba1 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -834,53 +834,42 @@ else() endif() include_directories(SYSTEM ${EIGEN3_INCLUDE_DIR}) -# ---[ Python + Numpy -if(BUILD_PYTHON) - # If not given a Python installation, then use the current active Python - if(NOT Python_EXECUTABLE) - execute_process( - COMMAND "which" "python3" RESULT_VARIABLE _exitcode OUTPUT_VARIABLE _py_exe) - if(${_exitcode} EQUAL 0) - if(NOT MSVC) - string(STRIP ${_py_exe} Python_EXECUTABLE) - endif() - message(STATUS "Setting Python to ${Python_EXECUTABLE}") - endif() - endif() - # Check that Python works - set(PYTHON_VERSION) - if(DEFINED Python_EXECUTABLE) - execute_process( - COMMAND "${Python_EXECUTABLE}" "--version" - RESULT_VARIABLE _exitcode OUTPUT_VARIABLE PYTHON_VERSION) - if(NOT _exitcode EQUAL 0) - message(FATAL_ERROR "The Python executable ${Python_EXECUTABLE} cannot be run. Make sure that it is an absolute path.") - endif() - if(PYTHON_VERSION) - string(REGEX MATCH "([0-9]+)\\.([0-9]+)" PYTHON_VERSION ${PYTHON_VERSION}) +# ---[ Python Interpreter +# If not given a Python installation, then use the current active Python +if(NOT Python_EXECUTABLE) + execute_process( + COMMAND "which" "python3" RESULT_VARIABLE _exitcode OUTPUT_VARIABLE _py_exe) + if(${_exitcode} EQUAL 0) + if(NOT MSVC) + string(STRIP ${_py_exe} Python_EXECUTABLE) endif() + message(STATUS "Setting Python to ${Python_EXECUTABLE}") endif() +endif() - # These should fill in the rest of the variables, like versions, but resepct - # the variables we set above +if(BUILD_PYTHON) + set(PYTHON_COMPONENTS Development) if(USE_NUMPY) - find_package(Python COMPONENTS Interpreter Development NumPy) - else() - find_package(Python COMPONENTS Interpreter Development) + list(APPEND PYTHON_COMPONENTS NumPy) endif() + find_package(Python COMPONENTS Interpreter OPTIONAL_COMPONENTS ${PYTHON_COMPONENTS}) +else() + find_package(Python COMPONENTS Interpreter) +endif() - if(NOT Python_Development_FOUND) - message(FATAL_ERROR - "Python development libraries could not be found.") - endif() +if(NOT Python_Interpreter_FOUND) + message(FATAL_ERROR "Python3 could not be found.") +endif() - if(${Python_VERSION} VERSION_LESS 3.8) - message(FATAL_ERROR - "Found Python libraries version ${Python_VERSION}. Python < 3.8 is no longer supported by PyTorch.") - endif() +if(${Python_VERSION} VERSION_LESS 3.8) + message(FATAL_ERROR + "Found Python libraries version ${Python_VERSION}. Python < 3.8 is no longer supported by PyTorch.") +endif() - if(Python_Interpreter_FOUND) +# ---[ Python + Numpy +if(BUILD_PYTHON) + if(Python_Development_FOUND) if(USE_NUMPY) if(NOT Python_NumPy_FOUND) message(WARNING "NumPy could not be found. Not building with NumPy. Suppress this warning with -DUSE_NUMPY=OFF") From b2f5fd8efbcceecf079ac8aa67ec30693c3f2b83 Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 31 May 2024 17:46:13 +0000 Subject: [PATCH 180/706] [ts_converter] Basic support for prim::If conversion (#127336) Script module: ``` graph(%self : __torch__.M, %x.1 : Tensor, %y.1 : Tensor): %11 : int = prim::Constant[value=1]() %5 : bool = aten::Bool(%x.1) # /data/users/angelayi/pytorch2/test/export/test_converter.py:27:19 %21 : Tensor = prim::If(%5) # /data/users/angelayi/pytorch2/test/export/test_converter.py:27:16 block0(): %8 : Tensor = aten::mul(%y.1, %y.1) # /data/users/angelayi/pytorch2/test/export/test_converter.py:28:27 -> (%8) block1(): %12 : Tensor = aten::add(%y.1, %y.1, %11) # /data/users/angelayi/pytorch2/test/export/test_converter.py:30:27 -> (%12) return (%21) ``` ExportedProgram: ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x_1: "b8[]", y_1: "i64[]"): # File: .23:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, [l_args_3_0_]); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(x_1, true_graph_0, false_graph_0, [y_1]); x_1 = true_graph_0 = false_graph_0 = y_1 = None return (conditional,) class (torch.nn.Module): def forward(self, y_1: "i64[]"): # File: .20:6 in forward, code: mul_tensor = torch.ops.aten.mul.Tensor(l_args_3_0__1, l_args_3_0__1); l_args_3_0__1 = None mul: "i64[]" = torch.ops.aten.mul.Tensor(y_1, y_1); y_1 = None return mul class (torch.nn.Module): def forward(self, y_1: "i64[]"): # File: .21:6 in forward, code: add_tensor = torch.ops.aten.add.Tensor(l_args_3_0__1, l_args_3_0__1, alpha = 1); l_args_3_0__1 = None add: "i64[]" = torch.ops.aten.add.Tensor(y_1, y_1); y_1 = None return add ``` This PR also adds support for TupleIndex and incorporates some changes from https://github.com/pytorch/pytorch/pull/127341 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127336 Approved by: https://github.com/BoyuanFeng --- test/export/test_converter.py | 46 ++++++++++- torch/_export/converter.py | 147 +++++++++++++++++++++++++++++----- 2 files changed, 172 insertions(+), 21 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 64cea8cf8ac9..9ab6161fe4d7 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -8,6 +8,7 @@ from torch._dynamo.test_case import TestCase from torch._export.converter import TS2EPConverter +from torch.export import ExportedProgram from torch.testing._internal.common_utils import run_tests @@ -15,7 +16,7 @@ class TestConverter(TestCase): - def _check_equal_ts_ep_converter(self, mod, inp): + def _check_equal_ts_ep_converter(self, mod, inp) -> ExportedProgram: ts_model = torch.jit.script(mod) ep = TS2EPConverter(ts_model, inp).convert() ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) @@ -24,6 +25,7 @@ def _check_equal_ts_ep_converter(self, mod, inp): for ep_t, orig_t in zip(ep_out, orig_out): self.assertEqual(ep_t.shape, orig_t.shape) self.assertTrue(torch.allclose(ep_t, orig_t)) + return ep def test_ts2ep_converter_basic(self): class MSingle(torch.nn.Module): @@ -108,6 +110,48 @@ def forward(self, x): inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),) self._check_equal_ts_ep_converter(Module(), inp) + def test_convert_if_basic(self): + class M(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + if x: + return y * y + else: + return y + y + + inp = (torch.tensor(True), torch.tensor(4)) + ep = self._check_equal_ts_ep_converter(M(), inp) + + torch.testing.assert_close( + ep.module()(torch.tensor(False), torch.tensor(4)), + M()(torch.tensor(False), torch.tensor(4)), + ) + + def test_convert_if_multiple_out(self): + class M(torch.nn.Module): + def true_fn(self, y, z): + return (z * z, z + z) + + def false_fn(self, y, z): + return (y * y * y, y + y) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + z = y * y + + if x: + res = self.true_fn(y, z) + else: + res = self.false_fn(y, z) + + return res[0] + res[1] + + inp = (torch.tensor(True), torch.tensor(4)) + ep = self._check_equal_ts_ep_converter(M(), inp) + + torch.testing.assert_close( + ep.module()(torch.tensor(False), torch.tensor(4)), + M()(torch.tensor(False), torch.tensor(4)), + ) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 7e6812985bad..fb10ac6a1e66 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -1,3 +1,4 @@ +import operator from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch @@ -54,22 +55,16 @@ def get_op_overload(node: torch._C.Node): return op_overload -class TS2EPConverter: - # TorchScript model to ExportedProgram converter +class TS2FXGraphConverter: def __init__( self, - ts_model, - sample_args: Tuple[Any, ...], - sample_kwargs: Optional[Dict[str, Any]] = None, + ts_graph: Union[torch._C.Graph, torch._C.Block], + param_names: Set[str], + buffer_names: Set[str], ): - self.ts_model = ts_model - self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args) - - self.sample_args = sample_args - self.sample_kwargs = sample_kwargs - - self.param_names: Set[str] = {name for name, _ in ts_model.named_parameters()} - self.buffer_names: Set[str] = {name for name, _ in ts_model.named_buffers()} + self.ts_graph = ts_graph + self.param_names = param_names + self.buffer_names = buffer_names self.fx_graph: torch.fx.Graph = torch.fx.Graph() self.input_specs: List[InputSpec] = [] @@ -82,6 +77,13 @@ def __init__( self.attribute_map: Dict[str, Any] = {} self.tensor_constants: Dict[str, torch.Tensor] = {} + self.subgraphs: Dict[str, torch.fx.GraphModule] = {} + + def add_subgraph(self, subgraph) -> str: + name = f"subgraph_{len(self.subgraphs)}" + self.subgraphs[name] = subgraph + return name + def get_args_kwargs(self, node: torch._C.Node, schema): args = [] kwargs = {} @@ -110,7 +112,7 @@ def get_fx_value(self, value: torch._C.Value): else: raise ValueError(f"Input {value_name} not found") - def convert(self) -> ExportedProgram: + def convert(self) -> torch.fx.GraphModule: self.convert_graph_inputs() for node in self.ts_graph.nodes(): @@ -118,14 +120,13 @@ def convert(self) -> ExportedProgram: self.convert_graph_outputs() - gm = torch.fx.GraphModule({}, self.fx_graph) + gm = torch.fx.GraphModule(self.subgraphs, self.fx_graph) inplace_optimize_sym_size_div(gm) gm.graph.lint() - ep = self.retrace_as_exported_program(gm) - return ep + return gm def convert_graph_inputs(self): for graph_input in self.ts_graph.inputs(): @@ -234,7 +235,10 @@ def get_attr(name: str): ) def convert_aten_op(self, node: torch._C.Node): - target = get_op_overload(node) + try: + target = get_op_overload(node) + except Exception as e: + raise RuntimeError(f"Unsupported node {node.kind()}") from e if target is torch.ops.aten.size.int: target = torch.ops.aten.sym_size.int @@ -280,6 +284,13 @@ def convert_prim_DictConstruct(self, node: torch._C.Node): output_name = node.output().debugName() self.name_to_node[output_name] = output_dict + def convert_prim_TupleIndex(self, node: torch._C.Node): + args = tuple(self.get_fx_value(input) for input in node.inputs()) + getitem_node = self.fx_graph.call_function(operator.getitem, args) + + output_name = node.output().debugName() + self.name_to_node[output_name] = getitem_node + def convert_aten_Int(self, node: torch._C.Node): # converts aten::Int as aten._to_copy + aten::_local_scalar_dense target = torch.ops.aten._to_copy.default @@ -352,6 +363,70 @@ def convert_aten_div(self, node: torch._C.Node): self.convert_aten_op(node) + def convert_prim_if(self, node: torch._C.Node): + inputs = list(node.inputs()) + assert len(inputs) == 1 + predicate = self.get_fx_value(inputs[0]) + + # Get union of inputs to blocks + arguments = set() + for block in node.blocks(): + block_args = set() + + # TODO: block.inputs(), not sure what theyre used for + + for block_node in block.nodes(): + for block_node_in in block_node.inputs(): + if block_node_in.debugName() in self.name_to_node: + block_args.add(block_node_in.debugName()) + + arguments.update(block_args) + + arguments = list(arguments) + + # Convert blocks to subgraphs + subgraph_nodes = [] + for block in node.blocks(): + subgraph_converter = TS2FXGraphConverter(block, set(), set()) + subgraph_converter.constant_map = self.constant_map + + for block_arg in arguments: + normalized_block_arg_name = normalize_name(block_arg) + placeholder_node = subgraph_converter.fx_graph.placeholder( + normalized_block_arg_name + ) + subgraph_converter.name_to_node[block_arg] = placeholder_node + + subgraph = subgraph_converter.convert() + subgraph_name = self.add_subgraph(subgraph) + subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name)) + + assert len(subgraph_nodes) == 2 + + fx_block_args = [self.name_to_node[arg_name] for arg_name in arguments] + args = ( + predicate, + subgraph_nodes[0], + subgraph_nodes[1], + tuple(fx_block_args), + ) + + cond_node = self.fx_graph.call_function(torch.cond, args, {}) + + output_name = node.output().debugName() + self.name_to_node[output_name] = cond_node + + def convert_as_noop(self, node: torch._C.Node): + # Converts the node as a no-op by mapping its output node as arg[0] + + target = get_op_overload(node) + schema = target._schema + + args, kwargs = self.get_args_kwargs(node, schema) + + output_name = node.output().debugName() + self.name_to_node[output_name] = args[0] + def convert_node(self, node: torch._C.Node): node_kind = node.kind() if node_kind == "prim::CreateObject": @@ -371,12 +446,18 @@ def convert_node(self, node: torch._C.Node): self.convert_prim_dtype(node) elif node_kind == "prim::DictConstruct": self.convert_prim_DictConstruct(node) + elif node_kind == "prim::TupleIndex": + self.convert_prim_TupleIndex(node) # elif node_kind == "aten::Int": # convert_aten_Int(node) elif node_kind == "aten::_convolution": self.convert_aten__convolution(node) elif node_kind == "aten::div": self.convert_aten_div(node) + elif node_kind == "prim::If": + self.convert_prim_if(node) + elif node_kind == "aten::Bool": + self.convert_as_noop(node) elif node_kind.startswith("aten::"): self.convert_aten_op(node) else: @@ -413,9 +494,35 @@ def convert_graph_outputs(self): args[0] ) # Get rid of an extra list wrapped around final output. - def retrace_as_exported_program(self, gm: torch.fx.GraphModule): + +class TS2EPConverter: + # TorchScript model to ExportedProgram converter + def __init__( + self, + ts_model, + sample_args: Tuple[Any, ...], + sample_kwargs: Optional[Dict[str, Any]] = None, + ): + self.ts_model = ts_model + self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args) + + self.sample_args = sample_args + self.sample_kwargs = sample_kwargs + + self.param_names: Set[str] = {name for name, _ in ts_model.named_parameters()} + self.buffer_names: Set[str] = {name for name, _ in ts_model.named_buffers()} + + def convert(self) -> ExportedProgram: + graph_converter = TS2FXGraphConverter( + self.ts_graph, self.param_names, self.buffer_names + ) + gm = graph_converter.convert() + ep = self.retrace_as_exported_program(gm, graph_converter.tensor_constants) + return ep + + def retrace_as_exported_program(self, gm: torch.fx.GraphModule, tensor_constants): # TODO: adjust input orders to match GraphSignature convention - inputs = [*self.sample_args, *self.params, *self.tensor_constants.values()] + inputs = [*self.sample_args, *self.params, *tensor_constants.values()] ep = torch.export._trace._export( gm, From 4a0d96e496d956ae56572c66df613f44a669035f Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 31 May 2024 17:57:04 +0000 Subject: [PATCH 181/706] Add a GH action to autolabel docathon PRs (#127569) To ease oncall burden for the docathon PR reviewers and ensure all PRs are correctly labeled, adding this GH action that will look for the issue number in the PR and if that issue has a docathon-h1-2024 label, then it would propagate the labels from the issues into the PR. It should not conflict with the existing labelers because we use ``pull_request.add_to_labels`` - credit @kit1980. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127569 Approved by: https://github.com/kit1980 --- .github/scripts/docathon-label-sync.py | 52 +++++++++++++++++++++++ .github/workflows/docathon-sync-label.yml | 30 +++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 .github/scripts/docathon-label-sync.py create mode 100644 .github/workflows/docathon-sync-label.yml diff --git a/.github/scripts/docathon-label-sync.py b/.github/scripts/docathon-label-sync.py new file mode 100644 index 000000000000..a10c3c3f886c --- /dev/null +++ b/.github/scripts/docathon-label-sync.py @@ -0,0 +1,52 @@ +import os +import re +import sys + +from github import Github + + +def main() -> None: + token = os.environ.get("GITHUB_TOKEN") + + repo_owner = "pytorch" + repo_name = "pytorch" + pull_request_number = int(sys.argv[1]) + + g = Github(token) + repo = g.get_repo(f"{repo_owner}/{repo_name}") + pull_request = repo.get_pull(pull_request_number) + pull_request_body = pull_request.body + # PR without description + if pull_request_body is None: + return + + # get issue number from the PR body + if not re.search(r"#\d{1,6}", pull_request_body): + print("The pull request does not mention an issue.") + return + issue_number = int(re.findall(r"#(\d{1,6})", pull_request_body)[0]) + issue = repo.get_issue(issue_number) + issue_labels = issue.labels + docathon_label_present = any( + label.name == "docathon-h1-2024" for label in issue_labels + ) + + # if the issue has a docathon label, add all labels from the issue to the PR. + if not docathon_label_present: + print("The 'docathon-h1-2024' label is not present in the issue.") + return + pull_request_labels = pull_request.get_labels() + pull_request_label_names = [label.name for label in pull_request_labels] + issue_label_names = [label.name for label in issue_labels] + labels_to_add = [ + label for label in issue_label_names if label not in pull_request_label_names + ] + if not labels_to_add: + print("The pull request already has the same labels.") + return + pull_request.add_to_labels(*labels_to_add) + print("Labels added to the pull request!") + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/docathon-sync-label.yml b/.github/workflows/docathon-sync-label.yml new file mode 100644 index 000000000000..7cb1f608722d --- /dev/null +++ b/.github/workflows/docathon-sync-label.yml @@ -0,0 +1,30 @@ +name: Docathon Labels Sync + +on: + pull_request_target: + types: [opened, synchronize, edited] + branches: [main] + +jobs: + check-labels: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - name: Check out the repo + uses: actions/checkout@v2 + with: + fetch-depth: 1 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.x + - name: Install dependencies + run: | + pip install requests==2.32.3 + pip install PyGithub==2.3.0 + - name: Run Python script + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: python ./.github/scripts/docathon-label-sync.py ${{ github.event.pull_request.number }} From 0be06b08fc30e0fc90d954ce53b74863723fdd07 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 31 May 2024 18:50:49 +0000 Subject: [PATCH 182/706] [GPT-fast benchmark] Merge GPT-fast and micro benchmark output as one CSV file (#127586) Consolidate GPT-fast models benchmark with micro-benchmark, and save output as one CSV file with the same format as https://github.com/pytorch/pytorch/pull/126754#issue-2307296847. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127586 Approved by: https://github.com/Chillee --- benchmarks/gpt_fast/benchmark.py | 321 +++++-------------------- benchmarks/gpt_fast/generate.py | 308 ++++++++++++++++++++++++ benchmarks/gpt_fast/micro_benchmark.py | 103 -------- 3 files changed, 367 insertions(+), 365 deletions(-) create mode 100644 benchmarks/gpt_fast/generate.py delete mode 100644 benchmarks/gpt_fast/micro_benchmark.py diff --git a/benchmarks/gpt_fast/benchmark.py b/benchmarks/gpt_fast/benchmark.py index 083c98e4a92b..6e335ee31292 100644 --- a/benchmarks/gpt_fast/benchmark.py +++ b/benchmarks/gpt_fast/benchmark.py @@ -1,257 +1,74 @@ import argparse import csv import dataclasses -import itertools import os import time -from typing import Optional, Tuple -from mixtral_moe_model import Transformer as MixtralMoE -from mixtral_moe_quantize import ( - WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler, -) -from model import Transformer as LLaMA -from quantize import WeightOnlyInt8QuantHandler as LLaMAWeightOnlyInt8QuantHandler +from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8 import torch -import torch._inductor.config - -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future -torch._inductor.config.assert_indirect_indexing = False +import torch.nn as nn @dataclasses.dataclass class Experiment: name: str - module: type - mode: Optional[str] - quantizer: type - token_per_sec: float - memory_bandwidth: float - - -# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. -all_experiments = { - "llama-7b-fp16": Experiment( - "Llama-2-7b-chat-hf", - LLaMA, - "bfloat16", - LLaMAWeightOnlyInt8QuantHandler, - 94, - 1253, - ), - "llama-7b-int8": Experiment( - "Llama-2-7b-chat-hf", - LLaMA, - "int8", - LLaMAWeightOnlyInt8QuantHandler, - 144, - 957, - ), - "mixtral-int8": Experiment( # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation. - "Mixtral-8x7B-v0.1", - MixtralMoE, - "int8", - MixtralMoEWeightOnlyInt8QuantHandler, - 175, - 4129, - ), -} - -DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv" - - -def device_sync(device): - if "cuda" in device: - torch.cuda.synchronize(device) - elif "cpu" in device: - pass - else: - print(f"device={device} is not yet suppported") - - -def multinomial_sample_one_no_sync( - probs_sort, -): # Does multinomial sampling without a cuda synchronization - q = torch.empty_like(probs_sort).exponential_(1) - return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) - - -def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): - logits = logits / max(temperature, 1e-5) - - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) - logits = torch.where(logits < pivot, -float("Inf"), logits) - probs = torch.nn.functional.softmax(logits, dim=-1) - return probs - - -def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): - probs = logits_to_probs(logits[0, -1], temperature, top_k) - idx_next = multinomial_sample_one_no_sync(probs) - return idx_next, probs - - -@torch.compile(fullgraph=True) -def prefill( - model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs -) -> torch.Tensor: - # input_pos: [B, S] - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs)[0] - - -@torch.compile(fullgraph=True, mode="reduce-overhead") -def decode_one_token( - model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs -) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [B, 1] - assert input_pos.shape[-1] == 1 - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs) - - -def decode_n_tokens( - model: torch.nn.Module, - cur_token: torch.Tensor, - input_pos: torch.Tensor, - num_new_tokens: int, - **sampling_kwargs, -): - new_tokens, new_probs = [], [] - for i in range(num_new_tokens): - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.MATH - ): # Actually better for Inductor to codegen attention here - next_token, next_prob = decode_one_token( - model, cur_token, input_pos, **sampling_kwargs - ) - input_pos += 1 - new_tokens.append(next_token.clone()) - new_probs.append(next_prob.clone()) - cur_token = next_token.view(1, -1) - - return new_tokens, new_probs + metric: str + target: float + actual: float -@torch.no_grad() -def generate( - model: torch.nn.Module, prompt: torch.Tensor, max_new_tokens: int, **sampling_kwargs -) -> torch.Tensor: - device, dtype = prompt.device, prompt.dtype - T = prompt.size(0) - T_new = T + max_new_tokens - max_seq_length = min(T_new, model.config.block_size) - - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty - input_pos = torch.arange(0, T, device=device) - - next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) - seq[T] = next_token +def do_inference(mod, x, num_samples: int = 5): + total_time = 0 + start = -1 - input_pos = torch.tensor([T], device=device, dtype=torch.int) + for i in range(start, num_samples): + torch.cuda.synchronize("cuda") - generated_tokens, _ = decode_n_tokens( - model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs - ) - seq[T + 1 :] = torch.cat(generated_tokens) - return seq + t0 = time.perf_counter() + mod(x) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue -def _load_model(x: Experiment, device="cuda", precision=torch.bfloat16): - with torch.device("meta"): - model = x.module.from_name(x.name) - model = model.to(dtype=precision) + torch.cuda.synchronize("cuda") + total_time += time.perf_counter() - t0 - if x.mode == "int8": - print("Using int8 weight-only quantization!") - model = x.quantizer(model).convert_for_runtime() + total_time = total_time / num_samples - state_dict = model.state_dict() - for k, v in state_dict.items(): - state_dict[k] = torch.nn.Parameter( - torch.randn(v.shape, device=device).to(dtype=v.dtype), - requires_grad=v.requires_grad, - ) - model.load_state_dict(state_dict, assign=True) - return model.eval() + return total_time -def _get_model_size(model): - model_size = 0 - for name, child in model.named_children(): - if not isinstance(child, torch.nn.Embedding): - model_size += sum( +def run_multi_layer_norm(): + class MultiLayerNorm(nn.Module): + def __init__(self, num_layers, normalized_shape, eps=1e-5, bias=True): + super().__init__() + self.num_layers = num_layers + self.norm_layers = nn.ModuleList( [ - p.numel() * p.dtype.itemsize - for p in itertools.chain(child.parameters(), child.buffers()) + nn.LayerNorm(normalized_shape, eps=eps, bias=bias) + for _ in range(num_layers) ] ) - return model_size - - -def run_experiment( - x: Experiment, - num_samples: int = 5, - max_new_tokens: int = 200, - top_k: int = 200, - temperature: float = 0.8, -) -> None: - device = "cuda" - print(f"Loading model {x.name}") - t0 = time.time() - model = _load_model(x) - device_sync(device=device) # MKG - print(f"Time to load model: {time.time() - t0:.02f} seconds") - - prompt = torch.tensor( - [1, 15043, 29892, 590, 1024, 338], device=device, dtype=torch.int32 - ) - prompt_length = prompt.size(0) - torch.manual_seed(1234) - model_size = _get_model_size(model) + def forward(self, x): + for layer_norm in self.norm_layers: + x = layer_norm(x) + return x - aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []} - start = -1 + mod = MultiLayerNorm(num_layers=8, normalized_shape=4096).to("cuda") + mod = torch.compile(mod) + input = torch.randn([512, 1024, 4096], dtype=torch.bfloat16, device="cuda") + inference_time = do_inference(mod, input) - for i in range(start, num_samples): - device_sync(device=device) # MKG + memory_bandwidth = input.numel() * input.dtype.itemsize / inference_time / 1e9 - t0 = time.perf_counter() - y = generate( - model, prompt, max_new_tokens, temperature=temperature, top_k=top_k + return [ + Experiment( + "multi_layer_norm", "memory_bandwidth(GB/s)", 92, f"{memory_bandwidth:.02f}" ) - - if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - continue - - device_sync(device=device) # MKG - t = time.perf_counter() - t0 - tokens_generated = y.size(0) - prompt_length - tokens_sec = tokens_generated / t - aggregate_metrics["tokens_per_sec"].append(tokens_sec) - aggregate_metrics["memory_bandwidth"].append(model_size * tokens_sec / 1e9) - - token_per_sec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() - memory_bandwidth = torch.mean( - torch.tensor(aggregate_metrics["memory_bandwidth"]) - ).item() - print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec") - print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s") - print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") - return token_per_sec, memory_bandwidth + ] def output_csv(output_file, headers, row): @@ -275,41 +92,27 @@ def output_csv(output_file, headers, row): writer.writerow(list(line) + ["0"] * (len(headers) - len(line))) -def main(experiments=None, output_file=DEFAULT_OUTPUT_FILE): +DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv" + +all_experiments = { + # A list of GPT models: LlaMa, Mixtral, etc. + run_llama2_7b_bf16, + run_llama2_7b_int8, + run_mixtral_8x7b_int8, + # A list of micro-benchmarks. + run_multi_layer_norm, +} + + +def main(output_file=DEFAULT_OUTPUT_FILE): results = [] - if experiments is None: - experiments = all_experiments - else: - experiments = {k: v for k, v in all_experiments.items() if k in experiments} - - for x in experiments.values(): - actual_token_per_sec, actual_memory_bandwidth = run_experiment(x) - token_per_sec_pct = f"{actual_token_per_sec / x.token_per_sec * 100:.2f}%" - bandwidth_pct = f"{actual_memory_bandwidth / x.memory_bandwidth * 100:.2f}%" - results.append( - ( - x.name, - x.mode, - x.token_per_sec, - f"{actual_token_per_sec:.2f}", - token_per_sec_pct, - x.memory_bandwidth, - f"{actual_memory_bandwidth:.2f}", - bandwidth_pct, - ) - ) + for func in all_experiments: + lst = func() + for x in lst: + results.append(dataclasses.astuple(x)) - headers = [ - "name", - "mode", - "token_per_sec[target]", - "token_per_sec[actual]", - "token_per_sec[pct]", - "memory_bandwidth[target]", - "memory_bandwidth[actual]", - "memory_bandwidth[pct]", - ] + headers = [field.name for field in dataclasses.fields(Experiment)] for row in results: output_csv(output_file, headers, row) @@ -317,12 +120,6 @@ def main(experiments=None, output_file=DEFAULT_OUTPUT_FILE): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run experiments.") - parser.add_argument( - "--experiments", - nargs="*", - default=None, - help="Experiment names to run (default: all)", - ) parser.add_argument( "--output", default=DEFAULT_OUTPUT_FILE, @@ -330,4 +127,4 @@ def main(experiments=None, output_file=DEFAULT_OUTPUT_FILE): ) args = parser.parse_args() - main(experiments=args.experiments, output_file=args.output) + main(output_file=args.output) diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py new file mode 100644 index 000000000000..a4e4b06c79d7 --- /dev/null +++ b/benchmarks/gpt_fast/generate.py @@ -0,0 +1,308 @@ +import dataclasses +import itertools +import time +from typing import Optional, Tuple + +from mixtral_moe_model import Transformer as MixtralMoE +from mixtral_moe_quantize import ( + WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler, +) +from model import Transformer as LLaMA +from quantize import WeightOnlyInt8QuantHandler as LLaMAWeightOnlyInt8QuantHandler + +import torch +import torch._inductor.config + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future +torch._inductor.config.assert_indirect_indexing = False + + +@dataclasses.dataclass +class GPTModelConfig: + name: str + module: type + mode: Optional[str] + quantizer: type + token_per_sec: float + memory_bandwidth: float + + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif "cpu" in device: + pass + else: + print(f"device={device} is not yet suppported") + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +@torch.compile(fullgraph=True) +def prefill( + model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + + +@torch.compile(fullgraph=True, mode="reduce-overhead") +def decode_one_token( + model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + + +def decode_n_tokens( + model: torch.nn.Module, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + **sampling_kwargs, +): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH + ): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + + +@torch.no_grad() +def generate( + model: torch.nn.Module, prompt: torch.Tensor, max_new_tokens: int, **sampling_kwargs +) -> torch.Tensor: + device, dtype = prompt.device, prompt.dtype + T = prompt.size(0) + T_new = T + max_new_tokens + max_seq_length = min(T_new, model.config.block_size) + + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + + generated_tokens, _ = decode_n_tokens( + model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs + ) + seq[T + 1 :] = torch.cat(generated_tokens) + return seq + + +def _load_model(x: GPTModelConfig, device="cuda", precision=torch.bfloat16): + with torch.device("meta"): + model = x.module.from_name(x.name) + model = model.to(dtype=precision) + + if x.mode == "int8": + print("Using int8 weight-only quantization!") + model = x.quantizer(model).convert_for_runtime() + + state_dict = model.state_dict() + for k, v in state_dict.items(): + state_dict[k] = torch.nn.Parameter( + torch.randn(v.shape, device=device).to(dtype=v.dtype), + requires_grad=v.requires_grad, + ) + model.load_state_dict(state_dict, assign=True) + return model.eval() + + +def _get_model_size(model): + model_size = 0 + for name, child in model.named_children(): + if not isinstance(child, torch.nn.Embedding): + model_size += sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain(child.parameters(), child.buffers()) + ] + ) + return model_size + + +def run_experiment( + x: GPTModelConfig, + num_samples: int = 5, + max_new_tokens: int = 200, + top_k: int = 200, + temperature: float = 0.8, +) -> None: + device = "cuda" + print(f"Loading model {x.name}") + t0 = time.time() + model = _load_model(x) + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + prompt = torch.tensor( + [1, 15043, 29892, 590, 1024, 338], device=device, dtype=torch.int32 + ) + prompt_length = prompt.size(0) + + torch.manual_seed(1234) + model_size = _get_model_size(model) + + aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []} + start = -1 + + for i in range(start, num_samples): + device_sync(device=device) # MKG + + t0 = time.perf_counter() + y = generate( + model, prompt, max_new_tokens, temperature=temperature, top_k=top_k + ) + + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + aggregate_metrics["memory_bandwidth"].append(model_size * tokens_sec / 1e9) + + token_per_sec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() + memory_bandwidth = torch.mean( + torch.tensor(aggregate_metrics["memory_bandwidth"]) + ).item() + print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec") + print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + return token_per_sec, memory_bandwidth + + +# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. +def run_llama2_7b_bf16(): + from benchmark import Experiment + + model = GPTModelConfig( + "Llama-2-7b-chat-hf", + LLaMA, + "bfloat16", + LLaMAWeightOnlyInt8QuantHandler, + 94, + 1253, + ) + token_per_sec, memory_bandwidth = run_experiment(model) + return [ + Experiment( + "llama2_7b_bf16", + "token_per_sec", + model.token_per_sec, + f"{token_per_sec:.02f}", + ), + Experiment( + "llama2_7b_bf16", + "memory_bandwidth(GB/s)", + model.memory_bandwidth, + f"{memory_bandwidth:.02f}", + ), + ] + + +# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. +def run_llama2_7b_int8(): + from benchmark import Experiment + + model = GPTModelConfig( + "Llama-2-7b-chat-hf", + LLaMA, + "int8", + LLaMAWeightOnlyInt8QuantHandler, + 144, + 957, + ) + token_per_sec, memory_bandwidth = run_experiment(model) + return [ + Experiment( + "llama2_7b_int8", + "token_per_sec", + model.token_per_sec, + f"{token_per_sec:.02f}", + ), + Experiment( + "llama2_7b_int8", + "memory_bandwidth(GB/s)", + model.memory_bandwidth, + f"{memory_bandwidth:.02f}", + ), + ] + + +# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. +def run_mixtral_8x7b_int8(): + from benchmark import Experiment + + # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation. + model = GPTModelConfig( + "Mixtral-8x7B-v0.1", + MixtralMoE, + "int8", + MixtralMoEWeightOnlyInt8QuantHandler, + 175, + 4129, + ) + token_per_sec, memory_bandwidth = run_experiment(model) + return [ + Experiment( + "mixtral_8x7b_int8", + "token_per_sec", + model.token_per_sec, + f"{token_per_sec:.02f}", + ), + Experiment( + "mixtral_8x7b_int8", + "memory_bandwidth(GB/s)", + model.memory_bandwidth, + f"{memory_bandwidth:.02f}", + ), + ] diff --git a/benchmarks/gpt_fast/micro_benchmark.py b/benchmarks/gpt_fast/micro_benchmark.py deleted file mode 100644 index 3c8f0865a244..000000000000 --- a/benchmarks/gpt_fast/micro_benchmark.py +++ /dev/null @@ -1,103 +0,0 @@ -import argparse -import dataclasses -import time - -import torch -import torch.nn as nn - - -@dataclasses.dataclass -class Experiment: - name: str - metric: str - target: float - actual: float - - -DEFAULT_OUTPUT_FILE = "micro_benchmark.csv" - - -def do_inference(mod, x, num_samples: int = 5): - total_time = 0 - start = -1 - - for i in range(start, num_samples): - torch.cuda.synchronize("cuda") - - t0 = time.perf_counter() - mod(x) - - if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - continue - - torch.cuda.synchronize("cuda") - total_time += time.perf_counter() - t0 - - total_time = total_time / num_samples - - return total_time - - -def run_multi_layernorm(): - class MultiLayerNorm(nn.Module): - def __init__(self, num_layers, normalized_shape, eps=1e-5, bias=True): - super().__init__() - self.num_layers = num_layers - self.norm_layers = nn.ModuleList( - [ - nn.LayerNorm(normalized_shape, eps=eps, bias=bias) - for _ in range(num_layers) - ] - ) - - def forward(self, x): - for layer_norm in self.norm_layers: - x = layer_norm(x) - return x - - mod = MultiLayerNorm(num_layers=8, normalized_shape=4096).to("cuda") - mod = torch.compile(mod) - input = torch.randn([512, 1024, 4096], dtype=torch.bfloat16, device="cuda") - inference_time = do_inference(mod, input) - - memory_bandwidth = input.numel() * input.dtype.itemsize / inference_time / 1e9 - - return [ - Experiment( - "multi_layer_norm", "memory_bandwidth(GB/s)", 92, f"{memory_bandwidth:.02f}" - ) - ] - - -all_experiments = { - run_multi_layernorm, -} - - -def main(output_file=DEFAULT_OUTPUT_FILE): - results = [] - - for func in all_experiments: - lst = func() - for x in lst: - results.append(dataclasses.astuple(x)) - - headers = [field.name for field in dataclasses.fields(Experiment)] - - from benchmark import output_csv - - for row in results: - output_csv(output_file, headers, row) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run experiments.") - parser.add_argument( - "--output", - default=DEFAULT_OUTPUT_FILE, - help="Set the output CSV file to save the benchmark results", - ) - args = parser.parse_args() - - main(output_file=args.output) From 121c55d8d12a878b12eab00a7cebae2e2fa47ee7 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 31 May 2024 18:54:54 +0000 Subject: [PATCH 183/706] Old branch deletion script to also delete old ciflow tags (#127625) Change branch deletion script to also delete left over ciflow tags that the bot doesn't get to, as well as the one created by triggering a workflow on HUD Example run https://github.com/pytorch/pytorch/actions/runs/9322082915/job/25662376463?pr=127625 (didn't actually delete the tag, but lists what tags it would delete) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127625 Approved by: https://github.com/huydhn --- .github/scripts/delete_old_branches.py | 58 ++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 9 deletions(-) diff --git a/.github/scripts/delete_old_branches.py b/.github/scripts/delete_old_branches.py index 21b86fefa1a8..c2676ae09ea7 100644 --- a/.github/scripts/delete_old_branches.py +++ b/.github/scripts/delete_old_branches.py @@ -2,6 +2,7 @@ import os import re from datetime import datetime +from functools import lru_cache from pathlib import Path from typing import Any, Callable, Dict, List, Set @@ -187,6 +188,17 @@ def get_recent_prs() -> Dict[str, Any]: return prs_by_branch_base +@lru_cache(maxsize=1) +def get_open_prs() -> List[Dict[str, Any]]: + return paginate_graphql( + GRAPHQL_OPEN_PRS, + {"owner": "pytorch", "repo": "pytorch"}, + lambda data: False, + lambda res: res["data"]["repository"]["pullRequests"]["nodes"], + lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"], + ) + + def get_branches_with_magic_label_or_open_pr() -> Set[str]: pr_infos: List[Dict[str, Any]] = paginate_graphql( GRAPHQL_NO_DELETE_BRANCH_LABEL, @@ -196,15 +208,7 @@ def get_branches_with_magic_label_or_open_pr() -> Set[str]: lambda res: res["data"]["repository"]["label"]["pullRequests"]["pageInfo"], ) - pr_infos.extend( - paginate_graphql( - GRAPHQL_OPEN_PRS, - {"owner": "pytorch", "repo": "pytorch"}, - lambda data: False, - lambda res: res["data"]["repository"]["pullRequests"]["nodes"], - lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"], - ) - ) + pr_infos.extend(get_open_prs()) # Get the most recent PR for each branch base (group gh together) branch_bases = set() @@ -270,5 +274,41 @@ def delete_branches() -> None: delete_branch(git_repo, branch) +def delete_old_ciflow_tags() -> None: + # Deletes ciflow tags if they are associated with a closed PR or a specific + # commit. Lightweight tags don't have information about the date they were + # created, so we can't check how old they are. The script just assumes that + # ciflow tags should be deleted regardless of creation date. + git_repo = GitRepo(str(REPO_ROOT), "origin", debug=True) + + def delete_tag(tag: str) -> None: + print(f"Deleting tag {tag}") + ESTIMATED_TOKENS[0] += 1 + delete_branch(git_repo, f"refs/tags/{tag}") + + tags = git_repo._run_git("tag").splitlines() + open_pr_numbers = [x["number"] for x in get_open_prs()] + + for tag in tags: + try: + if ESTIMATED_TOKENS[0] > 400: + print("Estimated tokens exceeded, exiting") + break + if not tag.startswith("ciflow/"): + continue + re_match_pr = re.match(r"^ciflow\/.*\/(\d{5,6})$", tag) + re_match_sha = re.match(r"^ciflow\/.*\/([0-9a-f]{40})$", tag) + if re_match_pr: + pr_number = int(re_match_pr.group(1)) + if pr_number in open_pr_numbers: + continue + delete_tag(tag) + elif re_match_sha: + delete_tag(tag) + except Exception as e: + print(f"Failed to check tag {tag}: {e}") + + if __name__ == "__main__": delete_branches() + delete_old_ciflow_tags() From b704c7cf0f8fb8b92198ffd4cd07716a7b9efae1 Mon Sep 17 00:00:00 2001 From: Kwanghoon An Date: Fri, 31 May 2024 19:08:04 +0000 Subject: [PATCH 184/706] Re trying Support min/max carry over for eager mode from_float method (#127576) Summary: Original commit changeset: 2605900516c8 Original Phabricator Diff: D57977896 Test Plan: Re enabling due to prod failure Reviewed By: jerryzh168 Differential Revision: D57978925 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127576 Approved by: https://github.com/jerryzh168 --- torch/ao/quantization/quantize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 17cec9f35908..534def354573 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -1,7 +1,7 @@ import copy import itertools import warnings - +import inspect import torch import torch.nn as nn import torch.ao.nn.quantized as nnq @@ -636,7 +636,11 @@ def swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_ weight_qparams = get_qparam_dict(weight_post_process) new_mod = qmod.from_float(mod, weight_qparams) else: - new_mod = qmod.from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) + sig = inspect.signature(qmod.from_float) + if 'use_precomputed_fake_quant' in sig.parameters: + new_mod = qmod.from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) + else: + new_mod = qmod.from_float(mod) swapped = True if swapped: From 8af1c655e5e45fd7ed3c56ec9285ff02f901e3d8 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 30 May 2024 14:36:14 -0700 Subject: [PATCH 185/706] improve eager overhead of _disable_dynamo (#127325) it seems like `_disable_dynamo` actually has a fair amount of overhead (especially when it was added to `DTensor.__new__`: this change speeds up @wanchaol 's repro from 0.380 -> 0.312s: P1378202570 (that repro runs a vanilla MLP using 2D parallelism, and calls the DTensor constructor 1280 times). It looks like most of the slowndown is in the fact that we are repeatedly running `import torch._dynamo` and constructing an instance of `torch._dynamo.disable(fn, recursive)` on every call to the constructor - this PR caches it on the first invocation. ~~Update: I realized I cannot use `torch.compiler.is_compiling` to know when to fast-path, because when we hit a graph break, cpython will be running so it will return False.~~ ~~As a test / potential fix, I added a new config, `torch._dynamo.config._is_compiling` that is set to True **always** inside a compiled region (even on frames that are run by cpython). This definitely seems to do what I want in terms of knowing when to fastpath and avoid overhead - although interested in feedback on how reasonable this is~~ Pull Request resolved: https://github.com/pytorch/pytorch/pull/127325 Approved by: https://github.com/wanchaol, https://github.com/anijain2305 --- torch/_compile.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch/_compile.py b/torch/_compile.py index 354d64e9ff9f..2b00415e0eba 100644 --- a/torch/_compile.py +++ b/torch/_compile.py @@ -19,9 +19,15 @@ def _disable_dynamo(fn=None, recursive=True): @functools.wraps(fn) def inner(*args, **kwargs): - import torch._dynamo + # cache this on the first invocation to avoid adding too much overhead. + disable_fn = getattr(fn, "__dynamo_disable", None) + if disable_fn is None: + import torch._dynamo - return torch._dynamo.disable(fn, recursive)(*args, **kwargs) + disable_fn = torch._dynamo.disable(fn, recursive) + fn.__dynamo_disable = disable_fn + + return disable_fn(*args, **kwargs) return inner else: From 11034448708ef878de9f0c0bbee595961b0c9d3b Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Tue, 28 May 2024 07:21:43 -0700 Subject: [PATCH 186/706] [AOTI] Add back include_pytorch for specifying link paths (#126802) Summary: Running dashboard with the cpp wrapper mode sometimes hit erros like "undefined symbol: aoti_torch_empty_stride", although it can not be reproduced locally and seems only happen on the dashboard CI. Differential Revision: [D57911442](https://our.internmc.facebook.com/intern/diff/D57911442) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126802 Approved by: https://github.com/chenyang78 ghstack dependencies: #126916, #127037 --- torch/_inductor/codecache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c47f01751482..b5aa1d1b8a61 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2538,7 +2538,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache): cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} cache_clear = staticmethod(cache.clear) cpp_compile_command_flags = { - "include_pytorch": not config.abi_compatible, + "include_pytorch": True, "shared": True, } entry_function = "inductor_entry_cpp" From bbf892dd58dcd8a5683b56a365489ecc6fc67b0c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 31 May 2024 19:35:15 +0000 Subject: [PATCH 187/706] Revert "Add back private function torch.cuda.amp.autocast_mode._cast (#127433)" This reverts commit 6e0eeecc7cd4dc389683e35d1f2e34738e09e597. Reverted https://github.com/pytorch/pytorch/pull/127433 on behalf of https://github.com/fbgheith due to depends on https://github.com/pytorch/pytorch/pull/126898 which is failing internally and needs to be reverted ([comment](https://github.com/pytorch/pytorch/pull/127433#issuecomment-2142869610)) --- torch/cuda/amp/autocast_mode.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index 09a44f50c90b..e50206c70577 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -50,11 +50,6 @@ def __call__(self, func): return super().__call__(func) -# Preserved only for BC reasons -def _cast(value, dtype): - return torch.amp.autocast_mode._cast(value, "cuda", dtype) - - @deprecated( "`torch.cuda.amp.custom_fwd(args...)` is deprecated. " "Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.", From ea13e9a097aaa875a2b404822579b7f8b62ea291 Mon Sep 17 00:00:00 2001 From: Robert Mast Date: Fri, 31 May 2024 19:38:42 +0000 Subject: [PATCH 188/706] correct BLAS input (#126200) Fixes #32407 With this little correction to Dependencies.cmake it is possible to build an MKL-free version of Pytorch up from version v2.0.0 by explicitly choosing another MKL-free BLAS. This pullrequest fulfills the "if not already present" part of the original comment in Dependencies.cmake: "setting default preferred BLAS options if not already present." It's tested with this Action-.yml: ``` name: Build PyTorch v2.0.0 without AVX on: push: branches: - v2.0.0 pull_request: branches: - v2.0.0 jobs: build: runs-on: ubuntu-20.04 defaults: run: shell: bash -el {0} steps: - name: Checkout repository uses: actions/checkout@v4 with: #repository: 'pytorch/pytorch' #ref: 'v2.3.0' submodules: 'recursive' - uses: conda-incubator/setup-miniconda@v3 with: auto-activate-base: true activate-environment: true python-version: 3.10.13 - name: Install Dependencies - Common - Linux 2 run: | conda info conda list conda install nomkl conda install astunparse numpy ninja pyyaml setuptools cmake cffi typing_extensions future six requests dataclasses export PYTORCH_CPU_CAPABILITY=cpu export ATEN_CPU_CAPABILITY_DEFAULT=cpu export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} export ATEN_CPU_CAPABILITY=default export USE_NNPACK=0 export MAX_JOBS=4 export USE_CUDA=0 export USE_ROCM=0 export BLAS=OpenBLAS export CMAKE_ARGS="-D CMAKE_BUILD_TYPE=Release -D USE_AVX=OFF -D USE_NNPACK=OFF -D C_HAS_AVX_2=OFF -D C_HAS_AVX2_2=OFF -D CXX_HAS_AVX_2=OFF -D CXX_HAS_AVX2_2=OFF -D CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS=OFF -DPYTHON_INCLUDE_DIR=$(python -c "import sysconfig; print(sysconfig.get_path('include'))") -DPYTHON_LIBRARY=$(python -c "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))") -DPYTHON_EXECUTABLE:FILEPATH=`which python`" pip install build wheel typing_extensions python setup.py bdist_wheel - name: Archive production artifacts uses: actions/upload-artifact@v4 with: name: dist-without-markdown path: | dist !dist/**/*.md ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126200 Approved by: https://github.com/jgong5, https://github.com/kit1980 --- cmake/Dependencies.cmake | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 1e4c4262fba1..4d53b2f860d6 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -145,13 +145,23 @@ endif() set(AT_MKLDNN_ACL_ENABLED 0) # setting default preferred BLAS options if not already present. -if(NOT INTERN_BUILD_MOBILE) - set(BLAS "MKL" CACHE STRING "Selected BLAS library") -else() - set(BLAS "Eigen" CACHE STRING "Selected BLAS library") - set(AT_MKLDNN_ENABLED 0) - set(AT_MKL_ENABLED 0) +if(NOT DEFINED BLAS) + if(NOT INTERN_BUILD_MOBILE) + set(BLAS "MKL" CACHE STRING "Selected BLAS library") + else() + set(BLAS "Eigen" CACHE STRING "Selected BLAS library") + set(AT_MKLDNN_ENABLED 0) + set(AT_MKL_ENABLED 0) + endif() +elseif(NOT BLAS STREQUAL "MKL") + if(USE_MKLDNN) + message(WARNING + "You explicitly chose with BLAS to not use MKL, so disabling USE_MKLDNN. Suppress this warning with " + "-DUSE_MKLDNN=OFF.") + set(USE_MKLDNN OFF) + endif() endif() + set_property(CACHE BLAS PROPERTY STRINGS "ATLAS;BLIS;Eigen;FLAME;Generic;MKL;OpenBLAS;vecLib") message(STATUS "Trying to find preferred BLAS backend of choice: " ${BLAS}) From 033e7330211e9c9f85cd7745c221aaccd02c5c76 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 31 May 2024 19:47:24 +0000 Subject: [PATCH 189/706] Revert "[BE] wrap deprecated function/class with `typing_extensions.deprecated` (#126898)" This reverts commit 749a132fb0a8325cbad4734a563aa459ca611991. Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456)) --- .github/requirements/conda-env-Linux-X64.txt | 2 +- .github/requirements/conda-env-iOS.txt | 2 +- .github/requirements/conda-env-macOS-ARM64 | 2 +- .github/requirements/conda-env-macOS-X64 | 2 +- test/distributed/_tensor/test_api.py | 2 +- .../distributed/fsdp/test_fsdp_optim_state.py | 2 +- test/functorch/test_eager_transforms.py | 6 +- test/nn/test_init.py | 2 +- test/nn/test_module_hooks.py | 9 ++- test/test_autocast.py | 4 +- test/test_autograd.py | 10 +-- test/test_cuda.py | 8 +-- test/test_optim.py | 2 +- test/test_prims.py | 2 +- test/test_pytree.py | 2 +- test/test_stateless.py | 2 +- test/test_torch.py | 4 +- torch/__init__.py | 11 ++++ torch/_dynamo/decorators.py | 3 +- torch/_dynamo/eval_frame.py | 8 +-- torch/_dynamo/external_utils.py | 5 -- torch/_functorch/apis.py | 6 +- torch/_functorch/deprecated.py | 26 ++++---- torch/_functorch/eager_transforms.py | 4 +- torch/_functorch/pytree_hacks.py | 5 +- torch/_higher_order_ops/associative_scan.py | 2 +- torch/_inductor/ir.py | 3 +- torch/_library/abstract_impl.py | 9 ++- torch/_prims_common/__init__.py | 13 ++-- torch/_utils.py | 6 +- torch/_vmap_internals.py | 10 +-- torch/ao/nn/quantizable/modules/activation.py | 11 +--- torch/ao/nn/quantized/dynamic/modules/rnn.py | 6 +- torch/ao/quantization/fx/convert.py | 18 ++--- torch/ao/quantization/fx/fuse.py | 12 ++-- torch/ao/quantization/fx/prepare.py | 24 +++---- torch/ao/quantization/qconfig.py | 22 ++----- torch/ao/quantization/quantize_fx.py | 18 ++--- torch/autograd/__init__.py | 16 ++--- torch/autograd/_functions/tensor.py | 11 ++-- torch/autograd/function.py | 16 ++--- torch/autograd/gradcheck.py | 31 ++++----- torch/autograd/profiler.py | 4 +- torch/autograd/profiler_legacy.py | 10 +-- torch/autograd/profiler_util.py | 23 +------ torch/backends/cuda/__init__.py | 20 +++--- torch/cpu/amp/autocast_mode.py | 11 ++-- torch/cpu/amp/grad_scaler.py | 10 ++- torch/cuda/_memory_viz.py | 6 +- torch/cuda/amp/autocast_mode.py | 27 ++++---- torch/cuda/amp/grad_scaler.py | 10 ++- torch/cuda/memory.py | 17 +++-- torch/cuda/nccl.py | 8 +-- torch/distributed/_composable/fully_shard.py | 20 +++--- torch/distributed/_functional_collectives.py | 3 +- .../distributed/_shard/checkpoint/__init__.py | 15 ++--- .../distributed/_shard/sharded_tensor/api.py | 11 ++-- torch/distributed/_sharded_tensor/__init__.py | 14 ++-- torch/distributed/_sharding_spec/__init__.py | 13 ++-- torch/distributed/_tensor/api.py | 2 - .../_checkpoint/checkpoint_wrapper.py | 1 - .../ddp_comm_hooks/default_hooks.py | 4 +- .../checkpoint/state_dict_loader.py | 10 ++- .../checkpoint/state_dict_saver.py | 11 ++-- torch/distributed/distributed_c10d.py | 65 +++++++++---------- torch/distributed/elastic/metrics/api.py | 12 ++-- torch/distributed/fsdp/_init_utils.py | 3 +- .../fsdp/fully_sharded_data_parallel.py | 8 +-- torch/distributed/launch.py | 22 +++---- torch/distributed/optim/__init__.py | 14 +--- torch/distributed/pipeline/__init__.py | 18 ++--- torch/distributed/tensor/parallel/_utils.py | 5 +- torch/distributions/distribution.py | 8 +-- .../multipledispatch/dispatcher.py | 30 +++++---- torch/hub.py | 10 +-- torch/jit/_script.py | 4 +- torch/jit/_trace.py | 8 +-- torch/library.py | 12 ++-- torch/multiprocessing/spawn.py | 2 +- torch/nn/functional.py | 18 ++--- torch/nn/init.py | 6 +- torch/nn/modules/activation.py | 10 +-- torch/nn/modules/container.py | 11 ++-- torch/nn/modules/conv.py | 15 +++-- torch/nn/modules/loss.py | 12 ++-- torch/nn/modules/module.py | 19 ++---- torch/nn/modules/rnn.py | 6 +- torch/nn/parallel/__init__.py | 11 +--- torch/nn/parallel/comm.py | 4 +- torch/nn/parallel/distributed.py | 7 +- torch/nn/parallel/scatter_gather.py | 8 +-- torch/nn/utils/clip_grad.py | 9 +-- torch/nn/utils/stateless.py | 14 ++-- torch/nn/utils/weight_norm.py | 9 +-- torch/optim/adadelta.py | 6 +- torch/optim/adam.py | 6 +- torch/optim/adamax.py | 6 +- torch/optim/adamw.py | 6 +- torch/optim/asgd.py | 4 +- torch/optim/nadam.py | 4 +- torch/optim/optimizer.py | 11 ++-- torch/optim/radam.py | 4 +- torch/optim/rmsprop.py | 6 +- torch/optim/rprop.py | 6 +- torch/optim/sgd.py | 2 +- torch/profiler/profiler.py | 5 +- torch/sparse/semi_structured.py | 9 +-- torch/testing/_comparison.py | 16 +++-- torch/testing/_creation.py | 2 +- .../_internal/optests/generate_tests.py | 2 +- torch/utils/_config_module.py | 12 ++-- torch/utils/_contextlib.py | 12 ++-- torch/utils/_cxx_pytree.py | 12 ++-- torch/utils/_pytree.py | 27 +++----- torch/utils/data/backward_compatibility.py | 11 +--- torch/utils/data/dataset.py | 10 +-- torch/utils/data/graph_settings.py | 10 ++- 117 files changed, 478 insertions(+), 700 deletions(-) diff --git a/.github/requirements/conda-env-Linux-X64.txt b/.github/requirements/conda-env-Linux-X64.txt index 78534c21e911..16bbc57dd3be 100644 --- a/.github/requirements/conda-env-Linux-X64.txt +++ b/.github/requirements/conda-env-Linux-X64.txt @@ -6,4 +6,4 @@ numpy=1.23.3 pyyaml=6.0 requests=2.31.0 setuptools=68.2.2 -typing-extensions=4.9.0 +typing-extensions=4.3.0 diff --git a/.github/requirements/conda-env-iOS.txt b/.github/requirements/conda-env-iOS.txt index a88a16dba4df..205c07925a01 100644 --- a/.github/requirements/conda-env-iOS.txt +++ b/.github/requirements/conda-env-iOS.txt @@ -5,4 +5,4 @@ numpy=1.23.3 pyyaml=6.0 requests=2.31.0 setuptools=68.2.2 -typing-extensions=4.9.0 +typing-extensions=4.3.0 diff --git a/.github/requirements/conda-env-macOS-ARM64 b/.github/requirements/conda-env-macOS-ARM64 index 26b034c7d6e2..951cda496403 100644 --- a/.github/requirements/conda-env-macOS-ARM64 +++ b/.github/requirements/conda-env-macOS-ARM64 @@ -2,7 +2,7 @@ numpy=1.22.3 pyyaml=6.0 setuptools=61.2.0 cmake=3.22.* -typing-extensions=4.9.0 +typing-extensions=4.3.0 dataclasses=0.8 pip=22.2.2 pillow=10.0.1 diff --git a/.github/requirements/conda-env-macOS-X64 b/.github/requirements/conda-env-macOS-X64 index 35da8324689a..95be2a082397 100644 --- a/.github/requirements/conda-env-macOS-X64 +++ b/.github/requirements/conda-env-macOS-X64 @@ -4,7 +4,7 @@ numpy=1.21.2 pyyaml=5.3 setuptools=46.0.0 cmake=3.22.* -typing-extensions=4.9.0 +typing-extensions=4.3.0 dataclasses=0.8 pip=22.2.2 pillow=10.0.1 diff --git a/test/distributed/_tensor/test_api.py b/test/distributed/_tensor/test_api.py index 21763b091a54..196bd6407b26 100644 --- a/test/distributed/_tensor/test_api.py +++ b/test/distributed/_tensor/test_api.py @@ -237,7 +237,7 @@ def output_fn(outputs, device_mesh): assert isinstance(outputs, DTensor) return outputs.to_local() - with self.assertWarnsRegex(FutureWarning, "Deprecating"): + with self.assertWarnsRegex(UserWarning, "Deprecating"): replica_module = distribute_module( module_to_replicate, device_mesh, diff --git a/test/distributed/fsdp/test_fsdp_optim_state.py b/test/distributed/fsdp/test_fsdp_optim_state.py index 29925c13f2d6..672b71d5290f 100644 --- a/test/distributed/fsdp/test_fsdp_optim_state.py +++ b/test/distributed/fsdp/test_fsdp_optim_state.py @@ -1436,7 +1436,7 @@ def should_check_method(method_name: str): def get_warning_context(): warning_regex = "`optim_input` argument is deprecated" return self.assertWarnsRegex( - expected_warning=FutureWarning, expected_regex=warning_regex + expected_warning=UserWarning, expected_regex=warning_regex ) self._run_on_all_optim_state_apis( diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index c767810beb85..9aae6e5451a3 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -3258,7 +3258,7 @@ def test_deprecation_vmap(self, device): x = torch.randn(3, device=device) # functorch version of the API is deprecated - with self.assertWarnsRegex(FutureWarning, "Please use `torch.vmap`"): + with self.assertWarnsRegex(UserWarning, "Please use torch.vmap"): vmap(torch.sin) # the non-functorch version is not deprecated @@ -3276,9 +3276,7 @@ def test_deprecation_transforms(self, device, transform): new_api = getattr(torch.func, transform) # functorch version of the API is deprecated - with self.assertWarnsRegex( - FutureWarning, f"Please use `torch.func.{transform}`" - ): + with self.assertWarnsRegex(UserWarning, f"Please use torch.func.{transform}"): api(torch.sin) # the non-functorch version is not deprecated diff --git a/test/nn/test_init.py b/test/nn/test_init.py index 9ae471414474..8826fabc263b 100644 --- a/test/nn/test_init.py +++ b/test/nn/test_init.py @@ -521,7 +521,7 @@ def fn(): init.normal(x) with self.assertWarnsRegex( - FutureWarning, + UserWarning, "deprecated", msg="methods not suffixed with underscore should be deprecated", ): diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index f76837660302..8dbd255c6c53 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -1387,8 +1387,7 @@ def forward(self, l): m.register_backward_hook(noop) with self.assertWarnsRegex( - FutureWarning, - "does not take as input a single Tensor or a tuple of Tensors", + UserWarning, "does not take as input a single Tensor or a tuple of Tensors" ): m([a, b]) @@ -1401,7 +1400,7 @@ def forward(self, a, b): m.register_backward_hook(noop) with self.assertWarnsRegex( - FutureWarning, "does not return a single Tensor or a tuple of Tensors" + UserWarning, "does not return a single Tensor or a tuple of Tensors" ): m(a, b) @@ -1414,7 +1413,7 @@ def forward(self, a, b): m.register_backward_hook(noop) with self.assertWarnsRegex( - FutureWarning, "outputs are generated by different autograd Nodes" + UserWarning, "outputs are generated by different autograd Nodes" ): m(a, b) @@ -1427,7 +1426,7 @@ def forward(self, a): m.register_backward_hook(noop) with self.assertWarnsRegex( - FutureWarning, "the forward contains multiple autograd Nodes" + UserWarning, "the forward contains multiple autograd Nodes" ): m(a) diff --git a/test/test_autocast.py b/test/test_autocast.py index 24f87944990d..ce3d94318ccd 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -255,8 +255,8 @@ def test_generic_autocast(self): def test_cpu_autocast_deprecated_warning(self): with self.assertWarnsRegex( - FutureWarning, - r"`torch.cpu.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cpu', args...\)` instead.", + DeprecationWarning, + r"torch.cpu.amp.autocast\(args...\) is deprecated. Please use torch.amp.autocast\('cpu', args...\) instead.", ): with torch.cpu.amp.autocast(): _ = torch.ones(10) diff --git a/test/test_autograd.py b/test/test_autograd.py index 911762024930..0dc9aca21041 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -154,7 +154,7 @@ def hook(*args): def test_grad_mode_class_decoration(self): # Decorating class is deprecated and should not be used - with self.assertWarnsRegex(FutureWarning, "Decorating classes is deprecated"): + with self.assertWarnsRegex(UserWarning, "Decorating classes is deprecated"): @torch.no_grad() class Foo: @@ -5937,13 +5937,13 @@ def fn(inputs): b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64) with self.assertWarnsRegex( - FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API" + UserWarning, "get_numerical_jacobian was part of PyTorch's private API" ): jacobian = get_numerical_jacobian(fn, (a, b), target=a, eps=1e-6) self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double)) with self.assertWarnsRegex( - FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API" + UserWarning, "get_numerical_jacobian was part of PyTorch's private API" ): jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6) self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double)) @@ -5963,7 +5963,7 @@ def fn(x, y): outputs = fn(a, b) with self.assertWarnsRegex( - FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API" + UserWarning, "get_analytical_jacobian was part of PyTorch's private API" ): ( jacobians, @@ -5991,7 +5991,7 @@ def backward(ctx, grad_out): outputs = NonDetFunc.apply(a, 1e-6) with self.assertWarnsRegex( - FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API" + UserWarning, "get_analytical_jacobian was part of PyTorch's private API" ): ( jacobians, diff --git a/test/test_cuda.py b/test/test_cuda.py index 785f0499df05..c919158e2c4e 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1820,10 +1820,10 @@ def backward(ctx, grad): return grad, grad self.assertRegex( - str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated." + str(w[0].message), r"torch.cuda.amp.custom_fwd\(args...\) is deprecated." ) self.assertRegex( - str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated." + str(w[1].message), r"torch.cuda.amp.custom_bwd\(args...\) is deprecated." ) mymm = MyMM.apply @@ -2016,8 +2016,8 @@ def test_autocast_checkpointing(self): def test_cuda_autocast_deprecated_warning(self): with self.assertWarnsRegex( - FutureWarning, - r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.", + DeprecationWarning, + r"torch.cuda.amp.autocast\(args...\) is deprecated. Please use torch.amp.autocast\('cuda', args...\) instead.", ): with torch.cuda.amp.autocast(): _ = torch.ones(10) diff --git a/test/test_optim.py b/test/test_optim.py index 3ab57fecd833..d61c33e2adce 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -287,7 +287,7 @@ def test_param_group_with_lrscheduler_goes_right_direction( inpt = torch.randn(5, device=device, dtype=dtype) # avoid endless recompiles by wrapping LR in a tensor if we're compiling - lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01 + lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01 optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}]) schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c] diff --git a/test/test_prims.py b/test/test_prims.py index f1452acd7ab3..2a1f10cc8748 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -338,7 +338,7 @@ def test_mul_complex(self): prims.mul(torch.randn(2), 1 + 1j) def test_check_deprecation_warning(self): - with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'): + with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'): torch._prims_common.check(True, lambda: 'message') diff --git a/test/test_pytree.py b/test/test_pytree.py index 0a1c480a8fa7..caaf4d0b53bd 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -723,7 +723,7 @@ def __init__(self, x, y): self.y = y with self.assertWarnsRegex( - FutureWarning, "torch.utils._pytree._register_pytree_node" + UserWarning, "torch.utils._pytree._register_pytree_node" ): py_pytree._register_pytree_node( DummyType, diff --git a/test/test_stateless.py b/test/test_stateless.py index 6256f2b55bf8..32ec45937059 100644 --- a/test/test_stateless.py +++ b/test/test_stateless.py @@ -901,7 +901,7 @@ def test_stateless_functional_call_warns(self): m = torch.nn.Linear(1, 1) params = dict(m.named_parameters()) x = torch.randn(3, 1) - with self.assertWarnsRegex(FutureWarning, "Please use `torch.func.functional_call`"): + with self.assertWarnsRegex(UserWarning, "Please use torch.func.functional_call"): stateless.functional_call(m, params, x) class TestPythonOptimizeMode(TestCase): diff --git a/test/test_torch.py b/test/test_torch.py index 717943f43646..05070abaf669 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6198,8 +6198,8 @@ def test_grad_scaler_deprecated_warning(self, device): GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler with self.assertWarnsRegex( - FutureWarning, - rf"`torch.{device.type}.amp.GradScaler\(args...\)` is deprecated.", + UserWarning, + rf"torch.{device.type}.amp.GradScaler\(args...\) is deprecated.", ): _ = GradScaler(init_scale=2.0) diff --git a/torch/__init__.py b/torch/__init__.py index a2492c40a949..440b833fd079 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1996,6 +1996,17 @@ def _register_device_module(device_type, module): from torch.func import vmap +# The function _sparse_coo_tensor_unsafe is removed from PyTorch +# Python API (v. 1.13), here we temporarily provide its replacement +# with a deprecation warning. +# TODO: remove the function for PyTorch v 1.15. +def _sparse_coo_tensor_unsafe(*args, **kwargs): + import warnings + warnings.warn('torch._sparse_coo_tensor_unsafe is deprecated, ' + 'use torch.sparse_coo_tensor(..., check_invariants=False) instead.') + kwargs['check_invariants'] = False + return torch.sparse_coo_tensor(*args, **kwargs) + # Register MPS specific decomps torch.backends.mps._init() diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 2c4417d9af50..8c82bf542169 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -7,6 +7,7 @@ from .comptime import comptime from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage +from .external_utils import is_compiling if TYPE_CHECKING: from torch._C._dynamo.eval_frame import ( # noqa: F401 @@ -272,7 +273,7 @@ def mark_static(t, index=None): Unlike mark_dynamic, this can be done inside a graph, in which case it induces specialization on the tensor. """ - if torch.compiler.is_compiling(): + if is_compiling(): if index is None: for s in t.size(): comptime.force_static(s) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 8a195664b403..1e164b2f7895 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -798,8 +798,7 @@ def guard_export_print(guards): warnings.warn( "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. " "If you don't migrate, we may break your explain call in the future if your user defined kwargs " - "conflict with future kwargs added to explain(f).", - FutureWarning, + "conflict with future kwargs added to explain(f)." ) return inner(*extra_args, **extra_kwargs) else: @@ -942,7 +941,7 @@ def check_signature_rewritable(graph): tb = "".join(traceback.format_list(stack)) extra = "" if len(user_stacks) > 1: - extra = f"(elided {len(user_stacks) - 1} more accesses)" + extra = f"(elided {len(user_stacks)-1} more accesses)" msg = f"{source.name()}, accessed at:\n{tb}{extra}" # TODO: option to print ALL of the stack traces at once input_errors.append(msg) @@ -1477,8 +1476,7 @@ def graph_with_interpreter(*args): warnings.warn( "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. " "If you don't migrate, we may break your export call in the future if your user defined kwargs " - "conflict with future kwargs added to export(f).", - FutureWarning, + "conflict with future kwargs added to export(f)." ) return inner(*extra_args, **extra_kwargs) else: diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 669f86c9ec59..3ba10d34b771 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -2,7 +2,6 @@ import functools from typing import List -from typing_extensions import deprecated import torch import torch.utils._pytree as pytree @@ -13,10 +12,6 @@ np = None # type: ignore[assignment] -@deprecated( - "`is_compiling` is deprecated. Use `torch.compiler.is_compiling()` instead.", - category=FutureWarning, -) def is_compiling() -> bool: """ Indicates whether we are tracing/compiling with torch.compile() or torch.export(). diff --git a/torch/_functorch/apis.py b/torch/_functorch/apis.py index 477a01583b3d..ee0c0a1984e4 100644 --- a/torch/_functorch/apis.py +++ b/torch/_functorch/apis.py @@ -188,7 +188,7 @@ def vmap( vmap does not provide general autobatching or handle variable-length sequences out of the box. """ - from torch.compiler import is_compiling + from torch._dynamo import is_compiling _check_randomness_arg(randomness) if not (chunk_size is None or chunk_size > 0): @@ -390,7 +390,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla """ # To avoid cyclical dependency. import torch._functorch.eager_transforms as eager_transforms - from torch.compiler import is_compiling + from torch._dynamo import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) @@ -432,8 +432,8 @@ def grad_and_value( See :func:`grad` for examples """ + from torch._dynamo import is_compiling from torch._functorch import eager_transforms - from torch.compiler import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_and_value_impl( diff --git a/torch/_functorch/deprecated.py b/torch/_functorch/deprecated.py index bf080fcc3165..82a34f7d41c3 100644 --- a/torch/_functorch/deprecated.py +++ b/torch/_functorch/deprecated.py @@ -1,12 +1,3 @@ -""" -The APIs in this file are exposed as `functorch.*`. They are thin wrappers -around the torch.func.* APIs that have deprecation warnings -- we're trying -to move people to the torch.func.* equivalents. - -NB: We don't use *args, **kwargs in the signatures because that changes the -documentation. -""" - import textwrap import warnings from typing import Any, Callable, Optional, Tuple, Union @@ -18,16 +9,25 @@ from torch._functorch.eager_transforms import argnums_t from torch._functorch.vmap import in_dims_t, out_dims_t +""" +The APIs in this file are exposed as `functorch.*`. They are thin wrappers +around the torch.func.* APIs that have deprecation warnings -- we're trying +to move people to the torch.func.* equivalents. + +NB: We don't use *args, **kwargs in the signatures because that changes the +documentation. +""" + def get_warning(api, new_api=None, replace_newlines=False): if new_api is None: new_api = f"torch.func.{api}" warning = ( f"We've integrated functorch into PyTorch. As the final step of the \n" - f"integration, `functorch.{api}` is deprecated as of PyTorch \n" + f"integration, functorch.{api} is deprecated as of PyTorch \n" f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n" - f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n" - f"and/or the `torch.func` migration guide for more details \n" + f"Please use {new_api} instead; see the PyTorch 2.0 release notes \n" + f"and/or the torch.func migration guide for more details \n" f"https://pytorch.org/docs/main/func.migrating.html" ) if replace_newlines: @@ -37,7 +37,7 @@ def get_warning(api, new_api=None, replace_newlines=False): def warn_deprecated(api, new_api=None): warning = get_warning(api, new_api, replace_newlines=True) - warnings.warn(warning, FutureWarning, stacklevel=2) + warnings.warn(warning, stacklevel=2) def setup_docs(functorch_api, torch_func_api=None, new_api_name=None): diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index 80751c9694fd..fff6bd67838f 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -765,7 +765,7 @@ def compute_jacobian_preallocate_and_copy(): # Dynamo does not support HOP composition if their inner function is # annotated with @functools.wraps(...). We circumvent this issue by applying # wraps only if we're not tracing with dynamo. - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) return wrapper_fn @@ -1346,7 +1346,7 @@ def push_jvp(basis): # Dynamo does not support HOP composition if their inner function is # annotated with @functools.wraps(...). We circumvent this issue by applying # wraps only if we're not tracing with dynamo. - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) return wrapper_fn diff --git a/torch/_functorch/pytree_hacks.py b/torch/_functorch/pytree_hacks.py index 96dea7ad1007..8c4b50bc6ad4 100644 --- a/torch/_functorch/pytree_hacks.py +++ b/torch/_functorch/pytree_hacks.py @@ -16,8 +16,7 @@ with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( - "`torch._functorch.pytree_hacks` is deprecated and will be removed in a future release. " - "Please `use torch.utils._pytree` instead.", + "torch._functorch.pytree_hacks is deprecated and will be removed in a future release. " + "Please use torch.utils._pytree instead.", DeprecationWarning, - stacklevel=2, ) diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index e0e22eb4202f..8b406f39a64d 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -76,7 +76,7 @@ def add(x: torch.Tensor, y: torch.Tensor): assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}" assert isinstance(dim, int), "dim must be an int, but got {type(dim)}" - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): return torch.compile(associative_scan, fullgraph=True)( combine_fn, input, dim diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index c46cad5e41e2..704a38e99e9a 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2750,8 +2750,7 @@ def make_indexer(self): """A closure containing math to read a given element""" def indexer(index): - assert len(index) == len(self.stride) - assert len(index) == len(self.size) + assert len(index) == len(self.stride) == len(self.size) result = self.offset for idx, stride, sz in zip(index, self.stride, self.size): if sz != 1: diff --git a/torch/_library/abstract_impl.py b/torch/_library/abstract_impl.py index 2946b743ee53..14d6d8c46235 100644 --- a/torch/_library/abstract_impl.py +++ b/torch/_library/abstract_impl.py @@ -1,7 +1,7 @@ import contextlib import functools +import warnings from typing import Callable, Optional -from typing_extensions import deprecated import torch from torch._library.utils import Kernel, RegistrationHandle @@ -124,11 +124,10 @@ def __init__(self, _fake_mode, _op): self._shape_env = _fake_mode.shape_env self._op = _op - @deprecated( - "`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead", - category=FutureWarning, - ) def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: + warnings.warn( + "create_unbacked_symint is deprecated, please use new_dynamic_size instead" + ) return self.new_dynamic_size(min=min, max=max) def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt: diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 10290535f930..68674af0a285 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -21,7 +21,7 @@ TYPE_CHECKING, Union, ) -from typing_extensions import deprecated, TypeAlias +from typing_extensions import TypeAlias if TYPE_CHECKING: @@ -1789,11 +1789,6 @@ def check_in_bounds_for_storage( # NOTE: This function should ideally be removed, but some Meta internal models # packaged with `torch.package` are using it, so it will have to be removed # at some point in the future when those models no longer use this function. -@deprecated( - "`torch._prims_common.check` is deprecated and will be removed in the future. " - "Please use `torch._check*` functions instead.", - category=FutureWarning, -) def check( b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError ) -> None: @@ -1806,6 +1801,12 @@ def check( .. note:: This function is planned for removal in the future. Please use `torch._check*` functions instead. """ + warnings.warn( + DeprecationWarning( + "'torch._prims_common.check' will be removed in the future. Please use " + "'torch._check*' functions instead" + ) + ) torch._check_with(exc_type, b, s) diff --git a/torch/_utils.py b/torch/_utils.py index d2bb59239a30..eec2d8231d1a 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -7,7 +7,7 @@ import warnings from collections import defaultdict from typing import Any, Callable, DefaultDict, Generic, List, Optional -from typing_extensions import deprecated, ParamSpec +from typing_extensions import ParamSpec import torch @@ -852,10 +852,6 @@ def classproperty(func): return _ClassPropertyDescriptor(func) -@deprecated( - "`is_compiling` is deprecated. Use `torch.compiler.is_compiling()` instead.", - category=FutureWarning, -) def is_compiling() -> bool: """ Indicates whether we are tracing/compiling with torch.compile() or torch.export(). diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index 465e5dbdca1b..8440abccb239 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -1,6 +1,6 @@ import functools +import warnings from typing import Any, Callable, List, Optional, Tuple, Union -from typing_extensions import deprecated import torch from torch import Tensor @@ -190,14 +190,14 @@ def _get_name(func: Callable): # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, # sends those into func, and then unwraps the output BatchedTensors. Operations # on BatchedTensors perform the batched operations that the user is asking for. -@deprecated( - "Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.", - category=FutureWarning, -) def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable: """ Please use torch.vmap instead of this API. """ + warnings.warn( + "Please use torch.vmap instead of torch._vmap_internals.vmap. ", + stacklevel=2, + ) return _vmap(func, in_dims, out_dims) diff --git a/torch/ao/nn/quantizable/modules/activation.py b/torch/ao/nn/quantizable/modules/activation.py index 0023faaaa162..56be29a09d62 100644 --- a/torch/ao/nn/quantizable/modules/activation.py +++ b/torch/ao/nn/quantizable/modules/activation.py @@ -224,6 +224,7 @@ def dequantize(self): return fp + @classmethod def from_observed(cls, other): # The whole flow is float -> observed -> quantized @@ -335,10 +336,7 @@ def _forward_impl(self, if attn_mask is not None: if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for `attn_mask` in `nn.MultiheadAttention` is deprecated. " - "Use bool tensor instead.", - ) + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") attn_mask = attn_mask.to(torch.bool) assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ f'Only float and bool types are supported for attn_mask, not {attn_mask.dtype}' @@ -356,10 +354,7 @@ def _forward_impl(self, # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for `key_padding_mask` in `nn.MultiheadAttention` is deprecated. " - "Use bool tensor instead.", - ) + warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") key_padding_mask = key_padding_mask.to(torch.bool) if self.bias_k is not None and self.bias_v is not None: if static_k is None and static_v is None: diff --git a/torch/ao/nn/quantized/dynamic/modules/rnn.py b/torch/ao/nn/quantized/dynamic/modules/rnn.py index c81771a71889..1cef66060719 100644 --- a/torch/ao/nn/quantized/dynamic/modules/rnn.py +++ b/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -1,6 +1,5 @@ import numbers import warnings -from typing_extensions import deprecated import torch import torch.nn as nn @@ -17,11 +16,8 @@ def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Ten return tensor.index_select(dim, permutation) -@deprecated( - "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", - category=FutureWarning, -) def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead") return _apply_permutation(tensor, permutation, dim) diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 5aa095b49b65..ef90f8b71ece 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -949,30 +949,24 @@ def convert( if convert_custom_config is None: convert_custom_config = ConvertCustomConfig() - if isinstance(convert_custom_config, dict): + if isinstance(convert_custom_config, Dict): warnings.warn( "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " - "in a future version. Please pass in a ConvertCustomConfig instead.", - FutureWarning, - ) + "in a future version. Please pass in a ConvertCustomConfig instead.") convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) - if isinstance(qconfig_mapping, dict): + if isinstance(qconfig_mapping, Dict): warnings.warn( "Passing a QConfig dictionary to convert is deprecated and will not be supported " - "in a future version. Please pass in a QConfigMapping instead.", - FutureWarning, - ) + "in a future version. Please pass in a QConfigMapping instead.") qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None qconfig_mapping = copy.deepcopy(qconfig_mapping) assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping) - if isinstance(backend_config, dict): + if isinstance(backend_config, Dict): warnings.warn( "Passing a backend_config_dict to prepare is deprecated and will not be supported " - "in a future version. Please pass in a BackendConfig instead.", - FutureWarning, - ) + "in a future version. Please pass in a BackendConfig instead.") backend_config = BackendConfig.from_dict(backend_config) if backend_config is None: diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py index 17b934efc6be..91b876997d10 100644 --- a/torch/ao/quantization/fx/fuse.py +++ b/torch/ao/quantization/fx/fuse.py @@ -52,20 +52,16 @@ def fuse( if fuse_custom_config is None: fuse_custom_config = FuseCustomConfig() - if isinstance(fuse_custom_config, dict): + if isinstance(fuse_custom_config, Dict): warnings.warn( "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " - "in a future version. Please pass in a FuseCustomConfig instead.", - FutureWarning, - ) + "in a future version. Please pass in a FuseCustomConfig instead.") fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) - if isinstance(backend_config, dict): + if isinstance(backend_config, Dict): warnings.warn( "Passing a backend_config_dict to prepare is deprecated and will not be supported " - "in a future version. Please pass in a BackendConfig instead.", - FutureWarning, - ) + "in a future version. Please pass in a BackendConfig instead.") backend_config = BackendConfig.from_dict(backend_config) named_modules = dict(model.named_modules()) diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index ce99fc757efb..9ca91ecb4930 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -1749,36 +1749,28 @@ def prepare( if _equalization_config is None: _equalization_config = QConfigMapping() - if isinstance(qconfig_mapping, dict): + if isinstance(qconfig_mapping, Dict): warnings.warn( "Passing a QConfig dictionary to prepare is deprecated and will not be supported " - "in a future version. Please pass in a QConfigMapping instead.", - FutureWarning, - ) + "in a future version. Please pass in a QConfigMapping instead.") qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) - if isinstance(_equalization_config, dict): + if isinstance(_equalization_config, Dict): warnings.warn( "Passing a QConfig dictionary to prepare for equalization is deprecated and will not " - "be supported in a future version. Please pass in a QConfigMapping instead.", - FutureWarning, - ) + "be supported in a future version. Please pass in a QConfigMapping instead.") _equalization_config = QConfigMapping.from_dict(_equalization_config) - if isinstance(prepare_custom_config, dict): + if isinstance(prepare_custom_config, Dict): warnings.warn( "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " - "in a future version. Please pass in a PrepareCustomConfig instead.", - FutureWarning, - ) + "in a future version. Please pass in a PrepareCustomConfig instead.") prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) - if isinstance(backend_config, dict): + if isinstance(backend_config, Dict): warnings.warn( "Passing a backend_config_dict to prepare is deprecated and will not be supported " - "in a future version. Please pass in a BackendConfig instead.", - FutureWarning, - ) + "in a future version. Please pass in a BackendConfig instead.") backend_config = BackendConfig.from_dict(backend_config) assert isinstance(qconfig_mapping, QConfigMapping) diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 88e7b47aff2b..dc8353d61729 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -1,6 +1,5 @@ from collections import namedtuple from typing import Optional, Any, Union, Type -from typing_extensions import deprecated import torch import torch.nn as nn @@ -107,10 +106,6 @@ def __new__(cls, activation, weight): return super().__new__(cls, activation, weight) -@deprecated( - "`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", - category=FutureWarning, -) class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])): """ Describes how to dynamically quantize a layer or a part of the network by providing @@ -132,6 +127,7 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): if isinstance(weight, nn.Module): raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed") + warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead") return super().__new__(cls, activation, weight) @@ -426,20 +422,16 @@ def get_default_qat_qconfig(backend='x86', version=1): weight=None, ) -@deprecated( - "`torch.ao.quantization.get_default_qconfig_dict` is deprecated and will be removed in " - "a future version. Please use `torch.ao.quantization.get_default_qconfig_mapping` instead.", - category=FutureWarning, -) def get_default_qconfig_dict(backend='x86', version=0): + warnings.warn( + "torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in " + "a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.") return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict() -@deprecated( - "`torch.ao.quantization.get_default_qat_qconfig_dict` is deprecated and will be removed in " - "a future version. Please use `torch.ao.quantization.get_default_qat_qconfig_mapping` instead.", - category=FutureWarning, -) def get_default_qat_qconfig_dict(backend='x86', version=1): + warnings.warn( + "torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in " + "a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.") return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict() def _assert_valid_qconfig(qconfig: Optional[QConfig], diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index 453c0511e4d9..c9a3db87552a 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -117,12 +117,10 @@ def _prepare_fx( if _equalization_config is None: _equalization_config = QConfigMapping() - if isinstance(prepare_custom_config, dict): + if isinstance(prepare_custom_config, Dict): warnings.warn( "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " - "in a future version. Please pass in a PrepareCustomConfig instead.", - FutureWarning, - ) + "in a future version. Please pass in a PrepareCustomConfig instead.") prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) # swap FloatFunctional with FXFloatFunctional @@ -224,12 +222,10 @@ def fuse_fx( if fuse_custom_config is None: fuse_custom_config = FuseCustomConfig() - if isinstance(fuse_custom_config, dict): + if isinstance(fuse_custom_config, Dict): warnings.warn( "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " - "in a future version. Please pass in a FuseCustomConfig instead.", - FutureWarning, - ) + "in a future version. Please pass in a FuseCustomConfig instead.") fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") @@ -515,12 +511,10 @@ def _convert_fx( if convert_custom_config is None: convert_custom_config = ConvertCustomConfig() - if isinstance(convert_custom_config, dict): + if isinstance(convert_custom_config, Dict): warnings.warn( "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " - "in a future version. Please pass in a ConvertCustomConfig instead.", - FutureWarning, - ) + "in a future version. Please pass in a ConvertCustomConfig instead.") convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) _check_is_graph_module(graph_module) diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 4cefb143dcc0..9b5788aff227 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -252,20 +252,17 @@ def backward( ) if grad_variables is not None: - warnings.warn( - "`grad_variables` is deprecated. Use `grad_tensors` instead.", - FutureWarning, - ) + warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.") if grad_tensors is None: grad_tensors = grad_variables else: raise RuntimeError( - "`grad_tensors` and `grad_variables` (deprecated) " - "arguments both passed to `backward()`. Please only " - "use `grad_tensors`." + "'grad_tensors' and 'grad_variables' (deprecated) " + "arguments both passed to backward(). Please only " + "use 'grad_tensors'." ) if inputs is not None and len(inputs) == 0: - raise RuntimeError("`inputs` argument to `backward()` cannot be empty.") + raise RuntimeError("'inputs' argument to backward() cannot be empty.") tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors) inputs = ( @@ -398,8 +395,7 @@ def grad( warnings.warn( "only_inputs argument is deprecated and is ignored now " "(defaults to True). To accumulate gradient for other " - "parts of the graph, please use torch.autograd.backward.", - FutureWarning, + "parts of the graph, please use torch.autograd.backward." ) grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs)) diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py index d2b3149bfc81..f091d38777fc 100644 --- a/torch/autograd/_functions/tensor.py +++ b/torch/autograd/_functions/tensor.py @@ -1,6 +1,6 @@ import operator +import warnings from functools import reduce -from typing_extensions import deprecated import torch import torch._utils @@ -9,12 +9,11 @@ class Type(Function): @staticmethod - @deprecated( - "`torch.autograd._functions.Type` is deprecated as of PyTorch 2.1, " - "please use `torch.tensor.to(dtype=dtype)` instead.", - category=FutureWarning, - ) def forward(ctx, i, dest_type): + warnings.warn( + "torch.autograd._functions.Type is deprecated as of PyTorch 2.1, please use " + "torch.tensor.to(dtype=dtype) instead." + ) ctx.input_type = type(i) ctx.input_device = -1 if not i.is_cuda else i.get_device() return i.type(dest_type) diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 9aca2b2a1b32..9c624ce5d14b 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -4,7 +4,6 @@ import warnings from collections import OrderedDict from typing import Any, List, Optional, Tuple -from typing_extensions import deprecated import torch import torch._C as _C @@ -180,14 +179,12 @@ def mark_dirty(self, *args: torch.Tensor): """ self.dirty_tensors = args - @deprecated( - "`mark_shared_storage` is deprecated. " - "Tensors with shared storages are automatically tracked. " - "Note that calls to `set_()` are not tracked", - category=FutureWarning, - ) def mark_shared_storage(self, *pairs): - pass + warnings.warn( + "mark_shared_storage is deprecated. " + "Tensors with shared storages are automatically tracked. Note " + "that calls to `set_()` are not tracked" + ) def mark_non_differentiable(self, *args: torch.Tensor): r"""Mark outputs as non-differentiable. @@ -494,8 +491,9 @@ class Function(_SingleLevelFunction): """ def __init__(self, *args, **kwargs): + cls = self.__class__ warnings.warn( - f"{self.__class__} should not be instantiated. Methods on autograd functions" + f"{cls} should not be instantiated. Methods on autograd functions" "are all static, so you should invoke them on the class itself. " "Instantiating an autograd function will raise an " "error in a future version of PyTorch.", diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index a0d874038761..f2e6aa22fe94 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -3,7 +3,6 @@ import warnings from itertools import product from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union -from typing_extensions import deprecated import torch import torch.testing @@ -307,14 +306,6 @@ def _get_numerical_jacobian( return jacobians -@deprecated( - "`get_numerical_jacobian` was part of PyTorch's private API and not " - "meant to be exposed. We are deprecating it and it will be removed " - "in a future version of PyTorch. If you have a specific use for " - "this or feature request for this to be a stable API, please file " - "us an issue at https://github.com/pytorch/pytorch/issues/new", - category=FutureWarning, -) def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0): """Compute the numerical Jacobian for a given fn and its inputs. @@ -334,6 +325,13 @@ def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0): Note that `target` may not even be part of `input` to `fn`, so please be **very careful** in this to not clone `target`. """ + warnings.warn( + "get_numerical_jacobian was part of PyTorch's private API and not " + "meant to be exposed. We are deprecating it and it will be removed " + "in a future version of PyTorch. If you have a specific use for " + "this or feature request for this to be a stable API, please file " + "us an issue at https://github.com/pytorch/pytorch/issues/new" + ) if ( grad_out != 1.0 ): # grad_out param is only kept for backward compatibility reasons @@ -820,17 +818,16 @@ def _get_analytical_vJu_backward_mode( return reduced_jacobians -@deprecated( - "`get_analytical_jacobian` was part of PyTorch's private API and not " - "meant to be exposed. We are deprecating it and it will be removed " - "in a future version of PyTorch. If you have a specific use for " - "this or feature request for this to be a stable API, please file " - "us an issue at https://github.com/pytorch/pytorch/issues/new", - category=FutureWarning, -) def get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0): # Replicates the behavior of the old get_analytical_jacobian before the refactor # This shares much of its code with _check_analytical_jacobian_attributes + warnings.warn( + "get_analytical_jacobian was part of PyTorch's private API and not " + "meant to be exposed. We are deprecating it and it will be removed " + "in a future version of PyTorch. If you have a specific use for " + "this or feature request for this to be a stable API, please file " + "us an issue at https://github.com/pytorch/pytorch/issues/new" + ) if ( grad_out != 1.0 ): # grad_out param is only kept for backward compatibility reasons diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 38cc0e3a3b35..5da75b608a82 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -213,9 +213,7 @@ def __init__( self.use_cuda = use_cuda if self.use_cuda: warn( - "The attribute `use_cuda` will be deprecated soon, " - "please use ``use_device = 'cuda'`` instead.", - FutureWarning, + "The attribute `use_cuda` will be deprecated soon, please use ``use_device = 'cuda'`` instead." ) self.use_device: Optional[str] = "cuda" else: diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py index e8b2b62019bc..f72d366a3677 100644 --- a/torch/autograd/profiler_legacy.py +++ b/torch/autograd/profiler_legacy.py @@ -1,6 +1,5 @@ import itertools -import warnings -from typing_extensions import deprecated +from warnings import warn import torch import torch.cuda @@ -24,11 +23,6 @@ __all__ = ["profile"] -@deprecated( - "`torch.autograd.profiler_legacy.profile` is deprecated and will be removed in a future release. " - "Please use `torch.profiler` instead.", - category=None, # TODO: change to `FutureWarning` -) class profile: """DEPRECATED: use torch.profiler instead.""" @@ -57,7 +51,7 @@ def __init__( self.with_modules = with_modules if self.use_cuda and not torch.cuda.is_available(): - warnings.warn("CUDA is not available, disabling CUDA profiling") + warn("CUDA is not available, disabling CUDA profiling") self.use_cuda = False if self.use_cuda: diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 23243733aaa8..4833f989b82a 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -6,7 +6,6 @@ from operator import attrgetter from typing import Any, Dict, List, Optional, Tuple -from typing_extensions import deprecated import torch from torch.autograd import DeviceType @@ -416,10 +415,6 @@ def device_time(self): return 0.0 if self.count == 0 else 1.0 * self.device_time_total / self.count # type: ignore[attr-defined] @property - @deprecated( - "`cuda_time` is deprecated, please use `device_time` instead.", - category=FutureWarning, - ) def cuda_time(self): # To be deprecated return self.device_time @@ -543,12 +538,8 @@ def self_device_memory_usage(self): ) @property - @deprecated( - "`self_cuda_memory_usage` is deprecated. Use `self_device_memory_usage` instead.", - category=FutureWarning, - ) def self_cuda_memory_usage(self): # To be deprecated - return self.self_device_memory_usage + self.self_device_memory_usage @property def cpu_time_total(self): @@ -583,12 +574,8 @@ def device_time_total(self): return self.time_range.elapsed_us() @property - @deprecated( - "`cuda_time_total` is deprecated. Use `device_time_total` instead.", - category=FutureWarning, - ) def cuda_time_total(self): # To be deprecated - return self.device_time_total + self.device_time_total @property def self_device_time_total(self): @@ -603,12 +590,8 @@ def self_device_time_total(self): return self.device_time_total @property - @deprecated( - "`self_cuda_time_total` is deprecated. Use `self_device_time_total` instead.", - category=FutureWarning, - ) def self_cuda_time_total(self): # To be deprecated - return self.self_device_time_total + self.self_device_time_total @property def key(self): diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index c35a962ba693..f1b68a446225 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -1,7 +1,7 @@ import contextlib +import warnings from typing import Union -from typing_extensions import deprecated import torch @@ -377,15 +377,6 @@ def enable_cudnn_sdp(enabled: bool): @contextlib.contextmanager -@deprecated( - ( - "`torch.backends.cuda.sdp_kernel()` is deprecated. " - "In the future, this context manager will be removed. " - "Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, " - "with updated signature." - ), - category=FutureWarning, -) def sdp_kernel( enable_flash: bool = True, enable_math: bool = True, @@ -398,6 +389,15 @@ def sdp_kernel( This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention. Upon exiting the context manager, the previous state of the flags will be restored. """ + warnings.warn( + ( + "torch.backends.cuda.sdp_kernel() " + "is deprecated. In the future, this context manager will be removed. " + "Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated " + "signature." + ), + FutureWarning, + ) from torch.nn.attention import sdpa_kernel backend_list = [] diff --git a/torch/cpu/amp/autocast_mode.py b/torch/cpu/amp/autocast_mode.py index b545e91dd6f4..3f0a574f7d38 100644 --- a/torch/cpu/amp/autocast_mode.py +++ b/torch/cpu/amp/autocast_mode.py @@ -1,5 +1,5 @@ +import warnings from typing import Any -from typing_extensions import deprecated import torch @@ -12,11 +12,6 @@ class autocast(torch.amp.autocast_mode.autocast): ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cpu", args...)`` instead. """ - @deprecated( - "`torch.cpu.amp.autocast(args...)` is deprecated. " - "Please use `torch.amp.autocast('cpu', args...)` instead.", - category=FutureWarning, - ) def __init__( self, enabled: bool = True, @@ -28,6 +23,10 @@ def __init__( self.device = "cpu" self.fast_dtype = dtype return + warnings.warn( + "torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.", + DeprecationWarning, + ) super().__init__( "cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled ) diff --git a/torch/cpu/amp/grad_scaler.py b/torch/cpu/amp/grad_scaler.py index 72b893a02a49..2c93e0100f16 100644 --- a/torch/cpu/amp/grad_scaler.py +++ b/torch/cpu/amp/grad_scaler.py @@ -1,4 +1,4 @@ -from typing_extensions import deprecated +import warnings import torch @@ -11,11 +11,6 @@ class GradScaler(torch.amp.GradScaler): ``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead. """ - @deprecated( - "`torch.cpu.amp.GradScaler(args...)` is deprecated. " - "Please use `torch.amp.GradScaler('cpu', args...)` instead.", - category=FutureWarning, - ) def __init__( self, init_scale: float = 2.0**16, @@ -24,6 +19,9 @@ def __init__( growth_interval: int = 2000, enabled: bool = True, ) -> None: + warnings.warn( + "torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead." + ) super().__init__( "cpu", init_scale=init_scale, diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index a44854d1524c..587d7e9c7c5e 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -145,8 +145,8 @@ def _seg_info(seg): before_segs = {_seg_key(seg) for seg in before} after_segs = {_seg_key(seg) for seg in after} - print(f'only_before = {[a for a, _ in (before_segs - after_segs)]}') - print(f'only_after = {[a for a, _ in (after_segs - before_segs)]}') + print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}') + print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}') for seg in before: if _seg_key(seg) not in after_segs: @@ -382,7 +382,7 @@ def find_segment(addr): def _format_viz(data, viz_kind, device): if device is not None: - warnings.warn('device argument is deprecated, plots now contain all device', FutureWarning) + warnings.warn('device argument is deprecated, plots now contain all device') buffer = pickle.dumps(data) buffer += b'\x00' * (3 - len(buffer) % 3) # Encode the buffer with base64 diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index e50206c70577..e953d20cb2a5 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -1,6 +1,6 @@ import functools +import warnings from typing import Any -from typing_extensions import deprecated import torch @@ -13,11 +13,6 @@ class autocast(torch.amp.autocast_mode.autocast): ``torch.cuda.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` instead. """ - @deprecated( - "`torch.cuda.amp.autocast(args...)` is deprecated. " - "Please use `torch.amp.autocast('cuda', args...)` instead.", - category=FutureWarning, - ) def __init__( self, enabled: bool = True, @@ -29,6 +24,10 @@ def __init__( self.device = "cuda" self.fast_dtype = dtype return + warnings.warn( + "torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.", + DeprecationWarning, + ) super().__init__( "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled ) @@ -50,29 +49,25 @@ def __call__(self, func): return super().__call__(func) -@deprecated( - "`torch.cuda.amp.custom_fwd(args...)` is deprecated. " - "Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.", - category=FutureWarning, -) def custom_fwd(fwd=None, *, cast_inputs=None): """ ``torch.cuda.amp.custom_fwd(args...)`` is deprecated. Please use ``torch.amp.custom_fwd(args..., device_type='cuda')`` instead. """ + warnings.warn( + "torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead." + ) return functools.partial(torch.amp.custom_fwd, device_type="cuda")( fwd=fwd, cast_inputs=cast_inputs ) -@deprecated( - "`torch.cuda.amp.custom_bwd(args...)` is deprecated. " - "Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.", - category=FutureWarning, -) def custom_bwd(bwd): """ ``torch.cuda.amp.custom_bwd(args...)`` is deprecated. Please use ``torch.amp.custom_bwd(args..., device_type='cuda')`` instead. """ + warnings.warn( + "torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead." + ) return functools.partial(torch.amp.custom_bwd, device_type="cuda")(bwd) diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py index 367f21594f1c..8263fcdb480d 100644 --- a/torch/cuda/amp/grad_scaler.py +++ b/torch/cuda/amp/grad_scaler.py @@ -1,4 +1,4 @@ -from typing_extensions import deprecated +import warnings import torch @@ -11,11 +11,6 @@ class GradScaler(torch.amp.GradScaler): ``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead. """ - @deprecated( - "`torch.cuda.amp.GradScaler(args...)` is deprecated. " - "Please use `torch.amp.GradScaler('cuda', args...)` instead.", - category=FutureWarning, - ) def __init__( self, init_scale: float = 2.0**16, @@ -24,6 +19,9 @@ def __init__( growth_interval: int = 2000, enabled: bool = True, ) -> None: + warnings.warn( + "torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead." + ) super().__init__( "cuda", init_scale=init_scale, diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 0f12395ac778..a593a3810834 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -9,7 +9,6 @@ from inspect import signature from typing import Any, Dict, Optional, Tuple, Union -from typing_extensions import deprecated import torch from torch import _C @@ -447,21 +446,21 @@ def max_memory_reserved(device: Union[Device, int] = None) -> int: return memory_stats(device=device).get("reserved_bytes.all.peak", 0) -@deprecated( - "`torch.cuda.memory_cached` has been renamed to `torch.cuda.memory_reserved`", - category=FutureWarning, -) def memory_cached(device: Union[Device, int] = None) -> int: r"""Deprecated; see :func:`~torch.cuda.memory_reserved`.""" + warnings.warn( + "torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved", + FutureWarning, + ) return memory_reserved(device=device) -@deprecated( - "`torch.cuda.max_memory_cached` has been renamed to `torch.cuda.max_memory_reserved`", - category=FutureWarning, -) def max_memory_cached(device: Union[Device, int] = None) -> int: r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`.""" + warnings.warn( + "torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved", + FutureWarning, + ) return max_memory_reserved(device=device) diff --git a/torch/cuda/nccl.py b/torch/cuda/nccl.py index 4170e20b5318..05751ab5f87b 100644 --- a/torch/cuda/nccl.py +++ b/torch/cuda/nccl.py @@ -89,9 +89,8 @@ def reduce( ) else: warnings.warn( - "`nccl.reduce` with an output tensor list is deprecated. " - "Please specify a single output tensor with argument 'output' instead instead.", - FutureWarning, + "nccl.reduce with an output tensor list is deprecated. " + "Please specify a single output tensor with argument 'output' instead instead." ) _output = outputs[root] elif not isinstance(output, torch.Tensor) and isinstance( @@ -100,8 +99,7 @@ def reduce( # User called old API with positional arguments of list of output tensors. warnings.warn( "nccl.reduce with an output tensor list is deprecated. " - "Please specify a single output tensor.", - FutureWarning, + "Please specify a single output tensor." ) _output = output[root] else: diff --git a/torch/distributed/_composable/fully_shard.py b/torch/distributed/_composable/fully_shard.py index 950a034071a4..37e3d1544cd1 100644 --- a/torch/distributed/_composable/fully_shard.py +++ b/torch/distributed/_composable/fully_shard.py @@ -1,5 +1,5 @@ +import warnings from typing import Callable, Iterable, Optional, Union -from typing_extensions import deprecated import torch import torch.distributed as dist @@ -38,13 +38,6 @@ @contract(state_cls=_FSDPState) -@deprecated( - "`torch.distributed._composable.fully_shard` is being deprecated. " - "You can continue to use the wrapper based FSDP. " - "See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py. " - "`torch.distributed._composable.fully_shard` will be removed after PyTorch 2.5.", - category=FutureWarning, -) def fully_shard( module: nn.Module, *, @@ -62,7 +55,16 @@ def fully_shard( Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] ] = None, ) -> nn.Module: - """Applies ``FullyShardedDataParallel`` (FSDP) semantics to ``module``.""" + """ + Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``. + """ + warnings.warn( + "``torch.distributed._composable.fully_shard`` is being deprecated." + "You can contintue to use the wrapper based FSDP." + "See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py." + "``torch.distributed._composable.fully_shard`` will be removed after PyTorch 2.5." + ) + torch._C._log_api_usage_once("torch.distributed.fully_shard") # Enforce the new auto wrap policy if policy is not None and not isinstance(policy, _Policy): diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 0e58f0a2b3a1..8d598713cf50 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -766,8 +766,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: warnings.warn( "The combination of ranks + tag as process group " "identifier has been deprecated. Please switch to " - "using ProcessGroup, DeviceMesh, or group name instead.", - FutureWarning, + "using ProcessGroup, DeviceMesh, or group name instead." ) return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag) else: diff --git a/torch/distributed/_shard/checkpoint/__init__.py b/torch/distributed/_shard/checkpoint/__init__.py index 161a43f276d6..166c6f9254cf 100644 --- a/torch/distributed/_shard/checkpoint/__init__.py +++ b/torch/distributed/_shard/checkpoint/__init__.py @@ -5,15 +5,8 @@ import warnings from torch.distributed.checkpoint import * # noqa: F403 - - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "`torch.distributed._shard.checkpoint` will be deprecated, " - "use `torch.distributed.checkpoint` instead", - DeprecationWarning, - stacklevel=2, - ) - +warnings.warn( + "torch.distributed._shard.checkpoint will be deprecated, use torch.distributed.checkpoint instead", + DeprecationWarning +) sys.modules['torch.distributed._shard.checkpoint'] = torch.distributed.checkpoint diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 65da388d0f4f..a5e961e4bb78 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -10,7 +10,6 @@ cast, TYPE_CHECKING, ) -from typing_extensions import deprecated import copy import warnings from functools import reduce @@ -397,7 +396,7 @@ def shard_size(shard_md): return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined] if enforce_dtype: - warnings.warn("`enforce_dtype` is deprecated. Please use `dtype` instead.", FutureWarning) + warnings.warn("enforce_dtype is deprecated. Please use dtype instead.") rank = dist.get_rank(self._process_group) full_size = self.metadata().size @@ -738,7 +737,6 @@ def _init_from_local_shards( return sharded_tensor @classmethod - @deprecated(DEPRECATE_MSG, category=FutureWarning) def _init_from_local_tensor( cls, local_tensor: torch.Tensor, @@ -803,6 +801,8 @@ def _init_from_local_tensor( We fully rely on the user to ensure local tensor is sharded based on the sharding spec. """ + warnings.warn(DEPRECATE_MSG) + if not local_tensor.is_contiguous(): raise ValueError('local_tensor is not a contiguous Tensor.') @@ -980,7 +980,6 @@ def sharding_spec(self) -> shard_spec.ShardingSpec: """ return self._sharding_spec - @deprecated(DEPRECATE_MSG, category=FutureWarning) def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: """ Reshard a sharded tensor given the ``resharding_spec``. For now, we only support @@ -1051,6 +1050,8 @@ def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2 tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3 """ + warnings.warn(DEPRECATE_MSG) + if ( not isinstance(resharding_spec, shard_spec.ChunkShardingSpec) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec) @@ -1095,7 +1096,6 @@ def local_tensor(self) -> torch.Tensor: return self.local_shards()[0].tensor @classmethod - @deprecated(DEPRECATE_MSG, category=FutureWarning) def __torch_function__(cls, func, types, args=(), kwargs=None): def dispatch(st: ShardedTensor, func: Callable): # Dispatch to custom user provided op first if it exists. @@ -1120,6 +1120,7 @@ def dispatch(st: ShardedTensor, func: Callable): f"torch function '{func.__name__}', with args: {args} and " f"kwargs: {kwargs} not supported for ShardedTensor!") + warnings.warn(DEPRECATE_MSG) # Find ShardedTensor instance to get process_group and sharding_spec. st_instance = None diff --git a/torch/distributed/_sharded_tensor/__init__.py b/torch/distributed/_sharded_tensor/__init__.py index 6c6694cfb081..9e6b1662589c 100644 --- a/torch/distributed/_sharded_tensor/__init__.py +++ b/torch/distributed/_sharded_tensor/__init__.py @@ -5,14 +5,8 @@ import warnings from torch.distributed._shard.sharded_tensor import * # noqa: F403 - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "`torch.distributed._sharded_tensor` will be deprecated, " - "use `torch.distributed._shard.sharded_tensor` instead", - DeprecationWarning, - stacklevel=2, - ) - +warnings.warn( + "torch.distributed._sharded_tensor will be deprecated, use torch.distributed._shard.sharded_tensor instead", + DeprecationWarning +) sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor diff --git a/torch/distributed/_sharding_spec/__init__.py b/torch/distributed/_sharding_spec/__init__.py index 21c56d5dc849..f3060005dbdd 100644 --- a/torch/distributed/_sharding_spec/__init__.py +++ b/torch/distributed/_sharding_spec/__init__.py @@ -5,15 +5,10 @@ import warnings from torch.distributed._shard.sharding_spec import * # noqa: F403 - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "`torch.distributed._sharding_spec` will be deprecated, " - "use `torch.distributed._shard.sharding_spec` instead", - DeprecationWarning, - stacklevel=2, - ) +warnings.warn( + "torch.distributed._sharding_spec will be deprecated, use torch.distributed._shard.sharding_spec instead", + DeprecationWarning +) import torch.distributed._shard.sharding_spec as _sharding_spec sys.modules['torch.distributed._sharding_spec'] = _sharding_spec diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 287d07a7c868..ba25c628a83e 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -746,7 +746,6 @@ def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: warnings.warn( "Deprecating input_fn that takes two arguments (inputs, device_mesh), " "please use input_fn that takes in (module, inputs, device_mesh) instead!", - FutureWarning, ) module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] elif num_args == 3: @@ -766,7 +765,6 @@ def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: warnings.warn( "Deprecating output_fn that takes two arguments (inputs, device_mesh), " "please use output_fn that takes in (module, inputs, device_mesh) instead!", - FutureWarning, ) module.register_forward_hook( lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index 0d55b37c0044..364648f1a7f7 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -233,7 +233,6 @@ def checkpoint_wrapper( f"Please specify {CheckpointImpl.NO_REENTRANT} as " f"{CheckpointImpl.REENTRANT} will soon be removed as " "the default and eventually deprecated.", - FutureWarning, stacklevel=1, ) return CheckpointWrapper( diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index 791061e34f90..bff55327e847 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -85,7 +85,7 @@ def decompress(fut): decompressed_tensor.copy_(value) return decompressed_tensor - if torch.compiler.is_compiling(): + if torch._utils.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) @@ -134,7 +134,7 @@ def decompress(fut): decompressed_tensor.copy_(value) return decompressed_tensor - if torch.compiler.is_compiling(): + if torch._utils.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index 6c1546e1cc0f..b7e1337e6c4f 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -1,7 +1,6 @@ import os import warnings from typing import Any, cast, Dict, Optional, Set, Union -from typing_extensions import deprecated import torch import torch.distributed as dist @@ -18,11 +17,6 @@ __all__ = ["load_state_dict", "load"] -@deprecated( - "`load_state_dict` is deprecated and will be removed in future versions. " - "Please use `load` instead.", - category=FutureWarning, -) def load_state_dict( state_dict: Dict[str, Any], storage_reader: StorageReader, @@ -32,6 +26,10 @@ def load_state_dict( planner: Optional[LoadPlanner] = None, ) -> None: """This method is deprecated. Please switch to 'load'.""" + warnings.warn( + "'load_state_dict' is deprecated and will be removed in future versions. " + "Please use 'load' instead." + ) storage_reader.reset() with _profile(): # TODO: test returning `load` here instead. diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index 451603288d12..0313f1c2ab61 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -3,7 +3,6 @@ import warnings from concurrent.futures import Future, ThreadPoolExecutor from typing import cast, Optional, Union -from typing_extensions import deprecated import torch import torch.distributed as dist @@ -25,11 +24,6 @@ __all__ = ["save_state_dict", "save", "async_save"] -@deprecated( - "`save_state_dict` is deprecated and will be removed in future versions." - "Please use `save` instead.", - category=FutureWarning, -) def save_state_dict( state_dict: STATE_DICT_TYPE, storage_writer: StorageWriter, @@ -39,6 +33,11 @@ def save_state_dict( planner: Optional[SavePlanner] = None, ) -> Metadata: """This method is deprecated. Please switch to 'save'.""" + warnings.warn( + "'save_state_dict' is deprecated and will be removed in future versions." + "Please use 'save' instead." + ) + storage_writer.reset() # TODO: test returning `save` here instead. diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index decf309cfec1..6fc505c78110 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -14,7 +14,6 @@ from collections import namedtuple from datetime import timedelta from typing import Any, Callable, Dict, Optional, Tuple, Union, List, TYPE_CHECKING -from typing_extensions import deprecated import torch from torch._C._distributed_c10d import ( @@ -365,12 +364,11 @@ def __init__(self): setattr(self, k, v) self.__members__ = ReduceOp.RedOpType.__members__ - @deprecated( - "`torch.distributed.reduce_op` is deprecated, " - "please use `torch.distributed.ReduceOp` instead", - category=FutureWarning, - ) def __getattribute__(self, key): + warnings.warn( + "torch.distributed.reduce_op is deprecated, please use " + "torch.distributed.ReduceOp instead" + ) return object.__getattribute__(self, key) @@ -677,8 +675,7 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device warnings.warn( f"You are using a Backend {type(group)} as a ProcessGroup. " "This usage is deprecated since PyTorch 2.0. Please use a public API " - "of PyTorch Distributed instead.", - FutureWarning, + "of PyTorch Distributed instead." ) # Most users create Gloo with private API for object collectives _world.pg_default_device[group] = torch.device("cpu") @@ -832,15 +829,13 @@ def get_global_rank(group: ProcessGroup, group_rank: int) -> int: return rank raise ValueError(f"Group rank {group_rank} is not part of group {group}") - # TODO: remove this once the ecosystem moves away from it. -@deprecated( - "`torch.distributed.distributed_c10d._get_global_rank` is deprecated, " - "please use `torch.distributed.distributed_c10d.get_global_rank` instead", - category=FutureWarning, -) def _get_global_rank(group, rank) -> int: """Use get_global_rank as this method is deprecated.""" + warnings.warn( + "torch.distributed.distributed_c10d._get_global_rank is deprecated " + "please use torch.distributed.distributed_c10d.get_global_rank instead" + ) return get_global_rank(group, rank) @@ -2291,12 +2286,6 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): work.wait() @_exception_logger -@deprecated( - "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must " - "use it, please revisit our documentation later at " - "https://pytorch.org/docs/main/distributed.html#collective-functions", - category=FutureWarning, -) def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): """ WARNING: at this time individual shape checking is not implemented across nodes. @@ -2331,6 +2320,11 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): None, if not async_op or if not part of the group. """ + warnings.warn( + "torch.distributed.all_reduce_coalesced will be deprecated. If you must " + "use it, please revisit our documentation later at " + "https://pytorch.org/docs/main/distributed.html#collective-functions" + ) if isinstance(tensors, torch.Tensor): tensors = [tensors] _check_tensor_list(tensors, "tensor") @@ -3204,11 +3198,6 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal @_exception_logger -@deprecated( - "`torch.distributed._all_gather_base` is a private function and will be deprecated. " - "Please use `torch.distributed.all_gather_into_tensor` instead.", - category=FutureWarning, -) def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False): """ Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. @@ -3230,16 +3219,15 @@ def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False): `all_gather_into_tensor` instead. """ + warnings.warn( + "torch.distributed._all_gather_base is a private function and will be " + "deprecated. Please use torch.distributed.all_gather_into_tensor " + "instead." + ) return all_gather_into_tensor(output_tensor, input_tensor, group, async_op) @_exception_logger -@deprecated( - "`torch.distributed.all_gather_coalesced` will be deprecated. If you must use it, " - "please revisit our documentation later at " - "https://pytorch.org/docs/main/distributed.html#collective-functions", - category=FutureWarning, -) def all_gather_coalesced( output_tensor_lists, input_tensor_list, group=None, async_op=False ): @@ -3286,6 +3274,11 @@ def all_gather_coalesced( performance improvements but users of this function should take extra care to ensure that each node passes in tensors whose shapes match across nodes. """ + warnings.warn( + "torch.distributed.all_gather_coalesced will be deprecated. If you must " + "use it, please revisit our documentation later at " + "https://pytorch.org/docs/main/distributed.html#collective-functions" + ) # We only check basic compatibility with C++ params here, C++ code will # do shape and type checking. if _rank_not_in_group(group): @@ -3615,11 +3608,6 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F work.wait() -@deprecated( - "`torch.distributed._reduce_scatter_base` is a private function and will be deprecated. " - "Please use `torch.distributed.reduce_scatter_tensor` instead.", - category=FutureWarning, -) def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False): """ Reduces, then scatters a flattened tensor to all processes in a group. @@ -3640,6 +3628,11 @@ def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=Fa `reduce_scatter_tensor` instead. """ + warnings.warn( + "torch.distributed._reduce_scatter_base is a private function and will " + "be deprecated. Please use torch.distributed.reduce_scatter_tensor " + "instead." + ) return reduce_scatter_tensor(output, input, op, group, async_op) diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 11a3930acf70..1499943c78d2 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -8,10 +8,10 @@ import abc import time +import warnings from collections import namedtuple from functools import wraps from typing import Dict, Optional -from typing_extensions import deprecated __all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream', 'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms', @@ -137,7 +137,6 @@ def wrapper(*args, **kwargs): return wrap -@deprecated("Deprecated, use `@prof` instead", category=FutureWarning) def profile(group=None): """ @profile decorator adds latency and success/failure metrics to any given function. @@ -149,6 +148,8 @@ def profile(group=None): @metrics.profile("my_metric_group") def some_function(): """ + warnings.warn("Deprecated, use @prof instead", DeprecationWarning) + def wrap(func): @wraps(func) def wrapper(*args, **kwargs): @@ -186,11 +187,10 @@ def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchel getStream(metric_group).add_value(metric_name, metric_value) -@deprecated( - "Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead", - category=FutureWarning, -) def publish_metric(metric_group: str, metric_name: str, metric_value: int): + warnings.warn( + "Deprecated, use put_metric(metric_group)(metric_name, metric_value) instead" + ) metric_stream = getStream(metric_group) metric_stream.add_value(metric_name, metric_value) diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 2364b1871206..ddd48c48a0ac 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -446,8 +446,7 @@ def _init_core_state( elif sharding_strategy == ShardingStrategy.NO_SHARD: warnings.warn( "The `NO_SHARD` sharding strategy is deprecated. If having issues, " - "please use `DistributedDataParallel` instead.", - FutureWarning, + "please use DistributedDataParallel instead.", # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and # level 3 is from the true caller stacklevel=3, diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index fdb72ce0b219..766fb76bbd09 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -1201,9 +1201,8 @@ def clip_grad_norm_( def _warn_optim_input(optim_input): if optim_input is not None: warnings.warn( - "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. " - "You may remove it from your code without changing its functionality.", - FutureWarning, + "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. You may remove it " + "from your code without changing its functionality." ) @staticmethod @@ -1222,8 +1221,7 @@ def _warn_legacy_optim_state_dict(curr: str, new: str): warnings.warn( f"``FullyShardedDataParallel.{curr}``is being deprecated and is " f"replaced by ``FullyShardedDataParallel.{new}``. " - f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2.", - FutureWarning, + f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2." ) @staticmethod diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index 3efb0c3cf31d..c95804b8e8bb 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -159,7 +159,7 @@ """ -from typing_extensions import deprecated as _deprecated +import warnings from torch.distributed.run import get_args_parser, run @@ -188,17 +188,17 @@ def launch(args): run(args) -@_deprecated( - "The module torch.distributed.launch is deprecated\n" - "and will be removed in future. Use torchrun.\n" - "Note that --use-env is set by default in torchrun.\n" - "If your script expects `--local-rank` argument to be set, please\n" - "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" - "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" - "further instructions\n", - category=FutureWarning, -) def main(args=None): + warnings.warn( + "The module torch.distributed.launch is deprecated\n" + "and will be removed in future. Use torchrun.\n" + "Note that --use-env is set by default in torchrun.\n" + "If your script expects `--local-rank` argument to be set, please\n" + "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" + "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" + "further instructions\n", + FutureWarning, + ) args = parse_args(args) launch(args) diff --git a/torch/distributed/optim/__init__.py b/torch/distributed/optim/__init__.py index fe33265fd532..0b576c65afea 100644 --- a/torch/distributed/optim/__init__.py +++ b/torch/distributed/optim/__init__.py @@ -5,8 +5,6 @@ optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to apply the gradients on each worker. """ -import warnings - import torch from torch import optim @@ -26,15 +24,9 @@ from .named_optimizer import _NamedOptimizer from .utils import as_functional_optim -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "`TorchScript` support for functional optimizers is deprecated " - "and will be removed in a future PyTorch release. " - "Consider using the `torch.compile` optimizer instead.", - DeprecationWarning, - stacklevel=2, - ) +from warnings import warn +warn("TorchScript support for functional optimizers is" + "deprecated and will be removed in a future PyTorch release. Consider using the torch.compile optimizer instead.") # DistributedOptimizer imports torch.distributed.rpc names, so gate availability # based on RPC being available. diff --git a/torch/distributed/pipeline/__init__.py b/torch/distributed/pipeline/__init__.py index eacd2bc99d04..5bc82f0692c1 100644 --- a/torch/distributed/pipeline/__init__.py +++ b/torch/distributed/pipeline/__init__.py @@ -1,13 +1,7 @@ import warnings - - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "`torch.distributed.pipeline` is deprecated. For up-to-date pipeline parallel " - "implementation, please refer to the PiPPy library under the PyTorch " - "organization (Pipeline Parallelism for PyTorch): " - "https://github.com/pytorch/PiPPy", - DeprecationWarning, - stacklevel=2, - ) +warnings.warn( + "torch.distributed.pipeline is deprecated. For up-to-date pipeline parallel " + "implementation, please refer to the PiPPy library under the PyTorch " + "organization (Pipeline Parallelism for PyTorch): " + "https://github.com/pytorch/PiPPy" +) diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 651a4cc9a847..c31170a0cd57 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -22,10 +22,7 @@ def _deprecate_warnings(func_name: str, extra_msg: str) -> None: """ # TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo. if not is_torchdynamo_compiling(): - warnings.warn( - f"{func_name} is deprecated and will be removed soon. {extra_msg}", - FutureWarning, - ) + warnings.warn(f"{func_name} is deprecated and will be removed soon. {extra_msg}") def _validate_tp_mesh_dim( diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 2fb05828a8b3..2752d710e8fb 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -1,6 +1,5 @@ import warnings from typing import Any, Dict, Optional, Tuple -from typing_extensions import deprecated import torch from torch.distributions import constraints @@ -172,15 +171,14 @@ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """ raise NotImplementedError - @deprecated( - "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.", - category=FutureWarning, - ) def sample_n(self, n: int) -> torch.Tensor: """ Generates n samples or n batches of samples if the distribution parameters are batched. """ + warnings.warn( + "sample_n will be deprecated. Use .sample((n,)) instead", UserWarning + ) return self.sample(torch.Size((n,))) def log_prob(self, value: torch.Tensor) -> torch.Tensor: diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index c46e47e5d35b..d2a8e6bfc7ff 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -1,6 +1,5 @@ from warnings import warn import inspect -from typing_extensions import deprecated from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning from .utils import expand_tuples from .variadic import Variadic, isvariadic @@ -28,21 +27,24 @@ def ambiguity_warn(dispatcher, ambiguities): warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) -@deprecated( - "`halt_ordering` is deprecated, you can safely remove this call.", - category=FutureWarning, -) def halt_ordering(): - """Deprecated interface to temporarily disable ordering.""" + """Deprecated interface to temporarily disable ordering. + """ + warn( + 'halt_ordering is deprecated, you can safely remove this call.', + DeprecationWarning, + ) -@deprecated( - "`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, " - "you should call the `reorder()` method on each dispatcher.", - category=FutureWarning, -) def restart_ordering(on_ambiguity=ambiguity_warn): - """Deprecated interface to temporarily resume ordering.""" + """Deprecated interface to temporarily resume ordering. + """ + warn( + 'restart_ordering is deprecated, if you would like to eagerly order' + 'the dispatchers, you should call the ``reorder()`` method on each' + ' dispatcher.', + DeprecationWarning, + ) def variadic_signature_matches_iter(types, full_signature): @@ -314,12 +316,14 @@ def dispatch_iter(self, *types): result = self.funcs[signature] yield result - @deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning) def resolve(self, types): """ Determine appropriate implementation for this type signature .. deprecated:: 0.4.4 Use ``dispatch(*types)`` instead """ + warn("resolve() is deprecated, use dispatch(*types)", + DeprecationWarning) + return self.dispatch(*types) def __getstate__(self): diff --git a/torch/hub.py b/torch/hub.py index 4ea92ed6be82..286dfbaa59b2 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -13,7 +13,6 @@ import zipfile from pathlib import Path from typing import Dict, Optional, Any -from typing_extensions import deprecated from urllib.error import HTTPError, URLError from urllib.request import urlopen, Request from urllib.parse import urlparse # noqa: F401 @@ -681,13 +680,10 @@ def _is_legacy_zip_format(filename: str) -> bool: return False -@deprecated( - 'Falling back to the old format < 1.6. This support will be ' - 'deprecated in favor of default zipfile format introduced in 1.6. ' - 'Please redo torch.save() to save it in the new zipfile format.', - category=FutureWarning, -) def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]: + warnings.warn('Falling back to the old format < 1.6. This support will be ' + 'deprecated in favor of default zipfile format introduced in 1.6. ' + 'Please redo torch.save() to save it in the new zipfile format.') # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. # We deliberately don't handle tarfile here since our legacy serialization format was in tar. # E.g. resnet18-5c106cde.pth which is widely used. diff --git a/torch/jit/_script.py b/torch/jit/_script.py index b77b0d2ea45f..8c223c66318c 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1094,9 +1094,7 @@ def _script_impl( if optimize is not None: warnings.warn( - "`optimize` is deprecated and has no effect. " - "Use `with torch.jit.optimized_execution()` instead", - FutureWarning, + "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" ) # No-op for modules, functions, class instances that are already scripted diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 2713a66a4499..9dbbce88db7b 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -978,9 +978,7 @@ def forward(self, x): return func if optimize is not None: warnings.warn( - "`optimize` is deprecated and has no effect. " - "Use `with torch.jit.optimized_execution()` instead", - FutureWarning, + "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" ) from torch._utils_internal import ( @@ -1187,9 +1185,7 @@ def weighted_kernel_sum(self, weight): return mod if optimize is not None: warnings.warn( - "`optimize` is deprecated and has no effect. " - "Use `with torch.jit.optimized_execution()` instead", - FutureWarning, + "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" ) var_lookup_fn = _create_interpreter_name_lookup_fn(0) diff --git a/torch/library.py b/torch/library.py index a69e16950f7e..68aefb84a206 100644 --- a/torch/library.py +++ b/torch/library.py @@ -1,6 +1,5 @@ from ._ops import OpOverload from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence -from typing_extensions import deprecated import traceback import torch import weakref @@ -9,6 +8,7 @@ import re import contextlib import sys +import warnings from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t, CustomOpDef import torch._library as _library @@ -451,15 +451,15 @@ def wrap(f): return wrap -@deprecated( - "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that " - "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.", - category=FutureWarning, -) def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4. Please use that instead. """ + warnings.warn("torch.library.impl_abstract was renamed to " + "torch.library.register_fake. Please use that instead; " + "we will remove torch.library.impl_abstract in a future " + "version of PyTorch.", + DeprecationWarning, stacklevel=2) if func is not None: _stacklevel = _stacklevel + 1 return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel) diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 7c5a0896b436..a6ddc0102ce4 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -277,5 +277,5 @@ def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): "To use a different start_method use:\n\t\t" " torch.multiprocessing.start_processes(...)" ) - warnings.warn(msg, FutureWarning) + warnings.warn(msg) return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 8d0d43087b2c..38d4bd5756fd 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1818,8 +1818,7 @@ def softsign(input): # noqa: D400,D402 def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: warnings.warn( - f"Implicit dimension choice for {name} has been deprecated. " - "Change the call to include dim=X as an argument.", + f"Implicit dimension choice for {name} has been deprecated. Change the call to include dim=X as an argument.", stacklevel=stacklevel, ) if ndim == 0 or ndim == 1 or ndim == 3: @@ -3824,10 +3823,7 @@ def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners= affects the outputs. """ - warnings.warn( - "`nn.functional.upsample` is deprecated. " - "Use `nn.functional.interpolate` instead.", - ) + warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.") return interpolate(input, size, scale_factor, mode, align_corners) @@ -4147,10 +4143,7 @@ def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 {backward_reproducibility_note} """ # DeprecationWarning is ignored by default - warnings.warn( - "`nn.functional.upsample_nearest` is deprecated. " - "Use `nn.functional.interpolate` instead.", - ) + warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.") return interpolate(input, size, scale_factor, mode="nearest") @@ -4206,10 +4199,7 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 {backward_reproducibility_note} """ # DeprecationWarning is ignored by default - warnings.warn( - "`nn.functional.upsample_bilinear` is deprecated. " - "Use `nn.functional.interpolate` instead.", - ) + warnings.warn("nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead.") return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) diff --git a/torch/nn/init.py b/torch/nn/init.py index f5be081e7dd0..426069d780c0 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -599,11 +599,7 @@ def _make_deprecate(meth): old_name = new_name[:-1] def deprecated_init(*args, **kwargs): - warnings.warn( - f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.", - FutureWarning, - stacklevel=2, - ) + warnings.warn(f"nn.init.{old_name} is now deprecated in favor of nn.init.{new_name}.", stacklevel=2) return meth(*args, **kwargs) deprecated_init.__doc__ = fr""" diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 0d8911893011..bf15c3342d1d 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -219,16 +219,10 @@ def __init__( ) -> None: super().__init__() if min_value is not None: - warnings.warn( - "keyword argument `min_value` is deprecated and rename to `min_val`", - FutureWarning, - ) + warnings.warn("keyword argument min_value is deprecated and rename to min_val") min_val = min_value if max_value is not None: - warnings.warn( - "keyword argument `max_value` is deprecated and rename to `max_val`", - FutureWarning, - ) + warnings.warn("keyword argument max_value is deprecated and rename to max_val") max_val = max_value self.min_val = min_val diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 775a826d69cc..1b5659d4b7e9 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -1,3 +1,4 @@ +import warnings from collections import OrderedDict, abc as container_abcs from itertools import chain, islice import operator @@ -9,7 +10,6 @@ from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union from typing_extensions import Self -from typing_extensions import deprecated __all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict'] @@ -29,14 +29,13 @@ def _addindent(s_, numSpaces): return s -@deprecated( - "`nn.Container` is deprecated. " - "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.", - category=FutureWarning, -) class Container(Module): + def __init__(self, **kwargs: Any) -> None: super().__init__() + # DeprecationWarning is ignored by default + warnings.warn("nn.Container is deprecated. All of it's functionality " + "is now implemented in nn.Module. Subclass that instead.") for key, value in kwargs.items(): self.add_module(key, value) diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 4ab4c8bff9fc..075d5e9865e6 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -1,4 +1,5 @@ import math +import warnings import torch from torch import Tensor @@ -12,7 +13,6 @@ from ..common_types import _size_1_t, _size_2_t, _size_3_t from typing import Optional, List, Tuple, Union -from typing_extensions import deprecated __all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'LazyConvTranspose1d', 'LazyConvTranspose2d', @@ -40,6 +40,9 @@ :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`."""} # noqa: B950 + + + class _ConvNd(Module): __constants__ = ['stride', 'padding', 'dilation', 'groups', @@ -607,6 +610,7 @@ def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.weight, self.bias) + class _ConvTransposeNd(_ConvNd): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, @@ -1117,13 +1121,10 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten # `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as # above would still work). class _ConvTransposeMixin(_ConvTransposeNd): - - @deprecated( - "`_ConvTransposeMixin` is a deprecated internal class. " - "Please consider using public APIs.", - category=FutureWarning, - ) def __init__(self, *args, **kwargs): + warnings.warn( + "_ConvTransposeMixin is a deprecated internal class. " + "Please consider using public APIs.") super().__init__(*args, **kwargs) diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 4324c1df144d..ee034bf458a6 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1,3 +1,5 @@ +import warnings + from .distance import PairwiseDistance from .module import Module from .. import functional as F @@ -5,7 +7,6 @@ from torch import Tensor from typing import Callable, Optional -from typing_extensions import deprecated __all__ = ['L1Loss', 'NLLLoss', 'NLLLoss2d', 'PoissonNLLLoss', 'GaussianNLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', @@ -217,15 +218,12 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) -@deprecated( - "`NLLLoss2d` has been deprecated. " - "Please use `NLLLoss` instead as a drop-in replacement and see " - "https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss for more details.", - category=FutureWarning, -) class NLLLoss2d(NLLLoss): def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, reduce=None, reduction: str = 'mean') -> None: + warnings.warn("NLLLoss2d has been deprecated. " + "Please use NLLLoss instead as a drop-in replacement and see " + "https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss for more details.") super().__init__(weight, size_average, ignore_index, reduce, reduction) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 2e65bb97c659..f9d400cdbb35 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1346,8 +1346,7 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): warnings.warn("Using non-full backward hooks on a Module that does not return a " "single Tensor or a tuple of Tensors is deprecated and will be removed " "in future versions. This hook will be missing some of the grad_output. " - "Please use register_full_backward_hook to get the documented behavior.", - FutureWarning) + "Please use register_full_backward_hook to get the documented behavior.") return else: result = (result,) @@ -1357,8 +1356,7 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): warnings.warn("Using non-full backward hooks on a Module that does not take as input a " "single Tensor or a tuple of Tensors is deprecated and will be removed " "in future versions. This hook will be missing some of the grad_input. " - "Please use register_full_backward_hook to get the documented behavior.", - FutureWarning) + "Please use register_full_backward_hook to get the documented behavior.") return else: inputs = (inputs,) @@ -1368,13 +1366,11 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): if len(out_grad_fn) == 0 or (len(out_grad_fn) == 1 and grad_fn not in out_grad_fn): warnings.warn("Using a non-full backward hook when outputs are nested in python data structure " "is deprecated and will be removed in future versions. This hook will be missing " - "some grad_output.", - FutureWarning) + "some grad_output.") elif len(out_grad_fn) > 1: warnings.warn("Using a non-full backward hook when outputs are generated by different autograd Nodes " "is deprecated and will be removed in future versions. This hook will be missing " - "some grad_output. Please use register_full_backward_hook to get the documented behavior.", - FutureWarning) + "some grad_output. Please use register_full_backward_hook to get the documented behavior.") else: # At this point the grad_output part of the hook will most likely be correct inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None} @@ -1385,8 +1381,7 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes " "is deprecated and will be removed in future versions. This hook will be missing " "some grad_input. Please use register_full_backward_hook to get the documented " - "behavior.", - FutureWarning) + "behavior.") def register_forward_pre_hook( self, @@ -1910,9 +1905,7 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): warnings.warn( "Positional args are being deprecated, use kwargs instead. Refer to " "https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.state_dict" - " for details.", - FutureWarning, - ) + " for details.") if destination is None: destination = OrderedDict() diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index b4bdd7824474..742bec9ebd19 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -3,7 +3,6 @@ import numbers import weakref from typing import List, Tuple, Optional, overload -from typing_extensions import deprecated import torch from torch import Tensor @@ -25,11 +24,8 @@ def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Ten return tensor.index_select(dim, permutation) -@deprecated( - "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", - category=FutureWarning, -) def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead") return _apply_permutation(tensor, permutation, dim) diff --git a/torch/nn/parallel/__init__.py b/torch/nn/parallel/__init__.py index adcd6bd838eb..d0708296d47a 100644 --- a/torch/nn/parallel/__init__.py +++ b/torch/nn/parallel/__init__.py @@ -1,5 +1,3 @@ -from typing_extensions import deprecated - from .parallel_apply import parallel_apply from .replicate import replicate from .data_parallel import DataParallel, data_parallel @@ -9,11 +7,8 @@ __all__ = ['replicate', 'scatter', 'parallel_apply', 'gather', 'data_parallel', 'DataParallel', 'DistributedDataParallel'] - -@deprecated( - "`torch.nn.parallel.DistributedDataParallelCPU` is deprecated, " - "please use `torch.nn.parallel.DistributedDataParallel` instead.", - category=FutureWarning, -) def DistributedDataParallelCPU(*args, **kwargs): + import warnings + warnings.warn("torch.nn.parallel.DistributedDataParallelCPU is deprecated, " + "please use torch.nn.parallel.DistributedDataParallel instead.") return DistributedDataParallel(*args, **kwargs) diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index 2e090f123c34..764775587d68 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -226,9 +226,7 @@ def gather(tensors, dim=0, destination=None, *, out=None): if destination == -1: warnings.warn( 'Using -1 to represent CPU tensor is deprecated. Please use a ' - 'device object or string instead, e.g., "cpu".', - FutureWarning, - ) + 'device object or string instead, e.g., "cpu".') destination = _get_device_index(destination, allow_cpu=True, optional=True) return torch._C._gather(tensors, dim, destination) else: diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 5f2013664f56..b27c960a154c 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -771,8 +771,7 @@ def __init__( # do not receive gradients. warnings.warn( "The `check_reduction` argument in `DistributedDataParallel` " - "module is deprecated. Please avoid using it.", - FutureWarning, + "module is deprecated. Please avoid using it." ) # Check that a module does not have Uninitialized parameters @@ -1467,7 +1466,7 @@ def _lazy_init(self): def _should_disable_cpp_reducer(self) -> bool: return self._use_python_reducer and ( - torch.compiler.is_compiling() or self._force_to_disable_cpp_reducer + torch._utils.is_compiling() or self._force_to_disable_cpp_reducer ) def _pre_forward(self, *inputs, **kwargs): @@ -1480,7 +1479,7 @@ def _pre_forward(self, *inputs, **kwargs): h.remove() self._accum_grad_hooks.clear() - if not self._lazy_init_ran and not torch.compiler.is_compiling(): + if not self._lazy_init_ran and not torch._utils.is_compiling(): self._lazy_init() if self._delay_all_reduce_all_params: diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index f6fb9d47ecbf..8daa1117bfaf 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,17 +1,13 @@ import torch from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, overload -from typing_extensions import deprecated from ._functions import Scatter, Gather +import warnings __all__ = ['scatter', 'scatter_kwargs', 'gather'] - -@deprecated( - "`is_namedtuple` is deprecated, please use the python checks instead", - category=FutureWarning, -) def is_namedtuple(obj: Any) -> bool: # Check if type was created from collections.namedtuple or a typing.NamedTuple. + warnings.warn("is_namedtuple is deprecated, please use the python checks instead") return _is_namedtuple(obj) def _is_namedtuple(obj: Any) -> bool: diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 4ac8a4e7445b..6549a6f3e2c8 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -1,6 +1,6 @@ +import warnings import functools from typing import Union, Iterable, List, Dict, Tuple, Optional, cast -from typing_extensions import deprecated import torch from torch import Tensor @@ -99,11 +99,6 @@ def clip_grad_norm_( return total_norm -@deprecated( - "`torch.nn.utils.clip_grad_norm` is now deprecated " - "in favor of `torch.nn.utils.clip_grad_norm_`.", - category=FutureWarning, -) def clip_grad_norm( parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2., error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor: @@ -113,6 +108,8 @@ def clip_grad_norm( This method is now deprecated in favor of :func:`torch.nn.utils.clip_grad_norm_`. """ + warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor " + "of torch.nn.utils.clip_grad_norm_.", stacklevel=2) return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach) diff --git a/torch/nn/utils/stateless.py b/torch/nn/utils/stateless.py index 660a1a484ebb..2cb6c7460d4c 100644 --- a/torch/nn/utils/stateless.py +++ b/torch/nn/utils/stateless.py @@ -1,7 +1,7 @@ import contextlib +import warnings from collections import defaultdict from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union -from typing_extensions import deprecated import torch from torch import Tensor @@ -148,12 +148,6 @@ def _reparametrize_module( ) -@deprecated( - "`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 " - "and will be removed in a future version of PyTorch. " - "Please use `torch.func.functional_call` instead which is a drop-in replacement.", - category=FutureWarning, -) def functional_call( module: "torch.nn.Module", parameters_and_buffers: Dict[str, Tensor], @@ -222,6 +216,12 @@ def functional_call( Returns: Any: the result of calling ``module``. """ + warnings.warn( + "This API is deprecated as of PyTorch 2.0 and will be removed in a future " + "version of PyTorch. Please use torch.func.functional_call instead " + "which is a drop-in replacement for this API." + ) + return _functional_call( module, parameters_and_buffers, diff --git a/torch/nn/utils/weight_norm.py b/torch/nn/utils/weight_norm.py index 6cfe4b3e526d..942a13a4eb83 100644 --- a/torch/nn/utils/weight_norm.py +++ b/torch/nn/utils/weight_norm.py @@ -2,7 +2,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter from torch import _weight_norm, norm_except_dim from typing import Any, TypeVar -from typing_extensions import deprecated +import warnings from ..modules import Module __all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm'] @@ -24,12 +24,9 @@ def compute_weight(self, module: Module) -> Any: return _weight_norm(v, g, self.dim) @staticmethod - @deprecated( - "`torch.nn.utils.weight_norm` is deprecated " - "in favor of `torch.nn.utils.parametrizations.weight_norm`.", - category=FutureWarning, - ) def apply(module, name: str, dim: int) -> 'WeightNorm': + warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.") + for hook in module._forward_pre_hooks.values(): if isinstance(hook, WeightNorm) and hook.name == name: raise RuntimeError(f"Cannot register two weight_norm hooks on the same parameter {name}") diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 4d1a4e25319c..097c8040b63e 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -254,7 +254,7 @@ def _single_tensor_adadelta( has_complex: bool, ): # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -310,7 +310,7 @@ def _multi_tensor_adadelta( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -413,7 +413,7 @@ def adadelta( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 1c625682fc34..fba4b2027b05 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -353,7 +353,7 @@ def _single_tensor_adam( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -466,7 +466,7 @@ def _multi_tensor_adam( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -743,7 +743,7 @@ def adam( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 005327d8bb88..8af468ba8386 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -243,7 +243,7 @@ def _single_tensor_adamax( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -315,7 +315,7 @@ def _multi_tensor_adamax( return # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -424,7 +424,7 @@ def adamax( See :class:`~torch.optim.Adamax` for details. """ - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 707ac17c361c..e58b28244083 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -354,7 +354,7 @@ def _single_tensor_adamw( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -467,7 +467,7 @@ def _multi_tensor_adamw( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -728,7 +728,7 @@ def adamw( See :class:`~torch.optim.AdamW` for details. """ - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 633a14832282..f53f8b427e9f 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -214,7 +214,7 @@ def _single_tensor_asgd( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type @@ -287,7 +287,7 @@ def _multi_tensor_asgd( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index fd1f8ab0e718..b860ed3ddda3 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -304,7 +304,7 @@ def _single_tensor_nadam( exp_avg_sq = torch.view_as_real(exp_avg_sq) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == mu_product.device.type == step_t.device.type @@ -390,7 +390,7 @@ def _multi_tensor_nadam( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index fc091e273c36..0352b7976579 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -24,6 +24,7 @@ import torch import torch.utils.hooks as hooks +from torch._utils import is_compiling from torch.utils._foreach_utils import ( _get_foreach_kernels_supported_devices, _get_fused_kernels_supported_devices, @@ -96,14 +97,14 @@ def _use_grad(self, *args, **kwargs): def _get_value(x): # item is significantly faster than a cpu tensor in eager mode - if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and is_compiling(): return x else: return x.item() if isinstance(x, torch.Tensor) else x def _stack_if_compiling(x): - if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and is_compiling(): return torch.stack(x) else: return x @@ -144,7 +145,7 @@ def wrapper(func): # the capturable flag. If capturable=True, this is not a problem. @functools.wraps(func) def maybe_fallback(*args, **kwargs): - if torch.compiler.is_compiling() and ( + if is_compiling() and ( not kwargs.get("capturable", False) and has_state_steps and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda) @@ -417,7 +418,7 @@ def _cuda_graph_capture_health_check(self) -> None: # Thus, when compiling, inductor will determine if cudagraphs # can be enabled based on whether there is input mutation or CPU tensors. if ( - not torch.compiler.is_compiling() + not is_compiling() and torch.backends.cuda.is_built() and torch.cuda.is_available() ): @@ -504,7 +505,7 @@ def _group_tensors_by_device_and_dtype( """Groups a list of lists of tensors by device and dtype. Skips this step if we are compiling since this will occur during inductor lowering. """ - if torch.compiler.is_compiling(): + if is_compiling(): return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} else: return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] diff --git a/torch/optim/radam.py b/torch/optim/radam.py index 619f10493587..ea592185c887 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -271,7 +271,7 @@ def _single_tensor_radam( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -369,7 +369,7 @@ def _multi_tensor_radam( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index bdc3ec0b8b3f..b3375c338b40 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -276,7 +276,7 @@ def _single_tensor_rmsprop( step = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step.device.type @@ -349,7 +349,7 @@ def _multi_tensor_rmsprop( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert all( p.device.type == step.device.type @@ -467,7 +467,7 @@ def rmsprop( """ # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index af1854cc518a..ec40aae5c90a 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -236,7 +236,7 @@ def _single_tensor_rprop( step = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step.device.type @@ -302,7 +302,7 @@ def _multi_tensor_rprop( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert all( p.device.type == step.device.type @@ -414,7 +414,7 @@ def rprop( """ # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index aa3062095c6a..c0efc2443078 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -429,7 +429,7 @@ def _multi_tensor_sgd( if not device_has_sparse_grad: # handle internal item() call if lr is a tensor - if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling(): + if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): grads_x_lr = torch._foreach_mul(device_grads, -lr) torch._foreach_add_(device_params, grads_x_lr) else: diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 3847da03ab8e..2c16c34bb3c3 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -598,10 +598,7 @@ def __init__( ): activities_set = set(activities) if activities else supported_activities() if use_cuda is not None: - warn( - "`use_cuda` is deprecated, use `activities` argument instead", - FutureWarning, - ) + warn("use_cuda is deprecated, use activities argument instead") if use_cuda: activities_set.add(ProfilerActivity.CUDA) elif ProfilerActivity.CUDA in activities_set: diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index d592e5ef6a62..587fcc0d72ea 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -359,12 +359,9 @@ def to_sparse_semi_structured( [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16)) """ if transposed: - warnings.warn( - "Setting transpose from `to_sparse_semi_structured` is deprecated " - "and will be removed in a future release. " - "`SparseSemiStructuredTensor` only support contiguous input tensors.", - FutureWarning, - stacklevel=2, + raise DeprecationWarning( + "Setting transpose from to_sparse_semi_structured is deprecated and will be removed in a future release." + "SparseSemiStructuredTensor only support contiguous input tensors. " ) # set from _FORCE_CUTLASS flag diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 85d5adb0cd3a..e2bad14e4490 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -2,6 +2,7 @@ import cmath import collections.abc import contextlib +import warnings from typing import ( Any, Callable, @@ -15,7 +16,6 @@ Type, Union, ) -from typing_extensions import deprecated import torch @@ -1523,12 +1523,6 @@ def assert_close( raise error_metas[0].to_error(msg) -@deprecated( - "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. " - "Please use `torch.testing.assert_close()` instead. " - "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.", - category=FutureWarning, -) def assert_allclose( actual: Any, expected: Any, @@ -1544,6 +1538,14 @@ def assert_allclose( Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions `here `_. """ + warnings.warn( + "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. " + "Please use `torch.testing.assert_close()` instead. " + "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.", + FutureWarning, + stacklevel=2, + ) + if not isinstance(actual, torch.Tensor): actual = torch.tensor(actual) if not isinstance(expected, torch.Tensor): diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index 2433552a0873..0b01b172a477 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -150,7 +150,7 @@ def clamp(a: float, l: float, h: float) -> float: warnings.warn( "Passing `low==high` to `torch.testing.make_tensor` for floating or complex types " "is deprecated since 2.1 and will be removed in 2.3. " - "Use `torch.full(...)` instead.", + "Use torch.full(...) instead.", FutureWarning, ) elif low >= high: diff --git a/torch/testing/_internal/optests/generate_tests.py b/torch/testing/_internal/optests/generate_tests.py index d01f91563c92..70ee48274800 100644 --- a/torch/testing/_internal/optests/generate_tests.py +++ b/torch/testing/_internal/optests/generate_tests.py @@ -569,7 +569,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if ( torch.jit.is_tracing() or torch.jit.is_scripting() - or torch.compiler.is_compiling() + or torch._dynamo.is_compiling() ): return func(*args, **kwargs) # Pre-existing code may not use the .default overload. If we see an diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 6b38645e486b..f468e2d84890 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -7,9 +7,9 @@ import pickle import tokenize import unittest +import warnings from types import FunctionType, ModuleType from typing import Any, Dict, Optional, Set, Union -from typing_extensions import deprecated from unittest import mock # Types saved/loaded in configs @@ -196,12 +196,12 @@ def get_hash(self) -> bytes: self._is_dirty = False return self._hash_digest - @deprecated( - "`config.to_dict()` has been deprecated. It may no longer change the underlying config." - " use `config.shallow_copy_dict()` or `config.get_config_copy()` instead", - category=FutureWarning, - ) def to_dict(self) -> Dict[str, Any]: + warnings.warn( + "config.to_dict() has been deprecated. It may no longer change the underlying config." + " use config.shallow_copy_dict() or config.get_config_copy() instead", + DeprecationWarning, + ) return self.shallow_copy_dict() def shallow_copy_dict(self) -> Dict[str, Any]: diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index 59b7d368af26..c55e69618575 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -122,14 +122,10 @@ class _DecoratorContextManager: def __call__(self, orig_func: F) -> F: if inspect.isclass(orig_func): - warnings.warn( - "Decorating classes is deprecated and will be disabled in " - "future versions. You should only decorate functions or methods. " - "To preserve the current behavior of class decoration, you can " - "directly decorate the `__init__` method and nothing else.", - FutureWarning, - stacklevel=2, - ) + warnings.warn("Decorating classes is deprecated and will be disabled in " + "future versions. You should only decorate functions or methods. " + "To preserve the current behavior of class decoration, you can " + "directly decorate the `__init__` method and nothing else.") func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs)) else: func = orig_func diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 01adf0a4f9b1..aba15f1482f2 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -15,6 +15,7 @@ import functools import sys import types +import warnings from typing import ( Any, Callable, @@ -27,7 +28,6 @@ TypeVar, Union, ) -from typing_extensions import deprecated import torch @@ -167,11 +167,6 @@ def register_pytree_node( ) -@deprecated( - "`torch.utils._cxx_pytree._register_pytree_node` is deprecated. " - "Please use `torch.utils._cxx_pytree.register_pytree_node` instead.", - category=FutureWarning, -) def _register_pytree_node( cls: Type[Any], flatten_fn: FlattenFunc, @@ -212,6 +207,11 @@ def _register_pytree_node( original context. This is used for json deserialization, which is being used in :mod:`torch.export` right now. """ + warnings.warn( + "torch.utils._cxx_pytree._register_pytree_node is deprecated. " + "Please use torch.utils._cxx_pytree.register_pytree_node instead.", + stacklevel=2, + ) _private_register_pytree_node( cls, diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 989be9b2d617..2831d662d9f6 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -48,7 +48,6 @@ TypeVar, Union, ) -from typing_extensions import deprecated __all__ = [ @@ -252,11 +251,6 @@ def _register_namedtuple( ) -@deprecated( - "`torch.utils._pytree._register_pytree_node` is deprecated. " - "Please use `torch.utils._pytree.register_pytree_node` instead.", - category=FutureWarning, -) def _register_pytree_node( cls: Type[Any], flatten_fn: FlattenFunc, @@ -293,11 +287,16 @@ def _register_pytree_node( Like ``flatten_fn``, but in place of a List[leaf], it should return a List[(keypath, leaf)]. """ + warnings.warn( + "torch.utils._pytree._register_pytree_node is deprecated. " + "Please use torch.utils._pytree.register_pytree_node instead.", + stacklevel=2, + ) + if to_str_fn is not None or maybe_from_str_fn is not None: warnings.warn( - "`to_str_fn` and `maybe_from_str_fn` is deprecated. " - "Please use `to_dumpable_context` and `from_dumpable_context` instead.", - FutureWarning, + "to_str_fn and maybe_from_str_fn is deprecated. " + "Please use to_dumpable_context and from_dumpable_context instead." ) _private_register_pytree_node( @@ -1452,20 +1451,14 @@ def treespec_pprint(treespec: TreeSpec) -> str: # TODO(angelayi): remove this function after OSS/internal stabilize -@deprecated( - "`pytree_to_str` is deprecated. Please use `treespec_dumps` instead.", - category=FutureWarning, -) def pytree_to_str(treespec: TreeSpec) -> str: + warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") return treespec_dumps(treespec) # TODO(angelayi): remove this function after OSS/internal stabilize -@deprecated( - "`str_to_pytree` is deprecated. Please use `treespec_loads` instead.", - category=FutureWarning, -) def str_to_pytree(json: str) -> TreeSpec: + warnings.warn("str_to_pytree is deprecated. Please use treespec_loads") return treespec_loads(json) diff --git a/torch/utils/data/backward_compatibility.py b/torch/utils/data/backward_compatibility.py index f51418265f41..be97f016a091 100644 --- a/torch/utils/data/backward_compatibility.py +++ b/torch/utils/data/backward_compatibility.py @@ -1,10 +1,5 @@ -from typing_extensions import deprecated as _deprecated +import warnings - -@_deprecated( - "Usage of `backward_compatibility.worker_init_fn` is deprecated " - "as `DataLoader` automatically applies sharding in every worker", - category=FutureWarning, -) def worker_init_fn(worker_id): - pass + warnings.warn("Usage of backward_compatibility.worker_init_fn is deprecated" + " as DataLoader automatically applies sharding in every worker") diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index b3cf9d92943d..554bf90d108b 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -14,7 +14,6 @@ TypeVar, Union, ) -from typing_extensions import deprecated # No 'default_generator' in torch/__init__.pyi from torch import default_generator, randperm @@ -349,11 +348,12 @@ def __getitem__(self, idx): return self.datasets[dataset_idx][sample_idx] @property - @deprecated( - "`cummulative_sizes` attribute is renamed to `cumulative_sizes`", - category=FutureWarning, - ) def cummulative_sizes(self): + warnings.warn( + "cummulative_sizes attribute is renamed to " "cumulative_sizes", + DeprecationWarning, + stacklevel=2, + ) return self.cumulative_sizes diff --git a/torch/utils/data/graph_settings.py b/torch/utils/data/graph_settings.py index 573069279201..4b42cc6065a7 100644 --- a/torch/utils/data/graph_settings.py +++ b/torch/utils/data/graph_settings.py @@ -2,7 +2,6 @@ import warnings from typing import Any, List, Optional, Set -from typing_extensions import deprecated import torch @@ -117,12 +116,11 @@ def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool] = None) - return datapipe -@deprecated( - "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. " - "Please use `apply_random_seed` instead.", - category=FutureWarning, -) def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe: + warnings.warn( + "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases." + "\nPlease use `apply_random_seed` instead." + ) return apply_random_seed(datapipe, rng) From 806e6257f307b98c28311023a9038026aeda0694 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 31 May 2024 09:30:42 -0400 Subject: [PATCH 190/706] Unconditionally assign symbolic shapes as locals (#127486) Internal xref: https://fb.workplace.com/groups/1405155842844877/posts/8493858177307906 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/127486 Approved by: https://github.com/albanD --- test/dynamo/test_unspec.py | 2 -- torch/_inductor/codegen/wrapper.py | 30 ++++++++---------------------- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 317fd15195ba..d5fdc006949e 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -332,7 +332,6 @@ def fn(x): # specialization is allowed) opt_fn(x) - @unittest.expectedFailure def test_conv1d_symint_padding(self): kernel = torch.randn(1, 1, 4) @@ -341,7 +340,6 @@ def func(x): out = F.conv1d(x, kernel, padding=padding, stride=2) return out - # TODO: NameError: name 's1' is not defined when dynamic=True opt_func = torch.compile(func) x = torch.randn(1, 1, 175) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 19bb7bf3c25e..0bf4814f80b1 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -28,12 +28,7 @@ from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.codegen.multi_kernel import MultiKernelState -from torch.fx.experimental.symbolic_shapes import ( - ConvertIntKey, - DivideByKey, - free_unbacked_symbols, - SymTypes, -) +from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes from torch.fx.node import _get_qualified_name from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.symbol import symbol_is_type, SymT @@ -864,7 +859,7 @@ def strideof(name): return f"{name}_stride" # Assign all symbolic shapes needed to local variables - needed = V.graph.sizevars.free_symbols() + bound_vars: Set[sympy.Symbol] = set() def is_expr(x): return isinstance(x[1], sympy.Expr) @@ -874,37 +869,28 @@ def is_expr(x): filter(lambda x: not is_expr(x), graph_inputs.items()) ) - def is_unbacked_symbol(s): - return isinstance(s, sympy.Symbol) and free_unbacked_symbols(s) - for name, shape in graph_inputs_expr: - shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] - if (b := shape in needed) or is_unbacked_symbol(shape): - if b: - needed.remove(shape) # type: ignore[arg-type] + if isinstance(shape, sympy.Symbol) and shape not in bound_vars: code.writeline(f"{self.declare}{shape} = {name}{self.ending}") + bound_vars.add(shape) for name, value in graph_inputs_tensors: shapes = value.get_size() for dim, shape in enumerate(shapes): - shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] - if (b := shape in needed) or is_unbacked_symbol(shape): - if b: - needed.remove(shape) # type: ignore[arg-type] + if isinstance(shape, sympy.Symbol) and shape not in bound_vars: code.writeline( f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}" ) + bound_vars.add(shape) for name, value in graph_inputs_tensors: shapes = value.get_stride() for dim, shape in enumerate(shapes): - shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] - if (b := shape in needed) or is_unbacked_symbol(shape): - if b: - needed.remove(shape) # type: ignore[arg-type] + if isinstance(shape, sympy.Symbol) and shape not in bound_vars: code.writeline( f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}" ) + bound_vars.add(shape) def ensure_size_computed(self, sym: sympy.Symbol): if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): From 923edef31c7f3e98a14625724f2019b1422dcb26 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 31 May 2024 20:09:08 +0000 Subject: [PATCH 191/706] FP8 rowwise scaling (#125204) # Summary This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met: - `x`'s scale should be a 1-dimensional tensor of length `M`. - `y`'s scale should be a 1-dimensional tensor of length `N`. It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row". The following two PRs were required to enable local builds: - [PR #126185](https://github.com/pytorch/pytorch/pull/126185) - [PR #125523](https://github.com/pytorch/pytorch/pull/125523) ### Todo We still do not build our Python wheels with this architecture. @ptrblck @malfet, should we replace `sm_90` with `sm_90a`? The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit: https://github.com/pytorch/pytorch/pull/125204/files#r1586986954 #### ifdef I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \ defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this Kernel Credit: @jwfromm Pull Request resolved: https://github.com/pytorch/pytorch/pull/125204 Approved by: https://github.com/lw --- aten/src/ATen/CMakeLists.txt | 1 + aten/src/ATen/cuda/detail/LazyNVRTC.cpp | 37 ++ aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h | 15 +- aten/src/ATen/native/cuda/Blas.cpp | 113 +++- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 535 +++++++++++++++++++ aten/src/ATen/native/cuda/RowwiseScaledMM.h | 15 + test/test_matmul_cuda.py | 149 +++++- third_party/cutlass.BUILD | 14 +- 8 files changed, 854 insertions(+), 25 deletions(-) create mode 100644 aten/src/ATen/native/cuda/RowwiseScaledMM.cu create mode 100644 aten/src/ATen/native/cuda/RowwiseScaledMM.h diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 9fa7a1f2305b..696621eeca6f 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -472,6 +472,7 @@ endif() if(USE_CUDA AND NOT USE_ROCM) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) + list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) if($ENV{ATEN_STATIC_CUDA}) list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${CUDA_LIBRARIES} diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp index 1b85e7776e22..75c503d48d51 100644 --- a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp +++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -170,6 +170,43 @@ CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *); CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int); CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction); +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 +CUresult CUDAAPI +cuTensorMapEncodeTiled( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const cuuint32_t* boxDim, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) { + auto fn = reinterpret_cast( + getCUDALibrary().sym(__func__)); + if (!fn) + throw std::runtime_error("Can't get cuTensorMapEncodeTiled"); + lazyNVRTC.cuTensorMapEncodeTiled = fn; + return fn( + tensorMap, + tensorDataType, + tensorRank, + globalAddress, + globalDim, + globalStrides, + boxDim, + elementStrides, + interleave, + swizzle, + l2Promotion, + oobFill); +} + +#endif + // Irregularly shaped functions CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index 574b2c41c264..cb34d10db254 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -59,16 +59,25 @@ namespace at { namespace cuda { _(cuLinkAddData) \ _(cuLinkComplete) \ _(cuFuncSetAttribute) \ - _(cuFuncGetAttribute) + _(cuFuncGetAttribute) \ + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 +#define AT_FORALL_NVRTC_EXTENDED(_) \ + AT_FORALL_NVRTC_BASE(_) \ + _(cuTensorMapEncodeTiled) +#else +#define AT_FORALL_NVRTC_EXTENDED(_) \ + AT_FORALL_NVRTC_BASE(_) +#endif #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 #define AT_FORALL_NVRTC(_) \ - AT_FORALL_NVRTC_BASE(_) \ + AT_FORALL_NVRTC_EXTENDED(_) \ _(nvrtcGetCUBINSize) \ _(nvrtcGetCUBIN) #else #define AT_FORALL_NVRTC(_) \ - AT_FORALL_NVRTC_BASE(_) + AT_FORALL_NVRTC_EXTENDED(_) #endif #else diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 84c59a4fd0d7..ed59b47349cc 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1,3 +1,7 @@ +#include +#include +#include +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -10,6 +14,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -819,24 +824,97 @@ static bool _scaled_mm_allowed_device() { #endif } +namespace{ + +enum class ScalingType { + TensorWise, + RowWise, + Error +}; + +// Validates the scale tensors to scaled_mm +// And returns the type of scaling/which kernel to use +ScalingType get_scaling_type( + const c10::optional& scale_a, + const c10::optional& scale_b, + int64_t dim_m, + int64_t dim_n) { + TORCH_CHECK( + scale_a.has_value() == scale_b.has_value(), + "Both scale_a and scale_b must be present or absent."); + + if (scale_a.has_value()) { + // Both Per-Tensor and Row-wise scaling expect fp32 tensors + TORCH_CHECK( + scale_a->scalar_type() == kFloat && scale_b->scalar_type() == kFloat, + "Both scale_a and scale_b must be float (fp32) tensors."); + + // Check the singluar scale case for per-tensor scaling + if (scale_a->numel() == 1 && scale_b->numel() == 1) { + return ScalingType::TensorWise; + } else if (scale_a->dim() == 1 && scale_a->size(0) == dim_m) { +// Check the per-row scaling case +#if !defined(USE_ROCM) && !defined(_MSC_VER) || \ + (defined(USE_ROCM) && ROCM_VERSION >= 60000) + TORCH_CHECK( + scale_a->dim() == 1 && scale_b->dim() == 1, + "Both scale_a and scale_b must be 1-dimensional tensors"); + TORCH_CHECK( + scale_b->size(0) == dim_n, + "For row-wise scaling, scale_b must have size ", + dim_n, + " but got ", + scale_b->size(0), + "."); + TORCH_CHECK( + scale_a->is_contiguous() && scale_b->is_contiguous(), + "Both scale_a and scale_b must be contiguous."); + return ScalingType::RowWise; +#else + TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); + return ScalingType::Error; +#endif // !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && + // ROCM_VERSION >= 60000) + } else { + TORCH_CHECK( + false, + "For row-wise scaling, scale_a must be size ", + dim_m, + " but got ", + scale_a->numel(), + " and scale_b must be size ", + dim_n, + " but got ", + scale_b->numel(), + "."); + // Unreachable + return ScalingType::RowWise; + } + } + return ScalingType::Error; +} + +} // namespace + // Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax // Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default. // If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed. // Known limitations: // - Only works if mat1 is row-major and mat2 is column-major // - Only works if matrices sizes are divisible by 32 -// +// - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0) +// and scale_b should have size = to mat2.size(1) // Arguments: // - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16` // - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type -// - `scale_a`: a scalar tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type -// - `scale_b`: a scalar tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type -// - `scale_result`: a scalar tensor with the scale of the output, only set if the output is a float8 type +// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type +// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type +// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type // - `use_fast_accum`: if true, enables fast float8 accumulation // - `out`: a reference to the output tensor -// - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace +// - `amax`: a reference to the amax tensor of the output, only mutated if the output is a float8 type and will be updated inplace std::tuple _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, @@ -855,10 +933,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK( mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - TORCH_CHECK(!scale_a || (scale_a->numel() == 1 && scale_a->scalar_type() == kFloat), - "scale_a must be float scalar"); - TORCH_CHECK(!scale_b || (scale_b->numel() == 1 && scale_b->scalar_type() == kFloat), - "scale_b must be a float scalar"); + + // Check what type of scaling we are doing based on inputs + ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1)); + TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); + TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], @@ -901,12 +980,26 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, {scale_result_, "scale_result", 7}}; checkAllSameGPU(__func__, targs); } - + // Validation checks have passed lets resize the output to actual size IntArrayRef mat1_sizes = mat1.sizes(); IntArrayRef mat2_sizes = mat2.sizes(); at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); at::native::resize_output(amax, {}); + // We are doing row-wise scaling + if (scaling_choice == ScalingType::RowWise) { + TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling."); + at::cuda::detail::f8f8bf16_rowwise( + mat1, + mat2, + scale_a.value(), + scale_b.value(), + bias, + use_fast_accum, + out); + return {out, amax}; + } + cublasCommonArgs args(mat1, mat2, out); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu new file mode 100644 index 000000000000..14eb8f5fbf80 --- /dev/null +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -0,0 +1,535 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +// Determine if the architecture supports rowwise scaled mm +#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + +#define BUILD_ROWWISE_FP8_KERNEL +#endif + +#if defined(BUILD_ROWWISE_FP8_KERNEL) + +// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader +static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const cuuint32_t* boxDim, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) { + return at::globalContext().getNVRTC().cuTensorMapEncodeTiled( + tensorMap, + tensorDataType, + tensorRank, + globalAddress, + globalDim, + globalStrides, + boxDim, + elementStrides, + interleave, + swizzle, + l2Promotion, + oobFill); +} + + +#include +#include +#include +#include +#include +#include +#include + +// Rename the global function symbol +#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled +#include +#undef cuTensorMapEncodeTiled +// Set everything back to normal + +#include +#include +#include + +#include +#include +#include +#include + + +namespace { +// Cutlass rowwise kernel +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + bool FAST_ACCUM, + bool USE_BIAS, + typename INPUT_DTYPE, + typename BIAS_DTYPE> +void f8f8bf16_rowwise_impl( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + c10::optional bias, + at::Tensor out) { + int M = XQ.size(0); + int N = WQ.size(1); + int K = XQ.size(1); + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK( + WQ.is_cuda() && WQ.ndimension() == 2 && WQ.stride(1) == WQ.size(0) && + WQ.stride(0) == 1); + + // auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + + using ElementInputA = INPUT_DTYPE; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); + + using ElementInputB = cutlass::float_e4m3_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); + + using ElementBias = BIAS_DTYPE; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using KernelSchedule = cutlass::gemm::collective:: + KernelScheduleAuto; // Kernel to launch based on the default setting in + // the Collective Builder + + // Implement rowwise scaling epilogue. + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementBias, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementComputeEpilogue, // First stage output type. + ElementComputeEpilogue, // First stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + cute::conditional_t< // Second stage output type. + USE_BIAS, + ElementBias, + ElementOutput>, + ElementComputeEpilogue, // Second stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementOutput, // Final (optional) stage output type. + ElementBias, // Final stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeBias = + cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using SlowAccum = cute::conditional_t; + using FastAccum = + cute::conditional_t; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, 1)); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, 1)); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(M, N, 1)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + {reinterpret_cast(XQ.data_ptr()), + stride_a, + reinterpret_cast(WQ.data_ptr()), + stride_b}, + {{}, // Epilogue thread we populate below. + (ElementOutput*)out.data_ptr(), + stride_output, + (ElementOutput*)out.data_ptr(), + stride_output}}; + + if constexpr (USE_BIAS) { + arguments.epilogue.thread = { + {reinterpret_cast(bias.value().data_ptr())}, // bias + // compute_1 + { + {reinterpret_cast( + x_scale.data_ptr())}, // x_scale + // compute_0 + { + {reinterpret_cast( + w_scale.data_ptr())}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }, + {}, // Plus + }; + } else { + arguments.epilogue.thread = { + {reinterpret_cast( + x_scale.data_ptr())}, // x_scale + // compute_0 + { + {reinterpret_cast( + w_scale.data_ptr())}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +// FP8 Rowwise Cutlass kernel dispatch. +enum class KernelMode { Small, Large, Default }; + +KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { + auto M = XQ.size(0); + auto K = XQ.size(1); + auto N = WQ.size(0); + // Use a large kernel if at least two shapes are large.... + bool use_large_kernel = + ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || + (K >= 2048 && N >= 2048)); + if (M <= 128 || N <= 128) { + return KernelMode::Small; + } else if (use_large_kernel) { + return KernelMode::Large; + } else { + return KernelMode::Default; + } +} + +template +void dispatch_fp8_rowwise_kernel( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + c10::optional bias, + at::Tensor out) { + KernelMode kernel = get_kernel_mode(XQ, WQ); + if (kernel == KernelMode::Small) { + return f8f8bf16_rowwise_impl< + 64, + 128, + 128, + 2, + 1, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); + } else if (kernel == KernelMode::Large) { + return f8f8bf16_rowwise_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return f8f8bf16_rowwise_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); + } +} + +} // namespace + +#endif // !defined(USE_ROCM) + +namespace at::cuda::detail { +void f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + c10::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out) { +#if defined(BUILD_ROWWISE_FP8_KERNEL) + // Check datatypes. + TORCH_CHECK( + x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, + "Scale tensors must be float32."); + if (bias.has_value()) { + TORCH_CHECK( + bias.value().dtype() == at::kFloat || + bias.value().dtype() == at::kBFloat16, + "Bias type must be bfloat16 or float32 if provided."); + } + // Extract problem size. + int M = XQ.size(0); + int N = WQ.size(1); + int K = XQ.size(1); + + bool use_bias = bias.has_value(); + bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; + + // Templatize based on input dtype. + bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2; + TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn, "For row-wise scaling the second input is required to be a float8_e4m3fn dtype."); + + if (use_bias) { + if (bf16_bias) { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + } +#else // BUILD_ROWWISE_FP8_KERNEL + TORCH_CHECK(false, "Rowwise scaling is not currenlty supported on your device"); +#endif +} + +} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.h b/aten/src/ATen/native/cuda/RowwiseScaledMM.h new file mode 100644 index 000000000000..4d9054108c85 --- /dev/null +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include + + +namespace at::cuda::detail { +TORCH_API void f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + c10::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out); +} // at::cuda::detail diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index a5c583580848..7793e7411e88 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -204,7 +204,6 @@ def _expand_to_batch(t: torch.Tensor): self.assertEqual(out1_gpu, out2_gpu[0]) - f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" if torch.version.hip: @@ -256,8 +255,12 @@ def amax_to_scale( scale.copy_(res) return scale -def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype): - amax = torch.max(torch.abs(x)) +def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): + if dim is None: + amax = torch.max(torch.abs(x)) + else: + amax = torch.max(torch.abs(x), dim=dim).values + return amax_to_scale(amax, float8_dtype, x.dtype) def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype): @@ -316,7 +319,6 @@ def mm_float8( def to_fp8_saturated( x: torch.Tensor, - x_scale: torch.tensor, fp8_dtype: torch.dtype ): """ @@ -339,8 +341,6 @@ def to_fp8_saturated( of a tensor has a maximum value of `amax1`, and the current amax value is `amax2`, where `amax1 < amax2`. """ - x_scaled = x * x_scale - if fp8_dtype == e4m3_type: x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) elif fp8_dtype == e5m2_type: @@ -353,8 +353,6 @@ def to_fp8_saturated( @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") class TestFP8MatmulCuda(TestCase): - - @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) def _test_tautological_mm(self, device: str = "cuda", x_dtype: torch.dtype = e4m3_type, @@ -418,8 +416,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype): x_scale = tensor_to_scale(x, input_dtype).float() y_scale = tensor_to_scale(y, input_dtype).float() - x_fp8 = to_fp8_saturated(x, x_scale, e4m3_type) - y_fp8 = to_fp8_saturated(y, y_scale, e4m3_type) + x_fp8 = to_fp8_saturated(x * x_scale, e4m3_type) + y_fp8 = to_fp8_saturated(y * y_scale, e4m3_type) # Calculate actual F8 mm out_scaled_mm, output_amax_scaled = mm_float8( @@ -526,6 +524,137 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) + @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) + @skipIfRocm() + @parametrize("use_fast_accum", [True, False]) + def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: + M, K, N = (1024, 512, 2048) + fill_value = 0.5 + x = torch.full((M, K), fill_value, device=device) + y = torch.full((N, K), fill_value, device=device) + + x_scales = torch.ones(x.shape[0], device=device, dtype=torch.float32) + y_scales = torch.ones(y.shape[0], device=device, dtype=torch.float32) + + x_fp8 = x.to(torch.float8_e4m3fn) + y_fp8 = y.to(torch.float8_e4m3fn).t() + + out_fp8, _ = torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=x_scales, + scale_b=y_scales, + out_dtype=torch.bfloat16, + use_fast_accum=use_fast_accum, + ) + self.assertEqual( + out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) + ) + + @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) + @skipIfRocm() + def test_float8_error_messages(self, device) -> None: + M, K, N = (1024, 512, 2048) + fill_value = 0.5 + x = torch.full((M, K), fill_value, device=device) + y = torch.full((N, K), fill_value, device=device) + + x_fp8 = x.to(torch.float8_e4m3fn) + y_fp8 = y.to(torch.float8_e4m3fn).t() + + with self.assertRaisesRegex( + RuntimeError, + "For row-wise scaling, scale_a must be size 1024 but got 1 and scale_b must be size 2048 but got 2", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((), device="cuda"), + scale_b=torch.ones((2), device="cuda"), + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + "For row-wise scaling, scale_b must have size 2048 but got 2049.", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N + 1), device="cuda"), + out_dtype=torch.bfloat16, + ) + with self.assertRaisesRegex( + RuntimeError, + "Both scale_a and scale_b must be 1-dimensional tensors", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N, N), device="cuda"), + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + "Both scale_a and scale_b must be contiguous.", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N * 2), device="cuda")[::2], + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + "For row-wise scaling the second input is required to be a float8_e4m3fn dtype.", + ): + torch._scaled_mm( + x_fp8, + y_fp8.to(torch.float8_e5m2), + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N), device="cuda"), + out_dtype=torch.bfloat16, + ) + + @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) + @skipIfRocm() + @parametrize("base_dtype", [torch.bfloat16]) + def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): + torch.manual_seed(42) + input_dtype = e4m3_type + output_dtype = base_dtype + + x = torch.randn(16, 16, device="cuda", dtype=base_dtype) + y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + + x_scales = tensor_to_scale(x, input_dtype, dim=1).float() + y_scales = tensor_to_scale(y, input_dtype, dim=0).float() + + x_fp8 = to_fp8_saturated(x * x_scales[:, None], e4m3_type) + y_fp8 = to_fp8_saturated(y * y_scales[None, :], e4m3_type) + + # Calculate actual F8 mm + out_scaled_mm, _ = mm_float8( + x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype + ) + + # Calculate emulated F8 mm + out_emulated, _ = mm_float8_emulated( + x_fp8, x_scales[:, None], y_fp8, y_scales[None, :], output_dtype + ) + + if base_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 7e-2, 7e-2 + else: + atol, rtol = 2e-3, 2e-3 + + torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD index e712d59597cc..e3e7b7b288e7 100644 --- a/third_party/cutlass.BUILD +++ b/third_party/cutlass.BUILD @@ -5,7 +5,17 @@ load("@rules_cc//cc:defs.bzl", "cc_library") cc_library( name = "cutlass", - hdrs = glob(["include/**/*.h", "include/**/*.hpp"]), - includes = ["include/"], + hdrs = glob([ + "include/**/*.h", + "include/**/*.hpp", + "include/**/*.inl", + "tools/util/include/**/*.h", + "tools/util/include/**/*.hpp", + "tools/util/include/**/*.inl", + ]), + includes = [ + "include/", + "tools/util/include/", + ], visibility = ["//visibility:public"], ) From 6bfc6e08759cf1fd7cf89916124285bf131b7168 Mon Sep 17 00:00:00 2001 From: albanD Date: Fri, 31 May 2024 20:48:15 +0000 Subject: [PATCH 192/706] Add back private function torch.cuda.amp.autocast_mode._cast (#127433) This is unfortunately used in a few places in the wild: https://github.com/search?q=torch.cuda.amp.autocast_mode._cast&type=code Pull Request resolved: https://github.com/pytorch/pytorch/pull/127433 Approved by: https://github.com/zou3519, https://github.com/guangyey --- torch/cuda/amp/autocast_mode.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index e953d20cb2a5..79d7c3cc1344 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -49,6 +49,11 @@ def __call__(self, func): return super().__call__(func) +# Preserved only for BC reasons +def _cast(value, dtype): + return torch.amp.autocast_mode._cast(value, "cuda", dtype) + + def custom_fwd(fwd=None, *, cast_inputs=None): """ ``torch.cuda.amp.custom_fwd(args...)`` is deprecated. Please use From f33beb767d04ad00aecbcf16690e786eb93ebdd8 Mon Sep 17 00:00:00 2001 From: David Berard Date: Fri, 31 May 2024 00:07:06 +0000 Subject: [PATCH 193/706] [NestedTensor] Use maybe_mark_dynamic instead of mark_dynamic (#127453) Fixes #127097 **TL;DR**: dimensions marked with mark_dynamic can result in assertion failures if the marked-dynamic dimensions get specialized. In NJT, we don't care _that_ much that a dimension is marked as dynamic. So instead, mark with `maybe_mark_dynamic` which suggests that a dimension should be dynamic, but doesn't fail if the dimension gets specialized. **Background**: NJT marks the values tensor as dynamic: https://github.com/pytorch/pytorch/blob/49ad90349d57c35ab83f40c28d8b18caefb416d1/torch/nested/_internal/nested_tensor.py#L122 It does this for two reasons: 1. **Conceptual**: We know that this dimension _should_ be dynamic; it's a nested tensor, so the sequence lengths will _probably_ vary between batches in the common case. Therefore, we should compile it as dynamic to prevent needing a recompile to trigger automatic dynamic shapes. 2. **Implementation detail**: Right now we run into issues with torch.compile / tensor_unflatten / other details when the dimensions are not marked as dynamic. We have some attempts to remove this (e.g. https://github.com/pytorch/pytorch/pull/126563) but while testing this I wasn't able to get all tests to pass, so there could be potential regressions here if we removed the mark_dynamic. **Justification for this change** 1. **Conceptual**: AFAIK, we don't care enough about the dynamism of this dimension to error out if we specialize. We'd prefer that we don't have to recompile to get automatic dynamic shapes, but it's also better to not have this issue (and not to force the user to go hunt down all the other equivalent shapes to mark them as dynamic as well). This solution allows us to suggest the dynamism but not force it. 2. **Implementation detail**: This still marks the dimension as symbolic at the beginning of dynamo tracing, so we will (probably) avoid a lot of the issues we run into when we completely remove the `mark_dynamic` decorators. Differential Revision: [D57933779](https://our.internmc.facebook.com/intern/diff/D57933779) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127453 Approved by: https://github.com/soulitzer, https://github.com/YuqingJ --- test/test_nestedtensor.py | 66 +++++++++++++++++++++++++ torch/nested/_internal/nested_tensor.py | 4 +- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 597180129f72..d369135a6e52 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -10,6 +10,8 @@ import numpy as np import torch +import torch._dynamo +import torch._dynamo.testing import torch.nn import torch.nn.functional as F from torch.testing._internal.common_cuda import ( @@ -4008,6 +4010,70 @@ def check_size(nt1, nt2, nt3, nt4): nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4)) check_size(nt1_t, nt2_t, nt3_t, nt4_t) + @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_specialize_dynamic_shape(self, device): + values = torch.randn((18, 16), device=device) + offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device) + like_values = torch.randn_like(values) + + # this marks values as dynamic + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + + def fn(values, same_size): + # here, the dynamic shape is specialized by same_size's shape + # https://github.com/pytorch/pytorch/issues/127097 + # make sure this doesn't error out in torch.compile + return values + same_size + + self.assertEqual( + fn(values, like_values), + torch.compile(fn)(values, like_values), + ) + + @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_specialize_dynamic_shape_recompile(self, device): + def generate_inp(total_len): + values = torch.randn((total_len, 16), device=device) + offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device) + like_values = torch.randn_like(values) + return values, offsets, like_values + + def check_results(ref_fn, res_fn, args): + values, offsets, like_values = args + # this may add dynamic shape markings + # goal of this test is to make sure that whatever markings are there, + # we eventually stop recompiling as shape changes. + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + + self.assertEqual( + ref_fn(values, like_values), + res_fn(values, like_values), + ) + + + def fn(values, same_size): + return values + same_size + + compile_counter = torch._dynamo.testing.CompileCounter() + + compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn) + check_results(fn, compiled_fn, generate_inp(18)) + self.assertEqual(compile_counter.frame_count, 1) + + check_results(fn, compiled_fn, generate_inp(19)) + # we'll probably recompile here with dynamic shapes - it's okay if not though. + frame_count_2 = compile_counter.frame_count + self.assertIn(frame_count_2, [1, 2]) + + # make sure that by now we've already compiled with dynamic shapes, so additional + # shapes should not trigger additional recompiles. + check_results(fn, compiled_fn, generate_inp(20)) + self.assertEqual(compile_counter.frame_count, frame_count_2) + # Doesn't work until we have real views @xfailIfTorchDynamo # Note 1: Math fallback doesn't work with bfloat16 on CUDA diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 5cc6b1c75d7a..5ef8983a8393 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -118,8 +118,8 @@ def __init__(self, values, offsets, *, lengths=None, **kwargs): self._metadata_cache = kwargs.get("_metadata_cache") or {} # collapsed ragged dim must always be dynamic - torch._dynamo.mark_dynamic(self, self._ragged_idx) - torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1) + torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx) + torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1) def values(self): # dispatch to get proper view relationship From bb1468d50660a7c3c1c635925688f406e1d7bd5f Mon Sep 17 00:00:00 2001 From: Lucas Pasqualin Date: Fri, 31 May 2024 21:59:07 +0000 Subject: [PATCH 194/706] Updates state dict in state dict loader (#127617) Fixes #125096 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127617 Approved by: https://github.com/Skylion007, https://github.com/fegin --- torch/distributed/checkpoint/state_dict_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index b7e1337e6c4f..df3cc945832c 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -175,6 +175,7 @@ def load( elem = state_dict[key] if isinstance(elem, Stateful): elem.load_state_dict(statetful_sd[key]) + state_dict[key] = statetful_sd[key] def _load_state_dict( From 02248b73eb3c0bcd8a114b09092f2073c64a21ea Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 31 May 2024 22:24:41 +0000 Subject: [PATCH 195/706] [EZ] Port over all test-infra scale configs to lf runners (#127645) Follow up to https://github.com/pytorch/pytorch/pull/127578 Since GPU builds seem to be working correctly, porting over all remaining scale configs from [the org-wide scale config file](https://github.com/pytorch/test-infra/blob/main/.github/scale-config.yml) The naming convention here is all temporary. We'll figure out something better before completing the migration Pull Request resolved: https://github.com/pytorch/pytorch/pull/127645 Approved by: https://github.com/malfet --- .github/lf-canary-scale-config.yml | 130 ++++++++++++++++++++++++++++- .github/lf-scale-config.yml | 130 ++++++++++++++++++++++++++++- 2 files changed, 258 insertions(+), 2 deletions(-) diff --git a/.github/lf-canary-scale-config.yml b/.github/lf-canary-scale-config.yml index 628444237e52..eb7288ea56bf 100644 --- a/.github/lf-canary-scale-config.yml +++ b/.github/lf-canary-scale-config.yml @@ -1,5 +1,56 @@ -# Defines runner types provisioned by by LF Self-hosted runners for pytorch/pytorch-canary and their labels. +# Defines runner types that will be provisioned by by LF Self-hosted +# runners for pytorch/pytorch-canary and their labels. +# +# Runners listed here will be available as self hosted runners. +# Configuration is directly pulled from the main branch. +# +# Default values: +# +# runner_types: +# runner_label: # label to specify in the Github Actions workflow +# instance_type: m4.large +# os: linux +# max_available: 20 +# disk_size: 50 +# is_ephemeral: true + runner_types: + lf.c.linux.12xlarge: + disk_size: 200 + instance_type: c5.12xlarge + is_ephemeral: false + max_available: 1000 + os: linux + lf.c.linux.24xl.spr-metal: + disk_size: 200 + instance_type: c7i.metal-24xl + is_ephemeral: false + max_available: 30 + os: linux + lf.c.linux.16xlarge.spr: + disk_size: 200 + instance_type: c7i.16xlarge + is_ephemeral: false + max_available: 30 + os: linux + lf.c.linux.12xlarge.ephemeral: + disk_size: 200 + instance_type: c5.12xlarge + is_ephemeral: true + max_available: 300 + os: linux + lf.c.linux.16xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.16xlarge + is_ephemeral: false + max_available: 30 + os: linux + lf.c.linux.24xlarge: + disk_size: 150 + instance_type: c5.24xlarge + is_ephemeral: false + max_available: 250 + os: linux lf.c.linux.2xlarge: disk_size: 150 instance_type: c5.2xlarge @@ -24,3 +75,80 @@ runner_types: is_ephemeral: false max_available: 400 os: linux + lf.c.linux.g4dn.12xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g4dn.12xlarge + is_ephemeral: false + max_available: 50 + os: linux + lf.c.linux.g4dn.metal.nvidia.gpu: + disk_size: 150 + instance_type: g4dn.metal + is_ephemeral: false + max_available: 30 + os: linux + lf.c.linux.g5.48xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.48xlarge + is_ephemeral: false + max_available: 20 + os: linux + lf.c.linux.g5.12xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.12xlarge + is_ephemeral: false + max_available: 150 + os: linux + lf.c.linux.g5.4xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.4xlarge + is_ephemeral: false + max_available: 1200 + os: linux + lf.c.linux.large: + disk_size: 15 + instance_type: c5.large + is_ephemeral: false + os: linux + lf.c.linux.arm64.2xlarge: + disk_size: 256 + instance_type: t4g.2xlarge + is_ephemeral: false + max_available: 200 + os: linux + lf.c.linux.arm64.m7g.2xlarge: + disk_size: 256 + instance_type: m7g.2xlarge + is_ephemeral: false + max_available: 20 + os: linux + lf.c.windows.4xlarge: + disk_size: 256 + instance_type: c5d.4xlarge + is_ephemeral: true + max_available: 420 + os: windows + lf.c.windows.4xlarge.nonephemeral: + disk_size: 256 + instance_type: c5d.4xlarge + is_ephemeral: false + max_available: 420 + os: windows + lf.c.windows.8xlarge.nvidia.gpu: + disk_size: 256 + instance_type: p3.2xlarge + is_ephemeral: true + max_available: 150 + os: windows + lf.c.windows.8xlarge.nvidia.gpu.nonephemeral: + disk_size: 256 + instance_type: p3.2xlarge + is_ephemeral: false + max_available: 150 + os: windows + lf.c.windows.g5.4xlarge.nvidia.gpu: + disk_size: 256 + instance_type: g5.4xlarge + is_ephemeral: false + max_available: 250 + os: windows diff --git a/.github/lf-scale-config.yml b/.github/lf-scale-config.yml index cd4e6dc9f4f4..7977d7c15c2f 100644 --- a/.github/lf-scale-config.yml +++ b/.github/lf-scale-config.yml @@ -1,5 +1,56 @@ -# Defines runner types provisioned by by LF Self-hosted runners for pytorch/pytorch and their labels. +# Defines runner types that will be provisioned by by LF Self-hosted +# runners for pytorch/pytorch and their labels. +# +# Runners listed here will be available as self hosted runners. +# Configuration is directly pulled from the main branch. +# +# Default values: +# +# runner_types: +# runner_label: # label to specify in the Github Actions workflow +# instance_type: m4.large +# os: linux +# max_available: 20 +# disk_size: 50 +# is_ephemeral: true + runner_types: + lf.linux.12xlarge: + disk_size: 200 + instance_type: c5.12xlarge + is_ephemeral: false + max_available: 1000 + os: linux + lf.linux.24xl.spr-metal: + disk_size: 200 + instance_type: c7i.metal-24xl + is_ephemeral: false + max_available: 30 + os: linux + lf.linux.16xlarge.spr: + disk_size: 200 + instance_type: c7i.16xlarge + is_ephemeral: false + max_available: 30 + os: linux + lf.linux.12xlarge.ephemeral: + disk_size: 200 + instance_type: c5.12xlarge + is_ephemeral: true + max_available: 300 + os: linux + lf.linux.16xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.16xlarge + is_ephemeral: false + max_available: 30 + os: linux + lf.linux.24xlarge: + disk_size: 150 + instance_type: c5.24xlarge + is_ephemeral: false + max_available: 250 + os: linux lf.linux.2xlarge: disk_size: 150 instance_type: c5.2xlarge @@ -24,3 +75,80 @@ runner_types: is_ephemeral: false max_available: 400 os: linux + lf.linux.g4dn.12xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g4dn.12xlarge + is_ephemeral: false + max_available: 50 + os: linux + lf.linux.g4dn.metal.nvidia.gpu: + disk_size: 150 + instance_type: g4dn.metal + is_ephemeral: false + max_available: 30 + os: linux + lf.linux.g5.48xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.48xlarge + is_ephemeral: false + max_available: 20 + os: linux + lf.linux.g5.12xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.12xlarge + is_ephemeral: false + max_available: 150 + os: linux + lf.linux.g5.4xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.4xlarge + is_ephemeral: false + max_available: 1200 + os: linux + lf.linux.large: + disk_size: 15 + instance_type: c5.large + is_ephemeral: false + os: linux + lf.linux.arm64.2xlarge: + disk_size: 256 + instance_type: t4g.2xlarge + is_ephemeral: false + max_available: 200 + os: linux + lf.linux.arm64.m7g.2xlarge: + disk_size: 256 + instance_type: m7g.2xlarge + is_ephemeral: false + max_available: 20 + os: linux + lf.windows.4xlarge: + disk_size: 256 + instance_type: c5d.4xlarge + is_ephemeral: true + max_available: 420 + os: windows + lf.windows.4xlarge.nonephemeral: + disk_size: 256 + instance_type: c5d.4xlarge + is_ephemeral: false + max_available: 420 + os: windows + lf.windows.8xlarge.nvidia.gpu: + disk_size: 256 + instance_type: p3.2xlarge + is_ephemeral: true + max_available: 150 + os: windows + lf.windows.8xlarge.nvidia.gpu.nonephemeral: + disk_size: 256 + instance_type: p3.2xlarge + is_ephemeral: false + max_available: 150 + os: windows + lf.windows.g5.4xlarge.nvidia.gpu: + disk_size: 256 + instance_type: g5.4xlarge + is_ephemeral: false + max_available: 250 + os: windows From 57baae9c9b43fd31199dedd3f0fd5ed67faf5769 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 31 May 2024 22:30:59 +0000 Subject: [PATCH 196/706] Migrating CI/CD jobs to macOS 14 (#127582) We have half the fleet in MacoS 14 already and it has been running fine so far https://github.com/pytorch/pytorch/issues/127490. So, I'm preparing the final push to replace the rest of them. This also switches release build from 13 to 14 (GitHub runners) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127582 Approved by: https://github.com/atalman --- .github/actionlint.yaml | 4 +-- .github/scripts/generate_ci_workflows.py | 8 +++--- .github/workflows/build-ios-binaries.yml | 4 +-- ...rated-macos-arm64-binary-conda-nightly.yml | 10 +++---- ...rm64-binary-libtorch-cxx11-abi-nightly.yml | 2 +- ...rated-macos-arm64-binary-wheel-nightly.yml | 10 +++---- .github/workflows/mac-mps.yml | 17 ++++++------ .github/workflows/periodic.yml | 4 +-- .github/workflows/trunk.yml | 26 +++++++++---------- 9 files changed, 40 insertions(+), 45 deletions(-) diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 569facc32cdf..679658dafdd8 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -23,8 +23,6 @@ self-hosted-runner: - macos-m1-stable - macos-m1-13 - macos-m1-14 - - macos-12-xl - - macos-12 - - macos12.3-m1 - macos-latest-xlarge - macos-13-xlarge + - macos-14-xlarge diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 54884e3a1261..fcac02bb8fe8 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -60,7 +60,7 @@ class BinaryBuildWorkflow: branches: str = "nightly" # Mainly for macos cross_compile_arm64: bool = False - macos_runner: str = "macos-12-xl" + macos_runner: str = "macos-14-xlarge" def __post_init__(self) -> None: if self.abi_version: @@ -285,7 +285,7 @@ class OperatingSystem: libtorch_variants=["shared-with-deps"], ), cross_compile_arm64=False, - macos_runner="macos-13-xlarge", + macos_runner="macos-14-xlarge", ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_LIBTORCH}, isolated_workflow=True, @@ -298,7 +298,7 @@ class OperatingSystem: OperatingSystem.MACOS_ARM64 ), cross_compile_arm64=False, - macos_runner="macos-13-xlarge", + macos_runner="macos-14-xlarge", ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, isolated_workflow=True, @@ -308,7 +308,7 @@ class OperatingSystem: os=OperatingSystem.MACOS_ARM64, package_type="conda", cross_compile_arm64=False, - macos_runner="macos-13-xlarge", + macos_runner="macos-14-xlarge", build_configs=generate_binary_build_matrix.generate_conda_matrix( OperatingSystem.MACOS_ARM64 ), diff --git a/.github/workflows/build-ios-binaries.yml b/.github/workflows/build-ios-binaries.yml index 3f3be84f48bd..32598f07a5c0 100644 --- a/.github/workflows/build-ios-binaries.yml +++ b/.github/workflows/build-ios-binaries.yml @@ -49,7 +49,7 @@ jobs: { config: "default", shard: 1, num_shards: 1, - runner: "macos-13-xlarge", + runner: "macos-14-xlarge", ios_platform: "SIMULATOR", ios_arch: "arm64", use_lite_interpreter: ${{ inputs.use_lite_interpreter || 1 }}, @@ -60,7 +60,7 @@ jobs: { config: "default", shard: 1, num_shards: 1, - runner: "macos-13-xlarge", + runner: "macos-14-xlarge", ios_platform: "OS", ios_arch: "arm64", use_lite_interpreter: ${{ inputs.use_lite_interpreter || 1 }}, diff --git a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml index a8cbdb7cd6fe..52ccb92a1935 100644 --- a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml @@ -34,7 +34,7 @@ concurrency: jobs: conda-py3_8-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -152,7 +152,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -270,7 +270,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -388,7 +388,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -506,7 +506,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml index 0ed7ba10a07d..7e2e345aefbc 100644 --- a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml @@ -34,7 +34,7 @@ concurrency: jobs: libtorch-cpu-shared-with-deps-cxx11-abi-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index 167161de3645..94a8fd9cd3de 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -34,7 +34,7 @@ concurrency: jobs: wheel-py3_8-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -153,7 +153,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -272,7 +272,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -391,7 +391,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -510,7 +510,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/mac-mps.yml b/.github/workflows/mac-mps.yml index 8554df29d1f6..53504b6133f6 100644 --- a/.github/workflows/mac-mps.yml +++ b/.github/workflows/mac-mps.yml @@ -13,29 +13,28 @@ concurrency: permissions: read-all jobs: - macos-13-py3-arm64-build: - name: macos-13-py3-arm64 + macos-py3-arm64-build: + name: macos-py3-arm64 uses: ./.github/workflows/_mac-build.yml with: sync-tag: macos-py3-arm64-build - build-environment: macos-13-py3-arm64 - runner-type: macos-m1-13 + build-environment: macos-py3-arm64 + runner-type: macos-m1-stable build-generates-artifacts: true # To match the one pre-installed in the m1 runners python-version: 3.9.12 test-matrix: | { include: [ - { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" }, - { config: "mps", shard: 1, num_shards: 1, runner: "macos-m2-14" }, + { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" }, ]} macos-py3-arm64-mps-test: name: macos-py3-arm64-mps uses: ./.github/workflows/_mac-test-mps.yml - needs: macos-13-py3-arm64-build + needs: macos-py3-arm64-build with: sync-tag: macos-py3-arm64-mps-test - build-environment: macos-13-py3-arm64 + build-environment: macos-py3-arm64 # Same as the build job python-version: 3.9.12 - test-matrix: ${{ needs.macos-13-py3-arm64-build.outputs.test-matrix }} + test-matrix: ${{ needs.macos-py3-arm64-build.outputs.test-matrix }} diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 716a72cc6d23..925bca54c074 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -151,7 +151,7 @@ jobs: { config: "default", shard: 1, num_shards: 1, - runner: "macos-13-xlarge", + runner: "macos-14-xlarge", ios_platform: "SIMULATOR", ios_arch: "arm64", use_lite_interpreter: 1, @@ -162,7 +162,7 @@ jobs: { config: "default", shard: 1, num_shards: 1, - runner: "macos-13-xlarge", + runner: "macos-14-xlarge", ios_platform: "OS", ios_arch: "arm64", use_lite_interpreter: 1, diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index f0567393d5fa..a91238fa2c9b 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -143,13 +143,13 @@ jobs: { config: "default", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, ]} - macos-13-py3-arm64-build: - name: macos-13-py3-arm64 + macos-py3-arm64-build: + name: macos-py3-arm64 uses: ./.github/workflows/_mac-build.yml with: sync-tag: macos-py3-arm64-build - build-environment: macos-13-py3-arm64 - runner-type: macos-m1-13 + build-environment: macos-py3-arm64 + runner-type: macos-m1-stable build-generates-artifacts: true # To match the one pre-installed in the m1 runners python-version: 3.9.12 @@ -163,31 +163,29 @@ jobs: macos-py3-arm64-mps-test: name: macos-py3-arm64-mps uses: ./.github/workflows/_mac-test-mps.yml - needs: macos-13-py3-arm64-build - if: needs.macos-13-py3-arm64-build.outputs.build-outcome == 'success' + needs: macos-py3-arm64-build + if: needs.macos-py3-arm64-build.outputs.build-outcome == 'success' with: sync-tag: macos-py3-arm64-mps-test - build-environment: macos-13-py3-arm64 + build-environment: macos-py3-arm64 # Same as the build job python-version: 3.9.12 test-matrix: | { include: [ - { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" }, { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" }, - ]} - macos-13-py3-arm64-test: - name: macos-13-py3-arm64 + macos-py3-arm64-test: + name: macos-py3-arm64 uses: ./.github/workflows/_mac-test.yml needs: - - macos-13-py3-arm64-build + - macos-py3-arm64-build - target-determination with: - build-environment: macos-13-py3-arm64 + build-environment: macos-py3-arm64 # Same as the build job python-version: 3.9.12 - test-matrix: ${{ needs.macos-13-py3-arm64-build.outputs.test-matrix }} + test-matrix: ${{ needs.macos-py3-arm64-build.outputs.test-matrix }} win-vs2019-cpu-py3-build: name: win-vs2019-cpu-py3 From f7171313abf14d9501a330457140b2f8a01c9985 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 31 May 2024 22:56:08 +0000 Subject: [PATCH 197/706] [Inductor] FlexAttention backward kernel optimization (#127208) BWD Speedups (before this PR): ``` | Type | Speedup | shape | score_mod | dtype | |---------|-----------|-------------------|---------------|----------------| | Average | 0.211 | | | | | Max | 0.364 | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | | Min | 0.044 | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | ``` BWD Speedups (after this PR, though not optimizing block size yet): ``` | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|---------------|----------------| | Average | 0.484 | | | | | Max | 0.626 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | | Min | 0.355 | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | ``` There are a few things need to do as follow-ups: * Optimized default block size on A100/H100. * Support different seqlen for Q and K/V. * Support dynamic shapes for backward. * Enhance unit tests to check there is no ```nan``` value in any grad. I think we should make some changes to ```test_padded_dense_causal``` because it has invalid inputs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127208 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 4 +- torch/_inductor/kernel/flex_attention.py | 300 +++++++++++++---------- torch/_inductor/select_algorithm.py | 5 +- 3 files changed, 184 insertions(+), 125 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index bc688ab834cb..d4feead90301 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -144,6 +144,8 @@ def _check_equal( ): compiled_error = (golden_out - compiled_out).abs().mean() ref_error = (golden_out - ref_out).abs().mean() + if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any(): + self.assertTrue(False, "Output/Grad with NaN") if compiled_error > ref_error * fudge_factor: name = tensor_name if tensor_name is not None else "" msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." @@ -195,7 +197,7 @@ def run_test( self._check_equal( k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" ) - v_fudge_factor = 8 * fudge_factor + v_fudge_factor = 4 * fudge_factor self._check_equal( v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" ) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 5a1f45e767a7..3e95dd4f65ce 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,7 +1,6 @@ """ Triton Implementation of the flex_attention Kernel""" import logging -import math from enum import auto, Enum from typing import Any, List, Tuple @@ -189,7 +188,7 @@ def build_subgraph_buffer( Z = {{size("Q", 0)}} H = {{size("Q", 1)}} - N_CTX = {{size("Q", 2)}} + Q_LEN = {{size("Q", 2)}} qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty @@ -200,7 +199,7 @@ def build_subgraph_buffer( qkv_offset = off_hz * stride_qh Q_block_ptr = tl.make_block_ptr( base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), + shape=(Q_LEN, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), @@ -208,7 +207,7 @@ def build_subgraph_buffer( ) K_block_ptr = tl.make_block_ptr( base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), + shape=(BLOCK_DMODEL, Q_LEN), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), @@ -216,7 +215,7 @@ def build_subgraph_buffer( ) V_block_ptr = tl.make_block_ptr( base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), + shape=(Q_LEN, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), @@ -236,7 +235,7 @@ def build_subgraph_buffer( q = (q * qk_scale).to(MATMUL_PRECISION) # loop over k, v and update accumulator lo = 0 - hi = N_CTX + hi = Q_LEN for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- load k, v -- @@ -299,7 +298,7 @@ def build_subgraph_buffer( # TODO dont want to write this if we dont require grad if OUTPUT_LOGSUMEXP: - l_ptrs = LSE + off_hz * N_CTX + offs_m + l_ptrs = LSE + off_hz * Q_LEN + offs_m lse = m_i + tl.math.log2(l_i) tl.store(l_ptrs, lse) """, @@ -446,13 +445,22 @@ def flex_attention(*args, **kwargs): # ---------------------------- Backward HOP Implementation ---------------------------- -def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, meta): +def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, meta): """How is this kernel parallelized? Currently this is only parallelizing over batch * num_heads, but we can, and want to parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require atomic updates to some grad values or to have a two pass kernel design. """ - return (batch_size * num_heads, 1, 1) + import triton + + # TODO: support different seqlen for Query and Key/Value. + num_key_value = num_queries + return ( + triton.cdiv(num_queries, meta["BLOCK_M2"]) + + triton.cdiv(num_key_value, meta["BLOCK_N1"]), + 1, + batch_size * num_heads, + ) flex_attention_backward_template = TritonTemplate( @@ -470,95 +478,83 @@ def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, # M: Number of queries, N: Number of keys/values, D: Model dimension # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head # (Modifiable) Config options: - # BLOCK_M - # BLOCK_N + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. # SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the # change of base out of the loop - # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row - # is not masked out? If so, we can skip an extra safety check - # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad # Define Q Strides - stride_qz = {{stride("Q", 0)}} - stride_qh = {{stride("Q", 1)}} - stride_qm = {{stride("Q", 2)}} - stride_qk = {{stride("Q", 3)}} - # Define K Strides - stride_kz = {{stride("K", 0)}} - stride_kh = {{stride("K", 1)}} - stride_kn = {{stride("K", 2)}} - stride_kk = {{stride("K", 3)}} - # Define V Strides - stride_vz = {{stride("V", 0)}} - stride_vh = {{stride("V", 1)}} - stride_vn = {{stride("V", 2)}} - stride_vk = {{stride("V", 3)}} + stride_z = {{stride("Q", 0)}} + stride_h = {{stride("Q", 1)}} + stride_tok = {{stride("Q", 2)}} + stride_d = {{stride("Q", 3)}} Z = {{size("Q", 0)}} H = {{size("Q", 1)}} - N_CTX = {{size("Q", 2)}} + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} - qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty - off_hz = tl.program_id(0) + pid = tl.program_id(0) + NUM_KV_BLOCKS = KV_LEN // BLOCK_N1 + + bhid = tl.program_id(2) + off_chz = (bhid * Q_LEN).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + + off_hz = tl.program_id(2) off_z = off_hz // H # batch idx off_h = off_hz % H # head idx # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh - - # Asserting contiguous for now... - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DV += off_z * stride_vz + off_h * stride_vh - - # TODO I think that this should be N_CTX/BLOCK_N blocks - for start_n in range(0, NUM_Q_BLOCKS): - # We are not doing the causal optimization yet allowing us to start further down the - # kv column - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_m = tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, BLOCK_DMODEL) - - # initialize pointers to value-like data - q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) - do_ptrs = DO + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) - - # pointer to row-wise quantities in value-like data - D_ptrs = DELTA + off_hz * N_CTX - l_ptrs = LSE + off_hz * N_CTX - - # initialize dv and dk - dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - - # Key and Value stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - - for start_m in range(0, NUM_Q_BLOCKS * BLOCK_M, BLOCK_M): - offs_m_curr = start_m + offs_m - - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - - if SCORE_MOD_IS_LINEAR: - qk_scale *= 1.44269504 - q = (q * qk_scale).to(MATMUL_PRECISION) - - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, tl.trans(k.to(MATMUL_PRECISION)), acc=qk) - pre_mod_scores = qk + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DV += adj + LSE += off_chz + DELTA += off_chz + + offs_k = tl.arange(0, BLOCK_DMODEL) + + if pid >= NUM_KV_BLOCKS: + # THIS BLOCK DOES DQ + off_pid = pid - NUM_KV_BLOCKS + start_m2 = off_pid * BLOCK_M2 + + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + do = tl.load(DO + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) + + lse = tl.load(LSE + offs_m2) + lse = lse[:, None] + + start_n2 = 0 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + offs_n2 = start_n2 + tl.arange(0, BLOCK_N2) + kT_ptrs = K + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d + Di = tl.load(DELTA + offs_m2) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + curr_n = start_n2 + num_steps = KV_LEN // BLOCK_N2 + for blk_idx in range(num_steps): + offs_n2= curr_n + tl.arange(0, BLOCK_N2) + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - m = offs_m_curr[:, None] - n = offs_n[None, :] + pre_mod_scores = qk + m = offs_m2[:, None] + n = offs_n2[None, :] {{ modification( subgraph_number=0, output_name="post_mod_scores", @@ -569,25 +565,13 @@ def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, n="n", out="qk" ) | indent_except_first(3) }} - # TODO: In the case that score_mod is linear, this can be LICMed + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if not SCORE_MOD_IS_LINEAR: post_mod_scores *= 1.44269504 - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - l_i = tl.load(l_ptrs + offs_m_curr) - p = tl.math.exp2(post_mod_scores - l_i[:, None]) - - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(tl.trans(p.to(MATMUL_PRECISION)), do) - - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr) # [BLOCKM, 1] - - # compute ds = p * (dp - delta[:, None]) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, tl.trans(v)) - ds = p * dp - + p = tl.math.exp2(post_mod_scores - lse).to(MATMUL_PRECISION) + # Compute dP and dS. + dp = tl.dot(do, vT) + ds = p * (dp - Di[:, None]) # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ {{ modification( subgraph_number=1, @@ -601,32 +585,101 @@ def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, ) | indent_except_first(3) }} ds = grad_scores # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds.to(MATMUL_PRECISION)), q) - # compute dq - dq = tl.load(dq_ptrs) - dq += tl.dot(ds.to(MATMUL_PRECISION), k) - - # Store grad_query - tl.store(dq_ptrs, dq) - - # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - - # write-back - index_n = offs_n[:, None] - index_k = offs_k[None, :] + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += BLOCK_N2 + kT_ptrs += BLOCK_N2 * stride_tok + vT_ptrs += BLOCK_N2 * stride_tok + # Write back dQ. + dq_ptrs = DQ + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dq_ptrs, dq) + else: + # THIS BLOCK DOES DK & DV + start_n1 = pid * BLOCK_N1 + start_m1 = 0 + + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) + + offs_m1 = start_m1 + tl.arange(0, BLOCK_M1) + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + qT_ptrs = Q + offs_m1[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m1[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + curr_m = start_m1 + num_steps = Q_LEN // BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load LSE before computing qk to reduce pipeline stall. + offs_m1 = curr_m + tl.arange(0, BLOCK_M1) + lse = tl.load(LSE + offs_m1) + qkT = tl.dot(k, qT) + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = offs_m1[None, :] + n = offs_n1[:, None] + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_h", + m="m", + n="n", + out="qkT" + ) | indent_except_first(3) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not SCORE_MOD_IS_LINEAR: + post_mod_scores *= 1.44269504 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do) + Di = tl.load(DELTA + offs_m1) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + m = offs_m1[None, :] + n = offs_n1[:, None] + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_h", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(3) }} + dsT = grad_scores + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT)) + # Increment pointers. + curr_m += BLOCK_M1 + qT_ptrs += BLOCK_M1 * stride_tok + do_ptrs += BLOCK_M1 * stride_tok - # Store grad_key and grad_value - dv_ptrs = DV + (index_n * stride_vn + index_k * stride_vk) + dv_ptrs = DV + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d tl.store(dv_ptrs, dv) + # Write back dK. + index_n = offs_n1[:, None] + index_k = offs_k[None, :] # TODO generalize and add proper mask support mask = (index_n != -1) & (index_k != -1) {{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}} - """, ) @@ -722,10 +775,11 @@ def flex_attention_backward(*args, **kwargs): mutated_inputs=[grad_query, grad_value], num_stages=num_stages, num_warps=num_warps, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, + BLOCK_M1=BLOCK_M, + BLOCK_N1=BLOCK_N, + BLOCK_M2=BLOCK_N, + BLOCK_N2=BLOCK_M, BLOCK_DMODEL=query.get_size()[-1], - NUM_Q_BLOCKS=math.ceil(query.get_size()[-2] / BLOCK_M), # For now, we always assume the "sound" option SCORE_MOD_IS_LINEAR=False, ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 13bfdcb60fda..3ba0ff0d949b 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -310,7 +310,10 @@ def modification( Args: subgraph_number (int): The index of the subgraph in self.subgraphs """ - with self.create_subgraph_body(f"modification_{subgraph_number}"): + num = 0 + while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: + num += 1 + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): assert isinstance(subgraph_number, int) assert isinstance(self.subgraphs, list) assert ( From a8c9b26534ebb7eb268f3aab13405284cf982792 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 31 May 2024 23:00:07 +0000 Subject: [PATCH 198/706] [BE] Fix dependabot security errors (#127567) Fixes https://github.com/pytorch/pytorch/security/dependabot/36 and https://github.com/pytorch/pytorch/security/dependabot/37 by deleting spurious dependency Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/127567 Approved by: https://github.com/malfet --- .github/requirements/conda-env-Linux-X64.txt | 1 - .github/requirements/conda-env-iOS.txt | 1 - 2 files changed, 2 deletions(-) diff --git a/.github/requirements/conda-env-Linux-X64.txt b/.github/requirements/conda-env-Linux-X64.txt index 16bbc57dd3be..dc44eb39f69f 100644 --- a/.github/requirements/conda-env-Linux-X64.txt +++ b/.github/requirements/conda-env-Linux-X64.txt @@ -4,6 +4,5 @@ mkl-include=2022.1.0 ninja=1.10.2 numpy=1.23.3 pyyaml=6.0 -requests=2.31.0 setuptools=68.2.2 typing-extensions=4.3.0 diff --git a/.github/requirements/conda-env-iOS.txt b/.github/requirements/conda-env-iOS.txt index 205c07925a01..3539a8a0ccf8 100644 --- a/.github/requirements/conda-env-iOS.txt +++ b/.github/requirements/conda-env-iOS.txt @@ -3,6 +3,5 @@ cmake=3.22.1 ninja=1.10.2 numpy=1.23.3 pyyaml=6.0 -requests=2.31.0 setuptools=68.2.2 typing-extensions=4.3.0 From ff8042bcfb518127c86ad5b4af4fa9171a499904 Mon Sep 17 00:00:00 2001 From: Huamin Li Date: Fri, 31 May 2024 23:56:11 +0000 Subject: [PATCH 199/706] Enable AOTI shim v2 build and add into libtorch (#125211) Summary: Follow up of https://github.com/pytorch/pytorch/pull/125087 This diff will create shim v2 header and cpp file and corresponding build Differential Revision: D56617546 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125211 Approved by: https://github.com/desertfire --- buckbuild.bzl | 2 ++ build.bzl | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/buckbuild.bzl b/buckbuild.bzl index 649ebe668365..1d668117e910 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -383,6 +383,7 @@ def get_aten_generated_files(enabled_backends): "core/TensorMethods.cpp", "core/aten_interned_strings.h", "core/enum_tag.h", + "torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp", ] + get_aten_derived_type_srcs(enabled_backends) # This is tiresome. A better strategy would be to unconditionally @@ -467,6 +468,7 @@ def gen_aten_files( cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([ "--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT), "--install_dir $OUT", + "--aoti_install_dir $OUT/torch/csrc/inductor/aoti_torch/generated" ] + extra_params), visibility = visibility, compatible_with = compatible_with, diff --git a/build.bzl b/build.bzl index 5ab9f92acecc..8fd15f4e9c42 100644 --- a/build.bzl +++ b/build.bzl @@ -73,6 +73,7 @@ def define_targets(rules): "$(execpath //torchgen:gen)", "--install_dir=$(RULEDIR)", "--source-path aten/src/ATen", + "--aoti_install_dir=$(RULEDIR)/torch/csrc/inductor/aoti_torch/generated" ] + (["--static_dispatch_backend CPU"] if rules.is_cpu_static_dispatch_build() else [])) gen_aten_outs_cuda = ( @@ -83,6 +84,7 @@ def define_targets(rules): gen_aten_outs = ( GENERATED_H + GENERATED_H_CORE + GENERATED_CPP + GENERATED_CPP_CORE + + GENERATED_AOTI_CPP + aten_ufunc_generated_cpu_sources() + aten_ufunc_generated_cpu_kernel_sources() + [ "Declarations.yaml", @@ -316,3 +318,8 @@ GENERATED_AUTOGRAD_CPP = [ "torch/csrc/lazy/generated/RegisterAutogradLazy.cpp", "torch/csrc/lazy/generated/RegisterLazy.cpp", ] + _GENERATED_AUTOGRAD_CPP_HEADERS + GENERATED_LAZY_H + +GENERATED_AOTI_CPP = [ + "torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp", + "torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp", +] From df53cc711482153e71f5214b4a878419af0a022e Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sat, 1 Jun 2024 01:25:10 +0000 Subject: [PATCH 200/706] [reland] "[reland] `_foreach_copy` with different src/dst dtypes" (#127186) Fixes #115171 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127186 Approved by: https://github.com/ezyang --- aten/src/ATen/native/ForeachUtils.h | 5 +- .../ATen/native/cuda/ForeachBinaryOpList.cu | 188 ++++++++++++++++-- test/test_foreach.py | 22 ++ 3 files changed, 199 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index 0839dd9a1560..f5c0672402f3 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -102,12 +102,13 @@ inline void check_foreach_api_restrictions( // corresponding tensors (aligning in index across the tensorLists) share the // same device and dtype. inline bool _check_tensors_share_device_and_dtype( - ArrayRef tensorLists) { + ArrayRef tensorLists, + const bool skip_dtype_check = false) { const auto expected_dtype = tensorLists[0][0].dtype(); const auto expected_device = tensorLists[0][0].device(); auto is_tensor_okay = [&](const Tensor& tensor) { - return tensor.dtype() == expected_dtype && + return (skip_dtype_check || tensor.dtype() == expected_dtype) && tensor.device() == expected_device && tensor.layout() == at::kStrided && tensor.is_non_overlapping_and_dense(); }; diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu index cf7e40115fef..533aa38c04cf 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu @@ -1,9 +1,11 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -252,20 +254,156 @@ FOREACH_BINARY_OP_LIST( power_functor, /*division_op*/ true); -template -struct Identity { - __device__ __forceinline__ T operator()(const T& x) { - return x; +template +struct Copy { + __device__ __forceinline__ dst_t operator()(const src_t& x) { + return static_cast(x); } }; +template +struct Copy> { + __device__ __forceinline__ dst_t operator()(const c10::complex& x) { + if constexpr (!(std::is_same_v> || + std::is_same_v>)) { + return static_cast(x.real()); + } else { + return static_cast(x); + } + } +}; + +template +struct Copy> { + __device__ __forceinline__ dst_t operator()(const c10::complex& x) { + if constexpr (!(std::is_same_v> || + std::is_same_v>)) { + return static_cast(x.real()); + } else { + return static_cast(x); + } + } +}; + +#define AT_DISPATCH_SOURCE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Byte, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Char, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Long, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Short, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Int, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Double, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Float, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::ComplexDouble, \ + src_t, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::ComplexFloat, \ + src_t, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Half, \ + src_t, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::BFloat16, \ + src_t, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Bool, \ + src_t, \ + __VA_ARGS__)) + +namespace { + +template < + typename T, + typename src_t, + int depth, + int r_args_depth, + int res_arg_index> +struct CopyFunctor { + static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1); + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + src_t* src_ptr = (src_t*)tl.addresses[0][tensor_loc]; + src_ptr += chunk_idx * chunk_size; + T* self_ptr = (T*)tl.addresses[1][tensor_loc]; + self_ptr += chunk_idx * chunk_size; + + const bool all_aligned{is_aligned(src_ptr) && is_aligned(self_ptr)}; + + n -= chunk_idx * chunk_size; + src_t src_args[kILP]; + T r_args[kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(src_args, src_ptr, 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[ii] = static_cast(op(src_args[ii])); + } + // store + load_store(self_ptr, r_args, i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const auto i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + src_args[ii] = src_ptr[i]; + } + } +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[ii] = static_cast(op(src_args[ii])); + } + store_args(self_ptr, r_args, i_start, chunk_size, n); + } + } + } +}; + +} // anonymous namespace + void foreach_tensor_copy_list_kernel_cuda_( TensorList self, TensorList src, const bool non_blocking) { check_foreach_api_restrictions(self, src); - if (!can_use_fast_route( - self, src, /* does_op_promote_integer_inputs_to_float */ false)) { + if (!(_check_tensors_share_device_and_dtype( + {self, src}, /* skip_dtype_check */ true) && + std::all_of( + src.cbegin(), + src.cend(), + [&](const auto& t) -> bool { + return t.dtype() == src[0].dtype(); + }) && + _check_tensors_share_sizes_and_strides({self, src}))) { return at::native::foreach_tensor_copy_list_kernel_slow_( self, src, non_blocking); } @@ -280,16 +418,38 @@ void foreach_tensor_copy_list_kernel_cuda_( "foreach_tensor_copy", [&]() { using opmath_t = at::opmath_type; - multi_tensor_apply<2>( - tensor_lists, - UnaryOpFunctor< - scalar_t, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1>(), - Identity()); + AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] { + if constexpr (std::is_same_v) { + multi_tensor_apply<2>( + tensor_lists, + UnaryOpFunctor< + scalar_t, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), + Copy()); + } else { + // Ref: + // https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301 + if (!self[0].is_complex() && src[0].is_complex()) { + TORCH_WARN_ONCE( + "Casting complex values to real discards the imaginary part"); + } + multi_tensor_apply<2>( + tensor_lists, + CopyFunctor< + scalar_t, + src_t, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), + Copy()); + } + }); }); increment_version(self); } +#undef AT_DISPATCH_SOURCE_TYPES + } // namespace at::native diff --git a/test/test_foreach.py b/test/test_foreach.py index 8465d538187c..2683b9823190 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -1248,6 +1248,28 @@ def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op): copy_(t, s, non_blocking) self.assertEqual(ref_input, sample.input) + @onlyCUDA + @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db)) + def test_foreach_copy_with_multi_dtypes(self, device, dtype, op): + # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_ + foreach_copy_ = ForeachFuncWrapper(op.inplace_variant) + for sample in op.sample_inputs(device, dtype, noncontiguous=False): + for src_dtype in floating_types_and(torch.half, torch.bfloat16): + if src_dtype == dtype: + continue + self_tensors = [t.clone() for t in sample.input] + src_tensors = [t.to(src_dtype) for t in self_tensors] + out = foreach_copy_( + (self_tensors, src_tensors), is_cuda=True, expect_fastpath=True + ) + self.assertEqual( + out, + [ + torch.empty_like(t).copy_(s) + for t, s in zip(self_tensors, src_tensors) + ], + ) + # Test reverse-mode & forward-mode AD if supported. @onlyCUDA @ops( From 25447ba241b788eb942af6f93c1dac71deadee65 Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Fri, 31 May 2024 15:16:51 +0000 Subject: [PATCH 201/706] Always Link libtorch and libtorch_cpu to ensure the functionality for AOT mode (#127381) Fix #126763: The root cause is that the produced library does not link any torch library because the vec ISA is invalid, and then it cannot run into another path without linking `libtorch` and `libtorch_cpu`. https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codecache.py#L1637-L1642 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127381 Approved by: https://github.com/desertfire --- test/inductor/test_aot_inductor.py | 22 ++++++++++++++-------- torch/_inductor/codecache.py | 5 +++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 03d18f7c3f3f..2a9a966628ad 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -12,6 +12,7 @@ import torch import torch._export import torch._inductor +import torch._inductor.config import torch.nn as nn from torch._dynamo.testing import rand_strided, same from torch._dynamo.utils import counters @@ -1313,14 +1314,19 @@ def fn(a, b, alpha=1.0): with self.assertRaises(RuntimeError): torch._export.aot_compile(fn, args=(a, b), kwargs={"alpha": 2.0}) - so_path = torch._export.aot_compile( - torch.ops.aten.add, args=(a, b), kwargs={"alpha": 2.0}, same_signature=False - ) - kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path) - res = kernel_runner.run([a, b]) - self.assertTrue(isinstance(res, list)) - self.assertTrue(len(res) == 1) - self.assertEqual(fn(a, b, alpha=2.0), res[0]) + for simdlen in [0, None]: + with torch._inductor.config.patch({"cpp.simdlen": simdlen}): + so_path = torch._export.aot_compile( + torch.ops.aten.add, + args=(a, b), + kwargs={"alpha": 2.0}, + same_signature=False, + ) + kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path) + res = kernel_runner.run([a, b]) + self.assertTrue(isinstance(res, list)) + self.assertTrue(len(res) == 1) + self.assertEqual(fn(a, b, alpha=2.0), res[0]) def test_buffer_mutation_2(self): class Model(torch.nn.Module): diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index b5aa1d1b8a61..617a7ba7e262 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1717,6 +1717,11 @@ def get_include_and_linking_paths( else: libs = ["omp"] if config.is_fbcode() else ["gomp"] + # For AOT mode, the produced library relies on torch cpu to set grad mode + # like aoti_torch_grad_mode_set_enabled + if aot_mode and sys.platform == "linux" and not config.is_fbcode(): + libs += ["torch", "torch_cpu"] + # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 if not config.abi_compatible: libs += ["c10"] From 3c1cf03fde145bdbe1f5ffb81765d076c10b4c04 Mon Sep 17 00:00:00 2001 From: a-gardner1 Date: Sat, 1 Jun 2024 04:03:10 +0000 Subject: [PATCH 202/706] Add fake impl for aten.unique_dim (#126561) Follow-up to #113118 and #124306. Developed in coordination with the solution to https://github.com/microsoft/onnxscript/pull/1547 This PR adds the missing fake tensor implementation for `aten.unique_dim`, thus enabling tracing and compilation of `torch.unique` when `dim` is not None. Local testing has proceeded with the following simple script (provided that one has checked out the changes in https://github.com/microsoft/onnxscript/pull/1547): ```python import onnx import onnxruntime as ort import logging import numpy as np onnx_program = torch.onnx.dynamo_export( lambda x: torch.unique(x, dim=0, return_inverse=True), torch.arange(10), export_options=torch.onnx.ExportOptions( dynamic_shapes=True, diagnostic_options=torch.onnx.DiagnosticOptions( verbosity_level=logging.DEBUG))) onnx_program.save("torch_unique.onnx") onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10)) onnx_outputs = onnx_program(*onnx_inputs) loaded_onnx_program = onnx.load("torch_unique.onnx") onnx.checker.check_model(loaded_onnx_program) ort_session = ort.InferenceSession("torch_unique.onnx") inputs = np.random.randint(0, 10, 10) print(f"Inputs: {inputs}") outputs = ort_session.run(None, { "l_x_": inputs }) print(f"Outputs: {outputs}") print("Success") ``` Co-authored-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126561 Approved by: https://github.com/ezyang --- test/test_ops.py | 4 +-- test/test_proxy_tensor.py | 3 -- torch/_subclasses/fake_impls.py | 59 +++++++++++++++++++++++++-------- 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 44f503ae9b6e..cbec88136ed2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2522,8 +2522,8 @@ def map_to_fake(e): or name in sometimes_dynamic_output_op_test ) self.assertTrue( - mode.shape_env is None - or not mode.shape_env.allow_dynamic_output_shape_ops + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops or name not in supported_dynamic_output_op_tests ) except torch._subclasses.fake_tensor.DataDependentOutputException: diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index d8aa8863d566..c7b2e51ced20 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -2003,7 +2003,6 @@ def f(t): xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition - xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but... @@ -2034,8 +2033,6 @@ def f(t): inplace_symbolic_tensor_failures = { # bugs xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double - # decomp not implemented - xfail('unique', ''), } out_symbolic_tensor_failures = { diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 4376d24255ef..2b1cf13cc935 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -258,9 +258,8 @@ def dyn_shape(fake_mode, func, *args, **kwargs): raise DynamicOutputShapeException(func) -@register_op_impl(aten._unique2.default) -def unique2( - fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False +def _unique( + fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False ): if ( fake_mode.shape_env is None @@ -269,7 +268,8 @@ def unique2( # Without symints/symfloats, cannot handle this raise DynamicOutputShapeException(func) - if (nnz := arg.unique_memo) is None: + # Do not use a memo for unique_dim + if dim is not None or (nnz := arg.unique_memo) is None: # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, @@ -291,28 +291,59 @@ def unique2( maxval = sys.maxsize - 1 - if not has_free_symbols(arg.numel()): - maxval = int(arg.numel()) + numel = arg.numel() if dim is None else arg.size(dim) + if not has_free_symbols(numel): + maxval = int(numel) _constrain_range_for_size(nnz, max=maxval) - arg.unique_memo = nnz + if dim is None: + arg.unique_memo = nnz - ret = [arg.new_empty((nnz,))] + if dim is None: + ret = [arg.new_empty((nnz,))] + else: + ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])] - if return_inverse: - ret.append(torch.empty_like(arg)) + return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu") + if return_inverse or return_if_dim_and_cpu: + inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],)) else: - ret.append(arg.new_empty(0)) + inverse = arg.new_empty(0) + ret.append(inverse) - if return_counts: - ret.append(torch.empty_like(arg)) + if return_counts or return_if_dim_and_cpu: + counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],)) else: - ret.append(arg.new_empty(0)) + counts = arg.new_empty(0) + ret.append(counts) return tuple(ret) +@register_op_impl(aten._unique2.default) +def unique2( + fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False +): + return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) + + +@register_op_impl(aten.unique_dim.default) +def unique_dim( + fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False +): + return _unique( + fake_mode, + func, + arg, + # normalize dim to be non-negative + dim if dim >= 0 else dim % max(arg.ndim, 1), + sorted, + return_inverse, + return_counts, + ) + + @register_op_impl(aten.repeat_interleave.Tensor) def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): if output_size is None: From 7ef7c265d4361691dc4cf54152db083de3215fbf Mon Sep 17 00:00:00 2001 From: Huy Do Date: Sat, 1 Jun 2024 04:31:37 +0000 Subject: [PATCH 203/706] Ack codecvt_utf8_utf16 as a deprecated func in C++17 (#127659) https://en.cppreference.com/w/cpp/header/codecvt. This starts to fail on MacOS after migrating it to MacOS 14 with a newer toolchain. For example https://hud.pytorch.org/pytorch/pytorch/commit/57baae9c9b43fd31199dedd3f0fd5ed67faf5769. As there is no clear alternative to the deprecated function yet, I just ack the warning to fix the build and complete the migration https://github.com/pytorch/pytorch/issues/127490 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127659 Approved by: https://github.com/kit1980, https://github.com/atalman --- c10/util/StringUtil.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/c10/util/StringUtil.cpp b/c10/util/StringUtil.cpp index 084c59c7d161..1f5254a3deda 100644 --- a/c10/util/StringUtil.cpp +++ b/c10/util/StringUtil.cpp @@ -41,10 +41,15 @@ std::ostream& _strFromWide(std::ostream& ss, const std::wstring& wString); #ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +// TODO (huydhn) https://en.cppreference.com/w/cpp/header/codecvt has been +// deprecated in C++17 but there is no alternative yet, so I just ack it std::ostream& _strFromWide(std::ostream& ss, const std::wstring& wString) { std::wstring_convert> converter; return _str(ss, converter.to_bytes(wString)); } +#pragma GCC diagnostic pop #else // #ifndef _WIN32 // The WIN32 implementation of wstring_convert leaks memory; see From 554265d4504108c1236035f8c957d3364f6c1123 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Sat, 1 Jun 2024 06:15:31 +0000 Subject: [PATCH 204/706] [Inductor]: Use new device-agnostic libdevice import from triton.language (#127348) Triton refactored `libdevice` in https://github.com/triton-lang/triton/commit/5e6952d8c529770ff0321c8ded633c32af0ff9ea While both imports still appear to work under CUDA, this change is required to pull the correct libdevice variants under the Intel XPU backend. I am working on developing a test that catches this behavior. The easiest path would be to enable `test/inductor/test_triton_kernels.py` under the XPU backend, but a different group at Intel manages that test and I need to see if they already have an enabling plan. I am not sure the double `libdevice` import (see line 22 where I have the nolint flag) is really necessary but have yet to find a conclusive test case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127348 Approved by: https://github.com/etaf, https://github.com/peterbell10 --- torch/_inductor/runtime/triton_helpers.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index 9179e5cd676a..95708ada9020 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -16,15 +16,21 @@ class tl: # type: ignore[no-redef] # In the latest triton, math functions were shuffled around into different modules: # https://github.com/openai/triton/pull/3172 -if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): - libdevice = tl.extra.cuda.libdevice - math = tl.math -elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): - libdevice = tl.extra.intel.libdevice +try: + from triton.language.extra import libdevice + + libdevice = tl.extra.libdevice # noqa: F811 math = tl.math -else: - libdevice = tl.math - math = tl +except ImportError: + if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): + libdevice = tl.extra.cuda.libdevice + math = tl.math + elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): + libdevice = tl.extra.intel.libdevice + math = tl.math + else: + libdevice = tl.math + math = tl @triton.jit From e62925930f6a62f6aeeb1fe1a661a9bd3352b53d Mon Sep 17 00:00:00 2001 From: Shan19900305 Date: Sat, 1 Jun 2024 06:54:30 +0000 Subject: [PATCH 205/706] Clear dest impl extra_meta_ info when shallow_copy_from src impl to dest impl. (#127616) tensorA.data = tensorB will call shallow_copy_from function to copy tensorB metadata and storage to tensorA metadata and storage. If tensorB extra_meta_ is nullptr,then tensorA extra_meta_ still keep in tensorA. This will contaminate new meta data in tensorA. @ezyang @bdhirsh Pull Request resolved: https://github.com/pytorch/pytorch/pull/127616 Approved by: https://github.com/ezyang --- c10/core/TensorImpl.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 47f83c78e578..516a61f02004 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -577,6 +577,11 @@ void TensorImpl::copy_generic_tensor_metadata( dest_impl->numel_ = src_impl->numel_; if (src_impl->extra_meta_ != nullptr) { dest_impl->extra_meta_ = src_impl->extra_meta_->clone(); + } else if (dest_impl->extra_meta_ != nullptr) { + // Clean dest_impl extra meta data, cause shallow_copy_from dest impl is a + // real tensor impl, which maybe take extra meta data. This info will + // contaminate the new dest_impl metadata info. + dest_impl->extra_meta_.reset(nullptr); } // NB: symbolic sizes and strides are copied as is custom policy, but python From c3be459f26fe6050bf5041835eca64bf71202fa2 Mon Sep 17 00:00:00 2001 From: "haozhe.zhu" Date: Fri, 31 May 2024 15:45:18 +0800 Subject: [PATCH 206/706] [inductor] fix mkldnn linear binary fusion check ut (#127296) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In this PR: (1)Fix the unary fusion for bf16 conv/linear. Previously we registered same fusion pattern for `bf16. fp16`. And we do not check the dtype while matching the pattern. This results the `fp16` case matched the `bf16` pattern but in later replacement, we found that we have a float16 here which is not expected, so we do not fuse them. We fix it by checking dtypes to avoid `fp16` case matched `bf16` pattern. ``` def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): def fn(match): matched = _is_single_computation_op(computation_op, **lowp_dtype**)(match) # previously we do not check lowp_dtype here ``` It is not exposed before because we only check the match count, and the match count is anyway correct because we matched the pattern. To address this, we add check on number of `generated_kernel`. If it is not fused, there will be an additional kernel to compute the post op. (2)Previous the ut ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_binary ``` dose not check the fusion status, fix it in this PR. (3)Extend `test_conv_binary` to test with lp. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127296 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_mkldnn_pattern_matcher.py | 93 ++++++++++++++++++-- torch/_inductor/fx_passes/mkldnn_fusion.py | 14 ++- 2 files changed, 94 insertions(+), 13 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 756de35df84c..94fe34c64e53 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -10,7 +10,7 @@ from torch._dynamo import config as dynamo_config from torch._dynamo.utils import counters from torch._export import capture_pre_autograd_graph -from torch._inductor import config +from torch._inductor import config, metrics from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch.ao.quantization.quantize_pt2e import ( @@ -83,6 +83,36 @@ def get_default_quantizer(is_qat, is_dynamic): return quantizer +def cal_conv_generated_kernel_number(mod, input, dtype): + # this function is to decide how many kernels are generated + # while testing conv2d/3d/deconv2d + # the assumption is: + # (1) There will be a to_dtype kernel for input for lp + # (2) inductor always use channe_last format, there will + # be a to_channel_last format for input + # (3) to_dtype and to_channel_last for input can be fused + # (4) inductor always get channel last format from mkldnn_conv_pointwise(binary), + # and force the output to have same stride with eager. + # So there will be a to_contiguous for output if eager output is contiguouse + mod = copy.deepcopy(mod) + input = input.clone() + if dtype == torch.float32: + maybe_autocast = contextlib.nullcontext() + else: + maybe_autocast = torch.cpu.amp.autocast(dtype=dtype) + with torch.no_grad(), maybe_autocast: + output = mod(input) + input_kernel, output_kernel = 0, 0 + if ( + input.is_contiguous(memory_format=torch.contiguous_format) + or dtype != torch.float32 + ): + input_kernel = 1 + if output.is_contiguous(memory_format=torch.contiguous_format): + output_kernel = 1 + return input_kernel + output_kernel + + @config.patch({"freezing": True}) class TestPatternMatcherBase(TestCase): def _check_unary_is_decomposed(self, unary_fn): @@ -264,6 +294,7 @@ def forward(self, x): memory_format, dtype, ) in options: + metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) else: @@ -284,10 +315,18 @@ def forward(self, x): # Has extra dtype conversion nodes for autocast. match_nodes += 2 self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype) + generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype) + self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm def test_conv2d_unary_cpu(self): self._test_conv_unary_cpu_base(dim=4) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm def test_conv3d_unary_cpu(self): self._test_conv_unary_cpu_base(dim=5) @@ -321,6 +360,7 @@ def forward(self, x): dtypes.append(torch.float16) options = itertools.product(unary_list, [True, False], dtypes) for unary_fn, bias, dtype in options: + metrics.reset() mod = M(unary_fn, 10, 30, bias=bias).eval() # only fuse for linear when the dtype is bf16 mod = mod @@ -335,6 +375,8 @@ def forward(self, x): self._test_common( mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype ) + # only generated 1 kernel for "to" + self.assertEqual(metrics.generated_kernel_count, 1) @unittest.skipIf(not TEST_MKL, "Test requires MKL") def test_linear_fp32(self): @@ -354,6 +396,9 @@ def forward(self, x): matcher_nodes = 1 self._test_common(mod, (v,), matcher_count, matcher_nodes) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm def test_conv_transpose2d_unary(self): class M(torch.nn.Module): def __init__( @@ -386,6 +431,7 @@ def forward(self, x): ) for unary_fn, memory_format, dtype in options: + metrics.reset() x_shape = (1, 3, 28, 28) mod = M(unary_fn).eval() @@ -401,6 +447,8 @@ def forward(self, x): # Has extra dtype conversion nodes for autocast. match_nodes += 2 self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype) + generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype) + self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) def _test_conv_binary_base(self, dim=4): assert dim == 4 or dim == 5 @@ -430,19 +478,29 @@ def forward(self, x): else: return self.binary_fn(x1, x2) + dtypes = [ + torch.float, + ] + if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + dtypes.append(torch.bfloat16) + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + dtypes.append(torch.float16) cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d test_memory_format = [torch.contiguous_format, cl_format] options = itertools.product( binary_list, [True, False], test_memory_format, + dtypes, ) for ( binary_fn, has_relu, memory_format, + dtype, ) in options: + metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) else: @@ -457,11 +515,21 @@ def forward(self, x): match_nodes = binary_list[binary_fn][1] if has_relu: match_nodes += 1 - self._test_common(mod, (v,), match_count, match_nodes + 2) + self._test_common( + mod, (v,), match_count, match_nodes + 2, check_autocast=dtype + ) + generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype) + self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm def test_conv2d_binary(self): self._test_conv_binary_base(dim=4) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm def test_conv3d_binary(self): self._test_conv_binary_base(dim=5) @@ -489,7 +557,7 @@ def forward(self, x, y): ) out_feature = 30 for binary_fn, input_shape, bias, dtype in options: - torch._dynamo.reset() + metrics.reset() # addmm(mm) + (linear+add) match_count = 2 match_nodes = 3 @@ -498,13 +566,20 @@ def forward(self, x, y): # view + linear + view(joint_graph+freeze pass) match_count = match_count + 5 if is_inplace else match_count + 3 match_nodes = match_nodes + 7 if is_inplace else match_nodes + 5 - mod = M(binary_fn, input_shape[-1], out_feature, bias).to(dtype).eval() - v = torch.randn(input_shape).to(dtype) + mod = M(binary_fn, input_shape[-1], out_feature, bias).eval() + v = torch.randn(input_shape) other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype) - mod_c = torch.compile(mod) - out, code = run_and_get_code(mod_c, v, other) - self.assertEqual(out, mod(v, other), rtol=1e-2, atol=1e-2) - # TODO - assert fusions work code + self._test_common( + mod, + ( + v, + other, + ), + match_count, + match_nodes, + check_autocast=dtype, + ) + self.assertEqual(metrics.generated_kernel_count, 1) def test_multi_linear_share_same_input(self): # llama pattern. diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 3edb4a397932..5d1a723fa58a 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -197,9 +197,15 @@ def _binary_fusion_v1(computation_call, binary_fn): def _binary_fusion_v2(computation_call, binary_fn): return CallFunction(binary_fn, computation_call, KeywordArg("other")) - def _is_single_computation_op(computation_op): + def _is_single_computation_op(computation_op, lowp_dtype=None): def fn(match): computation_nodes = filter_nodes(match.nodes, computation_op) + + if lowp_dtype: + output_node_meta = match.output_node().meta.get("val") + if output_node_meta.dtype != lowp_dtype: + return False + if len(computation_nodes) < 1: return False if any(n.args[-3] != "none" for n in computation_nodes): @@ -210,7 +216,7 @@ def fn(match): def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): def fn(match): - matched = _is_single_computation_op(computation_op)(match) + matched = _is_single_computation_op(computation_op, lowp_dtype)(match) computation_node = filter_nodes(match.nodes, computation_op)[0] if lowp_dtype: conversion_dtype_nodes = filter_nodes( @@ -249,7 +255,7 @@ def fn(match, *args, **kwargs): def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): @register_lowering_pattern( - pattern, extra_check=_is_single_computation_op(computation_op) + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) ) def fn(match, *args, **kwargs): negative_slope = kwargs.get("negative_slope") @@ -291,7 +297,7 @@ def fn(match, *args, **kwargs): def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): @register_lowering_pattern( - pattern, extra_check=_is_single_computation_op(computation_op) + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) ) def fn(match, *args, **kwargs): min_value = kwargs.get("min_value") From 25994a7ed135f100ff36eb9fe41e9c75668329ef Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 31 May 2024 19:38:39 -0700 Subject: [PATCH 207/706] [AOTI] Fix a bug when mutated buffer meets .to (#127671) Summary: Before this change, the added unit test will trigger: `AssertionError: Can not find the original value for L__self____tensor_constant0_cuda0`. The reason is GraphLowering.constant_name could rename a constant with a device suffix but AOTI requires that new name being registered properly. Differential Revision: [D58047165](https://our.internmc.facebook.com/intern/diff/D58047165) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127671 Approved by: https://github.com/ColinPeppler, https://github.com/22quinn --- test/inductor/test_aot_inductor.py | 21 +++++++- torch/_inductor/graph.py | 86 +++++++++++++++--------------- 2 files changed, 63 insertions(+), 44 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 2a9a966628ad..15e10140d926 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1302,7 +1302,6 @@ def forward(self, x): return self.foo + x example_inputs = (torch.rand(4, 4, device=self.device),) - torch._export.aot_compile(Model(self.device), example_inputs) self.check_model(Model(self.device), example_inputs) def test_non_tensor_input(self): @@ -1385,6 +1384,26 @@ def forward(self, inp_pos, k, v): self.check_model(model, example_inputs) self.code_check_count(model, example_inputs, "empty_strided", 2) + def test_buffer_mutation_4(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "_tensor_constant0", + torch.randint(1, size=[38], dtype=torch.int64, device="cpu"), + ) + + def forward(self, x): + return x + self._tensor_constant0.to(torch.device(type="cuda", index=0)) + + example_inputs = ( + torch.randint(1, size=[38], dtype=torch.int64, device="cuda"), + ) + torch._export.aot_compile(Model(), example_inputs) + @requires_multigpu() def test_replicate_on_devices(self): if self.device != "cuda": diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index b7e8a1c48b74..412caf5e2242 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -802,46 +802,46 @@ def get_original_value_of_constant(self, name: str): else self.constants[name] ) - def add_tensor_constant(self, data, name=None): - def allocate(name): - if not config.aot_inductor.use_runtime_constant_folding: - for constant_name, value in self.constants.items(): - if ( - not data.is_mkldnn - and data.size() == value.size() - and data.stride() == value.stride() - and data.dtype == value.dtype - and data.device == value.device - and data.untyped_storage().data_ptr() - == value.untyped_storage().data_ptr() - and data.storage_offset() == value.storage_offset() - ): - return constant_name - - if name is None: - name = f"constant{len(self.constants)}" - if name[0].isdigit(): - name = f"constant_{name}" - name = self.qualify_name(name) - # We may generate a var name for each constant in the codegen. - # Let's only keep sane characters. - prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name) - name = prefix - cnt = 0 - while name in self.constants: - name = f"{prefix}_{cnt}" - cnt += 1 - self.constants[name] = data - self.constant_reprs[name] = ( - f"{data.device!r} {data.dtype!r} " - f"{tuple(data.size())!r} {tuple(data.stride())!r} " - f"{hash(data):x}" - ) - return name - - new_name = allocate(name) - self.allocated_constant_name[new_name] = name + def allocate_non_dup_const_name(self, name, data): + orig_name = name + if not config.aot_inductor.use_runtime_constant_folding: + for constant_name, value in self.constants.items(): + if ( + not data.is_mkldnn + and data.size() == value.size() + and data.stride() == value.stride() + and data.dtype == value.dtype + and data.device == value.device + and data.untyped_storage().data_ptr() + == value.untyped_storage().data_ptr() + and data.storage_offset() == value.storage_offset() + ): + return constant_name + + if name is None: + name = f"constant{len(self.constants)}" + if name[0].isdigit(): + name = f"constant_{name}" + name = self.qualify_name(name) + # We may generate a var name for each constant in the codegen. + # Let's only keep sane characters. + prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name) + name = prefix + cnt = 0 + while name in self.constants: + name = f"{prefix}_{cnt}" + cnt += 1 + self.constants[name] = data + self.constant_reprs[name] = ( + f"{data.device!r} {data.dtype!r} " + f"{tuple(data.size())!r} {tuple(data.stride())!r} " + f"{hash(data):x}" + ) + self.allocated_constant_name[name] = orig_name + return name + def add_tensor_constant(self, data, name=None): + new_name = self.allocate_non_dup_const_name(name, data) return TensorBox.create( ir.ConstantBuffer( new_name, @@ -857,10 +857,10 @@ def constant_name(self, name: str, device_override: Optional[torch.device]): """ if self.constants[name].device == device_override or device_override is None: return name - alt_name = f"{name}_{device_override.type}{device_override.index or 0}" - if alt_name not in self.constants: - self.constants[alt_name] = self.constants[name].to(device_override) - return alt_name + return self.allocate_non_dup_const_name( + f"{name}_{device_override.type}{device_override.index or 0}", + self.constants[name].to(device_override), + ) def placeholder(self, target: str, args, kwargs): example = super().placeholder(target, args, kwargs) From 4aa7a1efcfdfc36822053a5a6be54a0bcce05d7c Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 31 May 2024 09:31:11 -0700 Subject: [PATCH 208/706] [dynamo] Initial exception handling support (#126923) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126923 Approved by: https://github.com/williamwen42, https://github.com/jansel --- test/dynamo/test_exceptions.py | 234 ++++++++++++++++++++++++++++ test/dynamo/test_misc.py | 2 +- test/dynamo/test_repros.py | 2 +- torch/_dynamo/exc.py | 4 + torch/_dynamo/symbolic_convert.py | 232 +++++++++++++++++++++++++-- torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/builtin.py | 14 ++ torch/_dynamo/variables/misc.py | 12 ++ 8 files changed, 487 insertions(+), 14 deletions(-) create mode 100644 test/dynamo/test_exceptions.py diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py new file mode 100644 index 000000000000..1cf31f9edc36 --- /dev/null +++ b/test/dynamo/test_exceptions.py @@ -0,0 +1,234 @@ +# Owner(s): ["module: dynamo"] + +import torch +import torch._dynamo.config + +import torch._dynamo.test_case +import torch._functorch.config +import torch.utils.checkpoint + + +class ExceptionTests(torch._dynamo.test_case.TestCase): + def test_exception(self): + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + raise NotImplementedError + except Exception: + x = torch.sigmoid(x) + + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception2(self): + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + raise NotImplementedError + except (NotImplementedError, AttributeError) as e: + x = torch.sigmoid(x) + + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception3(self): + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + raise NotImplementedError("Not implemented") + except AssertionError: + x = torch.sigmoid(x) + except NotImplementedError: + x = torch.cos(x) + finally: + x = torch.cos(x) + + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_with_another_exception(self): + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + raise NotImplementedError("Not implemented") + except NotImplementedError as e: + x = torch.sigmoid(x) + try: + x = torch.cos(x) + raise AssertionError + except AssertionError: + x = torch.cos(x) + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_else(self): + def gn(x): + return torch.cos(x) + + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + x = gn(x) + except Exception: + x = torch.sigmoid(x) + else: + x = torch.cos(x) + + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + # TODO(anijain2305) - does not work with fullgraph=True + def test_exception_with_another_exception2(self): + def gn(x): + try: + x = torch.cos(x) + raise NotImplementedError("Not implemented") + except NotImplementedError as e: + x = torch.sigmoid(x) + raise + + def fn(x): + try: + x = torch.cos(x) + gn(x) + except Exception: + pass + return x + + x = torch.randn(4) + ref = fn(x) + # Cant use fullgraph=True because RERAISE is not supported + opt_fn = torch.compile(fn, backend="eager") + res = opt_fn(x) + + # TODO(anijain2305) - does not work with fullgraph=True + def test_exception_with_ctx_manager(self): + def fn(x): + x = torch.cos(x) + try: + with torch.no_grad(): + x = torch.sin(x) + raise NotImplementedError("Not implemented") + except NotImplementedError as e: + x = torch.sigmoid(x) + return x + + x = torch.randn(4) + ref = fn(x) + # Cant use fullgraph=True because WITH_EXCEPT_START is not supported + opt_fn = torch.compile(fn, backend="eager") + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_raised_from_child(self): + def gn(): + raise NotImplementedError("foo") + + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + gn() + x = torch.sin(x) + except Exception: + x = torch.sigmoid(x) + + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_nn_module_getattr(self): + class A: + def __init__(self): + self._b = 20 + + def __getattr__(self, name): + fixed_name = "_" + name + if fixed_name in self.__dict__: + return self.__dict__[fixed_name] + raise AttributeError(f"{name} absent") + + class B(A): + def __init__(self): + self.a = 10 + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return 30 + + obj = B() + + def fn(x): + return x * obj.a * obj.b * obj.c + + x = torch.ones(4) + ref = fn(x) + print(ref) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) + def test_custom_getattr_on_module_exception(self): + class Foo(torch.nn.Module): + def __init__(self, a=3): + super().__init__() + self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2)) + + def __getattr__(self, name): + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "a_copy": + return self.a + raise + + def forward(self, x): + return x * self.a * self.a_copy + + mod = Foo() + opt_mod = torch.compile(mod, backend="eager", fullgraph=True) + + x = torch.ones(4) + self.assertEqual(mod(x), opt_mod(x)) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index d9611028f789..739dedbc8d05 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -6840,7 +6840,7 @@ def fn(): x += 1 return x - opt_fn = torch._dynamo.optimize("eager")(fn) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) self.assertEqual(opt_fn(), torch.tensor([2.0])) def test_nested_sequential_with(self): diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 4b151d8b093e..90e1d34e8acc 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -3051,7 +3051,7 @@ def f(x): with self.assertRaisesRegex(AssertionError, "torch.Size"): opt_f(args) self.assertEqual( - torch._dynamo.utils.counters["unimplemented"][ + torch._dynamo.utils.counters["graph_break"][ "assert with non-string message" ], 1, diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 2ca4c311540e..d9f4c847d030 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -183,6 +183,10 @@ class IncorrectUsage(Exception): pass +class ObservedException(TorchDynamoException): + pass + + # These exceptions are ok to fallback to eager/graph_break. exceptions_allowed_to_be_fallback = ( torch._subclasses.fake_tensor.DataDependentOutputException, diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index a0014d3339c4..71ed48fbb292 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -653,6 +653,7 @@ class InstructionTranslatorBase( inconsistent_side_effects: bool current_speculation: Optional[SpeculationEntry] dispatch_table: List[Any] + exn_vt_stack: List[VariableTracker] exec_recorder: Optional[ExecutionRecorder] strict_checks_fn: Optional[Callable[[VariableTracker], bool]] @@ -802,6 +803,9 @@ def step(self): try: self.dispatch_table[inst.opcode](self, inst) return not self.output.should_exit + except exc.ObservedException: + self.exception_handler() + return True except ReturnValueOp: return False except Unsupported: @@ -991,9 +995,6 @@ def LOAD_GLOBAL(self, inst): assert name in self.f_builtins self.exec_recorder.builtins[name] = self.f_builtins[name] - if inst.argval == "AssertionError": - unimplemented("assert with non-string message") - if name in self.symbolic_globals: variable = self.output.side_effects[self.symbolic_globals[name]] self.push(self.output.side_effects.load_global(variable, name)) @@ -1130,21 +1131,24 @@ def IMPORT_FROM(self, inst): self.DUP_TOP(inst) self._load_attr(inst) - def load_builtin(self, inst): - if inst.argval not in self.f_builtins: - raise NameError(f"name '{inst.argval}' is not defined") - val = self.f_builtins[inst.argval] + def load_builtin_from_argval(self, argval): + if argval not in self.f_builtins: + raise NameError(f"name '{argval}' is not defined") + val = self.f_builtins[argval] if callable(val): builtins_source = GlobalSource( self.output.name_of_builtins_dict_key_in_fglobals ) - var_source = GetItemSource(builtins_source, inst.argval) + var_source = GetItemSource(builtins_source, argval) self.push(VariableBuilder(self, var_source)(val)) else: assert is_builtin_constant(val) self.push(ConstantVariable.create(value=val)) + def load_builtin(self, inst): + self.load_builtin_from_argval(inst.argval) + def jump(self, inst): self.instruction_pointer = self.indexof[inst.target] @@ -1236,16 +1240,213 @@ def RAISE_VARARGS(self, inst): unimplemented("re-raise") elif inst.arg == 1: val = self.pop() + + # TODO(anijain2305) - Merge StopIterationVariable to use the same exception infra. if ( isinstance(val, BuiltinVariable) and val.fn is StopIteration ) or isinstance(val, variables.StopIterationVariable): raise exc.UserStopIteration + + # User can raise exception in 2 ways + # 1) raise exception type - raise NotImplementedError + # 2) raise execption instance - raise NotImplemetedError("foo") + + # 1) when user raises exception type + if isinstance(val, variables.BuiltinVariable): + # Create the instance of the exception type + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 + val = val.call_function(self, [], {}) + + # Save the exception in a global data structure + self.exn_vt_stack.append(val) + + # 2) when user raises exception instance + if isinstance(val, variables.ExceptionVariable): + raise exc.ObservedException(f"raised exception {val}") unimplemented(f"raise {exc}") else: unimplemented("raise ... from ...") + def exception_handler(self): + if sys.version_info >= (3, 11): + exn_tab_entry = self.current_instruction.exn_tab_entry + if exn_tab_entry: + # Implementation is based on https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + + # 1) pop values from the stack until it matches the stack depth + # for the handler + while len(self.stack) > exn_tab_entry.depth: + self.pop() + + # 2) if 'lasti' is true, then push the offset that the exception was raised at + if exn_tab_entry.lasti: + # This is untested. Any test that tests this end-to-end + # requires supporting more bytecodes. Therefore graph + # breaking for now. + unimplemented("lasti=True while exception handling") + self.push( + variables.ConstantVariable(self.current_instruction.offset) + ) + + # 3) push the exception to the stack + assert len(self.exn_vt_stack) + self.push(self.exn_vt_stack[-1]) + + # 4) jump to the handler + self.jump(exn_tab_entry) + else: + # No handler found. Bubble the exception to the parent + # instruction translater. We use special exception for this. + self.stack.clear() + if type(self) is InstructionTranslator: + raise Unsupported("Observed exception") + raise exc.ObservedException + else: + if len(self.block_stack): + # base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455 + + assert len(self.exn_vt_stack) + exception_var = self.exn_vt_stack[-1] + + block_stack_entry = self.block_stack.pop() + + while block_stack_entry.inst.opname == "EXCEPT_HANDLER": + # TODO(anijain2305) - This is not tested .. unable to create a testcase + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 + self.popn(3) + if len(self.block_stack) == 0: + unimplemented( + "exception is raised when block stack " "is empty" + ) + block_stack_entry = self.block_stack.pop() + + if block_stack_entry.inst.opname != "SETUP_FINALLY": + unimplemented( + "exception is raised when top of the block stack " + "is not exception handler (e.g. try .. with .. except). " + f"Current TOS is {block_stack_entry.inst}" + ) + + # Push a dummy block stack entry of EXCEPT_HANDLER + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 + except_handler_inst = Instruction(1e6, "EXCEPT_HANDLER", None, 0) + self.block_stack.append(BlockStackEntry(except_handler_inst, None)) + + # Push old exception + if len(self.exn_vt_stack) >= 2: + old_exception = self.exn_vt_stack[-2] + + # Push the old exception on to stack - tb, value, type + # Traceback is currently mapped to UnknownVariable + self.push(variables.UnknownVariable()) + self.push(old_exception) + self.push(variables.BuiltinVariable(old_exception.exc_type)) + else: + # Push empty exception tb, value, type + self.push(variables.ConstantVariable(None)) + self.push(variables.ConstantVariable(None)) + self.push(variables.ConstantVariable(None)) + + # Push new exception - tb, val, type + # Traceback is currently mapped to UnknownVariable + self.push(variables.UnknownVariable()) + self.push(exception_var) + self.push(variables.BuiltinVariable(exception_var.exc_type)) + + # Jump to target + self.jump(block_stack_entry) + else: + # No handler found. Bubble the exception to the parent + # instruction translater. We use special exception for this. + self.stack.clear() + if type(self) is InstructionTranslator: + raise Unsupported("Observed exception") + raise exc.ObservedException + + def PUSH_EXC_INFO(self, inst): + val = self.pop() + assert len(self.exn_vt_stack) + self.push(self.exn_vt_stack[-1]) + self.push(val) + + def POP_EXCEPT(self, inst): + if sys.version_info >= (3, 11): + val = self.pop() + assert isinstance(val, variables.ExceptionVariable) + + # This exception is handled and therefore we can clear the error indicator + assert len(self.exn_vt_stack) + self.exn_vt_stack.pop() + else: + assert len(self.block_stack) > 0 + if self.block_stack[-1].inst.opname != "EXCEPT_HANDLER": + raise AssertionError( + "Bug in Dynamo tracing of exception handling." + "Top of the block stack is not EXCEPT_HANDLER." + ) + self.block_stack.pop() + + self.popn(3) + + # This exception is handled and therefore we can clear the error indicator + assert len(self.exn_vt_stack) + self.exn_vt_stack.pop() + + def check_if_exc_matches(self): + assert len(self.stack) >= 2 + expected_exc_types = self.pop() + exc_instance = self.stack[-1] + + # Users can check exception in 2 ways + # 1) except NotImplementedError --> BuilinVariable + # 2) except (NotImplemetedError, AttributeError) -> TupleVariable + + if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)): + unimplemented( + f"except has an unsupported types of objects {expected_exc_types}" + ) + + if sys.version_info >= (3, 11): + if not isinstance(exc_instance, variables.ExceptionVariable): + unimplemented( + f"except expects to recieve an object of exception type but received {exc_instance}" + ) + + if isinstance(expected_exc_types, TupleVariable): + expected_types = expected_exc_types.items + else: + expected_types = [ + expected_exc_types, + ] + + for expected_type in expected_types: + if not isinstance(expected_type, BuiltinVariable): + unimplemented( + f"except has an unsupported types of object {expected_type}" + ) + if isinstance(exc_instance, variables.ExceptionVariable) and issubclass( + exc_instance.exc_type, expected_type.fn + ): + return True + elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass( + exc_instance.fn, expected_type.fn + ): + return True + + return False + + def CHECK_EXC_MATCH(self, inst): + self.push(variables.ConstantVariable(self.check_if_exc_matches())) + + def JUMP_IF_NOT_EXC_MATCH(self, inst): + if not self.check_if_exc_matches(): + self.jump(inst) + def COMPARE_OP(self, inst): - self.push(compare_op_handlers[inst.argval](self, self.popn(2), {})) + if inst.argval == "exception match": + self.CHECK_EXC_MATCH(inst) + else: + self.push(compare_op_handlers[inst.argval](self, self.popn(2), {})) def GET_ITER(self, inst): self.call_function(BuiltinVariable(iter), [self.pop()], {}) @@ -1769,7 +1970,7 @@ def MATCH_KEYS(self, inst): self.push(ConstantVariable.create(False)) def LOAD_ASSERTION_ERROR(self, inst): - unimplemented("assert with non-string message") + self.load_builtin_from_argval("AssertionError") UNARY_POSITIVE = stack_op(operator.pos) UNARY_NEGATIVE = stack_op(operator.neg) @@ -2066,6 +2267,7 @@ def __init__( self.kw_names = None self.accept_prefix_inst = True self.prefix_insts = [] + self.exn_vt_stack = [] # Properties of the input/output code self.instructions: List[Instruction] = instructions @@ -2577,6 +2779,14 @@ def get_trace_call_log_str(): try: with strict_ctx: tracer.run() + except exc.ObservedException as e: + msg = f"Observed exception DURING INLING {code} : {e}" + # TODO(anijain2305) - This works but we should probably have a + # global/central data structure for the exception stack. + parent.exn_vt_stack.extend(tracer.exn_vt_stack) + log.debug(msg) + # bubble up the exception to the parent frame. + raise except exc.SkipFrame as e: msg = f"SKIPPED INLINING {code}: {e}" log.debug(msg) @@ -2757,8 +2967,6 @@ def LOAD_GLOBAL(self, inst): self.PUSH_NULL(inst) name = inst.argval - if inst.argval == "AssertionError": - unimplemented("assert with non-string message") _, fglobals_vt, global_source = self.get_globals_source_and_value(name) if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name): diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 06f634efb348..25bda3769eb4 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -63,6 +63,7 @@ AutogradFunctionVariable, ClosureVariable, DeletedVariable, + ExceptionVariable, GetAttrVariable, InspectSignatureVariable, LambdaVariable, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 5603496193e2..ce1d4bf9a0dd 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -712,6 +712,20 @@ def _make_handler(fn, arg_types: List[type], has_kwargs: bool): tx, [v.realize() for v in args], kwargs ) + if inspect.isclass(fn) and issubclass(fn, Exception): + + def create_exception_class_object(tx, args, kwargs): + if fn is AssertionError and not all( + isinstance(x, variables.ConstantVariable) + and isinstance(x.value, str) + for x in args + ): + unimplemented("assert with non-string message") + + return variables.ExceptionVariable(fn, args, **kwargs) + + return create_exception_class_object + if obj.can_insert_in_graph() and not ( fn is operator.getitem and not issubclass(arg_types[0], variables.TensorVariable) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index c053f04662a9..a5ac9c4d8fb4 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -179,6 +179,18 @@ def call_method( unimplemented(f"non-function or method super: {inner_fn}") +class ExceptionVariable(VariableTracker): + def __init__(self, exc_type, args, **kwargs): + super().__init__(**kwargs) + self.exc_type = exc_type + self.args = args + + def reconstruct(self, codegen): + codegen.load_import_from("builtins", self.exc_type.__name__) + codegen.foreach(self.args) + codegen.call_function(len(self.args), True) + + class UnknownVariable(VariableTracker): """ It could be anything! From ac60bdaf01e786a67f1da4c57df2772452a6f9d9 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 31 May 2024 09:02:32 -0700 Subject: [PATCH 209/706] Allow slow foreach to run for any backend, not just CPU (#127412) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127412 Approved by: https://github.com/albanD --- aten/src/ATen/native/native_functions.yaml | 278 +++++++++--------- test/test_foreach.py | 2 +- .../_internal/common_methods_invocations.py | 44 ++- 3 files changed, 172 insertions(+), 152 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 41968a72fd8e..a051f43e87eb 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10323,14 +10323,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow CUDA: foreach_tensor_add_scalar_kernel_cuda - func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_ CUDA: foreach_tensor_add_scalar_kernel_cuda_ autogen: _foreach_add.Scalar_out @@ -10338,14 +10338,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow CUDA: foreach_tensor_add_list_kernel_cuda - func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_ CUDA: foreach_tensor_add_list_kernel_cuda_ autogen: _foreach_add.List_out @@ -10353,14 +10353,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_add_scalarlist_kernel_slow CUDA: foreach_tensor_add_scalarlist_kernel_cuda - func: _foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_add_scalarlist_kernel_slow_ CUDA: foreach_tensor_add_scalarlist_kernel_cuda_ autogen: _foreach_add.ScalarList_out @@ -10368,14 +10368,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_tensor_kernel_slow + CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow CUDA: foreach_tensor_add_tensor_kernel_cuda - func: _foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_tensor_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_ CUDA: foreach_tensor_add_tensor_kernel_cuda_ autogen: _foreach_add.Tensor_out @@ -10383,14 +10383,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_sub_scalar_kernel_slow CUDA: foreach_tensor_sub_scalar_kernel_cuda - func: _foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_sub_scalar_kernel_slow_ CUDA: foreach_tensor_sub_scalar_kernel_cuda_ autogen: _foreach_sub.Scalar_out @@ -10398,14 +10398,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_sub_list_kernel_slow CUDA: foreach_tensor_sub_list_kernel_cuda - func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_sub_list_kernel_slow_ CUDA: foreach_tensor_sub_list_kernel_cuda_ autogen: _foreach_sub.List_out @@ -10413,14 +10413,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_sub_scalarlist_kernel_slow CUDA: foreach_tensor_sub_scalarlist_kernel_cuda - func: _foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_sub_scalarlist_kernel_slow_ CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_ autogen: _foreach_sub.ScalarList_out @@ -10428,14 +10428,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow CUDA: foreach_tensor_mul_scalar_kernel_cuda - func: _foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_ CUDA: foreach_tensor_mul_scalar_kernel_cuda_ autogen: _foreach_mul.Scalar_out @@ -10443,14 +10443,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow CUDA: foreach_tensor_mul_list_kernel_cuda - func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_ CUDA: foreach_tensor_mul_list_kernel_cuda_ autogen: _foreach_mul.List_out @@ -10458,14 +10458,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_mul_scalarlist_kernel_slow CUDA: foreach_tensor_mul_scalarlist_kernel_cuda - func: _foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_mul_scalarlist_kernel_slow_ CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_ autogen: _foreach_mul.ScalarList_out @@ -10473,14 +10473,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_tensor_kernel_slow + CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow CUDA: foreach_tensor_mul_tensor_kernel_cuda - func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_tensor_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_ CUDA: foreach_tensor_mul_tensor_kernel_cuda_ autogen: _foreach_mul.Tensor_out @@ -10488,14 +10488,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_div_scalar_kernel_slow CUDA: foreach_tensor_div_scalar_kernel_cuda - func: _foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_div_scalar_kernel_slow_ CUDA: foreach_tensor_div_scalar_kernel_cuda_ autogen: _foreach_div.Scalar_out @@ -10503,14 +10503,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow CUDA: foreach_tensor_div_list_kernel_cuda - func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow_ CUDA: foreach_tensor_div_list_kernel_cuda_ autogen: _foreach_div.List_out @@ -10518,14 +10518,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_div_scalarlist_kernel_slow CUDA: foreach_tensor_div_scalarlist_kernel_cuda - func: _foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_div_scalarlist_kernel_slow_ CUDA: foreach_tensor_div_scalarlist_kernel_cuda_ autogen: _foreach_div.ScalarList_out @@ -10533,14 +10533,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_tensor_kernel_slow + CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow CUDA: foreach_tensor_div_tensor_kernel_cuda - func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_tensor_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow_ CUDA: foreach_tensor_div_tensor_kernel_cuda_ autogen: _foreach_div.Tensor_out @@ -10548,14 +10548,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda - func: _foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ autogen: _foreach_clamp_max.Scalar_out @@ -10563,14 +10563,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow CUDA: foreach_tensor_clamp_max_list_kernel_cuda - func: _foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_ CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ autogen: _foreach_clamp_max.List_out @@ -10578,14 +10578,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda - func: _foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ autogen: _foreach_clamp_max.ScalarList_out @@ -10593,14 +10593,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda - func: _foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ autogen: _foreach_clamp_min.Scalar_out @@ -10608,14 +10608,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow CUDA: foreach_tensor_clamp_min_list_kernel_cuda - func: _foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow_ CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ autogen: _foreach_clamp_min.List_out @@ -10623,14 +10623,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda - func: _foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ autogen: _foreach_clamp_min.ScalarList_out @@ -10639,14 +10639,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda - func: _foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ autogen: _foreach_maximum.Scalar_out @@ -10655,14 +10655,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow CUDA: foreach_tensor_clamp_min_list_kernel_cuda - func: _foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow_ CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ autogen: _foreach_maximum.List_out @@ -10671,14 +10671,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda - func: _foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ autogen: _foreach_maximum.ScalarList_out @@ -10686,14 +10686,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda - func: _foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ autogen: _foreach_minimum.Scalar_out @@ -10701,14 +10701,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow CUDA: foreach_tensor_clamp_max_list_kernel_cuda - func: _foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_ CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ autogen: _foreach_minimum.List_out @@ -10716,14 +10716,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda - func: _foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ autogen: _foreach_minimum.ScalarList_out @@ -10731,28 +10731,28 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_scalar_slow + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalar_slow CUDA: foreach_tensor_addcdiv_scalar_cuda - func: _foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_scalarlist_slow + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalarlist_slow CUDA: foreach_tensor_addcdiv_scalarlist_cuda - func: _foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_tensor_slow + CompositeExplicitAutograd: foreach_tensor_addcdiv_tensor_slow CUDA: foreach_tensor_addcdiv_tensor_cuda - func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_scalar_slow_ + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalar_slow_ CUDA: foreach_tensor_addcdiv_scalar_cuda_ autogen: _foreach_addcdiv.Scalar_out @@ -10760,7 +10760,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_scalarlist_slow_ + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalarlist_slow_ CUDA: foreach_tensor_addcdiv_scalarlist_cuda_ autogen: _foreach_addcdiv.ScalarList_out @@ -10768,7 +10768,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_tensor_slow_ + CompositeExplicitAutograd: foreach_tensor_addcdiv_tensor_slow_ CUDA: foreach_tensor_addcdiv_tensor_cuda_ autogen: _foreach_addcdiv.Tensor_out @@ -10776,28 +10776,28 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_scalar_slow + CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow CUDA: foreach_tensor_addcmul_scalar_cuda - func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_scalarlist_slow + CompositeExplicitAutograd: foreach_tensor_addcmul_scalarlist_slow CUDA: foreach_tensor_addcmul_scalarlist_cuda - func: _foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_tensor_slow + CompositeExplicitAutograd: foreach_tensor_addcmul_tensor_slow CUDA: foreach_tensor_addcmul_tensor_cuda - func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_scalar_slow_ + CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_ CUDA: foreach_tensor_addcmul_scalar_cuda_ autogen: _foreach_addcmul.Scalar_out @@ -10805,7 +10805,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_scalarlist_slow_ + CompositeExplicitAutograd: foreach_tensor_addcmul_scalarlist_slow_ CUDA: foreach_tensor_addcmul_scalarlist_cuda_ autogen: _foreach_addcmul.ScalarList_out @@ -10813,7 +10813,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_tensor_slow_ + CompositeExplicitAutograd: foreach_tensor_addcmul_tensor_slow_ CUDA: foreach_tensor_addcmul_tensor_cuda_ autogen: _foreach_addcmul.Tensor_out @@ -10821,14 +10821,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_abs_slow + CompositeExplicitAutograd: foreach_tensor_abs_slow CUDA: foreach_tensor_abs_cuda - func: _foreach_abs_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_abs_slow_ + CompositeExplicitAutograd: foreach_tensor_abs_slow_ CUDA: foreach_tensor_abs_cuda_ autogen: _foreach_abs.out @@ -10836,14 +10836,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_acos_slow + CompositeExplicitAutograd: foreach_tensor_acos_slow CUDA: foreach_tensor_acos_cuda - func: _foreach_acos_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_acos_slow_ + CompositeExplicitAutograd: foreach_tensor_acos_slow_ CUDA: foreach_tensor_acos_cuda_ autogen: _foreach_acos.out @@ -10851,14 +10851,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_asin_slow + CompositeExplicitAutograd: foreach_tensor_asin_slow CUDA: foreach_tensor_asin_cuda - func: _foreach_asin_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_asin_slow_ + CompositeExplicitAutograd: foreach_tensor_asin_slow_ CUDA: foreach_tensor_asin_cuda_ autogen: _foreach_asin.out @@ -10866,14 +10866,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_atan_slow + CompositeExplicitAutograd: foreach_tensor_atan_slow CUDA: foreach_tensor_atan_cuda - func: _foreach_atan_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_atan_slow_ + CompositeExplicitAutograd: foreach_tensor_atan_slow_ CUDA: foreach_tensor_atan_cuda_ autogen: _foreach_atan.out @@ -10881,14 +10881,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_ceil_slow + CompositeExplicitAutograd: foreach_tensor_ceil_slow CUDA: foreach_tensor_ceil_cuda - func: _foreach_ceil_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_ceil_slow_ + CompositeExplicitAutograd: foreach_tensor_ceil_slow_ CUDA: foreach_tensor_ceil_cuda_ autogen: _foreach_ceil.out @@ -10896,14 +10896,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_cos_slow + CompositeExplicitAutograd: foreach_tensor_cos_slow CUDA: foreach_tensor_cos_cuda - func: _foreach_cos_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_cos_slow_ + CompositeExplicitAutograd: foreach_tensor_cos_slow_ CUDA: foreach_tensor_cos_cuda_ autogen: _foreach_cos.out @@ -10911,14 +10911,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_cosh_slow + CompositeExplicitAutograd: foreach_tensor_cosh_slow CUDA: foreach_tensor_cosh_cuda - func: _foreach_cosh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_cosh_slow_ + CompositeExplicitAutograd: foreach_tensor_cosh_slow_ CUDA: foreach_tensor_cosh_cuda_ autogen: _foreach_cosh.out @@ -10926,14 +10926,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_erf_slow + CompositeExplicitAutograd: foreach_tensor_erf_slow CUDA: foreach_tensor_erf_cuda - func: _foreach_erf_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_erf_slow_ + CompositeExplicitAutograd: foreach_tensor_erf_slow_ CUDA: foreach_tensor_erf_cuda_ autogen: _foreach_erf.out @@ -10941,14 +10941,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_erfc_slow + CompositeExplicitAutograd: foreach_tensor_erfc_slow CUDA: foreach_tensor_erfc_cuda - func: _foreach_erfc_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_erfc_slow_ + CompositeExplicitAutograd: foreach_tensor_erfc_slow_ CUDA: foreach_tensor_erfc_cuda_ autogen: _foreach_erfc.out @@ -10956,14 +10956,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_exp_slow + CompositeExplicitAutograd: foreach_tensor_exp_slow CUDA: foreach_tensor_exp_cuda - func: _foreach_exp_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_exp_slow_ + CompositeExplicitAutograd: foreach_tensor_exp_slow_ CUDA: foreach_tensor_exp_cuda_ autogen: _foreach_exp.out @@ -10971,14 +10971,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_expm1_slow + CompositeExplicitAutograd: foreach_tensor_expm1_slow CUDA: foreach_tensor_expm1_cuda - func: _foreach_expm1_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_expm1_slow_ + CompositeExplicitAutograd: foreach_tensor_expm1_slow_ CUDA: foreach_tensor_expm1_cuda_ autogen: _foreach_expm1.out @@ -10986,14 +10986,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_floor_slow + CompositeExplicitAutograd: foreach_tensor_floor_slow CUDA: foreach_tensor_floor_cuda - func: _foreach_floor_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_floor_slow_ + CompositeExplicitAutograd: foreach_tensor_floor_slow_ CUDA: foreach_tensor_floor_cuda_ autogen: _foreach_floor.out @@ -11001,14 +11001,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_frac_slow + CompositeExplicitAutograd: foreach_tensor_frac_slow CUDA: foreach_tensor_frac_cuda - func: _foreach_frac_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_frac_slow_ + CompositeExplicitAutograd: foreach_tensor_frac_slow_ CUDA: foreach_tensor_frac_cuda_ autogen: _foreach_frac.out @@ -11016,7 +11016,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices variants: function dispatch: - CPU: foreach_tensor_ternary_lerp_slow + CompositeExplicitAutograd: foreach_tensor_ternary_lerp_slow CUDA: foreach_tensor_lerp_ternary_cuda autogen: _foreach_lerp.List_out @@ -11024,7 +11024,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices variants: function dispatch: - CPU: foreach_tensor_ternary_lerp_slow_ + CompositeExplicitAutograd: foreach_tensor_ternary_lerp_slow_ CUDA: foreach_tensor_lerp_ternary_cuda_ autogen: _foreach_lerp.List_out @@ -11032,7 +11032,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices variants: function dispatch: - CPU: foreach_tensor_lerp_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_lerp_list_kernel_slow CUDA: foreach_tensor_lerp_list_cuda autogen: _foreach_lerp.Scalar_out @@ -11040,7 +11040,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices variants: function dispatch: - CPU: foreach_tensor_lerp_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_lerp_list_kernel_slow_ CUDA: foreach_tensor_lerp_list_cuda_ autogen: _foreach_lerp.Scalar_out @@ -11048,14 +11048,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_lgamma_slow + CompositeExplicitAutograd: foreach_tensor_lgamma_slow CUDA: foreach_tensor_lgamma_cuda - func: _foreach_lgamma_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_lgamma_slow_ + CompositeExplicitAutograd: foreach_tensor_lgamma_slow_ CUDA: foreach_tensor_lgamma_cuda_ autogen: _foreach_lgamma.out @@ -11063,14 +11063,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log_slow + CompositeExplicitAutograd: foreach_tensor_log_slow CUDA: foreach_tensor_log_cuda - func: _foreach_log_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log_slow_ + CompositeExplicitAutograd: foreach_tensor_log_slow_ CUDA: foreach_tensor_log_cuda_ autogen: _foreach_log.out @@ -11078,14 +11078,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log10_slow + CompositeExplicitAutograd: foreach_tensor_log10_slow CUDA: foreach_tensor_log10_cuda - func: _foreach_log10_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log10_slow_ + CompositeExplicitAutograd: foreach_tensor_log10_slow_ CUDA: foreach_tensor_log10_cuda_ autogen: _foreach_log10.out @@ -11093,14 +11093,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log1p_slow + CompositeExplicitAutograd: foreach_tensor_log1p_slow CUDA: foreach_tensor_log1p_cuda - func: _foreach_log1p_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log1p_slow_ + CompositeExplicitAutograd: foreach_tensor_log1p_slow_ CUDA: foreach_tensor_log1p_cuda_ autogen: _foreach_log1p.out @@ -11108,14 +11108,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log2_slow + CompositeExplicitAutograd: foreach_tensor_log2_slow CUDA: foreach_tensor_log2_cuda - func: _foreach_log2_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log2_slow_ + CompositeExplicitAutograd: foreach_tensor_log2_slow_ CUDA: foreach_tensor_log2_cuda_ autogen: _foreach_log2.out @@ -11123,7 +11123,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_max_slow + CompositeExplicitAutograd: foreach_tensor_max_slow CUDA: foreach_tensor_max_cuda autogen: _foreach_max.out @@ -11131,14 +11131,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_neg_slow + CompositeExplicitAutograd: foreach_tensor_neg_slow CUDA: foreach_tensor_neg_cuda - func: _foreach_neg_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_neg_slow_ + CompositeExplicitAutograd: foreach_tensor_neg_slow_ CUDA: foreach_tensor_neg_cuda_ autogen: _foreach_neg.out @@ -11146,7 +11146,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_norm_slow + CompositeExplicitAutograd: foreach_tensor_norm_slow CUDA: foreach_tensor_norm_cuda autogen: _foreach_norm.Scalar_out @@ -11154,35 +11154,35 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_pow_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_pow_list_kernel_slow CUDA: foreach_tensor_pow_list_kernel_cuda - func: _foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_pow_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_pow_scalar_kernel_slow CUDA: foreach_tensor_pow_scalar_kernel_cuda - func: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_pow_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_pow_scalarlist_kernel_slow CUDA: foreach_tensor_pow_scalarlist_kernel_cuda - func: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_scalar_pow_list_kernel_slow + CompositeExplicitAutograd: foreach_scalar_pow_list_kernel_slow CUDA: foreach_scalar_pow_list_kernel_cuda - func: _foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> () device_check: NoCheck variants: function dispatch: - CPU: foreach_tensor_pow_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_pow_list_kernel_slow_ CUDA: foreach_tensor_pow_list_kernel_cuda_ autogen: _foreach_pow.List_out @@ -11190,7 +11190,7 @@ device_check: NoCheck variants: function dispatch: - CPU: foreach_tensor_pow_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_pow_scalar_kernel_slow_ CUDA: foreach_tensor_pow_scalar_kernel_cuda_ autogen: _foreach_pow.Scalar_out @@ -11198,7 +11198,7 @@ device_check: NoCheck variants: function dispatch: - CPU: foreach_tensor_pow_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_pow_scalarlist_kernel_slow_ CUDA: foreach_tensor_pow_scalarlist_kernel_cuda_ autogen: _foreach_pow.ScalarList_out @@ -11206,14 +11206,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_reciprocal_slow + CompositeExplicitAutograd: foreach_tensor_reciprocal_slow CUDA: foreach_tensor_reciprocal_cuda - func: _foreach_reciprocal_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_reciprocal_slow_ + CompositeExplicitAutograd: foreach_tensor_reciprocal_slow_ CUDA: foreach_tensor_reciprocal_cuda_ autogen: _foreach_reciprocal.out @@ -11221,14 +11221,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_round_slow + CompositeExplicitAutograd: foreach_tensor_round_slow CUDA: foreach_tensor_round_cuda - func: _foreach_round_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_round_slow_ + CompositeExplicitAutograd: foreach_tensor_round_slow_ CUDA: foreach_tensor_round_cuda_ autogen: _foreach_round.out @@ -11236,14 +11236,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sigmoid_slow + CompositeExplicitAutograd: foreach_tensor_sigmoid_slow CUDA: foreach_tensor_sigmoid_cuda - func: _foreach_sigmoid_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sigmoid_slow_ + CompositeExplicitAutograd: foreach_tensor_sigmoid_slow_ CUDA: foreach_tensor_sigmoid_cuda_ autogen: _foreach_sigmoid.out @@ -11251,14 +11251,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sign_slow + CompositeExplicitAutograd: foreach_tensor_sign_slow CUDA: foreach_tensor_sign_cuda - func: _foreach_sign_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sign_slow_ + CompositeExplicitAutograd: foreach_tensor_sign_slow_ CUDA: foreach_tensor_sign_cuda_ autogen: _foreach_sign.out @@ -11266,14 +11266,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sin_slow + CompositeExplicitAutograd: foreach_tensor_sin_slow CUDA: foreach_tensor_sin_cuda - func: _foreach_sin_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sin_slow_ + CompositeExplicitAutograd: foreach_tensor_sin_slow_ CUDA: foreach_tensor_sin_cuda_ autogen: _foreach_sin.out @@ -11281,14 +11281,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sinh_slow + CompositeExplicitAutograd: foreach_tensor_sinh_slow CUDA: foreach_tensor_sinh_cuda - func: _foreach_sinh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sinh_slow_ + CompositeExplicitAutograd: foreach_tensor_sinh_slow_ CUDA: foreach_tensor_sinh_cuda_ autogen: _foreach_sinh.out @@ -11296,14 +11296,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sqrt_slow + CompositeExplicitAutograd: foreach_tensor_sqrt_slow CUDA: foreach_tensor_sqrt_cuda - func: _foreach_sqrt_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sqrt_slow_ + CompositeExplicitAutograd: foreach_tensor_sqrt_slow_ CUDA: foreach_tensor_sqrt_cuda_ autogen: _foreach_sqrt.out @@ -11311,14 +11311,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_tan_slow + CompositeExplicitAutograd: foreach_tensor_tan_slow CUDA: foreach_tensor_tan_cuda - func: _foreach_tan_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_tan_slow_ + CompositeExplicitAutograd: foreach_tensor_tan_slow_ CUDA: foreach_tensor_tan_cuda_ autogen: _foreach_tan.out @@ -11326,14 +11326,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_tanh_slow + CompositeExplicitAutograd: foreach_tensor_tanh_slow CUDA: foreach_tensor_tanh_cuda - func: _foreach_tanh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_tanh_slow_ + CompositeExplicitAutograd: foreach_tensor_tanh_slow_ CUDA: foreach_tensor_tanh_cuda_ autogen: _foreach_tanh.out @@ -11341,14 +11341,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_trunc_slow + CompositeExplicitAutograd: foreach_tensor_trunc_slow CUDA: foreach_tensor_trunc_cuda - func: _foreach_trunc_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_trunc_slow_ + CompositeExplicitAutograd: foreach_tensor_trunc_slow_ CUDA: foreach_tensor_trunc_cuda_ autogen: _foreach_trunc.out @@ -11356,7 +11356,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_zero_slow_ + CompositeExplicitAutograd: foreach_tensor_zero_slow_ CUDA: foreach_tensor_zero_cuda_ autogen: _foreach_zero, _foreach_zero.out @@ -11364,7 +11364,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_copy_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_ CUDA: foreach_tensor_copy_list_kernel_cuda_ autogen: _foreach_copy.out diff --git a/test/test_foreach.py b/test/test_foreach.py index 2683b9823190..61d81d18db7b 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -597,7 +597,7 @@ def test_binary_op_list_error_cases(self, device, dtype, op): # Empty lists for fop in ops_to_test: with self.assertRaisesRegex( - RuntimeError, "There were no tensor arguments to this function" + RuntimeError, "Tensor list must have at least one tensor." ): fop(tensors1, tensors2) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f59cca11becc..cdb7d164eeb5 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11040,11 +11040,23 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=(torch.half,), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.half,), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.half,), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=(torch.half,), device_type="cpu"), DecorateInfo(unittest.skip("flaky"), "TestForeach", "test_parity", device_type="cpu", dtypes=(torch.complex64,)), DecorateInfo( unittest.expectedFailure, @@ -11075,10 +11087,14 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), # Samples have complex types and inplace only works if the dtype is complex. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), @@ -11107,10 +11123,14 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), # fails with div_cpu is not implemented with ComplexHalf - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), ), ), ] From 4129c3e596497ac87dad58e4be0f8290b216ab97 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 31 May 2024 11:08:17 -0700 Subject: [PATCH 210/706] Let us find out why we wrote foreach meta regs (#127623) Turns out it was for no reason!...well, after realizing that these ops are all CompositeExplicit, their meta impls come for free. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127623 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #127412 --- torch/_meta_registrations.py | 324 ------------------ .../_internal/common_methods_invocations.py | 12 +- 2 files changed, 8 insertions(+), 328 deletions(-) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 95e3aa1eebf3..624801bf9afa 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1,6 +1,5 @@ import math from enum import Enum -from functools import partial from typing import List, Optional, Sequence, Tuple, Union import torch @@ -19,7 +18,6 @@ corresponding_real_dtype, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, - FloatLike, IntLike, make_contiguous_strides_for, Number, @@ -3135,328 +3133,6 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): return self.new_empty(self.size()) -def register_meta_foreach(ops): - def wrapper(fn): - def register(op): - op_name = str(op).split(".")[1] - scalar_op = getattr(aten, op_name.replace("_foreach_", "")) - - _add_op_to_registry( - meta_table, - op, - partial( - fn, - _scalar_op=scalar_op, - ), - ) - - pytree.tree_map_(register, ops) - return fn - - return wrapper - - -@register_meta_foreach( - [ - aten._foreach_abs, - aten._foreach_acos, - aten._foreach_asin, - aten._foreach_atan, - aten._foreach_ceil, - aten._foreach_cos, - aten._foreach_cosh, - aten._foreach_erf, - aten._foreach_erfc, - aten._foreach_exp, - aten._foreach_expm1, - aten._foreach_frac, - aten._foreach_floor, - aten._foreach_lgamma, - aten._foreach_log, - aten._foreach_log10, - aten._foreach_log1p, - aten._foreach_log2, - aten._foreach_max, - aten._foreach_neg, - aten._foreach_reciprocal, - aten._foreach_round, - aten._foreach_sigmoid, - aten._foreach_sign, - aten._foreach_sin, - aten._foreach_sinh, - aten._foreach_sqrt, - aten._foreach_tan, - aten._foreach_tanh, - aten._foreach_trunc, - aten._foreach_zero, - aten._foreach_add, - aten._foreach_sub, - aten._foreach_mul, - aten._foreach_div, - aten._foreach_clamp_min, - aten._foreach_clamp_max, - aten._foreach_lerp, - ], -) -def _meta_foreach_out_of_place(*args, _scalar_op=None, **kwargs): - torch._check( - isinstance(args[0], list), - lambda: (f"The first argument must be List[Tensor], but got {type(args[0])}."), - ) - - nelem = len(args[0]) - torch._check( - nelem > 0, - lambda: ("Tensor list must have at least one tensor."), - ) - - nlists = 1 - for iarg, arg in enumerate(args[1:]): - if isinstance(arg, list): - nlists += 1 - torch._check( - len(arg) == nelem, - lambda: ( - f"self and argument-{iarg+2} must match in length, " - f"but got {nelem} and {len(arg)}." - ), - ) - elif isinstance(arg, Tensor): - torch._check( - arg.dim() == 0 and arg.numel() == 1, - lambda: ( - "scalar tensor expected to be 0 dim but it has " - f"{arg.dim()} dimensions and {arg.numel()} elements." - ), - ) - else: - break - - result = [] - for elem in range(nelem): - each_args = [args[i][elem] for i in range(nlists)] - result.append(_scalar_op(*each_args, *args[nlists:], **kwargs)) - - return result - - -@register_meta_foreach( - [ - aten._foreach_abs_, - aten._foreach_acos_, - aten._foreach_asin_, - aten._foreach_atan_, - aten._foreach_ceil_, - aten._foreach_cos_, - aten._foreach_cosh_, - aten._foreach_erf_, - aten._foreach_erfc_, - aten._foreach_exp_, - aten._foreach_expm1_, - aten._foreach_frac_, - aten._foreach_floor_, - aten._foreach_lgamma_, - aten._foreach_log_, - aten._foreach_log10_, - aten._foreach_log1p_, - aten._foreach_log2_, - aten._foreach_neg_, - aten._foreach_reciprocal_, - aten._foreach_round_, - aten._foreach_sigmoid_, - aten._foreach_sign_, - aten._foreach_sin_, - aten._foreach_sinh_, - aten._foreach_sqrt_, - aten._foreach_tan_, - aten._foreach_tanh_, - aten._foreach_trunc_, - aten._foreach_zero_, - aten._foreach_add_, - aten._foreach_sub_, - aten._foreach_mul_, - aten._foreach_div_, - aten._foreach_clamp_min_, - aten._foreach_clamp_max_, - aten._foreach_lerp_, - aten._foreach_copy_, - ] -) -def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs): - _meta_foreach_out_of_place(*args, _scalar_op=_scalar_op, **kwargs) - return - - -@register_meta([aten._foreach_pow_.Scalar]) -def meta__foreach_pow__scalar(self, exponent): - torch._check( - isinstance(exponent, FloatLike), - lambda: f"exponent must be a float but got {type(exponent)}", - ) - return - - -@register_meta([aten._foreach_pow.ScalarAndTensor]) -def meta__foreach_pow_scalar_and_tensor(self, exponent): - # Only foreach_pow has a ScalarAndTensor method and needs special - # handling because it does not work with _meta_foreach_out_of_place. - torch._check( - isinstance(exponent, List), - lambda: f"exponent must be a tensor list but got {type(exponent)}", - ) - return [torch.empty_like(e) for e in exponent] - - -@register_meta([aten._foreach_norm]) -def meta__foreach_norm(self, ord=2, dtype=None): - torch._check( - isinstance(self, list), - lambda: f"self must be a tensor list but got {type(self)}", - ) - torch._check( - isinstance(ord, Number), - lambda: f"ord must be an integer but got {type(ord)}", - ) - torch._check( - dtype is None or isinstance(dtype, torch.dtype), - lambda: f"dtype must be either None or torch.dtype but got {type(dtype)}", - ) - return [ - torch.empty( - (), - device=t.device, - dtype=t.dtype.to_real() if dtype is None else dtype.to_real(), - ) - for t in self - ] - - -def _check_foreach_binop_tensor_lists(self, other): - torch._check( - isinstance(self, List) and isinstance(other, List), - lambda: ( - "The first two arguments of must be List[Tensor], " - f"but got {type(self)} and {type(other)}." - ), - ) - torch._check( - len(self) > 0 and len(self) == len(other), - lambda: ( - "self and other must be non-empty and match in length, " - f"but got {len(self)} and {len(other)}." - ), - ) - - -@register_meta( - [ - aten._foreach_maximum, - aten._foreach_minimum, - ] -) -def meta__foreach_binop_scalar(*args): - # aten.maximum(Tensor, Scalar) does not exist. - return _meta_foreach_out_of_place(*args, _scalar_op=aten.clamp_min) - - -@register_meta( - [ - aten._foreach_maximum_, - aten._foreach_minimum_, - ] -) -def meta__foreach_binop__scalar(*args): - # aten.maximum(Tensor, Scalar) does not exist - _meta_foreach_inplace(*args, _scalar_op=aten.clamp_min_) - return - - -@register_meta( - [ - aten._foreach_addcdiv.Scalar, - aten._foreach_addcmul.Scalar, - ] -) -def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1): - # forach_addcdiv and addcdiv have different signatures and - # cannot use _meta_foreach_out_of_place. - torch._check( - all(isinstance(l, List) for l in [self, tensor1, tensor2]), - lambda: ( - "All arguments must be List[Tensor], " - f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}" - ), - ) - torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") - torch._check( - len(self) == len(tensor1) and len(self) == len(tensor2), - lambda: "All input tensor lists must have the same length", - ) - - return [torch.empty_like(s) for s in self] - - -@register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor]) -def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars): - torch._check( - all(isinstance(l, List) for l in [self, tensor1, tensor2]) - and isinstance(scalars, torch.Tensor), - lambda: ( - "_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], tensor, " - f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}" - ), - ) - torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") - torch._check( - len(self) == len(tensor1) and len(self) == len(tensor2), - lambda: "All input tensor lists must have the same length", - ) - - -@register_meta( - [ - aten._foreach_addcdiv_.Scalar, - aten._foreach_addcmul_.Scalar, - ] -) -def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1): - torch._check( - all(isinstance(l, List) for l in [self, tensor1, tensor2]), - lambda: ( - "All arguments of _foreach_addc*_ must be List[Tensor], " - f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}" - ), - ) - torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") - torch._check( - len(self) == len(tensor1) and len(self) == len(tensor2), - lambda: "All input tensor lists must have the same length", - ) - - -@register_meta( - [ - aten._foreach_addcdiv_.ScalarList, - aten._foreach_addcmul_.ScalarList, - ] -) -def meta__foreach_addcop__scalarlist(self, tensor1, tensor2, scalars): - torch._check( - all(isinstance(l, List) for l in [self, tensor1, tensor2, scalars]), - lambda: ( - "_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], List[Scalar], " - f"but got {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}" - ), - ) - torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") - torch._check( - len(self) == len(tensor1) - and len(self) == len(tensor2) - and len(self) == len(scalars), - lambda: "All input tensor lists must have the same length", - ) - - @register_meta([aten._fused_adam_.default]) def meta__fused_adam_( self, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index cdb7d164eeb5..50cfac763be5 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11037,10 +11037,14 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", From efcea2d2fd1654311cd88ddbfa2406e351c34bcd Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 31 May 2024 09:31:11 -0700 Subject: [PATCH 211/706] [dynamo] Support __getitem__ on NNModuleVariable __dict__ (#126956) Moves further along (but still fails) for the testcase in https://github.com/pytorch/pytorch/pull/126875 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126956 Approved by: https://github.com/jansel ghstack dependencies: #126923 --- test/dynamo/test_misc.py | 18 ++++++++++++++++++ torch/_dynamo/variables/misc.py | 12 ++++++++---- torch/_dynamo/variables/nn_module.py | 16 ++++++++++++++++ torch/_dynamo/variables/user_defined.py | 1 + 4 files changed, 43 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 739dedbc8d05..83ba3936f2de 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10407,6 +10407,24 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + def test_module_dunder_dict(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.foo = 1 + self.bar = 2 + self.baz = 3 + + def forward(self, x): + if "foo" in self.__dict__: + return x * self.bar + return x * self.baz + + mod = MyModule() + x = torch.randn(10) + opt_mod = torch.compile(mod, backend="eager", fullgraph=True) + self.assertEqual(mod(x), opt_mod(x)) + class TestTracer(JitTestCase): def test_jit_save(self): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index a5ac9c4d8fb4..cc0fb7096701 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -691,11 +691,13 @@ def call_method( and self.name == "__dict__" and not kwargs and args[0].is_python_constant() - and isinstance(self.obj, variables.UserDefinedObjectVariable) + and isinstance( + self.obj, + (variables.UserDefinedObjectVariable, variables.NNModuleVariable), + ) ): obj = self.obj key = args[0].as_python_constant() - obj._check_for_getattribute() if obj.has_key_in_generic_dict(tx, key): # redirect to var_getattr on the original obj return obj.var_getattr(tx, key) @@ -713,11 +715,13 @@ def call_method( and len(args) == 1 and args[0].is_python_constant() and not kwargs - and isinstance(self.obj, variables.UserDefinedObjectVariable) + and isinstance( + self.obj, + (variables.UserDefinedObjectVariable, variables.NNModuleVariable), + ) ): obj = self.obj key = args[0].as_python_constant() - obj._check_for_getattribute() if obj.has_key_in_generic_dict(tx, key): return variables.ConstantVariable(True) else: diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index f71767a7b7cb..e1848de97935 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -189,6 +189,19 @@ def convert_to_unspecialized(self, tx): GenerationTracker.mark_class_dynamic(type(mod)) raise UnspecializeRestartAnalysis + def has_key_in_generic_dict(self, tx, key): + base = tx.output.get_submodule(self.module_key) + + if object_has_getattribute(base): + unimplemented("NNModuleVariable with custom __getattribute__") + + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + base_dict = object.__getattribute__(base, "__dict__") + return key in base_dict + def _custom_getattr_fallback(self, base, tx, name, options): """Check for a __getattr__ and handle it specially if it is implemented""" if object_has_getattribute(base): @@ -223,6 +236,9 @@ def var_getattr(self, tx, name): if not self.source: unimplemented("GETATTR with no source") + if name == "__dict__": + return variables.GetAttrVariable(self, name, source=source) + if name in base_dict: subobj = base_dict[name] elif ( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ca913060abf9..5b785293911f 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -819,6 +819,7 @@ def _getattr_static(self, name): return subobj def has_key_in_generic_dict(self, tx, key): + self._check_for_getattribute() if tx.output.side_effects.has_pending_mutation_of_attr(self, key): mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) return not isinstance(mutated_attr, variables.DeletedVariable) From 114c752b14e7f1226a889cf7939f32c20df06a38 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 1 Jun 2024 16:39:06 +0000 Subject: [PATCH 212/706] Revert "Improve MAGMA conditional macro in BatchLinearAlgebra.cpp (#127495)" This reverts commit ee08cf57924a4230edad3101666890d8fe050c75. Reverted https://github.com/pytorch/pytorch/pull/127495 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/127495#issuecomment-2143508218)) --- .../native/cuda/linalg/BatchLinearAlgebra.cpp | 213 ++++++++++-------- 1 file changed, 115 insertions(+), 98 deletions(-) diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index 18a1316fb567..2122d2af5f6a 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -40,6 +40,8 @@ #include #include +const bool use_magma_ = true; + namespace { struct MagmaInitializer { MagmaInitializer() { @@ -59,6 +61,9 @@ struct MagmaInitializer { #error "MAGMA release minor or micro version >= 10, please correct AT_MAGMA_VERSION" #endif +#else +const bool use_magma_ = false; + #endif namespace at::native { @@ -79,9 +84,9 @@ void magmaLdlHermitian( magma_int_t ldda, magma_int_t* ipiv, magma_int_t* info) { - static_assert( - false&&sizeof(scalar_t), - "LDL decomposition is not available." + TORCH_CHECK( + false, + "LDL decomposition is not available.", "Please rebuild with MAGMA 2.5.4+."); } @@ -1029,13 +1034,18 @@ magma_trans_t to_magma(TransposeType trans) { namespace { -#if AT_MAGMA_ENABLED() template void apply_ldl_factor_magma( const Tensor& A, const Tensor& pivots, const Tensor& info, bool upper) { +#if !AT_MAGMA_ENABLED() + TORCH_CHECK( + false, + "torch.linalg.ldl_factor: MAGMA library not found in " + "compilation. Please rebuild with MAGMA."); +#else auto batch_size = batchCount(A); magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)"); magma_int_t leading_dim = magma_int_cast(A.stride(-1), "A.stride(-1)"); @@ -1066,6 +1076,7 @@ void apply_ldl_factor_magma( } pivots.copy_(pivots_cpu); info.copy_(info_cpu); +#endif } void ldl_factor_magma( @@ -1087,7 +1098,6 @@ void ldl_factor_magma( apply_ldl_factor_magma(LD, pivots, info, upper); }); } -#endif void ldl_factor_kernel( const Tensor& LD, @@ -1100,10 +1110,8 @@ void ldl_factor_kernel( case at::LinalgBackend::Cusolver: return ldl_factor_cusolver( LD, pivots, info, upper, hermitian); -#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: return ldl_factor_magma(LD, pivots, info, upper, hermitian); -#endif default: // By default use cusolver if available and magma otherwise. // If cusolver and magma 2.5.4+ are both available and hermitian=true, @@ -1147,9 +1155,12 @@ REGISTER_CUDA_DISPATCH(ldl_solve_stub, &ldl_solve_kernel) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -#if AT_MAGMA_ENABLED() template static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, int64_t& info) { +#if !AT_MAGMA_ENABLED() +AT_ERROR("cholesky_solve: MAGMA library not found in " + "compilation. Please rebuild with MAGMA."); +#else magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower; auto A_data = A.data_ptr(); @@ -1168,8 +1179,8 @@ static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, int64_t& info auto b_mat_stride = matrixStride(b); magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount"); - scalar_t** A_array = nullptr; - scalar_t** b_array = nullptr; + scalar_t** A_array; + scalar_t** b_array; ALLOCATE_ARRAY(A_array, scalar_t*, batch_size); ALLOCATE_ARRAY(b_array, scalar_t*, batch_size); @@ -1186,7 +1197,7 @@ static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, int64_t& info // Compute as many batches of 65535 possible // The number of "mini"-batches are floor(batch_size / batch_limit) // and these cover floor(batch_size / batch_limit) * batch_limit matrix solves - int64_t mini_batches = batch_size / batch_limit, mini_idx = 0; + int64_t mini_batches = batch_size / batch_limit, mini_idx; for (mini_idx = 0; mini_idx < mini_batches * batch_limit; mini_idx += batch_limit) { scalar_t** A_array_cur = &A_array[mini_idx]; scalar_t** b_array_cur = &b_array[mini_idx]; @@ -1210,6 +1221,7 @@ static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, int64_t& info info = info_tmp; } +#endif } Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bool upper) { @@ -1222,7 +1234,6 @@ Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bo TORCH_CHECK(info == 0, "MAGMA cholesky_solve : invalid argument: ", -info); return self_working_copy; } -#endif // Todo: cusolverDnpotrsBatched only supports nrhs == 1 and does not have good performance. // Batched cholesky_solve is dispatched to magma. @@ -1232,20 +1243,14 @@ Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upp switch (preferred_backend) { case at::LinalgBackend::Cusolver: return _cholesky_solve_helper_cuda_cusolver(self, A, upper); -#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: return _cholesky_solve_helper_cuda_magma(self, A, upper); -#endif default: -#if !AT_MAGMA_ENABLED() - return _cholesky_solve_helper_cuda_cusolver(self, A, upper); -#else - if (batchCount(self) == 1) { + if (batchCount(self) == 1 || !use_magma_) { return _cholesky_solve_helper_cuda_cusolver(self, A, upper); } else { return _cholesky_solve_helper_cuda_magma(self, A, upper); } -#endif } #else return _cholesky_solve_helper_cuda_magma(self, A, upper); @@ -1254,9 +1259,14 @@ Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upp // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -#if AT_MAGMA_ENABLED() template static void apply_cholesky(const Tensor& self, bool upper, const Tensor& info) { +#if !AT_MAGMA_ENABLED() + TORCH_CHECK( + false, + "Calling torch.linalg.cholesky on a CUDA tensor requires compiling ", + "PyTorch with MAGMA. Please use PyTorch built with MAGMA support."); +#else magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower; auto self_data = self.data_ptr(); @@ -1278,7 +1288,7 @@ static void apply_cholesky(const Tensor& self, bool upper, const Tensor& info) { auto self_mat_stride = matrixStride(self); magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount"); - scalar_t** self_array = nullptr; + scalar_t** self_array; ALLOCATE_ARRAY(self_array, scalar_t*, batch_size); @@ -1304,6 +1314,7 @@ static void apply_cholesky(const Tensor& self, bool upper, const Tensor& info) { uplo, n, self_array_cur, lda, info_array_cur, nbatches, magma_queue); } } +#endif } void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info) { @@ -1339,7 +1350,6 @@ void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info) } } } -#endif static void cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) { #if defined(USE_LINALG_SOLVER) && !defined(USE_ROCM) @@ -1348,21 +1358,15 @@ static void cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) case at::LinalgBackend::Cusolver: cholesky_helper_cusolver(input, upper, info); break; -#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: cholesky_helper_magma(input, upper, info); break; -#endif default: -#if !AT_MAGMA_ENABLED() - cholesky_helper_cusolver(input, upper, info); -#else - if (batchCount(input) == 1 || use_cusolver_potrf_batched_) { + if (batchCount(input) == 1 || !use_magma_ || use_cusolver_potrf_batched_) { cholesky_helper_cusolver(input, upper, info); } else { cholesky_helper_magma(input, upper, info); } -#endif } #else cholesky_helper_magma(input, upper, info); @@ -1380,9 +1384,11 @@ This is an in-place routine, content of 'input' is overwritten. MAGMA requires 'infos' to reside in CPU memory. For more information see MAGMA's documentation for POTRS routine. */ -#if AT_MAGMA_ENABLED() template static void apply_cholesky_inverse(Tensor& input, Tensor& infos, bool upper) { +#if !AT_MAGMA_ENABLED() + TORCH_CHECK(false, "cholesky_inverse: MAGMA library not found in compilation. Please rebuild with MAGMA."); +#else // magmaCholeskyInverse (magma_dpotri_gpu) is slow because internally // it transfers data several times between GPU and CPU and calls lapack routine on CPU // using magmaCholeskySolveBatched is a lot faster @@ -1412,6 +1418,7 @@ static void apply_cholesky_inverse(Tensor& input, Tensor& infos, bool upper) { int64_t info_tmp = 0; apply_cholesky_solve(result_u, input_u, upper, info_tmp); infos.fill_(info_tmp); +#endif } // This is a type dispatching helper function for 'apply_cholesky_inverse' @@ -1421,7 +1428,6 @@ Tensor& cholesky_inverse_kernel_impl_magma(Tensor &result, Tensor& infos, bool u }); return result; } -#endif Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper) { // This function calculates the inverse matrix in-place @@ -1432,25 +1438,20 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper) switch (preferred_backend) { case at::LinalgBackend::Cusolver: return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); -#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: return cholesky_inverse_kernel_impl_magma(result, infos, upper); -#endif default: -#if !AT_MAGMA_ENABLED() - return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); -#else - if (batchCount(result) == 1) { + if (batchCount(result) == 1 || + !use_magma_) { return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); } else { return cholesky_inverse_kernel_impl_magma(result, infos, upper); } - -#endif } #else return cholesky_inverse_kernel_impl_magma(result, infos, upper); #endif + } REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); @@ -1525,9 +1526,14 @@ static void apply_lu_factor_looped_magma(const Tensor& input, const Tensor& pivo For further details, please see the MAGMA documentation for magma_dgetrf_batched. */ -#if AT_MAGMA_ENABLED() template static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +#if !AT_MAGMA_ENABLED() + TORCH_CHECK( + false, + "Calling linalg.lu_factor on a CUDA tensor requires compiling ", + "PyTorch with MAGMA. Please rebuild with MAGMA."); +#else // There is a bug in lu_factor_batched_magma in MAGMA < 2.5.2, see // https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on std::tuple version; @@ -1544,7 +1550,7 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv magma_int_t n = magma_int_cast(input.size(-1), "n"); auto leading_dimension = std::max(1, m); - scalar_t** input_array = nullptr; + scalar_t** input_array; ALLOCATE_ARRAY(input_array, scalar_t*, batch_size); // Set up array of pointers to matrices @@ -1564,7 +1570,7 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv // magmaLuBatched might not set the values for it // see https://github.com/pytorch/pytorch/pull/53064 pivots.fill_(1); - magma_int_t** pivots_array = nullptr; + magma_int_t** pivots_array; ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size); for (int64_t i = 0; i < batch_size; i++) { pivots_array[i] = &pivots_data[i * pivots_stride]; @@ -1577,6 +1583,7 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv // block CPU until all operations on the queue are finished // this explicit sync prevents garbage results from the subsequent magmaLuSolveBatched call from a different queue magma_queue_sync(magma_queue.get_queue()); +#endif } static void lu_factor_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { @@ -1590,7 +1597,6 @@ static void lu_factor_batched_magma(const Tensor& input, const Tensor& pivots, c apply_lu_factor_batched_magma(input, pivots, infos, compute_pivots); }); } -#endif static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { auto batch_size = batchCount(input); @@ -1598,7 +1604,6 @@ static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& i auto m = input.size(-2); auto n = input.size(-1); -#if AT_MAGMA_ENABLED() const auto lu_factor_magma = [batch_size](const Tensor& input, const Tensor& pivots, const Tensor& infos, const bool compute_pivots) { if (batch_size == 1) { lu_factor_looped_magma(input, pivots, infos, compute_pivots); @@ -1606,7 +1611,6 @@ static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& i lu_factor_batched_magma(input, pivots, infos, compute_pivots); } }; -#endif const auto preferred_backend = at::globalContext().linalgPreferredBackend(); #ifdef USE_LINALG_SOLVER @@ -1631,12 +1635,9 @@ static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& i lu_factor_cusolver(input, pivots, infos, compute_pivots); } else #endif // ifdef USE_LINALG_SOLVER -#if AT_MAGMA_ENABLED() if (preferred_backend == at::LinalgBackend::Magma) { lu_factor_magma(input, pivots, infos, compute_pivots); - } else -#endif - { // preferred backend == default + } else { // preferred backend == default #ifdef USE_LINALG_SOLVER #if AT_MAGMA_ENABLED() // If magma batched is buggy, we use cusolver @@ -1700,8 +1701,8 @@ AT_ERROR("triangular_solve: MAGMA library not found in " auto A_mat_stride = matrixStride(A); auto b_mat_stride = matrixStride(b); - scalar_t** A_array = nullptr; - scalar_t** b_array = nullptr; + scalar_t** A_array; + scalar_t** b_array; ALLOCATE_ARRAY(A_array, scalar_t*, batch_size); ALLOCATE_ARRAY(b_array, scalar_t*, batch_size); @@ -1719,7 +1720,7 @@ AT_ERROR("triangular_solve: MAGMA library not found in " // The number of "mini"-batches are floor(batch_size / batch_limit) // and these cover floor(batch_size / batch_limit) * batch_limit matrix solves int64_t mini_batches = batch_size / batch_limit; - int64_t mini_idx = 0; // this is outside the loop because it is used for the case batch_size % batch_limit != 0 + int64_t mini_idx; // this is outside the loop because it is used for the case batch_size % batch_limit != 0 for (mini_idx = 0; mini_idx < mini_batches * batch_limit; mini_idx += batch_limit) { scalar_t** A_array_cur = &A_array[mini_idx]; scalar_t** b_array_cur = &b_array[mini_idx]; @@ -1776,7 +1777,7 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) { #ifdef USE_LINALG_SOLVER return orgqr_helper_cusolver(result, tau); // cusolver #else - static_assert(false, "Calling torch.orgqr on a CUDA tensor requires compiling ", + TORCH_CHECK(false, "Calling torch.orgqr on a CUDA tensor requires compiling ", "PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support."); #endif } @@ -1787,8 +1788,8 @@ void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, b #ifdef USE_LINALG_SOLVER ormqr_cusolver(input, tau, other, left, transpose); #else - static_assert(false, - "Calling torch.ormqr on a CUDA tensor requires compiling " + TORCH_CHECK(false, + "Calling torch.ormqr on a CUDA tensor requires compiling ", "PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support."); #endif } @@ -1797,9 +1798,15 @@ REGISTER_CUDA_DISPATCH(ormqr_stub, &ormqr_kernel); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ qr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -#if AT_MAGMA_ENABLED() template static void apply_geqrf(const Tensor& input, const Tensor& tau) { +#if !AT_MAGMA_ENABLED() + TORCH_CHECK( + false, + "Calling torch.geqrf on a CUDA tensor requires compiling ", + "PyTorch with MAGMA. Please use PyTorch built with MAGMA support."); +#else + magma_int_t m = magma_int_cast(input.size(-2), "m"); magma_int_t n = magma_int_cast(input.size(-1), "n"); @@ -1826,6 +1833,7 @@ static void apply_geqrf(const Tensor& input, const Tensor& tau) { checkMagmaInternalError(info, "geqrf"); } tau.copy_(tau_cpu, /*non_blocking=*/true); +#endif } // This is a type dispatching helper function for 'apply_geqrf' @@ -1834,7 +1842,6 @@ void geqrf_magma(const Tensor& input, const Tensor& tau) { apply_geqrf(input, tau); }); } -#endif void geqrf_kernel(const Tensor& input, const Tensor& tau) { #ifdef USE_LINALG_SOLVER @@ -1860,10 +1867,8 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) { // - ?geqrf2_gpu gives correct R, but doesn't allow computation of Q via ?orgqr_gpu // Refer to the below link for more details: // http://icl.cs.utk.edu/magma/forum/viewtopic.php?f=2&t=1015&p=2800&hilit=geqrf_gpu#p2800 -#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: return geqrf_magma(input, tau); -#endif case at::LinalgBackend::Cusolver: default: return geqrf_cusolver_backend(input, tau); @@ -1875,9 +1880,14 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) { REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel); -#if AT_MAGMA_ENABLED() template static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { +#if !AT_MAGMA_ENABLED() + TORCH_CHECK( + false, + "Calling torch.linalg.eigh/eigvalsh on a CUDA tensor requires compiling ", + "PyTorch with MAGMA. Please use PyTorch built with MAGMA support."); +#else TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == kCPU); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == kCPU); @@ -1897,7 +1907,7 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const auto values_data = values.data_ptr(); auto infos_data = infos.data_ptr(); - scalar_t* wA = nullptr; + scalar_t* wA; ALLOCATE_ARRAY(wA, scalar_t, lda * lda); // Run once, first to get the optimum work sizes. @@ -1907,14 +1917,14 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const magma_int_t lwork = -1; scalar_t wkopt; magma_int_t liwork = -1; - magma_int_t iwkopt = -1; + magma_int_t iwkopt; magma_int_t lrwork = -1; value_t rwkopt; magmaSyevd(jobz, uplo, n, vectors_data, lda, values_data, wA, lda, &wkopt, lwork, &rwkopt, lrwork, &iwkopt, liwork, infos_data); - scalar_t* work = nullptr; - magma_int_t* iwork = nullptr; + scalar_t* work; + magma_int_t* iwork; lwork = magma_int_cast(std::max(1, real_impl(wkopt)), "work_size"); liwork = magma_int_cast(std::max(1, iwkopt), "iwork_size"); ALLOCATE_ARRAY(work, scalar_t, lwork); @@ -1941,6 +1951,7 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const return; } } +#endif } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1979,17 +1990,14 @@ void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, co eigenvalues.copy_(eigenvalues_cpu); } } -#endif void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { #if defined(USE_LINALG_SOLVER) auto preferred_backend = at::globalContext().linalgPreferredBackend(); switch (preferred_backend) { -#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: linalg_eigh_magma(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); break; -#endif case at::LinalgBackend::Cusolver: default: linalg_eigh_cusolver(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); @@ -2009,9 +2017,12 @@ This is an in-place routine, content of 'input', 'values', 'vectors' is overwrit 'infos' is an int Tensor containing error codes for each matrix in the batched input. For more information see MAGMA's documentation for GEEV routine. */ -#if AT_MAGMA_ENABLED() template void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) { +#if !AT_MAGMA_ENABLED() +TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling PyTorch with MAGMA. " + "Either transfer the tensor to the CPU before calling torch.linalg.eig or recompile with MAGMA."); +#else TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == at::kCPU); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == at::kCPU); @@ -2061,6 +2072,7 @@ void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& in magmaEig(jobvl, jobvr, n, input_working_ptr, lda, values_working_ptr, lvectors_data, ldvl, rvectors_working_ptr, ldvr, work_data, lwork, rwork_data, info_working_ptr); } +#endif } // This is a type dispatching helper function for 'apply_linalg_eig' @@ -2093,6 +2105,10 @@ static void apply_svd_magma(const Tensor& A, const Tensor& S, const Tensor& Vh, const Tensor& info) { +#if !AT_MAGMA_ENABLED() +AT_ERROR("linalg.svd: MAGMA library not found in " + "compilation. Please rebuild with MAGMA."); +#else using value_t = typename c10::scalar_value_type::type; const auto A_data = A.data_ptr(); const auto U_data = compute_uv ? U.data_ptr() : nullptr; @@ -2120,7 +2136,7 @@ static void apply_svd_magma(const Tensor& A, rwork = static_cast(storage_rwork.mutable_data()); } - magma_int_t* iwork = nullptr; + magma_int_t* iwork; ALLOCATE_ARRAY(iwork, magma_int_t, 8 * std::min(m, n)); // Query svd for the optimal lwork size @@ -2135,7 +2151,7 @@ static void apply_svd_magma(const Tensor& A, &wkopt, lwork, rwork, iwork, info_data); lwork = magma_int_cast(real_impl(wkopt), "work_size"); } - scalar_t* work = nullptr; + scalar_t* work; ALLOCATE_ARRAY(work, scalar_t, lwork); for (int64_t i = 0; i < batchsize; i++) { @@ -2148,6 +2164,7 @@ static void apply_svd_magma(const Tensor& A, work, lwork, rwork, iwork, info_data + i); } +#endif } void svd_magma(const Tensor& A, @@ -2189,7 +2206,6 @@ void svd_magma(const Tensor& A, S.copy_(S_, /*non_blocking*/true); info.copy_(info, /*non_blocking*/true); } -#endif void svd_kernel(const Tensor& A, const bool full_matrices, @@ -2201,13 +2217,10 @@ void svd_kernel(const Tensor& A, const Tensor& info) { #ifdef USE_LINALG_SOLVER // We always use cuSOLVER unless the user has specified they want to use MAGMA -#if AT_MAGMA_ENABLED() bool use_magma = at::globalContext().linalgPreferredBackend() == at::LinalgBackend::Magma; if (use_magma) { svd_magma(A, full_matrices, compute_uv, U, S, Vh, info); - } else -#endif - { + } else { // svd_cusolver computes V rather than Vh, so we pass a view of Vh.mT // and then conjugate Vh in-place svd_cusolver(A, full_matrices, compute_uv, driver, U, S, compute_uv ? Vh.mT() : Vh, info); @@ -2238,9 +2251,14 @@ REGISTER_CUDA_DISPATCH(svd_stub, &svd_kernel) For further details, please see the MAGMA documentation for magma_dgetrs_gpu. */ -#if AT_MAGMA_ENABLED() template static void apply_lu_solve_looped_magma(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) { +#if !AT_MAGMA_ENABLED() + TORCH_CHECK( + false, + "Calling linalg.lu_solve on a CUDA tensor requires compiling ", + "PyTorch with MAGMA. Please rebuild with MAGMA."); +#else auto trans = to_magma(transpose); auto b_data = B.data_ptr(); auto lu_data = LU.data_ptr(); @@ -2278,6 +2296,7 @@ static void apply_lu_solve_looped_magma(const Tensor& LU, const Tensor& pivots, // so we don't need to check it all the time TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0); } +#endif } /* @@ -2296,6 +2315,12 @@ static void apply_lu_solve_looped_magma(const Tensor& LU, const Tensor& pivots, */ template static void apply_lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) { +#if !AT_MAGMA_ENABLED() + TORCH_CHECK( + false, + "Calling linalg.lu_solve on a CUDA tensor requires compiling ", + "PyTorch with MAGMA. Please rebuild with MAGMA."); +#else TORCH_INTERNAL_ASSERT(batchCount(B) == batchCount(LU), "batch_size of LU and B must be the same"); TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same"); auto trans = to_magma(transpose); @@ -2313,9 +2338,9 @@ static void apply_lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, auto pivots_stride = pivots.size(-1); magma_int_t batch_size = magma_int_cast(batchCount(B), "batchCount"); - magma_int_t** pivots_array = nullptr; - scalar_t** lu_array = nullptr; - scalar_t** b_array = nullptr; + magma_int_t** pivots_array; + scalar_t** lu_array; + scalar_t** b_array; ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size); ALLOCATE_ARRAY(lu_array, scalar_t*, batch_size); @@ -2339,7 +2364,7 @@ static void apply_lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, scalar_t** b_array_cur = &b_array[mini_idx]; magma_int_t** pivots_array_cur = &pivots_array[mini_idx]; - int info = -1; + int info; magmaLuSolveBatched( n, nrhs, lu_array_cur, leading_dimension, pivots_array_cur, b_array_cur, leading_dimension, @@ -2349,6 +2374,7 @@ static void apply_lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, // so we don't need to check it all the time TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0); } +#endif } static void lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { @@ -2364,7 +2390,6 @@ static void lu_solve_looped_magma(const Tensor& LU, const Tensor& pivots, const apply_lu_solve_looped_magma(LU, pivots, B, trans); }); } -#endif c10::MaybeOwned maybe_expand_lu(const Tensor& B, const Tensor& LU) { // B and LU have the same number of dimensions @@ -2399,11 +2424,9 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor auto b = batchCount(B); auto n = LU.size(-2); auto k = B.size(-1); -#if AT_MAGMA_ENABLED() // magma implementation of LU solve cannot handle a b tensor with last dim > 1024 // See https://bitbucket.org/icl/magma/issues/19/dgesv_batched-dgetrs_batched-fails-for bool over_batched_magma_dim_limit = k > 1024; -#endif // heuristics determined from tests discussed in https://github.com/pytorch/pytorch/pull/72935 // Computes X = U^{-1}L^{-1}P^T B via triangular solves @@ -2418,7 +2441,7 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) - .declare_static_shape(pivots_->sizes(), /*squash_dims=*/pivots_->dim() - 1) + .declare_static_shape(pivots_->sizes(), /*squash_dim=*/pivots_->dim() - 1) .add_output(perm) .add_const_input(*pivots_) .build(); @@ -2434,7 +2457,7 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor // B1 = P^T @ B (must be done out-of-place as B is both source and target) auto B1 = B.scatter(-2, inv_perm.unsqueeze(-1).expand_as(B), B); // B = L^{-1} @ B1 - at::linalg_solve_triangular_out(const_cast(B), *LU_, B1, /*upper=*/false, /*left=*/true, /*unitriangular=*/true); + at::linalg_solve_triangular_out(const_cast(B), *LU_, std::move(B1), /*upper=*/false, /*left=*/true, /*unitriangular=*/true); // B = U^{-1} @ B at::linalg_solve_triangular_out(const_cast(B), *LU_, B, /*upper=*/true); } else { @@ -2456,13 +2479,11 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor }; #endif -#if AT_MAGMA_ENABLED() auto lu_solve_batched_magma_fn = [](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { auto LU_ = maybe_expand_lu(B, LU); auto pivots_ = maybe_expand_pivots(B, pivots); lu_solve_batched_magma(*LU_, *pivots_, B, trans); }; -#endif // Preferred Backend @@ -2477,7 +2498,6 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor return; } else #endif // ifdef USE_LINALG_SOLVER -#if AT_MAGMA_ENABLED() if (preferred_backend == at::LinalgBackend::Magma) { // Looped magma is very slow, but batched magma is buggy in these two cases if (!over_batched_magma_dim_limit && trans == TransposeType::NoTranspose) { @@ -2488,7 +2508,6 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor } return; } -#endif // Heuristic //if (n == k) { @@ -2529,12 +2548,9 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor } if (n <= 8) { -#if AT_MAGMA_ENABLED() - if (!over_batched_magma_dim_limit && trans == TransposeType::NoTranspose && k >= 256) { + if (use_magma_ && !over_batched_magma_dim_limit && trans == TransposeType::NoTranspose && k >= 256) { lu_solve_batched_magma_fn(LU, pivots, B, trans); - } else -#endif - { + } else { lu_solve_batched_cublas_fn(LU, pivots, B, trans); } } else if (n <= 64) { @@ -2567,9 +2583,12 @@ REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_kernel); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -#if AT_MAGMA_ENABLED() template static void apply_gels(const Tensor& a, Tensor& b, Tensor& infos) { +#if !AT_MAGMA_ENABLED() + TORCH_CHECK(false, "torch.linalg.lstsq: MAGMA library not found in " + "compilation. Please rebuild with MAGMA."); +#else auto trans = MagmaNoTrans; auto m = magma_int_cast(a.size(-2), "m"); auto n = magma_int_cast(a.size(-1), "n"); @@ -2599,6 +2618,7 @@ static void apply_gels(const Tensor& a, Tensor& b, Tensor& infos) { hwork_ptr, lwork, infos_working_ptr); } ); +#endif } void gels_magma(const Tensor& a, Tensor& b, Tensor& infos) { @@ -2606,7 +2626,6 @@ void gels_magma(const Tensor& a, Tensor& b, Tensor& infos) { apply_gels(a, b, infos); }); } -#endif void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/) { // The steps for using the QR decomposition for solving least squares problems @@ -2695,10 +2714,8 @@ void gels_looped(const Tensor& a, Tensor& b, Tensor& infos) { #if defined(USE_LINALG_SOLVER) && !defined(USE_ROCM) auto preferred_backend = at::globalContext().linalgPreferredBackend(); switch (preferred_backend) { -#if AT_MAGMA_ENABLED() case at::LinalgBackend::Magma: return gels_magma(a, b, infos); -#endif case at::LinalgBackend::Cusolver: default: // linalg_lstsq_gels is a generic function that is implemented using From d49dc8f4b8f57d14bc2e148e69f6168272c696a5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 1 Jun 2024 17:12:46 +0000 Subject: [PATCH 213/706] Revert "Add noqa to prevent lint warnings (#127545)" This reverts commit f9937afd4f87fbb4844642ae2f587b13b5caa08c. Reverted https://github.com/pytorch/pytorch/pull/127545 on behalf of https://github.com/izaitsevfb due to reverting to unblock the revert of #127545 ([comment](https://github.com/pytorch/pytorch/pull/127545#issuecomment-2143517711)) --- test/inductor/test_halide.py | 2 +- test/inductor/test_kernel_benchmark.py | 2 +- test/inductor/test_triton_wrapper.py | 2 +- torch/_inductor/autotune_process.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 2 +- torch/_inductor/compile_fx.py | 2 +- torch/_inductor/scheduler.py | 2 +- torch/_inductor/select_algorithm.py | 2 +- torch/testing/_internal/inductor_utils.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/inductor/test_halide.py b/test/inductor/test_halide.py index 9b923bd1981d..158a669dad2f 100644 --- a/test/inductor/test_halide.py +++ b/test/inductor/test_halide.py @@ -3,7 +3,7 @@ import unittest import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._inductor.codecache import HalideCodeCache from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta from torch._inductor.test_case import run_tests, TestCase diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index ffe0300d8aad..87ddb0bec2e6 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -6,7 +6,7 @@ from unittest.mock import patch import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._dynamo.testing import rand_strided from torch._inductor import config from torch._inductor.codecache import PyCodeCache diff --git a/test/inductor/test_triton_wrapper.py b/test/inductor/test_triton_wrapper.py index f0d3ad829d45..7f7ded46182a 100644 --- a/test/inductor/test_triton_wrapper.py +++ b/test/inductor/test_triton_wrapper.py @@ -4,7 +4,7 @@ import sys import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._inductor.codecache import PyCodeCache from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index c9462d788e8d..5e04211639d6 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -25,7 +25,7 @@ ) import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch import multiprocessing from torch._dynamo.testing import rand_strided diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index c24fa33676d4..2a436e8b5858 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -10,7 +10,7 @@ import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile import torch._ops from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey from .. import config, ir diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index c9a253709f40..26d75669a206 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from unittest import mock -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile import torch.fx import torch.utils._pytree as pytree diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index a7517575d888..5a13c7f3cae4 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -28,7 +28,7 @@ import sympy import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 3ba0ff0d949b..90b17de519bc 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -22,7 +22,7 @@ from filelock import FileLock import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._dynamo.testing import rand_strided from torch._dynamo.utils import counters, identity, preserve_rng_state diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 1078a189f69c..d441988d4bd2 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -5,7 +5,7 @@ import unittest import functools from subprocess import CalledProcessError -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._inductor.codecache import CppCodeCache from torch.utils._triton import has_triton from torch.testing._internal.common_utils import ( From 22f392ba408f71a15af46847cddc6e1cf9ce86fb Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 1 Jun 2024 17:16:16 +0000 Subject: [PATCH 214/706] Revert "[easy?] Move AsyncCompile to a different file (#127235)" This reverts commit f58fc16e8f059232f452a333f32e14ff681e12af. Reverted https://github.com/pytorch/pytorch/pull/127235 on behalf of https://github.com/izaitsevfb due to breaking internal tests, see [D58015187](https://www.internalfb.com/diff/D58015187) ([comment](https://github.com/pytorch/pytorch/pull/127235#issuecomment-2143518610)) --- test/inductor/test_codecache.py | 2 +- test/inductor/test_cudacodecache.py | 3 +- test/inductor/test_halide.py | 1 - test/inductor/test_kernel_benchmark.py | 1 - test/inductor/test_triton_wrapper.py | 1 - torch/_inductor/async_compile.py | 239 ------------------- torch/_inductor/autotune_process.py | 1 - torch/_inductor/codecache.py | 207 +++++++++++++++- torch/_inductor/codegen/cpp_wrapper_cpu.py | 3 +- torch/_inductor/codegen/wrapper.py | 4 +- torch/_inductor/compile_fx.py | 2 - torch/_inductor/compile_worker/__main__.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 1 + torch/_inductor/scheduler.py | 1 - torch/_inductor/select_algorithm.py | 1 - torch/testing/_internal/inductor_utils.py | 2 +- 16 files changed, 214 insertions(+), 257 deletions(-) delete mode 100644 torch/_inductor/async_compile.py diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index a96d9aa67e82..d5ae4c4f03ca 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -12,8 +12,8 @@ from torch._dynamo import reset from torch._dynamo.utils import counters from torch._inductor import config, metrics -from torch._inductor.async_compile import AsyncCompile from torch._inductor.codecache import ( + AsyncCompile, cuda_compile_command, CUDACodeCache, FxGraphCachePickler, diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index ac26f6a6656c..33a179a9abc7 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -6,8 +6,7 @@ import torch from torch._inductor import config -from torch._inductor.async_compile import AsyncCompile -from torch._inductor.codecache import CUDACodeCache +from torch._inductor.codecache import AsyncCompile, CUDACodeCache from torch._inductor.codegen.cuda.cuda_env import nvcc_exist from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase diff --git a/test/inductor/test_halide.py b/test/inductor/test_halide.py index 158a669dad2f..52227c20d1ff 100644 --- a/test/inductor/test_halide.py +++ b/test/inductor/test_halide.py @@ -3,7 +3,6 @@ import unittest import torch -import torch._inductor.async_compile from torch._inductor.codecache import HalideCodeCache from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta from torch._inductor.test_case import run_tests, TestCase diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 87ddb0bec2e6..23804e08f23f 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -6,7 +6,6 @@ from unittest.mock import patch import torch -import torch._inductor.async_compile from torch._dynamo.testing import rand_strided from torch._inductor import config from torch._inductor.codecache import PyCodeCache diff --git a/test/inductor/test_triton_wrapper.py b/test/inductor/test_triton_wrapper.py index 7f7ded46182a..24ba84ebf86a 100644 --- a/test/inductor/test_triton_wrapper.py +++ b/test/inductor/test_triton_wrapper.py @@ -4,7 +4,6 @@ import sys import torch -import torch._inductor.async_compile from torch._inductor.codecache import PyCodeCache from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py deleted file mode 100644 index c163df9bd878..000000000000 --- a/torch/_inductor/async_compile.py +++ /dev/null @@ -1,239 +0,0 @@ -from __future__ import annotations - -import functools -import logging -import multiprocessing -import os -import sys -from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor -from functools import partial -from time import time -from typing import Any, Callable, Dict, List, Optional, Set - -import torch -from torch._dynamo.device_interface import get_registered_device_interfaces -from torch._inductor import config -from torch._inductor.codecache import ( - CodeCacheFuture, - CppCodeCache, - CppPythonBindingsCodeCache, - CUDACodeCache, - HalideCodeCache, - LambdaFuture, - TritonCodeCache, - TritonFuture, -) -from torch._inductor.compile_worker.subproc_pool import ( - _warm_process_pool, - AnyPool, - SubprocPool, -) -from torch._inductor.compile_worker.watchdog import _async_compile_initializer - -from torch._inductor.runtime.compile_tasks import ( - _set_triton_ptxas_path, - _worker_compile_triton, -) -from torch._inductor.runtime.hints import HalideMeta - -from torch.hub import _Faketqdm, tqdm - -# timing metrics for time spent in the compilation -_cumulative_compile_time = 0.0 -_t0: Optional[float] = None - -kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") - - -def caching_device_properties(): - for _, device_interface in get_registered_device_interfaces(): - if device_interface.is_available(): - device_interface.Worker.get_device_properties() - - -def _compile_start() -> None: - global _t0 - if _t0 is None: - _t0 = time() - - -def _compile_end() -> None: - global _cumulative_compile_time, _t0 - if _t0 is not None: - t1 = time() - _cumulative_compile_time += t1 - _t0 - _t0 = None - # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) - - -_IS_WINDOWS = sys.platform == "win32" - -log = logging.getLogger(__name__) - - -# Used to keep track of all process pools invoked so far. -_pool_set: Set[AnyPool] = set() - - -def shutdown_compile_workers() -> None: - """Shut down all outstanding compile-worker pools.""" - for pool in _pool_set: - pool.shutdown() - after_fork() - - -def after_fork(): - """Reset pools to initial state without shutting them down""" - _pool_set.clear() - AsyncCompile.process_pool.cache_clear() - - -try: - os.register_at_fork(after_in_child=after_fork) -except AttributeError: - pass # register_at_fork does not exists on windows - - -class AsyncCompile: - def __init__(self) -> None: - pass - - @staticmethod - @functools.lru_cache(1) - def pool() -> ThreadPoolExecutor: - assert config.compile_threads > 1 - return ThreadPoolExecutor(config.compile_threads) - - @staticmethod - @functools.lru_cache(1) - def process_pool() -> AnyPool: - assert config.compile_threads > 1 - pool: AnyPool - if config.worker_start_method == "subprocess": - # Wrapper around ProcessPoolExecutor forks in a new process we control - pool = SubprocPool(config.compile_threads) - else: - # ensure properties have been calculated before processes - # are forked - caching_device_properties() - ctx = multiprocessing.get_context(config.worker_start_method) - pool = ProcessPoolExecutor( - config.compile_threads, - mp_context=ctx, - initializer=partial(_async_compile_initializer, os.getpid()), - ) - # when this pool is created in a subprocess object, the normal exit handler - # doesn't run, and we need to register our own handler. - # exitpriority has to be high, because another one of the finalizers will - # kill the worker thread that sends the shutdown message to the workers... - multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) - - _pool_set.add(pool) - return pool - - @classmethod - def warm_pool(cls) -> None: - if config.compile_threads <= 1: - return - _compile_start() - _warm_process_pool(cls.process_pool(), config.compile_threads) - _compile_end() - - @classmethod - def submit(cls, task: Callable[..., Any]) -> Any: - if config.compile_threads <= 1: - return task() - return cls.pool().submit(task) - - def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): - kernel_code_log.info("Triton Kernel:\n%s", source_code) - _compile_start() - _set_triton_ptxas_path() - - kernel = TritonCodeCache.load(kernel_name, source_code) - if config.compile_threads > 1: - return TritonFuture( - kernel, - self.process_pool().submit( - _worker_compile_triton, - kernel._reload_in_subproc, - ), - ) - else: - kernel.precompile() - return kernel - - def multi_kernel(self, *args, **kwargs) -> Any: - from torch._inductor.codegen.multi_kernel import MultiKernelCall - - # no need to call this in parallel since the sub-kernels are already parallel tasks - return MultiKernelCall(*args, **kwargs) - - def cpp(self, source_code: str): - kernel_code_log.info("CPP Kernel:\n%s", source_code) - if config.compile_threads <= 1: - return CppCodeCache.load(source_code).kernel - else: - get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) - return LambdaFuture(lambda: get_result().kernel) - - def cpp_pybinding(self, argtypes: List[str], source_code: str): - kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) - if config.compile_threads <= 1: - return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) - else: - get_result = CppPythonBindingsCodeCache.load_pybinding_async( - argtypes, source_code, submit_fn=self.submit - ) - return LambdaFuture(get_result) - - def cuda(self, source_code, dst_file_ext): - kernel_code_log.info("CUDA Kernel:\n%s", source_code) - - def task(): - return CUDACodeCache.load(source_code, dst_file_ext)[0] - - return self.submit(task) - - def halide(self, meta: HalideMeta, source_code: str): - kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) - if config.compile_threads <= 1: - return HalideCodeCache.generate_halide(meta, source_code) - else: - get_result = HalideCodeCache.generate_halide_async( - meta, source_code, submit_fn=self.submit - ) - return LambdaFuture(get_result) - - def wait(self, scope: Dict[str, Any]) -> None: - num_kernels = len( - [ - value - for key, value in scope.items() - if isinstance(value, (Future, CodeCacheFuture)) - ] - ) - pbar = tqdm( - total=num_kernels, - desc="Inductor Compilation", - disable=config.disable_progress, - delay=0, - ) - if config.compile_threads > 1: - for key, result in scope.items(): - if config.verbose_progress and not isinstance(pbar, _Faketqdm): - pbar.set_postfix_str(key) - if isinstance(result, (Future, CodeCacheFuture)): - scope[key] = result.result() - pbar.update(1) - - _compile_end() - - -if ( - os.environ.get("TORCH_TNT_IN_USE", "0") == "1" - or os.environ.get("TORCH_WARM_POOL", "1") != "1" -): - pass -else: - AsyncCompile.warm_pool() diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 5e04211639d6..7db4d2a3291c 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -25,7 +25,6 @@ ) import torch -import torch._inductor.async_compile from torch import multiprocessing from torch._dynamo.testing import rand_strided diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 617a7ba7e262..37375ddd4639 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -9,6 +9,7 @@ import io import json import logging +import multiprocessing import os import pickle import pkgutil @@ -25,7 +26,7 @@ import threading import warnings from bisect import bisect_right -from concurrent.futures import Future +from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor from copy import copy from ctypes import c_void_p, cdll, CDLL from functools import partial @@ -48,13 +49,22 @@ ) import torch +from torch._dynamo.device_interface import get_registered_device_interfaces from torch._dynamo.utils import counters, dynamo_timed from torch._inductor import config, exc, metrics from torch._inductor.codegen.cuda import cuda_env +from torch._inductor.compile_worker.subproc_pool import ( + _warm_process_pool, + AnyPool, + SubprocPool, +) +from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import ( _module_to_triton_kernel, _reload_python_module, _reload_python_module_in_subproc, + _set_triton_ptxas_path, + _worker_compile_triton, ) from torch._inductor.runtime.hints import HalideMeta from torch._inductor.runtime.runtime_utils import cache_dir @@ -72,6 +82,7 @@ from torch._inductor.graph import GraphLowering from torch._inductor.ir import ChoiceCaller +from torch.hub import _Faketqdm, tqdm _HERE = os.path.abspath(__file__) _TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) @@ -103,11 +114,31 @@ def use_global_cache() -> bool: output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") LOCK_TIMEOUT = 600 _IS_WINDOWS = sys.platform == "win32" +# timing metrics for time spent in the compilation +_cumulative_compile_time = 0.0 +_t0: Optional[float] = None + + +def _compile_start() -> None: + global _t0 + if _t0 is None: + _t0 = time() + + +def _compile_end() -> None: + global _cumulative_compile_time, _t0 + if _t0 is not None: + t1 = time() + _cumulative_compile_time += t1 - _t0 + _t0 = None + # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) + log = logging.getLogger(__name__) @@ -3179,6 +3210,12 @@ def load(cls, source_code, dst_file_ext) -> Tuple[DLLWrapper, str, str]: return (DLLWrapper(dst_file_path), hash_key, source_code_path) +def caching_device_properties(): + for _, device_interface in get_registered_device_interfaces(): + if device_interface.is_available(): + device_interface.Worker.get_device_properties() + + class CodeCacheFuture: def result(self): raise NotImplementedError @@ -3212,3 +3249,171 @@ def __init__(self, result_fn): def result(self): return self.result_fn() + + +# Used to keep track of all process pools invoked so far. +_pool_set: Set[AnyPool] = set() + + +def shutdown_compile_workers() -> None: + """Shut down all outstanding compile-worker pools.""" + for pool in _pool_set: + pool.shutdown() + after_fork() + + +def after_fork(): + """Reset pools to initial state without shutting them down""" + _pool_set.clear() + AsyncCompile.process_pool.cache_clear() + + +try: + os.register_at_fork(after_in_child=after_fork) +except AttributeError: + pass # register_at_fork does not exists on windows + + +class AsyncCompile: + def __init__(self) -> None: + pass + + @staticmethod + @functools.lru_cache(1) + def pool() -> ThreadPoolExecutor: + assert config.compile_threads > 1 + return ThreadPoolExecutor(config.compile_threads) + + @staticmethod + @functools.lru_cache(1) + def process_pool() -> AnyPool: + assert config.compile_threads > 1 + pool: AnyPool + if config.worker_start_method == "subprocess": + # Wrapper around ProcessPoolExecutor forks in a new process we control + pool = SubprocPool(config.compile_threads) + else: + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + ctx = multiprocessing.get_context(config.worker_start_method) + pool = ProcessPoolExecutor( + config.compile_threads, + mp_context=ctx, + initializer=partial(_async_compile_initializer, os.getpid()), + ) + # when this pool is created in a subprocess object, the normal exit handler + # doesn't run, and we need to register our own handler. + # exitpriority has to be high, because another one of the finalizers will + # kill the worker thread that sends the shutdown message to the workers... + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + + _pool_set.add(pool) + return pool + + @classmethod + def warm_pool(cls) -> None: + if config.compile_threads <= 1: + return + _compile_start() + _warm_process_pool(cls.process_pool(), config.compile_threads) + _compile_end() + + @classmethod + def submit(cls, task: Callable[..., Any]) -> Any: + if config.compile_threads <= 1: + return task() + return cls.pool().submit(task) + + def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): + kernel_code_log.info("Triton Kernel:\n%s", source_code) + _compile_start() + _set_triton_ptxas_path() + + kernel = TritonCodeCache.load(kernel_name, source_code) + if config.compile_threads > 1: + return TritonFuture( + kernel, + self.process_pool().submit( + _worker_compile_triton, + kernel._reload_in_subproc, + ), + ) + else: + kernel.precompile() + return kernel + + def multi_kernel(self, *args, **kwargs) -> Any: + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + # no need to call this in parallel since the sub-kernels are already parallel tasks + return MultiKernelCall(*args, **kwargs) + + def cpp(self, source_code: str): + kernel_code_log.info("CPP Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppCodeCache.load(source_code).kernel + else: + get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) + return LambdaFuture(lambda: get_result().kernel) + + def cpp_pybinding(self, argtypes: List[str], source_code: str): + kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) + else: + get_result = CppPythonBindingsCodeCache.load_pybinding_async( + argtypes, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def cuda(self, source_code, dst_file_ext): + kernel_code_log.info("CUDA Kernel:\n%s", source_code) + + def task(): + return CUDACodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def halide(self, meta: HalideMeta, source_code: str): + kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) + if config.compile_threads <= 1: + return HalideCodeCache.generate_halide(meta, source_code) + else: + get_result = HalideCodeCache.generate_halide_async( + meta, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def wait(self, scope: Dict[str, Any]) -> None: + num_kernels = len( + [ + value + for key, value in scope.items() + if isinstance(value, (Future, CodeCacheFuture)) + ] + ) + pbar = tqdm( + total=num_kernels, + desc="Inductor Compilation", + disable=config.disable_progress, + delay=0, + ) + if config.compile_threads > 1: + for key, result in scope.items(): + if config.verbose_progress and not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(key) + if isinstance(result, (Future, CodeCacheFuture)): + scope[key] = result.result() + pbar.update(1) + + _compile_end() + + +if ( + os.environ.get("TORCH_TNT_IN_USE", "0") == "1" + or os.environ.get("TORCH_WARM_POOL", "1") != "1" +): + pass +else: + AsyncCompile.warm_pool() diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 2a436e8b5858..7bda7cf255c4 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -9,11 +9,10 @@ from sympy import Expr import torch - -import torch._inductor.async_compile import torch._ops from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey from .. import config, ir + from ..codecache import CudaKernelParamCache from ..utils import cache_on_self, sympy_product from ..virtualized import V diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0bf4814f80b1..f1028e9068d6 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -33,7 +33,7 @@ from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.symbol import symbol_is_type, SymT -from .. import async_compile, config, ir +from .. import codecache, config, ir from ..ir import ReinterpretView from ..runtime import triton_heuristics from ..runtime.hints import DeviceProperties @@ -501,7 +501,7 @@ def write_header(self) -> None: from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided - from {async_compile.__name__} import AsyncCompile + from {codecache.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 26d75669a206..7eca31da87b4 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -11,8 +11,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from unittest import mock -import torch._inductor.async_compile - import torch.fx import torch.utils._pytree as pytree diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py index e478a5345675..6cd1d1e600ac 100644 --- a/torch/_inductor/compile_worker/__main__.py +++ b/torch/_inductor/compile_worker/__main__.py @@ -3,7 +3,7 @@ import sys import typing -from torch._inductor.async_compile import caching_device_properties +from torch._inductor.codecache import caching_device_properties from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 6629e0fe5e77..078ff472461d 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -748,6 +748,7 @@ def save_cuda_kernel(self, grid, stream, launcher): # User defined triton kernels will have arbitrary kwarg names "meta": launcher.config.kwargs, } + from torch._inductor.codecache import CudaKernelParamCache binary = ( diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 5a13c7f3cae4..b8cf50bd1ba8 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -28,7 +28,6 @@ import sympy import torch -import torch._inductor.async_compile from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 90b17de519bc..f5ff815c3e43 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -22,7 +22,6 @@ from filelock import FileLock import torch -import torch._inductor.async_compile from torch._dynamo.testing import rand_strided from torch._dynamo.utils import counters, identity, preserve_rng_state diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index d441988d4bd2..e8db1e394b96 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -5,7 +5,7 @@ import unittest import functools from subprocess import CalledProcessError -import torch._inductor.async_compile + from torch._inductor.codecache import CppCodeCache from torch.utils._triton import has_triton from torch.testing._internal.common_utils import ( From edffb28d398372c0a98fb2db815f16b835a1db19 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sat, 1 Jun 2024 17:19:24 +0000 Subject: [PATCH 215/706] [BE][Ez]: Enable B019 - flags memory leaks through LRU cache on method (#127686) Flags potential mem leaks through LRUCache and will hopefully make future contributors rethink this pattern which can cause memleaks. noqas the violations we currently have (should be fixed later) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127686 Approved by: https://github.com/c-p-i-o --- pyproject.toml | 1 - torch/_inductor/codecache.py | 4 ++-- torch/_inductor/codegen/cpp_wrapper_cpu.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cuda.py | 2 +- torch/_inductor/select_algorithm.py | 2 +- torch/distributed/_tensor/_sharding_prop.py | 2 +- 6 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 133c86047606..8ec04f77042d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ ignore = [ # these ignores are from flake8-bugbear; please fix! "B007", "B008", "B017", "B018", # Useless expression - "B019", "B023", "B028", # No explicit `stacklevel` keyword argument found "E402", diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 37375ddd4639..f53f159cc019 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -258,7 +258,7 @@ def set_value(self, *keys: str, value: Any) -> None: class PersistentCache(CacheBase): - @functools.lru_cache(None) + @functools.lru_cache(None) # noqa: B019 def get_global_cache(self): global_cache_path = self.get_global_cache_path() if global_cache_path is None or not global_cache_path.is_file(): @@ -1292,7 +1292,7 @@ def build_arch_flags(self) -> str: def __hash__(self) -> int: return hash(str(self)) - @functools.lru_cache(None) + @functools.lru_cache(None) # noqa: B019 def __bool__(self) -> bool: if config.cpp.vec_isa_ok is not None: return config.cpp.vec_isa_ok diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 7bda7cf255c4..d9c499de34ee 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1533,7 +1533,7 @@ def codegen_layout(self, layout): else: return LAYOUT_TO_ATEN[layout] - @functools.lru_cache(None) + @functools.lru_cache(None) # noqa: B019 def codegen_int_array_var( self, int_array: str, diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index e77277d75621..2519f80b6626 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -76,7 +76,7 @@ def generate(self, is_inference): self.prefix.writeline("\n") return super().generate(is_inference) - @functools.lru_cache(None) + @functools.lru_cache(None) # noqa: B019 def generate_load_kernel_once( self, name: str, diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index f5ff815c3e43..f9221d0dd49b 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -772,7 +772,7 @@ def to_callable(self): def call_name(self): return f"extern_kernels.{self.name}" - @functools.lru_cache(None) + @functools.lru_cache(None) # noqa: B019 def hash_key(self): fn = self.to_callable() parts = [ diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/_tensor/_sharding_prop.py index 314ef87193eb..3510f80cbeba 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/_tensor/_sharding_prop.py @@ -88,7 +88,7 @@ def register_op_strategy( if schema_info is not None: self.op_to_schema_info[op_overload] = schema_info - @lru_cache + @lru_cache # noqa: B019 def _propagate_tensor_meta( self, op_schema: OpSchema ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: From 42312a52b3c8b6e545c60e4b34b59a3a5b57ceb3 Mon Sep 17 00:00:00 2001 From: Lucas Pasqualin Date: Sat, 1 Jun 2024 17:50:52 +0000 Subject: [PATCH 216/706] [DSD] Adds type_check param to copy state dict utils (#127417) [DSD] Adds type_check param to copy state dict utils. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127417 Approved by: https://github.com/fegin --- torch/distributed/_state_dict_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 48d1a6bfb9c2..fced8800047c 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -331,6 +331,7 @@ def _copy_state_dict( state_dict: Dict[str, Any], copy_state_dict: Dict[str, Any], non_blocking: bool = False, + type_check: bool = True, ) -> Dict[str, Any]: """ Copies all tensors in a given state dict into a different state_dict with the @@ -352,6 +353,9 @@ def _copy_state_dict( The state dict we are copying into. This state_dict must have exactly the same structure as the source `state_dict`. non_blocking: (bool): Whether copy ops should be performed asynchronously + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. Returns: State Dict copy @@ -367,7 +371,7 @@ def _copy_state_dict( cpu_offload=False, ranks_only=tuple(), companion_obj=copy_state_dict, - type_check=True, + type_check=type_check, non_blocking=non_blocking, ) From 82cd7a7dab91a0d8fe189ebab89d1fc40192077f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 1 Jun 2024 18:46:16 +0000 Subject: [PATCH 217/706] Revert "Default meta device to use swap_tensors in nn.Module._apply (.to_empty and .to('meta')) (#126819)" This reverts commit fa426b096b3635daab6ce26b44d50f3baab5a4e5. Reverted https://github.com/pytorch/pytorch/pull/126819 on behalf of https://github.com/izaitsevfb due to suspicious build instructions count regression, see [D58015016](https://www.internalfb.com/diff/D58015016) ([comment](https://github.com/pytorch/pytorch/pull/126814#issuecomment-2143545818)) --- test/test_modules.py | 27 ++++++++++----------------- torch/nn/modules/module.py | 2 -- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index 601cf5cefdf9..e854eec8add7 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -983,23 +983,16 @@ def test_to_empty(self, device, dtype, module_info, swap, training): p_ids_after = [id(p) for p in m.parameters()] p_cdatas_after = [p._cdata for p in m.parameters()] - # id same, ._cdata differs --> swapped cdata of THPVariable - # Technically, meta and device have different shallow copy types, so when swap=False it will create a new - # parameter and assign it to the module BUT we opt into swap_tensors when either one is on meta. - self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) - self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) - - # Test the opposite direction device --> meta - m = m.to(device="meta") - - p_ids_after_meta = [id(p) for p in m.parameters()] - p_cdatas_after_meta = [p._cdata for p in m.parameters()] - - # id same, ._cdata differs --> swapped cdata of THPVariable - # Technically, meta and device have different shallow copy types, so when swap=False it will create a new - # parameter and assign it to the module BUT we opt into swap_tensors when either one is on meta. - self.assertTrue(all(a == b for a, b in zip(p_ids_after, p_ids_after_meta))) - self.assertTrue(all(a != b for a, b in zip(p_cdatas_after, p_cdatas_after_meta))) + if swap: + # id same, ._cdata differs --> swapped cdata of THPVariable + self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) + self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) + else: + # id and ._cdata differ + # meta and device have different shallow copy types, so this will create a new + # parameter and assign it to the module + self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after))) + self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) instantiate_device_type_tests(TestModule, globals(), allow_mps=True) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f9d400cdbb35..70ff0bd7297f 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -798,8 +798,6 @@ def compute_should_use_swap_tensors(tensor, tensor_applied): return (should_use_swap_tensors # subclasses may have multiple child tensors so we need to use swap_tensors or is_traceable_wrapper_subclass(tensor_applied) - or tensor.device.type == 'meta' - or tensor_applied.device.type == 'meta' or tensor.device.type == 'xla' or tensor_applied.device.type == 'xla') From 17dea09b15cefc9dc5ee94833f2de7947b5333d3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 1 Jun 2024 18:46:16 +0000 Subject: [PATCH 218/706] Revert "Default XLA to use swap_tensors path in nn.Module._apply (#126814)" This reverts commit bfdec93395f675a0e5a59e95aef9104ac8f5081a. Reverted https://github.com/pytorch/pytorch/pull/126814 on behalf of https://github.com/izaitsevfb due to suspicious build instructions count regression, see [D58015016](https://www.internalfb.com/diff/D58015016) ([comment](https://github.com/pytorch/pytorch/pull/126814#issuecomment-2143545818)) --- test/test_nn.py | 4 ++-- torch/nn/modules/module.py | 10 ++-------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 6bcb4017e4b5..6dfac4f7ca1b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8184,9 +8184,9 @@ def test_batchnorm_large_batch(self, device, dtype): @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128) def test_conv_empty_input(self, device, dtype): def help(input, conv, memory_format): - ref_out = conv(input).detach() + ref_out = conv(input) conv_cl = conv.to(memory_format=memory_format) - out_cl = conv_cl(input).detach() + out_cl = conv_cl(input) self.assertEqual(ref_out, out_cl) input_cl = input.to(memory_format=memory_format) out_cl2 = conv(input_cl) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 70ff0bd7297f..73420c0f32e7 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -794,13 +794,6 @@ def compute_should_use_set_data(tensor, tensor_applied): should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() - def compute_should_use_swap_tensors(tensor, tensor_applied): - return (should_use_swap_tensors - # subclasses may have multiple child tensors so we need to use swap_tensors - or is_traceable_wrapper_subclass(tensor_applied) - or tensor.device.type == 'xla' - or tensor_applied.device.type == 'xla') - for key, param in self._parameters.items(): if param is None: continue @@ -811,7 +804,8 @@ def compute_should_use_swap_tensors(tensor, tensor_applied): param_applied = fn(param) p_should_use_set_data = compute_should_use_set_data(param, param_applied) - p_should_use_swap_tensors = compute_should_use_swap_tensors(param, param_applied) + # subclasses may have multiple child tensors so we need to use swap_tensors + p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) param_grad = param.grad if p_should_use_swap_tensors: From b505e8647547f029d0f7df408ee5f2968f757f89 Mon Sep 17 00:00:00 2001 From: Wei Wang <143543872+nWEIdia@users.noreply.github.com> Date: Sat, 1 Jun 2024 19:12:40 +0000 Subject: [PATCH 219/706] [Inductor][CI][CUDA 12.4] Update dynamic_inductor_timm_training.csv - change gluon_inception_v3 from fail_accuracy to pass (#127672) From the HUD, most of the time the "X" is due to "improved_accuracy" for gluon_inception_v3. ![image](https://github.com/pytorch/pytorch/assets/143543872/d4f70377-2756-4921-872d-587426f00302) https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=inductor_timm Pull Request resolved: https://github.com/pytorch/pytorch/pull/127672 Approved by: https://github.com/eqy, https://github.com/Skylion007 --- .../cu124/dynamic_inductor_timm_training.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv index b1a70e91cbae..9443ae8c83a8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv @@ -86,7 +86,7 @@ ghostnet_100,pass,6 -gluon_inception_v3,fail_accuracy,7 +gluon_inception_v3,pass,7 From 6e2e09f6cc0b0be530ee1f5766134f17fe6a47e3 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 31 May 2024 13:48:05 -0700 Subject: [PATCH 220/706] [inductor] fix redis-related env vars in remote_cache.py (#127583) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127583 Approved by: https://github.com/oulgen --- torch/_inductor/remote_cache.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 7c40f603c4d9..5bf3f50154e8 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -27,17 +27,14 @@ class RedisRemoteCacheBackend(RemoteCacheBackend): def __init__(self, cache_id: str): import redis - self._cache_id = cache_id - self._key_fmt = os.environ.get( - "TORCHINDUCTOR_REDIS_KEY_FORMAT", "pt2:{cache_id}:{key}" - ) + self._key_fmt = f"pt2:{cache_id}:{{key}}" self._redis = redis.Redis( - host=os.environ.get("TRITON_REDIS_HOST", "localhost"), - port=int(os.environ.get("TRITON_REDIS_PORT", 6379)), + host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"), + port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)), ) def _get_key(self, key: str) -> str: - return self._key_fmt.format(cache_id=self._cache_id, key=key) + return self._key_fmt.format(key=key) def get(self, key: str): return self._redis.get(self._get_key(key)) From 2e779166eb3c3fca2d97214041c07ff30235b323 Mon Sep 17 00:00:00 2001 From: eqy Date: Sat, 1 Jun 2024 21:22:53 +0000 Subject: [PATCH 221/706] [Functorch][cuDNN] Bump tolerances for `test_vmapjvpvjp` (#127355) cuDNN can select a winograd kernel for this case which slightly affects tolerances... Pull Request resolved: https://github.com/pytorch/pytorch/pull/127355 Approved by: https://github.com/zou3519, https://github.com/Skylion007 --- test/functorch/test_ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index bd75a8d0bb74..4766b4cddabb 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -2032,6 +2032,10 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): tol2( "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-04, rtol=5e-04)} ), + tol1( + "nn.functional.conv_transpose2d", + {torch.float32: tol(atol=5e-04, rtol=5e-04)}, + ), tol1("svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), tol1("matrix_exp", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), ), From 0d9e527c4d7430decede2d165c9b6794f6cd2472 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Sun, 2 Jun 2024 00:28:43 +0000 Subject: [PATCH 222/706] Remove tensor storage_offset/storage_bytes from the cache key (#127319) Summary: We observed differences in these fields and inductor does not specialize on them so it is safe to remove them from the key. Test Plan: CI Reviewed By: masnesral Differential Revision: D57871276 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127319 Approved by: https://github.com/masnesral --- test/inductor/test_codecache.py | 24 ++++++++--------- .../_aot_autograd/autograd_cache.py | 2 +- torch/_inductor/codecache.py | 27 +++++++++++++++---- torch/_inductor/compile_fx.py | 23 ++++++++-------- 4 files changed, 47 insertions(+), 29 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index d5ae4c4f03ca..e7f619a9cb36 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -527,7 +527,7 @@ def test_hash_fake_tensors(self): FxGraphCachePickler.dumps(torch.randn(3)[1:]), FxGraphCachePickler.dumps(torch.randn(3)[1:]), ) - self.assertNotEqual( + self.assertEqual( FxGraphCachePickler.dumps(torch.randn(3)[1:]), FxGraphCachePickler.dumps(torch.randn(2)), ) @@ -586,16 +586,16 @@ def test_hash_kwargs(self): ordering of the kwargs dict and any set arguments. """ # Dict order of the kwargs should not affect hashes. - details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1}) - details2 = FxGraphHashDetails(None, [], {"z": 1, "a": 0}) + details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1}, []) + details2 = FxGraphHashDetails(None, [], {"z": 1, "a": 0}, []) self.assertEqual( FxGraphCachePickler.dumps(details1), FxGraphCachePickler.dumps(details2), ) # Different kwarg values should affect hashes. - details1 = FxGraphHashDetails(None, [], {"a": 0}) - details2 = FxGraphHashDetails(None, [], {"a": 1}) + details1 = FxGraphHashDetails(None, [], {"a": 0}, []) + details2 = FxGraphHashDetails(None, [], {"a": 1}, []) self.assertNotEqual( FxGraphCachePickler.dumps(details1), FxGraphCachePickler.dumps(details2), @@ -605,16 +605,16 @@ def test_hash_kwargs(self): # sorting and creating a new set seems to change the order. set1 = {"a", "b", "c", "d", "e", "f", "g"} set2 = set(sorted(set1)) # noqa: C414 - details1 = FxGraphHashDetails(None, [], {"a": set1}) - details2 = FxGraphHashDetails(None, [], {"a": set2}) + details1 = FxGraphHashDetails(None, [], {"a": set1}, []) + details2 = FxGraphHashDetails(None, [], {"a": set2}, []) self.assertEqual( FxGraphCachePickler.dumps(details1), FxGraphCachePickler.dumps(details2), ) # But different set contents should affect hashes. - details1 = FxGraphHashDetails(None, [], {"a": {1, 2, 3}}) - details2 = FxGraphHashDetails(None, [], {"a": {1, 2}}) + details1 = FxGraphHashDetails(None, [], {"a": {1, 2, 3}}, []) + details2 = FxGraphHashDetails(None, [], {"a": {1, 2}}, []) self.assertNotEqual( FxGraphCachePickler.dumps(details1), FxGraphCachePickler.dumps(details2), @@ -625,11 +625,11 @@ def test_hash_config_changes(self): Test that different config settings affect hashes. """ with config.patch({"max_autotune": False}): - details1 = FxGraphHashDetails(None, [], {}) - details2 = FxGraphHashDetails(None, [], {}) + details1 = FxGraphHashDetails(None, [], {}, []) + details2 = FxGraphHashDetails(None, [], {}, []) with config.patch({"max_autotune": True}): - details3 = FxGraphHashDetails(None, [], {}) + details3 = FxGraphHashDetails(None, [], {}, []) self.assertEqual( FxGraphCachePickler.dumps(details1), diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 07885c136c7d..057aff8467c5 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -118,7 +118,7 @@ def __init__( self.code_hash = get_autograd_code_hash() self.autograd_config = config.save_config() try: - super().__init__(gm, example_inputs, {}) + super().__init__(gm, example_inputs, {}, []) except BypassFxGraphCache as e: # Sometimes inductor configs are unpickleable and can fail raise BypassAOTAutogradCache from e diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index f53f159cc019..5d648e6b2a98 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -446,11 +446,22 @@ def _ident(x: Any) -> Any: return x +def extract_tensor_metadata_for_cache_key(t): + """ + Extracts the tensor metadata and removes fields of the TensorMetadata + that are not needed for caching + """ + meta = extract_tensor_metadata(t) + if not hasattr(t, "_is_inductor_static"): + meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None) + return meta + + def _reduce_fake_tensor(t): """ See FxGraphCachePickler. Custom reducer to pickle FakeTensors. """ - metadata = extract_tensor_metadata(t) + metadata = extract_tensor_metadata_for_cache_key(t) return (_ident, (metadata,)) @@ -481,7 +492,7 @@ def _reduce_tensor(t): f"FX graph cache handling of a large constant took {elapsed:.1}s. Please file an issue." ) - metadata = extract_tensor_metadata(t) + metadata = extract_tensor_metadata_for_cache_key(t) return (_ident, (TensorMetadataAndValues(metadata, values),)) @@ -554,7 +565,7 @@ def debug_str(cls, inp: Any) -> str: def get_str(obj) -> str: if isinstance(obj, torch.Tensor): - return str(extract_tensor_metadata(obj)) + return str(extract_tensor_metadata_for_cache_key(obj)) elif isinstance(obj, bytes): return "" else: @@ -639,6 +650,7 @@ def __init__( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], ): self.gm = gm self.example_inputs = example_inputs @@ -654,6 +666,9 @@ def __init__( else: self.fx_kwargs[k] = fx_kwargs[k] + # Alignment checks + self.inputs_to_check = inputs_to_check + # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels. self.deterministic_algorithms_settings = ( torch.are_deterministic_algorithms_enabled(), @@ -686,11 +701,12 @@ def compiled_fx_graph_hash( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], ) -> str: """ Generate a unique hash of the FX graph for caching. """ - details = FxGraphHashDetails(gm, example_inputs, fx_kwargs) + details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check) # The prefix distinguishes among the other kinds of objects we # cache in this module. key = "f" + FxGraphCachePickler.get_hash(details) @@ -990,6 +1006,7 @@ def load( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], local: bool, remote: bool, ): @@ -1001,7 +1018,7 @@ def load( compiled_graph = None try: FxGraphCache._check_can_cache(gm) - key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs) + key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs, inputs_to_check) remote_cache = None if remote: diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 7eca31da87b4..77b8925a79a5 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -480,16 +480,26 @@ def compile_fx_inner( start = time.time() fx_graph_remote_cache = should_use_remote_fx_graph_cache() + inputs_to_check = get_input_idxs_to_check(example_inputs, range(num_fixed)) if ( not config.force_disable_caches and (config.fx_graph_cache or fx_graph_remote_cache) and not aot_mode ): + for i, input in enumerate(example_inputs): + if ( + isinstance(input, torch.Tensor) + and input.device.type == "cuda" + and i < num_fixed + ): + input._is_inductor_static = True # type: ignore[attr-defined] + compiled_graph = FxGraphCache.load( fx_codegen_and_compile, gm, example_inputs, graph_kwargs, + inputs_to_check, local=config.fx_graph_cache, remote=fx_graph_remote_cache, ) @@ -625,8 +635,8 @@ def compiled_artifact(new_inputs): # cudagraphs does its own aligning of inputs if not cudagraphs: - new_callable = align_inputs( - compiled_graph.current_callable, example_inputs, range(num_fixed) + new_callable = align_inputs_from_check_idxs( + compiled_graph.current_callable, inputs_to_check ) if new_callable is not compiled_graph.current_callable: compiled_graph.current_callable = new_callable @@ -908,15 +918,6 @@ def run(new_inputs): return run -def align_inputs( - model: Callable[[List[torch.Tensor]], Any], - inputs: List[torch.Tensor], - static_input_idxs: Sequence[int] = (), -): - inputs_to_check = get_input_idxs_to_check(inputs, static_input_idxs) - return align_inputs_from_check_idxs(model, inputs_to_check) - - @dynamo_utils.dynamo_timed def cudagraphify( model: torch.fx.GraphModule, From 2cef2fc2b4cb0ba95ddc33768e41267e6c19a058 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Sun, 2 Jun 2024 00:36:33 +0000 Subject: [PATCH 223/706] [ts migration] support aten::dim, aten::len, aten::__getitem__ (#127593) - Add support for aten::dim, aten::len, aten::__getitem__ for torchscript to export converter. - Add unit tests Co-authored-by: cyy Co-authored-by: Menglu Yu Co-authored-by: Animesh Jain Co-authored-by: Simon Fan Co-authored-by: Zain Rizvi Co-authored-by: Tugsbayasgalan (Tugsuu) Manlaibaatar Co-authored-by: titaiwangms Co-authored-by: Yueming Hao Co-authored-by: IvanKobzarev Co-authored-by: PyTorch MergeBot Co-authored-by: Edward Z. Yang Co-authored-by: Bin Bao Co-authored-by: Feny Patel Co-authored-by: Mikayla Gawarecki Co-authored-by: xinan.lin Co-authored-by: Zain Huda Co-authored-by: Chien-Chin Huang Co-authored-by: Wei Wang Co-authored-by: Jason Ansel Co-authored-by: Aaron Gokaslan Co-authored-by: Iris Z <31293777+wz337@users.noreply.github.com> Co-authored-by: Wang, Eikan Co-authored-by: angelayi Co-authored-by: Svetlana Karslioglu Co-authored-by: Yanbo Liang Co-authored-by: Catherine Lee Co-authored-by: Kwanghoon An Co-authored-by: Brian Hirsh Co-authored-by: Robert Mast Co-authored-by: drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/127593 Approved by: https://github.com/SherlockNoMad, https://github.com/malfet --- test/export/test_converter.py | 40 +++++++++++++++++++++++++++++++++++ torch/_export/converter.py | 13 ++++++++++++ 2 files changed, 53 insertions(+) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 9ab6161fe4d7..cde08b7f7cd3 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -70,6 +70,46 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): self._check_equal_ts_ep_converter(MOutputTuple(), inp) self._check_equal_ts_ep_converter(MOutputDict(), inp) + def test_aten_dim(self): + class Module(torch.nn.Module): + def forward(self, x): + num_dim = x.dim() + return torch.ones(num_dim) + + inp = (torch.ones(1, 3),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_aten_len(self): + class Module(torch.nn.Module): + def forward(self, x): + length = len(x) + return torch.ones(length) + + inp = (torch.ones(2, 3),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_aten___getitem___list(self): + class Module(torch.nn.Module): + def forward(self, x): + y = torch.split(x, 2) + return y[0] + + inp = (torch.rand((3, 2)),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_aten___getitem___dict(self): + class Module(torch.nn.Module): + def forward(self, x): + y = torch.split(x, 2) + d_int = {0: y[0], 1: y[1]} + d_str = {"0": y[0], "1": y[1]} + d_bool = {True: y[0], False: y[1]} + d_float = {0.1: y[0], 2.3: y[1]} + return d_int[0], d_str["0"], d_bool[True], d_float[0.1] + + inp = (torch.rand((3, 2)),) + self._check_equal_ts_ep_converter(Module(), inp) + def test_prim_device(self): class Module(torch.nn.Module): def forward(self, x): diff --git a/torch/_export/converter.py b/torch/_export/converter.py index fb10ac6a1e66..88f43d46cb48 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -1,4 +1,5 @@ import operator + from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch @@ -363,6 +364,16 @@ def convert_aten_div(self, node: torch._C.Node): self.convert_aten_op(node) + def convert_aten___getitem__(self, node: torch._C.Node): + input_container, index = tuple( + self.get_fx_value(input) for input in node.inputs() + ) + fx_node = self.fx_graph.call_function( + operator.getitem, (input_container, index) + ) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + def convert_prim_if(self, node: torch._C.Node): inputs = list(node.inputs()) assert len(inputs) == 1 @@ -452,6 +463,8 @@ def convert_node(self, node: torch._C.Node): # convert_aten_Int(node) elif node_kind == "aten::_convolution": self.convert_aten__convolution(node) + elif node_kind == "aten::__getitem__": + self.convert_aten___getitem__(node) elif node_kind == "aten::div": self.convert_aten_div(node) elif node_kind == "prim::If": From c19ad112f65ab88af3eeb319dda0675c076b0b67 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Fri, 31 May 2024 10:28:11 -0700 Subject: [PATCH 224/706] [Inductor UT][Intel GPU] Skip test case which doesn't currently work on the XPU stack but newly re-enabled by community. (#127629) The Inductor UT test/inductor/test_triton_heuristics.py:test_artificial_zgrid that previously skipped was recently enbaled by the PR https://github.com/pytorch/pytorch/pull/127448. However, the test doesn't currently work on the XPU stack, it will huang on GPU, so this PR skip the test for Intel GPU instead of expected failure. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127629 Approved by: https://github.com/EikanWang, https://github.com/peterbell10 --- test/inductor/test_triton_heuristics.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index c0908251f85b..549903f47ce4 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -4,9 +4,8 @@ import unittest import torch -from torch.testing._internal.common_device_type import expectedFailureXPU -from torch.testing._internal.common_utils import IS_LINUX +from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU try: @@ -73,11 +72,11 @@ def forward(primals_1, primals_2, primals_5): ] self.assertEqual(forward(*args), foo_c(*args)) - @expectedFailureXPU + @skipIfXpu def test_artificial_zgrid(self): self._test_artificial_zgrid() - @expectedFailureXPU + @skipIfXpu @config.patch("cpp_wrapper", True) def test_artificial_grid_cpp_wrapper(self): self._test_artificial_zgrid() From 4fd777ed5979920ccc27621c17984c3c92f58206 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Sun, 2 Jun 2024 02:09:58 +0000 Subject: [PATCH 225/706] [ONNX] Add quantized layer norm op to opset 17 (#127640) Fixes #126160 Continue #126555 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127640 Approved by: https://github.com/justinchuby --- torch/onnx/symbolic_opset17.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/torch/onnx/symbolic_opset17.py b/torch/onnx/symbolic_opset17.py index 3aad249a1126..c7720b9e5c9f 100644 --- a/torch/onnx/symbolic_opset17.py +++ b/torch/onnx/symbolic_opset17.py @@ -26,7 +26,7 @@ # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md -__all__ = ["layer_norm", "stft"] +__all__ = ["layer_norm", "stft", "quantized_layer_norm"] _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) @@ -67,6 +67,24 @@ def layer_norm( ) +@_onnx_symbolic("quantized::layer_norm") +def quantized_layer_norm( + g: jit_utils.GraphContext, + x, + normalized_shape, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = layer_norm(g, x, normalized_shape, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + def _compute_edge_sizes(n_fft, window_size): """Helper function to compute the sizes of the edges (left and right) of a given window centered within an FFT size.""" From 16578e858465a7461ecc4c9068171e8cf9d8b63f Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Sun, 2 Jun 2024 02:28:40 +0000 Subject: [PATCH 226/706] [symbolic shapes] if symbol not in var_ranges default to unknown range (#127681) Purpose of this PR is to get around this error: https://github.com/pytorch/pytorch/issues/127677 Differential Revision: D58048558 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127681 Approved by: https://github.com/lezcano --- torch/fx/experimental/symbolic_shapes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 7492009e517a..d94068b770d3 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4411,7 +4411,11 @@ def _maybe_evaluate_static( # Skip var_ranges logic for SingletonInt which is only used # for jagged layout NestedTensors today continue - vr = var_ranges[k] + try: + vr = var_ranges[k] + except KeyError: + log.warning("%s is not in var_ranges, defaulting to unknown range.", k) + vr = self._default_unspecified_value_range() if size_oblivious and k in self.size_like: lower = max(2, vr.lower) else: From 2129903aa3e2298aed0cc1ad05aaedfafb7c5a35 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sun, 2 Jun 2024 03:43:22 +0000 Subject: [PATCH 227/706] Properly detect nested torch function args (#127496) Dynamo was not detecting nested torch function classes in containers. This was due to pytree compatibility for variable trackers being removed. Fixes https://github.com/pytorch/pytorch/issues/127174 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127496 Approved by: https://github.com/anijain2305 --- test/dynamo/test_misc.py | 6 ++-- test/dynamo/test_structured_trace.py | 6 ++-- test/dynamo/test_subclasses.py | 37 +++++++++++++++++++++++ test/test_overrides.py | 24 ++++++++++++++- torch/_dynamo/config.py | 5 +-- torch/_dynamo/utils.py | 13 +++++--- torch/_dynamo/variables/torch_function.py | 37 +++++++++++++++++++---- 7 files changed, 109 insertions(+), 19 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 83ba3936f2de..bcb0fd18818e 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -700,7 +700,7 @@ def f(x, y, z, n): """\ def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]"): # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None return ()""", ) @@ -759,7 +759,7 @@ def f(x, y, z, n): """\ def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]"): # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None getitem_4: "f32[3]" = foo_default[0] getitem_5: "f32[3]" = foo_default[1]; foo_default = None return (getitem_4, getitem_5)""", @@ -851,7 +851,7 @@ def f(x, y, z, n): """\ def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"): # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = arg3_1 = arg0_1 = arg1_1 = None + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg1_1 = arg0_1 = None return ()""", ) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index fb7a2c249a85..c27118c74fde 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -198,11 +198,11 @@ def fn(x, y): {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 1, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "l_y_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_y_": [1000, 1000], "l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 96887da09ea3..1bb571ccd0e3 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -433,6 +433,43 @@ def fn(x): res = fn(input) self.assertIsInstance(res, LocalSubclass) + def test_torch_function_list_args(self): + HANDLED_FUNCTIONS = {} + + class MyClass: + def __init__(self, foo): + self.foo = foo + + @classmethod + def __torch_function__( + cls, + func, + types, + args=(), + kwargs=None, + ): + if kwargs is None: + kwargs = {} + if func not in HANDLED_FUNCTIONS or not all( # noqa: C419 + [ # noqa: C419 + issubclass(t, (torch.Tensor, MyClass)) for t in types + ] + ): + return NotImplemented + return HANDLED_FUNCTIONS[func](*args, **kwargs) + + def _stack(input, dim=0, *, out=None): + return MyClass(sum([x.foo for x in input])) + + HANDLED_FUNCTIONS[torch.stack] = _stack + + @torch.compile(backend="eager", fullgraph=True) + def fn(v0, v1): + return torch.stack([v0, v1]) + + ret = fn(MyClass(1), MyClass(1)) + self.assertEqual(ret.foo, 2) + @parametrize( "comparison", [ diff --git a/test/test_overrides.py b/test/test_overrides.py index cb46ca6ed880..a55688b95f31 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -8,8 +8,9 @@ import pickle import collections import unittest +import contextlib -from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF +from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO from torch.overrides import ( handle_torch_function, has_torch_function, @@ -377,6 +378,27 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs) class TestTorchFunctionOverride(TestCase): + @classmethod + def setUpClass(cls): + cls._stack = contextlib.ExitStack() + if TEST_WITH_TORCHDYNAMO: + # Add classes to the wrapped tensor subclasses + @contextlib.contextmanager + def setup_subclasses(): + old = set(torch._dynamo.config.traceable_tensor_subclasses) + torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor) + try: + yield + finally: + torch._dynamo.config.traceable_tensor_subclasses.clear() + torch._dynamo.config.traceable_tensor_subclasses.update(old) + + cls._stack.enter_context(setup_subclasses()) + + @classmethod + def tearDownClass(cls): + cls._stack.close() + def test_mean_semantics(self): """Test that a function with one argument can be overrided""" t1 = DiagonalTensor(5, 2) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 212021859c46..6487a2726381 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -116,8 +116,9 @@ def is_fbcode(): # This feature doesn't really work. We offer this flag for experimental # purposes / if you want to help us build out support. # -# torchdynamo has very limited support for tensor subclasses that implement -# __torch_function__. Our current support is limited to tensor subclasses +# torchdynamo has limited support for tensor subclasses that implement +# __torch_function__ see [Note: __torch_function__] in torch_function.py. +# Our current support is limited to tensor subclasses # that DO NOT store metadata on the tensor (in general, dynamo does not # support Python code that stores extra attributes on tensors at present). # If your tensor subclass purely changes function call behavior via diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 58768957af87..9d43ba551cc1 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2529,12 +2529,17 @@ def is_torch_function_object(value): def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bool: - from torch._dynamo.variables import UserDefinedObjectVariable + from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable - return isinstance(vt, TensorWithTFOverrideVariable) or ( - isinstance(vt, UserDefinedObjectVariable) - and hasattr(vt.value, "__torch_function__") + if isinstance(vt, TensorWithTFOverrideVariable): + return True + + if isinstance(vt, LazyVariableTracker): + LazyVariableTracker.realize(vt) + + return isinstance(vt, UserDefinedObjectVariable) and hasattr( + vt.value, "__torch_function__" ) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 0674b8cfd146..6f210d498ce0 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -48,6 +48,33 @@ ] +def _get_all_args(args, kwargs): + return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs)) + + +def _flatten_vts(vts): + from collections import deque + + from .dicts import ConstDictVariable + from .lazy import LazyVariableTracker + from .lists import ListVariable + + vts = deque(vts) + output = [] + + while vts: + vt = vts.pop() + LazyVariableTracker.realize_all(vt) + if isinstance(vt, ListVariable): + vts.extend(vt.items) + elif isinstance(vt, ConstDictVariable): + vts.extend(vt.items.values()) + else: + output.append(vt) + + return output + + def _get_subclass_type(var): assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) return var.python_type() @@ -109,17 +136,15 @@ def build_torch_function_fn(tx, value, source): def can_dispatch_torch_function(tx, args, kwargs): - if tx.output.torch_function_enabled: - all_args = pytree.arg_tree_leaves(*args, **kwargs) - return any(has_torch_function(arg) for arg in all_args) - else: - return False + return tx.output.torch_function_enabled and any( + has_torch_function(arg) for arg in _get_all_args(args, kwargs) + ) def dispatch_torch_function(tx, fn, args, kwargs): """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" - all_args = pytree.arg_tree_leaves(*args, **kwargs) + all_args = _get_all_args(args, kwargs) overloaded_args = _get_overloaded_args( [arg for arg in all_args if has_torch_function(arg)], _get_subclass_type, From 4e7f497bb33626135995f372748aee5806dde821 Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 2 Jun 2024 04:40:21 +0000 Subject: [PATCH 228/706] [Submodule] Remove ios-cmake (#127694) It has not been updated for a long time and CI iOS builds don't rely on it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127694 Approved by: https://github.com/ezyang --- .gitmodules | 4 ---- cmake/Dependencies.cmake | 7 ++----- third_party/ios-cmake | 1 - 3 files changed, 2 insertions(+), 10 deletions(-) delete mode 160000 third_party/ios-cmake diff --git a/.gitmodules b/.gitmodules index 476f11fd945c..c031c2fd5ad3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -18,10 +18,6 @@ ignore = dirty path = third_party/protobuf url = https://github.com/protocolbuffers/protobuf.git -[submodule "third_party/ios-cmake"] - ignore = dirty - path = third_party/ios-cmake - url = https://github.com/Yangqing/ios-cmake.git [submodule "third_party/NNPACK"] ignore = dirty path = third_party/NNPACK diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 4d53b2f860d6..faac0117ed93 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -144,14 +144,14 @@ endif() # ---[ BLAS set(AT_MKLDNN_ACL_ENABLED 0) +set(AT_MKLDNN_ENABLED 0) +set(AT_MKL_ENABLED 0) # setting default preferred BLAS options if not already present. if(NOT DEFINED BLAS) if(NOT INTERN_BUILD_MOBILE) set(BLAS "MKL" CACHE STRING "Selected BLAS library") else() set(BLAS "Eigen" CACHE STRING "Selected BLAS library") - set(AT_MKLDNN_ENABLED 0) - set(AT_MKL_ENABLED 0) endif() elseif(NOT BLAS STREQUAL "MKL") if(USE_MKLDNN) @@ -245,7 +245,6 @@ else() endif() if(NOT INTERN_BUILD_MOBILE) - set(AT_MKL_ENABLED 0) set(AT_MKL_SEQUENTIAL 0) set(USE_BLAS 1) if(NOT (ATLAS_FOUND OR BLIS_FOUND OR GENERIC_BLAS_FOUND OR MKL_FOUND OR OpenBLAS_FOUND OR VECLIB_FOUND OR FlexiBLAS_FOUND OR NVPL_BLAS_FOUND)) @@ -1473,8 +1472,6 @@ if(NOT INTERN_BUILD_MOBILE) set(AT_ROCM_ENABLED 1) endif() - set(AT_MKLDNN_ENABLED 0) - set(AT_MKLDNN_ACL_ENABLED 0) if(USE_MKLDNN) if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) message(WARNING diff --git a/third_party/ios-cmake b/third_party/ios-cmake deleted file mode 160000 index 8abaed637d56..000000000000 --- a/third_party/ios-cmake +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8abaed637d56f1337d6e1d2c4026e25c1eade724 From c1dd3a615f4e15edc9e4264ff6f87267462c6d12 Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Sun, 2 Jun 2024 06:49:47 +0000 Subject: [PATCH 229/706] Implement Graph Transform Observer (#127427) Summary: Implement Graph Transform Observer Differential Revision: D57887518 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127427 Approved by: https://github.com/angelayi --- docs/source/fx.rst | 1 + test/fx/test_fx_xform_observer.py | 61 +++++++++++++++ torch/_inductor/config.py | 6 ++ torch/fx/passes/graph_transform_observer.py | 87 +++++++++++++++++++++ 4 files changed, 155 insertions(+) create mode 100644 test/fx/test_fx_xform_observer.py create mode 100644 torch/fx/passes/graph_transform_observer.py diff --git a/docs/source/fx.rst b/docs/source/fx.rst index e9b7cd2d5723..0a0af6254a5d 100644 --- a/docs/source/fx.rst +++ b/docs/source/fx.rst @@ -1175,6 +1175,7 @@ API Reference .. py:module:: torch.fx.passes.fake_tensor_prop .. py:module:: torch.fx.passes.graph_drawer .. py:module:: torch.fx.passes.graph_manipulation +.. py:module:: torch.fx.passes.graph_transform_observer .. py:module:: torch.fx.passes.infra.partitioner .. py:module:: torch.fx.passes.infra.pass_base .. py:module:: torch.fx.passes.infra.pass_manager diff --git a/test/fx/test_fx_xform_observer.py b/test/fx/test_fx_xform_observer.py new file mode 100644 index 000000000000..8e1a7c5ae2cd --- /dev/null +++ b/test/fx/test_fx_xform_observer.py @@ -0,0 +1,61 @@ +# Owner(s): ["module: fx"] + +import os +import tempfile + +import torch +from torch.fx import subgraph_rewriter, symbolic_trace +from torch.fx.passes.graph_transform_observer import GraphTransformObserver + +from torch.testing._internal.common_utils import TestCase + + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_fx.py TESTNAME\n\n" + "instead." + ) + + +class TestGraphTransformObserver(TestCase): + def test_graph_transform_observer(self): + class M(torch.nn.Module): + def forward(self, x): + val = torch.neg(x) + return torch.add(val, val) + + def pattern(x): + return torch.neg(x) + + def replacement(x): + return torch.relu(x) + + traced = symbolic_trace(M()) + + log_url = tempfile.mkdtemp() + + with GraphTransformObserver(traced, "replace_neg_with_relu", log_url) as ob: + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + + self.assertTrue("relu" in ob.created_nodes) + self.assertTrue("neg" in ob.erased_nodes) + + current_pass_count = GraphTransformObserver.get_current_pass_count() + + self.assertTrue( + os.path.isfile( + os.path.join( + log_url, + f"pass_{current_pass_count}_replace_neg_with_relu_input_graph.svg", + ) + ) + ) + self.assertTrue( + os.path.isfile( + os.path.join( + log_url, + f"pass_{current_pass_count}_replace_neg_with_relu_output_graph.svg", + ) + ) + ) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 960a3567f8c5..0b4cb41d9b2b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -880,6 +880,12 @@ class trace: # to workaround the above failure. dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None) + # If not None, this is the URL that saves the SVG files of the input/output + # graph of each pass that changed the graph + # The nodes that are being transformed in each pass will be colored in yellow + # URL only supports local directory for now + log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None) + # Store cProfile (see snakeviz to view) compile_profile = False diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py new file mode 100644 index 000000000000..a2ec324f512c --- /dev/null +++ b/torch/fx/passes/graph_transform_observer.py @@ -0,0 +1,87 @@ +import os +from typing import Optional + +from torch.fx.graph_module import GraphModule + +from .graph_drawer import FxGraphDrawer + +__all__ = ["GraphTransformObserver"] + + +class GraphTransformObserver: + __pass_count = 0 + + def __init__(self, gm: GraphModule, passname: str, log_url: Optional[str] = None): + # If log_url is None, we don't log anything + self.log_url = log_url + if self.log_url is None: + return + GraphTransformObserver.__pass_count += 1 + self.gm = gm + self.passname = passname + + self.input_dot_graph = FxGraphDrawer( + self.gm, + self.passname, + ignore_getattr=True, + ignore_parameters_and_buffers=True, + ).get_dot_graph() + + @classmethod + def get_current_pass_count(cls): + return cls.__pass_count + + def __enter__(self): + if self.log_url is None or self.gm is None: + return self + + self.erased_nodes = set() + self.created_nodes = set() + self.gm._register_create_node_hook(self.on_node_creation) + self.gm._register_erase_node_hook(self.on_node_erase) + + return self + + def __exit__(self, type, value, tb): + if self.log_url is None or self.gm is None: + return + + self.gm._unregister_create_node_hook(self.on_node_creation) + self.gm._unregister_erase_node_hook(self.on_node_erase) + + if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0: + for e in self.input_dot_graph.get_node_list(): + if e.get_name() in self.erased_nodes: + e.obj_dict["attributes"]["fillcolor"] = "yellow" + else: + e.obj_dict["attributes"]["fillcolor"] = "grey" + self.input_dot_graph.write_svg( + os.path.join( + self.log_url, + f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.svg", + ) + ) + + output_dot_graph = FxGraphDrawer( + self.gm, + self.passname, + ignore_getattr=True, + ignore_parameters_and_buffers=True, + ).get_dot_graph() + for e in output_dot_graph.get_node_list(): + if e.get_name() in self.created_nodes: + e.obj_dict["attributes"]["fillcolor"] = "yellow" + else: + e.obj_dict["attributes"]["fillcolor"] = "grey" + output_dot_graph.write_svg( + os.path.join( + self.log_url, + f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.svg", + ) + ) + + def on_node_creation(self, node): + self.created_nodes.add(node.name) + + def on_node_erase(self, node): + self.erased_nodes.add(node.name) From 67ef2683d970fc541b6d266d4b3f8ba9d13844ca Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 2 Jun 2024 07:31:08 +0000 Subject: [PATCH 230/706] [BE] wrap deprecated function/class with `typing_extensions.deprecated` (#127689) Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing. Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR. Resolves #126888 - #126888 This PR is split from PR #126898. - #126898 ------ Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689 Approved by: https://github.com/Skylion007 --- .github/requirements/conda-env-Linux-X64.txt | 2 +- .github/requirements/conda-env-iOS.txt | 2 +- .github/requirements/conda-env-macOS-ARM64 | 2 +- .github/requirements/conda-env-macOS-X64 | 2 +- test/distributed/_tensor/test_api.py | 2 +- .../distributed/fsdp/test_fsdp_optim_state.py | 2 +- test/functorch/test_eager_transforms.py | 6 +- test/nn/test_init.py | 2 +- test/nn/test_module_hooks.py | 9 +-- test/test_autocast.py | 4 +- test/test_autograd.py | 10 +-- test/test_cuda.py | 8 +-- test/test_prims.py | 2 +- test/test_pytree.py | 2 +- test/test_stateless.py | 2 +- test/test_torch.py | 4 +- torch/__init__.py | 11 --- torch/_dynamo/eval_frame.py | 10 ++- torch/_functorch/deprecated.py | 26 +++---- torch/_functorch/pytree_hacks.py | 5 +- torch/_inductor/ir.py | 3 +- torch/_library/abstract_impl.py | 9 +-- torch/_prims_common/__init__.py | 13 ++-- torch/_vmap_internals.py | 10 +-- torch/ao/nn/quantizable/modules/activation.py | 13 +++- torch/ao/nn/quantized/dynamic/modules/rnn.py | 6 +- torch/ao/quantization/fx/convert.py | 21 ++++-- torch/ao/quantization/fx/fuse.py | 14 ++-- torch/ao/quantization/fx/prepare.py | 28 +++++--- torch/ao/quantization/qconfig.py | 22 ++++-- torch/ao/quantization/quantize_fx.py | 21 ++++-- torch/autograd/__init__.py | 18 +++-- torch/autograd/_functions/tensor.py | 11 +-- torch/autograd/function.py | 16 +++-- torch/autograd/gradcheck.py | 31 +++++---- torch/autograd/profiler.py | 5 +- torch/autograd/profiler_legacy.py | 13 +++- torch/autograd/profiler_util.py | 23 ++++++- torch/backends/cuda/__init__.py | 20 +++--- torch/cpu/amp/autocast_mode.py | 11 +-- torch/cpu/amp/grad_scaler.py | 10 +-- torch/cuda/_memory_viz.py | 10 ++- torch/cuda/amp/autocast_mode.py | 32 ++++++--- torch/cuda/amp/grad_scaler.py | 10 +-- torch/cuda/memory.py | 17 ++--- torch/cuda/nccl.py | 10 ++- torch/distributed/_composable/fully_shard.py | 20 +++--- torch/distributed/_functional_collectives.py | 4 +- .../distributed/_shard/checkpoint/__init__.py | 15 ++-- .../distributed/_shard/sharded_tensor/api.py | 15 ++-- torch/distributed/_sharded_tensor/__init__.py | 14 ++-- torch/distributed/_sharding_spec/__init__.py | 13 ++-- torch/distributed/_tensor/api.py | 4 ++ .../_checkpoint/checkpoint_wrapper.py | 3 +- .../checkpoint/state_dict_loader.py | 10 +-- .../checkpoint/state_dict_saver.py | 11 +-- torch/distributed/distributed_c10d.py | 66 ++++++++++-------- torch/distributed/elastic/metrics/api.py | 12 ++-- torch/distributed/fsdp/_init_utils.py | 3 +- .../fsdp/fully_sharded_data_parallel.py | 43 +++++++++--- torch/distributed/launch.py | 22 +++--- torch/distributed/optim/__init__.py | 14 +++- torch/distributed/pipeline/__init__.py | 18 +++-- torch/distributed/tensor/parallel/_utils.py | 6 +- torch/distributions/distribution.py | 8 ++- .../multipledispatch/dispatcher.py | 30 ++++---- torch/hub.py | 10 ++- torch/jit/_script.py | 5 +- torch/jit/_trace.py | 10 ++- torch/library.py | 12 ++-- torch/multiprocessing/spawn.py | 2 +- torch/nn/functional.py | 21 ++++-- torch/nn/init.py | 6 +- torch/nn/modules/activation.py | 12 +++- torch/nn/modules/container.py | 11 +-- torch/nn/modules/conv.py | 15 ++-- torch/nn/modules/loss.py | 12 ++-- torch/nn/modules/module.py | 69 ++++++++++++------- torch/nn/modules/rnn.py | 6 +- torch/nn/parallel/__init__.py | 11 ++- torch/nn/parallel/comm.py | 5 +- torch/nn/parallel/distributed.py | 4 +- torch/nn/parallel/scatter_gather.py | 8 ++- torch/nn/utils/clip_grad.py | 9 ++- torch/nn/utils/stateless.py | 14 ++-- torch/nn/utils/weight_norm.py | 9 ++- torch/profiler/profiler.py | 6 +- torch/sparse/semi_structured.py | 9 ++- torch/testing/_comparison.py | 16 ++--- torch/testing/_creation.py | 3 +- torch/utils/_config_module.py | 12 ++-- torch/utils/_contextlib.py | 12 ++-- torch/utils/_cxx_pytree.py | 12 ++-- torch/utils/_pytree.py | 28 +++++--- torch/utils/data/backward_compatibility.py | 11 ++- torch/utils/data/dataset.py | 10 +-- torch/utils/data/graph_settings.py | 10 +-- 97 files changed, 763 insertions(+), 458 deletions(-) diff --git a/.github/requirements/conda-env-Linux-X64.txt b/.github/requirements/conda-env-Linux-X64.txt index dc44eb39f69f..e0b7177e39c4 100644 --- a/.github/requirements/conda-env-Linux-X64.txt +++ b/.github/requirements/conda-env-Linux-X64.txt @@ -5,4 +5,4 @@ ninja=1.10.2 numpy=1.23.3 pyyaml=6.0 setuptools=68.2.2 -typing-extensions=4.3.0 +typing-extensions=4.9.0 diff --git a/.github/requirements/conda-env-iOS.txt b/.github/requirements/conda-env-iOS.txt index 3539a8a0ccf8..fe67c6cbc312 100644 --- a/.github/requirements/conda-env-iOS.txt +++ b/.github/requirements/conda-env-iOS.txt @@ -4,4 +4,4 @@ ninja=1.10.2 numpy=1.23.3 pyyaml=6.0 setuptools=68.2.2 -typing-extensions=4.3.0 +typing-extensions=4.9.0 diff --git a/.github/requirements/conda-env-macOS-ARM64 b/.github/requirements/conda-env-macOS-ARM64 index 951cda496403..26b034c7d6e2 100644 --- a/.github/requirements/conda-env-macOS-ARM64 +++ b/.github/requirements/conda-env-macOS-ARM64 @@ -2,7 +2,7 @@ numpy=1.22.3 pyyaml=6.0 setuptools=61.2.0 cmake=3.22.* -typing-extensions=4.3.0 +typing-extensions=4.9.0 dataclasses=0.8 pip=22.2.2 pillow=10.0.1 diff --git a/.github/requirements/conda-env-macOS-X64 b/.github/requirements/conda-env-macOS-X64 index 95be2a082397..35da8324689a 100644 --- a/.github/requirements/conda-env-macOS-X64 +++ b/.github/requirements/conda-env-macOS-X64 @@ -4,7 +4,7 @@ numpy=1.21.2 pyyaml=5.3 setuptools=46.0.0 cmake=3.22.* -typing-extensions=4.3.0 +typing-extensions=4.9.0 dataclasses=0.8 pip=22.2.2 pillow=10.0.1 diff --git a/test/distributed/_tensor/test_api.py b/test/distributed/_tensor/test_api.py index 196bd6407b26..21763b091a54 100644 --- a/test/distributed/_tensor/test_api.py +++ b/test/distributed/_tensor/test_api.py @@ -237,7 +237,7 @@ def output_fn(outputs, device_mesh): assert isinstance(outputs, DTensor) return outputs.to_local() - with self.assertWarnsRegex(UserWarning, "Deprecating"): + with self.assertWarnsRegex(FutureWarning, "Deprecating"): replica_module = distribute_module( module_to_replicate, device_mesh, diff --git a/test/distributed/fsdp/test_fsdp_optim_state.py b/test/distributed/fsdp/test_fsdp_optim_state.py index 672b71d5290f..29925c13f2d6 100644 --- a/test/distributed/fsdp/test_fsdp_optim_state.py +++ b/test/distributed/fsdp/test_fsdp_optim_state.py @@ -1436,7 +1436,7 @@ def should_check_method(method_name: str): def get_warning_context(): warning_regex = "`optim_input` argument is deprecated" return self.assertWarnsRegex( - expected_warning=UserWarning, expected_regex=warning_regex + expected_warning=FutureWarning, expected_regex=warning_regex ) self._run_on_all_optim_state_apis( diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 9aae6e5451a3..c767810beb85 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -3258,7 +3258,7 @@ def test_deprecation_vmap(self, device): x = torch.randn(3, device=device) # functorch version of the API is deprecated - with self.assertWarnsRegex(UserWarning, "Please use torch.vmap"): + with self.assertWarnsRegex(FutureWarning, "Please use `torch.vmap`"): vmap(torch.sin) # the non-functorch version is not deprecated @@ -3276,7 +3276,9 @@ def test_deprecation_transforms(self, device, transform): new_api = getattr(torch.func, transform) # functorch version of the API is deprecated - with self.assertWarnsRegex(UserWarning, f"Please use torch.func.{transform}"): + with self.assertWarnsRegex( + FutureWarning, f"Please use `torch.func.{transform}`" + ): api(torch.sin) # the non-functorch version is not deprecated diff --git a/test/nn/test_init.py b/test/nn/test_init.py index 8826fabc263b..9ae471414474 100644 --- a/test/nn/test_init.py +++ b/test/nn/test_init.py @@ -521,7 +521,7 @@ def fn(): init.normal(x) with self.assertWarnsRegex( - UserWarning, + FutureWarning, "deprecated", msg="methods not suffixed with underscore should be deprecated", ): diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index 8dbd255c6c53..f76837660302 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -1387,7 +1387,8 @@ def forward(self, l): m.register_backward_hook(noop) with self.assertWarnsRegex( - UserWarning, "does not take as input a single Tensor or a tuple of Tensors" + FutureWarning, + "does not take as input a single Tensor or a tuple of Tensors", ): m([a, b]) @@ -1400,7 +1401,7 @@ def forward(self, a, b): m.register_backward_hook(noop) with self.assertWarnsRegex( - UserWarning, "does not return a single Tensor or a tuple of Tensors" + FutureWarning, "does not return a single Tensor or a tuple of Tensors" ): m(a, b) @@ -1413,7 +1414,7 @@ def forward(self, a, b): m.register_backward_hook(noop) with self.assertWarnsRegex( - UserWarning, "outputs are generated by different autograd Nodes" + FutureWarning, "outputs are generated by different autograd Nodes" ): m(a, b) @@ -1426,7 +1427,7 @@ def forward(self, a): m.register_backward_hook(noop) with self.assertWarnsRegex( - UserWarning, "the forward contains multiple autograd Nodes" + FutureWarning, "the forward contains multiple autograd Nodes" ): m(a) diff --git a/test/test_autocast.py b/test/test_autocast.py index ce3d94318ccd..24f87944990d 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -255,8 +255,8 @@ def test_generic_autocast(self): def test_cpu_autocast_deprecated_warning(self): with self.assertWarnsRegex( - DeprecationWarning, - r"torch.cpu.amp.autocast\(args...\) is deprecated. Please use torch.amp.autocast\('cpu', args...\) instead.", + FutureWarning, + r"`torch.cpu.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cpu', args...\)` instead.", ): with torch.cpu.amp.autocast(): _ = torch.ones(10) diff --git a/test/test_autograd.py b/test/test_autograd.py index 0dc9aca21041..911762024930 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -154,7 +154,7 @@ def hook(*args): def test_grad_mode_class_decoration(self): # Decorating class is deprecated and should not be used - with self.assertWarnsRegex(UserWarning, "Decorating classes is deprecated"): + with self.assertWarnsRegex(FutureWarning, "Decorating classes is deprecated"): @torch.no_grad() class Foo: @@ -5937,13 +5937,13 @@ def fn(inputs): b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64) with self.assertWarnsRegex( - UserWarning, "get_numerical_jacobian was part of PyTorch's private API" + FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API" ): jacobian = get_numerical_jacobian(fn, (a, b), target=a, eps=1e-6) self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double)) with self.assertWarnsRegex( - UserWarning, "get_numerical_jacobian was part of PyTorch's private API" + FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API" ): jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6) self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double)) @@ -5963,7 +5963,7 @@ def fn(x, y): outputs = fn(a, b) with self.assertWarnsRegex( - UserWarning, "get_analytical_jacobian was part of PyTorch's private API" + FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API" ): ( jacobians, @@ -5991,7 +5991,7 @@ def backward(ctx, grad_out): outputs = NonDetFunc.apply(a, 1e-6) with self.assertWarnsRegex( - UserWarning, "get_analytical_jacobian was part of PyTorch's private API" + FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API" ): ( jacobians, diff --git a/test/test_cuda.py b/test/test_cuda.py index c919158e2c4e..785f0499df05 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1820,10 +1820,10 @@ def backward(ctx, grad): return grad, grad self.assertRegex( - str(w[0].message), r"torch.cuda.amp.custom_fwd\(args...\) is deprecated." + str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated." ) self.assertRegex( - str(w[1].message), r"torch.cuda.amp.custom_bwd\(args...\) is deprecated." + str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated." ) mymm = MyMM.apply @@ -2016,8 +2016,8 @@ def test_autocast_checkpointing(self): def test_cuda_autocast_deprecated_warning(self): with self.assertWarnsRegex( - DeprecationWarning, - r"torch.cuda.amp.autocast\(args...\) is deprecated. Please use torch.amp.autocast\('cuda', args...\) instead.", + FutureWarning, + r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.", ): with torch.cuda.amp.autocast(): _ = torch.ones(10) diff --git a/test/test_prims.py b/test/test_prims.py index 2a1f10cc8748..f1452acd7ab3 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -338,7 +338,7 @@ def test_mul_complex(self): prims.mul(torch.randn(2), 1 + 1j) def test_check_deprecation_warning(self): - with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'): + with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'): torch._prims_common.check(True, lambda: 'message') diff --git a/test/test_pytree.py b/test/test_pytree.py index caaf4d0b53bd..0a1c480a8fa7 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -723,7 +723,7 @@ def __init__(self, x, y): self.y = y with self.assertWarnsRegex( - UserWarning, "torch.utils._pytree._register_pytree_node" + FutureWarning, "torch.utils._pytree._register_pytree_node" ): py_pytree._register_pytree_node( DummyType, diff --git a/test/test_stateless.py b/test/test_stateless.py index 32ec45937059..6256f2b55bf8 100644 --- a/test/test_stateless.py +++ b/test/test_stateless.py @@ -901,7 +901,7 @@ def test_stateless_functional_call_warns(self): m = torch.nn.Linear(1, 1) params = dict(m.named_parameters()) x = torch.randn(3, 1) - with self.assertWarnsRegex(UserWarning, "Please use torch.func.functional_call"): + with self.assertWarnsRegex(FutureWarning, "Please use `torch.func.functional_call`"): stateless.functional_call(m, params, x) class TestPythonOptimizeMode(TestCase): diff --git a/test/test_torch.py b/test/test_torch.py index 05070abaf669..717943f43646 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6198,8 +6198,8 @@ def test_grad_scaler_deprecated_warning(self, device): GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler with self.assertWarnsRegex( - UserWarning, - rf"torch.{device.type}.amp.GradScaler\(args...\) is deprecated.", + FutureWarning, + rf"`torch.{device.type}.amp.GradScaler\(args...\)` is deprecated.", ): _ = GradScaler(init_scale=2.0) diff --git a/torch/__init__.py b/torch/__init__.py index 440b833fd079..a2492c40a949 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1996,17 +1996,6 @@ def _register_device_module(device_type, module): from torch.func import vmap -# The function _sparse_coo_tensor_unsafe is removed from PyTorch -# Python API (v. 1.13), here we temporarily provide its replacement -# with a deprecation warning. -# TODO: remove the function for PyTorch v 1.15. -def _sparse_coo_tensor_unsafe(*args, **kwargs): - import warnings - warnings.warn('torch._sparse_coo_tensor_unsafe is deprecated, ' - 'use torch.sparse_coo_tensor(..., check_invariants=False) instead.') - kwargs['check_invariants'] = False - return torch.sparse_coo_tensor(*args, **kwargs) - # Register MPS specific decomps torch.backends.mps._init() diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 1e164b2f7895..318fdd265085 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -798,7 +798,9 @@ def guard_export_print(guards): warnings.warn( "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. " "If you don't migrate, we may break your explain call in the future if your user defined kwargs " - "conflict with future kwargs added to explain(f)." + "conflict with future kwargs added to explain(f).", + FutureWarning, + stacklevel=2, ) return inner(*extra_args, **extra_kwargs) else: @@ -941,7 +943,7 @@ def check_signature_rewritable(graph): tb = "".join(traceback.format_list(stack)) extra = "" if len(user_stacks) > 1: - extra = f"(elided {len(user_stacks)-1} more accesses)" + extra = f"(elided {len(user_stacks) - 1} more accesses)" msg = f"{source.name()}, accessed at:\n{tb}{extra}" # TODO: option to print ALL of the stack traces at once input_errors.append(msg) @@ -1476,7 +1478,9 @@ def graph_with_interpreter(*args): warnings.warn( "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. " "If you don't migrate, we may break your export call in the future if your user defined kwargs " - "conflict with future kwargs added to export(f)." + "conflict with future kwargs added to export(f).", + FutureWarning, + stacklevel=2, ) return inner(*extra_args, **extra_kwargs) else: diff --git a/torch/_functorch/deprecated.py b/torch/_functorch/deprecated.py index 82a34f7d41c3..058e206599c5 100644 --- a/torch/_functorch/deprecated.py +++ b/torch/_functorch/deprecated.py @@ -1,3 +1,12 @@ +""" +The APIs in this file are exposed as `functorch.*`. They are thin wrappers +around the torch.func.* APIs that have deprecation warnings -- we're trying +to move people to the torch.func.* equivalents. + +NB: We don't use *args, **kwargs in the signatures because that changes the +documentation. +""" + import textwrap import warnings from typing import Any, Callable, Optional, Tuple, Union @@ -9,25 +18,16 @@ from torch._functorch.eager_transforms import argnums_t from torch._functorch.vmap import in_dims_t, out_dims_t -""" -The APIs in this file are exposed as `functorch.*`. They are thin wrappers -around the torch.func.* APIs that have deprecation warnings -- we're trying -to move people to the torch.func.* equivalents. - -NB: We don't use *args, **kwargs in the signatures because that changes the -documentation. -""" - def get_warning(api, new_api=None, replace_newlines=False): if new_api is None: new_api = f"torch.func.{api}" warning = ( f"We've integrated functorch into PyTorch. As the final step of the \n" - f"integration, functorch.{api} is deprecated as of PyTorch \n" + f"integration, `functorch.{api}` is deprecated as of PyTorch \n" f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n" - f"Please use {new_api} instead; see the PyTorch 2.0 release notes \n" - f"and/or the torch.func migration guide for more details \n" + f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n" + f"and/or the `torch.func` migration guide for more details \n" f"https://pytorch.org/docs/main/func.migrating.html" ) if replace_newlines: @@ -37,7 +37,7 @@ def get_warning(api, new_api=None, replace_newlines=False): def warn_deprecated(api, new_api=None): warning = get_warning(api, new_api, replace_newlines=True) - warnings.warn(warning, stacklevel=2) + warnings.warn(warning, FutureWarning, stacklevel=3) def setup_docs(functorch_api, torch_func_api=None, new_api_name=None): diff --git a/torch/_functorch/pytree_hacks.py b/torch/_functorch/pytree_hacks.py index 8c4b50bc6ad4..96dea7ad1007 100644 --- a/torch/_functorch/pytree_hacks.py +++ b/torch/_functorch/pytree_hacks.py @@ -16,7 +16,8 @@ with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( - "torch._functorch.pytree_hacks is deprecated and will be removed in a future release. " - "Please use torch.utils._pytree instead.", + "`torch._functorch.pytree_hacks` is deprecated and will be removed in a future release. " + "Please `use torch.utils._pytree` instead.", DeprecationWarning, + stacklevel=2, ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 704a38e99e9a..c46cad5e41e2 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2750,7 +2750,8 @@ def make_indexer(self): """A closure containing math to read a given element""" def indexer(index): - assert len(index) == len(self.stride) == len(self.size) + assert len(index) == len(self.stride) + assert len(index) == len(self.size) result = self.offset for idx, stride, sz in zip(index, self.stride, self.size): if sz != 1: diff --git a/torch/_library/abstract_impl.py b/torch/_library/abstract_impl.py index 14d6d8c46235..2946b743ee53 100644 --- a/torch/_library/abstract_impl.py +++ b/torch/_library/abstract_impl.py @@ -1,7 +1,7 @@ import contextlib import functools -import warnings from typing import Callable, Optional +from typing_extensions import deprecated import torch from torch._library.utils import Kernel, RegistrationHandle @@ -124,10 +124,11 @@ def __init__(self, _fake_mode, _op): self._shape_env = _fake_mode.shape_env self._op = _op + @deprecated( + "`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead", + category=FutureWarning, + ) def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: - warnings.warn( - "create_unbacked_symint is deprecated, please use new_dynamic_size instead" - ) return self.new_dynamic_size(min=min, max=max) def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt: diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 68674af0a285..10290535f930 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -21,7 +21,7 @@ TYPE_CHECKING, Union, ) -from typing_extensions import TypeAlias +from typing_extensions import deprecated, TypeAlias if TYPE_CHECKING: @@ -1789,6 +1789,11 @@ def check_in_bounds_for_storage( # NOTE: This function should ideally be removed, but some Meta internal models # packaged with `torch.package` are using it, so it will have to be removed # at some point in the future when those models no longer use this function. +@deprecated( + "`torch._prims_common.check` is deprecated and will be removed in the future. " + "Please use `torch._check*` functions instead.", + category=FutureWarning, +) def check( b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError ) -> None: @@ -1801,12 +1806,6 @@ def check( .. note:: This function is planned for removal in the future. Please use `torch._check*` functions instead. """ - warnings.warn( - DeprecationWarning( - "'torch._prims_common.check' will be removed in the future. Please use " - "'torch._check*' functions instead" - ) - ) torch._check_with(exc_type, b, s) diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index 8440abccb239..465e5dbdca1b 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -1,6 +1,6 @@ import functools -import warnings from typing import Any, Callable, List, Optional, Tuple, Union +from typing_extensions import deprecated import torch from torch import Tensor @@ -190,14 +190,14 @@ def _get_name(func: Callable): # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, # sends those into func, and then unwraps the output BatchedTensors. Operations # on BatchedTensors perform the batched operations that the user is asking for. +@deprecated( + "Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.", + category=FutureWarning, +) def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable: """ Please use torch.vmap instead of this API. """ - warnings.warn( - "Please use torch.vmap instead of torch._vmap_internals.vmap. ", - stacklevel=2, - ) return _vmap(func, in_dims, out_dims) diff --git a/torch/ao/nn/quantizable/modules/activation.py b/torch/ao/nn/quantizable/modules/activation.py index 56be29a09d62..2c1aad574158 100644 --- a/torch/ao/nn/quantizable/modules/activation.py +++ b/torch/ao/nn/quantizable/modules/activation.py @@ -224,7 +224,6 @@ def dequantize(self): return fp - @classmethod def from_observed(cls, other): # The whole flow is float -> observed -> quantized @@ -336,7 +335,11 @@ def _forward_impl(self, if attn_mask is not None: if attn_mask.dtype == torch.uint8: - warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + warnings.warn( + "Byte tensor for `attn_mask` in `nn.MultiheadAttention` is deprecated. " + "Use bool tensor instead.", + stacklevel=3, + ) attn_mask = attn_mask.to(torch.bool) assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ f'Only float and bool types are supported for attn_mask, not {attn_mask.dtype}' @@ -354,7 +357,11 @@ def _forward_impl(self, # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + warnings.warn( + "Byte tensor for `key_padding_mask` in `nn.MultiheadAttention` is deprecated. " + "Use bool tensor instead.", + stacklevel=3, + ) key_padding_mask = key_padding_mask.to(torch.bool) if self.bias_k is not None and self.bias_v is not None: if static_k is None and static_v is None: diff --git a/torch/ao/nn/quantized/dynamic/modules/rnn.py b/torch/ao/nn/quantized/dynamic/modules/rnn.py index 1cef66060719..c81771a71889 100644 --- a/torch/ao/nn/quantized/dynamic/modules/rnn.py +++ b/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -1,5 +1,6 @@ import numbers import warnings +from typing_extensions import deprecated import torch import torch.nn as nn @@ -16,8 +17,11 @@ def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Ten return tensor.index_select(dim, permutation) +@deprecated( + "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", + category=FutureWarning, +) def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: - warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead") return _apply_permutation(tensor, permutation, dim) diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index ef90f8b71ece..6ca622cc4171 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -949,24 +949,33 @@ def convert( if convert_custom_config is None: convert_custom_config = ConvertCustomConfig() - if isinstance(convert_custom_config, Dict): + if isinstance(convert_custom_config, dict): warnings.warn( "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " - "in a future version. Please pass in a ConvertCustomConfig instead.") + "in a future version. Please pass in a ConvertCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) - if isinstance(qconfig_mapping, Dict): + if isinstance(qconfig_mapping, dict): warnings.warn( "Passing a QConfig dictionary to convert is deprecated and will not be supported " - "in a future version. Please pass in a QConfigMapping instead.") + "in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None qconfig_mapping = copy.deepcopy(qconfig_mapping) assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping) - if isinstance(backend_config, Dict): + if isinstance(backend_config, dict): warnings.warn( "Passing a backend_config_dict to prepare is deprecated and will not be supported " - "in a future version. Please pass in a BackendConfig instead.") + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) backend_config = BackendConfig.from_dict(backend_config) if backend_config is None: diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py index 91b876997d10..6b2b614728f8 100644 --- a/torch/ao/quantization/fx/fuse.py +++ b/torch/ao/quantization/fx/fuse.py @@ -52,16 +52,22 @@ def fuse( if fuse_custom_config is None: fuse_custom_config = FuseCustomConfig() - if isinstance(fuse_custom_config, Dict): + if isinstance(fuse_custom_config, dict): warnings.warn( "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " - "in a future version. Please pass in a FuseCustomConfig instead.") + "in a future version. Please pass in a FuseCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) - if isinstance(backend_config, Dict): + if isinstance(backend_config, dict): warnings.warn( "Passing a backend_config_dict to prepare is deprecated and will not be supported " - "in a future version. Please pass in a BackendConfig instead.") + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) backend_config = BackendConfig.from_dict(backend_config) named_modules = dict(model.named_modules()) diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 9ca91ecb4930..d8e25f1260f5 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -1749,28 +1749,40 @@ def prepare( if _equalization_config is None: _equalization_config = QConfigMapping() - if isinstance(qconfig_mapping, Dict): + if isinstance(qconfig_mapping, dict): warnings.warn( "Passing a QConfig dictionary to prepare is deprecated and will not be supported " - "in a future version. Please pass in a QConfigMapping instead.") + "in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) - if isinstance(_equalization_config, Dict): + if isinstance(_equalization_config, dict): warnings.warn( "Passing a QConfig dictionary to prepare for equalization is deprecated and will not " - "be supported in a future version. Please pass in a QConfigMapping instead.") + "be supported in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) _equalization_config = QConfigMapping.from_dict(_equalization_config) - if isinstance(prepare_custom_config, Dict): + if isinstance(prepare_custom_config, dict): warnings.warn( "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " - "in a future version. Please pass in a PrepareCustomConfig instead.") + "in a future version. Please pass in a PrepareCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) - if isinstance(backend_config, Dict): + if isinstance(backend_config, dict): warnings.warn( "Passing a backend_config_dict to prepare is deprecated and will not be supported " - "in a future version. Please pass in a BackendConfig instead.") + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) backend_config = BackendConfig.from_dict(backend_config) assert isinstance(qconfig_mapping, QConfigMapping) diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index dc8353d61729..88e7b47aff2b 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -1,5 +1,6 @@ from collections import namedtuple from typing import Optional, Any, Union, Type +from typing_extensions import deprecated import torch import torch.nn as nn @@ -106,6 +107,10 @@ def __new__(cls, activation, weight): return super().__new__(cls, activation, weight) +@deprecated( + "`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", + category=FutureWarning, +) class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])): """ Describes how to dynamically quantize a layer or a part of the network by providing @@ -127,7 +132,6 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): if isinstance(weight, nn.Module): raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed") - warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead") return super().__new__(cls, activation, weight) @@ -422,16 +426,20 @@ def get_default_qat_qconfig(backend='x86', version=1): weight=None, ) +@deprecated( + "`torch.ao.quantization.get_default_qconfig_dict` is deprecated and will be removed in " + "a future version. Please use `torch.ao.quantization.get_default_qconfig_mapping` instead.", + category=FutureWarning, +) def get_default_qconfig_dict(backend='x86', version=0): - warnings.warn( - "torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in " - "a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.") return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict() +@deprecated( + "`torch.ao.quantization.get_default_qat_qconfig_dict` is deprecated and will be removed in " + "a future version. Please use `torch.ao.quantization.get_default_qat_qconfig_mapping` instead.", + category=FutureWarning, +) def get_default_qat_qconfig_dict(backend='x86', version=1): - warnings.warn( - "torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in " - "a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.") return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict() def _assert_valid_qconfig(qconfig: Optional[QConfig], diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index c9a3db87552a..5767a525342e 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -117,10 +117,13 @@ def _prepare_fx( if _equalization_config is None: _equalization_config = QConfigMapping() - if isinstance(prepare_custom_config, Dict): + if isinstance(prepare_custom_config, dict): warnings.warn( "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " - "in a future version. Please pass in a PrepareCustomConfig instead.") + "in a future version. Please pass in a PrepareCustomConfig instead.", + FutureWarning, + stacklevel=3, + ) prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) # swap FloatFunctional with FXFloatFunctional @@ -222,10 +225,13 @@ def fuse_fx( if fuse_custom_config is None: fuse_custom_config = FuseCustomConfig() - if isinstance(fuse_custom_config, Dict): + if isinstance(fuse_custom_config, dict): warnings.warn( "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " - "in a future version. Please pass in a FuseCustomConfig instead.") + "in a future version. Please pass in a FuseCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") @@ -511,10 +517,13 @@ def _convert_fx( if convert_custom_config is None: convert_custom_config = ConvertCustomConfig() - if isinstance(convert_custom_config, Dict): + if isinstance(convert_custom_config, dict): warnings.warn( "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " - "in a future version. Please pass in a ConvertCustomConfig instead.") + "in a future version. Please pass in a ConvertCustomConfig instead.", + FutureWarning, + stacklevel=3, + ) convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) _check_is_graph_module(graph_module) diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 9b5788aff227..adf47ad1727d 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -252,17 +252,21 @@ def backward( ) if grad_variables is not None: - warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.") + warnings.warn( + "`grad_variables` is deprecated. Use `grad_tensors` instead.", + FutureWarning, + stacklevel=2, + ) if grad_tensors is None: grad_tensors = grad_variables else: raise RuntimeError( - "'grad_tensors' and 'grad_variables' (deprecated) " - "arguments both passed to backward(). Please only " - "use 'grad_tensors'." + "`grad_tensors` and `grad_variables` (deprecated) " + "arguments both passed to `backward()`. Please only " + "use `grad_tensors`." ) if inputs is not None and len(inputs) == 0: - raise RuntimeError("'inputs' argument to backward() cannot be empty.") + raise RuntimeError("`inputs` argument to `backward()` cannot be empty.") tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors) inputs = ( @@ -395,7 +399,9 @@ def grad( warnings.warn( "only_inputs argument is deprecated and is ignored now " "(defaults to True). To accumulate gradient for other " - "parts of the graph, please use torch.autograd.backward." + "parts of the graph, please use torch.autograd.backward.", + FutureWarning, + stacklevel=2, ) grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs)) diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py index f091d38777fc..d2b3149bfc81 100644 --- a/torch/autograd/_functions/tensor.py +++ b/torch/autograd/_functions/tensor.py @@ -1,6 +1,6 @@ import operator -import warnings from functools import reduce +from typing_extensions import deprecated import torch import torch._utils @@ -9,11 +9,12 @@ class Type(Function): @staticmethod + @deprecated( + "`torch.autograd._functions.Type` is deprecated as of PyTorch 2.1, " + "please use `torch.tensor.to(dtype=dtype)` instead.", + category=FutureWarning, + ) def forward(ctx, i, dest_type): - warnings.warn( - "torch.autograd._functions.Type is deprecated as of PyTorch 2.1, please use " - "torch.tensor.to(dtype=dtype) instead." - ) ctx.input_type = type(i) ctx.input_device = -1 if not i.is_cuda else i.get_device() return i.type(dest_type) diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 9c624ce5d14b..9aca2b2a1b32 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -4,6 +4,7 @@ import warnings from collections import OrderedDict from typing import Any, List, Optional, Tuple +from typing_extensions import deprecated import torch import torch._C as _C @@ -179,12 +180,14 @@ def mark_dirty(self, *args: torch.Tensor): """ self.dirty_tensors = args + @deprecated( + "`mark_shared_storage` is deprecated. " + "Tensors with shared storages are automatically tracked. " + "Note that calls to `set_()` are not tracked", + category=FutureWarning, + ) def mark_shared_storage(self, *pairs): - warnings.warn( - "mark_shared_storage is deprecated. " - "Tensors with shared storages are automatically tracked. Note " - "that calls to `set_()` are not tracked" - ) + pass def mark_non_differentiable(self, *args: torch.Tensor): r"""Mark outputs as non-differentiable. @@ -491,9 +494,8 @@ class Function(_SingleLevelFunction): """ def __init__(self, *args, **kwargs): - cls = self.__class__ warnings.warn( - f"{cls} should not be instantiated. Methods on autograd functions" + f"{self.__class__} should not be instantiated. Methods on autograd functions" "are all static, so you should invoke them on the class itself. " "Instantiating an autograd function will raise an " "error in a future version of PyTorch.", diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index f2e6aa22fe94..a0d874038761 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -3,6 +3,7 @@ import warnings from itertools import product from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing_extensions import deprecated import torch import torch.testing @@ -306,6 +307,14 @@ def _get_numerical_jacobian( return jacobians +@deprecated( + "`get_numerical_jacobian` was part of PyTorch's private API and not " + "meant to be exposed. We are deprecating it and it will be removed " + "in a future version of PyTorch. If you have a specific use for " + "this or feature request for this to be a stable API, please file " + "us an issue at https://github.com/pytorch/pytorch/issues/new", + category=FutureWarning, +) def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0): """Compute the numerical Jacobian for a given fn and its inputs. @@ -325,13 +334,6 @@ def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0): Note that `target` may not even be part of `input` to `fn`, so please be **very careful** in this to not clone `target`. """ - warnings.warn( - "get_numerical_jacobian was part of PyTorch's private API and not " - "meant to be exposed. We are deprecating it and it will be removed " - "in a future version of PyTorch. If you have a specific use for " - "this or feature request for this to be a stable API, please file " - "us an issue at https://github.com/pytorch/pytorch/issues/new" - ) if ( grad_out != 1.0 ): # grad_out param is only kept for backward compatibility reasons @@ -818,16 +820,17 @@ def _get_analytical_vJu_backward_mode( return reduced_jacobians +@deprecated( + "`get_analytical_jacobian` was part of PyTorch's private API and not " + "meant to be exposed. We are deprecating it and it will be removed " + "in a future version of PyTorch. If you have a specific use for " + "this or feature request for this to be a stable API, please file " + "us an issue at https://github.com/pytorch/pytorch/issues/new", + category=FutureWarning, +) def get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0): # Replicates the behavior of the old get_analytical_jacobian before the refactor # This shares much of its code with _check_analytical_jacobian_attributes - warnings.warn( - "get_analytical_jacobian was part of PyTorch's private API and not " - "meant to be exposed. We are deprecating it and it will be removed " - "in a future version of PyTorch. If you have a specific use for " - "this or feature request for this to be a stable API, please file " - "us an issue at https://github.com/pytorch/pytorch/issues/new" - ) if ( grad_out != 1.0 ): # grad_out param is only kept for backward compatibility reasons diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 5da75b608a82..162dfe1eeaef 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -213,7 +213,10 @@ def __init__( self.use_cuda = use_cuda if self.use_cuda: warn( - "The attribute `use_cuda` will be deprecated soon, please use ``use_device = 'cuda'`` instead." + "The attribute `use_cuda` will be deprecated soon, " + "please use ``use_device = 'cuda'`` instead.", + FutureWarning, + stacklevel=2, ) self.use_device: Optional[str] = "cuda" else: diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py index f72d366a3677..cb573faf4410 100644 --- a/torch/autograd/profiler_legacy.py +++ b/torch/autograd/profiler_legacy.py @@ -1,5 +1,6 @@ import itertools -from warnings import warn +import warnings +from typing_extensions import deprecated import torch import torch.cuda @@ -23,6 +24,11 @@ __all__ = ["profile"] +@deprecated( + "`torch.autograd.profiler_legacy.profile` is deprecated and will be removed in a future release. " + "Please use `torch.profiler` instead.", + category=None, # TODO: change to `FutureWarning` +) class profile: """DEPRECATED: use torch.profiler instead.""" @@ -51,7 +57,10 @@ def __init__( self.with_modules = with_modules if self.use_cuda and not torch.cuda.is_available(): - warn("CUDA is not available, disabling CUDA profiling") + warnings.warn( + "CUDA is not available, disabling CUDA profiling", + stacklevel=2, + ) self.use_cuda = False if self.use_cuda: diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 4833f989b82a..23243733aaa8 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -6,6 +6,7 @@ from operator import attrgetter from typing import Any, Dict, List, Optional, Tuple +from typing_extensions import deprecated import torch from torch.autograd import DeviceType @@ -415,6 +416,10 @@ def device_time(self): return 0.0 if self.count == 0 else 1.0 * self.device_time_total / self.count # type: ignore[attr-defined] @property + @deprecated( + "`cuda_time` is deprecated, please use `device_time` instead.", + category=FutureWarning, + ) def cuda_time(self): # To be deprecated return self.device_time @@ -538,8 +543,12 @@ def self_device_memory_usage(self): ) @property + @deprecated( + "`self_cuda_memory_usage` is deprecated. Use `self_device_memory_usage` instead.", + category=FutureWarning, + ) def self_cuda_memory_usage(self): # To be deprecated - self.self_device_memory_usage + return self.self_device_memory_usage @property def cpu_time_total(self): @@ -574,8 +583,12 @@ def device_time_total(self): return self.time_range.elapsed_us() @property + @deprecated( + "`cuda_time_total` is deprecated. Use `device_time_total` instead.", + category=FutureWarning, + ) def cuda_time_total(self): # To be deprecated - self.device_time_total + return self.device_time_total @property def self_device_time_total(self): @@ -590,8 +603,12 @@ def self_device_time_total(self): return self.device_time_total @property + @deprecated( + "`self_cuda_time_total` is deprecated. Use `self_device_time_total` instead.", + category=FutureWarning, + ) def self_cuda_time_total(self): # To be deprecated - self.self_device_time_total + return self.self_device_time_total @property def key(self): diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index f1b68a446225..c35a962ba693 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -1,7 +1,7 @@ import contextlib -import warnings from typing import Union +from typing_extensions import deprecated import torch @@ -377,6 +377,15 @@ def enable_cudnn_sdp(enabled: bool): @contextlib.contextmanager +@deprecated( + ( + "`torch.backends.cuda.sdp_kernel()` is deprecated. " + "In the future, this context manager will be removed. " + "Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, " + "with updated signature." + ), + category=FutureWarning, +) def sdp_kernel( enable_flash: bool = True, enable_math: bool = True, @@ -389,15 +398,6 @@ def sdp_kernel( This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention. Upon exiting the context manager, the previous state of the flags will be restored. """ - warnings.warn( - ( - "torch.backends.cuda.sdp_kernel() " - "is deprecated. In the future, this context manager will be removed. " - "Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated " - "signature." - ), - FutureWarning, - ) from torch.nn.attention import sdpa_kernel backend_list = [] diff --git a/torch/cpu/amp/autocast_mode.py b/torch/cpu/amp/autocast_mode.py index 3f0a574f7d38..b545e91dd6f4 100644 --- a/torch/cpu/amp/autocast_mode.py +++ b/torch/cpu/amp/autocast_mode.py @@ -1,5 +1,5 @@ -import warnings from typing import Any +from typing_extensions import deprecated import torch @@ -12,6 +12,11 @@ class autocast(torch.amp.autocast_mode.autocast): ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cpu", args...)`` instead. """ + @deprecated( + "`torch.cpu.amp.autocast(args...)` is deprecated. " + "Please use `torch.amp.autocast('cpu', args...)` instead.", + category=FutureWarning, + ) def __init__( self, enabled: bool = True, @@ -23,10 +28,6 @@ def __init__( self.device = "cpu" self.fast_dtype = dtype return - warnings.warn( - "torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.", - DeprecationWarning, - ) super().__init__( "cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled ) diff --git a/torch/cpu/amp/grad_scaler.py b/torch/cpu/amp/grad_scaler.py index 2c93e0100f16..72b893a02a49 100644 --- a/torch/cpu/amp/grad_scaler.py +++ b/torch/cpu/amp/grad_scaler.py @@ -1,4 +1,4 @@ -import warnings +from typing_extensions import deprecated import torch @@ -11,6 +11,11 @@ class GradScaler(torch.amp.GradScaler): ``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead. """ + @deprecated( + "`torch.cpu.amp.GradScaler(args...)` is deprecated. " + "Please use `torch.amp.GradScaler('cpu', args...)` instead.", + category=FutureWarning, + ) def __init__( self, init_scale: float = 2.0**16, @@ -19,9 +24,6 @@ def __init__( growth_interval: int = 2000, enabled: bool = True, ) -> None: - warnings.warn( - "torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead." - ) super().__init__( "cpu", init_scale=init_scale, diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index 587d7e9c7c5e..7d211fd3b8cb 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -145,8 +145,8 @@ def _seg_info(seg): before_segs = {_seg_key(seg) for seg in before} after_segs = {_seg_key(seg) for seg in after} - print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}') - print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}') + print(f'only_before = {[a for a, _ in (before_segs - after_segs)]}') + print(f'only_after = {[a for a, _ in (after_segs - before_segs)]}') for seg in before: if _seg_key(seg) not in after_segs: @@ -382,7 +382,11 @@ def find_segment(addr): def _format_viz(data, viz_kind, device): if device is not None: - warnings.warn('device argument is deprecated, plots now contain all device') + warnings.warn( + 'device argument is deprecated, plots now contain all device', + FutureWarning, + stacklevel=3, + ) buffer = pickle.dumps(data) buffer += b'\x00' * (3 - len(buffer) % 3) # Encode the buffer with base64 diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index 79d7c3cc1344..eb17d7a75e69 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -1,6 +1,6 @@ import functools -import warnings from typing import Any +from typing_extensions import deprecated import torch @@ -13,6 +13,11 @@ class autocast(torch.amp.autocast_mode.autocast): ``torch.cuda.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` instead. """ + @deprecated( + "`torch.cuda.amp.autocast(args...)` is deprecated. " + "Please use `torch.amp.autocast('cuda', args...)` instead.", + category=FutureWarning, + ) def __init__( self, enabled: bool = True, @@ -24,10 +29,6 @@ def __init__( self.device = "cuda" self.fast_dtype = dtype return - warnings.warn( - "torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.", - DeprecationWarning, - ) super().__init__( "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled ) @@ -50,29 +51,38 @@ def __call__(self, func): # Preserved only for BC reasons +@deprecated( + "`torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. " + "Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.", + category=FutureWarning, +) def _cast(value, dtype): return torch.amp.autocast_mode._cast(value, "cuda", dtype) +@deprecated( + "`torch.cuda.amp.custom_fwd(args...)` is deprecated. " + "Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.", + category=FutureWarning, +) def custom_fwd(fwd=None, *, cast_inputs=None): """ ``torch.cuda.amp.custom_fwd(args...)`` is deprecated. Please use ``torch.amp.custom_fwd(args..., device_type='cuda')`` instead. """ - warnings.warn( - "torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead." - ) return functools.partial(torch.amp.custom_fwd, device_type="cuda")( fwd=fwd, cast_inputs=cast_inputs ) +@deprecated( + "`torch.cuda.amp.custom_bwd(args...)` is deprecated. " + "Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.", + category=FutureWarning, +) def custom_bwd(bwd): """ ``torch.cuda.amp.custom_bwd(args...)`` is deprecated. Please use ``torch.amp.custom_bwd(args..., device_type='cuda')`` instead. """ - warnings.warn( - "torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead." - ) return functools.partial(torch.amp.custom_bwd, device_type="cuda")(bwd) diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py index 8263fcdb480d..367f21594f1c 100644 --- a/torch/cuda/amp/grad_scaler.py +++ b/torch/cuda/amp/grad_scaler.py @@ -1,4 +1,4 @@ -import warnings +from typing_extensions import deprecated import torch @@ -11,6 +11,11 @@ class GradScaler(torch.amp.GradScaler): ``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead. """ + @deprecated( + "`torch.cuda.amp.GradScaler(args...)` is deprecated. " + "Please use `torch.amp.GradScaler('cuda', args...)` instead.", + category=FutureWarning, + ) def __init__( self, init_scale: float = 2.0**16, @@ -19,9 +24,6 @@ def __init__( growth_interval: int = 2000, enabled: bool = True, ) -> None: - warnings.warn( - "torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead." - ) super().__init__( "cuda", init_scale=init_scale, diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index a593a3810834..0f12395ac778 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -9,6 +9,7 @@ from inspect import signature from typing import Any, Dict, Optional, Tuple, Union +from typing_extensions import deprecated import torch from torch import _C @@ -446,21 +447,21 @@ def max_memory_reserved(device: Union[Device, int] = None) -> int: return memory_stats(device=device).get("reserved_bytes.all.peak", 0) +@deprecated( + "`torch.cuda.memory_cached` has been renamed to `torch.cuda.memory_reserved`", + category=FutureWarning, +) def memory_cached(device: Union[Device, int] = None) -> int: r"""Deprecated; see :func:`~torch.cuda.memory_reserved`.""" - warnings.warn( - "torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved", - FutureWarning, - ) return memory_reserved(device=device) +@deprecated( + "`torch.cuda.max_memory_cached` has been renamed to `torch.cuda.max_memory_reserved`", + category=FutureWarning, +) def max_memory_cached(device: Union[Device, int] = None) -> int: r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`.""" - warnings.warn( - "torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved", - FutureWarning, - ) return max_memory_reserved(device=device) diff --git a/torch/cuda/nccl.py b/torch/cuda/nccl.py index 05751ab5f87b..f1332c968d69 100644 --- a/torch/cuda/nccl.py +++ b/torch/cuda/nccl.py @@ -89,8 +89,10 @@ def reduce( ) else: warnings.warn( - "nccl.reduce with an output tensor list is deprecated. " - "Please specify a single output tensor with argument 'output' instead instead." + "`nccl.reduce` with an output tensor list is deprecated. " + "Please specify a single output tensor with argument 'output' instead instead.", + FutureWarning, + stacklevel=2, ) _output = outputs[root] elif not isinstance(output, torch.Tensor) and isinstance( @@ -99,7 +101,9 @@ def reduce( # User called old API with positional arguments of list of output tensors. warnings.warn( "nccl.reduce with an output tensor list is deprecated. " - "Please specify a single output tensor." + "Please specify a single output tensor.", + FutureWarning, + stacklevel=2, ) _output = output[root] else: diff --git a/torch/distributed/_composable/fully_shard.py b/torch/distributed/_composable/fully_shard.py index 37e3d1544cd1..950a034071a4 100644 --- a/torch/distributed/_composable/fully_shard.py +++ b/torch/distributed/_composable/fully_shard.py @@ -1,5 +1,5 @@ -import warnings from typing import Callable, Iterable, Optional, Union +from typing_extensions import deprecated import torch import torch.distributed as dist @@ -38,6 +38,13 @@ @contract(state_cls=_FSDPState) +@deprecated( + "`torch.distributed._composable.fully_shard` is being deprecated. " + "You can continue to use the wrapper based FSDP. " + "See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py. " + "`torch.distributed._composable.fully_shard` will be removed after PyTorch 2.5.", + category=FutureWarning, +) def fully_shard( module: nn.Module, *, @@ -55,16 +62,7 @@ def fully_shard( Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] ] = None, ) -> nn.Module: - """ - Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``. - """ - warnings.warn( - "``torch.distributed._composable.fully_shard`` is being deprecated." - "You can contintue to use the wrapper based FSDP." - "See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py." - "``torch.distributed._composable.fully_shard`` will be removed after PyTorch 2.5." - ) - + """Applies ``FullyShardedDataParallel`` (FSDP) semantics to ``module``.""" torch._C._log_api_usage_once("torch.distributed.fully_shard") # Enforce the new auto wrap policy if policy is not None and not isinstance(policy, _Policy): diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 8d598713cf50..d170410061b1 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -766,7 +766,9 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: warnings.warn( "The combination of ranks + tag as process group " "identifier has been deprecated. Please switch to " - "using ProcessGroup, DeviceMesh, or group name instead." + "using ProcessGroup, DeviceMesh, or group name instead.", + FutureWarning, + stacklevel=3, ) return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag) else: diff --git a/torch/distributed/_shard/checkpoint/__init__.py b/torch/distributed/_shard/checkpoint/__init__.py index 166c6f9254cf..161a43f276d6 100644 --- a/torch/distributed/_shard/checkpoint/__init__.py +++ b/torch/distributed/_shard/checkpoint/__init__.py @@ -5,8 +5,15 @@ import warnings from torch.distributed.checkpoint import * # noqa: F403 -warnings.warn( - "torch.distributed._shard.checkpoint will be deprecated, use torch.distributed.checkpoint instead", - DeprecationWarning -) + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._shard.checkpoint` will be deprecated, " + "use `torch.distributed.checkpoint` instead", + DeprecationWarning, + stacklevel=2, + ) + sys.modules['torch.distributed._shard.checkpoint'] = torch.distributed.checkpoint diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index a5e961e4bb78..79944953fd40 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -10,6 +10,7 @@ cast, TYPE_CHECKING, ) +from typing_extensions import deprecated import copy import warnings from functools import reduce @@ -396,7 +397,11 @@ def shard_size(shard_md): return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined] if enforce_dtype: - warnings.warn("enforce_dtype is deprecated. Please use dtype instead.") + warnings.warn( + "`enforce_dtype` is deprecated. Please use `dtype` instead.", + FutureWarning, + stacklevel=2, + ) rank = dist.get_rank(self._process_group) full_size = self.metadata().size @@ -737,6 +742,7 @@ def _init_from_local_shards( return sharded_tensor @classmethod + @deprecated(DEPRECATE_MSG, category=FutureWarning) def _init_from_local_tensor( cls, local_tensor: torch.Tensor, @@ -801,8 +807,6 @@ def _init_from_local_tensor( We fully rely on the user to ensure local tensor is sharded based on the sharding spec. """ - warnings.warn(DEPRECATE_MSG) - if not local_tensor.is_contiguous(): raise ValueError('local_tensor is not a contiguous Tensor.') @@ -980,6 +984,7 @@ def sharding_spec(self) -> shard_spec.ShardingSpec: """ return self._sharding_spec + @deprecated(DEPRECATE_MSG, category=FutureWarning) def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: """ Reshard a sharded tensor given the ``resharding_spec``. For now, we only support @@ -1050,8 +1055,6 @@ def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2 tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3 """ - warnings.warn(DEPRECATE_MSG) - if ( not isinstance(resharding_spec, shard_spec.ChunkShardingSpec) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec) @@ -1096,6 +1099,7 @@ def local_tensor(self) -> torch.Tensor: return self.local_shards()[0].tensor @classmethod + @deprecated(DEPRECATE_MSG, category=FutureWarning) def __torch_function__(cls, func, types, args=(), kwargs=None): def dispatch(st: ShardedTensor, func: Callable): # Dispatch to custom user provided op first if it exists. @@ -1120,7 +1124,6 @@ def dispatch(st: ShardedTensor, func: Callable): f"torch function '{func.__name__}', with args: {args} and " f"kwargs: {kwargs} not supported for ShardedTensor!") - warnings.warn(DEPRECATE_MSG) # Find ShardedTensor instance to get process_group and sharding_spec. st_instance = None diff --git a/torch/distributed/_sharded_tensor/__init__.py b/torch/distributed/_sharded_tensor/__init__.py index 9e6b1662589c..6c6694cfb081 100644 --- a/torch/distributed/_sharded_tensor/__init__.py +++ b/torch/distributed/_sharded_tensor/__init__.py @@ -5,8 +5,14 @@ import warnings from torch.distributed._shard.sharded_tensor import * # noqa: F403 -warnings.warn( - "torch.distributed._sharded_tensor will be deprecated, use torch.distributed._shard.sharded_tensor instead", - DeprecationWarning -) + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._sharded_tensor` will be deprecated, " + "use `torch.distributed._shard.sharded_tensor` instead", + DeprecationWarning, + stacklevel=2, + ) + sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor diff --git a/torch/distributed/_sharding_spec/__init__.py b/torch/distributed/_sharding_spec/__init__.py index f3060005dbdd..21c56d5dc849 100644 --- a/torch/distributed/_sharding_spec/__init__.py +++ b/torch/distributed/_sharding_spec/__init__.py @@ -5,10 +5,15 @@ import warnings from torch.distributed._shard.sharding_spec import * # noqa: F403 -warnings.warn( - "torch.distributed._sharding_spec will be deprecated, use torch.distributed._shard.sharding_spec instead", - DeprecationWarning -) + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._sharding_spec` will be deprecated, " + "use `torch.distributed._shard.sharding_spec` instead", + DeprecationWarning, + stacklevel=2, + ) import torch.distributed._shard.sharding_spec as _sharding_spec sys.modules['torch.distributed._sharding_spec'] = _sharding_spec diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index ba25c628a83e..0a3f89af3c20 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -746,6 +746,8 @@ def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: warnings.warn( "Deprecating input_fn that takes two arguments (inputs, device_mesh), " "please use input_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, ) module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] elif num_args == 3: @@ -765,6 +767,8 @@ def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: warnings.warn( "Deprecating output_fn that takes two arguments (inputs, device_mesh), " "please use output_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, ) module.register_forward_hook( lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index 364648f1a7f7..24a079849df7 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -233,7 +233,8 @@ def checkpoint_wrapper( f"Please specify {CheckpointImpl.NO_REENTRANT} as " f"{CheckpointImpl.REENTRANT} will soon be removed as " "the default and eventually deprecated.", - stacklevel=1, + FutureWarning, + stacklevel=2, ) return CheckpointWrapper( module, diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index df3cc945832c..b8ad6f61da14 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -1,6 +1,7 @@ import os import warnings from typing import Any, cast, Dict, Optional, Set, Union +from typing_extensions import deprecated import torch import torch.distributed as dist @@ -17,6 +18,11 @@ __all__ = ["load_state_dict", "load"] +@deprecated( + "`load_state_dict` is deprecated and will be removed in future versions. " + "Please use `load` instead.", + category=FutureWarning, +) def load_state_dict( state_dict: Dict[str, Any], storage_reader: StorageReader, @@ -26,10 +32,6 @@ def load_state_dict( planner: Optional[LoadPlanner] = None, ) -> None: """This method is deprecated. Please switch to 'load'.""" - warnings.warn( - "'load_state_dict' is deprecated and will be removed in future versions. " - "Please use 'load' instead." - ) storage_reader.reset() with _profile(): # TODO: test returning `load` here instead. diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index 0313f1c2ab61..451603288d12 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -3,6 +3,7 @@ import warnings from concurrent.futures import Future, ThreadPoolExecutor from typing import cast, Optional, Union +from typing_extensions import deprecated import torch import torch.distributed as dist @@ -24,6 +25,11 @@ __all__ = ["save_state_dict", "save", "async_save"] +@deprecated( + "`save_state_dict` is deprecated and will be removed in future versions." + "Please use `save` instead.", + category=FutureWarning, +) def save_state_dict( state_dict: STATE_DICT_TYPE, storage_writer: StorageWriter, @@ -33,11 +39,6 @@ def save_state_dict( planner: Optional[SavePlanner] = None, ) -> Metadata: """This method is deprecated. Please switch to 'save'.""" - warnings.warn( - "'save_state_dict' is deprecated and will be removed in future versions." - "Please use 'save' instead." - ) - storage_writer.reset() # TODO: test returning `save` here instead. diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 6fc505c78110..17152f0a87ed 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -14,6 +14,7 @@ from collections import namedtuple from datetime import timedelta from typing import Any, Callable, Dict, Optional, Tuple, Union, List, TYPE_CHECKING +from typing_extensions import deprecated import torch from torch._C._distributed_c10d import ( @@ -364,11 +365,12 @@ def __init__(self): setattr(self, k, v) self.__members__ = ReduceOp.RedOpType.__members__ + @deprecated( + "`torch.distributed.reduce_op` is deprecated, " + "please use `torch.distributed.ReduceOp` instead", + category=FutureWarning, + ) def __getattribute__(self, key): - warnings.warn( - "torch.distributed.reduce_op is deprecated, please use " - "torch.distributed.ReduceOp instead" - ) return object.__getattribute__(self, key) @@ -675,7 +677,9 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device warnings.warn( f"You are using a Backend {type(group)} as a ProcessGroup. " "This usage is deprecated since PyTorch 2.0. Please use a public API " - "of PyTorch Distributed instead." + "of PyTorch Distributed instead.", + FutureWarning, + stacklevel=3, ) # Most users create Gloo with private API for object collectives _world.pg_default_device[group] = torch.device("cpu") @@ -829,13 +833,15 @@ def get_global_rank(group: ProcessGroup, group_rank: int) -> int: return rank raise ValueError(f"Group rank {group_rank} is not part of group {group}") + # TODO: remove this once the ecosystem moves away from it. +@deprecated( + "`torch.distributed.distributed_c10d._get_global_rank` is deprecated, " + "please use `torch.distributed.distributed_c10d.get_global_rank` instead", + category=FutureWarning, +) def _get_global_rank(group, rank) -> int: """Use get_global_rank as this method is deprecated.""" - warnings.warn( - "torch.distributed.distributed_c10d._get_global_rank is deprecated " - "please use torch.distributed.distributed_c10d.get_global_rank instead" - ) return get_global_rank(group, rank) @@ -2286,6 +2292,12 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): work.wait() @_exception_logger +@deprecated( + "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must " + "use it, please revisit our documentation later at " + "https://pytorch.org/docs/main/distributed.html#collective-functions", + category=FutureWarning, +) def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): """ WARNING: at this time individual shape checking is not implemented across nodes. @@ -2320,11 +2332,6 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): None, if not async_op or if not part of the group. """ - warnings.warn( - "torch.distributed.all_reduce_coalesced will be deprecated. If you must " - "use it, please revisit our documentation later at " - "https://pytorch.org/docs/main/distributed.html#collective-functions" - ) if isinstance(tensors, torch.Tensor): tensors = [tensors] _check_tensor_list(tensors, "tensor") @@ -3198,6 +3205,11 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal @_exception_logger +@deprecated( + "`torch.distributed._all_gather_base` is a private function and will be deprecated. " + "Please use `torch.distributed.all_gather_into_tensor` instead.", + category=FutureWarning, +) def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False): """ Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. @@ -3219,15 +3231,16 @@ def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False): `all_gather_into_tensor` instead. """ - warnings.warn( - "torch.distributed._all_gather_base is a private function and will be " - "deprecated. Please use torch.distributed.all_gather_into_tensor " - "instead." - ) return all_gather_into_tensor(output_tensor, input_tensor, group, async_op) @_exception_logger +@deprecated( + "`torch.distributed.all_gather_coalesced` will be deprecated. If you must use it, " + "please revisit our documentation later at " + "https://pytorch.org/docs/main/distributed.html#collective-functions", + category=FutureWarning, +) def all_gather_coalesced( output_tensor_lists, input_tensor_list, group=None, async_op=False ): @@ -3274,11 +3287,6 @@ def all_gather_coalesced( performance improvements but users of this function should take extra care to ensure that each node passes in tensors whose shapes match across nodes. """ - warnings.warn( - "torch.distributed.all_gather_coalesced will be deprecated. If you must " - "use it, please revisit our documentation later at " - "https://pytorch.org/docs/main/distributed.html#collective-functions" - ) # We only check basic compatibility with C++ params here, C++ code will # do shape and type checking. if _rank_not_in_group(group): @@ -3608,6 +3616,11 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F work.wait() +@deprecated( + "`torch.distributed._reduce_scatter_base` is a private function and will be deprecated. " + "Please use `torch.distributed.reduce_scatter_tensor` instead.", + category=FutureWarning, +) def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False): """ Reduces, then scatters a flattened tensor to all processes in a group. @@ -3628,11 +3641,6 @@ def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=Fa `reduce_scatter_tensor` instead. """ - warnings.warn( - "torch.distributed._reduce_scatter_base is a private function and will " - "be deprecated. Please use torch.distributed.reduce_scatter_tensor " - "instead." - ) return reduce_scatter_tensor(output, input, op, group, async_op) diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 1499943c78d2..11a3930acf70 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -8,10 +8,10 @@ import abc import time -import warnings from collections import namedtuple from functools import wraps from typing import Dict, Optional +from typing_extensions import deprecated __all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream', 'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms', @@ -137,6 +137,7 @@ def wrapper(*args, **kwargs): return wrap +@deprecated("Deprecated, use `@prof` instead", category=FutureWarning) def profile(group=None): """ @profile decorator adds latency and success/failure metrics to any given function. @@ -148,8 +149,6 @@ def profile(group=None): @metrics.profile("my_metric_group") def some_function(): """ - warnings.warn("Deprecated, use @prof instead", DeprecationWarning) - def wrap(func): @wraps(func) def wrapper(*args, **kwargs): @@ -187,10 +186,11 @@ def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchel getStream(metric_group).add_value(metric_name, metric_value) +@deprecated( + "Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead", + category=FutureWarning, +) def publish_metric(metric_group: str, metric_name: str, metric_value: int): - warnings.warn( - "Deprecated, use put_metric(metric_group)(metric_name, metric_value) instead" - ) metric_stream = getStream(metric_group) metric_stream.add_value(metric_name, metric_value) diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index ddd48c48a0ac..2364b1871206 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -446,7 +446,8 @@ def _init_core_state( elif sharding_strategy == ShardingStrategy.NO_SHARD: warnings.warn( "The `NO_SHARD` sharding strategy is deprecated. If having issues, " - "please use DistributedDataParallel instead.", + "please use `DistributedDataParallel` instead.", + FutureWarning, # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and # level 3 is from the true caller stacklevel=3, diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 766fb76bbd09..c798ed1818d7 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -1198,11 +1198,13 @@ def clip_grad_norm_( return total_norm.to(total_norm_dtype) @staticmethod - def _warn_optim_input(optim_input): + def _warn_optim_input(optim_input, *, stacklevel: int = 1): if optim_input is not None: warnings.warn( - "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. You may remove it " - "from your code without changing its functionality." + "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. " + "You may remove it from your code without changing its functionality.", + FutureWarning, + stacklevel=stacklevel + 1, ) @staticmethod @@ -1217,11 +1219,13 @@ def _is_using_optim_input(optim_input, optim) -> bool: return False @staticmethod - def _warn_legacy_optim_state_dict(curr: str, new: str): + def _warn_legacy_optim_state_dict(curr: str, new: str, *, stacklevel: int = 1): warnings.warn( f"``FullyShardedDataParallel.{curr}``is being deprecated and is " f"replaced by ``FullyShardedDataParallel.{new}``. " - f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2." + f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2.", + FutureWarning, + stacklevel=stacklevel + 1, ) @staticmethod @@ -1239,6 +1243,8 @@ def _optim_state_dict_impl( full_state_dict: bool = True, group: Optional[dist.ProcessGroup] = None, cpu_offload: bool = True, + *, + _stacklevel: int = 1, ) -> Dict[str, Any]: """Transform the state-dict of an optimizer corresponding to a sharded model. @@ -1247,7 +1253,9 @@ def _optim_state_dict_impl( FSDP internal information and internal sharding from the optim_state_dict. """ if full_state_dict: - FullyShardedDataParallel._warn_optim_input(optim_input) + FullyShardedDataParallel._warn_optim_input( + optim_input, stacklevel=_stacklevel + 1 + ) using_optim_input = FullyShardedDataParallel._is_using_optim_input( optim_input, optim, @@ -1398,7 +1406,9 @@ def full_optim_state_dict( then nonzero ranks return an empty :class:`dict`. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( - "full_optim_state_dict", "optim_state_dict" + "full_optim_state_dict", + "optim_state_dict", + stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_impl( model=model, @@ -1408,6 +1418,7 @@ def full_optim_state_dict( rank0_only=rank0_only, group=group, full_state_dict=True, + _stacklevel=2, ) @staticmethod @@ -1429,7 +1440,9 @@ def sharded_optim_state_dict( cannot be directly used by the regular ``optim.load_state_dict``. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( - "sharded_optim_state_dict", "optim_state_dict" + "sharded_optim_state_dict", + "optim_state_dict", + stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_impl( model=model, @@ -1439,6 +1452,7 @@ def sharded_optim_state_dict( rank0_only=False, full_state_dict=False, group=group, + _stacklevel=2, ) @staticmethod @@ -1507,7 +1521,9 @@ def shard_full_optim_state_dict( restricted to only include this rank's part of the optimizer state. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( - "shard_full_optim_state_dict", "optim_state_dict_to_load" + "shard_full_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_to_load_impl( optim_state_dict=full_optim_state_dict, @@ -1544,7 +1560,9 @@ def flatten_sharded_optim_state_dict( Refer to :meth:`shard_full_optim_state_dict`. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( - "flatten_sharded_optim_state_dict", "optim_state_dict_to_load" + "flatten_sharded_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_to_load_impl( optim_state_dict=sharded_optim_state_dict, @@ -1624,7 +1642,9 @@ def scatter_full_optim_state_dict( restricted to only include this rank's part of the optimizer state. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( - "scatter_full_optim_state_dict", "optim_state_dict_to_load" + "scatter_full_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_to_load_impl( optim_state_dict=full_optim_state_dict, @@ -1855,6 +1875,7 @@ def optim_state_dict( cpu_offload=getattr( state_dict_settings.optim_state_dict_config, "offload_to_cpu", True ), + _stacklevel=2, ) @staticmethod diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index c95804b8e8bb..3efb0c3cf31d 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -159,7 +159,7 @@ """ -import warnings +from typing_extensions import deprecated as _deprecated from torch.distributed.run import get_args_parser, run @@ -188,17 +188,17 @@ def launch(args): run(args) +@_deprecated( + "The module torch.distributed.launch is deprecated\n" + "and will be removed in future. Use torchrun.\n" + "Note that --use-env is set by default in torchrun.\n" + "If your script expects `--local-rank` argument to be set, please\n" + "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" + "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" + "further instructions\n", + category=FutureWarning, +) def main(args=None): - warnings.warn( - "The module torch.distributed.launch is deprecated\n" - "and will be removed in future. Use torchrun.\n" - "Note that --use-env is set by default in torchrun.\n" - "If your script expects `--local-rank` argument to be set, please\n" - "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" - "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" - "further instructions\n", - FutureWarning, - ) args = parse_args(args) launch(args) diff --git a/torch/distributed/optim/__init__.py b/torch/distributed/optim/__init__.py index 0b576c65afea..fe33265fd532 100644 --- a/torch/distributed/optim/__init__.py +++ b/torch/distributed/optim/__init__.py @@ -5,6 +5,8 @@ optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to apply the gradients on each worker. """ +import warnings + import torch from torch import optim @@ -24,9 +26,15 @@ from .named_optimizer import _NamedOptimizer from .utils import as_functional_optim -from warnings import warn -warn("TorchScript support for functional optimizers is" - "deprecated and will be removed in a future PyTorch release. Consider using the torch.compile optimizer instead.") +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`TorchScript` support for functional optimizers is deprecated " + "and will be removed in a future PyTorch release. " + "Consider using the `torch.compile` optimizer instead.", + DeprecationWarning, + stacklevel=2, + ) # DistributedOptimizer imports torch.distributed.rpc names, so gate availability # based on RPC being available. diff --git a/torch/distributed/pipeline/__init__.py b/torch/distributed/pipeline/__init__.py index 5bc82f0692c1..eacd2bc99d04 100644 --- a/torch/distributed/pipeline/__init__.py +++ b/torch/distributed/pipeline/__init__.py @@ -1,7 +1,13 @@ import warnings -warnings.warn( - "torch.distributed.pipeline is deprecated. For up-to-date pipeline parallel " - "implementation, please refer to the PiPPy library under the PyTorch " - "organization (Pipeline Parallelism for PyTorch): " - "https://github.com/pytorch/PiPPy" -) + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed.pipeline` is deprecated. For up-to-date pipeline parallel " + "implementation, please refer to the PiPPy library under the PyTorch " + "organization (Pipeline Parallelism for PyTorch): " + "https://github.com/pytorch/PiPPy", + DeprecationWarning, + stacklevel=2, + ) diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index c31170a0cd57..e109f5e9af93 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -22,7 +22,11 @@ def _deprecate_warnings(func_name: str, extra_msg: str) -> None: """ # TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo. if not is_torchdynamo_compiling(): - warnings.warn(f"{func_name} is deprecated and will be removed soon. {extra_msg}") + warnings.warn( + f"{func_name} is deprecated and will be removed soon. {extra_msg}", + FutureWarning, + stacklevel=3, + ) def _validate_tp_mesh_dim( diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 2752d710e8fb..2fb05828a8b3 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -1,5 +1,6 @@ import warnings from typing import Any, Dict, Optional, Tuple +from typing_extensions import deprecated import torch from torch.distributions import constraints @@ -171,14 +172,15 @@ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """ raise NotImplementedError + @deprecated( + "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.", + category=FutureWarning, + ) def sample_n(self, n: int) -> torch.Tensor: """ Generates n samples or n batches of samples if the distribution parameters are batched. """ - warnings.warn( - "sample_n will be deprecated. Use .sample((n,)) instead", UserWarning - ) return self.sample(torch.Size((n,))) def log_prob(self, value: torch.Tensor) -> torch.Tensor: diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index d2a8e6bfc7ff..c46e47e5d35b 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -1,5 +1,6 @@ from warnings import warn import inspect +from typing_extensions import deprecated from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning from .utils import expand_tuples from .variadic import Variadic, isvariadic @@ -27,24 +28,21 @@ def ambiguity_warn(dispatcher, ambiguities): warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) +@deprecated( + "`halt_ordering` is deprecated, you can safely remove this call.", + category=FutureWarning, +) def halt_ordering(): - """Deprecated interface to temporarily disable ordering. - """ - warn( - 'halt_ordering is deprecated, you can safely remove this call.', - DeprecationWarning, - ) + """Deprecated interface to temporarily disable ordering.""" +@deprecated( + "`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, " + "you should call the `reorder()` method on each dispatcher.", + category=FutureWarning, +) def restart_ordering(on_ambiguity=ambiguity_warn): - """Deprecated interface to temporarily resume ordering. - """ - warn( - 'restart_ordering is deprecated, if you would like to eagerly order' - 'the dispatchers, you should call the ``reorder()`` method on each' - ' dispatcher.', - DeprecationWarning, - ) + """Deprecated interface to temporarily resume ordering.""" def variadic_signature_matches_iter(types, full_signature): @@ -316,14 +314,12 @@ def dispatch_iter(self, *types): result = self.funcs[signature] yield result + @deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning) def resolve(self, types): """ Determine appropriate implementation for this type signature .. deprecated:: 0.4.4 Use ``dispatch(*types)`` instead """ - warn("resolve() is deprecated, use dispatch(*types)", - DeprecationWarning) - return self.dispatch(*types) def __getstate__(self): diff --git a/torch/hub.py b/torch/hub.py index 286dfbaa59b2..4ea92ed6be82 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -13,6 +13,7 @@ import zipfile from pathlib import Path from typing import Dict, Optional, Any +from typing_extensions import deprecated from urllib.error import HTTPError, URLError from urllib.request import urlopen, Request from urllib.parse import urlparse # noqa: F401 @@ -680,10 +681,13 @@ def _is_legacy_zip_format(filename: str) -> bool: return False +@deprecated( + 'Falling back to the old format < 1.6. This support will be ' + 'deprecated in favor of default zipfile format introduced in 1.6. ' + 'Please redo torch.save() to save it in the new zipfile format.', + category=FutureWarning, +) def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]: - warnings.warn('Falling back to the old format < 1.6. This support will be ' - 'deprecated in favor of default zipfile format introduced in 1.6. ' - 'Please redo torch.save() to save it in the new zipfile format.') # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. # We deliberately don't handle tarfile here since our legacy serialization format was in tar. # E.g. resnet18-5c106cde.pth which is widely used. diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 8c223c66318c..7327a204fccc 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1094,7 +1094,10 @@ def _script_impl( if optimize is not None: warnings.warn( - "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" + "`optimize` is deprecated and has no effect. " + "Use `with torch.jit.optimized_execution()` instead", + FutureWarning, + stacklevel=3, ) # No-op for modules, functions, class instances that are already scripted diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 9dbbce88db7b..17914a5a444d 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -978,7 +978,10 @@ def forward(self, x): return func if optimize is not None: warnings.warn( - "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" + "`optimize` is deprecated and has no effect. " + "Use `with torch.jit.optimized_execution()` instead", + FutureWarning, + stacklevel=2, ) from torch._utils_internal import ( @@ -1185,7 +1188,10 @@ def weighted_kernel_sum(self, weight): return mod if optimize is not None: warnings.warn( - "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" + "`optimize` is deprecated and has no effect. " + "Use `with torch.jit.optimized_execution()` instead", + FutureWarning, + stacklevel=2, ) var_lookup_fn = _create_interpreter_name_lookup_fn(0) diff --git a/torch/library.py b/torch/library.py index 68aefb84a206..a69e16950f7e 100644 --- a/torch/library.py +++ b/torch/library.py @@ -1,5 +1,6 @@ from ._ops import OpOverload from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence +from typing_extensions import deprecated import traceback import torch import weakref @@ -8,7 +9,6 @@ import re import contextlib import sys -import warnings from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t, CustomOpDef import torch._library as _library @@ -451,15 +451,15 @@ def wrap(f): return wrap +@deprecated( + "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that " + "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.", + category=FutureWarning, +) def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4. Please use that instead. """ - warnings.warn("torch.library.impl_abstract was renamed to " - "torch.library.register_fake. Please use that instead; " - "we will remove torch.library.impl_abstract in a future " - "version of PyTorch.", - DeprecationWarning, stacklevel=2) if func is not None: _stacklevel = _stacklevel + 1 return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel) diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index a6ddc0102ce4..88bdc5155342 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -277,5 +277,5 @@ def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): "To use a different start_method use:\n\t\t" " torch.multiprocessing.start_processes(...)" ) - warnings.warn(msg) + warnings.warn(msg, FutureWarning, stacklevel=2) return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 38d4bd5756fd..f67e2ddee04a 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1818,7 +1818,8 @@ def softsign(input): # noqa: D400,D402 def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: warnings.warn( - f"Implicit dimension choice for {name} has been deprecated. Change the call to include dim=X as an argument.", + f"Implicit dimension choice for {name} has been deprecated. " + "Change the call to include dim=X as an argument.", stacklevel=stacklevel, ) if ndim == 0 or ndim == 1 or ndim == 3: @@ -3823,7 +3824,11 @@ def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners= affects the outputs. """ - warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.") + warnings.warn( + "`nn.functional.upsample` is deprecated. " + "Use `nn.functional.interpolate` instead.", + stacklevel=2, + ) return interpolate(input, size, scale_factor, mode, align_corners) @@ -4143,7 +4148,11 @@ def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 {backward_reproducibility_note} """ # DeprecationWarning is ignored by default - warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.") + warnings.warn( + "`nn.functional.upsample_nearest` is deprecated. " + "Use `nn.functional.interpolate` instead.", + stacklevel=2, + ) return interpolate(input, size, scale_factor, mode="nearest") @@ -4199,7 +4208,11 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 {backward_reproducibility_note} """ # DeprecationWarning is ignored by default - warnings.warn("nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead.") + warnings.warn( + "`nn.functional.upsample_bilinear` is deprecated. " + "Use `nn.functional.interpolate` instead.", + stacklevel=2, + ) return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) diff --git a/torch/nn/init.py b/torch/nn/init.py index 426069d780c0..f5be081e7dd0 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -599,7 +599,11 @@ def _make_deprecate(meth): old_name = new_name[:-1] def deprecated_init(*args, **kwargs): - warnings.warn(f"nn.init.{old_name} is now deprecated in favor of nn.init.{new_name}.", stacklevel=2) + warnings.warn( + f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.", + FutureWarning, + stacklevel=2, + ) return meth(*args, **kwargs) deprecated_init.__doc__ = fr""" diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index bf15c3342d1d..5dec6f9578b1 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -219,10 +219,18 @@ def __init__( ) -> None: super().__init__() if min_value is not None: - warnings.warn("keyword argument min_value is deprecated and rename to min_val") + warnings.warn( + "keyword argument `min_value` is deprecated and rename to `min_val`", + FutureWarning, + stacklevel=2, + ) min_val = min_value if max_value is not None: - warnings.warn("keyword argument max_value is deprecated and rename to max_val") + warnings.warn( + "keyword argument `max_value` is deprecated and rename to `max_val`", + FutureWarning, + stacklevel=2, + ) max_val = max_value self.min_val = min_val diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 1b5659d4b7e9..775a826d69cc 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -1,4 +1,3 @@ -import warnings from collections import OrderedDict, abc as container_abcs from itertools import chain, islice import operator @@ -10,6 +9,7 @@ from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union from typing_extensions import Self +from typing_extensions import deprecated __all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict'] @@ -29,13 +29,14 @@ def _addindent(s_, numSpaces): return s +@deprecated( + "`nn.Container` is deprecated. " + "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.", + category=FutureWarning, +) class Container(Module): - def __init__(self, **kwargs: Any) -> None: super().__init__() - # DeprecationWarning is ignored by default - warnings.warn("nn.Container is deprecated. All of it's functionality " - "is now implemented in nn.Module. Subclass that instead.") for key, value in kwargs.items(): self.add_module(key, value) diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 075d5e9865e6..4ab4c8bff9fc 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -1,5 +1,4 @@ import math -import warnings import torch from torch import Tensor @@ -13,6 +12,7 @@ from ..common_types import _size_1_t, _size_2_t, _size_3_t from typing import Optional, List, Tuple, Union +from typing_extensions import deprecated __all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'LazyConvTranspose1d', 'LazyConvTranspose2d', @@ -40,9 +40,6 @@ :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`."""} # noqa: B950 - - - class _ConvNd(Module): __constants__ = ['stride', 'padding', 'dilation', 'groups', @@ -610,7 +607,6 @@ def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.weight, self.bias) - class _ConvTransposeNd(_ConvNd): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, @@ -1121,10 +1117,13 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten # `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as # above would still work). class _ConvTransposeMixin(_ConvTransposeNd): + + @deprecated( + "`_ConvTransposeMixin` is a deprecated internal class. " + "Please consider using public APIs.", + category=FutureWarning, + ) def __init__(self, *args, **kwargs): - warnings.warn( - "_ConvTransposeMixin is a deprecated internal class. " - "Please consider using public APIs.") super().__init__(*args, **kwargs) diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index ee034bf458a6..4324c1df144d 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1,5 +1,3 @@ -import warnings - from .distance import PairwiseDistance from .module import Module from .. import functional as F @@ -7,6 +5,7 @@ from torch import Tensor from typing import Callable, Optional +from typing_extensions import deprecated __all__ = ['L1Loss', 'NLLLoss', 'NLLLoss2d', 'PoissonNLLLoss', 'GaussianNLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', @@ -218,12 +217,15 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) +@deprecated( + "`NLLLoss2d` has been deprecated. " + "Please use `NLLLoss` instead as a drop-in replacement and see " + "https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss for more details.", + category=FutureWarning, +) class NLLLoss2d(NLLLoss): def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, reduce=None, reduction: str = 'mean') -> None: - warnings.warn("NLLLoss2d has been deprecated. " - "Please use NLLLoss instead as a drop-in replacement and see " - "https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss for more details.") super().__init__(weight, size_average, ignore_index, reduce, reduction) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 73420c0f32e7..dd6d64b68c23 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1335,20 +1335,28 @@ def _get_backward_pre_hooks(self): def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): if not isinstance(result, torch.Tensor): if not (isinstance(result, tuple) and all(isinstance(r, torch.Tensor) for r in result)): - warnings.warn("Using non-full backward hooks on a Module that does not return a " - "single Tensor or a tuple of Tensors is deprecated and will be removed " - "in future versions. This hook will be missing some of the grad_output. " - "Please use register_full_backward_hook to get the documented behavior.") + warnings.warn( + "Using non-full backward hooks on a Module that does not return a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_output. " + "Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) return else: result = (result,) if not isinstance(inputs, torch.Tensor): if not (isinstance(inputs, tuple) and all(isinstance(i, torch.Tensor) for i in inputs)): - warnings.warn("Using non-full backward hooks on a Module that does not take as input a " - "single Tensor or a tuple of Tensors is deprecated and will be removed " - "in future versions. This hook will be missing some of the grad_input. " - "Please use register_full_backward_hook to get the documented behavior.") + warnings.warn( + "Using non-full backward hooks on a Module that does not take as input a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_input. " + "Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) return else: inputs = (inputs,) @@ -1356,13 +1364,21 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): # At this point we are sure that inputs and result are tuple of Tensors out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None} if len(out_grad_fn) == 0 or (len(out_grad_fn) == 1 and grad_fn not in out_grad_fn): - warnings.warn("Using a non-full backward hook when outputs are nested in python data structure " - "is deprecated and will be removed in future versions. This hook will be missing " - "some grad_output.") + warnings.warn( + "Using a non-full backward hook when outputs are nested in python data structure " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output.", + FutureWarning, + stacklevel=2, + ) elif len(out_grad_fn) > 1: - warnings.warn("Using a non-full backward hook when outputs are generated by different autograd Nodes " - "is deprecated and will be removed in future versions. This hook will be missing " - "some grad_output. Please use register_full_backward_hook to get the documented behavior.") + warnings.warn( + "Using a non-full backward hook when outputs are generated by different autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output. Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) else: # At this point the grad_output part of the hook will most likely be correct inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None} @@ -1370,10 +1386,14 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): next_functions = {n[0] for n in grad_fn.next_functions} if inputs_grad_fn != next_functions: - warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes " - "is deprecated and will be removed in future versions. This hook will be missing " - "some grad_input. Please use register_full_backward_hook to get the documented " - "behavior.") + warnings.warn( + "Using a non-full backward hook when the forward contains multiple autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_input. Please use register_full_backward_hook to get the documented " + "behavior.", + FutureWarning, + stacklevel=2, + ) def register_forward_pre_hook( self, @@ -1887,17 +1907,20 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): """ # TODO: Remove `args` and the parsing logic when BC allows. if len(args) > 0: + # DeprecationWarning is ignored by default + warnings.warn( + "Positional args are being deprecated, use kwargs instead. Refer to " + "https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.state_dict" + " for details.", + FutureWarning, + stacklevel=2, + ) if destination is None: destination = args[0] if len(args) > 1 and prefix == '': prefix = args[1] if len(args) > 2 and keep_vars is False: keep_vars = args[2] - # DeprecationWarning is ignored by default - warnings.warn( - "Positional args are being deprecated, use kwargs instead. Refer to " - "https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.state_dict" - " for details.") if destination is None: destination = OrderedDict() diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 742bec9ebd19..b4bdd7824474 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -3,6 +3,7 @@ import numbers import weakref from typing import List, Tuple, Optional, overload +from typing_extensions import deprecated import torch from torch import Tensor @@ -24,8 +25,11 @@ def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Ten return tensor.index_select(dim, permutation) +@deprecated( + "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", + category=FutureWarning, +) def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: - warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead") return _apply_permutation(tensor, permutation, dim) diff --git a/torch/nn/parallel/__init__.py b/torch/nn/parallel/__init__.py index d0708296d47a..adcd6bd838eb 100644 --- a/torch/nn/parallel/__init__.py +++ b/torch/nn/parallel/__init__.py @@ -1,3 +1,5 @@ +from typing_extensions import deprecated + from .parallel_apply import parallel_apply from .replicate import replicate from .data_parallel import DataParallel, data_parallel @@ -7,8 +9,11 @@ __all__ = ['replicate', 'scatter', 'parallel_apply', 'gather', 'data_parallel', 'DataParallel', 'DistributedDataParallel'] + +@deprecated( + "`torch.nn.parallel.DistributedDataParallelCPU` is deprecated, " + "please use `torch.nn.parallel.DistributedDataParallel` instead.", + category=FutureWarning, +) def DistributedDataParallelCPU(*args, **kwargs): - import warnings - warnings.warn("torch.nn.parallel.DistributedDataParallelCPU is deprecated, " - "please use torch.nn.parallel.DistributedDataParallel instead.") return DistributedDataParallel(*args, **kwargs) diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index 764775587d68..22cf80bd64e2 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -226,7 +226,10 @@ def gather(tensors, dim=0, destination=None, *, out=None): if destination == -1: warnings.warn( 'Using -1 to represent CPU tensor is deprecated. Please use a ' - 'device object or string instead, e.g., "cpu".') + 'device object or string instead, e.g., "cpu".', + FutureWarning, + stacklevel=2, + ) destination = _get_device_index(destination, allow_cpu=True, optional=True) return torch._C._gather(tensors, dim, destination) else: diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index b27c960a154c..ef6034ade58e 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -771,7 +771,9 @@ def __init__( # do not receive gradients. warnings.warn( "The `check_reduction` argument in `DistributedDataParallel` " - "module is deprecated. Please avoid using it." + "module is deprecated. Please avoid using it.", + FutureWarning, + stacklevel=2, ) # Check that a module does not have Uninitialized parameters diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index 8daa1117bfaf..f6fb9d47ecbf 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,13 +1,17 @@ import torch from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, overload +from typing_extensions import deprecated from ._functions import Scatter, Gather -import warnings __all__ = ['scatter', 'scatter_kwargs', 'gather'] + +@deprecated( + "`is_namedtuple` is deprecated, please use the python checks instead", + category=FutureWarning, +) def is_namedtuple(obj: Any) -> bool: # Check if type was created from collections.namedtuple or a typing.NamedTuple. - warnings.warn("is_namedtuple is deprecated, please use the python checks instead") return _is_namedtuple(obj) def _is_namedtuple(obj: Any) -> bool: diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 6549a6f3e2c8..4ac8a4e7445b 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -1,6 +1,6 @@ -import warnings import functools from typing import Union, Iterable, List, Dict, Tuple, Optional, cast +from typing_extensions import deprecated import torch from torch import Tensor @@ -99,6 +99,11 @@ def clip_grad_norm_( return total_norm +@deprecated( + "`torch.nn.utils.clip_grad_norm` is now deprecated " + "in favor of `torch.nn.utils.clip_grad_norm_`.", + category=FutureWarning, +) def clip_grad_norm( parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2., error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor: @@ -108,8 +113,6 @@ def clip_grad_norm( This method is now deprecated in favor of :func:`torch.nn.utils.clip_grad_norm_`. """ - warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor " - "of torch.nn.utils.clip_grad_norm_.", stacklevel=2) return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach) diff --git a/torch/nn/utils/stateless.py b/torch/nn/utils/stateless.py index 2cb6c7460d4c..660a1a484ebb 100644 --- a/torch/nn/utils/stateless.py +++ b/torch/nn/utils/stateless.py @@ -1,7 +1,7 @@ import contextlib -import warnings from collections import defaultdict from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union +from typing_extensions import deprecated import torch from torch import Tensor @@ -148,6 +148,12 @@ def _reparametrize_module( ) +@deprecated( + "`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 " + "and will be removed in a future version of PyTorch. " + "Please use `torch.func.functional_call` instead which is a drop-in replacement.", + category=FutureWarning, +) def functional_call( module: "torch.nn.Module", parameters_and_buffers: Dict[str, Tensor], @@ -216,12 +222,6 @@ def functional_call( Returns: Any: the result of calling ``module``. """ - warnings.warn( - "This API is deprecated as of PyTorch 2.0 and will be removed in a future " - "version of PyTorch. Please use torch.func.functional_call instead " - "which is a drop-in replacement for this API." - ) - return _functional_call( module, parameters_and_buffers, diff --git a/torch/nn/utils/weight_norm.py b/torch/nn/utils/weight_norm.py index 942a13a4eb83..6cfe4b3e526d 100644 --- a/torch/nn/utils/weight_norm.py +++ b/torch/nn/utils/weight_norm.py @@ -2,7 +2,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter from torch import _weight_norm, norm_except_dim from typing import Any, TypeVar -import warnings +from typing_extensions import deprecated from ..modules import Module __all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm'] @@ -24,9 +24,12 @@ def compute_weight(self, module: Module) -> Any: return _weight_norm(v, g, self.dim) @staticmethod + @deprecated( + "`torch.nn.utils.weight_norm` is deprecated " + "in favor of `torch.nn.utils.parametrizations.weight_norm`.", + category=FutureWarning, + ) def apply(module, name: str, dim: int) -> 'WeightNorm': - warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.") - for hook in module._forward_pre_hooks.values(): if isinstance(hook, WeightNorm) and hook.name == name: raise RuntimeError(f"Cannot register two weight_norm hooks on the same parameter {name}") diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 2c16c34bb3c3..5d1b50bc3020 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -598,7 +598,11 @@ def __init__( ): activities_set = set(activities) if activities else supported_activities() if use_cuda is not None: - warn("use_cuda is deprecated, use activities argument instead") + warn( + "`use_cuda` is deprecated, use `activities` argument instead", + FutureWarning, + stacklevel=2, + ) if use_cuda: activities_set.add(ProfilerActivity.CUDA) elif ProfilerActivity.CUDA in activities_set: diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 587fcc0d72ea..d592e5ef6a62 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -359,9 +359,12 @@ def to_sparse_semi_structured( [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16)) """ if transposed: - raise DeprecationWarning( - "Setting transpose from to_sparse_semi_structured is deprecated and will be removed in a future release." - "SparseSemiStructuredTensor only support contiguous input tensors. " + warnings.warn( + "Setting transpose from `to_sparse_semi_structured` is deprecated " + "and will be removed in a future release. " + "`SparseSemiStructuredTensor` only support contiguous input tensors.", + FutureWarning, + stacklevel=2, ) # set from _FORCE_CUTLASS flag diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index e2bad14e4490..85d5adb0cd3a 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -2,7 +2,6 @@ import cmath import collections.abc import contextlib -import warnings from typing import ( Any, Callable, @@ -16,6 +15,7 @@ Type, Union, ) +from typing_extensions import deprecated import torch @@ -1523,6 +1523,12 @@ def assert_close( raise error_metas[0].to_error(msg) +@deprecated( + "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. " + "Please use `torch.testing.assert_close()` instead. " + "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.", + category=FutureWarning, +) def assert_allclose( actual: Any, expected: Any, @@ -1538,14 +1544,6 @@ def assert_allclose( Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions `here `_. """ - warnings.warn( - "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. " - "Please use `torch.testing.assert_close()` instead. " - "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.", - FutureWarning, - stacklevel=2, - ) - if not isinstance(actual, torch.Tensor): actual = torch.tensor(actual) if not isinstance(expected, torch.Tensor): diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index 0b01b172a477..d8fb2ef18b1d 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -150,8 +150,9 @@ def clamp(a: float, l: float, h: float) -> float: warnings.warn( "Passing `low==high` to `torch.testing.make_tensor` for floating or complex types " "is deprecated since 2.1 and will be removed in 2.3. " - "Use torch.full(...) instead.", + "Use `torch.full(...)` instead.", FutureWarning, + stacklevel=3, ) elif low >= high: raise ValueError(f"`low` must be less than `high`, but got {low} >= {high}") diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index f468e2d84890..6b38645e486b 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -7,9 +7,9 @@ import pickle import tokenize import unittest -import warnings from types import FunctionType, ModuleType from typing import Any, Dict, Optional, Set, Union +from typing_extensions import deprecated from unittest import mock # Types saved/loaded in configs @@ -196,12 +196,12 @@ def get_hash(self) -> bytes: self._is_dirty = False return self._hash_digest + @deprecated( + "`config.to_dict()` has been deprecated. It may no longer change the underlying config." + " use `config.shallow_copy_dict()` or `config.get_config_copy()` instead", + category=FutureWarning, + ) def to_dict(self) -> Dict[str, Any]: - warnings.warn( - "config.to_dict() has been deprecated. It may no longer change the underlying config." - " use config.shallow_copy_dict() or config.get_config_copy() instead", - DeprecationWarning, - ) return self.shallow_copy_dict() def shallow_copy_dict(self) -> Dict[str, Any]: diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index c55e69618575..59b7d368af26 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -122,10 +122,14 @@ class _DecoratorContextManager: def __call__(self, orig_func: F) -> F: if inspect.isclass(orig_func): - warnings.warn("Decorating classes is deprecated and will be disabled in " - "future versions. You should only decorate functions or methods. " - "To preserve the current behavior of class decoration, you can " - "directly decorate the `__init__` method and nothing else.") + warnings.warn( + "Decorating classes is deprecated and will be disabled in " + "future versions. You should only decorate functions or methods. " + "To preserve the current behavior of class decoration, you can " + "directly decorate the `__init__` method and nothing else.", + FutureWarning, + stacklevel=2, + ) func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs)) else: func = orig_func diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index aba15f1482f2..01adf0a4f9b1 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -15,7 +15,6 @@ import functools import sys import types -import warnings from typing import ( Any, Callable, @@ -28,6 +27,7 @@ TypeVar, Union, ) +from typing_extensions import deprecated import torch @@ -167,6 +167,11 @@ def register_pytree_node( ) +@deprecated( + "`torch.utils._cxx_pytree._register_pytree_node` is deprecated. " + "Please use `torch.utils._cxx_pytree.register_pytree_node` instead.", + category=FutureWarning, +) def _register_pytree_node( cls: Type[Any], flatten_fn: FlattenFunc, @@ -207,11 +212,6 @@ def _register_pytree_node( original context. This is used for json deserialization, which is being used in :mod:`torch.export` right now. """ - warnings.warn( - "torch.utils._cxx_pytree._register_pytree_node is deprecated. " - "Please use torch.utils._cxx_pytree.register_pytree_node instead.", - stacklevel=2, - ) _private_register_pytree_node( cls, diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 2831d662d9f6..b4a0db5db730 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -48,6 +48,7 @@ TypeVar, Union, ) +from typing_extensions import deprecated __all__ = [ @@ -251,6 +252,11 @@ def _register_namedtuple( ) +@deprecated( + "`torch.utils._pytree._register_pytree_node` is deprecated. " + "Please use `torch.utils._pytree.register_pytree_node` instead.", + category=FutureWarning, +) def _register_pytree_node( cls: Type[Any], flatten_fn: FlattenFunc, @@ -287,16 +293,12 @@ def _register_pytree_node( Like ``flatten_fn``, but in place of a List[leaf], it should return a List[(keypath, leaf)]. """ - warnings.warn( - "torch.utils._pytree._register_pytree_node is deprecated. " - "Please use torch.utils._pytree.register_pytree_node instead.", - stacklevel=2, - ) - if to_str_fn is not None or maybe_from_str_fn is not None: warnings.warn( - "to_str_fn and maybe_from_str_fn is deprecated. " - "Please use to_dumpable_context and from_dumpable_context instead." + "`to_str_fn` and `maybe_from_str_fn` is deprecated. " + "Please use `to_dumpable_context` and `from_dumpable_context` instead.", + FutureWarning, + stacklevel=2, ) _private_register_pytree_node( @@ -1451,14 +1453,20 @@ def treespec_pprint(treespec: TreeSpec) -> str: # TODO(angelayi): remove this function after OSS/internal stabilize +@deprecated( + "`pytree_to_str` is deprecated. Please use `treespec_dumps` instead.", + category=FutureWarning, +) def pytree_to_str(treespec: TreeSpec) -> str: - warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") return treespec_dumps(treespec) # TODO(angelayi): remove this function after OSS/internal stabilize +@deprecated( + "`str_to_pytree` is deprecated. Please use `treespec_loads` instead.", + category=FutureWarning, +) def str_to_pytree(json: str) -> TreeSpec: - warnings.warn("str_to_pytree is deprecated. Please use treespec_loads") return treespec_loads(json) diff --git a/torch/utils/data/backward_compatibility.py b/torch/utils/data/backward_compatibility.py index be97f016a091..f51418265f41 100644 --- a/torch/utils/data/backward_compatibility.py +++ b/torch/utils/data/backward_compatibility.py @@ -1,5 +1,10 @@ -import warnings +from typing_extensions import deprecated as _deprecated + +@_deprecated( + "Usage of `backward_compatibility.worker_init_fn` is deprecated " + "as `DataLoader` automatically applies sharding in every worker", + category=FutureWarning, +) def worker_init_fn(worker_id): - warnings.warn("Usage of backward_compatibility.worker_init_fn is deprecated" - " as DataLoader automatically applies sharding in every worker") + pass diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index 554bf90d108b..b3cf9d92943d 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -14,6 +14,7 @@ TypeVar, Union, ) +from typing_extensions import deprecated # No 'default_generator' in torch/__init__.pyi from torch import default_generator, randperm @@ -348,12 +349,11 @@ def __getitem__(self, idx): return self.datasets[dataset_idx][sample_idx] @property + @deprecated( + "`cummulative_sizes` attribute is renamed to `cumulative_sizes`", + category=FutureWarning, + ) def cummulative_sizes(self): - warnings.warn( - "cummulative_sizes attribute is renamed to " "cumulative_sizes", - DeprecationWarning, - stacklevel=2, - ) return self.cumulative_sizes diff --git a/torch/utils/data/graph_settings.py b/torch/utils/data/graph_settings.py index 4b42cc6065a7..573069279201 100644 --- a/torch/utils/data/graph_settings.py +++ b/torch/utils/data/graph_settings.py @@ -2,6 +2,7 @@ import warnings from typing import Any, List, Optional, Set +from typing_extensions import deprecated import torch @@ -116,11 +117,12 @@ def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool] = None) - return datapipe +@deprecated( + "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. " + "Please use `apply_random_seed` instead.", + category=FutureWarning, +) def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe: - warnings.warn( - "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases." - "\nPlease use `apply_random_seed` instead." - ) return apply_random_seed(datapipe, rng) From c2547dfcc339f1c788b13ee3e191fb900e22207a Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sun, 2 Jun 2024 13:38:33 +0000 Subject: [PATCH 231/706] [BE][Ez]: Enable ruff PYI019 (#127684) Tells pytorch to use typing_extensions.Self when it's able to. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127684 Approved by: https://github.com/ezyang --- pyproject.toml | 1 - torch/nn/utils/rnn.pyi | 49 +++++++------------ .../_internal/diagnostics/infra/context.py | 12 +++-- 3 files changed, 26 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8ec04f77042d..01fce4a7d6fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,6 @@ ignore = [ "PERF401", "PERF403", # these ignores are from PYI; please fix! - "PYI019", "PYI024", "PYI036", "PYI041", diff --git a/torch/nn/utils/rnn.pyi b/torch/nn/utils/rnn.pyi index fed87febe2a6..fd033d8888be 100644 --- a/torch/nn/utils/rnn.pyi +++ b/torch/nn/utils/rnn.pyi @@ -1,14 +1,5 @@ -from typing import ( - Any, - Iterable, - NamedTuple, - Optional, - overload, - Sequence, - Tuple, - TypeVar, - Union, -) +from typing import Any, Iterable, NamedTuple, Optional, overload, Sequence, Tuple, Union + from typing_extensions import Self from torch import Tensor @@ -24,8 +15,6 @@ class PackedSequence_(NamedTuple): def bind(optional: Any, fn: Any): ... -_T = TypeVar("_T") - class PackedSequence(PackedSequence_): def __new__( cls, @@ -34,39 +23,39 @@ class PackedSequence(PackedSequence_): sorted_indices: Optional[Tensor] = ..., unsorted_indices: Optional[Tensor] = ..., ) -> Self: ... - def pin_memory(self: _T) -> _T: ... - def cuda(self: _T, *args: Any, **kwargs: Any) -> _T: ... - def cpu(self: _T) -> _T: ... - def double(self: _T) -> _T: ... - def float(self: _T) -> _T: ... - def half(self: _T) -> _T: ... - def long(self: _T) -> _T: ... - def int(self: _T) -> _T: ... - def short(self: _T) -> _T: ... - def char(self: _T) -> _T: ... - def byte(self: _T) -> _T: ... + def pin_memory(self: Self) -> Self: ... + def cuda(self: Self, *args: Any, **kwargs: Any) -> Self: ... + def cpu(self: Self) -> Self: ... + def double(self: Self) -> Self: ... + def float(self: Self) -> Self: ... + def half(self: Self) -> Self: ... + def long(self: Self) -> Self: ... + def int(self: Self) -> Self: ... + def short(self: Self) -> Self: ... + def char(self: Self) -> Self: ... + def byte(self: Self) -> Self: ... @overload def to( - self: _T, + self: Self, dtype: _dtype, non_blocking: bool = False, copy: bool = False, - ) -> _T: ... + ) -> Self: ... @overload def to( - self: _T, + self: Self, device: Optional[DeviceLikeType] = None, dtype: Optional[_dtype] = None, non_blocking: bool = False, copy: bool = False, - ) -> _T: ... + ) -> Self: ... @overload def to( - self: _T, + self: Self, other: Tensor, non_blocking: bool = False, copy: bool = False, - ) -> _T: ... + ) -> Self: ... @property def is_cuda(self) -> bool: ... def is_pinned(self) -> bool: ... diff --git a/torch/onnx/_internal/diagnostics/infra/context.py b/torch/onnx/_internal/diagnostics/infra/context.py index 22370850df86..6106a42467c1 100644 --- a/torch/onnx/_internal/diagnostics/infra/context.py +++ b/torch/onnx/_internal/diagnostics/infra/context.py @@ -21,6 +21,8 @@ TypeVar, ) +from typing_extensions import Self + from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version @@ -92,24 +94,24 @@ def sarif(self) -> sarif.Result: ) return sarif_result - def with_location(self: _Diagnostic, location: infra.Location) -> _Diagnostic: + def with_location(self: Self, location: infra.Location) -> Self: """Adds a location to the diagnostic.""" self.locations.append(location) return self def with_thread_flow_location( - self: _Diagnostic, location: infra.ThreadFlowLocation - ) -> _Diagnostic: + self: Self, location: infra.ThreadFlowLocation + ) -> Self: """Adds a thread flow location to the diagnostic.""" self.thread_flow_locations.append(location) return self - def with_stack(self: _Diagnostic, stack: infra.Stack) -> _Diagnostic: + def with_stack(self: Self, stack: infra.Stack) -> Self: """Adds a stack to the diagnostic.""" self.stacks.append(stack) return self - def with_graph(self: _Diagnostic, graph: infra.Graph) -> _Diagnostic: + def with_graph(self: Self, graph: infra.Graph) -> Self: """Adds a graph to the diagnostic.""" self.graphs.append(graph) return self From e24a87ed8d1fb1a50dcd29a0ae1f4767418dfd50 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sun, 2 Jun 2024 13:38:56 +0000 Subject: [PATCH 232/706] [BE][Ez]: Apply PYI059 - Generic always come last (#127685) Generic baseclass should always be last or unexpected issues can occur, especially in non-stub files (such as with MRO). Applies autofixes from the preview PYI059 rule to fix the issues in the codebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127685 Approved by: https://github.com/ezyang --- test/test_datapipe.py | 2 +- torch/_inductor/utils.py | 2 +- torch/onnx/_internal/registration.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index b6be7eb76b97..37cf896eda24 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -2423,7 +2423,7 @@ def test_batch_mapdatapipe(self): _generic_namedtuple_allowed = sys.version_info >= (3, 7) and sys.version_info < (3, 9) if _generic_namedtuple_allowed: - class InvalidData(Generic[T_co], NamedTuple): + class InvalidData(NamedTuple, Generic[T_co]): name: str data: T_co diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 7a96630ef213..0915a8330c34 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -390,7 +390,7 @@ def sort_func(elem): RV = TypeVar("RV", covariant=True) -class CachedMethod(Generic[P, RV], Protocol): +class CachedMethod(Protocol, Generic[P, RV]): @staticmethod def clear_cache(self) -> None: ... diff --git a/torch/onnx/_internal/registration.py b/torch/onnx/_internal/registration.py index 017a2fb7dadf..3b2e68e1e40a 100644 --- a/torch/onnx/_internal/registration.py +++ b/torch/onnx/_internal/registration.py @@ -61,7 +61,7 @@ def _dispatch_opset_version( _V = TypeVar("_V") -class OverrideDict(Generic[_K, _V], Collection[_K]): +class OverrideDict(Collection[_K], Generic[_K, _V]): """A dictionary that merges built-in and custom symbolic functions. It supports overriding and un-overriding built-in symbolic functions with custom From 08653fe355549aa484034efa895ef2ecdb91dd11 Mon Sep 17 00:00:00 2001 From: rzou Date: Fri, 31 May 2024 12:18:50 -0700 Subject: [PATCH 233/706] Beef up the allow_in_graph docs (#127117) We make the following changes: - most of the time when someone uses allow_in_graph, they actually wanted to make a custom op. We add a link to the custom ops landing page and explain the differences between allow_in_graph and custom ops. - we warn people against using allow_in_graph footguns and document them. Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/127117 Approved by: https://github.com/jansel, https://github.com/albanD --- torch/_dynamo/decorators.py | 18 ++------- torch/compiler/__init__.py | 77 +++++++++++++++++++++++++++++++------ 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 8c82bf542169..201dbd2f1453 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -74,22 +74,12 @@ def assume_constant_result(fn): def allow_in_graph(fn): """ - Customize which functions TorchDynamo will include in the generated - graph. Similar to `torch.fx.wrap()`. - :: - - torch._dynamo.allow_in_graph(my_custom_function) + Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function + and instead directly write it to the graph when encountered. - @torch._dynamo.optimize(...) - def fn(a): - x = torch.add(x, 1) - x = my_custom_function(x) - x = torch.add(x, 1) - return x - - fn(...) + See :func:`torch.compiler.allow_in_graph`'s docstring for the full documentation - Will capture a single graph containing `my_custom_function()`. + WARNING: this API can be a footgun, please read the documentation carefully. """ if isinstance(fn, (list, tuple)): return [allow_in_graph(x) for x in fn] diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index cf0b544e929a..a27238c3d833 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -32,22 +32,77 @@ def reset() -> None: def allow_in_graph(fn): """ - Customize which functions compilation will include in the generated graph. - It bypasses all introspection of the symbolic python code in favor of - directly writing it to the graph. - If fn is a list or tuple of callables it recursively applies :func:`allow_in_graph()` - to each function and returns a new list or tuple containing the modified functions + Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function + and instead directly write it to the graph when encountered. + + If you are using :func:`torch.compile` (with backend="inductor" (the default)), or + :func:`torch.export.export`, and trying to black-box a Python function throughout + all tracing, do not use this API. + Instead, please create a custom operator (see :ref:`custom-ops-landing-page`) + + .. warning:: + + If you're a typical torch.compile user (e.g. you're applying torch.compile to + a model to make it run faster), you probably don't want to use this function. + :func:`allow_in_graph` is a footgun because it skips the compiler frontend + (Dynamo) that is responsible for doing safety checks (graph breaks, handling + closures, etc). Incorrect usage will lead to difficult-to-debug silent + incorrectness issues. + + Given a Python function with no allow_in_graph decorator, regular execution + of torch.compile traces through the function. :func:`allow_in_graph` changes + it so that the frontend does not trace inside the function, but the compiler + backend still traces through it. Compare this to custom operators, which + treats a function as a black box throughout the torch.compile stack. The following + table compares these mechanisms. + + +------------------------+-----------------------+--------------------------------+ + | Mechanism | Frontend (Dynamo) | Backend (AOTAutograd+Inductor) | + +========================+=======================+================================+ + | no decorator | trace inside | trace inside | + +------------------------+-----------------------+--------------------------------+ + | allow_in_graph | opaque callable | trace inside | + +------------------------+-----------------------+--------------------------------+ + | custom op | opaque callable | opaque callable | + +------------------------+-----------------------+--------------------------------+ + + One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler + frontend: if you know the function works w.r.t. to the downstream components of the + compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from + symbolically introspecting the function properly (or if your code is in C/C++ and + therefore cannot be introspected with Dynamo), then one can decorate said function + with :func:`allow_in_graph` to bypass Dynamo. + + We require that ``fn`` adhere to the following restrictions. Failure to adhere + results in undefined behavior: + + - The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include: + Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?] + Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device + - The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet) + - all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn`` + (as opposed to being captured variables). Args: fn: A callable representing the function to be included in the graph. + If ``fn`` is a list or tuple of callables it recursively applies + :func:`allow_in_graph()` to each function and returns a new list or + tuple containing the modified functions. - .. warning:: + Example:: + + torch.compiler.allow_in_graph(my_custom_function) + + @torch.compile(...) + def fn(a): + x = torch.add(x, 1) + x = my_custom_function(x) + x = torch.add(x, 1) + return x + + fn(...) - :func:`allow_in_graph` skips TorchDynamo completely on the decorated function - skipping all TorchDynamo safety checks (graph breaks, handling closures, etc). - Therefore, one has to be very careful with :func:`allow_in_graph` since subsystems - like AOT Autograd rely on torchdynamo - If not careful, this could lead to soundness and really hard-to-debug issues. + Will capture a single graph containing ``my_custom_function()``. """ import torch._dynamo From fb53cd64973167a3f4161d5d48fb11a022bf43f0 Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Sun, 2 Jun 2024 16:25:02 +0000 Subject: [PATCH 234/706] =?UTF-8?q?[aten=5Fcuda/flash=5Fattn]=20Add=20type?= =?UTF-8?q?name=20to=20template=20argument=20Kernel=5Ftrait=E2=80=A6=20(#1?= =?UTF-8?q?27634)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the `typename` keyword to the template argument `Kernel_traits::TiledMma` and `Kernel_traits::TiledMmaSdP` (which are dependent type names) when calling the template function `pytorch_flash::convert_layout_acc_Aregs`. Without `typename` flash_attention kernels do not compile with Clang under C++20 since Clang compiles the entire .cu file in a single pass as opposed to NVCC which split compiles the host and device code. Adding `typename` seems to be OK under NVCC based on CI cuda builds succeeding. Below is the excerpt of the compilation error: ``` third_party/py/torch/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h:46:24: note: expanded from macro 'ALIBI_SWITCH' 46 | #define ALIBI_SWITCH BOOL_SWITCH | ^ third_party/py/torch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h:132:5: note: in instantiation of function template specialization 'pytorch_flash::run_flash_bwd_seqk_parallel, true>' requested here 132 | run_flash_bwd_seqk_parallel(params, stream); | ^ third_party/py/torch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h:280:13: note: in instantiation of function template specialization 'pytorch_flash::run_flash_bwd, true>' requested here 280 | run_flash_bwd, Is_dropout>(params, stream); | ^ third_party/py/torch/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h:36:26: note: expanded from macro 'DROPOUT_SWITCH' 36 | #define DROPOUT_SWITCH BOOL_SWITCH | ^ third_party/py/torch/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu:12:5: note: in instantiation of function template specialization 'pytorch_flash::run_mha_bwd_hdim160' request ed here 12 | run_mha_bwd_hdim160(params, stream); | ^ In file included from third_party/py/torch/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu:7: In file included from third_party/py/torch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h:12: third_party/py/torch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h:543:86: error: missing 'typename' prior to dependent type name 'Flash_bwd_kernel_traits<160, 64, 64, 8, 4, 4, 4, false, true>::TiledMmaSdP' 543 | Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127634 Approved by: https://github.com/Skylion007 --- .../transformers/cuda/flash_attn/flash_bwd_kernel.h | 2 +- .../transformers/cuda/flash_attn/flash_fwd_kernel.h | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h index db817a0657ff..5089fb2e294f 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h @@ -540,7 +540,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in : pytorch_flash::convert_type_relu(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2) // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); // if (cute::thread0()) { print(tPaP); } diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h index 0386a07cc64f..9d97abb5eb90 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h @@ -339,7 +339,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } @@ -402,7 +402,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } @@ -895,7 +895,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = pytorch_flash::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); @@ -957,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = pytorch_flash::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } From 30213ab0a7b27277e76ea9dd707ce629a63d91ee Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sun, 2 Jun 2024 21:07:23 +0000 Subject: [PATCH 235/706] [BE]: Update mypy to 1.10.0 (#127717) Updates mypy to the latest and greatest. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127717 Approved by: https://github.com/ezyang --- .ci/docker/requirements-ci.txt | 4 ++-- .lintrunner.toml | 2 +- torch/ao/quantization/fx/fuse_handler.py | 2 +- torch/distributed/_composable/fsdp/fully_shard.py | 6 +++++- torch/distributions/utils.py | 2 +- torch/fx/experimental/rewriter.py | 2 +- 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 0f5f1bb12bd5..e6866a94183a 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -85,10 +85,10 @@ librosa>=0.6.2 ; python_version < "3.11" #Pinned versions: #test that import: -mypy==1.9.0 +mypy==1.10.0 # Pin MyPy version because new errors are likely to appear with each release #Description: linter -#Pinned versions: 1.9.0 +#Pinned versions: 1.10.0 #test that import: test_typing.py, test_type_hints.py networkx==2.8.8 diff --git a/.lintrunner.toml b/.lintrunner.toml index eca2af96b761..e31279004348 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -136,7 +136,7 @@ init_command = [ 'numpy==1.24.3 ; python_version == "3.8"', 'numpy==1.26.0 ; python_version >= "3.9"', 'expecttest==0.1.6', - 'mypy==1.9.0', + 'mypy==1.10.0', 'sympy==1.11.1', 'types-requests==2.27.25', 'types-PyYAML==6.0.7', diff --git a/torch/ao/quantization/fx/fuse_handler.py b/torch/ao/quantization/fx/fuse_handler.py index 718cc561bfa0..123a51f6ff87 100644 --- a/torch/ao/quantization/fx/fuse_handler.py +++ b/torch/ao/quantization/fx/fuse_handler.py @@ -44,7 +44,7 @@ class DefaultFuseHandler(FuseHandler): def __init__( self, node: Node): - super().__init__(node) + super().__init__(node) # type:ignore[safe-super] def fuse(self, load_arg: Callable, diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 3efb8f7afd85..18967fe41468 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -342,4 +342,8 @@ def wrapped_method(self, *args, **kwargs): return fsdp_state._post_forward(self, args, out) # Use `__get__` to make `wrapped_method` an instance method - setattr(module, method_name, wrapped_method.__get__(module, type(module))) + setattr( + module, + method_name, + wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined] + ) diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 7a6d31a05722..f897e63b7891 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -117,7 +117,7 @@ class lazy_property: def __init__(self, wrapped): self.wrapped = wrapped - update_wrapper(self, wrapped) + update_wrapper(self, wrapped) # type:ignore[arg-type] def __get__(self, instance, obj_type=None): if instance is None: diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py index 85a95895f7c9..969717e01030 100644 --- a/torch/fx/experimental/rewriter.py +++ b/torch/fx/experimental/rewriter.py @@ -60,7 +60,7 @@ def change_func_globals(f, globals): closure=f.__closure__, ) g = functools.update_wrapper(g, f) - g.__kwdefaults__ = copy.copy(f.__kwdefaults__) + g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined] return g # Return the correct FunctionType object return change_func_globals(fn_compiled, globals=fn.__globals__) From 139b9c6529ba1106e28c345d76e9602dd3f6a6ab Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 2 Jun 2024 07:01:03 -0700 Subject: [PATCH 236/706] Avoid reference cycle in inner closure (#127711) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/127711 Approved by: https://github.com/Skylion007, https://github.com/izaitsevfb --- torch/_subclasses/meta_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 08c4c9ce4e3f..3a3a6eda012f 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -532,7 +532,7 @@ def json(k, v): # fields (feel free to add other special cases as appropriate) if k in ["data", "autograd_meta_from"]: return None # never repr these - if k in set(self._UNSERIALIZABLE): + if k in set(MetaTensorDesc._UNSERIALIZABLE): return repr(v) if isinstance(v, (torch.device, torch.dtype, torch.layout)): return repr(v) From 8b08b0f340a7600b7945eba1d1ce64eecc9488bb Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 2 Jun 2024 23:25:26 +0000 Subject: [PATCH 237/706] [BE] enable ruff rule `Q` from flake8-quotes (#127713) Enable [ruff rule `Q`](https://docs.astral.sh/ruff/rules/#flake8-quotes-q) from flake8-quotes. Fixes: - [avoidable-escaped-quote (Q003)](https://docs.astral.sh/ruff/rules/avoidable-escaped-quote/#avoidable-escaped-quote-q003) - [unnecessary-escaped-quote (Q004)](https://docs.astral.sh/ruff/rules/unnecessary-escaped-quote/#unnecessary-escaped-quote-q004) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127713 Approved by: https://github.com/ezyang --- pyproject.toml | 2 ++ test/test_jit.py | 14 +++++++------- test/test_linalg.py | 8 ++++---- test/test_mobile_optimizer.py | 8 ++++---- test/test_static_runtime.py | 4 ++-- torch/_custom_op/impl.py | 2 +- torch/ao/quantization/fx/utils.py | 2 +- torch/distributed/tensor/parallel/_utils.py | 2 +- torch/fx/experimental/symbolic_shapes.py | 10 +++++----- torch/fx/graph.py | 2 +- torch/fx/subgraph_rewriter.py | 4 ++-- torch/hub.py | 2 +- torch/library.py | 6 +++--- torch/nn/modules/module.py | 2 +- torch/serialization.py | 2 +- torch/testing/_internal/common_utils.py | 6 +++--- .../testing/_internal/distributed/rpc/rpc_test.py | 2 +- .../utils/benchmark/examples/blas_compare_setup.py | 2 +- torch/utils/data/datapipes/gen_pyi.py | 12 ++++++------ torch/utils/hipify/hipify_python.py | 2 +- 20 files changed, 48 insertions(+), 46 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 01fce4a7d6fb..aa532c59da3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,8 @@ select = [ "PT025", "PT026", "PYI", + "Q003", # avoidable escaped quote + "Q004", # unnecessary escaped quote "RSE", "RUF008", # mutable dataclass default "RUF015", # access first ele in constant time diff --git a/test/test_jit.py b/test/test_jit.py index bb6f4e255888..0e99c3602cd6 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4329,7 +4329,7 @@ def foobar(xyz): return torch.blargh(xyz) _, lineno = inspect.getsourcelines(foobar) - with self.assertRaisesRegex(RuntimeError, f"test_jit.py\", line {lineno + 1}"): + with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 1}'): scripted = torch.jit.script(foobar) def test_file_line_error_class_defn(self): @@ -4338,7 +4338,7 @@ def baz(self, xyz): return torch.blargh(xyz) _, lineno = inspect.getsourcelines(FooBar) - with self.assertRaisesRegex(RuntimeError, f"test_jit.py\", line {lineno + 2}"): + with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 2}'): torch.jit.script(FooBar) def test_file_line_graph(self): @@ -4405,7 +4405,7 @@ def forward(self, x, w): loaded = self.getExportImportCopy(ft) _, lineno = inspect.getsourcelines(FooTest) - with self.assertRaisesRegex(RuntimeError, f'test_jit.py\", line {lineno + 3}'): + with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 3}'): loaded(torch.rand(3, 4), torch.rand(30, 40)) def test_serialized_source_ranges_graph(self): @@ -4431,7 +4431,7 @@ def forward(self): _, lineno = inspect.getsourcelines(FooTest2) - with self.assertRaisesRegex(torch.jit.Error, f'test_jit.py\", line {lineno + 3}'): + with self.assertRaisesRegex(torch.jit.Error, f'test_jit.py", line {lineno + 3}'): ft = FooTest2() loaded = self.getExportImportCopy(ft) loaded() @@ -10260,7 +10260,7 @@ def fn(x): n = next(graph.inputs()) self.assertTrue(n.type() == torch._C.TensorType.getInferred()) - with self.assertRaisesRegex(RuntimeError, "Inferred \'x\' to be of type \'Tensor"): + with self.assertRaisesRegex(RuntimeError, "Inferred 'x' to be of type 'Tensor"): fn("1") def test_script_define_order(self): @@ -12309,7 +12309,7 @@ def forward(self, x): tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) FileCheck().check_not("value=").check("aten::mm")\ - .check("prim::CallMethod[name=\"forward\"]").check("aten::add") \ + .check('prim::CallMethod[name="forward"]').check("aten::add") \ .run(str(tm.graph)) FileCheck().check("aten::mm").run(str(tm.mod.graph)) @@ -14743,7 +14743,7 @@ def forward(self): return self.hello("hi"), self.hello(.5) w = CompileOverloadError() - with self.assertRaisesRegex(Exception, "but instead found type \'str\'"): + with self.assertRaisesRegex(Exception, "but instead found type 'str'"): torch.jit.script(w) # testing overload declared first, then non-overload diff --git a/test/test_linalg.py b/test/test_linalg.py index 5963b1e50448..040b86e60d60 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -963,9 +963,9 @@ def test_eigh_errors_and_warnings(self, device, dtype): # eigh requires 'uplo' parameter to be 'U' or 'L' t = torch.randn(3, 3, device=device, dtype=dtype) for uplo in ["a", "wrong"]: - with self.assertRaisesRegex(RuntimeError, "be \'L\' or \'U\'"): + with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"): torch.linalg.eigh(t, UPLO=uplo) - with self.assertRaisesRegex(ValueError, "be \'L\' or \'U\'"): + with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"): np.linalg.eigh(t.cpu().numpy(), UPLO=uplo) # if non-empty out tensor with wrong shape is passed a warning is given @@ -1062,9 +1062,9 @@ def test_eigvalsh_errors_and_warnings(self, device, dtype): # eigvalsh requires 'uplo' parameter to be 'U' or 'L' t = torch.randn(3, 3, device=device, dtype=dtype) for uplo in ["a", "wrong"]: - with self.assertRaisesRegex(RuntimeError, "be \'L\' or \'U\'"): + with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"): torch.linalg.eigvalsh(t, UPLO=uplo) - with self.assertRaisesRegex(ValueError, "be \'L\' or \'U\'"): + with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"): np.linalg.eigvalsh(t.cpu().numpy(), UPLO=uplo) # if non-empty out tensor with wrong shape is passed a warning is given diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py index e672d69ab5dd..28113d0bdf08 100644 --- a/test/test_mobile_optimizer.py +++ b/test/test_mobile_optimizer.py @@ -149,7 +149,7 @@ def forward(self, x): bn_scripted_module.eval() self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11) - FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \ + FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \ .run(str(get_forward(bn_scripted_module._c).graph)) optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} @@ -250,7 +250,7 @@ def foo(self, x): bn_no_forward_scripted_module.eval() self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11) - FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \ + FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \ .run(bn_no_forward_scripted_module.foo.graph) bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo']) @@ -471,7 +471,7 @@ def _quant_script_and_optimize(model): # basic case m, m_optim = _quant_script_and_optimize(Standalone()) - FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \ + FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \ .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ .run(m_optim.graph) self.assertFalse(hasattr(m_optim, "conv1")) @@ -485,7 +485,7 @@ def _quant_script_and_optimize(model): # generic case m, m_optim = _quant_script_and_optimize(Parent()) - FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \ + FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \ .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ .run(m_optim.graph) self.assertFalse(hasattr(m_optim, "conv1")) diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index 434793508d47..863f3c37c217 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -330,7 +330,7 @@ def test_fork_wait_exception(self): raise RuntimeError( "Tried execution of add.Tensors with incompatible shape. " "Exception raised by forked runtime execution does " - f"not contain expected substring: \"{expected_error_msg}\"" + f'not contain expected substring: "{expected_error_msg}"' ) from error """ @@ -360,7 +360,7 @@ def test_fork_wait_exception_async(self): raise RuntimeError( "Tried execution of add.Tensors with incompatible shape. " "Exception raised by forked runtime execution does " - f"not contain expected substring: \"{expected_error_msg}\"" + f'not contain expected substring: "{expected_error_msg}"' ) from error def test_multihead_attention_layer(self): diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index df83c51bcfd9..d9200160057c 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -836,7 +836,7 @@ def _find_custom_op(qualname, also_check_torch_library=False): return global_registry[qualname] if not also_check_torch_library: raise RuntimeError( - f"Could not find custom op \"{qualname}\". Did you register it via " + f'Could not find custom op "{qualname}". Did you register it via ' f"the torch._custom_ops API?") overload = get_op(qualname) result = custom_op_from_existing(overload) diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index be26332b2485..5cfedde4bc24 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -837,7 +837,7 @@ def _activation_post_process_satisfies_dtype_config_constraints( suggestion_str = ( "Please use torch.ao.quantization.get_default_qconfig_mapping or " "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n" - " qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\")\n" + ' qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n' " model = prepare_fx(model, qconfig_mapping, example_inputs)" ) if not isinstance(activation_post_process, FixedQParamsObserver) and \ diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index e109f5e9af93..3c7e269fffea 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -46,7 +46,7 @@ def _validate_tp_mesh_dim( """ if device_mesh.ndim > 1: raise ValueError(f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" - "If you have a 2-D or N-D device_mesh, consider passing in device_mesh[\"tp\"]") + 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]') parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) if parent_mesh: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index d94068b770d3..a2abde3a861e 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2199,7 +2199,7 @@ def relation_with_digit(expr, op, digit): buf += ( f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! " - "For more information, run with TORCH_LOGS=\"+dynamic\".\n" + 'For more information, run with TORCH_LOGS="+dynamic".\n' ) for s, val in forced_specializations.items(): buf += f" - {s} must be specialized to {val} because the guards generated for it are too complex.\n" @@ -3539,7 +3539,7 @@ def create_symbol( if not is_debug: maybe_more_info = ( ", for more info run with " - f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{sympy_expr}\"" + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{sympy_expr}"' ) fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) self.log.info( @@ -4156,7 +4156,7 @@ def issue_guard(guard: ShapeGuard) -> None: err = '\n'.join(error_msgs) raise ConstraintViolationError( f"Constraints violated ({debug_names})! " - "For more information, run with TORCH_LOGS=\"+dynamic\".\n" + 'For more information, run with TORCH_LOGS="+dynamic".\n' f"{err}" ) elif len(warn_msgs) > 0: @@ -4609,7 +4609,7 @@ def _make_data_dependent_error(self, expr, unhinted_expr, *, size_oblivious_resu f"{size_oblivious_result_msg}" "Potential framework code culprit (scroll up for full backtrace):\n" f"{''.join(traceback.StackSummary.from_list([fsummary]).format())}\n" - "For more information, run with TORCH_LOGS=\"dynamic\"\n" + 'For more information, run with TORCH_LOGS="dynamic"\n' "For extended logs when we create symbols, also add " f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n" "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" @@ -5010,7 +5010,7 @@ def _log_guard(self, prefix: str, g, forcing_spec: bool): if not is_debug: maybe_more_info = ( ", for more info run with " - f"TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"{str_g}\"" + f'TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="{str_g}"' ) self.log.info( "%s %s [guard added]%s (%s)%s%s", diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 96b186cc6c48..7c73c89473d5 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1504,7 +1504,7 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: if node.graph is not self: raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') if node not in self._find_nodes_lookup_table: - raise RuntimeError(f"Node \'{node}\' is not added to the side table") + raise RuntimeError(f"Node '{node}' is not added to the side table") map_arg(node.args, lambda arg: check_arg(arg, node)) map_arg(node.kwargs, lambda arg: check_arg(arg, node)) seen_values.add(node) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index d0bb4b55a403..3106daca0b18 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -70,8 +70,8 @@ def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: # CASE 3: The target doesn't exist as an attribute in `gm` # or `replacement` else: - raise RuntimeError("Attempted to create a \"", node.op, - "\" node during subgraph rewriting " + raise RuntimeError('Attempted to create a "', node.op, + '" node during subgraph rewriting ' f"with target {node.target}, but " "the referenced attribute does not " "exist in the replacement GraphModule") diff --git a/torch/hub.py b/torch/hub.py index 4ea92ed6be82..0ba9e25a2830 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -234,7 +234,7 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T try: url = _git_archive_link(repo_owner, repo_name, ref) - sys.stderr.write(f'Downloading: \"{url}\" to {cached_file}\n') + sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') download_url_to_file(url, cached_file, progress=False) except HTTPError as err: if err.code == 300: diff --git a/torch/library.py b/torch/library.py index a69e16950f7e..3bd0a1b6bc8a 100644 --- a/torch/library.py +++ b/torch/library.py @@ -350,8 +350,8 @@ def define(qualname, schema, *, lib=None, tags=()): if not NAMELESS_SCHEMA.fullmatch(schema): raise ValueError( f"define(qualname, schema, ...): expected schema " - f"to look like e.g. \"(Tensor x) -> Tensor\" but " - f"got \"{schema}\"") + f'to look like e.g. "(Tensor x) -> Tensor" but ' + f'got "{schema}"') lib.define(name + schema, alias_analysis="", tags=tags) @@ -782,7 +782,7 @@ def inner(*args, **kwargs): raise RuntimeError( f"Operator '{qualname}' was defined in C++ and has a Python " f"fake impl. In this situation, we require there to also be a " - f"companion C++ `m.set_python_module(\"{actual_module_name}\")` " + f'companion C++ `m.set_python_module("{actual_module_name}")` ' f"call, but we could not find one. Please add that to " f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the " f"operator was registered in ({cpp_filename})") diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index dd6d64b68c23..ffd429cc06f2 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -348,7 +348,7 @@ def _forward_unimplemented(self, *input: Any) -> None: instead of this since the former takes care of running the registered hooks while the latter silently ignores them. """ - raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function") + raise NotImplementedError(f'Module [{type(self).__name__}] is missing the required "forward" function') class Module: diff --git a/torch/serialization.py b/torch/serialization.py index d47a49ddf0fd..9401c775a510 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1267,7 +1267,7 @@ def persistent_load(saved_id): if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2): raise RuntimeError( "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. " - f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this " + f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this ' "functionality.") magic_number = pickle_module.load(f, **pickle_load_args) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2237ec67c500..e748ff0388fb 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2284,7 +2284,7 @@ def matches_test(target: str): print(f"Test {disabled_test} is disabled for some unrecognized ", f"platforms: [{invalid_plats_str}]. Please edit issue {issue_url} to fix the platforms ", - "assigned to this flaky test, changing \"Platforms: ...\" to a comma separated ", + 'assigned to this flaky test, changing "Platforms: ..." to a comma separated ', f"subset of the following (or leave it blank to match all platforms): {valid_plats}") # Sanitize the platforms list so that we continue to disable the test for any valid platforms given @@ -4401,8 +4401,8 @@ def check_test_defined_in_running_script(test_case): if running_script_path is None: return test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__))) - assert test_case_class_file == running_script_path, f"Class of loaded TestCase \"{test_case.id()}\" " \ - f"is not defined in the running script \"{running_script_path}\", but in \"{test_case_class_file}\". Did you " \ + assert test_case_class_file == running_script_path, f'Class of loaded TestCase "{test_case.id()}" ' \ + f'is not defined in the running script "{running_script_path}", but in "{test_case_class_file}". Did you ' \ "accidentally import a unittest.TestCase from another file?" def load_tests(loader, tests, pattern): diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 5d2a67cd473a..764198338636 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -1232,7 +1232,7 @@ def test_self_remote_rref_as_self_remote_arg(self): def test_rref_proxy_non_exist(self): dst = worker_name((self.rank + 1) % self.world_size) rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3)) - msg = "has no attribute \'non_exist\'" + msg = "has no attribute 'non_exist'" with self.assertRaisesRegex(AttributeError, msg): rref.rpc_sync().non_exist() diff --git a/torch/utils/benchmark/examples/blas_compare_setup.py b/torch/utils/benchmark/examples/blas_compare_setup.py index c08acb50950f..44038539cae0 100644 --- a/torch/utils/benchmark/examples/blas_compare_setup.py +++ b/torch/utils/benchmark/examples/blas_compare_setup.py @@ -183,7 +183,7 @@ def main(): check_run = subprocess.run( # Shameless abuse of `python -c ...` f"source activate {env_path} && " - "python -c \"" + 'python -c "' "import torch;" "from torch.utils.benchmark import Timer;" "print(torch.__config__.show());" diff --git a/torch/utils/data/datapipes/gen_pyi.py b/torch/utils/data/datapipes/gen_pyi.py index c0f8a801bd07..2729c6296c08 100644 --- a/torch/utils/data/datapipes/gen_pyi.py +++ b/torch/utils/data/datapipes/gen_pyi.py @@ -44,10 +44,10 @@ def find_file_paths(dir_paths: List[str], files_to_exclude: Set[str]) -> Set[str def extract_method_name(line: str) -> str: """Extract method name from decorator in the form of "@functional_datapipe({method_name})".""" - if "(\"" in line: - start_token, end_token = "(\"", "\")" - elif "(\'" in line: - start_token, end_token = "(\'", "\')" + if '("' in line: + start_token, end_token = '("', '")' + elif "('" in line: + start_token, end_token = "('", "')" else: raise RuntimeError(f"Unable to find appropriate method name within line:\n{line}") start, end = line.find(start_token) + len(start_token), line.find(end_token) @@ -71,9 +71,9 @@ def parse_datapipe_file(file_path: str) -> Tuple[Dict[str, str], Dict[str, str], method_name, class_name, signature = "", "", "" skip = False for line in f: - if line.count("\"\"\"") % 2 == 1: + if line.count('"""') % 2 == 1: skip = not skip - if skip or "\"\"\"" in line: # Saving docstrings + if skip or '"""' in line: # Saving docstrings doc_string_dict[method_name].append(line) continue if "@functional_datapipe" in line: diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 39e7070144aa..59ee1b2f4743 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -973,7 +973,7 @@ def repl(m): hipify_result.current_state = CurrentState.DONE return hipify_result except PermissionError as e: - print(f"{bcolors.WARNING}Failed to save {fout_path} with \"{e.strerror}\", leaving {fin_path} unchanged.{bcolors.ENDC}", + print(f'{bcolors.WARNING}Failed to save {fout_path} with "{e.strerror}", leaving {fin_path} unchanged.{bcolors.ENDC}', file=sys.stderr) hipify_result.hipified_path = fin_path hipify_result.status = "[skipped, no permissions]" From fec8ef8c179220de8251a0abbc42239c2743f157 Mon Sep 17 00:00:00 2001 From: Yash Rathore Date: Sun, 2 Jun 2024 23:41:43 +0000 Subject: [PATCH 238/706] [Aten][BlasKernel] Add function prototype to fix compiler error (#127719) Adds a prototype for function `fp16_dot_with_fp32_arith()` in `aten/src/ATen/native/BlasKernel.cpp`. Without this patch the build fails on Apple silicon/MacOs (CPU) with the error `no previous prototype for function 'fp16_dot_with_fp32_arith' [-Werror,-Wmissing-prototypes]`. The function cannot be marked `static` because its use is not limited to this file. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127719 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/BlasKernel.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index 79076544fc33..642467e5c1e6 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -105,6 +105,11 @@ void fp16_gemv_trans( const float beta, float16_t* y, const int incy); + +float fp16_dot_with_fp32_arith( + const float16_t* vec1, + const float16_t* vec2, + int64_t len); #endif template From e57f51b80fe1457fcc1359ce92a64ed425c3ab48 Mon Sep 17 00:00:00 2001 From: bigning Date: Mon, 3 Jun 2024 01:55:03 +0000 Subject: [PATCH 239/706] Update _dedup_save_plans.py (#126569) To resolve https://github.com/pytorch/pytorch/issues/125740, save each tensor on the lowest rank. Fixes #125740 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126569 Approved by: https://github.com/LucasLLC --- torch/distributed/checkpoint/_dedup_save_plans.py | 12 ++++++++++-- torch/distributed/checkpoint/default_planner.py | 5 +++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/torch/distributed/checkpoint/_dedup_save_plans.py b/torch/distributed/checkpoint/_dedup_save_plans.py index 2160c7dc366d..16d46e73baff 100644 --- a/torch/distributed/checkpoint/_dedup_save_plans.py +++ b/torch/distributed/checkpoint/_dedup_save_plans.py @@ -11,7 +11,10 @@ __all__ = ["dedup_save_plans"] -def dedup_save_plans(all_plans: List[SavePlan]) -> List[SavePlan]: +def dedup_save_plans( + all_plans: List[SavePlan], + save_to_lowest_rank: bool = False, +) -> List[SavePlan]: """ Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry. @@ -29,7 +32,12 @@ def dedup_save_plans(all_plans: List[SavePlan]) -> List[SavePlan]: to_remove: List[Set] = [set() for _ in range(len(all_plans))] plan_to_size = [0] * len(all_plans) for write_item_idx, plan_indices in write_item_to_plan_indices.items(): - select_plan_idx = min(plan_indices, key=lambda plan_idx: plan_to_size[plan_idx]) + if save_to_lowest_rank: + select_plan_idx = min(plan_indices) + else: + select_plan_idx = min( + plan_indices, key=lambda plan_idx: plan_to_size[plan_idx] + ) write_item = write_item_idx_to_write_item[write_item_idx] # essentially ignores the storage size of anything that is not a tensor, since diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index c9590c38d3e6..57ca0f2a764f 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -67,11 +67,12 @@ def __init__( flatten_state_dict: bool = True, flatten_sharded_tensors: bool = True, dedup_replicated_tensors: Optional[bool] = None, + dedup_save_to_lowest_rank: bool = False, ) -> None: self.flatten_state_dict = flatten_state_dict self.flatten_sharded_tensors = flatten_sharded_tensors self.mappings = {} - + self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank if dedup_replicated_tensors is not None: logger.warning( "DefaultSavePlanner's `dedup_replicated_tensors` argument is being " @@ -103,7 +104,7 @@ def create_local_plan(self) -> SavePlan: def create_global_plan( self, all_plans: List[SavePlan] ) -> Tuple[List[SavePlan], Metadata]: - all_plans = dedup_save_plans(all_plans) + all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) global_plan, metadata = create_default_global_save_plan(all_plans) From 84776d7597c801fd23cdf8b8c320c633914b8bd4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 3 Jun 2024 02:52:47 +0000 Subject: [PATCH 240/706] Revert "[BE]: Update mypy to 1.10.0 (#127717)" This reverts commit 30213ab0a7b27277e76ea9dd707ce629a63d91ee. Reverted https://github.com/pytorch/pytorch/pull/127717 on behalf of https://github.com/huydhn due to I am not sure why but the failures look legit and they are showing up in trunk https://hud.pytorch.org/pytorch/pytorch/commit/30213ab0a7b27277e76ea9dd707ce629a63d91ee ([comment](https://github.com/pytorch/pytorch/pull/127717#issuecomment-2144183347)) --- .ci/docker/requirements-ci.txt | 4 ++-- .lintrunner.toml | 2 +- torch/ao/quantization/fx/fuse_handler.py | 2 +- torch/distributed/_composable/fsdp/fully_shard.py | 6 +----- torch/distributions/utils.py | 2 +- torch/fx/experimental/rewriter.py | 2 +- 6 files changed, 7 insertions(+), 11 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index e6866a94183a..0f5f1bb12bd5 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -85,10 +85,10 @@ librosa>=0.6.2 ; python_version < "3.11" #Pinned versions: #test that import: -mypy==1.10.0 +mypy==1.9.0 # Pin MyPy version because new errors are likely to appear with each release #Description: linter -#Pinned versions: 1.10.0 +#Pinned versions: 1.9.0 #test that import: test_typing.py, test_type_hints.py networkx==2.8.8 diff --git a/.lintrunner.toml b/.lintrunner.toml index e31279004348..eca2af96b761 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -136,7 +136,7 @@ init_command = [ 'numpy==1.24.3 ; python_version == "3.8"', 'numpy==1.26.0 ; python_version >= "3.9"', 'expecttest==0.1.6', - 'mypy==1.10.0', + 'mypy==1.9.0', 'sympy==1.11.1', 'types-requests==2.27.25', 'types-PyYAML==6.0.7', diff --git a/torch/ao/quantization/fx/fuse_handler.py b/torch/ao/quantization/fx/fuse_handler.py index 123a51f6ff87..718cc561bfa0 100644 --- a/torch/ao/quantization/fx/fuse_handler.py +++ b/torch/ao/quantization/fx/fuse_handler.py @@ -44,7 +44,7 @@ class DefaultFuseHandler(FuseHandler): def __init__( self, node: Node): - super().__init__(node) # type:ignore[safe-super] + super().__init__(node) def fuse(self, load_arg: Callable, diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 18967fe41468..3efb8f7afd85 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -342,8 +342,4 @@ def wrapped_method(self, *args, **kwargs): return fsdp_state._post_forward(self, args, out) # Use `__get__` to make `wrapped_method` an instance method - setattr( - module, - method_name, - wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined] - ) + setattr(module, method_name, wrapped_method.__get__(module, type(module))) diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index f897e63b7891..7a6d31a05722 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -117,7 +117,7 @@ class lazy_property: def __init__(self, wrapped): self.wrapped = wrapped - update_wrapper(self, wrapped) # type:ignore[arg-type] + update_wrapper(self, wrapped) def __get__(self, instance, obj_type=None): if instance is None: diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py index 969717e01030..85a95895f7c9 100644 --- a/torch/fx/experimental/rewriter.py +++ b/torch/fx/experimental/rewriter.py @@ -60,7 +60,7 @@ def change_func_globals(f, globals): closure=f.__closure__, ) g = functools.update_wrapper(g, f) - g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined] + g.__kwdefaults__ = copy.copy(f.__kwdefaults__) return g # Return the correct FunctionType object return change_func_globals(fn_compiled, globals=fn.__globals__) From 7e97b33fbbb24fa5876500be764c97f51b74ac0e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 3 Jun 2024 03:55:33 +0000 Subject: [PATCH 241/706] [Dynamo] Log backward graph compilation metrics (#126629) Fixes #125313 Compilation metric logs for the code example at #125313: ``` %s CompilationMetrics(compile_id='0/0', frame_key='1', co_name='forward', co_filename='/data/users/ybliang/debug/debug2.py', co_firstlineno=10, cache_size=0, accumulated_cache_size=0, guard_count=11, shape_env_guard_count=0, graph_op_count=1, graph_node_count=3, graph_input_count=1, start_time=1716247236.6165977, entire_frame_compile_time_s=7.926939964294434, backend_compile_time_s=7.887059926986694, inductor_compile_time_s=4.108498811721802, code_gen_time_s=3.97833514213562, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons={"'skip function graph_break in file /home/ybliang/local/pytorch/torch/_dynamo/decorators.py'"}, dynamo_time_before_restart_s=0.025330543518066406, has_guarded_code=True, is_fwd=True) %s CompilationMetrics(compile_id='1/0', frame_key='2', co_name='torch_dynamo_resume_in_forward_at_12', co_filename='/data/users/ybliang/debug/debug2.py', co_firstlineno=12, cache_size=0, accumulated_cache_size=0, guard_count=10, shape_env_guard_count=0, graph_op_count=2, graph_node_count=5, graph_input_count=1, start_time=1716247244.544928, entire_frame_compile_time_s=0.10148310661315918, backend_compile_time_s=0.08753013610839844, inductor_compile_time_s=0.03691983222961426, code_gen_time_s=0.022417306900024414, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons=set(), dynamo_time_before_restart_s=0.0, has_guarded_code=True, is_fwd=True) tensor([[-0.1622, -0.0000, -0.0000, 0.5643, -0.0000, 0.0000, -0.5087, 0.0914, -0.0000, -0.0421]], grad_fn=) %s CompilationMetrics(compile_id='1/0', frame_key=None, co_name=None, co_filename=None, co_firstlineno=None, cache_size=None, accumulated_cache_size=None, guard_count=None, shape_env_guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, start_time=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, inductor_compile_time_s=0.026738643646240234, code_gen_time_s=0.016446352005004883, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=None, compliant_custom_ops=None, restart_reasons=None, dynamo_time_before_restart_s=None, has_guarded_code=None, is_fwd=False) %s CompilationMetrics(compile_id='0/0', frame_key=None, co_name=None, co_filename=None, co_firstlineno=None, cache_size=None, accumulated_cache_size=None, guard_count=None, shape_env_guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, start_time=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, inductor_compile_time_s=0.14563536643981934, code_gen_time_s=0.08652091026306152, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=None, compliant_custom_ops=None, restart_reasons=None, dynamo_time_before_restart_s=None, has_guarded_code=None, is_fwd=False) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126629 Approved by: https://github.com/ezyang --- torch/_dynamo/convert_frame.py | 1 + torch/_dynamo/utils.py | 137 ++++++++++++++---- .../_aot_autograd/runtime_wrappers.py | 40 +---- torch/_inductor/compile_fx.py | 2 +- torch/_inductor/graph.py | 2 +- torch/_inductor/runtime/runtime_utils.py | 2 +- 6 files changed, 115 insertions(+), 69 deletions(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 6dcb84fab8fc..e779ccef9e38 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -876,6 +876,7 @@ def format_guard_failures(): dynamo_time_before_restart = time.time() - start_time metrics = CompilationMetrics( + str(compile_id), frame_key, code.co_name, code.co_filename, diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 9d43ba551cc1..2b42c8dec63d 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -113,7 +113,9 @@ compilation_time_metrics: Dict[str, List[float]] = {} # profiling compilation time by frame phase -frame_phase_timing: Dict[str, Dict[str, float]] = {} +frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict( + lambda: collections.defaultdict(float) +) timer_counter = itertools.count() @@ -185,6 +187,10 @@ def print_time_report(): print(out) +def _add_time_spent(key, phase_name, time_spent): + frame_phase_timing[key][phase_name] += time_spent + + # dynamo_timed API works as a function decorator # By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics # where the key is the functions name. @@ -201,9 +207,12 @@ def print_time_report(): # phase_names record an extra record into a separate compilation timing structure, # one keyed on frame+name rather than function. # The frame is incremented outside of this function, in def increment_frame() above. +# `fwd_only` is used to identify if this phase or function is only called +# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`. +# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs. -def dynamo_timed(original_function=None, phase_name=None): +def dynamo_timed(original_function=None, phase_name=None, fwd_only=True): def dynamo_timed_inner(func): if config.cprofile: return func @@ -213,19 +222,70 @@ def time_wrapper(*args, **kwargs): key = func.__qualname__ if key not in compilation_time_metrics: compilation_time_metrics[key] = [] - with torch.profiler.record_function(f"{key} (dynamo_timed)"): - t0 = time.time() - r = func(*args, **kwargs) - time_spent = time.time() - t0 - compilation_time_metrics[key].append(time_spent) - if phase_name: - frame_key = str(curr_frame) - if frame_key not in frame_phase_timing: - frame_phase_timing[frame_key] = {} - if phase_name not in frame_phase_timing[frame_key]: - frame_phase_timing[frame_key][phase_name] = time_spent - else: - frame_phase_timing[frame_key][phase_name] += time_spent + + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + time_spent = float("-inf") + try: + with torch.profiler.record_function(f"{key} (dynamo_timed)"): + t0 = time.time() + r = func(*args, **kwargs) + time_spent = time.time() - t0 + compilation_time_metrics[key].append(time_spent) + except Exception as e: + fail_type = str(type(e)) + fail_reason = str(e) + raise + finally: + # Only record backward compilation metrics if phase_name is not None! + if phase_name: + frame_key = str(curr_frame) + # fwd only compilation stages: entire_frame_compile, backend_compile. + # use frame_key as time aggregation key. + if fwd_only and fail_type is None: + _add_time_spent(frame_key, phase_name, time_spent) + else: + # fwd + bwd compilation stages: inductor_compile, code_gen. + # use frame_key as time aggregation key for fwd graphs; + # use compile_id as time aggregation key for bwd graphs. + if torch._guards.TracingContext.try_get() is not None: + aot_graph_name = str( + torch._guards.TracingContext.get().aot_graph_name + ) + if ( + "forward" in aot_graph_name + or "inference" in aot_graph_name + ) and fail_type is None: + _add_time_spent(frame_key, phase_name, time_spent) + elif "backward" in aot_graph_name: + compile_id = str( + torch._guards.CompileContext.current_compile_id() + ) + if fail_type is None: + _add_time_spent(compile_id, phase_name, time_spent) + + # log backward compilation metrics at the end of `inductor_compile` of bwd graph, + # one record for one bwd graph. + if phase_name == "inductor_compile": + if fail_type is None: + inductor_compile_time = frame_phase_timing[ + compile_id + ].get("inductor_compile", None) + code_gen_time = frame_phase_timing[ + compile_id + ].get("code_gen", None) + else: + inductor_compile_time = None + code_gen_time = None + metrics = BwdCompilationMetrics( + compile_id, + inductor_compile_time, + code_gen_time, + fail_type, + fail_reason, + ) + record_compilation_metrics(metrics) + return r return time_wrapper @@ -598,6 +658,7 @@ def proxy_args_kwargs(args, kwargs): @dataclasses.dataclass class CompilationMetrics: + compile_id: str frame_key: str co_name: str co_filename: str @@ -628,26 +689,44 @@ class CompilationMetrics: has_guarded_code: bool +@dataclasses.dataclass +class BwdCompilationMetrics: + compile_id: str + inductor_compile_time_s: Optional[float] + code_gen_time_s: Optional[float] + fail_type: Optional[str] + fail_reason: Optional[str] + + DEFAULT_COMPILATION_METRICS_LIMIT = 64 -_compilation_metrics: Deque[CompilationMetrics] = collections.deque( - maxlen=DEFAULT_COMPILATION_METRICS_LIMIT -) +_compilation_metrics: Deque[ + Union[CompilationMetrics, BwdCompilationMetrics] +] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT) -def record_compilation_metrics(compilation_metrics: CompilationMetrics): +def record_compilation_metrics( + compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics] +): global _compilation_metrics _compilation_metrics.append(compilation_metrics) - torch._logging.trace_structured( - "compilation_metrics", - lambda: { - k: list(v) if isinstance(v, set) else v - for k, v in dataclasses.asdict(compilation_metrics).items() - }, - ) - if config.log_compilation_metrics: - log_compilation_event(compilation_metrics) + if isinstance(compilation_metrics, CompilationMetrics): + name = "compilation_metrics" + else: + name = "bwd_compilation_metrics" + # Currently only record fwd compilation metrics, will add bwd compilation metrics + # after the internal Scuba logging changes finish. + if isinstance(compilation_metrics, CompilationMetrics): + torch._logging.trace_structured( + name, + lambda: { + k: list(v) if isinstance(v, set) else v + for k, v in dataclasses.asdict(compilation_metrics).items() + }, + ) + if config.log_compilation_metrics: + log_compilation_event(compilation_metrics) def set_compilation_metrics_limit(new_size: int) -> None: @@ -663,7 +742,7 @@ def clear_compilation_metrics() -> None: _compilation_metrics.clear() -def get_compilation_metrics() -> List[CompilationMetrics]: +def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMetrics]]: return list(_compilation_metrics) diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index a450f401f9e2..fd188eb6a700 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -7,7 +7,6 @@ """ import collections import pprint -import time from contextlib import nullcontext from dataclasses import dataclass, field from functools import wraps @@ -24,7 +23,6 @@ tracing, TracingContext, ) -from torch._logging import trace_structured from torch._prims_common import CUDARngStateHelper from torch._subclasses import FakeTensor @@ -1801,41 +1799,9 @@ def call_compiled_backward(): with tracing(saved_context), compile_context( saved_compile_context ), context(), track_graph_compiling(aot_config, "backward"): - fail_type: Optional[str] = None - fail_reason: Optional[str] = None - start_time = time.time() - try: - CompiledFunction.compiled_bw = aot_config.bw_compiler( - bw_module, placeholder_list - ) - except Exception as e: - fail_type = str(type(e)) - fail_reason = str(e) - if saved_compile_context is not None: - e.compile_id = saved_compile_context.compile_id # type: ignore[attr-defined] - raise - finally: - # TODO: Similar to CompilationMetrics, we would - # like to report inductor_compile_time, but we - # cannot conveniently do so because these are - # keyed on utils.frame, and frame key is not - # incremented on backwards compilations. Maybe - # should just bump the frame key here too? - end_time = time.time() - # TODO: Put this in scuba? But CompilationMetrics - # is kind of not a great match, because there's no - # interaction with Dynamo, so a lot of Dynamo only - # events don't exist anymore. So we need a new - # scuba table. Lazy lazy... - trace_structured( - "aot_autograd_backward_compilation_metrics", - lambda: { - "start_time": start_time, - "elapsed_time": time.time() - start_time, - "fail_type": fail_type, - "fail_reason": fail_reason, - }, - ) + CompiledFunction.compiled_bw = aot_config.bw_compiler( + bw_module, placeholder_list + ) out = call_func_at_runtime_with_args( CompiledFunction.compiled_bw, diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 77b8925a79a5..60db8b7ee465 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -408,7 +408,7 @@ def with_fresh_cache_if_config(f): # the backward graph as well. @_use_lazy_graph_module(dynamo_config.use_lazy_graph_module) @with_fresh_cache_if_config -@dynamo_utils.dynamo_timed(phase_name="inductor_compile") +@dynamo_utils.dynamo_timed(phase_name="inductor_compile", fwd_only=False) def compile_fx_inner( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 412caf5e2242..743018014c4c 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1681,7 +1681,7 @@ def count_bytes(self): node_runtimes.append((node, node.get_estimated_runtime())) return total_bytes, node_counts, node_runtimes - @dynamo_timed(phase_name="code_gen") + @dynamo_timed(phase_name="code_gen", fwd_only=False) def compile_to_module(self): from .codecache import PyCodeCache diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 7d24be0ded47..bc3a3d008f3c 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -187,7 +187,7 @@ def get_first_attr(obj, *attrs): dynamo_timed = torch._dynamo.utils.dynamo_timed except AttributeError: # Compile workers only have a mock version of torch - def dynamo_timed(original_function=None, phase_name=None): + def dynamo_timed(original_function=None, phase_name=None, fwd_only=True): if original_function: return original_function return dynamo_timed From 3399ad8d9d4ea083ca904a9f167218fce10f3969 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sun, 2 Jun 2024 18:41:09 -0700 Subject: [PATCH 242/706] [Inductor][CPP] Add UT for bitwise right shift (#127731) **Summary** Per the discussion in https://github.com/pytorch/pytorch/issues/127310, `bitwise_right_shift` failed in Torch 2.1 but pass with latest PyTorch, Add the UT in this PR to ensure the correctness. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_bitwise_right_shift ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127731 Approved by: https://github.com/Skylion007 --- test/inductor/test_cpu_repro.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 5a552aa15c96..d4d0e258c3e2 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1884,6 +1884,14 @@ def test_ops_masked_with_bool_input(self): self.assertEqual(res_aten_eager, res) check_metrics_vec_kernel_count(1) + def test_bitwise_right_shift(self): + x = torch.randint(-1, 0, (1, 1, 1), device="cpu", dtype=torch.int64) + bit_num = 31 + res_aten_eager = torch.bitwise_right_shift(x, bit_num) + cfn = torch.compile(torch.bitwise_right_shift) + res = cfn(x, bit_num) + self.assertEqual(res_aten_eager, res) + @patch("torch.cuda.is_available", lambda: False) def test_scatter_using_atomic_add(self): def fn(a, dim, index, b): From 1b182ea0d2130cc292200e011a5a0184f77b0efa Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 3 Jun 2024 04:06:19 +0000 Subject: [PATCH 243/706] Remove c10::guts::{conjunction,disjunction} (#127726) They are not used in Pytorch OSS. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127726 Approved by: https://github.com/ezyang --- c10/util/C++17.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/c10/util/C++17.h b/c10/util/C++17.h index 1f62adb9bb00..fe2044f507d4 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -54,11 +54,6 @@ make_unique_base(Args&&... args) { return std::unique_ptr(new Child(std::forward(args)...)); } -template -using conjunction = std::conjunction; -template -using disjunction = std::disjunction; - #if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) && !defined(__HIP__) template From 288df042c5f30dae54d1e26bbf2623e2fb846f4d Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 3 Jun 2024 04:34:36 +0000 Subject: [PATCH 244/706] [1/N] Change static functions in headers to inline (#127727) So that it may fix some tricky linking issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127727 Approved by: https://github.com/ezyang --- c10/core/Backend.h | 13 ++++++------ c10/core/DispatchKeySet.h | 2 +- c10/core/ScalarType.h | 36 ++++++++++++++++----------------- c10/core/ScalarTypeToTypeMeta.h | 14 ++++++------- c10/util/int128.h | 4 ++-- c10/util/strides.h | 2 +- 6 files changed, 35 insertions(+), 36 deletions(-) diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 1cf1782fa570..8ecaa7be7377 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -65,7 +65,7 @@ enum class Backend { NumOptions }; -static inline Backend dispatchKeyToBackend(DispatchKey t) { +inline Backend dispatchKeyToBackend(DispatchKey t) { if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) { return Backend::CPU; } else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) { @@ -142,7 +142,7 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) { } } -static inline DispatchKey backendToDispatchKey(Backend b) { +inline DispatchKey backendToDispatchKey(Backend b) { switch (b) { case Backend::CPU: return DispatchKey::CPU; @@ -217,7 +217,7 @@ static inline DispatchKey backendToDispatchKey(Backend b) { } } -static inline DeviceType backendToDeviceType(Backend b) { +inline DeviceType backendToDeviceType(Backend b) { switch (b) { case Backend::CPU: case Backend::MkldnnCPU: @@ -281,8 +281,7 @@ static inline DeviceType backendToDeviceType(Backend b) { } } -// TODO: This probably shouldn't actually be static inline -static inline const char* toString(Backend b) { +inline const char* toString(Backend b) { switch (b) { case Backend::CPU: return "CPU"; @@ -357,7 +356,7 @@ static inline const char* toString(Backend b) { } } -static inline bool isSparse(Backend b) { +inline bool isSparse(Backend b) { switch (b) { case Backend::SparseXPU: case Backend::SparseCPU: @@ -371,7 +370,7 @@ static inline bool isSparse(Backend b) { } } -static inline bool isSparseCsr(Backend b) { +inline bool isSparseCsr(Backend b) { switch (b) { case Backend::SparseCsrXPU: case Backend::SparseCsrCPU: diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index f7461ea73a6d..4c391d60f2b0 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -901,7 +901,7 @@ C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias); // legacy code that is still using DispatchKey for things like instanceof // checks; if at all possible, refactor the code to stop using DispatchKey in // those cases. -static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { +inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { // NB: If you add any extra keys that can be stored in TensorImpl on // top of existing "backend" keys like CPU/CUDA, you need to add it // here. At the moment, autograd keys and ADInplaceOrView key need this diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 590b24a7bc20..f7f059fd513d 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -315,7 +315,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) #undef DEFINE_CONSTANT -static inline const char* toString(ScalarType t) { +inline const char* toString(ScalarType t) { #define DEFINE_CASE(_, name) \ case ScalarType::name: \ return #name; @@ -328,7 +328,7 @@ static inline const char* toString(ScalarType t) { #undef DEFINE_CASE } -static inline size_t elementSize(ScalarType t) { +inline size_t elementSize(ScalarType t) { #define CASE_ELEMENTSIZE_CASE(ctype, name) \ case ScalarType::name: \ return sizeof(ctype); @@ -341,7 +341,7 @@ static inline size_t elementSize(ScalarType t) { #undef CASE_ELEMENTSIZE_CASE } -static inline bool isIntegralType(ScalarType t, bool includeBool) { +inline bool isIntegralType(ScalarType t, bool includeBool) { bool isIntegral = (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || t == ScalarType::Long || t == ScalarType::Short || @@ -353,44 +353,44 @@ static inline bool isIntegralType(ScalarType t, bool includeBool) { C10_DEPRECATED_MESSAGE( "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.") -static inline bool isIntegralType(ScalarType t) { +inline bool isIntegralType(ScalarType t) { return isIntegralType(t, /*includeBool=*/false); } -static inline bool isFloat8Type(ScalarType t) { +inline bool isFloat8Type(ScalarType t) { return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz || t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz; } -static inline bool isReducedFloatingType(ScalarType t) { +inline bool isReducedFloatingType(ScalarType t) { return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t); } -static inline bool isFloatingType(ScalarType t) { +inline bool isFloatingType(ScalarType t) { return t == ScalarType::Double || t == ScalarType::Float || isReducedFloatingType(t); } -static inline bool isComplexType(ScalarType t) { +inline bool isComplexType(ScalarType t) { return ( t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble); } -static inline bool isQIntType(ScalarType t) { +inline bool isQIntType(ScalarType t) { // Don't forget to extend this when adding new QInt types return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || t == ScalarType::QUInt2x4; } -static inline bool isBitsType(ScalarType t) { +inline bool isBitsType(ScalarType t) { return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || t == ScalarType::Bits16; } -static inline bool isBarebonesUnsignedType(ScalarType t) { +inline bool isBarebonesUnsignedType(ScalarType t) { return t == ScalarType::UInt1 || t == ScalarType::UInt2 || t == ScalarType::UInt3 || t == ScalarType::UInt4 || t == ScalarType::UInt5 || t == ScalarType::UInt6 || @@ -398,7 +398,7 @@ static inline bool isBarebonesUnsignedType(ScalarType t) { t == ScalarType::UInt32 || t == ScalarType::UInt64; } -static inline ScalarType toQIntType(ScalarType t) { +inline ScalarType toQIntType(ScalarType t) { switch (t) { case ScalarType::Byte: return ScalarType::QUInt8; @@ -411,7 +411,7 @@ static inline ScalarType toQIntType(ScalarType t) { } } -static inline ScalarType toUnderlying(ScalarType t) { +inline ScalarType toUnderlying(ScalarType t) { switch (t) { case ScalarType::QUInt8: case ScalarType::QUInt4x2: @@ -427,7 +427,7 @@ static inline ScalarType toUnderlying(ScalarType t) { } } -static inline bool isSignedType(ScalarType t) { +inline bool isSignedType(ScalarType t) { #define CASE_ISSIGNED(name) \ case ScalarType::name: \ return std::numeric_limits< \ @@ -484,11 +484,11 @@ static inline bool isSignedType(ScalarType t) { #undef CASE_ISSIGNED } -static inline bool isUnderlying(ScalarType type, ScalarType qtype) { +inline bool isUnderlying(ScalarType type, ScalarType qtype) { return type == toUnderlying(qtype); } -static inline ScalarType toRealValueType(ScalarType t) { +inline ScalarType toRealValueType(ScalarType t) { switch (t) { case ScalarType::ComplexHalf: return ScalarType::Half; @@ -501,7 +501,7 @@ static inline ScalarType toRealValueType(ScalarType t) { } } -static inline ScalarType toComplexType(ScalarType t) { +inline ScalarType toComplexType(ScalarType t) { switch (t) { case ScalarType::BFloat16: // BFloat16 has range equivalent to Float, @@ -526,7 +526,7 @@ static inline ScalarType toComplexType(ScalarType t) { // see tensor_attributes.rst for detailed explanation and examples // of casting rules. -static inline bool canCast(const ScalarType from, const ScalarType to) { +inline bool canCast(const ScalarType from, const ScalarType to) { // We disallow complex -> non complex, e.g., float_tensor *= complex is // disallowed. if (isComplexType(from) && !isComplexType(to)) { diff --git a/c10/core/ScalarTypeToTypeMeta.h b/c10/core/ScalarTypeToTypeMeta.h index 910e0d24b0a3..d2694c96221e 100644 --- a/c10/core/ScalarTypeToTypeMeta.h +++ b/c10/core/ScalarTypeToTypeMeta.h @@ -13,21 +13,21 @@ namespace c10 { /** * convert ScalarType enum values to TypeMeta handles */ -static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { +inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { return caffe2::TypeMeta::fromScalarType(scalar_type); } /** * convert TypeMeta handles to ScalarType enum values */ -static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { +inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { return dtype.toScalarType(); } /** * typeMetaToScalarType(), lifted to optional */ -static inline optional optTypeMetaToScalarType( +inline optional optTypeMetaToScalarType( optional type_meta) { if (!type_meta.has_value()) { return c10::nullopt; @@ -38,19 +38,19 @@ static inline optional optTypeMetaToScalarType( /** * convenience: equality across TypeMeta/ScalarType conversion */ -static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { +inline bool operator==(ScalarType t, caffe2::TypeMeta m) { return m.isScalarType(t); } -static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { +inline bool operator==(caffe2::TypeMeta m, ScalarType t) { return t == m; } -static inline bool operator!=(ScalarType t, caffe2::TypeMeta m) { +inline bool operator!=(ScalarType t, caffe2::TypeMeta m) { return !(t == m); } -static inline bool operator!=(caffe2::TypeMeta m, ScalarType t) { +inline bool operator!=(caffe2::TypeMeta m, ScalarType t) { return !(t == m); } diff --git a/c10/util/int128.h b/c10/util/int128.h index b97a59446da2..7da595b79178 100644 --- a/c10/util/int128.h +++ b/c10/util/int128.h @@ -49,7 +49,7 @@ struct uint128_pod; #endif class uint128; -static inline uint128& operator<<=(uint128& self, int amount); +inline uint128& operator<<=(uint128& self, int amount); // An unsigned 128-bit integer type. Thread-compatible. class C10_API uint128 { @@ -277,7 +277,7 @@ inline uint128 operator>>(const uint128& val, int amount) { } } -static inline uint128& operator<<=(uint128& self, int amount) { +inline uint128& operator<<=(uint128& self, int amount) { // uint64_t shifts of >= 64 are undefined, so we will need some // special-casing. if (amount < 64) { diff --git a/c10/util/strides.h b/c10/util/strides.h index 980540b5b97a..d3d38fd7d011 100644 --- a/c10/util/strides.h +++ b/c10/util/strides.h @@ -6,7 +6,7 @@ namespace c10 { // Computes the contiguous strides of a tensor, given its sizes. -static inline DimVector contiguous_strides(const IntArrayRef sizes) { +inline DimVector contiguous_strides(const IntArrayRef sizes) { using Int = IntArrayRef::value_type; const Int dims = static_cast(sizes.size()); From e2e3ca94ccce1c0abbfd75ac0368793e1756c268 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 3 Jun 2024 04:35:49 +0000 Subject: [PATCH 245/706] [Inductor][Flex-attention] Support different sequence lengths for Query and Key/Value (#127678) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/127678 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 53 ++++++++++---- torch/_inductor/kernel/flex_attention.py | 93 ++++++++++++++---------- torch/_inductor/select_algorithm.py | 9 ++- torch/nn/attention/_flex_attention.py | 5 -- 4 files changed, 98 insertions(+), 62 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index d4feead90301..c4504e5574a1 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -155,23 +155,33 @@ def run_test( self, score_mod: Callable, dtype: torch.dtype = torch.float16, - B: int = B, - H: int = H, - S: int = S, - D: int = D, + Q_B: int = B, + Q_H: int = H, + Q_S: int = S, + Q_D: int = D, + KV_B: int = B, + KV_H: int = H, + KV_S: int = S, + KV_D: int = D, ): sdpa_partial = create_attention(score_mod) compiled_sdpa = torch.compile(sdpa_partial) - q = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q = torch.randn( + (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True + ) + k = torch.randn( + (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True + ) + v = torch.randn( + (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True + ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) golden_out = sdpa_partial(q_gold, k_gold, v_gold) ref_out = sdpa_partial(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) - backward_grad = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + backward_grad = torch.randn((Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda") golden_out.backward(backward_grad.to(torch.float64)) ref_out.backward(backward_grad) @@ -345,6 +355,25 @@ def test_builtin_score_mods_automatic_dynamic( ): self.run_automatic_dynamic_test(score_mod, dtype) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes_fast) + @common_utils.parametrize("score_mod", test_score_mods) + def test_builtin_score_mods_different_seqlen( + self, dtype: torch.dtype, score_mod: Callable + ): + self.run_test( + score_mod, + dtype, + B, + H, + S // 2, # Seqlen of Q is different from seqlen of K/V + D, + B, + H, + S, + D, + ) + @supported_platform @common_utils.parametrize("dtype", test_dtypes) def test_skip_odd_keys(self, dtype: torch.dtype): @@ -721,14 +750,6 @@ def test_mixed_dtypes_fails(self): ): _flex_attention(query, key, value, _identity) - @supported_platform - def test_different_sequence_length_fails(self): - query = torch.randn((1, 1, 2048, 64), dtype=torch.float32, device="cuda") - key = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") - value = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") - with self.assertRaisesRegex(ValueError, "NYI: The target sequence length"): - _flex_attention(query, key, value, _identity) - @supported_platform @patch.object(torch._inductor.config, "max_autotune", True) def test_max_autotune(self): diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 3e95dd4f65ce..42fabf65591d 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -189,6 +189,7 @@ def build_subgraph_buffer( Z = {{size("Q", 0)}} H = {{size("Q", 1)}} Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty @@ -196,9 +197,10 @@ def build_subgraph_buffer( start_m = tl.program_id(0) off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh + q_offset = off_hz * stride_qh + kv_offset = off_hz * stride_kh Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, + base=Q + q_offset, shape=(Q_LEN, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), @@ -206,16 +208,16 @@ def build_subgraph_buffer( order=(1, 0) ) K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, Q_LEN), + base=K + kv_offset, + shape=(BLOCK_DMODEL, KV_LEN), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(Q_LEN, BLOCK_DMODEL), + base=V + kv_offset, + shape=(KV_LEN, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), @@ -235,7 +237,7 @@ def build_subgraph_buffer( q = (q * qk_scale).to(MATMUL_PRECISION) # loop over k, v and update accumulator lo = 0 - hi = Q_LEN + hi = KV_LEN for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- load k, v -- @@ -425,6 +427,7 @@ def flex_attention(*args, **kwargs): ], num_stages=num_stages, num_warps=num_warps, + call_sizes=query.get_size(), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=query.get_size()[-1], @@ -445,7 +448,9 @@ def flex_attention(*args, **kwargs): # ---------------------------- Backward HOP Implementation ---------------------------- -def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, meta): +def flex_attention_backward_grid( + batch_size, num_heads, num_queries, d_model, num_key_value, meta +): """How is this kernel parallelized? Currently this is only parallelizing over batch * num_heads, but we can, and want to parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require @@ -453,8 +458,6 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me """ import triton - # TODO: support different seqlen for Query and Key/Value. - num_key_value = num_queries return ( triton.cdiv(num_queries, meta["BLOCK_M2"]) + triton.cdiv(num_key_value, meta["BLOCK_N1"]), @@ -476,7 +479,7 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me # DK: Derivative of Key, is the written to via the store_output call due to some limitations with # inductor codegen # M: Number of queries, N: Number of keys/values, D: Model dimension - # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim # (Modifiable) Config options: # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. @@ -486,10 +489,20 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me # change of base out of the loop # Define Q Strides - stride_z = {{stride("Q", 0)}} - stride_h = {{stride("Q", 1)}} - stride_tok = {{stride("Q", 2)}} - stride_d = {{stride("Q", 3)}} + stride_qz = {{stride("Q", 0)}} + stride_qh = {{stride("Q", 1)}} + stride_qm = {{stride("Q", 2)}} + stride_qd = {{stride("Q", 3)}} + # Define K Strides + stride_kz = {{stride("K", 0)}} + stride_kh = {{stride("K", 1)}} + stride_km = {{stride("K", 2)}} + stride_kd = {{stride("K", 3)}} + # Define V Strides + stride_vz = {{stride("V", 0)}} + stride_vh = {{stride("V", 1)}} + stride_vm = {{stride("V", 2)}} + stride_vd = {{stride("V", 3)}} Z = {{size("Q", 0)}} H = {{size("Q", 1)}} @@ -501,21 +514,22 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me pid = tl.program_id(0) NUM_KV_BLOCKS = KV_LEN // BLOCK_N1 - bhid = tl.program_id(2) - off_chz = (bhid * Q_LEN).to(tl.int64) - adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) - off_hz = tl.program_id(2) off_z = off_hz // H # batch idx off_h = off_hz % H # head idx + off_chz = (off_hz * Q_LEN).to(tl.int64) + q_adj = (stride_qh * (off_hz % H) + stride_qz * (off_hz // H)).to(tl.int64) + k_adj = (stride_kh * (off_hz % H) + stride_kz * (off_hz // H)).to(tl.int64) + v_adj = (stride_vh * (off_hz % H) + stride_vz * (off_hz // H)).to(tl.int64) + # offset pointers for batch/head - Q += adj - K += adj - V += adj - DO += adj - DQ += adj - DV += adj + Q += q_adj + K += k_adj + V += v_adj + DO += q_adj + DQ += q_adj + DV += v_adj LSE += off_chz DELTA += off_chz @@ -528,9 +542,9 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) - q = tl.load(Q + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) + q = tl.load(Q + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd) dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) - do = tl.load(DO + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) + do = tl.load(DO + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd) lse = tl.load(LSE + offs_m2) lse = lse[:, None] @@ -538,8 +552,8 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me start_n2 = 0 offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) offs_n2 = start_n2 + tl.arange(0, BLOCK_N2) - kT_ptrs = K + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d - vT_ptrs = V + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d + kT_ptrs = K + offs_n2[None, :] * stride_km + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vm + offs_k[:, None] * stride_vd Di = tl.load(DELTA + offs_m2) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) @@ -590,10 +604,10 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me dq += tl.dot(ds, tl.trans(kT)) # Increment pointers. curr_n += BLOCK_N2 - kT_ptrs += BLOCK_N2 * stride_tok - vT_ptrs += BLOCK_N2 * stride_tok + kT_ptrs += BLOCK_N2 * stride_km + vT_ptrs += BLOCK_N2 * stride_km # Write back dQ. - dq_ptrs = DQ + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d + dq_ptrs = DQ + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd tl.store(dq_ptrs, dq) else: # THIS BLOCK DOES DK & DV @@ -606,13 +620,13 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) + k = tl.load(K + offs_n1[:, None] * stride_km + offs_k[None, :] * stride_kd) + v = tl.load(V + offs_n1[:, None] * stride_vm + offs_k[None, :] * stride_vd) offs_m1 = start_m1 + tl.arange(0, BLOCK_M1) offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) - qT_ptrs = Q + offs_m1[None, :] * stride_tok + offs_k[:, None] * stride_d - do_ptrs = DO + offs_m1[:, None] * stride_tok + offs_k[None, :] * stride_d + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_qm + offs_k[None, :] * stride_qd # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) @@ -668,10 +682,10 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT)) # Increment pointers. curr_m += BLOCK_M1 - qT_ptrs += BLOCK_M1 * stride_tok - do_ptrs += BLOCK_M1 * stride_tok + qT_ptrs += BLOCK_M1 * stride_qm + do_ptrs += BLOCK_M1 * stride_qm - dv_ptrs = DV + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d + dv_ptrs = DV + offs_n1[:, None] * stride_vm + offs_k[None, :] * stride_vd tl.store(dv_ptrs, dv) # Write back dK. @@ -773,6 +787,7 @@ def flex_attention_backward(*args, **kwargs): layout=layout_k, # We use store_output only for grad_key subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer], mutated_inputs=[grad_query, grad_value], + call_sizes=query.get_size() + [key.get_size()[2]], num_stages=num_stages, num_warps=num_warps, BLOCK_M1=BLOCK_M, diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index f9221d0dd49b..5b24a002082c 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -579,6 +579,7 @@ def generate( epilogue_fn=identity, subgraphs=None, mutated_inputs=None, + call_sizes=None, **kwargs, ): """This function generates a TritonTemplateCaller @@ -613,6 +614,9 @@ def generate( "64-bit indexing is not yet implemented for triton templates" ) + if call_sizes is None: + call_sizes = layout.size + kernel_options = dict( input_nodes=input_nodes, defines=defines, @@ -620,13 +624,14 @@ def generate( num_warps=num_warps, grid_fn=self.grid, meta=kwargs, - call_sizes=layout.size, + call_sizes=call_sizes, prefix_args=prefix_args, suffix_args=suffix_args, epilogue_fn=epilogue_fn, index_dtype="tl.int32", subgraphs=subgraphs, ) + with patch.object( V.graph, "get_dtype", self._fake_get_dtype(fake_out) ), TritonTemplateKernel( @@ -700,7 +705,7 @@ def make_kernel_render(out_node): assert mod.__file__ is not None grid = self.grid( *V.graph.sizevars.size_hints( - layout.size, + call_sizes, fallback=config.unbacked_symint_fallback, ), kwargs, diff --git a/torch/nn/attention/_flex_attention.py b/torch/nn/attention/_flex_attention.py index bd999ec39118..430d3280442a 100644 --- a/torch/nn/attention/_flex_attention.py +++ b/torch/nn/attention/_flex_attention.py @@ -101,11 +101,6 @@ def score_mod( # Some basic input validation _validate_sdpa_input(query, key, value) - # This will restriction will be removed in newer version of the kernel - if query.size(-2) != key.size(-2): - raise ValueError( - "NYI: The target sequence length (L) of the query tensor must match the source sequence length (S) of the key tensor." - ) if query.size(-2) % 128 != 0: raise ValueError("NYI: S and L must be a multiple of 128") From 48846cd1646d4e7afe5bf41500a93f89be105ea6 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Mon, 3 Jun 2024 05:55:00 +0000 Subject: [PATCH 246/706] Update torch-xpu-ops pin (ATen XPU implementation) (#127730) Regular bi-weekly pin update. 1. Porting operator relative PyTorch unit tests. The existing operators in torch-xpu-ops are covered by, 1) Operator specific test, like test_binary_ufuncs.py. 2) Operator common test, like test_ops.py. 2. Bugfixing under the latest PyTorch unit test scope, https://github.com/intel/torch-xpu-ops/tree/release/2.4/test/xpu. Totally 297 ATen operators are implemented in torch-xpu-ops. https://github.com/intel/torch-xpu-ops/blob/release/2.4/yaml/xpu_functions.yaml Pull Request resolved: https://github.com/pytorch/pytorch/pull/127730 Approved by: https://github.com/EikanWang --- third_party/xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index d3e312dadded..7131a86c765c 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -aba5d332bb88d422a1256bb2ca5f60243ffc270f +bd76ae2a5a233ae57911c1de81322dcea19493c1 From 6d21685b45336b793b1172a7ce76b0bf3876eebf Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 May 2024 12:00:40 -0700 Subject: [PATCH 247/706] [DSD] Fixes various bugs for broadcast_from_rank0 (#127635) Fixes https://github.com/pytorch/pytorch/issues/126285 Summary: 1. Fixes https://github.com/pytorch/pytorch/issues/126285 2. Broadcasting one tensor per time to avoid OOM. 3. Add some docstring Pull Request resolved: https://github.com/pytorch/pytorch/pull/127635 Approved by: https://github.com/weifengpy --- .../distributed/checkpoint/test_state_dict.py | 16 +++++++++++++ torch/distributed/_state_dict_utils.py | 24 +++++++++++++++---- torch/distributed/checkpoint/state_dict.py | 4 +++- 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 8039c487962f..ccd1303c26db 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -605,9 +605,11 @@ def check(equal): # Drop the states to simulate loading from rank0 if dist.get_rank() > 0: load_states = {} + load_states2 = {} load_optim_states = {} else: load_states = copy.deepcopy(states) + load_states2 = copy.deepcopy(states) load_optim_states = copy.deepcopy(optim_states) set_model_state_dict( @@ -625,7 +627,21 @@ def check(equal): broadcast_from_rank0=True, full_state_dict=True ), ) + check(equal=True) + # Verify the `strict` flag. + load_states = load_states2 + if load_states: + key = next(iter(load_states.keys())) + load_states.pop(key) + with self.assertRaisesRegex(RuntimeError, "Missing key"): + set_model_state_dict( + fsdp_model, + model_state_dict=load_states, + options=StateDictOptions( + broadcast_from_rank0=True, full_state_dict=True + ), + ) device_mesh = init_device_mesh("cuda", (self.world_size,)) self.run_subtests( diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index fced8800047c..2ec7be89c9e0 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -513,7 +513,11 @@ def _broadcast_tensors( if pg is None: pg = dist.distributed_c10d._get_default_group() - dist._broadcast_coalesced(pg, tensors, 500, 0) + + if len(tensors) > 1: + dist._broadcast_coalesced(pg, tensors, 500, 0) + else: + dist.broadcast(tensors[0], src=0, group=pg) for key in keys: _local_state = local_state_dict.get(key, None) @@ -532,9 +536,11 @@ def _broadcast_state_dict( local_state_dict: Dict[str, Any], device: torch.device, pg: Optional[dist.ProcessGroup] = None, + strict: bool = False, ) -> None: - # Gather the full state dict keys, non tensor values, scalar tensor values, - # and tensor information. + # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`. + # If strict is True, any keys in `local_state_dict` but not in `full_state_dict` + # will be removed from `local_state_dict`. ret = {} if dist.get_rank() == 0: for key, value in full_state_dict.items(): @@ -551,7 +557,10 @@ def _broadcast_state_dict( # Gather values keys = [] + local_state_dict_keys = set(local_state_dict.keys()) + global_keys = set() for key, value in ret.items(): + global_keys.add(key) if not isinstance(value, _TensorInfo): if key in local_state_dict: local_state_dict[key] = value @@ -561,11 +570,16 @@ def _broadcast_state_dict( ret[key] = full_state_dict[key] keys.append(key) - # Broadcast every 10 tensors, just hardcode the number for now - if len(keys) >= 10: + # Broadcast every tensor to avoid OOM for now. + if len(keys) >= 1: _broadcast_tensors(ret, local_state_dict, keys, device, pg) keys.clear() + if strict: + if missing_keys := (local_state_dict_keys - global_keys): + for key in missing_keys: + local_state_dict.pop(key) + if keys: _broadcast_tensors(ret, local_state_dict, keys, device, pg) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 714210910072..0c4cc32c09a1 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -523,7 +523,9 @@ def _load_model_state_dict( else: assert device == value.device assert device is not None - _broadcast_state_dict(state_dict, local_state_dict, device=device) + _broadcast_state_dict( + state_dict, local_state_dict, device=device, strict=info.strict + ) for fqn, local_state in local_state_dict.items(): state_dict[fqn] = local_state From 10e3406ea5d115a54a7d753d33110762eb6c07ff Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 3 Jun 2024 07:15:44 +0000 Subject: [PATCH 248/706] [Inductor] Add FlexAttention backward kernel dynamic shape tests (#127728) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/127728 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 158 +++++++++++++++++++-------- 1 file changed, 110 insertions(+), 48 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index c4504e5574a1..59fbd31c09cf 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -151,6 +151,47 @@ def _check_equal( msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." self.assertTrue(False, msg) + def _check_out_and_grad( + self, + golden_out: torch.Tensor, + ref_out: torch.Tensor, + compiled_out: torch.Tensor, + q_gold: torch.Tensor, + q_ref: torch.Tensor, + q: torch.Tensor, + k_gold: torch.Tensor, + k_ref: torch.Tensor, + k: torch.Tensor, + v_gold: torch.Tensor, + v_ref: torch.Tensor, + v: torch.Tensor, + ): + dtype = ref_out.dtype + with torch.no_grad(): + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + + # Checkout output + self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") + + # Check gradients + q_fudge_factor = 2.5 * fudge_factor + self._check_equal( + q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" + ) + k_fudge_factor = 4 * fudge_factor + self._check_equal( + k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" + ) + v_fudge_factor = 4 * fudge_factor + self._check_equal( + v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" + ) + def run_test( self, score_mod: Callable, @@ -187,30 +228,20 @@ def run_test( ref_out.backward(backward_grad) compiled_out.backward(backward_grad) - with torch.no_grad(): - # Note, it seems like we really are less accurate than the float32 - # computation, likely due to the online softmax - if dtype == torch.float32: - fudge_factor = 10.0 - else: - fudge_factor = 1.1 - - # Checkout output - self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") - - # Check gradients - q_fudge_factor = 2.5 * fudge_factor - self._check_equal( - q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" - ) - k_fudge_factor = 4 * fudge_factor - self._check_equal( - k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" - ) - v_fudge_factor = 4 * fudge_factor - self._check_equal( - v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" - ) + self._check_out_and_grad( + golden_out, + ref_out, + compiled_out, + q_gold, + q_ref, + q, + k_gold, + k_ref, + k, + v_gold, + v_ref, + v, + ) def run_dynamic_test( self, @@ -223,24 +254,34 @@ def run_dynamic_test( ): sdpa_partial = create_attention(score_mod) # The first eager batch, shape (B, H, S, D) - q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out1 = sdpa_partial( - q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) - ) - ref_out1 = sdpa_partial(q1, k1, v1) + q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1) + q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64) + ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref) + golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold) + + backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + + golden_out1.backward(backward_grad1.to(torch.float64)) + ref_out1.backward(backward_grad1) # The second eager batch, shape (B * 2, H, S / 2, D) B = int(B * 2) S = int(S / 2) - q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out2 = sdpa_partial( - q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) - ) - ref_out2 = sdpa_partial(q2, k2, v2) + q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2) + q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64) + ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref) + golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold) + + backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + + golden_out2.backward(backward_grad2.to(torch.float64)) + ref_out2.backward(backward_grad2) # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. # We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation. @@ -248,20 +289,41 @@ def run_dynamic_test( # Compiling with dynamic shape in the first batch. compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) compiled_out1 = compiled_sdpa(q1, k1, v1) - - # Note, it seems like we really are less accurate than the float32 - # computation, likely due to the online softmax - if dtype == torch.float32: - fudge_factor = 10.0 - else: - fudge_factor = 1.1 - - self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) + compiled_out1.backward(backward_grad1) + + self._check_out_and_grad( + golden_out1, + ref_out1, + compiled_out1, + q1_gold, + q1_ref, + q1, + k1_gold, + k1_ref, + k1, + v1_gold, + v1_ref, + v1, + ) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # No re-compilation, use the compiled dynamic shape version. compiled_out2 = compiled_sdpa(q2, k2, v2) - self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) + compiled_out2.backward(backward_grad2) + self._check_out_and_grad( + golden_out2, + ref_out2, + compiled_out2, + q2_gold, + q2_ref, + q2, + k2_gold, + k2_ref, + k2, + v2_gold, + v2_ref, + v2, + ) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) def run_automatic_dynamic_test( From 2d1ad0c31a2c3021b3ac274af98c2ec8449e010c Mon Sep 17 00:00:00 2001 From: diwei sun Date: Mon, 3 Jun 2024 07:37:27 +0000 Subject: [PATCH 249/706] [CI] Add freezing for cpu inductor accuracy test in inductor CI (#124715) This PR is to enable '--freezing' when running dynamo accuracy check in CI. Backgroud: ISSUES[#124286](https://github.com/pytorch/pytorch/issues/124286) is not captured by CI since freezing is not enabled for cpu-inductor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124715 Approved by: https://github.com/chuanqi129, https://github.com/jgong5, https://github.com/atalman, https://github.com/desertfire --- .ci/pytorch/test.sh | 6 +- .github/workflows/inductor.yml | 5 + ...nductor_huggingface_freezing_inference.csv | 185 ++++++++++ .../cpu_inductor_timm_freezing_inference.csv | 245 +++++++++++++ ...inductor_torchbench_freezing_inference.csv | 341 ++++++++++++++++++ 5 files changed, 781 insertions(+), 1 deletion(-) create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 1d185747abf8..ee4bf37fdb0b 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -565,7 +565,11 @@ test_dynamo_benchmark() { test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@" else if [[ "${TEST_CONFIG}" == *cpu_inductor* ]]; then - test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --float32 "$@" + if [[ "${TEST_CONFIG}" == *freezing* ]]; then + test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --float32 --freezing "$@" + else + test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --float32 "$@" + fi elif [[ "${TEST_CONFIG}" == *aot_inductor* ]]; then test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "$@" else diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 4afd87f056f3..0f9c81104f9f 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -230,6 +230,11 @@ jobs: { config: "cpu_inductor_timm", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_huggingface_freezing", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, + { config: "cpu_inductor_timm_freezing", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_timm_freezing", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv new file mode 100644 index 000000000000..3942e3a2f343 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv @@ -0,0 +1,341 @@ +name,accuracy,graph_breaks + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,model_fail_to_load,0 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,pass,42 + + + +detectron2_fasterrcnn_r_101_dc5,pass,42 + + + +detectron2_fasterrcnn_r_101_fpn,pass,46 + + + +detectron2_fasterrcnn_r_50_c4,pass,42 + + + +detectron2_fasterrcnn_r_50_dc5,pass,42 + + + +detectron2_fasterrcnn_r_50_fpn,pass,46 + + + +detectron2_fcos_r_50_fpn,pass,23 + + + +detectron2_maskrcnn_r_101_c4,pass,57 + + + +detectron2_maskrcnn_r_101_fpn,fail_accuracy,64 + + + +detectron2_maskrcnn_r_50_c4,fail_accuracy,57 + + + +detectron2_maskrcnn_r_50_fpn,pass,64 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5_base,pass,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,pass,2 + + + +mobilenet_v3_large,pass,0 + + + +moco,model_fail_to_load,0 + + + +moondream,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,pass,2 + + + +resnext50_32x4d,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientdet,model_fail_to_load,0 + + + +timm_efficientnet,pass,0 + + + +timm_nfnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,28 + + + +yolov3,pass,2 From e017b56c0c5e167b36f03902e16220650849fb20 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Fri, 31 May 2024 17:19:53 +0000 Subject: [PATCH 250/706] [dtensor] local_map UX change: keep func signature and be compatible with Tensor input (#126924) **Summary** This PR has 2 parts of change in `local_map`: 1. regulates the way user can access `DeviceMesh` inside the `func` argument of `local_map`. This means `local_map` will strictly follow the `func` signature without implicitly passing any argument to `func`. If user wants to use `DeviceMesh` inside `func`, this mesh must be explicitly passed to `func` as an argument by user. For example, ``` def user_function(device_mesh, /, *args, **kwargs): USER CODE HERE local_func = local_map(func=user_function, ...) dtensor_out = local_func(device_mesh, dtensor_input, ...) ``` Before this PR, user code was like: ``` def user_function(device_mesh, /, *args, **kwargs): USER CODE HERE local_func = local_map(func=user_function, ...) dtensor_out = local_func(dtensor_input, ...) # local_map passes mesh implicitly for user ``` 2. `local_map` now supports mix use of `torch.Tensor` and `DTensor` in argument: - Pure torch.Tensor case: no `DTensor` argument is passed in, all tensor arguments are `torch.Tensor`. Bypass the `in_placements` check and unwrapping steps. The output will not be wrapped into `DTensor` but directly returned. - Pure DTensor case: no `torch.Tensor` argument is passed in, all tensor arguments are `DTensor`. This follows the default rule: `in_placements` check, unwrapping arguments, pass into `func`, wrapping the `torch.Tensor` output into `DTensor` if the `out_placements` is not `None`. - Mix of the above two: some arguments are `torch.Tensor` while some are `DTensor`. Only perform `in_placements` check and unwrapping on `DTensor` arguments. For output processing, it's the same as Pure DTensor case. **Test** `pytest test/distributed/_tensor/experimental/test_local_map.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126924 Approved by: https://github.com/wanchaol --- .../_tensor/experimental/test_local_map.py | 123 +++++++++++++--- .../_tensor/experimental/local_map.py | 138 +++++++++++------- 2 files changed, 191 insertions(+), 70 deletions(-) diff --git a/test/distributed/_tensor/experimental/test_local_map.py b/test/distributed/_tensor/experimental/test_local_map.py index 1035df2f5f7d..b483194d6c3a 100644 --- a/test/distributed/_tensor/experimental/test_local_map.py +++ b/test/distributed/_tensor/experimental/test_local_map.py @@ -5,6 +5,7 @@ import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import ( distribute_tensor, + DTensor, init_device_mesh, Replicate, Shard, @@ -18,23 +19,30 @@ ) -def equal_forward(device_mesh, X, Y): +funcol_py = torch.ops.c10d_functional + + +def equal_allgather_forward(device_mesh, X, Y): eq = torch.tensor([torch.equal(X, Y)], device=X.device) eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh) return torch.all(eq_gather).item() -def mm_forward(device_mesh, W, X): - return torch.mm(W, X) +def mm_all_gather_forward(device_mesh, A, B): + local_mm_result = torch.mm(A, B) + return funcol.all_gather_tensor(local_mm_result, 0, device_mesh).wait() + +def mm_forward(A, B): # no device mesh needed since we don't do collective + return torch.mm(A, B) -def mm_allreduce_forward(device_mesh, W, X): - partial_sum_tensor = torch.mm(W, X) - reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() - return reduced_tensor +def mm_allreduce_forward(device_mesh, A, B): + partial_sum_tensor = torch.mm(A, B) + return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() -def mul_forward(device_mesh, X, scalar): + +def mul_forward(X, scalar): # no device mesh needed since we don't do collective return torch.mul(X, scalar) @@ -58,6 +66,7 @@ def test_local_map_correctness(self): row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh + replicate = [Replicate()] W_dt = distribute_tensor( W, device_mesh, col_wise ) # col-wisely sharded W tensor @@ -70,12 +79,12 @@ def test_local_map_correctness(self): # DTensors' `_local_tensor`. local_mm_allreduce_forward = local_map( mm_allreduce_forward, - out_placements=[Replicate()], - in_placements=(col_wise, row_wise), + out_placements=replicate, + in_placements=(None, col_wise, row_wise), device_mesh=device_mesh, ) with comm_mode: - Y_dt = local_mm_allreduce_forward(W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # output redistribution to Replicate self.assertEqual(comm_mode.get_total_counts(), 1) @@ -88,6 +97,7 @@ def test_local_map_correctness(self): # check for `out_placements` @with_comms def test_local_map_out_placements(self): + # Test 1: wrap out into DTensor w/ `out_placements` device_mesh = init_device_mesh( device_type=self.device_type, mesh_shape=(self.world_size,) ) @@ -99,14 +109,40 @@ def test_local_map_out_placements(self): row_wise = [Shard(0)] X_dt = distribute_tensor(X, device_mesh, row_wise) Y_dt = distribute_tensor(Y, device_mesh, row_wise) - local_equal_forward = local_map(equal_forward, out_placements=None) + local_equal_allgather_forward = local_map( + equal_allgather_forward, + out_placements=None, + ) with comm_mode: - equal_dt = local_equal_forward(X_dt, Y_dt) # a bool + equal_dt = local_equal_allgather_forward(device_mesh, X_dt, Y_dt) # a bool self.assertEqual(comm_mode.get_total_counts(), 1) self.assertTrue(not equal_dt) self.assertTrue(not (X.equal(Y))) + # Test 2: directly return out if no argument is DTensor + # matmul in DDP + replicate = [Replicate()] + X = torch.randn( + 4 // self.world_size, 4, device=self.device_type, requires_grad=False + ) + W = torch.randn(4, 4, device=self.device_type, requires_grad=False) + local_mm_all_gather_forward = local_map( + mm_all_gather_forward, + out_placements=row_wise, + in_placements=(None, row_wise, replicate), + ) + with comm_mode: + Y = local_mm_all_gather_forward(device_mesh, X, W) + + self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual( + comm_mode.get_comm_counts()[funcol_py.all_gather_into_tensor], 1 + ) + X_replicate = funcol.all_gather_tensor(X, 0, device_mesh).wait() + Y_replicate = torch.mm(X_replicate, W) + self.assertEqual(Y, Y_replicate) # Y is a torch.Tensor + # check for `in_placements` handling @with_comms def test_local_map_in_placements(self): @@ -173,6 +209,54 @@ def test_local_map_in_placements(self): self.assertTrue(placement.is_shard(dim=0)) self.assertEqual(Y_dt.full_tensor(), Y) + # Test 4: `None` placements for Tensor input argument + X = torch.randn(16, 8, device=self.device_type, requires_grad=False) + W = torch.randn(8, 12, device=self.device_type, requires_grad=False) + X_dt = distribute_tensor( + X, device_mesh, row_wise + ) # row-wisely sharded X tensor + W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor + local_mm_forward = local_map( + mm_forward, + out_placements=None, + in_placements=(None, None), + device_mesh=device_mesh, + ) + with comm_mode: + Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local()) + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual( + DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(), + torch.mm(X, W), + ) + + # Test 5: Some placements for Tensor input argument + local_mm_forward = local_map( + mm_forward, + out_placements=None, + in_placements=(replicate, row_wise), + device_mesh=device_mesh, + ) + with comm_mode: + Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local()) + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual( + DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(), + torch.mm(X, W), + ) + + # Test 6: expect error - `None` placements for DTensor input argument + local_mm_forward = local_map( + mm_forward, + out_placements=row_wise, + in_placements=(row_wise, None), + device_mesh=device_mesh, + ) + with self.assertRaisesRegex(AssertionError, "expects placements"): + Y_dt = local_mm_forward(X_dt, W_dt) + # check for `redistribute_inputs` handling @with_comms def test_local_map_redistribute(self): @@ -188,6 +272,7 @@ def test_local_map_redistribute(self): row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh + replicate = [Replicate()] W_dt = distribute_tensor( W, device_mesh, row_wise ) # row-wisely sharded W tensor which will be redistributed @@ -198,13 +283,13 @@ def test_local_map_redistribute(self): # Test 1: allow input redistribution local_mm_allreduce_forward = local_map( mm_allreduce_forward, - out_placements=[Replicate()], - in_placements=(col_wise, row_wise), + out_placements=replicate, + in_placements=(None, col_wise, row_wise), device_mesh=device_mesh, redistribute_inputs=True, ) with comm_mode: - Y_dt = local_mm_allreduce_forward(W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # 2 for input redistribution and 1 for output self.assertEqual(comm_mode.get_total_counts(), 3) @@ -215,13 +300,13 @@ def test_local_map_redistribute(self): # Test 2: no input redistribution is allowed local_mm_allreduce_forward = local_map( mm_allreduce_forward, - out_placements=[Replicate()], - in_placements=(col_wise, row_wise), + out_placements=replicate, + in_placements=(None, col_wise, row_wise), device_mesh=device_mesh, redistribute_inputs=False, ) with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"): - Y_dt = local_mm_allreduce_forward(W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) if __name__ == "__main__": diff --git a/torch/distributed/_tensor/experimental/local_map.py b/torch/distributed/_tensor/experimental/local_map.py index 002ff5542a11..2bf12871cc36 100644 --- a/torch/distributed/_tensor/experimental/local_map.py +++ b/torch/distributed/_tensor/experimental/local_map.py @@ -2,6 +2,7 @@ from typing import Callable, Optional, Sequence, Tuple, Union import torch +from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor.placement_types import Placement @@ -12,7 +13,7 @@ PlacementType = Optional[Sequence[Placement]] -InputPlacements = Union[PlacementType, Tuple[PlacementType, ...]] +InputPlacements = Optional[Tuple[PlacementType, ...]] OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]] @@ -32,24 +33,36 @@ def local_map( func (Callable): the function to be applied on each local shard of :class:`DTensor`s. out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]): - the desired placements of the output :class:`DTensor`s. If the `output` of - `func` is a Python collection, the `out_placements` will be a Tuple of - `PlacementType` values 1:1 mapping to the flattened `output`. For - :class:`Tensor` output, the corresponding `PlacementType` will be its + the desired placements of the :class:`DTensor`s in `func`'s flattened output. + If the flattened `output` is a single value, the `out_placements` should be + of type `PlacementType`. Otherwise if the flattened `output` has multiple + values, the `out_placements` should be a tuple of `PlacementType` values 1:1 + mapping to the flattened `output`. + Besides, for :class:`Tensor` output, we use `PlacementType` as its placements (a `Tuple[Placement]` value). For non-:class:`Tensor` output, - the `PlacementType` will be `None`. - in_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]], optional): - the required placements of the input :class:`DTensor`s. If not specified, - the input :class:`DTensor` will not be redistributed before passing its local - tensor to `func`. Similarly to `out_placements`, `in_placements` should keep - a 1:1 mapping to the flattened input of `func`. If a redistribution is - required according to `in_placements` and `redistribute_inputs` is `False`, - an exception will be raised. + the `PlacementType` should be `None`. + Note that the only exception is when no :class:`DTensor` argument is passed + in. In this case, even if `out_placements` is not `None`, the result function + should ignore the desired placements because the application is not on + :class:`DTensors`. + in_placements (Tuple[`PlacementType`, ...], optional): + the required placements of the :class:`DTensor`s in `func`'s flattened input. + If `in_placements` is specified, `local_map` would examine whether the + placements of each :class:`DTensor` argument is the same as the required + placements or not. If the placements are not the same and + `redistribute_inputs` is `False`, an exception will be raised. Otherwise if + `redistribute_inputs` is `True`, the argument will be first redistributed to + the required sharding placements before passing its local tensor to `func`. + The only exception is when required placements are not `None` and the + argument is a :class:`torch.Tensor`. In this case, the placements examination + will be skipped and the argument will be directly passed to `func`. + If `in_placements` is `None`, no placements examination will be performed. + Default: `None` device_mesh (:class:`DeviceMesh`, optional): the device mesh that all the :class:`DTensor`s are placed on. If not specified, this will be inferred from the input :class:`DTensor`s' device mesh. `local_map` requires every :class:`DTensor`s to be placed on the same - device mesh. + device mesh. Default: `None`. redistribute_inputs (bool, optional): the bool value indicating whether to reshard the input :class:`DTensor`s when their placements are different from the required input placements. If this @@ -93,9 +106,9 @@ def local_map( >>> device_mesh=device_mesh, >>> ) >>> - >>> W_dt = distribute_tensor(W, device_mesh, col_wise) # col-wisely sharded W tensor - >>> X_dt = distribute_tensor(X, device_mesh, row_wise) # row-wisely sharded X tensor - >>> Y_dt = local_mm_allreduce_forward(W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors + >>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor + >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor + >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors NOTE: This API is currently experimental and subject to change """ @@ -103,10 +116,16 @@ def local_map( def wrapped(*args, **kwargs): # process input args flat_args, args_spec = pytree.tree_flatten(args) + if in_placements is not None: + assert len(in_placements) == len(flat_args), ( + f"in_placements length {len(in_placements)} does not match the number " + f"of input args {len(flat_args)}!" + ) # we assume every DTensor object is placed on the same device mesh flat_local_args = [] nonlocal device_mesh # access var device_mesh from the outer scope + seen_dtensor_arg = False for idx, arg in enumerate(flat_args): if isinstance(arg, DTensor): # TODO: the current code doesn't consider the uneven sharding case @@ -115,17 +134,16 @@ def wrapped(*args, **kwargs): if device_mesh is None: # infer device mesh from the DTensor arg device_mesh = arg.device_mesh + # this function is applied to at least one DTensor argument + seen_dtensor_arg = True + assert arg.device_mesh == device_mesh, ( - f"arg {arg} in local_map has a mismatched device mesh:" - f"{arg} has device mesh {arg.device_mesh} while" + f"arg {arg} in local_map has a mismatched device mesh: " + f"{arg} has device mesh {arg.device_mesh} while " f"the expected device mesh is {device_mesh}!" ) if in_placements is not None: - spec = ( - in_placements[idx] - if isinstance(in_placements, tuple) - else in_placements - ) + spec = in_placements[idx] assert ( spec is not None ), f"DTensor input {arg} expects placements but received {spec}!" @@ -139,44 +157,62 @@ def wrapped(*args, **kwargs): arg = arg.redistribute(device_mesh, spec) else: raise ValueError( - f"arg {arg} in local_map has a mismatched placements:" - f"arg placements is {arg.placements} but the input" - f"placements is {spec}!" - "If redistribute_inputs is wanted, set redistribute_inputs=True to local_map." + f"arg {arg} in local_map has a mismatched placements: " + f"arg placements is {arg.placements} but the input " + f"placements is {spec}! " + "If redistribute_inputs is wanted, set " + "redistribute_inputs=True to local_map." ) - flat_local_args.append(arg.to_local()) + local_arg = arg.to_local() + if isinstance(local_arg, AsyncCollectiveTensor): + local_arg = local_arg.wait() + + flat_local_args.append(local_arg) else: + # Non-Tensor input must have None in `in_placements` + if in_placements is not None and not isinstance(arg, torch.Tensor): + spec = in_placements[idx] + assert spec is None, ( + f"Non-Tensor input {arg} expects None placements " + f"but received {spec}!" + ) + flat_local_args.append(arg) local_args = pytree.tree_unflatten(flat_local_args, args_spec) - out = func(device_mesh, *local_args, **kwargs) + out = func(*local_args, **kwargs) - # process output - flat_out, out_spec = pytree.tree_flatten(out) - flat_dist_out = [] - for idx, out in enumerate(flat_out): - spec = ( - out_placements[idx] - if isinstance(out_placements, tuple) - else out_placements - ) - if isinstance(out, torch.Tensor): - assert not isinstance( - out, DTensor - ), f"torch.Tensor output expected but received {type(out)}: {out}" + if seen_dtensor_arg: + # process output + flat_out, out_spec = pytree.tree_flatten(out) - flat_dist_out.append( - DTensor.from_local(out, device_mesh, spec, run_check=False) + flat_dist_out = [] + for idx, out in enumerate(flat_out): + spec = ( + out_placements[idx] + if isinstance(out_placements, tuple) + else out_placements ) - else: - assert ( - spec is None - ), f"Non-tensor output {out} expects None placements but received {spec}!" - flat_dist_out.append(out) + if isinstance(out, torch.Tensor): + assert not isinstance( + out, DTensor + ), f"torch.Tensor output expected but received {type(out)}: {out}" + + flat_dist_out.append( + DTensor.from_local(out, device_mesh, spec, run_check=False) + ) + else: + assert ( + spec is None + ), f"Non-tensor output {out} expects None placements but received {spec}!" + + flat_dist_out.append(out) - return pytree.tree_unflatten(flat_dist_out, out_spec) + return pytree.tree_unflatten(flat_dist_out, out_spec) + else: + return out return wrapped From f343f98710dfa7305a873f558086c595a3c3d3d4 Mon Sep 17 00:00:00 2001 From: Daniil Kutz Date: Mon, 3 Jun 2024 08:48:12 +0000 Subject: [PATCH 251/706] [jit] Validate mobile module fields parsed by flatbuffer loader (#127437) Fixing error in `torch.jit.load` Python API function that cause crash in C-backend of PyTorch. The mobile module is succesfully parsed from flatbuffer format, but its fields are used without any validation. Fixes #127434 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127437 Approved by: https://github.com/davidberard98 --- torch/csrc/jit/mobile/flatbuffer_loader.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index 239deb76d267..bca407358913 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -296,6 +296,11 @@ mobile::Module FlatbufferLoader::parseModule( "Parsing flatbuffer module: Corrupted ivalues/object_types field"); TORCH_CHECK( reinterpret_cast(ivalues) < end, "Corrupted ivalues field"); + TORCH_CHECK( + module->storage_data_size() >= 0, + "Parsing flatbuffer module: illegal storage_data_size: ", + module->storage_data_size(), + ", expected to be non negative"); all_ivalues_.resize(ivalues->size()); all_types_.resize(module->object_types()->size()); storages_.resize(module->storage_data_size()); From d6963e769ce08d1da9e449ae7ea011dd322bac40 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 1 Jun 2024 21:39:45 -0700 Subject: [PATCH 252/706] Force Inductor output code to be dumped even if it fails to compile (#127700) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/127700 Approved by: https://github.com/oulgen --- torch/_inductor/graph.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 743018014c4c..411ac0b45ebb 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1688,8 +1688,25 @@ def compile_to_module(self): code, linemap = ( self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() ) - linemap = [(line_no, node.stack_trace) for line_no, node in linemap] - key, path = PyCodeCache.write(code) + + output_code_log.debug("Output code: \n%s", code) + try: + linemap = [(line_no, node.stack_trace) for line_no, node in linemap] + key, path = PyCodeCache.write(code) + except Exception: + trace_structured( + "inductor_output_code", + # Just omit the filename, I still want the code though! + payload_fn=lambda: code, + ) + raise + else: + trace_structured( + "inductor_output_code", + lambda: {"filename": path}, + payload_fn=lambda: code, + ) + mod = PyCodeCache.load_by_key_path( key, path, @@ -1706,12 +1723,6 @@ def compile_to_module(self): log_module_code(mod.__file__) log.debug("Output code written to: %s", mod.__file__) - output_code_log.debug("Output code: \n%s", code) - trace_structured( - "inductor_output_code", - lambda: {"filename": mod.__file__}, - payload_fn=lambda: code, - ) output_code_log.info("Output code written to: %s", mod.__file__) if config.benchmark_kernel: print(f"Compiled module path: {mod.__file__}", file=sys.stderr) From f03f8bc901a6c9038308a6353e8d280f4b5628f5 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 29 May 2024 18:20:31 +0000 Subject: [PATCH 253/706] Add aten._unsafe_masked_index (#116491) To generate masked indexing operations that would generate masked loads in triton code Pull Request resolved: https://github.com/pytorch/pytorch/pull/116491 Approved by: https://github.com/lezcano, https://github.com/peterbell10 --- .../src/ATen/functorch/BatchRulesIndexing.cpp | 23 ++++++ .../ATen/native/TensorAdvancedIndexing.cpp | 76 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 12 +++ test/distributed/_tensor/test_dtensor_ops.py | 2 + test/inductor/test_torchinductor.py | 32 +++++++- test/inductor/test_torchinductor_opinfo.py | 2 + test/onnx/test_fx_op_consistency.py | 5 ++ test/test_fx_experimental.py | 45 ++++++++--- test/test_mps.py | 8 ++ tools/autograd/derivatives.yaml | 11 +++ tools/autograd/gen_variable_type.py | 4 + torch/_decomp/__init__.py | 2 + torch/_decomp/decompositions.py | 55 ++++++++++++- torch/_dynamo/trace_rules.py | 2 + torch/_inductor/codegen/triton.py | 5 ++ torch/_inductor/decomposition.py | 2 + torch/_inductor/lowering.py | 70 +++++++++++++--- torch/_inductor/utils.py | 1 + .../_internal/common_methods_invocations.py | 79 +++++++++++++++++-- 19 files changed, 408 insertions(+), 28 deletions(-) create mode 100644 aten/src/ATen/functorch/BatchRulesIndexing.cpp diff --git a/aten/src/ATen/functorch/BatchRulesIndexing.cpp b/aten/src/ATen/functorch/BatchRulesIndexing.cpp new file mode 100644 index 000000000000..eb571b298078 --- /dev/null +++ b/aten/src/ATen/functorch/BatchRulesIndexing.cpp @@ -0,0 +1,23 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +namespace at { namespace functorch { + +#define OP_DECOMPOSE(op) m.impl(#op, static_cast(native::op)); +#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast(native::op)); + +TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { + OP_DECOMPOSE2(_unsafe_index, Tensor); + OP_DECOMPOSE(_unsafe_masked_index); + OP_DECOMPOSE(_unsafe_index_put); + OP_DECOMPOSE(_unsafe_masked_index_put_accumulate); +} + +}} diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 395af8e5ef13..041c3f2770ea 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -656,6 +656,82 @@ Tensor _unsafe_index(const Tensor& self, const torch::List return at::index(self, indices); } +Tensor _unsafe_masked_index(const Tensor& self, const Tensor& mask, const torch::List>& indices, const Scalar& fill) { + // Unsafe masked index is equivalent to + // where(mask, self[indices], fill) + // with the main difference being that the when the `mask` is false, the tensor + // `self` is not indexed using `indices`. This allows `indices` to be out-of-bounds + // when `mask` is false. When `mask` is true, the `indices` are expected to be + // in bounds and is not checked. + // + // This function is not meant to be executed on eager mode. An unoptimized version + // is provided here. + // + // compiler backends should implement this op such that `self[indices]` is not + // loaded when `mask` is true. See inductor for a reference. + auto clamp = [](const c10::optional& index, auto size) -> c10::optional { + if (!index) { + return index; + } + // Disallow bool + auto dtype = index->scalar_type(); + TORCH_CHECK(dtype == kLong || dtype == kInt, + "_unsafe_masked_index found unexpected index type ", dtype); + return at::clamp(*index, -size, size - 1); + }; + + torch::List> clamped_indices(indices); + std::transform(indices.begin(), indices.end(), self.sizes().begin(), clamped_indices.begin(), clamp); + + if (self.numel() == 0) { + // Returns a tensor filled with `fill` value + // We use a hack here since we do not have a method to get the + // correct size of the tensor. (except with meta impl which is + // not available on mobile builds) + std::vector new_sizes(self.dim()); + auto compute_new_size = [](const c10::optional& index, auto size) -> int64_t { + if (index && size == 0) { + return 1; + } else { + return size; + } + }; + std::transform(indices.begin(), indices.end(), self.sizes().begin(), new_sizes.begin(), compute_new_size); + auto result = self.new_full(new_sizes, fill); + return at::_unsafe_index(result, clamped_indices); + } + + auto result = at::_unsafe_index(self, clamped_indices); + return result.masked_fill(at::logical_not(mask), fill); +} + +Tensor _unsafe_masked_index_put_accumulate(const Tensor& self, const Tensor& mask, const torch::List>& indices, const Tensor& values) { + // This is the backward of _unsafe_masked_index. + // This function is not meant to be executed on eager mode. + + if (self.numel() == 0) { + return self.clone(); + } + + // We recompute the clamped indices and rely on inductor to CSE the computation + auto clamp = [](const c10::optional& index, auto size) -> c10::optional { + if (!index) { + return index; + } + // Disallow bool + auto dtype = index->scalar_type(); + TORCH_CHECK(dtype == kLong || dtype == kInt, + "_unsafe_masked_index found unexpected index type ", dtype); + return at::clamp(*index, -size, size - 1); + }; + + torch::List> clamped_indices(indices); + std::transform(indices.begin(), indices.end(), self.sizes().begin(), clamped_indices.begin(), clamp); + + auto masked_value = values.masked_fill(at::logical_not(mask), 0); + return at::_unsafe_index_put(self, clamped_indices, masked_value, true); +} + Tensor & put_(Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) { // See note [Writing Nondeterministic Operations] // Nondeterministic when index contains duplicate entries and we do not accumulate diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a051f43e87eb..5c28397a07d5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3061,6 +3061,18 @@ dispatch: CompositeExplicitAutograd: _unsafe_index +# Used by inductor to generate masked loads +# Note that we don't support boolean indexing, to avoid dynamic output shapes +- func: _unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _unsafe_masked_index + +- func: _unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _unsafe_masked_index_put_accumulate + - func: index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!) structured: True variants: function diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 22a56118b212..ef1ccc754c6c 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -478,6 +478,8 @@ def wrapped(fn): xfail("unique"), xfail("unsafe_split"), xfail("unsafe_chunk"), + xfail("_unsafe_masked_index"), + xfail("_unsafe_masked_index_put_accumulate"), xfail("var_mean"), xfail("var_mean", "unbiased"), xfail("vdot"), diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 111d0e1ef959..09a82b9d06e2 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1430,6 +1430,33 @@ def flip(x): actual = _run_and_assert_no_indirect_indexing(self, flip_opt, x) self.assertEqual(expect, actual) + def test__unsafe_masked_index(self): + def fn(a, mask, idx): + return aten._unsafe_masked_index(a, mask, idx, 1) + + self.common( + fn, + ( + torch.randn(8, device=self.device), + torch.tensor([True, False, True], device=self.device), + [torch.tensor([3, 9, -2], device=self.device)], + ), + ) + + def test__unsafe_masked_index_put_accumulate(self): + def fn(a, mask, idx, values): + return aten._unsafe_masked_index_put_accumulate(a, mask, idx, values) + + self.common( + fn, + ( + torch.randn(8, device=self.device), + torch.tensor([True, False, True], device=self.device), + [torch.tensor([3, 9, -2], device=self.device)], + torch.randn(3, device=self.device), + ), + ) + def test_sum1(self): def fn(a, b): return ((a + b).sum(-1),) @@ -10940,7 +10967,10 @@ def fn(x, n): fn_opt = torch.compile(fn) code = run_and_get_triton_code(fn_opt, x, 8) # load should be masked - self.assertTrue("tl.load(in_ptr0 + (tmp0), xmask" in code) + self.assertTrue( + "tl.load(in_ptr0 + (tmp0), xmask" in code + or "tl.load(in_ptr0 + (tmp0), (xmask).to(tl.int1)" in code + ) self.assertEqual(fn(x, 8), fn_opt(x, 8)) def test_kernel_names_descriptive(self): diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index b66c0ce0832f..22bec1756487 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -382,6 +382,8 @@ def wrapper_noop_set_seed(op, *args, **kwargs): }, ("std_mean.unbiased", "cuda", f16): {"reference_in_float": True}, ("uniform", "cuda"): {"reference_in_float": True}, + ("_unsafe_masked_index_put_accumulate", "cuda", f16): {"atol": 1e-4, "rtol": 0.01}, + ("_unsafe_masked_index_put_accumulate", "cpu", f16): {"atol": 1e-4, "rtol": 0.01}, # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors ("nn.functional.interpolate.bilinear", "cpu", u8): {"atol": 1, "rtol": 0}, ("nn.functional.upsample_bilinear", "cpu", u8): {"atol": 1, "rtol": 0}, diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index 4a4171699e65..08230bbc0099 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -170,6 +170,11 @@ def skip_torchlib_forward_compatibility( dtypes=(torch.float16,), reason="fixme: Assertion error: result mismatch", ), + xfail( + "_unsafe_masked_index", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), + ), xfail( "add", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Add") diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index d3231df35313..d3ee06d889a8 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1,5 +1,6 @@ # Owner(s): ["module: fx"] +import functools import math import numbers import operator @@ -51,6 +52,7 @@ from torch.testing._internal.common_nn import module_tests, new_module_tests from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase from torch.testing._internal.jit_utils import JitTestCase +import torch.utils._pytree as pytree try: import torchvision.models @@ -1623,21 +1625,40 @@ def jit_infer_type(v): param_names = [] param_values = [] fx_args = [] - for idx, v in enumerate(arg_values): - if isinstance(v, torch.Tensor): - param_names.append(f"arg_{idx}") - param_values.append(v) - fx_args.append(param_names[-1]) + + idx = 0 + + def process_arg(arg, name): + if isinstance(arg, torch.Tensor): + param_names.append(name) + param_values.append(arg) + return name + else: + return f"{repr(arg)}" + + def process_arg_with_idx(arg): + nonlocal idx + res = process_arg(arg, f"arg_{idx}") + idx = idx + 1 + return res + + def str_arg(arg): + if isinstance(arg, tuple): + args = [f"{str_arg(v)}, " for v in arg] + return f"({' '.join(args)})" + elif isinstance(arg, list): + args = [f"{str_arg(v)}" for v in arg] + return f"[{', '.join(args)}]" else: - fx_args.append(f"{repr(v)}") + return arg + + for v in arg_values: + arg = pytree.tree_map(process_arg_with_idx, v) + fx_args.append(str_arg(arg)) for k, v in kwarg_values.items(): - if isinstance(v, torch.Tensor): - param_names.append(k) - param_values.append(v) - fx_args.append(f"{k} = {k}") - else: - fx_args.append(f"{k} = {repr(v)}") + arg = pytree.tree_map(functools.partial(process_arg, name=k), v) + fx_args.append(f"{k} = {str_arg(arg)}") code = f""" class TestModule(torch.nn.Module): diff --git a/test/test_mps.py b/test/test_mps.py index 8c3bbf4b7bcf..8b3ae97ff218 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -144,6 +144,10 @@ def mps_ops_grad_modifier(ops): # round not working properly for float16 'round': [torch.float16], + + # atomic operation in backward pass + '_unsafe_masked_index': [torch.float16], + '_unsafe_masked_index_put_accumulate': [torch.float16], } MACOS_12_3_XFAILLIST_GRAD = { @@ -351,6 +355,7 @@ def mps_ops_modifier(ops): '__rdiv__', '__rmatmul__', '_chunk_cat', + '_unsafe_masked_index', 'acos', 'acosh', 'all', @@ -905,6 +910,9 @@ def mps_ops_modifier(ops): # round not working properly for float16 'round': [torch.float16], + + # atomic operations not supported + '_unsafe_masked_index_put_accumulate': [torch.bool, torch.int8, torch.uint8, torch.float16, torch.int16, torch.int64], } if product_version < 14.0: diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 4922513f295d..cb50be54feb5 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -830,6 +830,17 @@ self: at::_unsafe_index_put(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad, true) result: auto_linear +- name: _unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor + self: at::_unsafe_masked_index_put_accumulate(grad.new_zeros_symint(self.sym_sizes(), self.options()), mask, indices, grad) + mask: non_differentiable + result: _unsafe_masked_index(self_t, mask, indices, 0) + +- name: _unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor + self: grad + mask: non_differentiable + values: at::_unsafe_masked_index(grad, mask, indices, 0) + result: at::_unsafe_masked_index_put_accumulate(self_t, mask, indices, values_t) + - name: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor self: grad # The case source.dim() == 0 is necessary to support scalar tensors of the form diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index b9651ea2da80..86ea05bc6103 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -375,6 +375,10 @@ "linalg_lu_solve", "_linalg_slogdet", "_linalg_solve_ex", + "_unsafe_index", + "_unsafe_index_put", + "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", } GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = { diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index b277bb7eceb0..5a45b6a37a14 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -445,6 +445,8 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.unfold_backward, aten.unfold_copy, aten._unsafe_index, + aten._unsafe_masked_index, + aten._unsafe_masked_index_put_accumulate, aten.unsafe_split.Tensor, aten.unsafe_split_with_sizes, aten._unsafe_view, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 5bec539db06c..fc3ba94c806e 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -8,6 +8,7 @@ from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union import torch +import torch._meta_registrations import torch._prims as prims import torch._prims_common as utils import torch.nn.functional as F @@ -3712,10 +3713,62 @@ def _reshape_alias(x, shape, *args): @register_decomposition([aten._unsafe_index]) -def _index(x, indices): +def _unsafe_index(x, indices): return aten.index(x, indices) +@register_decomposition([aten._unsafe_masked_index]) +def _unsafe_masked_index(x, mask, indices, fill): + for index in indices: + if index is not None: + torch._check( + index.dtype in [torch.long, torch.int], + lambda: "tensors used as indices must be long or int tensors", + ) + + torch._check( + mask.dtype == torch.bool, + lambda: "tensors used as masks must be bool tensors", + ) + + if x.numel() == 0: + meta_result = torch._meta_registrations.meta_index_Tensor(x, indices) + return x.new_full(meta_result.shape, fill) + + for i in range(len(indices)): + index = indices[i] + if index is not None: + indices[i] = index.clamp(min=0, max=x.size(i) - 1) + + return aten._unsafe_index(x, indices).masked_fill(~mask, fill) + + +@register_decomposition([aten._unsafe_masked_index_put_accumulate]) +def _unsafe_masked_index_put_accumulate(x, mask, indices, values): + for index in indices: + if index is not None: + torch._check( + index.dtype in [torch.long, torch.int], + lambda: "tensors used as indices must be long or int tensors", + ) + + torch._check( + mask.dtype == torch.bool, + lambda: "tensors used as masks must be bool tensors", + ) + + if x.numel() == 0: + return x.clone() + + for i in range(len(indices)): + index = indices[i] + if index is not None: + indices[i] = index.clamp(min=-x.size(i), max=x.size(i) - 1) + + masked_value = values.masked_fill(~mask, 0) + return aten._unsafe_index_put(x, indices, masked_value, accumulate=True) + + def _nll_loss_forward( self: Tensor, target: Tensor, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 90f0667fecfb..1f7b70c29325 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1575,6 +1575,8 @@ "torch._unpack_dual", "torch._unsafe_index_put", "torch._unsafe_index", + "torch._unsafe_masked_index_put_accumulate", + "torch._unsafe_masked_index", "torch._use_cudnn_ctc_loss", "torch._use_cudnn_rnn_flatten_weight", "torch._values_copy", diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 4b0ea92f3bf4..953689e668e9 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -871,6 +871,11 @@ def index_expr(cls, expr, dtype): @staticmethod def masked(mask, body, other): + if mask is not None and torch.version.hip is not None: + mask = V.kernel.cse.generate( + V.kernel.compute, + f"{mask}.to(tl.int1)", + ) with V.kernel.mask_loads(mask) as new_mask: result = body() diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 960c3a42e1f1..c1269e3703c8 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -87,6 +87,8 @@ # the Inductor decomp table. decomps_to_exclude = [ aten._unsafe_index, + aten._unsafe_masked_index, + aten._unsafe_masked_index_put_accumulate, aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py aten._softmax_backward_data, aten.clamp_max, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 20b0082eb1d9..d8e679c67adb 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2927,6 +2927,17 @@ def fn(idx): def index_impl(x, indices, check): + output_size, inner_fn = index_impl_helper(x, indices, check) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=output_size, + ) + + +def index_impl_helper(x, indices, check): assert isinstance(indices, (list, tuple)) x_loader = x.make_loader() indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) @@ -2941,11 +2952,11 @@ def index_impl(x, indices, check): x_size = x.get_size() indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] - if 0 in indexed_size and 0 not in tensor_size: + if check and 0 in indexed_size and 0 not in tensor_size: raise IndexError("index is out of bounds for dimension with size 0") indexed_size = [x_size[i] for i in range(len(indices))] - output_size, inner_fn = index_output_size_and_inner_fn( + return index_output_size_and_inner_fn( x_size, indices, tensor_indices, @@ -2956,13 +2967,6 @@ def index_impl(x, indices, check): check=check, ) - return Pointwise.create( - device=x.get_device(), - dtype=x.get_dtype(), - inner_fn=inner_fn, - ranges=output_size, - ) - @register_lowering(aten.index, type_promotion_kind=None) def index(x, indices): @@ -3159,6 +3163,54 @@ def load_source_val(): return view(result_flat, self.get_size()) +fallback__unsafe_masked_index = fallback_handler( + aten._unsafe_masked_index.default, add_to_fallback_set=False +) + +fallback__unsafe_masked_index_put_accumulate = fallback_handler( + aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False +) + + +@register_lowering(aten._unsafe_masked_index, type_promotion_kind=None) +def _unsafe_masked_index(self, mask, indices, fill): + ranges, _unsafe_index_fn = index_impl_helper(self, indices, check=False) + mask_loader = mask.make_loader() + + def inner_fn(idx): + mask_val = ops.to_dtype(mask_loader(idx), torch.bool) + return ops.masked(mask_val, lambda: _unsafe_index_fn(idx), fill) + + return Pointwise.create( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None) +def _unsafe_masked_index_put_accumulate(x, mask, indices, values): + if torch.version.hip is not None: + # Avoid a triton compiler failure + return fallback__unsafe_masked_index_put_accumulate(x, mask, indices, values) + + masked_value = where(mask, values, 0) + shape = x.get_size() + clamped_indices = [ + clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None + for i in range(len(indices)) + ] + # TODO: use a masked store for this. currently only triton + # supports masked stores and cpp backend does not. + return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True) + + +@make_pointwise +def clamp(a, min, max): + return ops.maximum(min, ops.minimum(max, a)) + + @register_lowering(aten.as_strided_scatter, type_promotion_kind=None) def as_strided_scatter(self, src, size, stride, storage_offset=None): output = clone(self) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 0915a8330c34..5c41df4406c6 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -647,6 +647,7 @@ def get_first_incompatible_cudagraph_node(gm): forbidden_set.update( { "aten._unsafe_index_put.default", + "aten._unsafe_masked_index_put_accumulate.default", "aten.index_put.default", "aten.index_put_.default", "aten.scatter.src", diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 50cfac763be5..fe58b99e34a9 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -5267,6 +5267,64 @@ def make_idx(n, m): args=(0, idx, src, reduce), kwargs={'include_self': True}) +def sample_inputs__unsafe_masked_index(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m, dim, d): + view_shape = [1] * dim + view_shape[d] = n + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) + + cases = [ + ((S, S), S, M), + ((S, S), M, S), + ((S, S, S), S, M), + ] + + fill_value = make_tensor([], dtype=dtype, device="cpu").item() + + for c in cases: + self_shape, high, idx_size = c + dim = len(self_shape) + indices = [make_idx(idx_size, high, dim, d) for d in range(dim)] + masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, fill_value) + + masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, fill_value) + +def sample_inputs__unsafe_masked_index_put_accumulate(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m, dim, d): + view_shape = [1] * dim + view_shape[d] = n + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) + + cases = [ + ((S, S), S, (M, M)), + ((S, S), M, (S, S + 1)), + ((S, S, S), S, (M, M - 1, M + 1)), + ] + + fill_value = make_tensor([], dtype=dtype, device="cpu").item() + + for c in cases: + self_shape, high, idx_sizes = c + dim = len(self_shape) + indices = [make_idx(idx_sizes[d], high, dim, d) for d in range(dim)] + masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + values = make_arg(idx_sizes) + yield SampleInput(make_arg(self_shape), mask, indices, values) + + masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, values) + + def sample_inputs_mode(op_info, device, dtype, requires_grad, **kwargs): args = ( ((S, S, S), (),), @@ -18009,9 +18067,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, - skips=( - # lambda impl - DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),), sample_inputs_func=sample_inputs_column_stack,), OpInfo('pinverse', op=torch.pinverse, @@ -18102,6 +18157,22 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_out=True, sample_inputs_func=sample_inputs_index_reduce, ) for reduction_type in ('mean', 'prod', 'amin', 'amax')), + OpInfo('_unsafe_masked_index', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=False, + supports_inplace_autograd=False, + supports_scripting=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs__unsafe_masked_index), + OpInfo('_unsafe_masked_index_put_accumulate', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=False, + supports_inplace_autograd=False, + supports_scripting=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs__unsafe_masked_index_put_accumulate), OpInfo('__getitem__', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), # Runs very slowly on slow gradcheck - alternatively reduce input sizes @@ -18128,7 +18199,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): test_neg_view=False, sample_inputs_func=sample_inputs_index_put, skips=( - DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), DecorateInfo(unittest.skip("Skipped"), 'TestBwdGradients', 'test_fn_grad', dtypes=[torch.float64], device_type='cuda', active_if=(TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)), )), @@ -19028,7 +19098,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): skips=( # Not implemented on CUDA DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), # JIT tests don't work with Tensor keyword arguments # https://github.com/pytorch/pytorch/issues/58507 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), From d05cddfe2327a92a62fba2220d5d3f735e58d40d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 3 Jun 2024 15:00:21 +0000 Subject: [PATCH 254/706] Revert "FP8 rowwise scaling (#125204)" This reverts commit 923edef31c7f3e98a14625724f2019b1422dcb26. Reverted https://github.com/pytorch/pytorch/pull/125204 on behalf of https://github.com/atalman due to Broke nightlies and internal tests ([comment](https://github.com/pytorch/pytorch/pull/125204#issuecomment-2145422196)) --- aten/src/ATen/CMakeLists.txt | 1 - aten/src/ATen/cuda/detail/LazyNVRTC.cpp | 37 -- aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h | 15 +- aten/src/ATen/native/cuda/Blas.cpp | 113 +--- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 535 ------------------- aten/src/ATen/native/cuda/RowwiseScaledMM.h | 15 - test/test_matmul_cuda.py | 149 +----- third_party/cutlass.BUILD | 14 +- 8 files changed, 25 insertions(+), 854 deletions(-) delete mode 100644 aten/src/ATen/native/cuda/RowwiseScaledMM.cu delete mode 100644 aten/src/ATen/native/cuda/RowwiseScaledMM.h diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 696621eeca6f..9fa7a1f2305b 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -472,7 +472,6 @@ endif() if(USE_CUDA AND NOT USE_ROCM) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) - list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) if($ENV{ATEN_STATIC_CUDA}) list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${CUDA_LIBRARIES} diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp index 75c503d48d51..1b85e7776e22 100644 --- a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp +++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -170,43 +170,6 @@ CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *); CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int); CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction); -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 -CUresult CUDAAPI -cuTensorMapEncodeTiled( - CUtensorMap* tensorMap, - CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, - void* globalAddress, - const cuuint64_t* globalDim, - const cuuint64_t* globalStrides, - const cuuint32_t* boxDim, - const cuuint32_t* elementStrides, - CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill) { - auto fn = reinterpret_cast( - getCUDALibrary().sym(__func__)); - if (!fn) - throw std::runtime_error("Can't get cuTensorMapEncodeTiled"); - lazyNVRTC.cuTensorMapEncodeTiled = fn; - return fn( - tensorMap, - tensorDataType, - tensorRank, - globalAddress, - globalDim, - globalStrides, - boxDim, - elementStrides, - interleave, - swizzle, - l2Promotion, - oobFill); -} - -#endif - // Irregularly shaped functions CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index cb34d10db254..574b2c41c264 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -59,25 +59,16 @@ namespace at { namespace cuda { _(cuLinkAddData) \ _(cuLinkComplete) \ _(cuFuncSetAttribute) \ - _(cuFuncGetAttribute) \ - -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 -#define AT_FORALL_NVRTC_EXTENDED(_) \ - AT_FORALL_NVRTC_BASE(_) \ - _(cuTensorMapEncodeTiled) -#else -#define AT_FORALL_NVRTC_EXTENDED(_) \ - AT_FORALL_NVRTC_BASE(_) -#endif + _(cuFuncGetAttribute) #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 #define AT_FORALL_NVRTC(_) \ - AT_FORALL_NVRTC_EXTENDED(_) \ + AT_FORALL_NVRTC_BASE(_) \ _(nvrtcGetCUBINSize) \ _(nvrtcGetCUBIN) #else #define AT_FORALL_NVRTC(_) \ - AT_FORALL_NVRTC_EXTENDED(_) + AT_FORALL_NVRTC_BASE(_) #endif #else diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index ed59b47349cc..84c59a4fd0d7 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1,7 +1,3 @@ -#include -#include -#include -#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -14,7 +10,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -824,97 +819,24 @@ static bool _scaled_mm_allowed_device() { #endif } -namespace{ - -enum class ScalingType { - TensorWise, - RowWise, - Error -}; - -// Validates the scale tensors to scaled_mm -// And returns the type of scaling/which kernel to use -ScalingType get_scaling_type( - const c10::optional& scale_a, - const c10::optional& scale_b, - int64_t dim_m, - int64_t dim_n) { - TORCH_CHECK( - scale_a.has_value() == scale_b.has_value(), - "Both scale_a and scale_b must be present or absent."); - - if (scale_a.has_value()) { - // Both Per-Tensor and Row-wise scaling expect fp32 tensors - TORCH_CHECK( - scale_a->scalar_type() == kFloat && scale_b->scalar_type() == kFloat, - "Both scale_a and scale_b must be float (fp32) tensors."); - - // Check the singluar scale case for per-tensor scaling - if (scale_a->numel() == 1 && scale_b->numel() == 1) { - return ScalingType::TensorWise; - } else if (scale_a->dim() == 1 && scale_a->size(0) == dim_m) { -// Check the per-row scaling case -#if !defined(USE_ROCM) && !defined(_MSC_VER) || \ - (defined(USE_ROCM) && ROCM_VERSION >= 60000) - TORCH_CHECK( - scale_a->dim() == 1 && scale_b->dim() == 1, - "Both scale_a and scale_b must be 1-dimensional tensors"); - TORCH_CHECK( - scale_b->size(0) == dim_n, - "For row-wise scaling, scale_b must have size ", - dim_n, - " but got ", - scale_b->size(0), - "."); - TORCH_CHECK( - scale_a->is_contiguous() && scale_b->is_contiguous(), - "Both scale_a and scale_b must be contiguous."); - return ScalingType::RowWise; -#else - TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); - return ScalingType::Error; -#endif // !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && - // ROCM_VERSION >= 60000) - } else { - TORCH_CHECK( - false, - "For row-wise scaling, scale_a must be size ", - dim_m, - " but got ", - scale_a->numel(), - " and scale_b must be size ", - dim_n, - " but got ", - scale_b->numel(), - "."); - // Unreachable - return ScalingType::RowWise; - } - } - return ScalingType::Error; -} - -} // namespace - // Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax // Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default. // If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed. // Known limitations: // - Only works if mat1 is row-major and mat2 is column-major // - Only works if matrices sizes are divisible by 32 -// - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0) -// and scale_b should have size = to mat2.size(1) +// // Arguments: // - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16` // - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type -// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type -// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type -// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type +// - `scale_a`: a scalar tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type +// - `scale_b`: a scalar tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type +// - `scale_result`: a scalar tensor with the scale of the output, only set if the output is a float8 type // - `use_fast_accum`: if true, enables fast float8 accumulation // - `out`: a reference to the output tensor -// - `amax`: a reference to the amax tensor of the output, only mutated if the output is a float8 type and will be updated inplace +// - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace std::tuple _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, @@ -933,11 +855,10 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK( mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - - // Check what type of scaling we are doing based on inputs - ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1)); - TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); - + TORCH_CHECK(!scale_a || (scale_a->numel() == 1 && scale_a->scalar_type() == kFloat), + "scale_a must be float scalar"); + TORCH_CHECK(!scale_b || (scale_b->numel() == 1 && scale_b->scalar_type() == kFloat), + "scale_b must be a float scalar"); TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], @@ -980,26 +901,12 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, {scale_result_, "scale_result", 7}}; checkAllSameGPU(__func__, targs); } - // Validation checks have passed lets resize the output to actual size + IntArrayRef mat1_sizes = mat1.sizes(); IntArrayRef mat2_sizes = mat2.sizes(); at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); at::native::resize_output(amax, {}); - // We are doing row-wise scaling - if (scaling_choice == ScalingType::RowWise) { - TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling."); - at::cuda::detail::f8f8bf16_rowwise( - mat1, - mat2, - scale_a.value(), - scale_b.value(), - bias, - use_fast_accum, - out); - return {out, amax}; - } - cublasCommonArgs args(mat1, mat2, out); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu deleted file mode 100644 index 14eb8f5fbf80..000000000000 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ /dev/null @@ -1,535 +0,0 @@ -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include -#include -#include - -// Determine if the architecture supports rowwise scaled mm -#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 - -#define BUILD_ROWWISE_FP8_KERNEL -#endif - -#if defined(BUILD_ROWWISE_FP8_KERNEL) - -// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader -static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled( - CUtensorMap* tensorMap, - CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, - void* globalAddress, - const cuuint64_t* globalDim, - const cuuint64_t* globalStrides, - const cuuint32_t* boxDim, - const cuuint32_t* elementStrides, - CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill) { - return at::globalContext().getNVRTC().cuTensorMapEncodeTiled( - tensorMap, - tensorDataType, - tensorRank, - globalAddress, - globalDim, - globalStrides, - boxDim, - elementStrides, - interleave, - swizzle, - l2Promotion, - oobFill); -} - - -#include -#include -#include -#include -#include -#include -#include - -// Rename the global function symbol -#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled -#include -#undef cuTensorMapEncodeTiled -// Set everything back to normal - -#include -#include -#include - -#include -#include -#include -#include - - -namespace { -// Cutlass rowwise kernel -template < - int TB_M, - int TB_N, - int TB_K, - int TBS_M, - int TBS_N, - int TBS_K, - bool PONG, - bool FAST_ACCUM, - bool USE_BIAS, - typename INPUT_DTYPE, - typename BIAS_DTYPE> -void f8f8bf16_rowwise_impl( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, - at::Tensor w_scale, - c10::optional bias, - at::Tensor out) { - int M = XQ.size(0); - int N = WQ.size(1); - int K = XQ.size(1); - - TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); - TORCH_CHECK( - WQ.is_cuda() && WQ.ndimension() == 2 && WQ.stride(1) == WQ.size(0) && - WQ.stride(0) == 1); - - // auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); - - using ElementInputA = INPUT_DTYPE; - using LayoutInputA = cutlass::layout::RowMajor; - constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); - - using ElementInputB = cutlass::float_e4m3_t; - using LayoutInputB = cutlass::layout::ColumnMajor; - constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); - - using ElementBias = BIAS_DTYPE; - - using ElementOutput = cutlass::bfloat16_t; - using LayoutOutput = cutlass::layout::RowMajor; - constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); - - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; - using TileShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Threadblock-level - // tile size - using ClusterShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Shape of the - // threadblocks in a - // cluster - using KernelSchedule = cutlass::gemm::collective:: - KernelScheduleAuto; // Kernel to launch based on the default setting in - // the Collective Builder - - // Implement rowwise scaling epilogue. - using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0, - TileShape, - ElementComputeEpilogue, - cute::Stride, cute::Int<0>, cute::Int<0>>>; - - using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< - PONG ? 2 : 1, - TileShape, - ElementComputeEpilogue, - cute::Stride, cute::Int<1>, cute::Int<0>>>; - - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< - PONG ? 2 : 1, - TileShape, - ElementBias, - cute::Stride, cute::Int<1>, cute::Int<0>>>; - - using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, - ElementComputeEpilogue, // First stage output type. - ElementComputeEpilogue, // First stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, - cute::conditional_t< // Second stage output type. - USE_BIAS, - ElementBias, - ElementOutput>, - ElementComputeEpilogue, // Second stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute1 = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< - cutlass::plus, - ElementOutput, // Final (optional) stage output type. - ElementBias, // Final stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeBias = - cutlass::epilogue::fusion::Sm90EVT; - - using EpilogueEVT = - cute::conditional_t; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, - cutlass::arch::OpClassTensorOp, - TileShape, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, - ElementComputeEpilogue, - ElementOutput, - LayoutOutput, - AlignmentOutput, - ElementOutput, - LayoutOutput, - AlignmentOutput, - cutlass::epilogue::TmaWarpSpecialized, - EpilogueEVT>::CollectiveOp; - - using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; - using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using FastDefaultSchedule = - cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; - using FastPongSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using SlowAccum = cute::conditional_t; - using FastAccum = - cute::conditional_t; - using MainLoopSchedule = - cute::conditional_t; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementInputA, - LayoutInputA, - AlignmentInputA, - ElementInputB, - LayoutInputB, - AlignmentInputB, - ElementAccumulator, - TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainLoopSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue>; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideInputA = typename Gemm::GemmKernel::StrideA; - using StrideInputB = typename Gemm::GemmKernel::StrideB; - using StrideOutput = typename Gemm::GemmKernel::StrideC; - - StrideInputA stride_a = cutlass::make_cute_packed_stride( - StrideInputA{}, cute::make_shape(M, K, 1)); - StrideInputB stride_b = cutlass::make_cute_packed_stride( - StrideInputB{}, cute::make_shape(N, K, 1)); - StrideOutput stride_output = cutlass::make_cute_packed_stride( - StrideOutput{}, cute::make_shape(M, N, 1)); - - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K}, - {reinterpret_cast(XQ.data_ptr()), - stride_a, - reinterpret_cast(WQ.data_ptr()), - stride_b}, - {{}, // Epilogue thread we populate below. - (ElementOutput*)out.data_ptr(), - stride_output, - (ElementOutput*)out.data_ptr(), - stride_output}}; - - if constexpr (USE_BIAS) { - arguments.epilogue.thread = { - {reinterpret_cast(bias.value().data_ptr())}, // bias - // compute_1 - { - {reinterpret_cast( - x_scale.data_ptr())}, // x_scale - // compute_0 - { - {reinterpret_cast( - w_scale.data_ptr())}, // w_scale - {}, // Accumulator - {} // Multiplies - }, - {}, // Multiplies - }, - {}, // Plus - }; - } else { - arguments.epilogue.thread = { - {reinterpret_cast( - x_scale.data_ptr())}, // x_scale - // compute_0 - { - {reinterpret_cast( - w_scale.data_ptr())}, // w_scale - {}, // Accumulator - {} // Multiplies - }, - {}, // Multiplies - }; - } - - Gemm gemm; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm(at::cuda::getCurrentCUDAStream()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error( - std::string("cutlass cannot run") + - cutlass::cutlassGetStatusString(status)); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -// FP8 Rowwise Cutlass kernel dispatch. -enum class KernelMode { Small, Large, Default }; - -KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { - auto M = XQ.size(0); - auto K = XQ.size(1); - auto N = WQ.size(0); - // Use a large kernel if at least two shapes are large.... - bool use_large_kernel = - ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || - (K >= 2048 && N >= 2048)); - if (M <= 128 || N <= 128) { - return KernelMode::Small; - } else if (use_large_kernel) { - return KernelMode::Large; - } else { - return KernelMode::Default; - } -} - -template -void dispatch_fp8_rowwise_kernel( - at::Tensor XQ, - at::Tensor WQ, - at::Tensor x_scale, - at::Tensor w_scale, - c10::optional bias, - at::Tensor out) { - KernelMode kernel = get_kernel_mode(XQ, WQ); - if (kernel == KernelMode::Small) { - return f8f8bf16_rowwise_impl< - 64, - 128, - 128, - 2, - 1, - 1, - false, - FastAccum, - UseBias, - InputDType, - BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); - } else if (kernel == KernelMode::Large) { - return f8f8bf16_rowwise_impl< - 128, - 128, - 128, - 2, - 1, - 1, - true, - FastAccum, - UseBias, - InputDType, - BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return f8f8bf16_rowwise_impl< - 128, - 128, - 128, - 1, - 2, - 1, - false, - FastAccum, - UseBias, - InputDType, - BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); - } -} - -} // namespace - -#endif // !defined(USE_ROCM) - -namespace at::cuda::detail { -void f8f8bf16_rowwise( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, // FP32 - at::Tensor w_scale, // FP32 - c10::optional bias, // BF16 - bool use_fast_accum, - at::Tensor& out) { -#if defined(BUILD_ROWWISE_FP8_KERNEL) - // Check datatypes. - TORCH_CHECK( - x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, - "Scale tensors must be float32."); - if (bias.has_value()) { - TORCH_CHECK( - bias.value().dtype() == at::kFloat || - bias.value().dtype() == at::kBFloat16, - "Bias type must be bfloat16 or float32 if provided."); - } - // Extract problem size. - int M = XQ.size(0); - int N = WQ.size(1); - int K = XQ.size(1); - - bool use_bias = bias.has_value(); - bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; - - // Templatize based on input dtype. - bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2; - TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn, "For row-wise scaling the second input is required to be a float8_e4m3fn dtype."); - - if (use_bias) { - if (bf16_bias) { - if (use_fast_accum) { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - true, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - true, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); - } - } else { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - false, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - false, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); - } - } - } else { - if (use_fast_accum) { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - true, - true, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - true, - true, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } - } else { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - false, - true, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - false, - true, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } - } - } - } else { - if (use_fast_accum) { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - true, - false, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - true, - false, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } - } else { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - false, - false, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - false, - false, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } - } - } -#else // BUILD_ROWWISE_FP8_KERNEL - TORCH_CHECK(false, "Rowwise scaling is not currenlty supported on your device"); -#endif -} - -} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.h b/aten/src/ATen/native/cuda/RowwiseScaledMM.h deleted file mode 100644 index 4d9054108c85..000000000000 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once -#include -#include - - -namespace at::cuda::detail { -TORCH_API void f8f8bf16_rowwise( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, // FP32 - at::Tensor w_scale, // FP32 - c10::optional bias, // BF16 - bool use_fast_accum, - at::Tensor& out); -} // at::cuda::detail diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 7793e7411e88..a5c583580848 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -204,6 +204,7 @@ def _expand_to_batch(t: torch.Tensor): self.assertEqual(out1_gpu, out2_gpu[0]) + f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" if torch.version.hip: @@ -255,12 +256,8 @@ def amax_to_scale( scale.copy_(res) return scale -def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): - if dim is None: - amax = torch.max(torch.abs(x)) - else: - amax = torch.max(torch.abs(x), dim=dim).values - +def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype): + amax = torch.max(torch.abs(x)) return amax_to_scale(amax, float8_dtype, x.dtype) def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype): @@ -319,6 +316,7 @@ def mm_float8( def to_fp8_saturated( x: torch.Tensor, + x_scale: torch.tensor, fp8_dtype: torch.dtype ): """ @@ -341,6 +339,8 @@ def to_fp8_saturated( of a tensor has a maximum value of `amax1`, and the current amax value is `amax2`, where `amax1 < amax2`. """ + x_scaled = x * x_scale + if fp8_dtype == e4m3_type: x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) elif fp8_dtype == e5m2_type: @@ -353,6 +353,8 @@ def to_fp8_saturated( @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") class TestFP8MatmulCuda(TestCase): + + @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) def _test_tautological_mm(self, device: str = "cuda", x_dtype: torch.dtype = e4m3_type, @@ -416,8 +418,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype): x_scale = tensor_to_scale(x, input_dtype).float() y_scale = tensor_to_scale(y, input_dtype).float() - x_fp8 = to_fp8_saturated(x * x_scale, e4m3_type) - y_fp8 = to_fp8_saturated(y * y_scale, e4m3_type) + x_fp8 = to_fp8_saturated(x, x_scale, e4m3_type) + y_fp8 = to_fp8_saturated(y, y_scale, e4m3_type) # Calculate actual F8 mm out_scaled_mm, output_amax_scaled = mm_float8( @@ -524,137 +526,6 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) - @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) - @skipIfRocm() - @parametrize("use_fast_accum", [True, False]) - def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: - M, K, N = (1024, 512, 2048) - fill_value = 0.5 - x = torch.full((M, K), fill_value, device=device) - y = torch.full((N, K), fill_value, device=device) - - x_scales = torch.ones(x.shape[0], device=device, dtype=torch.float32) - y_scales = torch.ones(y.shape[0], device=device, dtype=torch.float32) - - x_fp8 = x.to(torch.float8_e4m3fn) - y_fp8 = y.to(torch.float8_e4m3fn).t() - - out_fp8, _ = torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=x_scales, - scale_b=y_scales, - out_dtype=torch.bfloat16, - use_fast_accum=use_fast_accum, - ) - self.assertEqual( - out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) - ) - - @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) - @skipIfRocm() - def test_float8_error_messages(self, device) -> None: - M, K, N = (1024, 512, 2048) - fill_value = 0.5 - x = torch.full((M, K), fill_value, device=device) - y = torch.full((N, K), fill_value, device=device) - - x_fp8 = x.to(torch.float8_e4m3fn) - y_fp8 = y.to(torch.float8_e4m3fn).t() - - with self.assertRaisesRegex( - RuntimeError, - "For row-wise scaling, scale_a must be size 1024 but got 1 and scale_b must be size 2048 but got 2", - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((), device="cuda"), - scale_b=torch.ones((2), device="cuda"), - out_dtype=torch.bfloat16, - ) - - with self.assertRaisesRegex( - RuntimeError, - "For row-wise scaling, scale_b must have size 2048 but got 2049.", - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N + 1), device="cuda"), - out_dtype=torch.bfloat16, - ) - with self.assertRaisesRegex( - RuntimeError, - "Both scale_a and scale_b must be 1-dimensional tensors", - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N, N), device="cuda"), - out_dtype=torch.bfloat16, - ) - - with self.assertRaisesRegex( - RuntimeError, - "Both scale_a and scale_b must be contiguous.", - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N * 2), device="cuda")[::2], - out_dtype=torch.bfloat16, - ) - - with self.assertRaisesRegex( - RuntimeError, - "For row-wise scaling the second input is required to be a float8_e4m3fn dtype.", - ): - torch._scaled_mm( - x_fp8, - y_fp8.to(torch.float8_e5m2), - scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N), device="cuda"), - out_dtype=torch.bfloat16, - ) - - @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) - @skipIfRocm() - @parametrize("base_dtype", [torch.bfloat16]) - def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): - torch.manual_seed(42) - input_dtype = e4m3_type - output_dtype = base_dtype - - x = torch.randn(16, 16, device="cuda", dtype=base_dtype) - y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() - - x_scales = tensor_to_scale(x, input_dtype, dim=1).float() - y_scales = tensor_to_scale(y, input_dtype, dim=0).float() - - x_fp8 = to_fp8_saturated(x * x_scales[:, None], e4m3_type) - y_fp8 = to_fp8_saturated(y * y_scales[None, :], e4m3_type) - - # Calculate actual F8 mm - out_scaled_mm, _ = mm_float8( - x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype - ) - - # Calculate emulated F8 mm - out_emulated, _ = mm_float8_emulated( - x_fp8, x_scales[:, None], y_fp8, y_scales[None, :], output_dtype - ) - - if base_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 7e-2, 7e-2 - else: - atol, rtol = 2e-3, 2e-3 - - torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD index e3e7b7b288e7..e712d59597cc 100644 --- a/third_party/cutlass.BUILD +++ b/third_party/cutlass.BUILD @@ -5,17 +5,7 @@ load("@rules_cc//cc:defs.bzl", "cc_library") cc_library( name = "cutlass", - hdrs = glob([ - "include/**/*.h", - "include/**/*.hpp", - "include/**/*.inl", - "tools/util/include/**/*.h", - "tools/util/include/**/*.hpp", - "tools/util/include/**/*.inl", - ]), - includes = [ - "include/", - "tools/util/include/", - ], + hdrs = glob(["include/**/*.h", "include/**/*.hpp"]), + includes = ["include/"], visibility = ["//visibility:public"], ) From 3f8b8f08c8adc3d509ea49fcdb6f8558f96244b8 Mon Sep 17 00:00:00 2001 From: PaliC Date: Fri, 31 May 2024 15:06:53 -0700 Subject: [PATCH 255/706] [Split Build] Make libtorch_global_deps accessible from libtorch wheel (#127570) Title Pull Request resolved: https://github.com/pytorch/pytorch/pull/127570 Approved by: https://github.com/atalman, https://github.com/malfet --- torch/__init__.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index a2492c40a949..18f1752019ec 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -211,17 +211,18 @@ def load_shared_libraries(library_path): if _running_with_deploy() or platform.system() == 'Windows': return - split_build_lib_name = LIBTORCH_PKG_NAME - library_path = find_package_path(split_build_lib_name) - if library_path: - load_shared_libraries(library_path) lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so') here = os.path.abspath(__file__) - lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) + global_deps_lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) + + split_build_lib_name = LIBTORCH_PKG_NAME + library_path = find_package_path(split_build_lib_name) + if library_path: + global_deps_lib_path = os.path.join(library_path, 'lib', lib_name) try: - ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) except OSError as err: # Can only happen for wheel with cuda libs as PYPI deps # As PyTorch is not purelib, but nvidia-*-cu12 is @@ -243,8 +244,11 @@ def load_shared_libraries(library_path): raise err for lib_folder, lib_name in cuda_libs.items(): _preload_cuda_deps(lib_folder, lib_name) - ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) + if library_path: + # loading libtorch_global_deps first due its special logic + load_shared_libraries(library_path) if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \ (_running_with_deploy() or platform.system() != 'Windows'): From 63d7ffe121207b3dd3d4a7d7ce86fabe16cf4eb1 Mon Sep 17 00:00:00 2001 From: James Wu Date: Mon, 3 Jun 2024 15:29:41 +0000 Subject: [PATCH 256/706] Retry of D58015187 Move AsyncCompile to a different file (#127691) Summary: This is a retry of https://github.com/pytorch/pytorch/pull/127545/files and D58015187, fixing the internal test that also imported codecache Test Plan: Same tests as CI in github, plus sandcastle for internal unit tests should pass now Differential Revision: D58054611 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127691 Approved by: https://github.com/oulgen --- test/inductor/test_codecache.py | 2 +- test/inductor/test_cudacodecache.py | 3 +- test/inductor/test_halide.py | 1 + test/inductor/test_kernel_benchmark.py | 1 + test/inductor/test_triton_wrapper.py | 1 + torch/_inductor/async_compile.py | 239 +++++++++++++++++++ torch/_inductor/autotune_process.py | 1 + torch/_inductor/codecache.py | 207 +--------------- torch/_inductor/codegen/cpp_wrapper_cpu.py | 3 +- torch/_inductor/codegen/wrapper.py | 4 +- torch/_inductor/compile_fx.py | 2 + torch/_inductor/compile_worker/__main__.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 1 - torch/_inductor/scheduler.py | 1 + torch/_inductor/select_algorithm.py | 1 + torch/testing/_internal/inductor_utils.py | 2 +- 16 files changed, 257 insertions(+), 214 deletions(-) create mode 100644 torch/_inductor/async_compile.py diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index e7f619a9cb36..1330d635f8db 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -12,8 +12,8 @@ from torch._dynamo import reset from torch._dynamo.utils import counters from torch._inductor import config, metrics +from torch._inductor.async_compile import AsyncCompile from torch._inductor.codecache import ( - AsyncCompile, cuda_compile_command, CUDACodeCache, FxGraphCachePickler, diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index 33a179a9abc7..ac26f6a6656c 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -6,7 +6,8 @@ import torch from torch._inductor import config -from torch._inductor.codecache import AsyncCompile, CUDACodeCache +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.codecache import CUDACodeCache from torch._inductor.codegen.cuda.cuda_env import nvcc_exist from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase diff --git a/test/inductor/test_halide.py b/test/inductor/test_halide.py index 52227c20d1ff..9b923bd1981d 100644 --- a/test/inductor/test_halide.py +++ b/test/inductor/test_halide.py @@ -3,6 +3,7 @@ import unittest import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._inductor.codecache import HalideCodeCache from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta from torch._inductor.test_case import run_tests, TestCase diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 23804e08f23f..ffe0300d8aad 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -6,6 +6,7 @@ from unittest.mock import patch import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.testing import rand_strided from torch._inductor import config from torch._inductor.codecache import PyCodeCache diff --git a/test/inductor/test_triton_wrapper.py b/test/inductor/test_triton_wrapper.py index 24ba84ebf86a..f0d3ad829d45 100644 --- a/test/inductor/test_triton_wrapper.py +++ b/test/inductor/test_triton_wrapper.py @@ -4,6 +4,7 @@ import sys import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._inductor.codecache import PyCodeCache from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py new file mode 100644 index 000000000000..c163df9bd878 --- /dev/null +++ b/torch/_inductor/async_compile.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import functools +import logging +import multiprocessing +import os +import sys +from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from functools import partial +from time import time +from typing import Any, Callable, Dict, List, Optional, Set + +import torch +from torch._dynamo.device_interface import get_registered_device_interfaces +from torch._inductor import config +from torch._inductor.codecache import ( + CodeCacheFuture, + CppCodeCache, + CppPythonBindingsCodeCache, + CUDACodeCache, + HalideCodeCache, + LambdaFuture, + TritonCodeCache, + TritonFuture, +) +from torch._inductor.compile_worker.subproc_pool import ( + _warm_process_pool, + AnyPool, + SubprocPool, +) +from torch._inductor.compile_worker.watchdog import _async_compile_initializer + +from torch._inductor.runtime.compile_tasks import ( + _set_triton_ptxas_path, + _worker_compile_triton, +) +from torch._inductor.runtime.hints import HalideMeta + +from torch.hub import _Faketqdm, tqdm + +# timing metrics for time spent in the compilation +_cumulative_compile_time = 0.0 +_t0: Optional[float] = None + +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +def caching_device_properties(): + for _, device_interface in get_registered_device_interfaces(): + if device_interface.is_available(): + device_interface.Worker.get_device_properties() + + +def _compile_start() -> None: + global _t0 + if _t0 is None: + _t0 = time() + + +def _compile_end() -> None: + global _cumulative_compile_time, _t0 + if _t0 is not None: + t1 = time() + _cumulative_compile_time += t1 - _t0 + _t0 = None + # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) + + +# Used to keep track of all process pools invoked so far. +_pool_set: Set[AnyPool] = set() + + +def shutdown_compile_workers() -> None: + """Shut down all outstanding compile-worker pools.""" + for pool in _pool_set: + pool.shutdown() + after_fork() + + +def after_fork(): + """Reset pools to initial state without shutting them down""" + _pool_set.clear() + AsyncCompile.process_pool.cache_clear() + + +try: + os.register_at_fork(after_in_child=after_fork) +except AttributeError: + pass # register_at_fork does not exists on windows + + +class AsyncCompile: + def __init__(self) -> None: + pass + + @staticmethod + @functools.lru_cache(1) + def pool() -> ThreadPoolExecutor: + assert config.compile_threads > 1 + return ThreadPoolExecutor(config.compile_threads) + + @staticmethod + @functools.lru_cache(1) + def process_pool() -> AnyPool: + assert config.compile_threads > 1 + pool: AnyPool + if config.worker_start_method == "subprocess": + # Wrapper around ProcessPoolExecutor forks in a new process we control + pool = SubprocPool(config.compile_threads) + else: + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + ctx = multiprocessing.get_context(config.worker_start_method) + pool = ProcessPoolExecutor( + config.compile_threads, + mp_context=ctx, + initializer=partial(_async_compile_initializer, os.getpid()), + ) + # when this pool is created in a subprocess object, the normal exit handler + # doesn't run, and we need to register our own handler. + # exitpriority has to be high, because another one of the finalizers will + # kill the worker thread that sends the shutdown message to the workers... + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + + _pool_set.add(pool) + return pool + + @classmethod + def warm_pool(cls) -> None: + if config.compile_threads <= 1: + return + _compile_start() + _warm_process_pool(cls.process_pool(), config.compile_threads) + _compile_end() + + @classmethod + def submit(cls, task: Callable[..., Any]) -> Any: + if config.compile_threads <= 1: + return task() + return cls.pool().submit(task) + + def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): + kernel_code_log.info("Triton Kernel:\n%s", source_code) + _compile_start() + _set_triton_ptxas_path() + + kernel = TritonCodeCache.load(kernel_name, source_code) + if config.compile_threads > 1: + return TritonFuture( + kernel, + self.process_pool().submit( + _worker_compile_triton, + kernel._reload_in_subproc, + ), + ) + else: + kernel.precompile() + return kernel + + def multi_kernel(self, *args, **kwargs) -> Any: + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + # no need to call this in parallel since the sub-kernels are already parallel tasks + return MultiKernelCall(*args, **kwargs) + + def cpp(self, source_code: str): + kernel_code_log.info("CPP Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppCodeCache.load(source_code).kernel + else: + get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) + return LambdaFuture(lambda: get_result().kernel) + + def cpp_pybinding(self, argtypes: List[str], source_code: str): + kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) + else: + get_result = CppPythonBindingsCodeCache.load_pybinding_async( + argtypes, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def cuda(self, source_code, dst_file_ext): + kernel_code_log.info("CUDA Kernel:\n%s", source_code) + + def task(): + return CUDACodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def halide(self, meta: HalideMeta, source_code: str): + kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) + if config.compile_threads <= 1: + return HalideCodeCache.generate_halide(meta, source_code) + else: + get_result = HalideCodeCache.generate_halide_async( + meta, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def wait(self, scope: Dict[str, Any]) -> None: + num_kernels = len( + [ + value + for key, value in scope.items() + if isinstance(value, (Future, CodeCacheFuture)) + ] + ) + pbar = tqdm( + total=num_kernels, + desc="Inductor Compilation", + disable=config.disable_progress, + delay=0, + ) + if config.compile_threads > 1: + for key, result in scope.items(): + if config.verbose_progress and not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(key) + if isinstance(result, (Future, CodeCacheFuture)): + scope[key] = result.result() + pbar.update(1) + + _compile_end() + + +if ( + os.environ.get("TORCH_TNT_IN_USE", "0") == "1" + or os.environ.get("TORCH_WARM_POOL", "1") != "1" +): + pass +else: + AsyncCompile.warm_pool() diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 7db4d2a3291c..c9462d788e8d 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -25,6 +25,7 @@ ) import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch import multiprocessing from torch._dynamo.testing import rand_strided diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 5d648e6b2a98..d338c2665484 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -9,7 +9,6 @@ import io import json import logging -import multiprocessing import os import pickle import pkgutil @@ -26,7 +25,7 @@ import threading import warnings from bisect import bisect_right -from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures import Future from copy import copy from ctypes import c_void_p, cdll, CDLL from functools import partial @@ -49,22 +48,13 @@ ) import torch -from torch._dynamo.device_interface import get_registered_device_interfaces from torch._dynamo.utils import counters, dynamo_timed from torch._inductor import config, exc, metrics from torch._inductor.codegen.cuda import cuda_env -from torch._inductor.compile_worker.subproc_pool import ( - _warm_process_pool, - AnyPool, - SubprocPool, -) -from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import ( _module_to_triton_kernel, _reload_python_module, _reload_python_module_in_subproc, - _set_triton_ptxas_path, - _worker_compile_triton, ) from torch._inductor.runtime.hints import HalideMeta from torch._inductor.runtime.runtime_utils import cache_dir @@ -82,7 +72,6 @@ from torch._inductor.graph import GraphLowering from torch._inductor.ir import ChoiceCaller -from torch.hub import _Faketqdm, tqdm _HERE = os.path.abspath(__file__) _TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) @@ -114,31 +103,11 @@ def use_global_cache() -> bool: output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") -kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") LOCK_TIMEOUT = 600 _IS_WINDOWS = sys.platform == "win32" -# timing metrics for time spent in the compilation -_cumulative_compile_time = 0.0 -_t0: Optional[float] = None - - -def _compile_start() -> None: - global _t0 - if _t0 is None: - _t0 = time() - - -def _compile_end() -> None: - global _cumulative_compile_time, _t0 - if _t0 is not None: - t1 = time() - _cumulative_compile_time += t1 - _t0 - _t0 = None - # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) - log = logging.getLogger(__name__) @@ -3227,12 +3196,6 @@ def load(cls, source_code, dst_file_ext) -> Tuple[DLLWrapper, str, str]: return (DLLWrapper(dst_file_path), hash_key, source_code_path) -def caching_device_properties(): - for _, device_interface in get_registered_device_interfaces(): - if device_interface.is_available(): - device_interface.Worker.get_device_properties() - - class CodeCacheFuture: def result(self): raise NotImplementedError @@ -3266,171 +3229,3 @@ def __init__(self, result_fn): def result(self): return self.result_fn() - - -# Used to keep track of all process pools invoked so far. -_pool_set: Set[AnyPool] = set() - - -def shutdown_compile_workers() -> None: - """Shut down all outstanding compile-worker pools.""" - for pool in _pool_set: - pool.shutdown() - after_fork() - - -def after_fork(): - """Reset pools to initial state without shutting them down""" - _pool_set.clear() - AsyncCompile.process_pool.cache_clear() - - -try: - os.register_at_fork(after_in_child=after_fork) -except AttributeError: - pass # register_at_fork does not exists on windows - - -class AsyncCompile: - def __init__(self) -> None: - pass - - @staticmethod - @functools.lru_cache(1) - def pool() -> ThreadPoolExecutor: - assert config.compile_threads > 1 - return ThreadPoolExecutor(config.compile_threads) - - @staticmethod - @functools.lru_cache(1) - def process_pool() -> AnyPool: - assert config.compile_threads > 1 - pool: AnyPool - if config.worker_start_method == "subprocess": - # Wrapper around ProcessPoolExecutor forks in a new process we control - pool = SubprocPool(config.compile_threads) - else: - # ensure properties have been calculated before processes - # are forked - caching_device_properties() - ctx = multiprocessing.get_context(config.worker_start_method) - pool = ProcessPoolExecutor( - config.compile_threads, - mp_context=ctx, - initializer=partial(_async_compile_initializer, os.getpid()), - ) - # when this pool is created in a subprocess object, the normal exit handler - # doesn't run, and we need to register our own handler. - # exitpriority has to be high, because another one of the finalizers will - # kill the worker thread that sends the shutdown message to the workers... - multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) - - _pool_set.add(pool) - return pool - - @classmethod - def warm_pool(cls) -> None: - if config.compile_threads <= 1: - return - _compile_start() - _warm_process_pool(cls.process_pool(), config.compile_threads) - _compile_end() - - @classmethod - def submit(cls, task: Callable[..., Any]) -> Any: - if config.compile_threads <= 1: - return task() - return cls.pool().submit(task) - - def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): - kernel_code_log.info("Triton Kernel:\n%s", source_code) - _compile_start() - _set_triton_ptxas_path() - - kernel = TritonCodeCache.load(kernel_name, source_code) - if config.compile_threads > 1: - return TritonFuture( - kernel, - self.process_pool().submit( - _worker_compile_triton, - kernel._reload_in_subproc, - ), - ) - else: - kernel.precompile() - return kernel - - def multi_kernel(self, *args, **kwargs) -> Any: - from torch._inductor.codegen.multi_kernel import MultiKernelCall - - # no need to call this in parallel since the sub-kernels are already parallel tasks - return MultiKernelCall(*args, **kwargs) - - def cpp(self, source_code: str): - kernel_code_log.info("CPP Kernel:\n%s", source_code) - if config.compile_threads <= 1: - return CppCodeCache.load(source_code).kernel - else: - get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) - return LambdaFuture(lambda: get_result().kernel) - - def cpp_pybinding(self, argtypes: List[str], source_code: str): - kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) - if config.compile_threads <= 1: - return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) - else: - get_result = CppPythonBindingsCodeCache.load_pybinding_async( - argtypes, source_code, submit_fn=self.submit - ) - return LambdaFuture(get_result) - - def cuda(self, source_code, dst_file_ext): - kernel_code_log.info("CUDA Kernel:\n%s", source_code) - - def task(): - return CUDACodeCache.load(source_code, dst_file_ext)[0] - - return self.submit(task) - - def halide(self, meta: HalideMeta, source_code: str): - kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) - if config.compile_threads <= 1: - return HalideCodeCache.generate_halide(meta, source_code) - else: - get_result = HalideCodeCache.generate_halide_async( - meta, source_code, submit_fn=self.submit - ) - return LambdaFuture(get_result) - - def wait(self, scope: Dict[str, Any]) -> None: - num_kernels = len( - [ - value - for key, value in scope.items() - if isinstance(value, (Future, CodeCacheFuture)) - ] - ) - pbar = tqdm( - total=num_kernels, - desc="Inductor Compilation", - disable=config.disable_progress, - delay=0, - ) - if config.compile_threads > 1: - for key, result in scope.items(): - if config.verbose_progress and not isinstance(pbar, _Faketqdm): - pbar.set_postfix_str(key) - if isinstance(result, (Future, CodeCacheFuture)): - scope[key] = result.result() - pbar.update(1) - - _compile_end() - - -if ( - os.environ.get("TORCH_TNT_IN_USE", "0") == "1" - or os.environ.get("TORCH_WARM_POOL", "1") != "1" -): - pass -else: - AsyncCompile.warm_pool() diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index d9c499de34ee..f33c5fb3136e 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -9,10 +9,11 @@ from sympy import Expr import torch + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch._ops from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey from .. import config, ir - from ..codecache import CudaKernelParamCache from ..utils import cache_on_self, sympy_product from ..virtualized import V diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index f1028e9068d6..0bf4814f80b1 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -33,7 +33,7 @@ from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.symbol import symbol_is_type, SymT -from .. import codecache, config, ir +from .. import async_compile, config, ir from ..ir import ReinterpretView from ..runtime import triton_heuristics from ..runtime.hints import DeviceProperties @@ -501,7 +501,7 @@ def write_header(self) -> None: from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided - from {codecache.__name__} import AsyncCompile + from {async_compile.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 60db8b7ee465..c6eddbed19fe 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -11,6 +11,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from unittest import mock +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools + import torch.fx import torch.utils._pytree as pytree diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py index 6cd1d1e600ac..e478a5345675 100644 --- a/torch/_inductor/compile_worker/__main__.py +++ b/torch/_inductor/compile_worker/__main__.py @@ -3,7 +3,7 @@ import sys import typing -from torch._inductor.codecache import caching_device_properties +from torch._inductor.async_compile import caching_device_properties from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 078ff472461d..6629e0fe5e77 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -748,7 +748,6 @@ def save_cuda_kernel(self, grid, stream, launcher): # User defined triton kernels will have arbitrary kwarg names "meta": launcher.config.kwargs, } - from torch._inductor.codecache import CudaKernelParamCache binary = ( diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index b8cf50bd1ba8..a7517575d888 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -28,6 +28,7 @@ import sympy import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 5b24a002082c..bc89441e3bd8 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -22,6 +22,7 @@ from filelock import FileLock import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.testing import rand_strided from torch._dynamo.utils import counters, identity, preserve_rng_state diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index e8db1e394b96..1078a189f69c 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -5,7 +5,7 @@ import unittest import functools from subprocess import CalledProcessError - +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._inductor.codecache import CppCodeCache from torch.utils._triton import has_triton from torch.testing._internal.common_utils import ( From badf898df239a2146e3c3561e6c41aed0866624f Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Mon, 3 Jun 2024 15:30:04 +0000 Subject: [PATCH 257/706] Remove unstable ARC jobs (#127563) Disable these jobs since we're no longer trying to enable ARC Pull Request resolved: https://github.com/pytorch/pytorch/pull/127563 Approved by: https://github.com/huydhn --- .github/workflows/unstable.yml | 171 --------------------------------- 1 file changed, 171 deletions(-) diff --git a/.github/workflows/unstable.yml b/.github/workflows/unstable.yml index ac1d49d1cce5..a2c4a45bd8b5 100644 --- a/.github/workflows/unstable.yml +++ b/.github/workflows/unstable.yml @@ -32,174 +32,3 @@ jobs: echo echo "Once the jobs are deemed stable enough (% red signal < 5% and TTS < 3h)," echo " they can graduate and move back to pull or trunk." - - # - # Experimental ARC jobs - # - llm-td: - name: before-test - uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read - - target-determination: - name: before-test - uses: ./.github/workflows/target_determination.yml - needs: llm-td - permissions: - id-token: write - contents: read - - linux-jammy-py3_8-gcc11-build: - name: linux-jammy-py3.8-gcc11 - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-jammy-py3.8-gcc11 - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "docs_test", shard: 1, num_shards: 1, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "backwards_compat", shard: 1, num_shards: 1, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "distributed", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "distributed", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - ]} - - linux-jammy-py3_8-gcc11-test: - name: linux-jammy-py3.8-gcc11 - uses: ./.github/workflows/_linux-test-rg.yml - needs: - - linux-jammy-py3_8-gcc11-build - - target-determination - with: - build-environment: linux-jammy-py3.8-gcc11 - docker-image: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.test-matrix }} - - linux-jammy-py3_8-gcc11-no-ops: - name: linux-jammy-py3.8-gcc11-no-ops - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-jammy-py3.8-gcc11-no-ops - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 1 }, - ]} - - linux-jammy-py3_8-gcc11-pch: - name: linux-jammy-py3.8-gcc11-pch - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-jammy-py3.8-gcc11-pch - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 1 }, - ]} - - linux-focal-py3_8-clang10-onnx-build: - name: linux-focal-py3.8-clang10-onnx - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-focal-py3.8-clang10-onnx - docker-image-name: pytorch-linux-focal-py3-clang10-onnx - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - ]} - - linux-focal-py3_8-clang10-onnx-test: - name: linux-focal-py3.8-clang10-onnx - uses: ./.github/workflows/_linux-test-rg.yml - needs: - - linux-focal-py3_8-clang10-onnx-build - - target-determination - with: - build-environment: linux-focal-py3.8-clang10-onnx - docker-image: ${{ needs.linux-focal-py3_8-clang10-onnx-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang10-onnx-build.outputs.test-matrix }} - - linux-jammy-py3_10-clang15-asan-build: - name: linux-jammy-py3.10-clang15-asan - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-jammy-py3.10-clang15-asan - docker-image-name: pytorch-linux-jammy-py3-clang15-asan - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "linux.4xlarge" }, - { config: "default", shard: 2, num_shards: 6, runner: "linux.4xlarge" }, - { config: "default", shard: 3, num_shards: 6, runner: "linux.4xlarge" }, - { config: "default", shard: 4, num_shards: 6, runner: "linux.4xlarge" }, - { config: "default", shard: 5, num_shards: 6, runner: "linux.4xlarge" }, - { config: "default", shard: 6, num_shards: 6, runner: "linux.4xlarge" }, - ]} - sync-tag: asan-build-arc - - linux-focal-py3_8-clang10-build: - name: linux-focal-py3.8-clang10 - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-focal-py3.8-clang10 - docker-image-name: pytorch-linux-focal-py3.8-clang10 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "crossref", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "crossref", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - ]} - - linux-focal-py3_8-clang10-test: - name: linux-focal-py3.8-clang10 - uses: ./.github/workflows/_linux-test-rg.yml - needs: - - linux-focal-py3_8-clang10-build - - target-determination - with: - build-environment: linux-focal-py3.8-clang10 - docker-image: ${{ needs.linux-focal-py3_8-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang10-build.outputs.test-matrix }} - - linux-focal-py3_11-clang10-build: - name: linux-focal-py3.11-clang10 - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-focal-py3.11-clang10 - docker-image-name: pytorch-linux-focal-py3.11-clang10 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "crossref", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "crossref", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - ]} - - linux-focal-py3_11-clang10-test: - name: linux-focal-py3.11-clang10 - uses: ./.github/workflows/_linux-test-rg.yml - needs: - - linux-focal-py3_11-clang10-build - - target-determination - with: - build-environment: linux-focal-py3.11-clang10 - docker-image: ${{ needs.linux-focal-py3_11-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_11-clang10-build.outputs.test-matrix }} - - # - # End of Experimental ARC jobs - # From 430cdfc0acea8f78ee59aca8c85499f7cc4736da Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Mon, 3 Jun 2024 15:38:27 +0000 Subject: [PATCH 258/706] [ATen][Native] fixes sparse SPMV on aarch64 (#127642) Fixes #127491 In #127491 result was allocated as `result = at::empty(...)`, which does not guarantee `result` being filled by zeros, therefore `torch.mv` was producing non-finite values. This happened mainly because the corner case (`beta = 0`) of `addmv` was not taken care of, as it should be just like in any other `addmv`/`addmm`: https://github.com/pytorch/pytorch/blob/923edef31c7f3e98a14625724f2019b1422dcb26/aten/src/ATen/native/mkl/SparseBlasImpl.cpp#L307-L311 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127642 Approved by: https://github.com/malfet --- aten/src/ATen/native/sparse/SparseBlasImpl.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp index fd67f0694f2d..c2e8c4439ab9 100644 --- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp @@ -410,6 +410,9 @@ void addmv_out_sparse_csr( const Tensor& result) { #if !AT_USE_MKL_SPARSE() TORCH_CHECK(mat.layout() == kSparseBsr || mat.layout() == kSparseCsr, "Unexpected layout", mat.layout()); + if (beta.toComplexDouble() == 0.) { + result.zero_(); + } AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( result.scalar_type(), "addmv_out_sparse_csr_impl_reference", [&] { if (mat.crow_indices().scalar_type() == kLong) { From 8677508167dae402033ca6e26aa6007c910a3a71 Mon Sep 17 00:00:00 2001 From: Shuqiang Zhang Date: Fri, 31 May 2024 14:59:32 -0700 Subject: [PATCH 259/706] [c10d] guard gpu context during abort (#127363) This is a mitigation for an internal out of MEM issues on GPU0 that happend during comms abort, this PR was tested internally to have fixed the out of MEM issue. Note This is supposed to be mitigation only, as the ideal fix should be within NCCL comm libs, which should just set the right CUDA context before any CUDA call and restore it to its exact previous state ncclCommDestroy/ncclCommAbort -> commReclaim -> commDestroySync (https://fburl.com/code/pori1tka) In commDestroySync, it thinks that "current device context" is not same as comm's device context. It tries to: 1) save the current context 2) sets the comm's device context 3) cleans up things 4) Restores "previously stored context" by another cudaSetDevice. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127363 Approved by: https://github.com/wconstab --- .../distributed/c10d/ProcessGroupNCCL.cpp | 20 ++++++++++++++++++- .../distributed/c10d/ProcessGroupNCCL.hpp | 1 + 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index ce2dc9d072b4..1c0bdc43be35 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -197,6 +197,20 @@ inline std::string getKeyFromDevice(at::Device& device) { return std::to_string(device.index()); } +inline at::DeviceIndex getIndexFromDeviceKey(const std::string& deviceKey) { + // initialize the device index to -1, which is an invalid value. + int index = -1; + try { + index = std::stoi(deviceKey); + } catch (const std::invalid_argument& e) { + LOG(WARNING) << c10::str( + "Invalid deviceKey: ", deviceKey, ",", e.what(), "."); + } catch (const std::out_of_range& e) { + LOG(ERROR) << "Out of range: " << e.what(); + } + return static_cast(index); +} + std::string getKeySendRecv(int myRank, int peer) { int lowRank = myRank < peer ? myRank : peer; int highRank = myRank < peer ? peer : myRank; @@ -1050,7 +1064,11 @@ void ProcessGroupNCCL::abortCommsFromMap( for (auto& it : ncclCommsMap) { auto& devName = it.first; auto& ncclComm = it.second; - + at::cuda::OptionalCUDAGuard gpuGuard; + at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); + if (deviceIndex >= 0) { + gpuGuard.set_index(deviceIndex); + } LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " << ncclComm->ncclComm_ << " on CUDA device: " << devName; ncclComm->ncclCommAbort(abortReason); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 117c24ebfb82..8a3b7b1b5c21 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -908,6 +908,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // communication, the key will be "1:2" on both processes. Note: this is for // the scenario where there is only 1 GPU per process. When it comes to // multiple GPUs per process, this part may need to redesigned. + // TODO: we probably need a separte map for P2P comms std::unordered_map> devNCCLCommMap_; // The NCCL communicators currently in process of being initialized. From 53f001c5993c4fd3cf99eaaae03012fb5e99c18e Mon Sep 17 00:00:00 2001 From: atalman Date: Mon, 3 Jun 2024 15:49:48 +0000 Subject: [PATCH 260/706] Revert "correct BLAS input (#126200)" (#127762) This reverts commit ea13e9a097aaa875a2b404822579b7f8b62ea291. Looks like this could have caused: https://github.com/pytorch/pytorch/actions/runs/9346105069/job/25722431775#step:17:984 Aarch64 tests failures: ``` + echo 'Checking that MKLDNN is available on aarch64' Checking that MKLDNN is available on aarch64 + pushd /tmp /tmp / + python -c 'import torch; exit(0 if torch.backends.mkldnn.is_available() else 1)' Error: Process completed with exit code 1. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127762 Approved by: https://github.com/PaliC, https://github.com/malfet --- cmake/Dependencies.cmake | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index faac0117ed93..8c7751f4c07b 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -147,21 +147,13 @@ set(AT_MKLDNN_ACL_ENABLED 0) set(AT_MKLDNN_ENABLED 0) set(AT_MKL_ENABLED 0) # setting default preferred BLAS options if not already present. -if(NOT DEFINED BLAS) - if(NOT INTERN_BUILD_MOBILE) - set(BLAS "MKL" CACHE STRING "Selected BLAS library") - else() - set(BLAS "Eigen" CACHE STRING "Selected BLAS library") - endif() -elseif(NOT BLAS STREQUAL "MKL") - if(USE_MKLDNN) - message(WARNING - "You explicitly chose with BLAS to not use MKL, so disabling USE_MKLDNN. Suppress this warning with " - "-DUSE_MKLDNN=OFF.") - set(USE_MKLDNN OFF) - endif() +if(NOT INTERN_BUILD_MOBILE) + set(BLAS "MKL" CACHE STRING "Selected BLAS library") +else() + set(BLAS "Eigen" CACHE STRING "Selected BLAS library") + set(AT_MKLDNN_ENABLED 0) + set(AT_MKL_ENABLED 0) endif() - set_property(CACHE BLAS PROPERTY STRINGS "ATLAS;BLIS;Eigen;FLAME;Generic;MKL;OpenBLAS;vecLib") message(STATUS "Trying to find preferred BLAS backend of choice: " ${BLAS}) From d1fad416a817b3360964f7d0ab4e448cac7ce367 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 3 Jun 2024 15:51:49 +0000 Subject: [PATCH 261/706] Revert "Add aten._unsafe_masked_index (#116491)" This reverts commit f03f8bc901a6c9038308a6353e8d280f4b5628f5. Reverted https://github.com/pytorch/pytorch/pull/116491 on behalf of https://github.com/PaliC due to breaking onnx tests ([comment](https://github.com/pytorch/pytorch/pull/116491#issuecomment-2145557724)) --- .../src/ATen/functorch/BatchRulesIndexing.cpp | 23 ------ .../ATen/native/TensorAdvancedIndexing.cpp | 76 ------------------ aten/src/ATen/native/native_functions.yaml | 12 --- test/distributed/_tensor/test_dtensor_ops.py | 2 - test/inductor/test_torchinductor.py | 32 +------- test/inductor/test_torchinductor_opinfo.py | 2 - test/onnx/test_fx_op_consistency.py | 5 -- test/test_fx_experimental.py | 45 +++-------- test/test_mps.py | 8 -- tools/autograd/derivatives.yaml | 11 --- tools/autograd/gen_variable_type.py | 4 - torch/_decomp/__init__.py | 2 - torch/_decomp/decompositions.py | 55 +------------ torch/_dynamo/trace_rules.py | 2 - torch/_inductor/codegen/triton.py | 5 -- torch/_inductor/decomposition.py | 2 - torch/_inductor/lowering.py | 70 +++------------- torch/_inductor/utils.py | 1 - .../_internal/common_methods_invocations.py | 79 ++----------------- 19 files changed, 28 insertions(+), 408 deletions(-) delete mode 100644 aten/src/ATen/functorch/BatchRulesIndexing.cpp diff --git a/aten/src/ATen/functorch/BatchRulesIndexing.cpp b/aten/src/ATen/functorch/BatchRulesIndexing.cpp deleted file mode 100644 index eb571b298078..000000000000 --- a/aten/src/ATen/functorch/BatchRulesIndexing.cpp +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include - -namespace at { namespace functorch { - -#define OP_DECOMPOSE(op) m.impl(#op, static_cast(native::op)); -#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast(native::op)); - -TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { - OP_DECOMPOSE2(_unsafe_index, Tensor); - OP_DECOMPOSE(_unsafe_masked_index); - OP_DECOMPOSE(_unsafe_index_put); - OP_DECOMPOSE(_unsafe_masked_index_put_accumulate); -} - -}} diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 041c3f2770ea..395af8e5ef13 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -656,82 +656,6 @@ Tensor _unsafe_index(const Tensor& self, const torch::List return at::index(self, indices); } -Tensor _unsafe_masked_index(const Tensor& self, const Tensor& mask, const torch::List>& indices, const Scalar& fill) { - // Unsafe masked index is equivalent to - // where(mask, self[indices], fill) - // with the main difference being that the when the `mask` is false, the tensor - // `self` is not indexed using `indices`. This allows `indices` to be out-of-bounds - // when `mask` is false. When `mask` is true, the `indices` are expected to be - // in bounds and is not checked. - // - // This function is not meant to be executed on eager mode. An unoptimized version - // is provided here. - // - // compiler backends should implement this op such that `self[indices]` is not - // loaded when `mask` is true. See inductor for a reference. - auto clamp = [](const c10::optional& index, auto size) -> c10::optional { - if (!index) { - return index; - } - // Disallow bool - auto dtype = index->scalar_type(); - TORCH_CHECK(dtype == kLong || dtype == kInt, - "_unsafe_masked_index found unexpected index type ", dtype); - return at::clamp(*index, -size, size - 1); - }; - - torch::List> clamped_indices(indices); - std::transform(indices.begin(), indices.end(), self.sizes().begin(), clamped_indices.begin(), clamp); - - if (self.numel() == 0) { - // Returns a tensor filled with `fill` value - // We use a hack here since we do not have a method to get the - // correct size of the tensor. (except with meta impl which is - // not available on mobile builds) - std::vector new_sizes(self.dim()); - auto compute_new_size = [](const c10::optional& index, auto size) -> int64_t { - if (index && size == 0) { - return 1; - } else { - return size; - } - }; - std::transform(indices.begin(), indices.end(), self.sizes().begin(), new_sizes.begin(), compute_new_size); - auto result = self.new_full(new_sizes, fill); - return at::_unsafe_index(result, clamped_indices); - } - - auto result = at::_unsafe_index(self, clamped_indices); - return result.masked_fill(at::logical_not(mask), fill); -} - -Tensor _unsafe_masked_index_put_accumulate(const Tensor& self, const Tensor& mask, const torch::List>& indices, const Tensor& values) { - // This is the backward of _unsafe_masked_index. - // This function is not meant to be executed on eager mode. - - if (self.numel() == 0) { - return self.clone(); - } - - // We recompute the clamped indices and rely on inductor to CSE the computation - auto clamp = [](const c10::optional& index, auto size) -> c10::optional { - if (!index) { - return index; - } - // Disallow bool - auto dtype = index->scalar_type(); - TORCH_CHECK(dtype == kLong || dtype == kInt, - "_unsafe_masked_index found unexpected index type ", dtype); - return at::clamp(*index, -size, size - 1); - }; - - torch::List> clamped_indices(indices); - std::transform(indices.begin(), indices.end(), self.sizes().begin(), clamped_indices.begin(), clamp); - - auto masked_value = values.masked_fill(at::logical_not(mask), 0); - return at::_unsafe_index_put(self, clamped_indices, masked_value, true); -} - Tensor & put_(Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) { // See note [Writing Nondeterministic Operations] // Nondeterministic when index contains duplicate entries and we do not accumulate diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5c28397a07d5..a051f43e87eb 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3061,18 +3061,6 @@ dispatch: CompositeExplicitAutograd: _unsafe_index -# Used by inductor to generate masked loads -# Note that we don't support boolean indexing, to avoid dynamic output shapes -- func: _unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor - variants: function - dispatch: - CompositeExplicitAutograd: _unsafe_masked_index - -- func: _unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor - variants: function - dispatch: - CompositeExplicitAutograd: _unsafe_masked_index_put_accumulate - - func: index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!) structured: True variants: function diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index ef1ccc754c6c..22a56118b212 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -478,8 +478,6 @@ def wrapped(fn): xfail("unique"), xfail("unsafe_split"), xfail("unsafe_chunk"), - xfail("_unsafe_masked_index"), - xfail("_unsafe_masked_index_put_accumulate"), xfail("var_mean"), xfail("var_mean", "unbiased"), xfail("vdot"), diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 09a82b9d06e2..111d0e1ef959 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1430,33 +1430,6 @@ def flip(x): actual = _run_and_assert_no_indirect_indexing(self, flip_opt, x) self.assertEqual(expect, actual) - def test__unsafe_masked_index(self): - def fn(a, mask, idx): - return aten._unsafe_masked_index(a, mask, idx, 1) - - self.common( - fn, - ( - torch.randn(8, device=self.device), - torch.tensor([True, False, True], device=self.device), - [torch.tensor([3, 9, -2], device=self.device)], - ), - ) - - def test__unsafe_masked_index_put_accumulate(self): - def fn(a, mask, idx, values): - return aten._unsafe_masked_index_put_accumulate(a, mask, idx, values) - - self.common( - fn, - ( - torch.randn(8, device=self.device), - torch.tensor([True, False, True], device=self.device), - [torch.tensor([3, 9, -2], device=self.device)], - torch.randn(3, device=self.device), - ), - ) - def test_sum1(self): def fn(a, b): return ((a + b).sum(-1),) @@ -10967,10 +10940,7 @@ def fn(x, n): fn_opt = torch.compile(fn) code = run_and_get_triton_code(fn_opt, x, 8) # load should be masked - self.assertTrue( - "tl.load(in_ptr0 + (tmp0), xmask" in code - or "tl.load(in_ptr0 + (tmp0), (xmask).to(tl.int1)" in code - ) + self.assertTrue("tl.load(in_ptr0 + (tmp0), xmask" in code) self.assertEqual(fn(x, 8), fn_opt(x, 8)) def test_kernel_names_descriptive(self): diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 22bec1756487..b66c0ce0832f 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -382,8 +382,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): }, ("std_mean.unbiased", "cuda", f16): {"reference_in_float": True}, ("uniform", "cuda"): {"reference_in_float": True}, - ("_unsafe_masked_index_put_accumulate", "cuda", f16): {"atol": 1e-4, "rtol": 0.01}, - ("_unsafe_masked_index_put_accumulate", "cpu", f16): {"atol": 1e-4, "rtol": 0.01}, # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors ("nn.functional.interpolate.bilinear", "cpu", u8): {"atol": 1, "rtol": 0}, ("nn.functional.upsample_bilinear", "cpu", u8): {"atol": 1, "rtol": 0}, diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index 08230bbc0099..4a4171699e65 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -170,11 +170,6 @@ def skip_torchlib_forward_compatibility( dtypes=(torch.float16,), reason="fixme: Assertion error: result mismatch", ), - xfail( - "_unsafe_masked_index", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), xfail( "add", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Add") diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index d3ee06d889a8..d3231df35313 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1,6 +1,5 @@ # Owner(s): ["module: fx"] -import functools import math import numbers import operator @@ -52,7 +51,6 @@ from torch.testing._internal.common_nn import module_tests, new_module_tests from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase from torch.testing._internal.jit_utils import JitTestCase -import torch.utils._pytree as pytree try: import torchvision.models @@ -1625,40 +1623,21 @@ def jit_infer_type(v): param_names = [] param_values = [] fx_args = [] - - idx = 0 - - def process_arg(arg, name): - if isinstance(arg, torch.Tensor): - param_names.append(name) - param_values.append(arg) - return name - else: - return f"{repr(arg)}" - - def process_arg_with_idx(arg): - nonlocal idx - res = process_arg(arg, f"arg_{idx}") - idx = idx + 1 - return res - - def str_arg(arg): - if isinstance(arg, tuple): - args = [f"{str_arg(v)}, " for v in arg] - return f"({' '.join(args)})" - elif isinstance(arg, list): - args = [f"{str_arg(v)}" for v in arg] - return f"[{', '.join(args)}]" + for idx, v in enumerate(arg_values): + if isinstance(v, torch.Tensor): + param_names.append(f"arg_{idx}") + param_values.append(v) + fx_args.append(param_names[-1]) else: - return arg - - for v in arg_values: - arg = pytree.tree_map(process_arg_with_idx, v) - fx_args.append(str_arg(arg)) + fx_args.append(f"{repr(v)}") for k, v in kwarg_values.items(): - arg = pytree.tree_map(functools.partial(process_arg, name=k), v) - fx_args.append(f"{k} = {str_arg(arg)}") + if isinstance(v, torch.Tensor): + param_names.append(k) + param_values.append(v) + fx_args.append(f"{k} = {k}") + else: + fx_args.append(f"{k} = {repr(v)}") code = f""" class TestModule(torch.nn.Module): diff --git a/test/test_mps.py b/test/test_mps.py index 8b3ae97ff218..8c3bbf4b7bcf 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -144,10 +144,6 @@ def mps_ops_grad_modifier(ops): # round not working properly for float16 'round': [torch.float16], - - # atomic operation in backward pass - '_unsafe_masked_index': [torch.float16], - '_unsafe_masked_index_put_accumulate': [torch.float16], } MACOS_12_3_XFAILLIST_GRAD = { @@ -355,7 +351,6 @@ def mps_ops_modifier(ops): '__rdiv__', '__rmatmul__', '_chunk_cat', - '_unsafe_masked_index', 'acos', 'acosh', 'all', @@ -910,9 +905,6 @@ def mps_ops_modifier(ops): # round not working properly for float16 'round': [torch.float16], - - # atomic operations not supported - '_unsafe_masked_index_put_accumulate': [torch.bool, torch.int8, torch.uint8, torch.float16, torch.int16, torch.int64], } if product_version < 14.0: diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index cb50be54feb5..4922513f295d 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -830,17 +830,6 @@ self: at::_unsafe_index_put(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad, true) result: auto_linear -- name: _unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor - self: at::_unsafe_masked_index_put_accumulate(grad.new_zeros_symint(self.sym_sizes(), self.options()), mask, indices, grad) - mask: non_differentiable - result: _unsafe_masked_index(self_t, mask, indices, 0) - -- name: _unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor - self: grad - mask: non_differentiable - values: at::_unsafe_masked_index(grad, mask, indices, 0) - result: at::_unsafe_masked_index_put_accumulate(self_t, mask, indices, values_t) - - name: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor self: grad # The case source.dim() == 0 is necessary to support scalar tensors of the form diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 86ea05bc6103..b9651ea2da80 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -375,10 +375,6 @@ "linalg_lu_solve", "_linalg_slogdet", "_linalg_solve_ex", - "_unsafe_index", - "_unsafe_index_put", - "_unsafe_masked_index", - "_unsafe_masked_index_put_accumulate", } GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = { diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 5a45b6a37a14..b277bb7eceb0 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -445,8 +445,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.unfold_backward, aten.unfold_copy, aten._unsafe_index, - aten._unsafe_masked_index, - aten._unsafe_masked_index_put_accumulate, aten.unsafe_split.Tensor, aten.unsafe_split_with_sizes, aten._unsafe_view, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index fc3ba94c806e..5bec539db06c 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -8,7 +8,6 @@ from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union import torch -import torch._meta_registrations import torch._prims as prims import torch._prims_common as utils import torch.nn.functional as F @@ -3713,62 +3712,10 @@ def _reshape_alias(x, shape, *args): @register_decomposition([aten._unsafe_index]) -def _unsafe_index(x, indices): +def _index(x, indices): return aten.index(x, indices) -@register_decomposition([aten._unsafe_masked_index]) -def _unsafe_masked_index(x, mask, indices, fill): - for index in indices: - if index is not None: - torch._check( - index.dtype in [torch.long, torch.int], - lambda: "tensors used as indices must be long or int tensors", - ) - - torch._check( - mask.dtype == torch.bool, - lambda: "tensors used as masks must be bool tensors", - ) - - if x.numel() == 0: - meta_result = torch._meta_registrations.meta_index_Tensor(x, indices) - return x.new_full(meta_result.shape, fill) - - for i in range(len(indices)): - index = indices[i] - if index is not None: - indices[i] = index.clamp(min=0, max=x.size(i) - 1) - - return aten._unsafe_index(x, indices).masked_fill(~mask, fill) - - -@register_decomposition([aten._unsafe_masked_index_put_accumulate]) -def _unsafe_masked_index_put_accumulate(x, mask, indices, values): - for index in indices: - if index is not None: - torch._check( - index.dtype in [torch.long, torch.int], - lambda: "tensors used as indices must be long or int tensors", - ) - - torch._check( - mask.dtype == torch.bool, - lambda: "tensors used as masks must be bool tensors", - ) - - if x.numel() == 0: - return x.clone() - - for i in range(len(indices)): - index = indices[i] - if index is not None: - indices[i] = index.clamp(min=-x.size(i), max=x.size(i) - 1) - - masked_value = values.masked_fill(~mask, 0) - return aten._unsafe_index_put(x, indices, masked_value, accumulate=True) - - def _nll_loss_forward( self: Tensor, target: Tensor, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 1f7b70c29325..90f0667fecfb 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1575,8 +1575,6 @@ "torch._unpack_dual", "torch._unsafe_index_put", "torch._unsafe_index", - "torch._unsafe_masked_index_put_accumulate", - "torch._unsafe_masked_index", "torch._use_cudnn_ctc_loss", "torch._use_cudnn_rnn_flatten_weight", "torch._values_copy", diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 953689e668e9..4b0ea92f3bf4 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -871,11 +871,6 @@ def index_expr(cls, expr, dtype): @staticmethod def masked(mask, body, other): - if mask is not None and torch.version.hip is not None: - mask = V.kernel.cse.generate( - V.kernel.compute, - f"{mask}.to(tl.int1)", - ) with V.kernel.mask_loads(mask) as new_mask: result = body() diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index c1269e3703c8..960c3a42e1f1 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -87,8 +87,6 @@ # the Inductor decomp table. decomps_to_exclude = [ aten._unsafe_index, - aten._unsafe_masked_index, - aten._unsafe_masked_index_put_accumulate, aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py aten._softmax_backward_data, aten.clamp_max, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index d8e679c67adb..20b0082eb1d9 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2927,17 +2927,6 @@ def fn(idx): def index_impl(x, indices, check): - output_size, inner_fn = index_impl_helper(x, indices, check) - - return Pointwise.create( - device=x.get_device(), - dtype=x.get_dtype(), - inner_fn=inner_fn, - ranges=output_size, - ) - - -def index_impl_helper(x, indices, check): assert isinstance(indices, (list, tuple)) x_loader = x.make_loader() indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) @@ -2952,11 +2941,11 @@ def index_impl_helper(x, indices, check): x_size = x.get_size() indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] - if check and 0 in indexed_size and 0 not in tensor_size: + if 0 in indexed_size and 0 not in tensor_size: raise IndexError("index is out of bounds for dimension with size 0") indexed_size = [x_size[i] for i in range(len(indices))] - return index_output_size_and_inner_fn( + output_size, inner_fn = index_output_size_and_inner_fn( x_size, indices, tensor_indices, @@ -2967,6 +2956,13 @@ def index_impl_helper(x, indices, check): check=check, ) + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=output_size, + ) + @register_lowering(aten.index, type_promotion_kind=None) def index(x, indices): @@ -3163,54 +3159,6 @@ def load_source_val(): return view(result_flat, self.get_size()) -fallback__unsafe_masked_index = fallback_handler( - aten._unsafe_masked_index.default, add_to_fallback_set=False -) - -fallback__unsafe_masked_index_put_accumulate = fallback_handler( - aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False -) - - -@register_lowering(aten._unsafe_masked_index, type_promotion_kind=None) -def _unsafe_masked_index(self, mask, indices, fill): - ranges, _unsafe_index_fn = index_impl_helper(self, indices, check=False) - mask_loader = mask.make_loader() - - def inner_fn(idx): - mask_val = ops.to_dtype(mask_loader(idx), torch.bool) - return ops.masked(mask_val, lambda: _unsafe_index_fn(idx), fill) - - return Pointwise.create( - device=self.get_device(), - dtype=self.get_dtype(), - inner_fn=inner_fn, - ranges=ranges, - ) - - -@register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None) -def _unsafe_masked_index_put_accumulate(x, mask, indices, values): - if torch.version.hip is not None: - # Avoid a triton compiler failure - return fallback__unsafe_masked_index_put_accumulate(x, mask, indices, values) - - masked_value = where(mask, values, 0) - shape = x.get_size() - clamped_indices = [ - clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None - for i in range(len(indices)) - ] - # TODO: use a masked store for this. currently only triton - # supports masked stores and cpp backend does not. - return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True) - - -@make_pointwise -def clamp(a, min, max): - return ops.maximum(min, ops.minimum(max, a)) - - @register_lowering(aten.as_strided_scatter, type_promotion_kind=None) def as_strided_scatter(self, src, size, stride, storage_offset=None): output = clone(self) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 5c41df4406c6..0915a8330c34 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -647,7 +647,6 @@ def get_first_incompatible_cudagraph_node(gm): forbidden_set.update( { "aten._unsafe_index_put.default", - "aten._unsafe_masked_index_put_accumulate.default", "aten.index_put.default", "aten.index_put_.default", "aten.scatter.src", diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index fe58b99e34a9..50cfac763be5 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -5267,64 +5267,6 @@ def make_idx(n, m): args=(0, idx, src, reduce), kwargs={'include_self': True}) -def sample_inputs__unsafe_masked_index(op_info, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - - def make_idx(n, m, dim, d): - view_shape = [1] * dim - view_shape[d] = n - return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) - - cases = [ - ((S, S), S, M), - ((S, S), M, S), - ((S, S, S), S, M), - ] - - fill_value = make_tensor([], dtype=dtype, device="cpu").item() - - for c in cases: - self_shape, high, idx_size = c - dim = len(self_shape) - indices = [make_idx(idx_size, high, dim, d) for d in range(dim)] - masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] - mask = functools.reduce(torch.logical_and, masks) - yield SampleInput(make_arg(self_shape), mask, indices, fill_value) - - masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] - mask = functools.reduce(torch.logical_and, masks) - yield SampleInput(make_arg(self_shape), mask, indices, fill_value) - -def sample_inputs__unsafe_masked_index_put_accumulate(op_info, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - - def make_idx(n, m, dim, d): - view_shape = [1] * dim - view_shape[d] = n - return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) - - cases = [ - ((S, S), S, (M, M)), - ((S, S), M, (S, S + 1)), - ((S, S, S), S, (M, M - 1, M + 1)), - ] - - fill_value = make_tensor([], dtype=dtype, device="cpu").item() - - for c in cases: - self_shape, high, idx_sizes = c - dim = len(self_shape) - indices = [make_idx(idx_sizes[d], high, dim, d) for d in range(dim)] - masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] - mask = functools.reduce(torch.logical_and, masks) - values = make_arg(idx_sizes) - yield SampleInput(make_arg(self_shape), mask, indices, values) - - masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] - mask = functools.reduce(torch.logical_and, masks) - yield SampleInput(make_arg(self_shape), mask, indices, values) - - def sample_inputs_mode(op_info, device, dtype, requires_grad, **kwargs): args = ( ((S, S, S), (),), @@ -18067,6 +18009,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),), sample_inputs_func=sample_inputs_column_stack,), OpInfo('pinverse', op=torch.pinverse, @@ -18157,22 +18102,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_out=True, sample_inputs_func=sample_inputs_index_reduce, ) for reduction_type in ('mean', 'prod', 'amin', 'amax')), - OpInfo('_unsafe_masked_index', - dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), - supports_out=False, - supports_inplace_autograd=False, - supports_scripting=False, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - sample_inputs_func=sample_inputs__unsafe_masked_index), - OpInfo('_unsafe_masked_index_put_accumulate', - dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), - supports_out=False, - supports_inplace_autograd=False, - supports_scripting=False, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - sample_inputs_func=sample_inputs__unsafe_masked_index_put_accumulate), OpInfo('__getitem__', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), # Runs very slowly on slow gradcheck - alternatively reduce input sizes @@ -18199,6 +18128,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): test_neg_view=False, sample_inputs_func=sample_inputs_index_put, skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), DecorateInfo(unittest.skip("Skipped"), 'TestBwdGradients', 'test_fn_grad', dtypes=[torch.float64], device_type='cuda', active_if=(TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)), )), @@ -19098,6 +19028,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): skips=( # Not implemented on CUDA DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), # JIT tests don't work with Tensor keyword arguments # https://github.com/pytorch/pytorch/issues/58507 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), From ded580a594c6d9b1397519981b0f34e8c94e17d8 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Fri, 31 May 2024 22:16:07 -0700 Subject: [PATCH 262/706] [dtensor] standardize multi mesh-dim strategy with utils (#126712) This PR standardize the multi mesh-dim strategy generation by unifying a util to expand from a single mesh dim strategy to multi mesh dim strategy, to allow strategy generation simpler Pull Request resolved: https://github.com/pytorch/pytorch/pull/126712 Approved by: https://github.com/tianyu-l --- torch/distributed/_tensor/_op_schema.py | 27 ++- .../distributed/_tensor/ops/embedding_ops.py | 174 ++++++------------ torch/distributed/_tensor/ops/math_ops.py | 55 ++---- torch/distributed/_tensor/ops/tensor_ops.py | 84 +++------ torch/distributed/_tensor/ops/utils.py | 53 +++++- 5 files changed, 167 insertions(+), 226 deletions(-) diff --git a/torch/distributed/_tensor/_op_schema.py b/torch/distributed/_tensor/_op_schema.py index 85c14746ce13..43aa065a59e0 100644 --- a/torch/distributed/_tensor/_op_schema.py +++ b/torch/distributed/_tensor/_op_schema.py @@ -238,15 +238,24 @@ def args_spec(self) -> Tuple[DTensorSpec, ...]: with NO non-DTensor positional arguments (i.e. int/float/tuple, etc) mainly used by sharding propagation to propagate the output spec """ - # filter out non-relevant values from args schema to get a clean spec list - # this would mainly be used by sharding propagation rules - if self.schema_info is not None and self.schema_info.needs_pytree: - return tuple( - item - for item in tree_leaves(self.args_schema) - if isinstance(item, DTensorSpec) - ) - return tuple(item for item in self.args_schema if isinstance(item, DTensorSpec)) + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, DTensorSpec)) + + @property + def args_strategy(self) -> Tuple[OpStrategy, ...]: + # filter out non-relevant values from args schema to get a clean OpStrategy list + # separate with args_spec for the ease of type annotation + # TODO: see if we should merge this with args_spec + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, OpStrategy)) def __repr__(self) -> str: args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema]) diff --git a/torch/distributed/_tensor/ops/embedding_ops.py b/torch/distributed/_tensor/ops/embedding_ops.py index f861c5fcbd57..7cc8dd262638 100644 --- a/torch/distributed/_tensor/ops/embedding_ops.py +++ b/torch/distributed/_tensor/ops/embedding_ops.py @@ -1,25 +1,17 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # implement matrix related ops for distributed tensor -import itertools from dataclasses import dataclass, field from typing import cast, List, Optional import torch import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementStrategy, - StrategyType, -) +from torch.distributed._tensor._op_schema import OpSchema, OpStrategy, StrategyType from torch.distributed._tensor.ops.utils import ( - generate_redistribute_costs, - is_tensor_shardable, + expand_to_full_mesh_op_strategy, register_op_strategy, ) from torch.distributed._tensor.placement_types import ( - DTensorSpec, Partial, Placement, Replicate, @@ -182,64 +174,35 @@ def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: indices_shape = indices_strategy.shape output_emd_dim = len(indices_shape) - all_mesh_dim_strategies = [] - - for mesh_dim in range(mesh.ndim): - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, weight, input_indices] - # first we always have replicate all for inputs and output - all_replicate: List[Placement] = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate - colwise_sharding = [Shard(output_emd_dim), Shard(1), Replicate()] - single_mesh_dim_strategies.append(colwise_sharding) - - # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial - embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0]) - - # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates - # from the input indices and use it for output reduction - rowwise_sharding = [ - embedding_partial_placement, - Shard(0), - embedding_partial_placement, - ] - single_mesh_dim_strategies.append(rowwise_sharding) - - # batch dim sharding, weight replicated, input can shard on any dim, output follows input - for input_dim in range(len(indices_shape)): - batch_sharding = [Shard(input_dim), Replicate(), Shard(input_dim)] - single_mesh_dim_strategies.append(batch_sharding) - - all_mesh_dim_strategies.append(single_mesh_dim_strategies) - - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list = [] - for specs in zip(*strategy_comb): - spec_list.append(DTensorSpec(mesh, tuple(specs))) - - if is_tensor_shardable(weight_shape, spec_list[1]) and is_tensor_shardable( - indices_shape, spec_list[2] - ): - # only add to the strategy list when both weight and indices are shardable - weight_spec, indices_spec = spec_list[1:] - redistribute_cost = [ - generate_redistribute_costs(weight_strategy, weight_spec), - generate_redistribute_costs(indices_strategy, indices_spec), - ] - strat = PlacementStrategy( - output_specs=spec_list[0], - input_specs=spec_list[1:], - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strat) - - return OpStrategy(all_strategies) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: List[Placement] = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate + colwise_sharding = [Shard(output_emd_dim), Shard(1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial + embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0]) + + # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates + # from the input indices and use it for output reduction + rowwise_sharding = [ + embedding_partial_placement, + Shard(0), + embedding_partial_placement, + ] + single_mesh_dim_strategies.append(rowwise_sharding) + + # batch dim sharding, weight replicated, input can shard on any dim, output follows input + for input_dim in range(len(indices_shape)): + batch_sharding = [Shard(input_dim), Replicate(), Shard(input_dim)] + single_mesh_dim_strategies.append(batch_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) @register_op_strategy(aten.embedding_dense_backward.default) @@ -257,55 +220,26 @@ def embedding_dense_backward_strategy( indices_shape = indices_strategy.shape grad_out_ndim = len(grad_out_shape) - all_mesh_dim_strategies = [] - - for mesh_dim in range(mesh.ndim): - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, weight, input_indices] - # first we always have replicate all for inputs and output - all_replicate: List[Placement] = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # colwise sharding backward, grad_out shard on last dim, input replicate, - # weight grad shard colwise - colwise_sharding = [Shard(1), Shard(grad_out_ndim - 1), Replicate()] - single_mesh_dim_strategies.append(colwise_sharding) - - # batch dim sharding, weight replicated, grad_out/input have same sharding - # that can shard on any dim, weight grad partial - for input_dim in range(len(indices_shape)): - batch_sharding = [Partial(), Shard(input_dim), Shard(input_dim)] - single_mesh_dim_strategies.append(batch_sharding) - - # grad_out partial, input replicate, weight grad keep partial - partial_sharding = [Partial(), Partial(), Replicate()] - single_mesh_dim_strategies.append(partial_sharding) - - all_mesh_dim_strategies.append(single_mesh_dim_strategies) - - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list = [] - for specs in zip(*strategy_comb): - spec_list.append(DTensorSpec(mesh, tuple(specs))) - - if is_tensor_shardable(grad_out_shape, spec_list[1]) and is_tensor_shardable( - indices_shape, spec_list[2] - ): - # only add to the strategy list when both grad_out and indices are shardable - grad_out_spec, indices_spec = spec_list[1:] - redistribute_cost = [ - generate_redistribute_costs(grad_out_strategy, grad_out_spec), - generate_redistribute_costs(indices_strategy, indices_spec), - ] - strat = PlacementStrategy( - output_specs=spec_list[0], - input_specs=spec_list[1:], - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strat) - - return OpStrategy(all_strategies) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: List[Placement] = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding backward, grad_out shard on last dim, input replicate, + # weight grad shard colwise + colwise_sharding = [Shard(1), Shard(grad_out_ndim - 1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # batch dim sharding, weight replicated, grad_out/input have same sharding + # that can shard on any dim, weight grad partial + for input_dim in range(len(indices_shape)): + batch_sharding = [Partial(), Shard(input_dim), Shard(input_dim)] + single_mesh_dim_strategies.append(batch_sharding) + + # grad_out partial, input replicate, weight grad keep partial + partial_sharding = [Partial(), Partial(), Replicate()] + single_mesh_dim_strategies.append(partial_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py index 91d20a9dd8cb..029d1f803cb1 100644 --- a/torch/distributed/_tensor/ops/math_ops.py +++ b/torch/distributed/_tensor/ops/math_ops.py @@ -1,5 +1,4 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -import itertools import math from dataclasses import dataclass from enum import Enum @@ -16,9 +15,9 @@ ) from torch.distributed._tensor.ops.utils import ( as_list, + expand_to_full_mesh_op_strategy, generate_redistribute_costs, is_tensor_evenly_shardable, - is_tensor_shardable, normalize_dim, normalize_dims, normalize_to_torch_size, @@ -1021,44 +1020,20 @@ def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: ) topk_dim = normalize_dim(topk_dim, input_strategy.ndim) - all_mesh_dim_strategies = [] + single_mesh_dim_strategies = [] - for mesh_dim in range(mesh.ndim): - single_mesh_dim_strategies = [] + # two outputs (values, indices), 1 input + # replicate always works + all_replicate: List[Placement] = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) - # two outputs (values, indices), 1 input - # replicate always works - all_replicate: List[Placement] = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) + # every dim except topk dim should work + for dim in range(input_strategy.ndim): + if dim != topk_dim: + dim_shardings: List[Placement] = [Shard(dim)] * 3 + single_mesh_dim_strategies.append(dim_shardings) + # TODO: topk on sharded dim requries non-trival reduction, address it later - # every dim except topk dim should work - for dim in range(input_strategy.ndim): - if dim != topk_dim: - dim_shardings: List[Placement] = [Shard(dim)] * 3 - single_mesh_dim_strategies.append(dim_shardings) - - # TODO: topk on sharded dim requries non-trival reduction, address it later - - all_mesh_dim_strategies.append(single_mesh_dim_strategies) - - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list = [] - for specs in zip(*strategy_comb): - spec_list.append(DTensorSpec(mesh, tuple(specs))) - - input_spec = spec_list[2] - if is_tensor_shardable(input_shape, input_spec): - redistribute_cost = [ - generate_redistribute_costs(input_strategy, input_spec) - ] - strategy = PlacementStrategy( - output_specs=tuple(spec_list[:2]), - input_specs=(input_spec,), - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strategy) - - return OpStrategy(all_strategies) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=2 + ) diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index b42fdcc6cc08..7aa90f2ebcd7 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -1,5 +1,4 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -import itertools from typing import cast, List, Optional, Sequence, Tuple import torch @@ -16,11 +15,10 @@ from torch.distributed._tensor.ops.common_rules import pointwise_rule from torch.distributed._tensor.ops.embedding_ops import _MaskPartial from torch.distributed._tensor.ops.utils import ( - generate_redistribute_costs, + expand_to_full_mesh_op_strategy, is_tensor_dim_sharded, is_tensor_evenly_shardable, is_tensor_partial, - is_tensor_shardable, normalize_dim, register_op_strategy, register_prop_rule, @@ -370,59 +368,33 @@ def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: input_shape = input_strategy.shape index_shape = index_strategy.shape - all_mesh_dim_strategies = [] - - for mesh_dim in range(mesh.ndim): - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, input, index] - # first we always have replicate all for inputs and output - all_replicate: List[Placement] = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # input sharding, input sharded, index accepts mask partial, output follows index - # this only works when the input is sharded on the gather dimension, and - # index has size 1 on the gather dimension - if index_shape[dim] == 1: - index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim]) - input_sharding = [ - index_partial_placement, - Shard(dim), - index_partial_placement, - ] - single_mesh_dim_strategies.append(input_sharding) - - # index sharding, input replicated, index sharded, output follows index - # this only works when the sharding dimension is the gather dimension - index_sharding = [Shard(dim), Replicate(), Shard(dim)] - single_mesh_dim_strategies.append(index_sharding) - - all_mesh_dim_strategies.append(single_mesh_dim_strategies) - - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list = [] - for specs in zip(*strategy_comb): - spec_list.append(DTensorSpec(mesh, tuple(specs))) - - if is_tensor_shardable(input_shape, spec_list[1]) and is_tensor_shardable( - index_shape, spec_list[2] - ): - input_spec, index_spec = spec_list[1:] - redistribute_cost = [ - generate_redistribute_costs(input_strategy, input_spec), - generate_redistribute_costs(index_strategy, index_spec), - ] - strat = PlacementStrategy( - output_specs=spec_list[0], - input_specs=spec_list[1:], - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strat) - - return OpStrategy(all_strategies) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index] + # first we always have replicate all for inputs and output + all_replicate: List[Placement] = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # input sharding, input sharded, index accepts mask partial, output follows index + # this only works when the input is sharded on the gather dimension, and + # index has size 1 on the gather dimension + if index_shape[dim] == 1: + index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim]) + input_sharding = [ + index_partial_placement, + Shard(dim), + index_partial_placement, + ] + single_mesh_dim_strategies.append(input_sharding) + + # index sharding, input replicated, index sharded, output follows index + # this only works when the sharding dimension is the gather dimension + index_sharding = [Shard(dim), Replicate(), Shard(dim)] + single_mesh_dim_strategies.append(index_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) def _derive_follow_placements_from_tuple_strategy( diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py index 98f65eab610b..b957842b4276 100644 --- a/torch/distributed/_tensor/ops/utils.py +++ b/torch/distributed/_tensor/ops/utils.py @@ -1,12 +1,19 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import functools +import itertools import operator from typing import cast, Iterable, List, Sequence, Tuple, Union import torch from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import OpStrategy, RuntimeSchemaInfo +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementStrategy, + RuntimeSchemaInfo, +) from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.device_mesh import DeviceMesh from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, @@ -224,3 +231,47 @@ def generate_redistribute_costs( redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec)) return redistribute_costs + + +def expand_to_full_mesh_op_strategy( + mesh: DeviceMesh, + op_schema: OpSchema, + single_mesh_dim_strategies: List[List[Placement]], + input_index: int = 1, +) -> OpStrategy: + # Expand the single_mesh_dim_strategies to full mesh dim strategies. + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + spec_list.append(DTensorSpec(mesh, tuple(specs))) + + input_specs = spec_list[input_index:] + input_args_strategy = op_schema.args_strategy + assert len(input_specs) == len(input_args_strategy) + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + strategy = PlacementStrategy( + output_specs=tuple(spec_list[:input_index]) + if input_index > 1 + else spec_list[0], + input_specs=input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + + return OpStrategy(all_strategies) From 21144ce5704f5d95dff8d28e3a389c798b03afe3 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Sat, 1 Jun 2024 12:09:08 -0700 Subject: [PATCH 263/706] [dtensor] implement scatter op with simple replication (#126713) as titled, implement torch.scatter op with simple replications strategy, need to follow up and see if we could actually support any sharding pattern Pull Request resolved: https://github.com/pytorch/pytorch/pull/126713 Approved by: https://github.com/tianyu-l ghstack dependencies: #126712 --- test/distributed/_tensor/test_dtensor_ops.py | 1 - test/distributed/_tensor/test_tensor_ops.py | 34 ++++++++++++++++++++ torch/distributed/_tensor/ops/tensor_ops.py | 27 ++++++++++++++++ torch/distributed/_tensor/ops/utils.py | 8 +++++ 4 files changed, 69 insertions(+), 1 deletion(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 22a56118b212..83f0bb875167 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -403,7 +403,6 @@ def wrapped(fn): xfail("rsub"), xfail("scalar_tensor"), xfail("scatter_add"), - xfail("scatter"), xfail("scatter_reduce", "amax"), xfail("scatter_reduce", "amin"), xfail("scatter_reduce", "mean"), diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py index 24e527533315..e86a702855c6 100644 --- a/test/distributed/_tensor/test_tensor_ops.py +++ b/test/distributed/_tensor/test_tensor_ops.py @@ -390,6 +390,40 @@ def test_new_empty_strided(self): self.assertEqual(new_empty_strided_dt._local_tensor.size(), (12, 4)) self.assertEqual(new_empty_strided_dt._local_tensor.stride(), (4, 1)) + @with_comms + def test_scatter(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + comm_mode = CommDebugMode() + + # case 1 all replicate: input replicated, index/src replicated, output replicated + global_indexs = [ + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[0, 1, 2], [0, 1, 4]]), + ] + for scatter_dim in [0, 1]: + srcs = [torch.arange(1, 11).reshape((2, 5)), 4] + for global_src in srcs: + global_input = torch.zeros(3, 5, dtype=torch.int64) + global_index = global_indexs[scatter_dim] + + input_dt = distribute_tensor( + global_input.clone(), device_mesh, [Replicate()] + ) + index_dt = distribute_tensor(global_index, device_mesh, [Replicate()]) + if isinstance(global_src, torch.Tensor): + src_dt = distribute_tensor(global_src, device_mesh, [Replicate()]) + else: + src_dt = global_src + global_output = torch.scatter( + global_input, scatter_dim, global_index, global_src + ) + with comm_mode: + output_dt = torch.scatter(input_dt, scatter_dim, index_dt, src_dt) + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual(output_dt.placements, [Replicate()]) + self.assertEqual(output_dt.to_local(), global_output) + @with_comms def test_gather(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index 7aa90f2ebcd7..40f75c151579 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -4,6 +4,7 @@ import torch from torch.distributed._tensor._op_schema import ( + _is_inplace_op, OpSchema, OpStrategy, OutputSharding, @@ -359,6 +360,32 @@ def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType return OpStrategy([PlacementStrategy(replicate_spec)]) +@register_op_strategy( + [aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src], + schema_info=RuntimeSchemaInfo(1), +) +def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index, src] + # first we always have replicate all for inputs and output + if len(op_schema.args_strategy) < 3: + # scatter_.src/scatter.src with src be float number instead of tensor + all_replicate: List[Placement] = [Replicate()] * 3 + else: + all_replicate = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + # TODO: see if we can support input sharding pattern + inplace_op = _is_inplace_op(op_schema.op) + + op_strategy = expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, inplace_op=inplace_op + ) + return op_strategy + + @register_op_strategy(aten.gather.default) def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: input_strategy = cast(OpStrategy, op_schema.args_schema[0]) diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py index b957842b4276..245298607c5e 100644 --- a/torch/distributed/_tensor/ops/utils.py +++ b/torch/distributed/_tensor/ops/utils.py @@ -237,7 +237,9 @@ def expand_to_full_mesh_op_strategy( mesh: DeviceMesh, op_schema: OpSchema, single_mesh_dim_strategies: List[List[Placement]], + *, input_index: int = 1, + inplace_op: bool = False, ) -> OpStrategy: # Expand the single_mesh_dim_strategies to full mesh dim strategies. all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim @@ -253,6 +255,12 @@ def expand_to_full_mesh_op_strategy( input_specs = spec_list[input_index:] input_args_strategy = op_schema.args_strategy assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the placement strategy to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + # check inputs shardable inputs_shardable = all( is_tensor_shardable(inp.shape, s) From 12c4a2c29762832437d65191bcb1c119ce3e8ef7 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Mon, 3 Jun 2024 17:22:10 +0000 Subject: [PATCH 264/706] [BE]: Apply PLR1736 fixes (unnecessary index lookup) (#127716) Applies the PLR1736 preview rule with some more autofixes to cut down on unnecessary accesses. Added a noqa since that test actually testing the dunder method. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127716 Approved by: https://github.com/ezyang --- test/distributed/test_device_mesh.py | 4 ++-- test/test_torch.py | 2 +- torch/_decomp/decompositions.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 03457de14b68..c024ddc98690 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -672,9 +672,9 @@ def test_all_gather_uneven(self): ) unpadded_list = [ ( - unpad_tensor(big_tensor_chunks[i], shard_dim, pad_sizes[i]) + unpad_tensor(big_tensor, shard_dim, pad_sizes[i]) if pad_sizes[i] > 0 - else big_tensor_chunks[i] + else big_tensor ) for i, big_tensor in enumerate(big_tensor_chunks) ] diff --git a/test/test_torch.py b/test/test_torch.py index 717943f43646..ff573706913f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -8397,7 +8397,7 @@ def test_resizable(self) -> None: def test_iter(self) -> None: x = torch.randn(5, 5) for i, sub in enumerate(x): - self.assertEqual(sub, x[i]) + self.assertEqual(sub, x[i]) # noqa: PLR1736 x = torch.tensor([]) self.assertEqual(list(x), []) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 5bec539db06c..76599d299b29 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2147,7 +2147,7 @@ def cudnn_batch_norm( def _broadcast_batch_norm_backward(x, broadcast_mask): for axis, mask in enumerate(broadcast_mask): - if mask == 1 and not (axis < x.ndim and x.shape[axis] == broadcast_mask[axis]): + if mask == 1 and not (axis < x.ndim and x.shape[axis] == mask): x = x.unsqueeze(axis) return x From 4d32de14b6caeb12f874b55d590d3cbda5cec6cb Mon Sep 17 00:00:00 2001 From: angelayi Date: Mon, 3 Jun 2024 17:25:51 +0000 Subject: [PATCH 265/706] [export] Handle serializing duplicate getitem nodes (#127633) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We ran into a graph that looks something like the following, where we have 2 getitem calls to the same index (%getitem, %getitem_2 both query topk[0]): ``` graph(): %x : [num_users=1] = placeholder[target=x] %topk : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%x, 2), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%topk, 0), kwargs = {}) %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%topk, 1), kwargs = {}) %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%topk, 0), kwargs = {}) %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem, %getitem_2), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_tensor, 2), kwargs = {}) return (mul, getitem_1) ``` The duplicate getitem call gets created during a pass.. so there are a couple of solutions: 1. Change serializer to support the case of duplicate getitem calls 2. Change the pass so that it doesn’t produce duplicate getitem calls 3. Add a pass which dedups the getitem calls As a framework, we should do 1 and 3 (through a CSE pass). This PR implements solution 1. However, the serializer currently does some special handling for getitem nodes -- instead of directly serializing the getitem nodes, we serialize the output of the node that outputting a list of tensors (the %topk node in this example) into a list nodes for each output ([%getitem, %getitem_1]). This fails when we have duplicate getitem nodes to the same index (%getitem_2), since we do not record that duplicate getitem node anywhere. So, the solution this PR takes is that the serializer will deduplicate the getitem nodes (%getitem_2 will be replaced with %getitem). This would result in a sematically correct graph, but not necessarily node-to-node identical as the original fx graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127633 Approved by: https://github.com/ydwu4 --- test/export/test_serialize.py | 43 ++++++++++++++++++++++++++ torch/_export/serde/serialize.py | 53 +++++++++++++++++++------------- 2 files changed, 74 insertions(+), 22 deletions(-) diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index b8ed2ef69f53..012b35c910b5 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -806,6 +806,49 @@ def forward(self, x): dynamic_shapes = {"x": {0: Dim("dim0"), 1: Dim("dim1")}} self.check_graph(Foo(), (torch.ones(4, 5),), dynamic_shapes=dynamic_shapes) + def test_multiple_getitem(self): + class M(torch.nn.Module): + def forward(self, x): + a, b = torch.topk(x, 2) + a = a * 2 + return a, b + + ep = torch.export.export(M(), (torch.ones(3),)) + + # insert another getitem node + for node in ep.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.mul.Tensor: + getitem_0 = node.args[0] + with ep.graph.inserting_before(getitem_0): + getitem_copy = ep.graph.node_copy(getitem_0) + mul_node = ep.graph.call_function( + torch.ops.aten.mul.Tensor, (getitem_copy, 2) + ) + mul_node.meta = copy.copy(getitem_copy.meta) + node.args = (getitem_0, mul_node) + + deserialized_ep = deserialize(serialize(ep)) + + inp = (torch.randn(3),) + orig_res = ep.module()(*inp) + res = deserialized_ep.module()(*inp) + self.assertTrue(torch.allclose(orig_res[0], res[0])) + self.assertTrue(torch.allclose(orig_res[1], res[1])) + + # The deserialized graph should have deduped getitem calls + self.assertExpectedInline( + deserialized_ep.graph_module.code.strip("\n"), + """\ +def forward(self, x): + topk_default = torch.ops.aten.topk.default(x, 2); x = None + getitem = topk_default[0] + getitem_1 = topk_default[1]; topk_default = None + mul_tensor = torch.ops.aten.mul.Tensor(getitem, 2) + mul = torch.ops.aten.mul.Tensor(getitem, mul_tensor); getitem = mul_tensor = None + return (mul, getitem_1) + """, + ) + @parametrize( "name,case", get_filtered_export_db_tests(), diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 38ef1da7d5d4..8d6dc939fb5c 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -387,14 +387,6 @@ def _is_single_tensor_list_return(target: Any) -> bool: return_type.getElementType(), torch.TensorType ) -def _output_node_at_index(node, index): - for user in node.users: - assert user.target is operator.getitem, f"{user} is not a getitem node" - if index == user.args[1]: - return user - return None - - @dataclass class GraphState: @@ -427,6 +419,7 @@ def __init__( self.graph_signature = graph_signature self.module_call_graph = module_call_graph self.custom_objs: Dict[str, torch._C.ScriptObject] = {} + self.duplicate_getitem_nodes: Dict[str, str] = {} @contextmanager def save_graph_state(self): @@ -552,6 +545,19 @@ def handle_call_function(self, node: torch.fx.Node): def handle_get_attr(self, node): pass + def _output_node_at_index(self, node, index): + user_node = None + for user in node.users: + assert user.target is operator.getitem, f"{user} is not a getitem node" + if index == user.args[1]: + if user_node is None: + user_node = user + else: + # We want to deduplicate getitem nodes that are trying to + # index to the same index + self.duplicate_getitem_nodes[user.name] = user_node.name + return user_node + def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: ret = {} if stack_trace := node.meta.get("stack_trace"): @@ -705,13 +711,16 @@ def serialize_input( return Argument.create( as_sym_bool=SymBoolArgument.create(as_name=arg.name) ) - else: - if isinstance(arg.meta["val"], ep.CustomObjArgument): - return Argument.create( - as_custom_obj=CustomObjArgument( - name=arg.name, class_fqn=arg.meta["val"].class_fqn - ) + elif isinstance(arg.meta["val"], ep.CustomObjArgument): + return Argument.create( + as_custom_obj=CustomObjArgument( + name=arg.name, class_fqn=arg.meta["val"].class_fqn ) + ) + elif arg.name in self.duplicate_getitem_nodes: + dedup_name = self.duplicate_getitem_nodes[arg.name] + return Argument.create(as_tensor=TensorArgument(name=dedup_name)) + else: return Argument.create(as_tensor=TensorArgument(name=arg.name)) elif isinstance(arg, inductor_tensor_buffers): # Other branches are for arguments in fx node. @@ -1121,7 +1130,7 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: # e.g "-> Tensor[]" tensor_args = [] for idx, meta in enumerate(meta_val): - user_node = _output_node_at_index(node, idx) + user_node = self._output_node_at_index(node, idx) name = ( user_node.name if user_node is not None @@ -1151,7 +1160,7 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: output_arguments.append(Argument.create(as_none=())) elif isinstance(meta, FakeTensor): assert isinstance(return_schema.real_type, (torch.OptionalType, torch.TensorType)) - user_node = _output_node_at_index(node, idx) + user_node = self._output_node_at_index(node, idx) name = ( user_node.name if user_node is not None @@ -1165,20 +1174,20 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: ) and isinstance( return_schema.real_type.getElementType(), torch.TensorType ) - user_node = _output_node_at_index(node, idx) + user_node = self._output_node_at_index(node, idx) assert user_node is not None args = [] for i, m in enumerate(meta): if m is None: continue - sub_user_node = _output_node_at_index(user_node, i) + sub_user_node = self._output_node_at_index(user_node, i) assert sub_user_node is not None, f"No user found at index {i}" args.append(self.serialize_tensor_output(sub_user_node.name, m)) output_arguments.append(Argument.create(as_tensors=args)) elif isinstance(meta, (int, SymInt)): - user_node = _output_node_at_index(node, idx) + user_node = self._output_node_at_index(node, idx) name = ( user_node.name if user_node is not None @@ -1208,7 +1217,7 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: if len(meta_val) == 1: assert isinstance(meta_val[0], torch.Tensor) - user_node = _output_node_at_index(node, 0) + user_node = self._output_node_at_index(node, 0) name = ( user_node.name if user_node is not None @@ -1218,7 +1227,7 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: outputs = [] for i, element_meta_val in enumerate(meta_val): - user_node = _output_node_at_index(node, i) + user_node = self._output_node_at_index(node, i) if isinstance(element_meta_val, list): # e.g "-> Tensor[]" assert user_node is not None @@ -1228,7 +1237,7 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: if not isinstance(m, torch.Tensor): raise SerializeError(f"Serialize list output with type {type(m)} nyi") - sub_user_node = _output_node_at_index(user_node, j) + sub_user_node = self._output_node_at_index(user_node, j) name = ( sub_user_node.name if sub_user_node is not None From 7c3740d3889837fb1587ca44ea2212c57e45fb52 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Mon, 3 Jun 2024 17:46:12 +0000 Subject: [PATCH 266/706] [NestedTensor] Extend coverage for unbind when ragged_idx != 1 (#127493) Summary: Extend coverage for the `NestedTensor` `unbind` operator to cases in which `ragged_idx != 1`. Currently, the `unbind` operator in the `NestedTensor` class splits a tensor along the 0-th dimension, where the `ragged_idx` property, which controls the jagged dimension upon which `unbind` splits, is 1. This diff extends support for `ragged_idx != 1` in `NestedTensor`s, allowing `unbind` to split a tensor along a jagged dimension greater than 0 for `NestedTensor`s with and without the `lengths` property. Test Plan: Added the following unit tests: `test_unbind_ragged_idx_equals_2_cpu`, `test_unbind_ragged_idx_equals_3_cpu`, and `test_unbind_ragged_idx_equals_last_dim_cpu` verify that `unbind` works for all jagged dimensions greater than 1, for `NestedTensor`s without `lengths`. ``` test_unbind_ragged_idx_equals_2_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok test_unbind_ragged_idx_equals_3_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok test_unbind_ragged_idx_equals_last_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok ``` `test_unbind_with_lengths_cpu` and `test_unbind_with_lengths_ragged_idx_equals_1_cpu` verify that `unbind` works when the jagged dimension is 1, for `NestedTensor`s with `lengths`. ``` test_unbind_with_lengths_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok test_unbind_with_lengths_ragged_idx_equals_1_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok ``` `test_unbind_with_lengths_ragged_idx_equals_2_cpu` and `test_unbind_with_lengths_ragged_idx_equals_3_cpu` verify that `unbind` works when the jagged dimension is greater than 1, for `NestedTensor`s with `lengths`. ``` test_unbind_with_lengths_ragged_idx_equals_2_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok test_unbind_with_lengths_ragged_idx_equals_3_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok ``` `test_unbind_with_lengths_ragged_idx_equals_0_cpu` verifies that `unbind` fails when the jagged dimension is 0 (the batch dimension), for `NestedTensor`s with `lengths`. ``` test_unbind_with_lengths_ragged_idx_equals_0_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok ``` `test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu` verifies that `unbind` fails when there is a mismatch between the offsets and the jagged dimension, for `NestedTensor`s with `lengths`. ``` test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok ``` `test_unbind_with_wrong_lengths_cpu` verifies that `unbind` fails when the lengths exceed the limitations set by offsets, for `NestedTensor`s with `lengths`. ``` test_unbind_with_wrong_lengths_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok ``` Differential Revision: D57942686 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127493 Approved by: https://github.com/davidberard98 --- test/test_nestedtensor.py | 139 +++++++++++++++++++++++++++++++++- torch/nested/_internal/ops.py | 14 ++-- 2 files changed, 146 insertions(+), 7 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index d369135a6e52..17b2bf5a8393 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -3048,6 +3048,15 @@ def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_g _make_tensor(5, 5, 6), _make_tensor(6, 5, 6), ], + # (B, *, D_0, D_1, D_2) with B=6 + [ + _make_tensor(2, 5, 6, 7), + _make_tensor(3, 5, 6, 7), + _make_tensor(4, 5, 6, 7, requires_grad=False), + _make_tensor(5, 5, 6, 7), + _make_tensor(6, 5, 6, 7), + _make_tensor(7, 5, 6, 7), + ], ] if include_list_of_lists: @@ -3786,12 +3795,140 @@ def test_unbind(self, device): nt = torch.nested.nested_tensor( tensor_list, layout=torch.jagged, - device=device) + device=device) # ragged_idx = 1 out = nt.unbind() self.assertEqual(len(out), len(tensor_list)) for i, t in enumerate(out): self.assertEqual(t, tensor_list[i]) + @parametrize("ragged_idx", [2, 3]) + def test_unbind_transpose(self, device, ragged_idx): + for tensor_list in self._get_example_tensor_lists(): + nt = torch.nested.nested_tensor( + tensor_list, + layout=torch.jagged, + device=device) + if ragged_idx < nt.dim(): + nt = nt.transpose(1, ragged_idx) # set ragged_idx + out = nt.unbind() + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t.transpose(0, ragged_idx - 1), tensor_list[i]) # transpose back each element of result + + def test_unbind_transpose_ragged_idx_last_dim(self, device): + for tensor_list in self._get_example_tensor_lists(): + nt = torch.nested.nested_tensor( + tensor_list, + layout=torch.jagged, + device=device).transpose(1, -1) # set ragged_idx = last dimension + out = nt.unbind() + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t.transpose(0, -1), tensor_list[i]) # transpose back each element of result + + def test_unbind_lengths(self, device): + values = torch.randn(16, 128, device=device) + offsets = torch.tensor([0, 8, 12, 13, 16], device=device) + lengths = torch.tensor([6, 2, 1, 2], device=device) + nt = torch.nested.nested_tensor_from_jagged( + values, + offsets=offsets, + lengths=lengths) # 3D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])]) + + out = nt.unbind() + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + def test_unbind_lengths_ragged_idx_1(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 8, 12, 13, 16], device=device) + lengths = torch.tensor([6, 2, 1, 2], device=device) + ragged_idx = 1 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, + offsets=offsets, + lengths=lengths, + _ragged_idx=ragged_idx) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :]) + + out = nt.unbind() + + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + def test_unbind_lengths_ragged_idx_2(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 2, 4, 8], device=device) + lengths = torch.tensor([2, 1, 3], device=device) + ragged_idx = 2 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, + offsets=offsets, + lengths=lengths, + _ragged_idx=ragged_idx) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :]) + + out = nt.unbind() + + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + def test_unbind_lengths_ragged_idx_3(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 100, 128], device=device) + lengths = torch.tensor([50, 28], device=device) + ragged_idx = 3 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, + offsets=offsets, + lengths=lengths, + _ragged_idx=ragged_idx) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])]) + + out = nt.unbind() + + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + @skipIfTorchDynamo("TorchDynamo raises an error for ragged_idx == 0 earlier than Torch") + def test_unbind_lengths_ragged_idx_0(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 100, 128], device=device) + lengths = torch.tensor([50, 28], device=device) + ragged_idx = 0 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, + offsets=offsets, + lengths=lengths, + _ragged_idx=ragged_idx) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])]) + + self.assertRaisesRegex( + RuntimeError, + r"unbind\(\): nested tensor.*out of bounds", + lambda: nt.unbind() + ) + @xfailIfTorchDynamo def test_layer_norm_2(self, device): test_tensor_list = self._get_list_for_jagged_tensor( diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index d448628b7cad..cfbb50b395fa 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -616,16 +616,18 @@ def unbind_int(func, *args, **kwargs): values = inp.values() offsets = inp.offsets() lengths = inp.lengths() + ragged_idx = inp._ragged_idx - if inp._ragged_idx != 1: + if lengths is None: + return torch.split(values, offsets.diff().tolist(), dim=(ragged_idx - 1)) + + if ragged_idx <= 0: raise RuntimeError( - "unbind(): only supported for NestedTensor when jagged dimension is 1" + "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)" ) - - if lengths is None: - return torch.split(values, offsets.diff().tolist()) return [ - values[offsets[i] : (offsets[i] + lengths[i])] for i in range(lengths.shape[0]) + torch.narrow(values, dim=(ragged_idx - 1), start=offsets[i], length=lengths[i]) + for i in range(lengths.shape[0]) ] From d8d0bf264a736c7fb3cd17799a1c1aba4addf8d9 Mon Sep 17 00:00:00 2001 From: Alnis Murtovi Date: Mon, 3 Jun 2024 17:53:48 +0000 Subject: [PATCH 267/706] Inductor: Allow small sizes of m for mixed mm autotuning (#127663) For mixed mm with small sizes of m, such as in the example provided in #127056, being able to set BLOCK_M to 16 leads to better performance. This PR introduces kernel configs that are specific to mixed mm by extending the mm configs with two configs that work well for the example provided in #127056. I am excluding configs with (BLOCK_M=16, BLOCK_K=16, BLOCK_N=64) because triton crashes when this config is used. For the example in #127056: - Without my changes, skip_triton is evaluated to true which disables autotuning. On my machine I achieve 146GB/s. - If autotuning is enabled, but BLOCK_M>=32, I achieve 614 GB/s. - With the changes in this PR (i.e. autotuning enabled and BLOCK_M=16), I achieve 772 GB/s. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127663 Approved by: https://github.com/Chillee --- torch/_inductor/kernel/mm.py | 15 +++++++++++--- torch/_inductor/kernel/mm_common.py | 32 ++++++++++++++++++++++++----- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index a90fdbfa33d9..eba1c65702e8 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -26,6 +26,7 @@ from .mm_common import ( addmm_epilogue, int8_mm_configs, + mixed_mm_configs, mm_args, mm_configs, mm_grid, @@ -407,15 +408,23 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype): # can't use triton kernel unless one of these is true or if running on v100 (numerical issues) skip_triton = ( - mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous() + mat1.layout.dtype != torch.float32 + and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed()) ) or _is_sm7x_or_older_gpu(layout.device.index) if inductor_config.force_mixed_mm: choices = [] if not skip_triton: b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "") - has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2) - for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor): + for config in mixed_mm_configs(m, n, k): + # skipping this config because triton crashes on it + # See: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424 + if ( + config.kwargs["BLOCK_M"] == 16 + and config.kwargs["BLOCK_K"] == 16 + and config.kwargs["BLOCK_N"] == 64 + ): + continue mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 97741cc0f8eb..7fa403fe78f1 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -27,14 +27,10 @@ def filtered_configs( n: int, k: int, configs: List[Tuple[int, int, int, int, int]], - has_int8_tensor=False, ): """Heuristic to shrink configs when they are bigger than the input size""" - # According to https://github.com/openai/triton/issues/2156#issuecomment-1695897424 - # it's safer to use at least [32, 32] block size for int8/uint8 - # tensors - min_block_size = 32 if has_int8_tensor else 16 + min_block_size = 16 m = max( next_power_of_2( V.graph.sizevars.size_hint( @@ -166,6 +162,18 @@ def filtered_configs( {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None}, ] +# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192). +mixed_mm_kernel_configs_small_m = [ + {"config": (16, 128, 256, 3, 4), "cond": True}, + {"config": (16, 128, 256, 5, 8), "cond": True}, +] + +mixed_mm_kernel_configs = ( + mm_kernel_configs + mixed_mm_kernel_configs_small_m + if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" + else mm_kernel_configs +) + # Create filtered list of configs based on cond evaluation @@ -179,6 +187,11 @@ def filtered_configs( for config in int8_mm_kernel_configs if config["cond"] ) +mixed_mm_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in mixed_mm_kernel_configs + if config["cond"] +) # On ROCm convert num_stages to 0 to enable software pipelining if torch.version.hip: @@ -190,6 +203,10 @@ def filtered_configs( (config[0], config[1], config[2], 0, config[4]) for config in mm_platform_configs ) + mixed_mm_platform_configs = tuple( + (config[0], config[1], config[2], 0, config[4]) + for config in mixed_mm_platform_configs + ) mm_configs = functools.partial( filtered_configs, @@ -201,6 +218,11 @@ def filtered_configs( configs=int8_platform_configs, ) +mixed_mm_configs = functools.partial( + filtered_configs, + configs=mixed_mm_platform_configs, +) + def mm_grid(m, n, meta): """ From 3437177e2bce6929f22cb02edac563335bb1e31f Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Mon, 3 Jun 2024 18:06:28 +0000 Subject: [PATCH 268/706] Quick Fix on #126854, deepcopy `lr` and other possible `base_parameters` (#127190) * Apply `deepcopy` to every base parameters (`initial_lr`, `max_lr`) when instantiating `LRScheduler`. Fixes #126854 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127190 Approved by: https://github.com/janeyx99 --- test/optim/test_lrscheduler.py | 114 +++++++++++++++++++++++++++++++++ torch/optim/lr_scheduler.py | 65 +++++++++---------- torch/optim/swa_utils.py | 20 +----- 3 files changed, 146 insertions(+), 53 deletions(-) diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py index bca8b3b6b69c..11c6a24bd161 100644 --- a/test/optim/test_lrscheduler.py +++ b/test/optim/test_lrscheduler.py @@ -1,4 +1,5 @@ # Owner(s): ["module: optimizer", "module: LrScheduler" ] +import copy import math import pickle import tempfile @@ -2403,6 +2404,119 @@ def test_lr_scheduler_state_dict_load(self, LRClass, weights_only): scheduler2.load_state_dict(state_dict_loaded) self.assertEqual(scheduler2.state_dict(), state_dict) + @parametrize( + "LRClass", + [ + partial(LambdaLR, lr_lambda=lambda e: e // 10), + partial(MultiplicativeLR, lr_lambda=lambda e: 0.95), + partial(StepLR, step_size=30), + partial(MultiStepLR, milestones=[30, 80]), + ConstantLR, + LinearLR, + partial(ExponentialLR, gamma=0.9), + PolynomialLR, + partial(CosineAnnealingLR, T_max=10), + partial(CosineAnnealingWarmRestarts, T_0=20), + ], + ) + def test_constant_initial_lr(self, LRClass): + # Test that the initial learning rate is constant + lr = torch.as_tensor(0.1) + opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) + sch = LRClass(opt) + + ori_param_groups = copy.deepcopy(opt.param_groups) + + for i in range(2): + opt.step() + sch.step(i) + lr.multiply_(0.1) + for group, ori_group in zip(opt.param_groups, ori_param_groups): + self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) + self.assertEqual(sch.base_lrs, [0.1]) + + def test_constant_initial_params_cyclelr(self): + # Test that the initial learning rate is constant + lr = torch.as_tensor(0.1) + max_lr = torch.as_tensor(0.2) + base_momentum = torch.as_tensor(0.8) + max_momentum = torch.as_tensor(0.9) + opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) + sch = CyclicLR( + opt, + base_lr=lr, + max_lr=max_lr, + base_momentum=base_momentum, + max_momentum=max_momentum, + ) + ori_param_groups = copy.deepcopy(opt.param_groups) + + for i in range(2): + lr.multiply_(0.5) + max_lr.multiply_(0.5) + base_momentum.multiply_(0.5) + max_momentum.multiply_(0.5) + opt.step() + sch.step(i) + for group, ori_group in zip(opt.param_groups, ori_param_groups): + self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) + self.assertEqual(group["max_momentum"], ori_group["max_momentum"]) + self.assertEqual(group["base_momentum"], ori_group["base_momentum"]) + self.assertEqual(sch.base_lrs, [0.1]) + self.assertEqual(sch.max_lrs, [0.2]) + self.assertEqual(group["max_momentum"], 0.9) + self.assertEqual(group["base_momentum"], 0.8) + + def test_constant_initial_params_onecyclelr(self): + # Test that the initial learning rate is constant + lr = torch.as_tensor(0.1) + base_momentum = torch.as_tensor(0.85) + max_momentum = torch.as_tensor(0.95) + opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) + sch = OneCycleLR( + opt, + max_lr=lr, + total_steps=10, + base_momentum=base_momentum, + max_momentum=max_momentum, + ) + ori_param_groups = copy.deepcopy(opt.param_groups) + + for i in range(2): + lr.multiply_(0.5) + base_momentum.multiply_(0.5) + max_momentum.multiply_(0.5) + opt.step() + sch.step(i) + + for group, ori_group in zip(opt.param_groups, ori_param_groups): + self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) + self.assertEqual(group["max_lr"], ori_group["max_lr"]) + self.assertEqual(group["min_lr"], ori_group["min_lr"]) + self.assertEqual(group["max_momentum"], ori_group["max_momentum"]) + self.assertEqual(group["base_momentum"], ori_group["base_momentum"]) + self.assertEqual(group["max_momentum"], 0.95) + self.assertEqual(group["base_momentum"], 0.85) + + def test_constant_initial_params_swalr(self): + # Test that the initial learning rate is constant + lr = torch.as_tensor(0.1) + swa_lr = torch.as_tensor(0.05) + opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) + sch = SWALR(opt, swa_lr=swa_lr) + ori_param_groups = copy.deepcopy(opt.param_groups) + + for i in range(2): + lr.multiply_(0.5) + swa_lr.multiply_(0.5) + opt.step() + sch.step() + for group, ori_group in zip(opt.param_groups, ori_param_groups): + self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) + self.assertEqual(group["swa_lr"], ori_group["swa_lr"]) + self.assertEqual(group["swa_lr"], 0.05) + self.assertEqual(sch.base_lrs, [0.1]) + instantiate_parametrized_tests(TestLRScheduler) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 77bdb6b46aac..42c55db82a43 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -65,6 +65,24 @@ def _check_verbose_deprecated_warning(verbose): return False +def _format_param(name: str, optimizer: Optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + + def _copy(_param): + return _param.clone() if isinstance(_param, Tensor) else _param + + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError( + f"{name} must have the same length as optimizer.param_groups. " + f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}." + ) + else: + param = [param] * len(optimizer.param_groups) + + return list(map(_copy, param)) + + class LRScheduler: _get_lr_called_within_step: bool = False @@ -77,7 +95,10 @@ def __init__(self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated"): # Initialize epoch and base learning rates if last_epoch == -1: for group in optimizer.param_groups: - group.setdefault("initial_lr", group["lr"]) + initial_lr = group["lr"] + if isinstance(initial_lr, Tensor): + initial_lr = initial_lr.clone() + group.setdefault("initial_lr", initial_lr) else: for i, group in enumerate(optimizer.param_groups): if "initial_lr" not in group: @@ -1491,16 +1512,16 @@ def __init__( raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") self.optimizer = optimizer - base_lrs = self._format_param("base_lr", optimizer, base_lr) + base_lrs = _format_param("base_lr", optimizer, base_lr) if last_epoch == -1: for lr, group in zip(base_lrs, optimizer.param_groups): if isinstance(group["lr"], Tensor): lr_val = lr.item() if isinstance(lr, Tensor) else lr - group["lr"].fill_(lr) + group["lr"].fill_(lr_val) else: group["lr"] = lr - self.max_lrs = self._format_param("max_lr", optimizer, max_lr) + self.max_lrs = _format_param("max_lr", optimizer, max_lr) step_size_up = float(step_size_up) step_size_down = ( @@ -1531,12 +1552,10 @@ def __init__( ) self.use_beta1 = "betas" in self.optimizer.defaults - self.base_momentums = self._format_param( + self.base_momentums = _format_param( "base_momentum", optimizer, base_momentum ) - self.max_momentums = self._format_param( - "max_momentum", optimizer, max_momentum - ) + self.max_momentums = _format_param("max_momentum", optimizer, max_momentum) if last_epoch == -1: for m_momentum, b_momentum, group in zip( self.max_momentums, self.base_momentums, optimizer.param_groups @@ -1564,17 +1583,6 @@ def _init_scale_fn(self): self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma) self.scale_mode = "iterations" - def _format_param(self, name, optimizer, param): - """Return correctly formatted lr/momentum for each param group.""" - if isinstance(param, (list, tuple)): - if len(param) != len(optimizer.param_groups): - raise ValueError( - f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}" - ) - return param - else: - return [param] * len(optimizer.param_groups) - def scale_fn(self, x) -> float: if self._scale_fn_custom is not None: return self._scale_fn_custom(x) @@ -2012,7 +2020,7 @@ def __init__( self._anneal_func_type = anneal_strategy # Initialize learning rate variables - max_lrs = self._format_param("max_lr", self.optimizer, max_lr) + max_lrs = _format_param("max_lr", self.optimizer, max_lr) if last_epoch == -1: for idx, group in enumerate(self.optimizer.param_groups): group["initial_lr"] = max_lrs[idx] / div_factor @@ -2030,10 +2038,8 @@ def __init__( "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" ) self.use_beta1 = "betas" in self.optimizer.defaults - max_momentums = self._format_param("max_momentum", optimizer, max_momentum) - base_momentums = self._format_param( - "base_momentum", optimizer, base_momentum - ) + max_momentums = _format_param("max_momentum", optimizer, max_momentum) + base_momentums = _format_param("base_momentum", optimizer, base_momentum) if last_epoch == -1: for m_momentum, b_momentum, group in zip( max_momentums, base_momentums, optimizer.param_groups @@ -2047,17 +2053,6 @@ def __init__( super().__init__(optimizer, last_epoch, verbose) - def _format_param(self, name, optimizer, param): - """Return correctly formatted lr/momentum for each param group.""" - if isinstance(param, (list, tuple)): - if len(param) != len(optimizer.param_groups): - raise ValueError( - f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}" - ) - return param - else: - return [param] * len(optimizer.param_groups) - def _anneal_func(self, *args, **kwargs): if hasattr(self, "_anneal_func_type"): if self._anneal_func_type == "cos": diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 7c2c9cdaf6f9..4cfca073af77 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -7,7 +7,7 @@ import torch from torch import Tensor from torch.nn import Module -from torch.optim.lr_scheduler import LRScheduler +from torch.optim.lr_scheduler import _format_param, LRScheduler from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices from .optimizer import Optimizer @@ -390,7 +390,7 @@ def __init__( anneal_strategy: Literal["cos", "linear"] = "cos", last_epoch=-1, ): - swa_lrs = self._format_param(optimizer, swa_lr) + swa_lrs = _format_param("swa_lr", optimizer, swa_lr) for swa_lr, group in zip(swa_lrs, optimizer.param_groups): group["swa_lr"] = swa_lr if anneal_strategy not in ["cos", "linear"]: @@ -409,22 +409,6 @@ def __init__( self.anneal_epochs = anneal_epochs super().__init__(optimizer, last_epoch) - @staticmethod - def _format_param( - optimizer: Optimizer, - swa_lrs: Union[float, List[float], Tuple[float, ...]], - ) -> Union[List[float], Tuple[float, ...]]: - if isinstance(swa_lrs, (list, tuple)): - if len(swa_lrs) != len(optimizer.param_groups): - raise ValueError( - "swa_lr must have the same length as " - f"optimizer.param_groups: swa_lr has {len(swa_lrs)}, " - f"optimizer.param_groups has {len(optimizer.param_groups)}" - ) - return swa_lrs - else: - return [swa_lrs] * len(optimizer.param_groups) - @staticmethod def _linear_anneal(t): return t From c35b65715cac6b2ab9af8a3d1c4b223eca1d6f93 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 3 Jun 2024 18:07:54 +0000 Subject: [PATCH 269/706] Revert "[Inductor][Flex-attention] Support different sequence lengths for Query and Key/Value (#127678)" This reverts commit e2e3ca94ccce1c0abbfd75ac0368793e1756c268. Reverted https://github.com/pytorch/pytorch/pull/127678 on behalf of https://github.com/atalman due to Ineternal breakage of https://github.com/pytorch/pytorch/pull/127208 hence reverting ([comment](https://github.com/pytorch/pytorch/pull/127678#issuecomment-2145821489)) --- test/inductor/test_flex_attention.py | 53 ++++---------- torch/_inductor/kernel/flex_attention.py | 93 ++++++++++-------------- torch/_inductor/select_algorithm.py | 9 +-- torch/nn/attention/_flex_attention.py | 5 ++ 4 files changed, 62 insertions(+), 98 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 59fbd31c09cf..16206b51bb13 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -196,33 +196,23 @@ def run_test( self, score_mod: Callable, dtype: torch.dtype = torch.float16, - Q_B: int = B, - Q_H: int = H, - Q_S: int = S, - Q_D: int = D, - KV_B: int = B, - KV_H: int = H, - KV_S: int = S, - KV_D: int = D, + B: int = B, + H: int = H, + S: int = S, + D: int = D, ): sdpa_partial = create_attention(score_mod) compiled_sdpa = torch.compile(sdpa_partial) - q = torch.randn( - (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True - ) - k = torch.randn( - (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True - ) - v = torch.randn( - (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True - ) + q = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) golden_out = sdpa_partial(q_gold, k_gold, v_gold) ref_out = sdpa_partial(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) - backward_grad = torch.randn((Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda") + backward_grad = torch.randn((B, H, S, D), dtype=dtype, device="cuda") golden_out.backward(backward_grad.to(torch.float64)) ref_out.backward(backward_grad) @@ -417,25 +407,6 @@ def test_builtin_score_mods_automatic_dynamic( ): self.run_automatic_dynamic_test(score_mod, dtype) - @supported_platform - @common_utils.parametrize("dtype", test_dtypes_fast) - @common_utils.parametrize("score_mod", test_score_mods) - def test_builtin_score_mods_different_seqlen( - self, dtype: torch.dtype, score_mod: Callable - ): - self.run_test( - score_mod, - dtype, - B, - H, - S // 2, # Seqlen of Q is different from seqlen of K/V - D, - B, - H, - S, - D, - ) - @supported_platform @common_utils.parametrize("dtype", test_dtypes) def test_skip_odd_keys(self, dtype: torch.dtype): @@ -812,6 +783,14 @@ def test_mixed_dtypes_fails(self): ): _flex_attention(query, key, value, _identity) + @supported_platform + def test_different_sequence_length_fails(self): + query = torch.randn((1, 1, 2048, 64), dtype=torch.float32, device="cuda") + key = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") + value = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") + with self.assertRaisesRegex(ValueError, "NYI: The target sequence length"): + _flex_attention(query, key, value, _identity) + @supported_platform @patch.object(torch._inductor.config, "max_autotune", True) def test_max_autotune(self): diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 42fabf65591d..3e95dd4f65ce 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -189,7 +189,6 @@ def build_subgraph_buffer( Z = {{size("Q", 0)}} H = {{size("Q", 1)}} Q_LEN = {{size("Q", 2)}} - KV_LEN = {{size("K", 2)}} qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty @@ -197,10 +196,9 @@ def build_subgraph_buffer( start_m = tl.program_id(0) off_hz = tl.program_id(1) - q_offset = off_hz * stride_qh - kv_offset = off_hz * stride_kh + qkv_offset = off_hz * stride_qh Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, + base=Q + qkv_offset, shape=(Q_LEN, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), @@ -208,16 +206,16 @@ def build_subgraph_buffer( order=(1, 0) ) K_block_ptr = tl.make_block_ptr( - base=K + kv_offset, - shape=(BLOCK_DMODEL, KV_LEN), + base=K + qkv_offset, + shape=(BLOCK_DMODEL, Q_LEN), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) V_block_ptr = tl.make_block_ptr( - base=V + kv_offset, - shape=(KV_LEN, BLOCK_DMODEL), + base=V + qkv_offset, + shape=(Q_LEN, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), @@ -237,7 +235,7 @@ def build_subgraph_buffer( q = (q * qk_scale).to(MATMUL_PRECISION) # loop over k, v and update accumulator lo = 0 - hi = KV_LEN + hi = Q_LEN for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- load k, v -- @@ -427,7 +425,6 @@ def flex_attention(*args, **kwargs): ], num_stages=num_stages, num_warps=num_warps, - call_sizes=query.get_size(), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=query.get_size()[-1], @@ -448,9 +445,7 @@ def flex_attention(*args, **kwargs): # ---------------------------- Backward HOP Implementation ---------------------------- -def flex_attention_backward_grid( - batch_size, num_heads, num_queries, d_model, num_key_value, meta -): +def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, meta): """How is this kernel parallelized? Currently this is only parallelizing over batch * num_heads, but we can, and want to parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require @@ -458,6 +453,8 @@ def flex_attention_backward_grid( """ import triton + # TODO: support different seqlen for Query and Key/Value. + num_key_value = num_queries return ( triton.cdiv(num_queries, meta["BLOCK_M2"]) + triton.cdiv(num_key_value, meta["BLOCK_N1"]), @@ -479,7 +476,7 @@ def flex_attention_backward_grid( # DK: Derivative of Key, is the written to via the store_output call due to some limitations with # inductor codegen # M: Number of queries, N: Number of keys/values, D: Model dimension - # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head # (Modifiable) Config options: # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. @@ -489,20 +486,10 @@ def flex_attention_backward_grid( # change of base out of the loop # Define Q Strides - stride_qz = {{stride("Q", 0)}} - stride_qh = {{stride("Q", 1)}} - stride_qm = {{stride("Q", 2)}} - stride_qd = {{stride("Q", 3)}} - # Define K Strides - stride_kz = {{stride("K", 0)}} - stride_kh = {{stride("K", 1)}} - stride_km = {{stride("K", 2)}} - stride_kd = {{stride("K", 3)}} - # Define V Strides - stride_vz = {{stride("V", 0)}} - stride_vh = {{stride("V", 1)}} - stride_vm = {{stride("V", 2)}} - stride_vd = {{stride("V", 3)}} + stride_z = {{stride("Q", 0)}} + stride_h = {{stride("Q", 1)}} + stride_tok = {{stride("Q", 2)}} + stride_d = {{stride("Q", 3)}} Z = {{size("Q", 0)}} H = {{size("Q", 1)}} @@ -514,22 +501,21 @@ def flex_attention_backward_grid( pid = tl.program_id(0) NUM_KV_BLOCKS = KV_LEN // BLOCK_N1 + bhid = tl.program_id(2) + off_chz = (bhid * Q_LEN).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + off_hz = tl.program_id(2) off_z = off_hz // H # batch idx off_h = off_hz % H # head idx - off_chz = (off_hz * Q_LEN).to(tl.int64) - q_adj = (stride_qh * (off_hz % H) + stride_qz * (off_hz // H)).to(tl.int64) - k_adj = (stride_kh * (off_hz % H) + stride_kz * (off_hz // H)).to(tl.int64) - v_adj = (stride_vh * (off_hz % H) + stride_vz * (off_hz // H)).to(tl.int64) - # offset pointers for batch/head - Q += q_adj - K += k_adj - V += v_adj - DO += q_adj - DQ += q_adj - DV += v_adj + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DV += adj LSE += off_chz DELTA += off_chz @@ -542,9 +528,9 @@ def flex_attention_backward_grid( offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) - q = tl.load(Q + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd) + q = tl.load(Q + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) - do = tl.load(DO + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd) + do = tl.load(DO + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) lse = tl.load(LSE + offs_m2) lse = lse[:, None] @@ -552,8 +538,8 @@ def flex_attention_backward_grid( start_n2 = 0 offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) offs_n2 = start_n2 + tl.arange(0, BLOCK_N2) - kT_ptrs = K + offs_n2[None, :] * stride_km + offs_k[:, None] * stride_kd - vT_ptrs = V + offs_n2[None, :] * stride_vm + offs_k[:, None] * stride_vd + kT_ptrs = K + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d Di = tl.load(DELTA + offs_m2) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) @@ -604,10 +590,10 @@ def flex_attention_backward_grid( dq += tl.dot(ds, tl.trans(kT)) # Increment pointers. curr_n += BLOCK_N2 - kT_ptrs += BLOCK_N2 * stride_km - vT_ptrs += BLOCK_N2 * stride_km + kT_ptrs += BLOCK_N2 * stride_tok + vT_ptrs += BLOCK_N2 * stride_tok # Write back dQ. - dq_ptrs = DQ + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd + dq_ptrs = DQ + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d tl.store(dq_ptrs, dq) else: # THIS BLOCK DOES DK & DV @@ -620,13 +606,13 @@ def flex_attention_backward_grid( dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n1[:, None] * stride_km + offs_k[None, :] * stride_kd) - v = tl.load(V + offs_n1[:, None] * stride_vm + offs_k[None, :] * stride_vd) + k = tl.load(K + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) offs_m1 = start_m1 + tl.arange(0, BLOCK_M1) offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) - qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd - do_ptrs = DO + offs_m1[:, None] * stride_qm + offs_k[None, :] * stride_qd + qT_ptrs = Q + offs_m1[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m1[:, None] * stride_tok + offs_k[None, :] * stride_d # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) @@ -682,10 +668,10 @@ def flex_attention_backward_grid( dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT)) # Increment pointers. curr_m += BLOCK_M1 - qT_ptrs += BLOCK_M1 * stride_qm - do_ptrs += BLOCK_M1 * stride_qm + qT_ptrs += BLOCK_M1 * stride_tok + do_ptrs += BLOCK_M1 * stride_tok - dv_ptrs = DV + offs_n1[:, None] * stride_vm + offs_k[None, :] * stride_vd + dv_ptrs = DV + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d tl.store(dv_ptrs, dv) # Write back dK. @@ -787,7 +773,6 @@ def flex_attention_backward(*args, **kwargs): layout=layout_k, # We use store_output only for grad_key subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer], mutated_inputs=[grad_query, grad_value], - call_sizes=query.get_size() + [key.get_size()[2]], num_stages=num_stages, num_warps=num_warps, BLOCK_M1=BLOCK_M, diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bc89441e3bd8..531f3c25a31b 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -580,7 +580,6 @@ def generate( epilogue_fn=identity, subgraphs=None, mutated_inputs=None, - call_sizes=None, **kwargs, ): """This function generates a TritonTemplateCaller @@ -615,9 +614,6 @@ def generate( "64-bit indexing is not yet implemented for triton templates" ) - if call_sizes is None: - call_sizes = layout.size - kernel_options = dict( input_nodes=input_nodes, defines=defines, @@ -625,14 +621,13 @@ def generate( num_warps=num_warps, grid_fn=self.grid, meta=kwargs, - call_sizes=call_sizes, + call_sizes=layout.size, prefix_args=prefix_args, suffix_args=suffix_args, epilogue_fn=epilogue_fn, index_dtype="tl.int32", subgraphs=subgraphs, ) - with patch.object( V.graph, "get_dtype", self._fake_get_dtype(fake_out) ), TritonTemplateKernel( @@ -706,7 +701,7 @@ def make_kernel_render(out_node): assert mod.__file__ is not None grid = self.grid( *V.graph.sizevars.size_hints( - call_sizes, + layout.size, fallback=config.unbacked_symint_fallback, ), kwargs, diff --git a/torch/nn/attention/_flex_attention.py b/torch/nn/attention/_flex_attention.py index 430d3280442a..bd999ec39118 100644 --- a/torch/nn/attention/_flex_attention.py +++ b/torch/nn/attention/_flex_attention.py @@ -101,6 +101,11 @@ def score_mod( # Some basic input validation _validate_sdpa_input(query, key, value) + # This will restriction will be removed in newer version of the kernel + if query.size(-2) != key.size(-2): + raise ValueError( + "NYI: The target sequence length (L) of the query tensor must match the source sequence length (S) of the key tensor." + ) if query.size(-2) % 128 != 0: raise ValueError("NYI: S and L must be a multiple of 128") From 3f45fa63f289d465de2694a7ec78ce90cd771b37 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 3 Jun 2024 18:10:45 +0000 Subject: [PATCH 270/706] Revert "[Inductor] Add FlexAttention backward kernel dynamic shape tests (#127728)" This reverts commit 10e3406ea5d115a54a7d753d33110762eb6c07ff. Reverted https://github.com/pytorch/pytorch/pull/127728 on behalf of https://github.com/yanboliang due to Ineternal breakage of https://github.com/pytorch/pytorch/pull/127208 hence reverting ([comment](https://github.com/pytorch/pytorch/pull/127728#issuecomment-2145822667)) --- test/inductor/test_flex_attention.py | 158 ++++++++------------------- 1 file changed, 48 insertions(+), 110 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 16206b51bb13..d4feead90301 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -151,47 +151,6 @@ def _check_equal( msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." self.assertTrue(False, msg) - def _check_out_and_grad( - self, - golden_out: torch.Tensor, - ref_out: torch.Tensor, - compiled_out: torch.Tensor, - q_gold: torch.Tensor, - q_ref: torch.Tensor, - q: torch.Tensor, - k_gold: torch.Tensor, - k_ref: torch.Tensor, - k: torch.Tensor, - v_gold: torch.Tensor, - v_ref: torch.Tensor, - v: torch.Tensor, - ): - dtype = ref_out.dtype - with torch.no_grad(): - # Note, it seems like we really are less accurate than the float32 - # computation, likely due to the online softmax - if dtype == torch.float32: - fudge_factor = 10.0 - else: - fudge_factor = 1.1 - - # Checkout output - self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") - - # Check gradients - q_fudge_factor = 2.5 * fudge_factor - self._check_equal( - q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" - ) - k_fudge_factor = 4 * fudge_factor - self._check_equal( - k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" - ) - v_fudge_factor = 4 * fudge_factor - self._check_equal( - v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" - ) - def run_test( self, score_mod: Callable, @@ -218,20 +177,30 @@ def run_test( ref_out.backward(backward_grad) compiled_out.backward(backward_grad) - self._check_out_and_grad( - golden_out, - ref_out, - compiled_out, - q_gold, - q_ref, - q, - k_gold, - k_ref, - k, - v_gold, - v_ref, - v, - ) + with torch.no_grad(): + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + + # Checkout output + self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") + + # Check gradients + q_fudge_factor = 2.5 * fudge_factor + self._check_equal( + q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" + ) + k_fudge_factor = 4 * fudge_factor + self._check_equal( + k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" + ) + v_fudge_factor = 4 * fudge_factor + self._check_equal( + v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" + ) def run_dynamic_test( self, @@ -244,34 +213,24 @@ def run_dynamic_test( ): sdpa_partial = create_attention(score_mod) # The first eager batch, shape (B, H, S, D) - q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1) - q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64) - ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref) - golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold) - - backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - - golden_out1.backward(backward_grad1.to(torch.float64)) - ref_out1.backward(backward_grad1) + q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + golden_out1 = sdpa_partial( + q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) + ) + ref_out1 = sdpa_partial(q1, k1, v1) # The second eager batch, shape (B * 2, H, S / 2, D) B = int(B * 2) S = int(S / 2) - q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2) - q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64) - ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref) - golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold) - - backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - - golden_out2.backward(backward_grad2.to(torch.float64)) - ref_out2.backward(backward_grad2) + q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + golden_out2 = sdpa_partial( + q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) + ) + ref_out2 = sdpa_partial(q2, k2, v2) # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. # We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation. @@ -279,41 +238,20 @@ def run_dynamic_test( # Compiling with dynamic shape in the first batch. compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) compiled_out1 = compiled_sdpa(q1, k1, v1) - compiled_out1.backward(backward_grad1) - - self._check_out_and_grad( - golden_out1, - ref_out1, - compiled_out1, - q1_gold, - q1_ref, - q1, - k1_gold, - k1_ref, - k1, - v1_gold, - v1_ref, - v1, - ) + + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + + self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # No re-compilation, use the compiled dynamic shape version. compiled_out2 = compiled_sdpa(q2, k2, v2) - compiled_out2.backward(backward_grad2) - self._check_out_and_grad( - golden_out2, - ref_out2, - compiled_out2, - q2_gold, - q2_ref, - q2, - k2_gold, - k2_ref, - k2, - v2_gold, - v2_ref, - v2, - ) + self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) def run_automatic_dynamic_test( From 2fc907971a0ba22c6f6f65295a9d5e7fe501aca7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 3 Jun 2024 18:13:27 +0000 Subject: [PATCH 271/706] Revert "[Inductor] FlexAttention backward kernel optimization (#127208)" This reverts commit f7171313abf14d9501a330457140b2f8a01c9985. Reverted https://github.com/pytorch/pytorch/pull/127208 on behalf of https://github.com/yanboliang due to test_flex_attention is failing internally ([comment](https://github.com/pytorch/pytorch/pull/127208#issuecomment-2145830810)) --- test/inductor/test_flex_attention.py | 4 +- torch/_inductor/kernel/flex_attention.py | 300 ++++++++++------------- torch/_inductor/select_algorithm.py | 5 +- 3 files changed, 125 insertions(+), 184 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index d4feead90301..bc688ab834cb 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -144,8 +144,6 @@ def _check_equal( ): compiled_error = (golden_out - compiled_out).abs().mean() ref_error = (golden_out - ref_out).abs().mean() - if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any(): - self.assertTrue(False, "Output/Grad with NaN") if compiled_error > ref_error * fudge_factor: name = tensor_name if tensor_name is not None else "" msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." @@ -197,7 +195,7 @@ def run_test( self._check_equal( k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" ) - v_fudge_factor = 4 * fudge_factor + v_fudge_factor = 8 * fudge_factor self._check_equal( v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" ) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 3e95dd4f65ce..5a1f45e767a7 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,6 +1,7 @@ """ Triton Implementation of the flex_attention Kernel""" import logging +import math from enum import auto, Enum from typing import Any, List, Tuple @@ -188,7 +189,7 @@ def build_subgraph_buffer( Z = {{size("Q", 0)}} H = {{size("Q", 1)}} - Q_LEN = {{size("Q", 2)}} + N_CTX = {{size("Q", 2)}} qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty @@ -199,7 +200,7 @@ def build_subgraph_buffer( qkv_offset = off_hz * stride_qh Q_block_ptr = tl.make_block_ptr( base=Q + qkv_offset, - shape=(Q_LEN, BLOCK_DMODEL), + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), @@ -207,7 +208,7 @@ def build_subgraph_buffer( ) K_block_ptr = tl.make_block_ptr( base=K + qkv_offset, - shape=(BLOCK_DMODEL, Q_LEN), + shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), @@ -215,7 +216,7 @@ def build_subgraph_buffer( ) V_block_ptr = tl.make_block_ptr( base=V + qkv_offset, - shape=(Q_LEN, BLOCK_DMODEL), + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), @@ -235,7 +236,7 @@ def build_subgraph_buffer( q = (q * qk_scale).to(MATMUL_PRECISION) # loop over k, v and update accumulator lo = 0 - hi = Q_LEN + hi = N_CTX for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- load k, v -- @@ -298,7 +299,7 @@ def build_subgraph_buffer( # TODO dont want to write this if we dont require grad if OUTPUT_LOGSUMEXP: - l_ptrs = LSE + off_hz * Q_LEN + offs_m + l_ptrs = LSE + off_hz * N_CTX + offs_m lse = m_i + tl.math.log2(l_i) tl.store(l_ptrs, lse) """, @@ -445,22 +446,13 @@ def flex_attention(*args, **kwargs): # ---------------------------- Backward HOP Implementation ---------------------------- -def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, meta): +def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, meta): """How is this kernel parallelized? Currently this is only parallelizing over batch * num_heads, but we can, and want to parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require atomic updates to some grad values or to have a two pass kernel design. """ - import triton - - # TODO: support different seqlen for Query and Key/Value. - num_key_value = num_queries - return ( - triton.cdiv(num_queries, meta["BLOCK_M2"]) - + triton.cdiv(num_key_value, meta["BLOCK_N1"]), - 1, - batch_size * num_heads, - ) + return (batch_size * num_heads, 1, 1) flex_attention_backward_template = TritonTemplate( @@ -478,83 +470,95 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me # M: Number of queries, N: Number of keys/values, D: Model dimension # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head # (Modifiable) Config options: - # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. - # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. - # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. - # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # BLOCK_M + # BLOCK_N # SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad # Define Q Strides - stride_z = {{stride("Q", 0)}} - stride_h = {{stride("Q", 1)}} - stride_tok = {{stride("Q", 2)}} - stride_d = {{stride("Q", 3)}} + stride_qz = {{stride("Q", 0)}} + stride_qh = {{stride("Q", 1)}} + stride_qm = {{stride("Q", 2)}} + stride_qk = {{stride("Q", 3)}} + # Define K Strides + stride_kz = {{stride("K", 0)}} + stride_kh = {{stride("K", 1)}} + stride_kn = {{stride("K", 2)}} + stride_kk = {{stride("K", 3)}} + # Define V Strides + stride_vz = {{stride("V", 0)}} + stride_vh = {{stride("V", 1)}} + stride_vn = {{stride("V", 2)}} + stride_vk = {{stride("V", 3)}} Z = {{size("Q", 0)}} H = {{size("Q", 1)}} - Q_LEN = {{size("Q", 2)}} - KV_LEN = {{size("K", 2)}} + N_CTX = {{size("Q", 2)}} + qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty - pid = tl.program_id(0) - NUM_KV_BLOCKS = KV_LEN // BLOCK_N1 - - bhid = tl.program_id(2) - off_chz = (bhid * Q_LEN).to(tl.int64) - adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) - - off_hz = tl.program_id(2) + off_hz = tl.program_id(0) off_z = off_hz // H # batch idx off_h = off_hz % H # head idx # offset pointers for batch/head - Q += adj - K += adj - V += adj - DO += adj - DQ += adj - DV += adj - LSE += off_chz - DELTA += off_chz - - offs_k = tl.arange(0, BLOCK_DMODEL) - - if pid >= NUM_KV_BLOCKS: - # THIS BLOCK DOES DQ - off_pid = pid - NUM_KV_BLOCKS - start_m2 = off_pid * BLOCK_M2 - - offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) - - q = tl.load(Q + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) - dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) - do = tl.load(DO + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) - - lse = tl.load(LSE + offs_m2) - lse = lse[:, None] - - start_n2 = 0 - offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) - offs_n2 = start_n2 + tl.arange(0, BLOCK_N2) - kT_ptrs = K + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d - vT_ptrs = V + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d - Di = tl.load(DELTA + offs_m2) - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - - curr_n = start_n2 - num_steps = KV_LEN // BLOCK_N2 - for blk_idx in range(num_steps): - offs_n2= curr_n + tl.arange(0, BLOCK_N2) - kT = tl.load(kT_ptrs) - vT = tl.load(vT_ptrs) - qk = tl.dot(q, kT) - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + + # Asserting contiguous for now... + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_vz + off_h * stride_vh + + # TODO I think that this should be N_CTX/BLOCK_N blocks + for start_n in range(0, NUM_Q_BLOCKS): + # We are not doing the causal optimization yet allowing us to start further down the + # kv column + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_DMODEL) + + # initialize pointers to value-like data + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) + do_ptrs = DO + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + + # pointer to row-wise quantities in value-like data + D_ptrs = DELTA + off_hz * N_CTX + l_ptrs = LSE + off_hz * N_CTX + + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + + # Key and Value stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + + for start_m in range(0, NUM_Q_BLOCKS * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + + if SCORE_MOD_IS_LINEAR: + qk_scale *= 1.44269504 + q = (q * qk_scale).to(MATMUL_PRECISION) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, tl.trans(k.to(MATMUL_PRECISION)), acc=qk) pre_mod_scores = qk - m = offs_m2[:, None] - n = offs_n2[None, :] + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = offs_m_curr[:, None] + n = offs_n[None, :] {{ modification( subgraph_number=0, output_name="post_mod_scores", @@ -565,94 +569,26 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me n="n", out="qk" ) | indent_except_first(3) }} - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # TODO: In the case that score_mod is linear, this can be LICMed if not SCORE_MOD_IS_LINEAR: post_mod_scores *= 1.44269504 - p = tl.math.exp2(post_mod_scores - lse).to(MATMUL_PRECISION) - # Compute dP and dS. - dp = tl.dot(do, vT) - ds = p * (dp - Di[:, None]) - # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ - {{ modification( - subgraph_number=1, - output_name = "grad_scores", - score="pre_mod_scores", - b="off_z", - h="off_h", - m="m", - n="n", - grad_score_mod="ds" - ) | indent_except_first(3) }} - ds = grad_scores # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - ds = ds.to(MATMUL_PRECISION) - # Compute dQ. - dq += tl.dot(ds, tl.trans(kT)) - # Increment pointers. - curr_n += BLOCK_N2 - kT_ptrs += BLOCK_N2 * stride_tok - vT_ptrs += BLOCK_N2 * stride_tok - # Write back dQ. - dq_ptrs = DQ + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dq_ptrs, dq) - else: - # THIS BLOCK DOES DK & DV - start_n1 = pid * BLOCK_N1 - start_m1 = 0 - - offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) - - dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) - - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) - - offs_m1 = start_m1 + tl.arange(0, BLOCK_M1) - offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) - qT_ptrs = Q + offs_m1[None, :] * stride_tok + offs_k[:, None] * stride_d - do_ptrs = DO + offs_m1[:, None] * stride_tok + offs_k[None, :] * stride_d - # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) - - curr_m = start_m1 - num_steps = Q_LEN // BLOCK_M1 - for blk_idx in range(num_steps): - qT = tl.load(qT_ptrs) - # Load LSE before computing qk to reduce pipeline stall. - offs_m1 = curr_m + tl.arange(0, BLOCK_M1) - lse = tl.load(LSE + offs_m1) - qkT = tl.dot(k, qT) - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - m = offs_m1[None, :] - n = offs_n1[:, None] - pre_mod_scores = qkT - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qkT", - b="off_z", - h="off_h", - m="m", - n="n", - out="qkT" - ) | indent_except_first(3) }} - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - if not SCORE_MOD_IS_LINEAR: - post_mod_scores *= 1.44269504 - pT = tl.math.exp2(post_mod_scores - lse[None, :]) + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(post_mod_scores - l_i[:, None]) + + # compute dv do = tl.load(do_ptrs) - # Compute dV. - ppT = pT - dv += tl.dot(ppT.to(MATMUL_PRECISION), do) - Di = tl.load(DELTA + offs_m1) - # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)) - dsT = pT * (dpT - Di[None, :]) + dv += tl.dot(tl.trans(p.to(MATMUL_PRECISION)), do) + + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) # [BLOCKM, 1] + + # compute ds = p * (dp - delta[:, None]) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, tl.trans(v)) + ds = p * dp + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ - m = offs_m1[None, :] - n = offs_n1[:, None] {{ modification( subgraph_number=1, output_name = "grad_scores", @@ -661,25 +597,36 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me h="off_h", m="m", n="n", - grad_score_mod="dsT" + grad_score_mod="ds" ) | indent_except_first(3) }} - dsT = grad_scores + ds = grad_scores # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT)) - # Increment pointers. - curr_m += BLOCK_M1 - qT_ptrs += BLOCK_M1 * stride_tok - do_ptrs += BLOCK_M1 * stride_tok + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(MATMUL_PRECISION)), q) + # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(MATMUL_PRECISION), k) + + # Store grad_query + tl.store(dq_ptrs, dq) + + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + + # write-back + index_n = offs_n[:, None] + index_k = offs_k[None, :] - dv_ptrs = DV + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d + # Store grad_key and grad_value + dv_ptrs = DV + (index_n * stride_vn + index_k * stride_vk) tl.store(dv_ptrs, dv) - # Write back dK. - index_n = offs_n1[:, None] - index_k = offs_k[None, :] # TODO generalize and add proper mask support mask = (index_n != -1) & (index_k != -1) {{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + """, ) @@ -775,11 +722,10 @@ def flex_attention_backward(*args, **kwargs): mutated_inputs=[grad_query, grad_value], num_stages=num_stages, num_warps=num_warps, - BLOCK_M1=BLOCK_M, - BLOCK_N1=BLOCK_N, - BLOCK_M2=BLOCK_N, - BLOCK_N2=BLOCK_M, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, BLOCK_DMODEL=query.get_size()[-1], + NUM_Q_BLOCKS=math.ceil(query.get_size()[-2] / BLOCK_M), # For now, we always assume the "sound" option SCORE_MOD_IS_LINEAR=False, ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 531f3c25a31b..d8ca3eefed70 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -310,10 +310,7 @@ def modification( Args: subgraph_number (int): The index of the subgraph in self.subgraphs """ - num = 0 - while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: - num += 1 - with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): + with self.create_subgraph_body(f"modification_{subgraph_number}"): assert isinstance(subgraph_number, int) assert isinstance(self.subgraphs, list) assert ( From db9d457a3f77b67c082b96c580a1356fd0f25ffd Mon Sep 17 00:00:00 2001 From: "Xiangyang (Mark) Guo" Date: Mon, 3 Jun 2024 19:33:06 +0000 Subject: [PATCH 272/706] Use sleef on macOS Apple silicon by default (#126509) Use sleef ~~for aarch64~~ on macOS Apple silicon by default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126509 Approved by: https://github.com/digantdesai, https://github.com/malfet --- CMakeLists.txt | 8 ++++++++ torch/_inductor/codecache.py | 2 ++ 2 files changed, 10 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 335f5750648c..998073bc72b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -892,6 +892,14 @@ endif() if(USE_SLEEF_FOR_ARM_VEC256) string(APPEND CMAKE_CXX_FLAGS " -DAT_BUILD_ARM_VEC256_WITH_SLEEF") + add_definitions(-DAT_BUILD_ARM_VEC256_WITH_SLEEF) +endif() + +# Enable sleef on macOS with Apple silicon by default +if((${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") AND (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "arm64")) + message(STATUS "Running on macOS with Apple silicon") + string(APPEND CMAKE_CXX_FLAGS " -DAT_BUILD_ARM_VEC256_WITH_SLEEF") + add_definitions(-DAT_BUILD_ARM_VEC256_WITH_SLEEF) endif() if(USE_XNNPACK) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index d338c2665484..6e70adf1758b 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1327,6 +1327,8 @@ def __bool__(self) -> bool: class VecNEON(VecISA): _bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h _macro = "-DCPU_CAPABILITY_NEON" + if sys.platform == "darwin" and platform.processor() == "arm": + _macro += " -DAT_BUILD_ARM_VEC256_WITH_SLEEF" _arch_flags = "" # Unused _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} From 941316f821b8f6fa580a86d5b9a24fbcc623503f Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 29 May 2024 16:13:13 -0700 Subject: [PATCH 273/706] [pipelining] Stress test schedules with multi iters (#127475) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127475 Approved by: https://github.com/wconstab --- test/distributed/pipelining/test_schedule.py | 40 ++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 232f69d8bcef..3cefc6da2322 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -52,6 +52,46 @@ def setUpClass(cls): dev_id = cls.rank % torch.cuda.device_count() cls.device = torch.device(f"cuda:{dev_id}") + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) + def test_multi_iter(self, ScheduleClass): + mod = MultiMLP(d_hid, n_layers=self.world_size) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + target = torch.randn(batch_size, d_hid, device=self.device) + loss_fn = torch.nn.MSELoss(reduction="sum") + + # Create a pipeline + chunks = 4 + split_spec = mod.split_spec if hasattr(mod, "split_spec") else None + pipe = pipeline( + mod, + chunks, + example_args=(x,), + split_spec=split_spec, + ) + + stage = PipelineStage( + pipe, + self.rank, + device=self.device, + ) + + # Attach to a schedule + schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) + + # Run + for _ in range(20): + if self.rank == 0: + schedule.step(x) + elif self.rank == self.world_size - 1: + losses = [] + out = schedule.step(target=target, losses=losses) + else: + schedule.step() + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) From ef9451ac8dc467d2bb436589e76a0d5cc80b45fd Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 3 Jun 2024 20:35:22 +0000 Subject: [PATCH 274/706] Move the build of AOTriton to base ROCM docker image. (#127012) Mitigates #126111 AOTrtion, as a Math library, takes long time to build. However, this library itself is not moving as fast as PyTorch itself and it is not cost-efficient to build it for every CI check. This PR moves the build of AOTriton from PyTorch to its base docker image, avoids duplicated and long build time. Pre-this-PR: * PyTorch base docker build job duration: 1.1-1.3h * PyTorch build job duration: 1.4-1.5hr (includes AOTriton build time of 1hr6min on a linux.2xlarge node) Post-this-PR: * PyTorch base docker build job duration: 1.3h (includes AOTriton build time of 20min on a linux.12xlarge node) * PyTorch build job duration: <20 min Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/127012 Approved by: https://github.com/jithunnair-amd, https://github.com/pruthvistony, https://github.com/huydhn --- .ci/docker/centos-rocm/Dockerfile | 7 ++++ .ci/docker/ci_commit_pins/aotriton.txt | 1 + .ci/docker/common/install_aotriton.sh | 24 +++++++++++++ .ci/docker/ubuntu-rocm/Dockerfile | 7 ++++ caffe2/CMakeLists.txt | 1 - cmake/External/aotriton.cmake | 47 ++++++++++++++++---------- 6 files changed, 68 insertions(+), 19 deletions(-) create mode 100644 .ci/docker/ci_commit_pins/aotriton.txt create mode 100644 .ci/docker/common/install_aotriton.sh diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index 6cb82a1f770c..38d2ff4ed9ab 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -118,6 +118,13 @@ COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH RUN bash ./install_cache.sh && rm install_cache.sh +# Install AOTriton +COPY ci_commit_pins/aotriton.txt aotriton.txt +COPY ./common/common_utils.sh common_utils.sh +COPY ./common/install_aotriton.sh install_aotriton.sh +RUN bash ./install_aotriton.sh /opt/rocm/aotriton && rm -rf install_aotriton.sh aotriton aotriton.txt common_utils.sh +ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton + # Include BUILD_ENVIRONMENT environment variable in image ARG BUILD_ENVIRONMENT ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT} diff --git a/.ci/docker/ci_commit_pins/aotriton.txt b/.ci/docker/ci_commit_pins/aotriton.txt new file mode 100644 index 000000000000..adb49c304bf4 --- /dev/null +++ b/.ci/docker/ci_commit_pins/aotriton.txt @@ -0,0 +1 @@ +24a3fe9cb57e5cda3c923df29743f9767194cc27 diff --git a/.ci/docker/common/install_aotriton.sh b/.ci/docker/common/install_aotriton.sh new file mode 100644 index 000000000000..47c7a9df773f --- /dev/null +++ b/.ci/docker/common/install_aotriton.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -ex + +source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh" + +AOTRITON_DIR="aotriton" +AOTRITON_PINNED_NAME="aotriton" # No .txt extension +AOTRITON_PINNED_COMMIT=$(get_pinned_commit ${AOTRITON_PINNED_NAME}) +AOTRITON_INSTALL_PREFIX="$1" + +git clone https://github.com/ROCm/aotriton.git "${AOTRITON_DIR}" +cd "${AOTRITON_DIR}" +git checkout "${AOTRITON_PINNED_COMMIT}" +git submodule sync --recursive +git submodule update --init --recursive --force --depth 1 +mkdir build +cd build +cmake .. -G Ninja -DCMAKE_INSTALL_PREFIX=./install_dir -DCMAKE_BUILD_TYPE=Release -DAOTRITON_COMPRESS_KERNEL=OFF -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NO_SHARED=ON +ninja install +mkdir -p "${AOTRITON_INSTALL_PREFIX}" +cp -r install_dir/* "${AOTRITON_INSTALL_PREFIX}" +find /tmp/ -mindepth 1 -delete +rm -rf ~/.triton diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index cc43d9ec2414..111a727fe5b8 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -110,6 +110,13 @@ COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH RUN bash ./install_cache.sh && rm install_cache.sh +# Install AOTriton +COPY ci_commit_pins/aotriton.txt aotriton.txt +COPY ./common/common_utils.sh common_utils.sh +COPY ./common/install_aotriton.sh install_aotriton.sh +RUN bash ./install_aotriton.sh /opt/rocm/aotriton && rm -rf install_aotriton.sh aotriton aotriton.txt common_utils.sh +ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton + # Include BUILD_ENVIRONMENT environment variable in image ARG BUILD_ENVIRONMENT ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT} diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 1e29044e19fd..df1eecf929a8 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -939,7 +939,6 @@ if(USE_ROCM) hip_add_library(torch_hip ${Caffe2_HIP_SRCS}) if(USE_FLASH_ATTENTION) target_link_libraries(torch_hip PRIVATE __caffe2_aotriton) - add_dependencies(torch_hip aotriton_external) endif() set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_hip) # see cmake/public/utils.cmake diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index de64370b37a2..c95c66626837 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -4,25 +4,36 @@ if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/src") set(__AOTRITON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/build") set(__AOTRITON_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch") - ExternalProject_Add(aotriton_external - GIT_REPOSITORY https://github.com/ROCm/aotriton.git - GIT_TAG 24a3fe9cb57e5cda3c923df29743f9767194cc27 - SOURCE_DIR ${__AOTRITON_SOURCE_DIR} - BINARY_DIR ${__AOTRITON_BUILD_DIR} - PREFIX ${__AOTRITON_INSTALL_DIR} - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} - -DAOTRITON_COMPRESS_KERNEL=OFF - -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} - -DAOTRITON_NO_PYTHON=ON - -DAOTRITON_NO_SHARED=ON - # CONFIGURE_COMMAND "" - # BUILD_COMMAND ${MAKE_COMMAND} - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a" - # INSTALL_COMMAND ${MAKE_COMMAND} install - ) - set(AOTRITON_FOUND TRUE) add_library(__caffe2_aotriton INTERFACE) - add_dependencies(__caffe2_aotriton aotriton_external) + # Note it is INSTALL"ED" + if(DEFINED ENV{AOTRITON_INSTALLED_PREFIX}) + set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}") + message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}") + else() + ExternalProject_Add(aotriton_external + GIT_REPOSITORY https://github.com/ROCm/aotriton.git + GIT_TAG 24a3fe9cb57e5cda3c923df29743f9767194cc27 + SOURCE_DIR ${__AOTRITON_SOURCE_DIR} + BINARY_DIR ${__AOTRITON_BUILD_DIR} + PREFIX ${__AOTRITON_INSTALL_DIR} + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} + -DAOTRITON_COMPRESS_KERNEL=OFF + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DAOTRITON_NO_PYTHON=ON + -DAOTRITON_NO_SHARED=ON + # CONFIGURE_COMMAND "" + BUILD_COMMAND "" # No build, install command will repeat the build process due to problems in the build system. + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a" + USES_TERMINAL_DOWNLOAD TRUE + USES_TERMINAL_CONFIGURE TRUE + USES_TERMINAL_BUILD TRUE + USES_TERMINAL_INSTALL TRUE + # INSTALL_COMMAND ${MAKE_COMMAND} install + ) + add_dependencies(__caffe2_aotriton aotriton_external) + message(STATUS "Using AOTriton compiled from source directory ${__AOTRITON_SOURCE_DIR}") + endif() target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a) target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) + set(AOTRITON_FOUND TRUE) endif() # __AOTRITON_INCLUDED From a4064da8cac7345fdf1ffb1f03262f9b235f37a0 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 3 Jun 2024 07:23:21 -0700 Subject: [PATCH 275/706] Always simplify sympy expressions before printing. (#127543) This is important because if a replacement has happened during inductor lowering, we may have stale symbols in sympy expressions that we need to replace away. Do this at the very end. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/127543 Approved by: https://github.com/lezcano --- test/inductor/test_memory_planning.py | 4 ++-- torch/_inductor/codegen/common.py | 6 ++++++ torch/_inductor/codegen/wrapper.py | 7 ++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index 1ec1dd9f89e9..78c7086972eb 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -56,7 +56,7 @@ def test_python_wrapper(self): ).check_next( "buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))" ).check( - "buf1 = alloc_from_pool(pool1, align((4*s0) + (4*s0*((-1) + s0)))," + "buf1 = alloc_from_pool(pool1, align(4*(s0*s0))," ).run( code ) @@ -74,7 +74,7 @@ def test_cpp_wrapper(self): ).check_next( "auto buf0 = alloc_from_pool(pool1, 0, at::kFloat, {s0, s0}, {s0, 1L});" ).check( - "auto buf1 = alloc_from_pool(pool1, align((4L*s0) + (4L*s0*((-1L) + s0)))," + "auto buf1 = alloc_from_pool(pool1, align(4L*(static_cast(s0*s0)))," ).run( code ) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 3e238203b770..f7b3e7a45d6e 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -403,6 +403,12 @@ def _print_align(self, expr): assert len(expr.args) == 1 return f"align({self._print(expr.args[0])})" + def doprint(self, expr, *, simplify: bool = True): + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + class PythonPrinter(ExprPrinter): def _print_ModularIndexing(self, expr): diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0bf4814f80b1..c90776e4cbd2 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -439,7 +439,7 @@ def __init__(self): self.stride = "stride()" self.last_seen_device_guard_index: Optional[int] = None self.supports_intermediate_hooks = True - self.expr_printer = pexpr + self.expr_printer: Callable[[Any], str] = pexpr self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {} self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol self.allow_stack_allocation: Optional[bool] = None @@ -906,10 +906,7 @@ def finalize_prefix(self): pass def codegen_python_sizevar(self, x: Expr, *, simplify: bool = True) -> str: - if simplify: - return pexpr(V.graph.sizevars.simplify(x)) - else: - return pexpr(x) + return pexpr(x, simplify=simplify) def codegen_sizevar(self, x: Expr) -> str: return self.codegen_python_sizevar(x) From 6d4ec9b2ecba7d26e885bdbb0faeeaa1e148cfd6 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Mon, 3 Jun 2024 21:21:55 +0000 Subject: [PATCH 276/706] [RFC] Introduce Checkpointable for DCP (#127540) (#127628) Summary: # Introduce Checkpointable interface for DCP to support arbitrary tensor subclasses for checkpointing **Authors:** * zainhuda ## **Summary** This diff adds a CheckpointableTensor interface to allow for future compatibility for any tensor subclass with DCP in a clean and maintainable way. ## **Motivation** For TorchRec sharding migration from ShardedTensor to DTensor, we create a tensor subclass that is stored by DTensor to support TorchRec's sharding schemes (ex, empty shards, multiple shards on a rank). ## **Proposed Implementation** View the CheckpointableTensor interface implementation, in which, we introduce the minimal set of methods needed to be compatible with DCP. These methods are expected to implemented by any tensor subclasses and as such are then checkpointable by DCP. ## **Drawbacks** No drawbacks, it extends functionality in a clean and maintainable way. ## **Alternatives** Alternative design was creating paths for checking for certain attributes in tensor subclasses which can get messy and hard to maintain/understand why it was there in the first place. Test Plan: Sandcastle cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC Differential Revision: D57970603 Pulled By: iamzainhuda Pull Request resolved: https://github.com/pytorch/pytorch/pull/127628 Approved by: https://github.com/wz337, https://github.com/XilunWu, https://github.com/fegin --- torch/distributed/checkpoint/planner.py | 36 +++++++++++++++++++ .../distributed/checkpoint/planner_helpers.py | 15 ++++++-- torch/distributed/checkpoint/utils.py | 8 ++++- 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/torch/distributed/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py index ad2466a50ee8..5eec8bf75466 100644 --- a/torch/distributed/checkpoint/planner.py +++ b/torch/distributed/checkpoint/planner.py @@ -426,3 +426,39 @@ def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: The contents of tensor will follow its device synchronization model. """ pass + + +class _Checkpointable: + """ + Interface for checkpointable objects. + This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface. + """ + + @abc.abstractmethod + def _create_write_items(self, fqn: str, object: Any) -> List[WriteItem]: + """ + Return a list of WriteItems based on object's contents. + """ + raise NotImplementedError( + "_Checkpointable._create_write_items is not implemented" + ) + + @abc.abstractmethod + def _create_chunk_list(self, tensor: torch.Tensor) -> List[ChunkStorageMetadata]: + """ + Return a list of `ChunkStorageMetadata` based on object's contents. + """ + raise NotImplementedError( + "_Checkpointable._create_chunk_list is not implemented" + ) + + @abc.abstractmethod + def _get_tensor_shard( + self, tensor: torch.Tensor, index: MetadataIndex + ) -> torch.Tensor: + """ + Return a 'torch.Tensor' shard based on 'MetadataIndex'. + """ + raise NotImplementedError( + "_Checkpointable._get_tensor_shard is not implemented" + ) diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py index 5829ab6111e2..c4e5be89a45d 100644 --- a/torch/distributed/checkpoint/planner_helpers.py +++ b/torch/distributed/checkpoint/planner_helpers.py @@ -8,6 +8,7 @@ from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._tensor import DTensor from torch.distributed._tensor._utils import compute_local_shape_and_global_offset +from torch.distributed.checkpoint.planner import _Checkpointable from torch.utils._pytree import tree_map_only @@ -217,7 +218,12 @@ def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: - if isinstance(object, DTensor): + if isinstance(object, _Checkpointable): + return object._create_write_items(fqn, object) + elif isinstance(object, DTensor): + # DTensor can contain a local tensor that is a tensor subclass + if isinstance(object.to_local(), _Checkpointable): + return object.to_local()._create_write_items(fqn, object) # type: ignore[arg-type] return [_create_write_items_for_dtensor(fqn, object)] elif isinstance(object, ShardedTensor): return [ @@ -242,7 +248,12 @@ def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]: - if isinstance(tensor, DTensor): + if isinstance(tensor, _Checkpointable): + local_chunks = tensor._create_chunk_list(tensor) + elif isinstance(tensor, DTensor): + # DTensor can contain a local tensor that is a tensor subclass + if isinstance(tensor.to_local(), _Checkpointable): + return tensor.to_local()._create_chunk_list(tensor) # type: ignore[arg-type] local_chunks = [_create_chunk_from_dtensor(tensor)] elif isinstance(tensor, ShardedTensor): local_chunks = [ diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index d781d9839bea..a93c0bfc400a 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -14,6 +14,7 @@ from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.planner import _Checkpointable from .api import ( _is_wrapped_exception, @@ -301,7 +302,12 @@ def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: - if isinstance(tensor, DTensor): + if isinstance(tensor, _Checkpointable): + return tensor._get_tensor_shard(tensor, index) + elif isinstance(tensor, DTensor): + # DTensor can contain a local tensor that is a tensor subclass + if isinstance(tensor.to_local(), _Checkpointable): + return tensor.to_local()._get_tensor_shard(tensor, index) # type: ignore[arg-type] return tensor.to_local() if isinstance(tensor, ShardedTensor): return _find_shard(tensor, index).tensor From c6dc6246902189fd33f800c48439568af9a02f20 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Thu, 30 May 2024 14:23:40 -0700 Subject: [PATCH 277/706] [torchbind] remove test cases that don't fakify script objects (#127113) As titled. Differential Revision: [D57991003](https://our.internmc.facebook.com/intern/diff/D57991003) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127113 Approved by: https://github.com/zou3519 --- test/export/test_torchbind.py | 56 ++++++++--------------------------- 1 file changed, 12 insertions(+), 44 deletions(-) diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 3e4a78c61769..b404e02fc370 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -255,15 +255,8 @@ def forward(self, token, obj_attr, x): ) @parametrize("pre_dispatch", [True, False]) - @parametrize("fakify_script_obj", [True, False]) - def test_input(self, pre_dispatch, fakify_script_obj): + def test_input(self, pre_dispatch): cc = torch.classes._TorchScriptTesting._Foo(10, 20) - if not fakify_script_obj: - qual_name = cc._type().qualified_name() # type: ignore[att-defined] - if torch._library.fake_class_registry.has_fake_class(qual_name): - torch._library.fake_class_registry.deregister_fake_class( - "_TorchScriptTesting::_Foo" - ) class MyModule(torch.nn.Module): def __init__(self): @@ -295,19 +288,11 @@ def forward(self, x, cc): # aot_export_function runs the program twice # in run_functionalized_fw_and_collect_metadata and create_aot_dispatcher_function # We also have a re-tracing test, which doubles the count. - if fakify_script_obj: - self.assertEqual(self.foo_add_tensor_counter, 4) + self.assertEqual(self.foo_add_tensor_counter, 4) @parametrize("pre_dispatch", [True, False]) - @parametrize("fakify_script_obj", [True, False]) - def test_input_as_custom_op_argument(self, pre_dispatch, fakify_script_obj): + def test_input_as_custom_op_argument(self, pre_dispatch): cc = torch.classes._TorchScriptTesting._Foo(10, 20) - if not fakify_script_obj: - qual_name = cc._type().qualified_name() # type: ignore[att-defined] - if torch._library.fake_class_registry.has_fake_class(qual_name): - torch._library.fake_class_registry.deregister_fake_class( - "_TorchScriptTesting::_Foo" - ) class MyModule(torch.nn.Module): def __init__(self): @@ -322,16 +307,13 @@ def forward(self, x, cc): torch.ops._TorchScriptTesting.takes_foo.default._dispatch_cache.clear() # Even though a C++ implementation for takes_foo.default is registered, # we still need the python implementation for takes_foo.default to trace with FakeFoo. - if fakify_script_obj: - with self.assertRaisesRegex( - RuntimeError, "no python implementation is found" - ): - self._test_export_same_as_eager( - MyModule(), - (torch.ones(2, 3), cc), - strict=False, - pre_dispatch=pre_dispatch, - ) + with self.assertRaisesRegex(RuntimeError, "no python implementation is found"): + self._test_export_same_as_eager( + MyModule(), + (torch.ones(2, 3), cc), + strict=False, + pre_dispatch=pre_dispatch, + ) torch.ops._TorchScriptTesting.takes_foo.default.py_impl( torch._C.DispatchKey.Meta @@ -364,8 +346,7 @@ def forward(self, token, x, cc): ) @parametrize("pre_dispatch", [True, False]) - @parametrize("fakify_script_obj", [True, False]) - def test_torchbind_alias(self, pre_dispatch, fakify_script_obj): + def test_torchbind_alias(self, pre_dispatch): class F2(torch.nn.Module): def __init__(self, foo): super().__init__() @@ -378,12 +359,6 @@ class F1(torch.nn.Module): def __init__(self): super().__init__() self.alpha = torch.classes._TorchScriptTesting._Foo(10, 20) - if not fakify_script_obj: - qual_name = self.alpha._type().qualified_name() - if torch._library.fake_class_registry.has_fake_class(qual_name): - torch._library.fake_class_registry.deregister_fake_class( - "_TorchScriptTesting::_Foo" - ) self.beta = self.alpha self.gamma = self.alpha self.foo = F2(self.gamma) @@ -402,8 +377,7 @@ def forward(self, x): # TODO(pianpwk): look into this @unittest.expectedFailure @parametrize("pre_dispatch", [True, False]) - @parametrize("fakify_script_obj", [True, False]) - def test_torchbind_input_and_alias(self, pre_dispatch, fakify_script_obj): + def test_torchbind_input_and_alias(self, pre_dispatch): # alias as model attribute class F3(torch.nn.Module): def forward(self, x, foo): @@ -411,12 +385,6 @@ def forward(self, x, foo): return x + self.foo.add_tensor(x) foo = torch.classes._TorchScriptTesting._Foo(10, 20) - if not fakify_script_obj: - qual_name = foo._type().qualified_name() # type: ignore[att-defined] - if torch._library.fake_class_registry.has_fake_class(qual_name): - torch._library.fake_class_registry.deregister_fake_class( - "_TorchScriptTesting::_Foo" - ) self._test_export_same_as_eager( F3(), (torch.ones(2, 3), foo), strict=False, pre_dispatch=pre_dispatch ) From 3efac92888ff473b2262b2cb06da1264cde19ae9 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Thu, 30 May 2024 14:23:40 -0700 Subject: [PATCH 278/706] [torchbind] support torch.compile with aot_eager backend (#127114) Differential Revision: [D57991001](https://our.internmc.facebook.com/intern/diff/D57991001) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127114 Approved by: https://github.com/zou3519 ghstack dependencies: #127113 --- test/export/test_torchbind.py | 130 ++++++++++++++---------- torch/_functorch/_aot_autograd/utils.py | 1 + torch/_functorch/aot_autograd.py | 6 ++ 3 files changed, 85 insertions(+), 52 deletions(-) diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index b404e02fc370..2ff09598ace8 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -907,8 +907,10 @@ def size(self): def tearDown(self): torch._dynamo.reset() - def test_compile_script_object_input(self): - backend = EagerAndRecordGraphs() + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_script_object_input(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() class Model(torch.nn.Module): def __init__(self): @@ -952,23 +954,25 @@ def forward(self, tq, x): # does not return L_tq_ as output. This is because it's able # to detect that L_tq_ is an input therefore don't return # it as graph output. Related logic is in dynamo/codegen.py - self.assertExpectedInline( - backend.graphs[0].code.strip(), - """\ -def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor): - l_tq_ = L_tq_ - l_x_ = L_x_ - cos = l_x_.cos() - call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos); cos = None - sin = l_x_.sin(); l_x_ = None - call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin); sin = None - call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop') - call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None - x_sin = call_torchbind_2 - 1; call_torchbind_2 = None - return (x_sin,)""", - ) + if backend == "eager": + self.assertExpectedInline( + backend.graphs[0].code.strip(), + """\ + def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor): + l_tq_ = L_tq_ + l_x_ = L_x_ + cos = l_x_.cos() + call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos); cos = None + sin = l_x_.sin(); l_x_ = None + call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin); sin = None + call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop') + call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None + x_sin = call_torchbind_2 - 1; call_torchbind_2 = None + return (x_sin,)""", + ) - def test_compile_script_object_input_guards(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_script_object_input_guards(self, backend): class Model(torch.nn.Module): def __init__(self): super().__init__() @@ -981,7 +985,7 @@ def forward(self, tq, x): return x_sin, tq mod = Model() - cnt = torch._dynamo.testing.CompileCounter() + cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) x = torch.randn(2, 3) tq1 = _empty_tensor_queue() @@ -1052,8 +1056,10 @@ def forward(self, tq, x): torch.compile(mod, backend=cnt)(tq3, x) self.assertEqual(cnt.frame_count, 2) - def test_compile_error_on_input_aliasing_contents(self): - backend = EagerAndRecordGraphs() + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_error_on_input_aliasing_contents(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() class Model(torch.nn.Module): def __init__(self): @@ -1071,7 +1077,11 @@ def forward(self, tq, x): with self.assertRaisesRegex(RuntimeError, "is alising"): torch.compile(mod, backend=backend)(tq1, x) - def test_compile_error_on_script_obj_setattr(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_error_on_script_obj_setattr(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() + def setattr_f(tq): tq.a = 1 return tq @@ -1079,19 +1089,25 @@ def setattr_f(tq): with self.assertRaisesRegex( RuntimeError, "call method __setattr__ on script object is not safe" ): - torch.compile(setattr_f, backend="eager")(_empty_tensor_queue()) + torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_error_on_script_obj_missing_attr(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() - def test_compile_error_on_script_obj_missing_attr(self): def setattr_f(tq): return tq._not_defined_attr with self.assertRaisesRegex( RuntimeError, "doesn't define method _not_defined_attr" ): - torch.compile(setattr_f, backend="eager")(_empty_tensor_queue()) + torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) - def test_compile_body_aliasing_contents(self): - backend = EagerAndRecordGraphs() + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_body_aliasing_contents(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() def f(tq, x): x1 = x.view(-1) @@ -1106,7 +1122,7 @@ def f(tq, x): f(_empty_tensor_queue(), x), torch.compile(f, backend=backend)(_empty_tensor_queue(), x), ) - if not torch._dynamo.is_compiling(): + if not torch._dynamo.is_compiling() and backend == "eager": self.assertExpectedInline( backend.graphs[0].code.strip(), """\ @@ -1124,8 +1140,10 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject): return (sub, add)""", ) - def test_compile_error_on_non_fakified_method(self): - backend = EagerAndRecordGraphs() + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_error_on_non_fakified_method(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() def f(tq, x): x1 = x.view(-1) @@ -1143,7 +1161,8 @@ def f(tq, x): ): torch.compile(f, backend=backend)(_empty_tensor_queue(), x) - def test_compile_obj_as_hop_input(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_obj_as_hop_input(self, backend): def f(tq, x): def fn(tq, x): tq.push(x) @@ -1155,10 +1174,11 @@ def fn(tq, x): _assertEqualScriptObject( self, f(_empty_tensor_queue(), x), - torch.compile(f, backend="eager")(_empty_tensor_queue(), x), + torch.compile(f, backend=backend)(_empty_tensor_queue(), x), ) - def test_compile_obj_closure(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_obj_closure(self, backend): def f(x): def inner_f(x): tq.push(x.sin()) @@ -1172,7 +1192,8 @@ def inner_f(x): x = torch.randn(3, 2) _assertEqualScriptObject(self, f(x), opt_f(x)) - def test_compile_global_obj(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_global_obj(self, backend): global _TENSOR_QUEUE_GLOBAL_TEST _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue() @@ -1180,7 +1201,7 @@ def f(x): _TENSOR_QUEUE_GLOBAL_TEST.push(x.sin()) return _TENSOR_QUEUE_GLOBAL_TEST.pop(), _TENSOR_QUEUE_GLOBAL_TEST - opt_f = torch.compile(f, backend="eager") + opt_f = torch.compile(f, backend=backend) x = torch.randn(3, 2) eager_ret = f(x) opt_ret = opt_f(x) @@ -1207,8 +1228,10 @@ def f(tq, x): ) self.assertEqual(cnt.frame_count, 4) - def test_compile_obj_attributes(self): - backend = EagerAndRecordGraphs() + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_obj_attributes(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() class Model(torch.nn.Module): def __init__(self): @@ -1222,21 +1245,23 @@ def forward(self, x): x = torch.randn(2, 3) opt_f = torch.compile(Model(), backend=backend) _assertEqualScriptObject(self, Model()(x), opt_f(x)) - self.assertEqual(len(backend.graphs), 1) - # lifted as input. In the future, we would want to cosolidate this - # with non-strict behavior, where they're set as attributes. - self.assertExpectedInline( - backend.graphs[0].code.strip(), - """\ -def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor): - l_self_tq = L_self_tq - l_x_ = L_x_ - call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_); l_x_ = None - call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop'); l_self_tq = None - return (call_torchbind_1,)""", - ) + if backend == "eager": + self.assertEqual(len(backend.graphs), 1) + # lifted as input. In the future, we would want to cosolidate this + # with non-strict behavior, where they're set as attributes. + self.assertExpectedInline( + backend.graphs[0].code.strip(), + """\ + def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor): + l_self_tq = L_self_tq + l_x_ = L_x_ + call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_); l_x_ = None + call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop'); l_self_tq = None + return (call_torchbind_1,)""", + ) - def test_compile_obj_torchbind_op(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_obj_torchbind_op(self, backend): def f(tq, x): torch.ops._TorchScriptTesting.queue_push(tq, x.cos()) torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1) @@ -1244,7 +1269,7 @@ def f(tq, x): torch.ops._TorchScriptTesting.queue_push(tq, x.sin()) return tq.pop(), tq.pop() + tq.size(), tq - opt_f = torch.compile(f, backend="eager") + opt_f = torch.compile(f, backend=backend) x = torch.randn(2) _assertEqualScriptObject( self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x) @@ -1302,6 +1327,7 @@ def __obj_unflatten__(cls, flattend_foo): instantiate_parametrized_tests(TestExportTorchbind) +instantiate_parametrized_tests(TestCompileTorchbind) if __name__ == "__main__": run_tests() diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index e23a32f10cc4..a479dd2712a4 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -25,6 +25,7 @@ type(None), *py_sym_types, FakeScriptObject, + torch.ScriptObject, ] original_zip = zip diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 1c4fff02220d..4dc854781e40 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -493,6 +493,12 @@ def convert(idx, x): return shape_env.create_symintnode( shape_env.create_symbol(x, source), hint=x, source=source ) + if isinstance( + x, torch.ScriptObject + ) and torch._library.fake_class_registry.has_fake_class( + x._type().qualified_name() + ): + return torch._library.fake_class_registry.to_fake_obj(fake_mode, x) if not isinstance(x, torch.Tensor): return x if isinstance(x, FakeTensor): From c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Thu, 30 May 2024 14:23:41 -0700 Subject: [PATCH 279/706] [torchbind] always fakify script object by default in non-strict export (#127116) This diff can be risky for internal tests: any torchbind class that hasn't registered a fake class will fail and we should fix them. We've gained some confidence that this can work e2e by implementing FakeTensorQueue for TBE models in sigmoid with [D54210823](https://www.internalfb.com/diff/D54210823). Differential Revision: [D57991002](https://our.internmc.facebook.com/intern/diff/D57991002) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127116 Approved by: https://github.com/zou3519 ghstack dependencies: #127113, #127114 --- test/inductor/test_torchbind.py | 23 +++-------------------- torch/_export/non_strict_utils.py | 2 -- torch/_functorch/aot_autograd.py | 6 +----- 3 files changed, 4 insertions(+), 27 deletions(-) diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index e1bb0ad36d0b..fd8f15d8212c 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -1,5 +1,4 @@ # Owner(s): ["module: functorch"] -import unittest import torch import torch._dynamo @@ -8,30 +7,14 @@ import torch._inductor.decomposition from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.common_utils import ( - find_library_location, - IS_FBCODE, - IS_MACOS, - IS_SANDCASTLE, - IS_WINDOWS, -) + +from torch.testing._internal.torchbind_impls import init_torchbind_implementations class TestTorchbind(TestCase): def setUp(self): super().setUp() - if IS_MACOS: - raise unittest.SkipTest("non-portable load_library call used in test") - elif IS_SANDCASTLE or IS_FBCODE: - torch.ops.load_library( - "//caffe2/test/cpp/jit:test_custom_class_registrations" - ) - elif IS_WINDOWS: - lib_file_path = find_library_location("torchbind_test.dll") - torch.ops.load_library(str(lib_file_path)) - else: - lib_file_path = find_library_location("libtorchbind_test.so") - torch.ops.load_library(str(lib_file_path)) + init_torchbind_implementations() def get_exported_model(self): """ diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index d15cb29f28df..638a7db7e537 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -444,8 +444,6 @@ def _fakify_script_objects( fake_to_real = {} def _maybe_fakify_obj(obj): - if not torch._library.fake_class_registry.has_fake_class(obj._type().qualified_name()): # type: ignore[attr-defined] - return obj fake_obj = torch._library.fake_class_registry.to_fake_obj(fake_mode, obj) fake_to_real[fake_obj] = obj return fake_obj diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 4dc854781e40..a339b63a3ea8 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -493,11 +493,7 @@ def convert(idx, x): return shape_env.create_symintnode( shape_env.create_symbol(x, source), hint=x, source=source ) - if isinstance( - x, torch.ScriptObject - ) and torch._library.fake_class_registry.has_fake_class( - x._type().qualified_name() - ): + if isinstance(x, torch.ScriptObject): return torch._library.fake_class_registry.to_fake_obj(fake_mode, x) if not isinstance(x, torch.Tensor): return x From 406532f8649090f78c7e1f8dbc6b48d135ff71a7 Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Mon, 3 Jun 2024 21:46:50 +0000 Subject: [PATCH 280/706] [AMD] Fix power_draw api (#127729) Summary: average_socket_power only gives me NA. So we need to change it to current_socket_power Test Plan: Before `torch.cuda.power_draw` gives me NA, after it gives me the right power reading (e.g.441) Differential Revision: D58047484 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127729 Approved by: https://github.com/nmacchioni, https://github.com/eqy --- torch/cuda/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 7cbb53012fe1..2f2784f26c7a 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1043,7 +1043,7 @@ def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int: def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int: handle = _get_amdsmi_handler(device) - return amdsmi.amdsmi_get_power_info(handle)["average_socket_power"] + return amdsmi.amdsmi_get_power_info(handle)["current_socket_power"] def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int: From 01fc22056a3d7f3c9c0852826ec5dab17c0d0060 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 3 Jun 2024 22:01:46 +0000 Subject: [PATCH 281/706] [BE] enable UFMT for `torch/masked/` (#127715) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127715 Approved by: https://github.com/cpuhrsch --- .lintrunner.toml | 12 --- torch/masked/__init__.py | 33 +++---- torch/masked/_ops.py | 104 ++++++++++----------- torch/masked/maskedtensor/_ops_refs.py | 111 +++++++++++++++-------- torch/masked/maskedtensor/binary.py | 30 +++--- torch/masked/maskedtensor/core.py | 47 +++++++--- torch/masked/maskedtensor/creation.py | 11 ++- torch/masked/maskedtensor/passthrough.py | 1 + torch/masked/maskedtensor/reductions.py | 2 + torch/masked/maskedtensor/unary.py | 21 +++-- 10 files changed, 211 insertions(+), 161 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index eca2af96b761..8f7be27ece84 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1674,18 +1674,6 @@ exclude_patterns = [ 'torch/hub.py', 'torch/library.py', 'torch/linalg/__init__.py', - # UFMT causes import cycle on masked - 'torch/masked/__init__.py', - 'torch/masked/_docs.py', - 'torch/masked/_ops.py', - 'torch/masked/maskedtensor/__init__.py', - 'torch/masked/maskedtensor/_ops_refs.py', - 'torch/masked/maskedtensor/binary.py', - 'torch/masked/maskedtensor/core.py', - 'torch/masked/maskedtensor/creation.py', - 'torch/masked/maskedtensor/passthrough.py', - 'torch/masked/maskedtensor/reductions.py', - 'torch/masked/maskedtensor/unary.py', 'torch/monitor/__init__.py', 'torch/nested/__init__.py', 'torch/nn/__init__.py', diff --git a/torch/masked/__init__.py b/torch/masked/__init__.py index e0193416ed2f..18d1b9f9e283 100644 --- a/torch/masked/__init__.py +++ b/torch/masked/__init__.py @@ -1,33 +1,34 @@ -from .maskedtensor.core import is_masked_tensor, MaskedTensor -from .maskedtensor.creation import as_masked_tensor, masked_tensor -from ._ops import ( +from torch.masked._ops import ( _canonical_dim, + _combine_input_and_mask, _generate_docstring, - _reduction_identity, - _where, _input_mask, _output_mask, - _combine_input_and_mask, - sum, - prod, - cumsum, - cumprod, + _reduction_identity, + _where, amax, amin, argmax, argmin, + cumprod, + cumsum, + log_softmax, + logaddexp, + logsumexp, mean, median, - logsumexp, - logaddexp, norm, - var, - std, + normalize, + prod, softmax, - log_softmax, softmin, - normalize, + std, + sum, + var, ) +from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor +from torch.masked.maskedtensor.creation import as_masked_tensor, masked_tensor + __all__ = [ "as_masked_tensor", diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index b7872a6d4cf4..0c082f7cd01f 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1,15 +1,12 @@ - import warnings - -# A workaround to support both TorchScript and MyPy: from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union import torch -from torch import Tensor -from torch.masked import as_masked_tensor, is_masked_tensor, MaskedTensor -from . import _docs +from torch import sym_float, Tensor from torch._prims_common import corresponding_real_dtype -from torch import sym_float +from torch.masked import _docs +from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor +from torch.masked.maskedtensor.creation import as_masked_tensor if TYPE_CHECKING: from torch.types import _dtype as DType @@ -469,7 +466,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]: raise RuntimeError(f"dim={d} appears multiple times in the list of dims") if d >= ndim or d < -ndim: raise IndexError( - f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})" + f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})" ) dims.append(d % ndim) return tuple(sorted(dims)) @@ -1420,7 +1417,6 @@ def median( dtype: Optional[DType] = None, mask: Optional[Tensor] = None, ) -> Tensor: - """\ {reduction_signature} {reduction_descr} @@ -1487,46 +1483,45 @@ def logaddexp( ) -> Tensor: """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor -Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other` -tensor. The :attr:`input` elements are masked out according to the boolean tensor -:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor -:attr:`other_mask`. - -The shapes of a mask tensor and the tensor to be masked -don't need to match, but they must be :ref:`broadcastable -` and the dimensionality of the mask -tensor must not be greater than of the tensor to be masked. - -Args: - input (Tensor): the input tensor - other (Tensor): the second input tensor - -Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type - of returned tensor. If specified, the output tensor is - casted to :attr:`dtype` after the operation is - performed. Default: None. - input_mask (:class:`torch.Tensor`, optional): the boolean tensor - containing the binary mask of validity of :attr:`input` tensor elements. - Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. - other_mask (:class:`torch.Tensor`, optional): the boolean tensor - containing the binary mask of validity of :attr:`other` tensor elements. - Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``. - -Example:: - - >>> input = torch.tensor([-100.0, -200, -300]) - >>> input - tensor([-100., -200., -300.]) - >>> other = torch.tensor([-1.0, -2, -3]) - >>> other - tensor([-1., -2., -3.]) - >>> mask = torch.tensor([True, False, True]) - >>> mask - tensor([ True, False, True]) - >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask) - tensor([-1., -inf, -3.]) -""" + Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other` + tensor. The :attr:`input` elements are masked out according to the boolean tensor + :attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor + :attr:`other_mask`. + + The shapes of a mask tensor and the tensor to be masked + don't need to match, but they must be :ref:`broadcastable + ` and the dimensionality of the mask + tensor must not be greater than of the tensor to be masked. + + Args: + input (Tensor): the input tensor + other (Tensor): the second input tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the output tensor is + casted to :attr:`dtype` after the operation is + performed. Default: None. + input_mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of :attr:`input` tensor elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + other_mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of :attr:`other` tensor elements. + Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``. + + Example:: + + >>> input = torch.tensor([-100.0, -200, -300]) + >>> input + tensor([-100., -200., -300.]) + >>> other = torch.tensor([-1.0, -2, -3]) + >>> other + tensor([-1., -2., -3.]) + >>> mask = torch.tensor([True, False, True]) + >>> mask + tensor([ True, False, True]) + >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask) + tensor([-1., -inf, -3.])""" if dtype is None: dtype = input.dtype if input.layout == torch.strided and other.layout == torch.strided: @@ -1586,7 +1581,9 @@ def _std_var( mask: Optional[Tensor], take_sqrt: Optional[bool], ) -> Tensor: - assert (unbiased is None or correction_opt is None), "Only one of unbiased and correction may be given" + assert ( + unbiased is None or correction_opt is None + ), "Only one of unbiased and correction may be given" correction = 1.0 if unbiased is not None: correction = 1.0 if unbiased else 0.0 @@ -1632,8 +1629,11 @@ def _std_var( if not keepdim: count = count.reshape(total.shape) if correction != 0: - real_dtype = (corresponding_real_dtype(compute_dtype) - if compute_dtype.is_complex else compute_dtype) + real_dtype = ( + corresponding_real_dtype(compute_dtype) + if compute_dtype.is_complex + else compute_dtype + ) count = count.to(real_dtype) count = torch.subtract(count, correction) count = torch.maximum(count, count.new_zeros([])) diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 81a890af5d65..69c947bc262d 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -1,43 +1,45 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from functools import partial -from typing import Callable, Any, Dict, TYPE_CHECKING +from typing import Any, Callable, Dict, TYPE_CHECKING + import torch if TYPE_CHECKING: import torch._ops -from .binary import ( - _apply_native_binary, - NATIVE_BINARY_FNS, - NATIVE_INPLACE_BINARY_FNS, -) -from .core import is_masked_tensor, MaskedTensor, _get_data, _masks_match, _maybe_get_mask -from .passthrough import ( - _apply_pass_through_fn, - PASSTHROUGH_FNS +from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS +from .core import ( + _get_data, + _masks_match, + _maybe_get_mask, + is_masked_tensor, + MaskedTensor, ) +from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS from .reductions import ( _apply_reduction, NATIVE_REDUCE_FNS, - TORCH_REDUCE_FNS, TENSOR_REDUCE_FNS, + TORCH_REDUCE_FNS, ) -from .unary import ( - _apply_native_unary, - NATIVE_UNARY_FNS, - NATIVE_INPLACE_UNARY_FNS, -) +from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS __all__ = [] # type: ignore[var-annotated] -def _check_args_kwargs_length(args, kwargs, error_prefix, len_args=None, len_kwargs=None): +def _check_args_kwargs_length( + args, kwargs, error_prefix, len_args=None, len_kwargs=None +): if len_args is not None and len_args != len(args): - raise ValueError(f"{error_prefix}: len(args) must be {len_args} but got {len(args)}") + raise ValueError( + f"{error_prefix}: len(args) must be {len_args} but got {len(args)}" + ) if len_kwargs is not None and len_kwargs != len(kwargs): - raise ValueError(f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}") + raise ValueError( + f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}" + ) class _MaskedContiguous(torch.autograd.Function): @@ -116,7 +118,9 @@ def forward(ctx, input): raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.") if input._masked_data.ndim != 2: - raise ValueError(f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}") + raise ValueError( + f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}" + ) if input.layout == torch.sparse_csr: return input @@ -157,7 +161,11 @@ def masked_out_like(mt): _MASKEDTENSOR_FUNCTION_TABLE = {} _function_fn_apply_map = { - (tuple(NATIVE_REDUCE_FNS), tuple(TORCH_REDUCE_FNS), tuple(TENSOR_REDUCE_FNS)): _apply_reduction, + ( + tuple(NATIVE_REDUCE_FNS), + tuple(TORCH_REDUCE_FNS), + tuple(TENSOR_REDUCE_FNS), + ): _apply_reduction, } for fn_map_list, apply_fn in _function_fn_apply_map.items(): @@ -177,9 +185,11 @@ def register_function_func(ops): def foo(func, *args, **kwargs): """ + def wrapper(func): for op in ops: _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op) + return wrapper @@ -190,7 +200,9 @@ def _general_function_reductions(func, *args, **kwargs): @register_function_func([torch.Tensor.where, torch.where]) def _function_where(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0 + ) return _MaskedWhere.apply(*args) @@ -216,6 +228,7 @@ def _function_to_sparse_csr(func, *args, **kwargs): _MASKEDTENSOR_DISPATCH_TABLE: Dict["torch._ops.OpOverload", Callable[..., Any]] = {} + def register_dispatch_func(aten_ops): """ Used for registering a new __torch_dispatch__ function to MaskedTensor @@ -227,9 +240,11 @@ def register_dispatch_func(aten_ops): def foo(func, *args, **kwargs): """ + def wrapper(func): for aten_op in aten_ops: _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op) + return wrapper @@ -272,9 +287,7 @@ def layout(func, *args, **kwargs): def is_contiguous(func, *args, **kwargs): data = _get_data(args[0]) if data.is_sparse: - raise ValueError( - "MaskedTensors with sparse data do not have is_contiguous" - ) + raise ValueError("MaskedTensors with sparse data do not have is_contiguous") return func(data, *args[1:], **kwargs) @@ -301,9 +314,7 @@ def is_non_overlapping_and_dense(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten.contiguous]) def contiguous(func, *args, **kwargs): if _get_data(args[0]).is_sparse: - raise ValueError( - "MaskedTensors with sparse data do not have contiguous" - ) + raise ValueError("MaskedTensors with sparse data do not have contiguous") return _MaskedContiguous.apply(args[0]) @@ -313,9 +324,13 @@ def new_empty_strided(func, *args, **kwargs): data = _get_data(args[0]) mask = _maybe_get_mask(args[0]) if tuple(args[1]) != tuple(data.size()): - raise ValueError(f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()") + raise ValueError( + f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()" + ) if tuple(args[2]) != tuple(data.stride()): - raise ValueError(f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()") + raise ValueError( + f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()" + ) return MaskedTensor(func(data, args[1], args[2], **kwargs), mask) @@ -339,7 +354,9 @@ def _to_copy(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._softmax]) def _softmax(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0 + ) data = _get_data(args[0]) mask = _maybe_get_mask(args[0]) result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2) @@ -359,7 +376,9 @@ def _softmax_backward_data(func, *args, **kwargs): grad, output, dim, input_dtype = args if is_masked_tensor(grad) and is_masked_tensor(output): if not _masks_match(grad, output): - raise ValueError("__torch_dispatch__, {func}: expected the masks of grad and output to match") + raise ValueError( + "__torch_dispatch__, {func}: expected the masks of grad and output to match" + ) grad_data = _get_data(grad) new_grad_data = torch.ops.aten._masked_softmax_backward( grad_data, @@ -370,7 +389,9 @@ def _softmax_backward_data(func, *args, **kwargs): res = MaskedTensor(new_grad_data, _maybe_get_mask(grad)) return res else: - raise ValueError(f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors") + raise ValueError( + f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors" + ) @register_dispatch_func([torch.ops.aten.copy_]) @@ -384,7 +405,9 @@ def copy_(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten.where]) def where(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0 + ) if not torch.is_tensor(args[0]): raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mx = args[1] @@ -400,7 +423,9 @@ def where(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_sparse]) def _to_sparse(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) if not torch.is_tensor(args[0]): raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mt = args[0] @@ -415,7 +440,9 @@ def _to_sparse(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_sparse_csr]) def _to_sparse_csr(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) if not torch.is_tensor(args[0]): raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mt = args[0] @@ -430,7 +457,9 @@ def _to_sparse_csr(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_dense]) def _to_dense(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) if not torch.is_tensor(args[0]): raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mt = args[0] @@ -444,14 +473,18 @@ def _to_dense(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._indices]) def _indices(func, *args, **kwargs): # Assumes data is sparse - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) data = _get_data(args[0]).indices() return MaskedTensor(data, torch.ones_like(data).bool()) @register_dispatch_func([torch.ops.aten._values]) def _values(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) data = _get_data(args[0]).values() return MaskedTensor(data, torch.ones_like(data).bool()) diff --git a/torch/masked/maskedtensor/binary.py b/torch/masked/maskedtensor/binary.py index 087ea95916e5..b035678f73a6 100644 --- a/torch/masked/maskedtensor/binary.py +++ b/torch/masked/maskedtensor/binary.py @@ -2,7 +2,14 @@ import torch -from .core import _map_mt_args_kwargs, _masks_match, _tensors_match, _wrap_result, is_masked_tensor +from .core import ( + _map_mt_args_kwargs, + _masks_match, + _tensors_match, + _wrap_result, + is_masked_tensor, +) + __all__ = [] # type: ignore[var-annotated] @@ -79,25 +86,22 @@ def _binary_helper(fn, args, kwargs, inplace): raise ValueError("len(kwargs) must equal 0") for a in args[2:]: if torch.is_tensor(a): - raise TypeError("MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs") + raise TypeError( + "MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs" + ) if not _masks_match(*args[:2]): raise ValueError( "Input masks must match. If you need support for this, please open an issue on Github." ) - data_args, data_kwargs = _map_mt_args_kwargs( - args, kwargs, lambda x: x.get_data() - ) - mask_args, mask_kwargs = _map_mt_args_kwargs( - args, kwargs, lambda x: x.get_mask() - ) + data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data()) + mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask()) args0_layout = data_args[0].layout same_layout = ( - (torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and - (args0_layout == data_args[1].layout) - ) + torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1]) + ) and (args0_layout == data_args[1].layout) if args0_layout == torch.sparse_coo: if same_layout: @@ -106,7 +110,9 @@ def _binary_helper(fn, args, kwargs, inplace): "sparse_coo indices must match. If you need support for this, please open an issue on Github." ) if data_args[0].size() != data_args[1].size(): - raise ValueError("input1 and input2 must have the same size for binary functions.") + raise ValueError( + "input1 and input2 must have the same size for binary functions." + ) data_args[1] = data_args[1].values() diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index d2002048edd9..4574fed9c0d6 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -13,7 +13,7 @@ def is_masked_tensor(a): - r""" Returns True if the input is a MaskedTensor, else False + r"""Returns True if the input is a MaskedTensor, else False Args: a: any input @@ -35,7 +35,9 @@ def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08): if is_masked_tensor(a) or is_masked_tensor(b): raise ValueError("Neither `a` nor `b` can be a MaskedTensor.") if a.layout != b.layout: - raise ValueError(f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}") + raise ValueError( + f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}" + ) if a.dtype != b.dtype: b = b.type(a.dtype) @@ -108,9 +110,7 @@ def _masked_tensor_str(data, mask, formatter): formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item()) for d in data ] - max_len = max( - 8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask) - ) + max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask)) return ( "[" + ", ".join( @@ -153,13 +153,21 @@ def __new__(cls, data, mask, requires_grad=False): kwargs["requires_grad"] = requires_grad kwargs["dispatch_sizes_strides_policy"] = "strides" kwargs["dispatch_layout"] = True - warnings.warn(("The PyTorch API of MaskedTensors is in prototype stage " - "and will change in the near future. Please open a Github issue " - "for features requests and see our documentation on the torch.masked " - "module for further information about the project."), UserWarning) + warnings.warn( + ( + "The PyTorch API of MaskedTensors is in prototype stage " + "and will change in the near future. Please open a Github issue " + "for features requests and see our documentation on the torch.masked " + "module for further information about the project." + ), + UserWarning, + ) if data.requires_grad: - warnings.warn("It is not recommended to create a MaskedTensor with a tensor that requires_grad. " - "To avoid this, you can use data.clone().detach()", UserWarning) + warnings.warn( + "It is not recommended to create a MaskedTensor with a tensor that requires_grad. " + "To avoid this, you can use data.clone().detach()", + UserWarning, + ) return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined] def _preprocess_data(self, data, mask): @@ -184,17 +192,23 @@ def _validate_members(self): data = self._masked_data mask = self.get_mask() if type(data) != type(mask): - raise TypeError(f"data and mask must have the same type. Got {type(data)} and {type(mask)}") + raise TypeError( + f"data and mask must have the same type. Got {type(data)} and {type(mask)}" + ) if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: raise TypeError(f"data layout of {data.layout} is not supported.") if data.layout == torch.sparse_coo: if not _tensors_match(data.indices(), mask.indices(), exact=True): - raise ValueError("data and mask are both sparse COO tensors but do not have the same indices.") + raise ValueError( + "data and mask are both sparse COO tensors but do not have the same indices." + ) elif data.layout == torch.sparse_csr: if not _tensors_match( data.crow_indices(), mask.crow_indices(), exact=True ) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True): - raise ValueError("data and mask are both sparse CSR tensors but do not share either crow or col indices.") + raise ValueError( + "data and mask are both sparse CSR tensors but do not share either crow or col indices." + ) if mask.dtype != torch.bool: raise TypeError("mask must have dtype bool.") if not ( @@ -219,7 +233,8 @@ def __init__(self, data, mask, requires_grad=False): @staticmethod def _from_values(data, mask): - """ Differentiable constructor for MaskedTensor """ + """Differentiable constructor for MaskedTensor""" + class Constructor(torch.autograd.Function): @staticmethod def forward(ctx, data, mask): @@ -265,6 +280,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE + if func in _MASKEDTENSOR_FUNCTION_TABLE: return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs) @@ -286,6 +302,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func = func.overloadpacket from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE + if func in _MASKEDTENSOR_DISPATCH_TABLE: return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs) diff --git a/torch/masked/maskedtensor/creation.py b/torch/masked/maskedtensor/creation.py index 861984a21e1c..6b490edfc058 100644 --- a/torch/masked/maskedtensor/creation.py +++ b/torch/masked/maskedtensor/creation.py @@ -2,20 +2,21 @@ from .core import MaskedTensor + __all__ = [ "as_masked_tensor", "masked_tensor", ] -"""" -These two factory functions are intended to mirror - torch.tensor - guaranteed to be a leaf node - torch.as_tensor - differentiable constructor that preserves the autograd history -""" +# These two factory functions are intended to mirror +# torch.tensor - guaranteed to be a leaf node +# torch.as_tensor - differentiable constructor that preserves the autograd history + def masked_tensor(data, mask, requires_grad=False): return MaskedTensor(data, mask, requires_grad) + def as_masked_tensor(data, mask): return MaskedTensor._from_values(data, mask) diff --git a/torch/masked/maskedtensor/passthrough.py b/torch/masked/maskedtensor/passthrough.py index 91c9e5f81830..d8c87a9c2110 100644 --- a/torch/masked/maskedtensor/passthrough.py +++ b/torch/masked/maskedtensor/passthrough.py @@ -10,6 +10,7 @@ from .core import _map_mt_args_kwargs, _wrap_result + __all__ = [] # type: ignore[var-annotated] diff --git a/torch/masked/maskedtensor/reductions.py b/torch/masked/maskedtensor/reductions.py index 737f4b240beb..d36df2715c0b 100644 --- a/torch/masked/maskedtensor/reductions.py +++ b/torch/masked/maskedtensor/reductions.py @@ -7,6 +7,7 @@ from .core import is_masked_tensor from .creation import as_masked_tensor, masked_tensor + __all__ = [] # type: ignore[var-annotated] @@ -159,6 +160,7 @@ def grad_reduce(*args, **kwargs): TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys()) TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys()) + def _is_reduction(fn): return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP diff --git a/torch/masked/maskedtensor/unary.py b/torch/masked/maskedtensor/unary.py index b3d5c136bfd4..4bfe987ef004 100644 --- a/torch/masked/maskedtensor/unary.py +++ b/torch/masked/maskedtensor/unary.py @@ -4,6 +4,7 @@ from .core import _map_mt_args_kwargs, _wrap_result + __all__ = [] # type: ignore[var-annotated] @@ -108,18 +109,18 @@ def _unary_helper(fn, args, kwargs, inplace): if len(kwargs) != 0: - raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. " - "If you need support for this, please open an issue on Github.") + raise ValueError( + "MaskedTensor unary ops require that len(kwargs) == 0. " + "If you need support for this, please open an issue on Github." + ) for a in args[1:]: if torch.is_tensor(a): - raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments") - - mask_args, mask_kwargs = _map_mt_args_kwargs( - args, kwargs, lambda x: x._masked_mask - ) - data_args, data_kwargs = _map_mt_args_kwargs( - args, kwargs, lambda x: x._masked_data - ) + raise TypeError( + "MaskedTensor unary ops do not support additional Tensor arguments" + ) + + mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_mask) + data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_data) if args[0].layout == torch.sparse_coo: data_args[0] = data_args[0].coalesce() From 6faa3d5f18bf079fc7568433302b5b145a8ea4ff Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 3 Jun 2024 10:06:02 -0700 Subject: [PATCH 282/706] Onboard ARM bfloat16 to gemm-by-dot-product-for-gemm_transa_ infrastructure (#127477) Summary: This gets us a baseline level of reasonable performance for bfloat16 matrix-vector and matrix-matrix multiplication on my Apple M1. I've intentionally left using intrinsics for future work. Test Plan: Used https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py (modified to run larger sizes) to benchmark a range of LLM-interesting matrix-vector and matrix-matrix sizes on my Apple M1 Pro. bfloat16 performance is improved across the board (except possibly for very small cases) and now exceeds float32 performance (as it should) for the matrix-vector cases. Before: ``` Matrix-vector: m=8, n=128, k=1 ==================== trans_b torch.float32 0.75 usec trans_b torch.float16 0.71 usec trans_b torch.bfloat16 0.81 usec m=128, n=8, k=1 ==================== trans_b torch.float32 0.75 usec trans_b torch.float16 0.93 usec trans_b torch.bfloat16 0.98 usec m=4096, n=4096, k=1 ==================== trans_b torch.float32 2194.31 usec trans_b torch.float16 661.27 usec trans_b torch.bfloat16 3758.42 usec m=11008, n=4096, k=1 ==================== trans_b torch.float32 5792.04 usec trans_b torch.float16 1789.98 usec trans_b torch.bfloat16 10120.67 usec m=4096, n=11008, k=1 ==================== trans_b torch.float32 6101.22 usec trans_b torch.float16 1927.34 usec trans_b torch.bfloat16 10469.47 usec m=32000, n=4096, k=1 ==================== trans_b torch.float32 18353.20 usec trans_b torch.float16 5161.06 usec trans_b torch.bfloat16 29601.69 usec Matrix-matrix (prompt len 4: m=8, n=128, k=4 ==================== trans_b torch.float32 2.14 usec trans_b torch.float16 0.85 usec trans_b torch.bfloat16 1.19 usec m=128, n=8, k=4 ==================== trans_b torch.float32 1.47 usec trans_b torch.float16 1.85 usec trans_b torch.bfloat16 1.75 usec m=4096, n=4096, k=4 ==================== trans_b torch.float32 4416.40 usec trans_b torch.float16 2688.36 usec trans_b torch.bfloat16 14987.33 usec m=11008, n=4096, k=4 ==================== trans_b torch.float32 6140.24 usec trans_b torch.float16 7467.26 usec trans_b torch.bfloat16 40295.52 usec m=4096, n=11008, k=4 ==================== trans_b torch.float32 6143.10 usec trans_b torch.float16 7298.04 usec trans_b torch.bfloat16 41393.43 usec m=32000, n=4096, k=4 ==================== trans_b torch.float32 17650.72 usec trans_b torch.float16 21346.63 usec trans_b torch.bfloat16 116849.98 usec Matrix-matrix (prompt len 8: m=8, n=128, k=8 ==================== trans_b torch.float32 1.05 usec trans_b torch.float16 1.03 usec trans_b torch.bfloat16 1.69 usec m=128, n=8, k=8 ==================== trans_b torch.float32 2.05 usec trans_b torch.float16 3.08 usec trans_b torch.bfloat16 2.95 usec m=4096, n=4096, k=8 ==================== trans_b torch.float32 2323.99 usec trans_b torch.float16 5265.45 usec trans_b torch.bfloat16 29942.40 usec m=11008, n=4096, k=8 ==================== trans_b torch.float32 6202.01 usec trans_b torch.float16 14677.90 usec trans_b torch.bfloat16 80625.18 usec m=4096, n=11008, k=8 ==================== trans_b torch.float32 6112.05 usec trans_b torch.float16 14340.52 usec trans_b torch.bfloat16 82799.99 usec m=32000, n=4096, k=8 ==================== trans_b torch.float32 17650.65 usec trans_b torch.float16 42551.43 usec trans_b torch.bfloat16 236081.08 usec Matrix-matrix (prompt len 16: m=8, n=128, k=16 ==================== trans_b torch.float32 1.26 usec trans_b torch.float16 1.34 usec trans_b torch.bfloat16 2.69 usec m=128, n=8, k=16 ==================== trans_b torch.float32 1.60 usec trans_b torch.float16 5.81 usec trans_b torch.bfloat16 5.34 usec m=4096, n=4096, k=16 ==================== trans_b torch.float32 2328.05 usec trans_b torch.float16 10526.58 usec trans_b torch.bfloat16 60028.28 usec m=11008, n=4096, k=16 ==================== trans_b torch.float32 6243.35 usec trans_b torch.float16 28505.08 usec trans_b torch.bfloat16 163670.15 usec m=4096, n=11008, k=16 ==================== trans_b torch.float32 5870.11 usec trans_b torch.float16 28597.89 usec trans_b torch.bfloat16 165404.88 usec m=32000, n=4096, k=16 ==================== trans_b torch.float32 17746.27 usec trans_b torch.float16 83393.87 usec trans_b torch.bfloat16 472313.13 usec Matrix-matrix (prompt len 32: m=8, n=128, k=32 ==================== trans_b torch.float32 1.35 usec trans_b torch.float16 2.01 usec trans_b torch.bfloat16 4.68 usec m=128, n=8, k=32 ==================== trans_b torch.float32 1.19 usec trans_b torch.float16 10.98 usec trans_b torch.bfloat16 10.13 usec m=4096, n=4096, k=32 ==================== trans_b torch.float32 2525.29 usec trans_b torch.float16 23106.71 usec trans_b torch.bfloat16 122987.04 usec m=11008, n=4096, k=32 ==================== trans_b torch.float32 6131.34 usec trans_b torch.float16 57537.41 usec trans_b torch.bfloat16 327825.00 usec m=4096, n=11008, k=32 ==================== trans_b torch.float32 6395.01 usec trans_b torch.float16 57456.33 usec trans_b torch.bfloat16 331325.58 usec m=32000, n=4096, k=32 ==================== trans_b torch.float32 19078.68 usec trans_b torch.float16 167735.08 usec trans_b torch.bfloat16 975736.88 usec Matrix-matrix (prompt len 128: m=8, n=128, k=128 ==================== trans_b torch.float32 2.40 usec trans_b torch.float16 6.07 usec trans_b torch.bfloat16 16.83 usec m=128, n=8, k=128 ==================== trans_b torch.float32 1.78 usec trans_b torch.float16 40.35 usec trans_b torch.bfloat16 37.21 usec m=4096, n=4096, k=128 ==================== trans_b torch.float32 4827.60 usec trans_b torch.float16 84341.24 usec trans_b torch.bfloat16 478917.75 usec m=11008, n=4096, k=128 ==================== trans_b torch.float32 11879.96 usec trans_b torch.float16 226484.33 usec trans_b torch.bfloat16 1289465.50 usec m=4096, n=11008, k=128 ==================== trans_b torch.float32 10707.75 usec trans_b torch.float16 229200.58 usec trans_b torch.bfloat16 1327416.67 usec m=32000, n=4096, k=128 ==================== trans_b torch.float32 33306.32 usec trans_b torch.float16 662898.21 usec trans_b torch.bfloat16 3815866.63 usec ``` After: ``` Matrix-vector: m=8, n=128, k=1 ==================== trans_b torch.float32 0.77 usec trans_b torch.float16 0.72 usec trans_b torch.bfloat16 0.77 usec m=128, n=8, k=1 ==================== trans_b torch.float32 0.73 usec trans_b torch.float16 0.93 usec trans_b torch.bfloat16 1.56 usec m=4096, n=4096, k=1 ==================== trans_b torch.float32 2195.22 usec trans_b torch.float16 675.40 usec trans_b torch.bfloat16 1038.29 usec m=11008, n=4096, k=1 ==================== trans_b torch.float32 5980.27 usec trans_b torch.float16 1806.08 usec trans_b torch.bfloat16 2756.46 usec m=4096, n=11008, k=1 ==================== trans_b torch.float32 6339.95 usec trans_b torch.float16 1844.71 usec trans_b torch.bfloat16 2726.52 usec m=32000, n=4096, k=1 ==================== trans_b torch.float32 18137.17 usec trans_b torch.float16 6020.75 usec trans_b torch.bfloat16 8612.89 usec Matrix-matrix (prompt len 4: m=8, n=128, k=4 ==================== trans_b torch.float32 2.24 usec trans_b torch.float16 0.91 usec trans_b torch.bfloat16 1.07 usec m=128, n=8, k=4 ==================== trans_b torch.float32 1.58 usec trans_b torch.float16 1.96 usec trans_b torch.bfloat16 2.11 usec m=4096, n=4096, k=4 ==================== trans_b torch.float32 4583.43 usec trans_b torch.float16 3014.04 usec trans_b torch.bfloat16 4434.04 usec m=11008, n=4096, k=4 ==================== trans_b torch.float32 6245.55 usec trans_b torch.float16 7513.82 usec trans_b torch.bfloat16 11207.80 usec m=4096, n=11008, k=4 ==================== trans_b torch.float32 6096.22 usec trans_b torch.float16 7688.82 usec trans_b torch.bfloat16 11143.72 usec m=32000, n=4096, k=4 ==================== trans_b torch.float32 17982.88 usec trans_b torch.float16 22001.28 usec trans_b torch.bfloat16 32470.62 usec Matrix-matrix (prompt len 8: m=8, n=128, k=8 ==================== trans_b torch.float32 1.05 usec trans_b torch.float16 1.02 usec trans_b torch.bfloat16 1.44 usec m=128, n=8, k=8 ==================== trans_b torch.float32 2.07 usec trans_b torch.float16 3.10 usec trans_b torch.bfloat16 3.38 usec m=4096, n=4096, k=8 ==================== trans_b torch.float32 2245.43 usec trans_b torch.float16 5597.87 usec trans_b torch.bfloat16 8775.08 usec m=11008, n=4096, k=8 ==================== trans_b torch.float32 6227.68 usec trans_b torch.float16 15102.41 usec trans_b torch.bfloat16 22457.37 usec m=4096, n=11008, k=8 ==================== trans_b torch.float32 6082.16 usec trans_b torch.float16 15131.57 usec trans_b torch.bfloat16 21860.15 usec m=32000, n=4096, k=8 ==================== trans_b torch.float32 19659.00 usec trans_b torch.float16 45075.64 usec trans_b torch.bfloat16 67746.75 usec Matrix-matrix (prompt len 16: m=8, n=128, k=16 ==================== trans_b torch.float32 1.31 usec trans_b torch.float16 1.41 usec trans_b torch.bfloat16 2.04 usec m=128, n=8, k=16 ==================== trans_b torch.float32 1.66 usec trans_b torch.float16 5.76 usec trans_b torch.bfloat16 6.37 usec m=4096, n=4096, k=16 ==================== trans_b torch.float32 2271.34 usec trans_b torch.float16 11198.46 usec trans_b torch.bfloat16 16893.54 usec m=11008, n=4096, k=16 ==================== trans_b torch.float32 6266.85 usec trans_b torch.float16 29342.49 usec trans_b torch.bfloat16 45159.22 usec m=4096, n=11008, k=16 ==================== trans_b torch.float32 5999.16 usec trans_b torch.float16 29157.43 usec trans_b torch.bfloat16 43295.81 usec m=32000, n=4096, k=16 ==================== trans_b torch.float32 18028.83 usec trans_b torch.float16 89626.88 usec trans_b torch.bfloat16 128164.62 usec Matrix-matrix (prompt len 32: m=8, n=128, k=32 ==================== trans_b torch.float32 1.38 usec trans_b torch.float16 2.03 usec trans_b torch.bfloat16 3.29 usec m=128, n=8, k=32 ==================== trans_b torch.float32 1.24 usec trans_b torch.float16 10.58 usec trans_b torch.bfloat16 11.97 usec m=4096, n=4096, k=32 ==================== trans_b torch.float32 2591.56 usec trans_b torch.float16 21683.62 usec trans_b torch.bfloat16 32657.68 usec m=11008, n=4096, k=32 ==================== trans_b torch.float32 6468.43 usec trans_b torch.float16 57811.33 usec trans_b torch.bfloat16 89263.21 usec m=4096, n=11008, k=32 ==================== trans_b torch.float32 6034.74 usec trans_b torch.float16 59372.56 usec trans_b torch.bfloat16 88107.85 usec m=32000, n=4096, k=32 ==================== trans_b torch.float32 18609.27 usec trans_b torch.float16 167298.00 usec trans_b torch.bfloat16 255116.37 usec Matrix-matrix (prompt len 128: m=8, n=128, k=128 ==================== trans_b torch.float32 2.44 usec trans_b torch.float16 6.11 usec trans_b torch.bfloat16 10.92 usec m=128, n=8, k=128 ==================== trans_b torch.float32 1.80 usec trans_b torch.float16 40.26 usec trans_b torch.bfloat16 44.82 usec m=4096, n=4096, k=128 ==================== trans_b torch.float32 4773.29 usec trans_b torch.float16 84458.54 usec trans_b torch.bfloat16 131248.58 usec m=11008, n=4096, k=128 ==================== trans_b torch.float32 12249.16 usec trans_b torch.float16 234411.87 usec trans_b torch.bfloat16 351970.71 usec m=4096, n=11008, k=128 ==================== trans_b torch.float32 11439.24 usec trans_b torch.float16 233347.04 usec trans_b torch.bfloat16 354475.96 usec m=32000, n=4096, k=128 ==================== trans_b torch.float32 33803.03 usec trans_b torch.float16 688157.54 usec trans_b torch.bfloat16 1048221.42 usec ``` Also ran the stock configuration; it was unchanged, indicating that we need to integrate this path with torch.mv separately, which will come in a follow-up PR.l Pull Request resolved: https://github.com/pytorch/pytorch/pull/127477 Approved by: https://github.com/malfet --- aten/src/ATen/native/BlasKernel.cpp | 52 ++++++++++++++++++++++--- aten/src/ATen/native/cpu/BlasKernel.cpp | 21 ++++------ 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index 642467e5c1e6..8f3138966be4 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -384,7 +384,7 @@ static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { return vaddvq_f32(x[0]); } -static C10_ALWAYS_INLINE void fp16_dot_with_fp32_arith_main_inner_loop( +static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop( const float16_t* vec1, const float16_t* vec2, float32x4_t sum[kF32RegistersPerIteration], @@ -397,7 +397,7 @@ static C10_ALWAYS_INLINE void fp16_dot_with_fp32_arith_main_inner_loop( sum[2 * registerPairIndex + 1] = f32_fma_high_f16(sum[2 * registerPairIndex + 1], temp_vec1, temp_vec2); } -static C10_ALWAYS_INLINE void fp16_dot_with_fp32_arith_vectorized_tail_inner_loop( +static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( const float16_t* vec1, const float16_t* vec2, float32x4_t* tailSum, @@ -407,14 +407,48 @@ static C10_ALWAYS_INLINE void fp16_dot_with_fp32_arith_vectorized_tail_inner_loo *tailSum = f32_fma_f16(*tailSum, temp_vec1, temp_vec2); } -float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int64_t len) { +static C10_ALWAYS_INLINE float32x4_t to_bfloat16(uint16x4_t u16) { + int32x4_t shift = vdupq_n_s32(16); + return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift)); +} + +static C10_ALWAYS_INLINE float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { + return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); +} + +static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + float32x4_t sum[kF32RegistersPerIteration], + int registerPairIndex) { + // TODO: detect intrinsic availability, use them if they're available. __ARM_FEATURE_BF16 + // Load a pair of f32 registers at a time. + const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + + sum[2 * registerPairIndex] = f32_fma_bf16(sum[2 * registerPairIndex], vget_low_u16(temp_vec1), vget_low_u16(temp_vec2)); + sum[2 * registerPairIndex + 1] = f32_fma_bf16(sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2)); +} + +static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + float32x4_t* tailSum, + int idx) { + const auto temp_vec1 = vld1_u16(reinterpret_cast(&vec1[idx])); + const auto temp_vec2 = vld1_u16(reinterpret_cast(&vec2[idx])); + *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); +} + +template +float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { const auto* vec1_ = vec1 + j; const auto* vec2_ = vec2 + j; c10::ForcedUnroll{}([vec1_, vec2_, &sum](auto k) { - fp16_dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k); + dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k); }); } auto reducedSum = reduce(sum); @@ -425,7 +459,7 @@ float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int float32x4_t tailSum = vdupq_n_f32(0); const auto len_aligned_4 = len & ~3; for (int j = len_aligned; j < len_aligned_4; j += 4) { - fp16_dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j); + dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j); } auto reducedTail = vpaddq_f32(tailSum, tailSum); reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0); @@ -437,6 +471,14 @@ float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int return reducedSum; } +float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int64_t len) { + return dot_with_fp32_arith(vec1, vec2, len); +} + +float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) { + return dot_with_fp32_arith(vec1, vec2, len); +} + // On my Apple M1 Macbook (which is ARM v8.5 and thus has the // instructions f32_fma_{low,high}_f16 is targeting), this kernel has // equivalent performance to the fp16-native kernel. diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index 387d6840999a..b664cdf262ad 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -38,6 +38,11 @@ float fp16_dot_with_fp32_arith( const float16_t* x, const float16_t* a, int64_t len); + +float bf16_dot_with_fp32_arith( + const at::BFloat16* x, + const at::BFloat16* a, + int64_t len); } #endif @@ -326,20 +331,8 @@ static float compute_dot(const at::Half* a, const at::Half* b, int64_t len) { len); } -static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t l) { - if ((l&3) != 0) { - return sum(l, [&](int64_t i) -> float { - return float(a[i]) * float(b[i]); - }); - } - float32x4_t rcv = vdupq_n_f32(0); - for (int64_t idx = 0; idx < l; idx += 4) { - float32x4_t aVec = load_as_float32x4(a + idx); - float32x4_t bVec = load_as_float32x4(b + idx); - rcv = vaddq_f32(rcv, vmulq_f32(aVec, bVec)); - } - auto sum = vpaddq_f32(rcv, rcv); - return vgetq_lane_f32(vpaddq_f32(sum, sum), 0); +static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t len) { + return at::native::blas_impl::bf16_dot_with_fp32_arith(a, b, len); } template <> From f6ca822366e8ef19c2c02d33fc2a4629bd04c8ec Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 3 Jun 2024 10:06:02 -0700 Subject: [PATCH 283/706] Patch ARM Half use_gemv_fast_path gate to avoid kernel duplication (#127478) Summary: The existing code didn't gate the fast path, so the fast path had to duplicate the stock kernel. Now we gate it and delete the duplicate kernel. Test Plan: Existing tests. Flipped the TORCH_INTERNAL_ASSERT_DEBUG_ONLY to non-debug and forced to fail (locally) to make sure we had test coverage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127478 Approved by: https://github.com/malfet ghstack dependencies: #127477 --- aten/src/ATen/native/BlasKernel.cpp | 42 ++++++++++++----------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index 8f3138966be4..efc05a2ba9d1 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -119,7 +119,9 @@ bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) { template bool gemv_use_fast_path(C10_UNUSED int64_t m, C10_UNUSED int64_t n, - C10_UNUSED int64_t lda, C10_UNUSED int64_t incx, C10_UNUSED int64_t incy) { + C10_UNUSED scalar_t alpha, C10_UNUSED int64_t lda, + C10_UNUSED int64_t incx, C10_UNUSED scalar_t beta, + C10_UNUSED int64_t incy) { return false; } @@ -138,7 +140,7 @@ void gemv_fast_path(C10_UNUSED const char *trans, C10_UNUSED const int *m, C10_U #define INSTANTIATE(scalar_t) \ template bool scal_use_fast_path(int64_t n, int64_t incx); \ -template bool gemv_use_fast_path(int64_t m, int64_t n, int64_t lda, int64_t incx, int64_t incy); \ +template bool gemv_use_fast_path(int64_t m, int64_t n, scalar_t alpha, int64_t lda, int64_t incx, scalar_t beta, int64_t incy); \ template void gemv_fast_path(const char *trans, const int *m, const int *n, const scalar_t *alpha, const scalar_t *a, const int *lda, const scalar_t *x, const int *incx, const scalar_t *beta, scalar_t *y, const int *incy); \ template void scal_fast_path(int *n, scalar_t *a, scalar_t *x, int *incx); @@ -165,15 +167,15 @@ void scal_fast_path(int *n, float *a, float *x, int *incx) { } template <> -bool gemv_use_fast_path(int64_t m, int64_t n, int64_t lda, int64_t incx, int64_t incy) { +bool gemv_use_fast_path(int64_t m, int64_t n, C10_UNUSED float alpha, int64_t lda, int64_t incx, C10_UNUSED float beta, int64_t incy) { auto intmax = std::numeric_limits::max(); return (m <= intmax) && (n <= intmax) && (lda <= intmax) && (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); } template <> -bool gemv_use_fast_path(int64_t m, int64_t n, int64_t lda, int64_t incx, int64_t incy) { - return gemv_use_fast_path(m, n, lda, incx, incy); +bool gemv_use_fast_path(int64_t m, int64_t n, C10_UNUSED double alpha, int64_t lda, int64_t incx, C10_UNUSED double beta, int64_t incy) { + return gemv_use_fast_path(m, n, (float)alpha, lda, incx, (float)beta, incy); } template <> @@ -206,10 +208,13 @@ template <> bool gemv_use_fast_path( C10_UNUSED int64_t m, C10_UNUSED int64_t n, + at::Half alpha, C10_UNUSED int64_t lda, C10_UNUSED int64_t incx, + at::Half beta, C10_UNUSED int64_t incy) { - return true; + return incx == 1 && c10::detail::fp16_from_bits(alpha.x) == 1.0f && + c10::detail::fp16_from_bits(beta.x) == 0.0f; } #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC @@ -501,26 +506,13 @@ void fp16_gemv_trans( const float beta, float16_t* y, const int incy) { - if (incx == 1 && alpha == 1.0 && beta == 0.0) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0); #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC - if (at::globalContext().allowFP16ReductionCPU()) { - return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, y, incy); - } -#endif - return fp16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy); - } - for (const auto i : c10::irange(n)) { - float sum = 0; - const auto row_ = a + lda * i; - for (const auto j : c10::irange(m)) { - sum += x[j * incx] * row_[j]; - } - if (beta == 0.0) { - y[i * incy] = alpha * sum; - } else { - y[i * incy] = beta * y[i * incy] + alpha * sum; - } + if (at::globalContext().allowFP16ReductionCPU()) { + return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, y, incy); } +#endif + return fp16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy); } @@ -670,7 +662,7 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, i if(n == 1) lda = m; #if AT_BUILD_WITH_BLAS() - if (blas_impl::gemv_use_fast_path(m, n, lda, incx, incy)) { + if (blas_impl::gemv_use_fast_path(m, n, alpha, lda, incx, beta, incy)) { TORCH_CHECK(lda >= std::max(1L, m), "lda should be at least max(1,", m, "), but have ", lda); int i_m = (int)m; int i_n = (int)n; From 0f1f0d3015f2f5ef0753468ed21d09e388d19c24 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 3 Jun 2024 10:06:03 -0700 Subject: [PATCH 284/706] Onboard ARM bfloat16 to gemv fast path (#127484) Summary: Used bfloat16 dot support from #127477 to write a bfloat16 transposed fast path and integrated it. Test Plan: Ran https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py before and after on my Apple M1 Pro. Before: ``` mv_nt torch.float32 6.77 usec mv_nt torch.float16 8.24 usec mv_nt torch.bfloat16 184.74 usec mv_ta torch.float32 5.71 usec mv_ta torch.float16 27.95 usec mv_ta torch.bfloat16 98.06 usec notrans torch.float32 5.55 usec notrans torch.float16 25.11 usec notrans torch.bfloat16 63.55 usec trans_a torch.float32 5.62 usec trans_a torch.float16 74.48 usec trans_a torch.bfloat16 313.19 usec trans_b torch.float32 5.68 usec trans_b torch.float16 8.18 usec trans_b torch.bfloat16 14.96 usec ``` After: ``` mv_nt torch.float32 5.40 usec mv_nt torch.float16 8.25 usec mv_nt torch.bfloat16 12.81 usec mv_ta torch.float32 5.69 usec mv_ta torch.float16 27.94 usec mv_ta torch.bfloat16 98.18 usec notrans torch.float32 5.60 usec notrans torch.float16 25.17 usec notrans torch.bfloat16 63.22 usec trans_a torch.float32 5.61 usec trans_a torch.float16 69.32 usec trans_a torch.bfloat16 316.62 usec trans_b torch.float32 5.60 usec trans_b torch.float16 8.09 usec trans_b torch.bfloat16 14.61 usec ``` Note large improvement in mv_nt torch.bfloat16 case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127484 Approved by: https://github.com/malfet ghstack dependencies: #127477, #127478 --- aten/src/ATen/native/BlasKernel.cpp | 86 +++++++++++++++++++++++++---- 1 file changed, 76 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index efc05a2ba9d1..66e39c218a06 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -118,8 +118,9 @@ bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) { } template -bool gemv_use_fast_path(C10_UNUSED int64_t m, C10_UNUSED int64_t n, - C10_UNUSED scalar_t alpha, C10_UNUSED int64_t lda, +bool gemv_use_fast_path(C10_UNUSED char trans, C10_UNUSED int64_t m, + C10_UNUSED int64_t n, C10_UNUSED scalar_t alpha, + C10_UNUSED int64_t lda, C10_UNUSED int64_t incx, C10_UNUSED scalar_t beta, C10_UNUSED int64_t incy) { return false; @@ -140,7 +141,7 @@ void gemv_fast_path(C10_UNUSED const char *trans, C10_UNUSED const int *m, C10_U #define INSTANTIATE(scalar_t) \ template bool scal_use_fast_path(int64_t n, int64_t incx); \ -template bool gemv_use_fast_path(int64_t m, int64_t n, scalar_t alpha, int64_t lda, int64_t incx, scalar_t beta, int64_t incy); \ +template bool gemv_use_fast_path(char trans, int64_t m, int64_t n, scalar_t alpha, int64_t lda, int64_t incx, scalar_t beta, int64_t incy); \ template void gemv_fast_path(const char *trans, const int *m, const int *n, const scalar_t *alpha, const scalar_t *a, const int *lda, const scalar_t *x, const int *incx, const scalar_t *beta, scalar_t *y, const int *incy); \ template void scal_fast_path(int *n, scalar_t *a, scalar_t *x, int *incx); @@ -167,15 +168,15 @@ void scal_fast_path(int *n, float *a, float *x, int *incx) { } template <> -bool gemv_use_fast_path(int64_t m, int64_t n, C10_UNUSED float alpha, int64_t lda, int64_t incx, C10_UNUSED float beta, int64_t incy) { +bool gemv_use_fast_path(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED float alpha, int64_t lda, int64_t incx, C10_UNUSED float beta, int64_t incy) { auto intmax = std::numeric_limits::max(); return (m <= intmax) && (n <= intmax) && (lda <= intmax) && (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); } template <> -bool gemv_use_fast_path(int64_t m, int64_t n, C10_UNUSED double alpha, int64_t lda, int64_t incx, C10_UNUSED double beta, int64_t incy) { - return gemv_use_fast_path(m, n, (float)alpha, lda, incx, (float)beta, incy); +bool gemv_use_fast_path(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED double alpha, int64_t lda, int64_t incx, C10_UNUSED double beta, int64_t incy) { + return gemv_use_fast_path(trans, m, n, (float)alpha, lda, incx, (float)beta, incy); } template <> @@ -197,7 +198,6 @@ INSTANTIATE(int8_t); INSTANTIATE(int16_t); INSTANTIATE(int); INSTANTIATE(int64_t); -INSTANTIATE(c10::BFloat16); #if defined(__aarch64__) && !defined(C10_MOBILE) template <> bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) { @@ -206,6 +206,7 @@ bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) template <> bool gemv_use_fast_path( + C10_UNUSED char trans, C10_UNUSED int64_t m, C10_UNUSED int64_t n, at::Half alpha, @@ -217,6 +218,20 @@ bool gemv_use_fast_path( c10::detail::fp16_from_bits(beta.x) == 0.0f; } +template <> +bool gemv_use_fast_path( + C10_UNUSED char trans, + C10_UNUSED int64_t m, + C10_UNUSED int64_t n, + at::BFloat16 alpha, + C10_UNUSED int64_t lda, + C10_UNUSED int64_t incx, + at::BFloat16 beta, + C10_UNUSED int64_t incy) { + return (trans == 'T' || trans == 't') && incx == 1 && alpha == 1.0 && beta == 0.0; +} + + #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC static inline float16_t reduce(float16x4_t x) { auto sum = vpadd_f16(x, x); @@ -495,6 +510,14 @@ static void fp16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, }); } +static void bf16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const at::BFloat16* a, const int lda, const at::BFloat16 *x, at::BFloat16* y, int incy) { + parallel_for(0, n, 1, [&](int begin, int end) { + for (int i = begin; i < end; ++i) { + y[i * incy] = bf16_dot_with_fp32_arith(x, a + lda * i, m); + } + }); +} + void fp16_gemv_trans( const int m, const int n, @@ -515,6 +538,21 @@ void fp16_gemv_trans( return fp16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy); } +void bf16_gemv_trans( + const int m, + const int n, + const at::BFloat16 alpha, + const at::BFloat16* a, + const int lda, + const at::BFloat16* x, + const int incx, + const at::BFloat16 beta, + at::BFloat16* y, + const int incy) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0); + return bf16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy); +} + #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC static void fp16_gemv_notrans_fp16_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) { @@ -629,9 +667,37 @@ void gemv_fast_path( *incy); } } -#else + +template <> +void gemv_fast_path( + const char* trans, + const int* m, + const int* n, + const at::BFloat16* alpha, + const at::BFloat16* a, + const int* lda, + const at::BFloat16* x, + const int* incx, + const at::BFloat16* beta, + at::BFloat16* y, + const int* incy) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(trans[0] == 'T' || trans[0] == 't'); + bf16_gemv_trans( + *m, + *n, + *alpha, + a, + *lda, + x, + *incx, + *beta, + y, + *incy); +} +#else // defined(__aarch64__) && !defined(C10_MOBILE) INSTANTIATE(c10::Half); -#endif +INSTANTIATE(c10::BFloat16); +#endif // defined(__aarch64__) && !defined(C10_MOBILE) #undef INSTANTIATE } // namespace blas_impl @@ -662,7 +728,7 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, i if(n == 1) lda = m; #if AT_BUILD_WITH_BLAS() - if (blas_impl::gemv_use_fast_path(m, n, alpha, lda, incx, beta, incy)) { + if (blas_impl::gemv_use_fast_path(trans, m, n, alpha, lda, incx, beta, incy)) { TORCH_CHECK(lda >= std::max(1L, m), "lda should be at least max(1,", m, "), but have ", lda); int i_m = (int)m; int i_n = (int)n; From 0e7bd7feddc7e2b5147596ab1c0e05cffb738f08 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Mon, 3 Jun 2024 22:30:11 +0000 Subject: [PATCH 285/706] [ROCm] TunableOp improvements (#124362) - use less memory; smaller default hipblaslt workspace size - options to avoid cache effects - icache flush option - rotating buffers during tuning - python APIs - unit tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/124362 Approved by: https://github.com/xw285cornell --- aten/src/ATen/cuda/Sleep.cu | 34 +++ aten/src/ATen/cuda/Sleep.h | 3 + aten/src/ATen/cuda/tunable/GemmCommon.h | 94 +++++- aten/src/ATen/cuda/tunable/GemmHipblaslt.h | 10 +- aten/src/ATen/cuda/tunable/README.md | 188 ++++++++---- aten/src/ATen/cuda/tunable/Tunable.cpp | 191 +++++++----- aten/src/ATen/cuda/tunable/Tunable.h | 85 ++++-- aten/src/ATen/cuda/tunable/TunableGemm.h | 186 +++++------- aten/src/ATen/cuda/tunable/TunableOp.h | 118 +++++--- docs/source/cuda.rst | 20 ++ docs/source/cuda.tunable.rst | 32 ++ test/test_linalg.py | 66 +++++ torch/_C/__init__.pyi.in | 15 + torch/csrc/cuda/Module.cpp | 330 +++++++++++++++++++++ torch/cuda/__init__.py | 3 +- torch/cuda/tunable.py | 242 +++++++++++++++ 16 files changed, 1304 insertions(+), 313 deletions(-) create mode 100644 docs/source/cuda.tunable.rst create mode 100644 torch/cuda/tunable.py diff --git a/aten/src/ATen/cuda/Sleep.cu b/aten/src/ATen/cuda/Sleep.cu index 4fe857e65c26..586520e25327 100644 --- a/aten/src/ATen/cuda/Sleep.cu +++ b/aten/src/ATen/cuda/Sleep.cu @@ -1,3 +1,4 @@ +#include #include #include @@ -32,4 +33,37 @@ void sleep(int64_t cycles) { C10_CUDA_KERNEL_LAUNCH_CHECK(); } +#ifdef USE_ROCM +__global__ void flush_icache_kernel() +{ + asm __volatile__("s_icache_inv \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" :: + :); +} +#endif + +void flush_icache() { +#ifdef USE_ROCM + dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 60); + dim3 block(64); + flush_icache_kernel<<>>(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#endif +} + } // namespace at::cuda diff --git a/aten/src/ATen/cuda/Sleep.h b/aten/src/ATen/cuda/Sleep.h index d31bf68ccafb..ef5e83a832f7 100644 --- a/aten/src/ATen/cuda/Sleep.h +++ b/aten/src/ATen/cuda/Sleep.h @@ -7,4 +7,7 @@ namespace at::cuda { // enqueues a kernel that spins for the specified number of cycles TORCH_CUDA_CU_API void sleep(int64_t cycles); +// flushes instruction cache for ROCm; no-op for CUDA +TORCH_CUDA_CU_API void flush_icache(); + } // namespace at::cuda diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index a1d7d0dc2163..a2c7c734a551 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -66,7 +66,7 @@ static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t siz return false; } else { - TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); + TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); } return true; @@ -76,30 +76,54 @@ static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t siz template struct GemmParams : OpParams { + GemmParams() { + duplicate_inputs_ = false; + } + std::string Signature() const override { return c10::str(transa, transb, "_", m, "_", n, "_", k); } - GemmParams* DeepCopy() const { + size_t GetSize(bool duplicate_inputs) const { + size_t size = sizeof(T) * ldc * n; + if (duplicate_inputs) { + size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + } + return size; + } + + GemmParams* DeepCopy(bool duplicate_inputs) const { GemmParams* copy = new GemmParams; *copy = *this; c10::DeviceIndex device = 0; AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); - size_t c_size = m * n * sizeof(T); + size_t c_size = ldc * n * sizeof(T); copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } return copy; } // only call on object returned by DeepCopy void Delete() { c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } } TuningStatus NumericalCheck(GemmParams *other) { auto c_dtype = c10::CppTypeToScalarType::value; - return detail::NumericalCheck(c_dtype, c, other->c, m*n) ? OK : FAIL; + return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; } char transa; @@ -115,15 +139,30 @@ struct GemmParams : OpParams { at::opmath_type beta; T* c; int64_t ldc; +private: + bool duplicate_inputs_; }; template struct GemmStridedBatchedParams : OpParams { + GemmStridedBatchedParams() { + duplicate_inputs_ = false; + } + std::string Signature() const override { return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); } - GemmStridedBatchedParams* DeepCopy() const { + size_t GetSize(bool duplicate_inputs) const { + size_t size = sizeof(T) * stride_c * batch; + if (duplicate_inputs) { + size += sizeof(T) * stride_a * batch; + size += sizeof(T) * stride_b * batch; + } + return size; + } + + GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const { GemmStridedBatchedParams* copy = new GemmStridedBatchedParams; *copy = *this; c10::DeviceIndex device = 0; @@ -132,12 +171,23 @@ struct GemmStridedBatchedParams : OpParams { copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = sizeof(T) * stride_a * batch; + size_t b_size = sizeof(T) * stride_b * batch; + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } return copy; } // only call on object returned by DeepCopy void Delete() { c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } } TuningStatus NumericalCheck(GemmStridedBatchedParams *other) { @@ -162,33 +212,59 @@ struct GemmStridedBatchedParams : OpParams { int64_t ldc; int64_t stride_c; int64_t batch; +private: + bool duplicate_inputs_; }; template struct ScaledGemmParams : OpParams { + ScaledGemmParams() { + duplicate_inputs_ = false; + } + std::string Signature() const override { return c10::str(transa, transb, "_", m, "_", n, "_", k); } - ScaledGemmParams* DeepCopy() const { + size_t GetSize(bool duplicate_inputs) const { + size_t size = sizeof(T) * ldc * n; + if (duplicate_inputs) { + size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + } + return size; + } + + ScaledGemmParams* DeepCopy(bool duplicate_inputs) const { ScaledGemmParams* copy = new ScaledGemmParams; *copy = *this; c10::DeviceIndex device = 0; AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); - size_t c_size = m * n * sizeof(T); + size_t c_size = ldc * n * sizeof(T); copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size); AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size); + copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size); + copy->duplicate_inputs_ = true; + } return copy; } // only call on object returned by DeepCopy void Delete() { c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } } TuningStatus NumericalCheck(ScaledGemmParams *other) { - return detail::NumericalCheck(c_dtype, c, other->c, m*n) ? OK : FAIL; + return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; } char transa; @@ -212,6 +288,8 @@ struct ScaledGemmParams : OpParams { ScalarType c_dtype; void* amax_ptr; bool use_fast_accum; +private: + bool duplicate_inputs_; }; } // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index b26c2415af7b..a9c420700275 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -263,19 +263,19 @@ static size_t GetHipblasltWorkspaceSize() { // 256MB is max workspace size allowed for hipblaslt // hipblaslt-bench uses 32MB // recommendation from hipblaslt author was 76MB - size_t workspace_size = 2*128*1024*1024; // default 256MB + size_t workspace_size = 32*1024; // going with 32MB if (env) { try { workspace_size = std::stoi(env); } catch(std::invalid_argument const& e) { TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,", - " using default workspace size of ", workspace_size, " bytes."); + " using default workspace size of ", workspace_size, " KiB."); } catch(std::out_of_range const& e) { TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,", - " using default workspace size of ", workspace_size, " bytes."); + " using default workspace size of ", workspace_size, " KiB."); } } - return workspace_size; + return workspace_size * 1024; } template @@ -413,12 +413,10 @@ class HipblasltGemmOp : public Callable { if (status == HIPBLAS_STATUS_SUCCESS) { if (ret_workspace_size >= workspace_size) { - //TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " workspace too large"); return FAIL; } } else { - //TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " not supported"); return FAIL; } diff --git a/aten/src/ATen/cuda/tunable/README.md b/aten/src/ATen/cuda/tunable/README.md index 364e6975c6c6..e17ff71f3004 100644 --- a/aten/src/ATen/cuda/tunable/README.md +++ b/aten/src/ATen/cuda/tunable/README.md @@ -2,67 +2,30 @@ This directory implements a TunableOp interface. -Some operations, such as GEMMs, could be implemented using more than one library or more than one technique. For -example, a GEMM could be implemented for CUDA or ROCm using either the blas or blasLt libraries. Further, ROCm's -rocblas and hipblaslt libraries allow the user to query for all possible algorithms and then choose one. How does one -know which implementation is the fastest and should be chosen? That's what TunableOp provides. - -The behavior of TunableOp is currently easily manipulated through environment variables, though you could use the C++ -interface of at::cuda::tunable::getTuningContext(). A Python interface to the TuningContext does not yet exist. - -Currently only a TunableGemm for ROCm is implemented. Any call to at::cuda::blas::gemm() can optionally use the -TunableGemm. Calling gemm() for a given set of input arguments (transa, transb, m, n, k) will attempt to use the -fastest available implementation. - -## Environment Variables - -#### PYTORCH_TUNABLEOP_ENABLED -Default is 0. Set to 1 to enable. -This is the big on/off switch for all TunableOp implementations. - -#### PYTORCH_TUNABLEOP_TUNING -Default is 1. Set to 0 to disable. -When enabled, if a tuned entry isn't found, run the tuning step and record the entry. - -#### PYTORCH_TUNABLEOP_VERBOSE -Default is 0. Set to 1 to enable. -This will produce a lot of diagnostic messages but may be useful to see if TunableOp is being used at all. -Otherwise, TunableOp is completely silent unless there is a warning or error during its use. - -#### PYTORCH_TUNABLEOP_FILENAME -Default is 'tunableop_results.csv'. If you provide a filename, the TuningContext will attempt to read it the first time -the context is used. If tuning is enabled and new tunings are discovered, it will also write out to this same filename -with all tunings, both the ones it read in at startup as well as the new ones found at runtime. This can be used, for -example, to build up a tunings file across many workloads by reusing the same file. Unsetting this variable is not -recommended but can be done, in which case the tuning results will not be saved. - -#### PYTORCH_TUNABLEOP_NUMERICAL_CHECK -Default is 1. Set to 0 to disable. Compare the results of each possible solution against the default solution and reject -those with low accuracy. - -#### PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED -Default is 1. Set to 0 to disable hipblaslt being considered during tuning. - -### Tuning Iterations -By default, each possible solution for a given operator will be run for either 100 iterations or as many iterations can -be run within 30ms, whichever is smaller. Its average execution will be calculated. The fastest solution is chosen. In -addition, a set of warm up iterations can optionally be run prior to the timed iterations. The following environment -variables can be used to set either the maximum number of iterations to attempt or the maximum amount of time allowed in -milliseconds, or both, in which case the smaller of the two values used. - -#### PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS -Default is 30. - -#### PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS -Default is 100. - -#### PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS -Default is 0, meaning it is not used. - -#### PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS -Default is 1. - -## File Output +Some operations, such as GEMMs, could be implemented using more than one library or more than one technique. For +example, a GEMM could be implemented for CUDA or ROCm using either the blas or blasLt libraries. Further, ROCm's +rocblas and hipblaslt libraries allow the user to query for all possible algorithms and then choose one. How does one +know which implementation is the fastest and should be chosen? That's what TunableOp provides. + +## Enabling TunableOp and Tuning Separately +The TunableOp feature is enabled separately from enabling the tuning phase itself. Enabling TunableOp means that PyTorch +will replace any standard operators with their Tunable implementations. Any call to a TunableOp first checks whether it +has already been tuned for the given operator inputs. If so, it will immediately call the tuned operation; no further +tuning will take place even when the tuning setting is enabled. Instead if no tuning result is found, and tuning is +enabled, the TunableOp will benchmark every registered implementation of that operator for the given set of inputs and +select the fastest. + +## File Input and Output +The first time any TunableOp is invoked, the internal database of tuned operations will be prepared by attempting to +read the results from the given file. The default filename is 'tunableop_results.csv'. To support tuning when multiple +GPUs are used across multiple processes, the GPU device ordinal is automatically inserted into the filename to avoid +multiple processes overwriting the same file. + +If tuning is enabled and new tunings are discovered during the course of your workload, it will also write out to this +same filename with all tunings, both the ones it read in at startup as well as the new ones found at runtime. This can +be used, for example, to build up a tunings file across many workloads by reusing the same file. The output file is +automatically created when the application terminates. This behavior can be controlled by the C++ and Python APIs but +not the environment variables. Assuming you specified a filename, you'll end up with a CSV file with contents like so: @@ -75,8 +38,8 @@ GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262 GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033 ``` -Note the "Validator" lines. If you change a library verison, or rocm version, or pytorch version, TunableOp will detect -this and not load the tunings because they are likely affected by other software changes. +Note the "Validator" lines. If you change a library verison, or ROCm version, or PyTorch version, TunableOp will detect +this and reject the tunings file because the prior tunings are likely affected by other software changes. The remaining lines are the tuned solutions for each TunableOp encountered during your execution. Each line consists of 4 comma-separated fields: operator name, operator parameters, solution name, and average execution time. The execution @@ -86,3 +49,102 @@ hipBLAS or hipBLASLt libraries, if you know the specific solution index you can selected by replacing the value. The operator name and parameters (fields 1 and 2) are internally named and should not be modified. In the case of GemmTunableOp, field 1 indicates the datatype and whether the inputs are transposed (T) or not (N) and field 2 indicates the M, N, K input shapes. + +There is an option to enable verbose output but it is only recommended for debugging purposes. This will produce a lot +of diagnostic messages but may be useful to see if TunableOp is being used at all. Otherwise, TunableOp is completely +silent, besides file output, unless there is a warning or error during its use. + +## A Note on Tuning Behavior, Warmup, and Cache Effects +Tuning an operator consists of iterating through the list or registered implementations and profiling each one. The +profile is established by running a single implementation in a loop multiple times and taking the average execution +time. There is also an optional warmup phase prior to tuning that can help with reaching stable power states by the +hardware. During tuning of a workload the various hardware caches will more likely produce hits than when not tuning. +There are options for flushing the instruction cache and rotate the input tensors which might help produce a more +faithful profile of the tuned operator as if the operator were run within a larger workload instead of in a tight, +repetitive loop. + +By default, each possible solution for a given operator will be run for either 100 iterations or as many iterations that +can be run within 30ms, whichever is smaller, and its average execution will be calculated. The fastest solution among +all that were successfully profiled will be chosen. A profile might fail if the given solution doesn't achieve the same +accuracy as the default implementation or if the solution returns an error code. + +## Current Tunable Operators + +### TunableGemm for ROCm +Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of PyTorch will function correctly when +using TunableOp but the only solution available to CUDA builds is the 'Default' implementation i.e. the original cuBLAS +default, now called through TunableOp. Any call to at::cuda::blas::gemm() or ::bgemm() will be routed through TunableOp +when enabled. Calling gemm() for a given set of input arguments (transa, transb, m, n, k) will attempt to use the +fastest available implementation across both rocblas and hipblaslt. + +## Tuning Context +The behavior of TunableOp is currently manipulated through environment variables, the C++ interface of +at::cuda::tunable::getTuningContext(), or the `torch.cuda.tunable` python interfaces. The environment variables take +precedence over any setting you manipulate using the C++ or Python APIs. + +### Environment Variable Interface +Environment variables are cached the first time they are read. You cannot use the environment variable interface +programmatically since the settings become fixed. Use the C++ or Python APIs instead. + +| Environment Variable | Description | +| -------------------- | ----------- | +| PYTORCH_TUNABLEOP_ENABLED | Default is 0. Set to 1 to enable. | +| PYTORCH_TUNABLEOP_TUNING | Default is 1. Set to 0 to disable. | +| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. | +| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. | +| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. | +| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. | +| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. | +| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. | +| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. | +| PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS | Default is 100. | +| PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS | Default is 0, meaning it is not used. Unit is milliseconds. | +| PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS | Default is 0, meaning it is not used. | +| PYTORCH_TUNABLEOP_ICACHE_FLUSH_ENABLED | Default is 1. Set to 0 to disable. | +| PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE | Default is to query L2 cache size. Set to 0 to disable. Otherwise, set to the number of MiB to use for the pool of operator parameters. For example, setting this to the size of your device's memory cache will guarantee that every tuning iteration will use a cold cache. | + +### Python Interface +All python APIs exist in the `torch.cuda.tunable` module. + +| Python API | Description | +| ---------- | ----------- | +| enable(val: bool = True) -> None | | +| is_enabled() -> bool | | +| tuning_enable(val: bool = True) -> None | Default is True. | +| tuning_is_enabled() -> bool | | +| set_max_tuning_duration(duration: int) -> None | | +| get_max_tuning_duration() -> int | | +| set_max_tuning_iterations(iterations: int) -> None | | +| get_max_tuning_iterations() -> int | | +| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | | +| get_filename() -> str | | +| get_results() -> Tuple[str, str, str, float] | | +| get_validators() -> Tuple[str, str] | | +| write_file_on_exit(val: bool) -> None | Default is True. | +| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | +| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | + +### C++ Interface +Example: +```C++ +#include + +at::cuda::tunable::getTuningContext()->EnableTunableOp(true); +``` + +| C++ API | Description | +| ------- | ----------- | +| void EnableTunableOp(bool value); | | +| bool IsTunableOpEnabled() const; | | +| void EnableTuning(bool value); | | +| bool IsTuningEnabled() const; | | +| void SetMaxTuningDurationMs(int max_duration_ms); | | +| int GetMaxTuningDurationMs() const; | | +| void SetMaxTuningIterations(int max_iter); | | +| int GetMaxTuningIterations() const; | | +| TuningResults GetTuningResults(); | | +| void SetFilename(const std::string& filename, bool insert_device_ordinal=false); | | +| std::string GetFilename() const; | | +| void WriteFileOnExit(bool value); | | +| bool ReadFile(const std::string& filename={}); | | +| bool WriteFile(const std::string& filename={}); | | diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index 22bde7f4c427..fc27fab77d79 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -65,14 +65,14 @@ ResultEntry TuningResultsManager::Lookup(const std::string& op_signature, const std::scoped_lock l{lock_}; auto kernel_map_it = results_.find(op_signature); if (kernel_map_it == results_.cend()) { - TUNABLE_LOG("missing op_signature, returning null ResultEntry"); + TUNABLE_LOG3("missing op_signature, returning null ResultEntry"); return ResultEntry::Null(); } const auto& km = kernel_map_it->second; auto it = km.find(params_signature); if (it == km.cend()) { - TUNABLE_LOG("missing params_signature, returning null ResultEntry"); + TUNABLE_LOG3("missing params_signature, returning null ResultEntry"); return ResultEntry::Null(); } return it->second; @@ -85,14 +85,14 @@ inline void TuningResultsManager::AddImpl(const std::string& op_signature, auto it = kernel_map.find(params_signature); if (it != kernel_map.end()) { if (it->second != best) { - TUNABLE_LOG(op_signature, "(", params_signature, ") already has a best kernel ", + TUNABLE_LOG1(op_signature, "(", params_signature, ") already has a best kernel ", "id=", it->second, " selected, want to add a different best kernel ", best, ", the new kernel id will be ignored."); } return; } - TUNABLE_LOG(op_signature, "(", params_signature, ") -> ", best); + TUNABLE_LOG2(op_signature, "(", params_signature, ") -> ", best); kernel_map.emplace(params_signature, best); } @@ -120,7 +120,7 @@ void TuningResultsManager::Delete(const std::string& op_signature, const std::st return; } - TUNABLE_LOG(op_signature, "(", params_signature, ")"); + TUNABLE_LOG2(op_signature, "(", params_signature, ")"); it->second.erase(it2); } @@ -131,7 +131,7 @@ inline void TuningResultsManager::DisjointMergeImpl( auto it = results.find(op_signature); if (it == results.end()) { for (const auto& [param_sig, kernel_id] : kernel_map) { - TUNABLE_LOG(op_signature, "(", param_sig, ") -> ", kernel_id); + TUNABLE_LOG2(op_signature, "(", param_sig, ") -> ", kernel_id); } results[op_signature] = kernel_map; return; @@ -143,7 +143,7 @@ inline void TuningResultsManager::DisjointMergeImpl( } void TuningResultsManager::Load(const std::unordered_map& results_to_load) { - TUNABLE_LOG("Loading results"); + TUNABLE_LOG1("Loading results"); std::scoped_lock l{lock_}; for (const auto& [op_signature, kernel_map] : results_to_load) { DisjointMergeImpl(op_signature, kernel_map, results_); @@ -194,12 +194,12 @@ static bool CheckMandatoryKeys( for (const auto& k : TuningResultsValidator::mandatory_keys) { if (gv_funcs.find(k) == gv_funcs.end()) { passed = false; - TUNABLE_LOG("key=\"", k, "\" is not registered for Get and Validate. "); + TUNABLE_LOG1("key=\"", k, "\" is not registered for Get and Validate. "); } if (to_check.find(k) == to_check.end()) { passed = false; - TUNABLE_LOG("key=\"", k, "\" is not provided for validation. "); + TUNABLE_LOG1("key=\"", k, "\" is not provided for validation. "); } } return passed; @@ -294,10 +294,14 @@ TuningContext::TuningContext() : enable_{false}, tuning_enable_{true}, manager_initialized_{false}, + write_file_on_exit_{true}, + numerics_check_enable_{false}, max_tuning_duration_ms_{30}, max_tuning_iterations_{100}, max_warmup_duration_ms_{0}, max_warmup_iterations_{0}, + icache_flush_{true}, + rotating_buffer_size_{-1}, filename_{}, results_count_from_input_file_{0} { @@ -311,115 +315,158 @@ TuningContext::~TuningContext() { return; } auto filename = GetFilename(); - if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty()) { + if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) { if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) { if (results_count_from_input_file_ > 0) { - TUNABLE_LOG("additional tuning results available, rewriting file ", filename); + TUNABLE_LOG1("additional tuning results available, rewriting file ", filename); } else { - TUNABLE_LOG("writing file ", filename); + TUNABLE_LOG1("writing file ", filename); } if (!WriteFile(filename)) { - TUNABLE_LOG("failed to write file ", filename); + TUNABLE_LOG1("failed to write file ", filename); } } } } -void TuningContext::EnableTunableOp() { - TUNABLE_LOG("Enable TunableOp"); - enable_ = true; -} - -void TuningContext::DisableTunableOp() { - TUNABLE_LOG("Disable TunableOp"); - enable_ = false; +void TuningContext::EnableTunableOp(bool value) { + enable_ = value; + if (value) { + TUNABLE_LOG1("Enable TunableOp"); + } + else { + TUNABLE_LOG1("Disable TunableOp"); + } } bool TuningContext::IsTunableOpEnabled() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_ENABLED"); if (env != nullptr && strcmp(env, "1") == 0) { - //TUNABLE_LOG("PYTORCH_TUNABLEOP_ENABLED=1"); return true; } return enable_; } -void TuningContext::EnableTuning() { - TUNABLE_LOG("Enable Tuning for TunableOp"); - tuning_enable_ = true; -} - -void TuningContext::DisableTuning() { - TUNABLE_LOG("Disable Tuning for TunableOp"); - tuning_enable_ = false; +void TuningContext::EnableTuning(bool value) { + tuning_enable_ = value; + if (value) { + TUNABLE_LOG1("Enable Tuning for TunableOp"); + } + else { + TUNABLE_LOG1("Disable Tuning for TunableOp"); + } } bool TuningContext::IsTuningEnabled() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_TUNING"); if (env != nullptr && strcmp(env, "0") == 0) { - //TUNABLE_LOG("PYTORCH_TUNABLEOP_TUNING=1"); return false; } return tuning_enable_; } +void TuningContext::WriteFileOnExit(bool value) { + write_file_on_exit_ = value; +} + +void TuningContext::EnableNumericsCheck(bool value) { + numerics_check_enable_ = value; +} + +bool TuningContext::IsNumericsCheckEnabled() const { + static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); + if (env != nullptr && strcmp(env, "0") == 0) { + return false; + } + return numerics_check_enable_; +} + void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) { - max_tuning_duration_ms_ = max_duration_ms; + max_tuning_duration_ms_ = max_duration_ms < 0 ? 0 : max_duration_ms; } int TuningContext::GetMaxTuningDurationMs() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"); if (env != nullptr) { - return atoi(env); + int val = atoi(env); + return val < 0 ? 0 : val; } return max_tuning_duration_ms_; } void TuningContext::SetMaxTuningIterations(int max_iter) { - max_tuning_iterations_ = max_iter; + max_tuning_iterations_ = max_iter < 0 ? 0 : max_iter; } int TuningContext::GetMaxTuningIterations() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"); if (env != nullptr) { - return atoi(env); + int val = atoi(env); + return val < 0 ? 0 : val; } return max_tuning_iterations_; } void TuningContext::SetMaxWarmupDurationMs(int max_duration_ms) { - max_warmup_duration_ms_ = max_duration_ms; + max_warmup_duration_ms_ = max_duration_ms < 0 ? 0 : max_duration_ms; } int TuningContext::GetMaxWarmupDurationMs() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"); if (env != nullptr) { - return atoi(env); + int val = atoi(env); + return val < 0 ? 0 : val; } return max_warmup_duration_ms_; } void TuningContext::SetMaxWarmupIterations(int max_iter) { - max_warmup_iterations_ = max_iter; + max_warmup_iterations_ = max_iter < 0 ? 0 : max_iter; } int TuningContext::GetMaxWarmupIterations() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS"); if (env != nullptr) { - return atoi(env); + int val = atoi(env); + return val < 0 ? 0 : val; } return max_warmup_iterations_; } -void TuningContext::EnableTunableOpAndTuning() { - EnableTunableOp(); - EnableTuning(); +void TuningContext::EnableICacheFlush(bool value) { + icache_flush_ = value; } -void TuningContext::DisableTunableOpAndTuning() { - DisableTunableOp(); - DisableTuning(); +bool TuningContext::IsICacheFlushEnabled() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_ICACHE_FLUSH_ENABLED"); + if (env != nullptr && strcmp(env, "0") == 0) { + return false; + } + return icache_flush_; +} + +void TuningContext::SetRotatingBufferSize(int size) { + rotating_buffer_size_ = size < 0 ? 0 : size; +} + +int TuningContext::GetRotatingBufferSize() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"); + if (env != nullptr) { + constexpr int MB = 1024 * 1024; + int val = atoi(env); + return val < 0 ? 0 : val * MB; // env var is specified as MB, returned as bytes + } + else { + if (rotating_buffer_size_ < 0) { + // negative buffer size (default) means query for L2 cache size + int l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize; + return l2_cache_size; + } + else { + return rotating_buffer_size_; + } + } } TuningResultsManager& TuningContext::GetTuningResultsManager() { @@ -429,7 +476,7 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() { // if SetFilename() was not already called, call it now with the default or env var const char *env = std::getenv("PYTORCH_TUNABLEOP_FILENAME"); std::string filename = (env == nullptr) ? "tunableop_results.csv" : env; - SetFilename(filename); + SetFilename(filename, true); } auto filename = GetFilename(); if (!filename.empty()) { @@ -461,32 +508,34 @@ TuningStatus TuningContext::LoadTuningResults(const TuningResults& tr) { return OK; } -void TuningContext::SetFilename(const std::string& filename) { +void TuningContext::SetFilename(const std::string& filename, bool insert_device_ordinal) { filename_ = filename; if (filename_.empty()) { return; } - // differentiate filename based on device ordinal to avoid - // use case of one process per device writing to same file - std::string device = c10::str(int(c10::cuda::current_device())); + if (insert_device_ordinal) { + // differentiate filename based on device ordinal to avoid + // use case of one process per device writing to same file + std::string device = c10::str(int(c10::cuda::current_device())); - // does filename contain %d to insert device ordinal in specific location? - const std::string TOKEN("%d"); - std::size_t found = filename_.find(TOKEN); - if (found != std::string::npos) { - filename_.replace(found, TOKEN.length(), device); - } - else { - // no %d present, so append device ordinal before final '.' - found = filename_.rfind("."); + // does filename contain %d to insert device ordinal in specific location? + const std::string TOKEN("%d"); + std::size_t found = filename_.find(TOKEN); if (found != std::string::npos) { - filename_.insert(found, device); + filename_.replace(found, TOKEN.length(), device); } else { - // all else fails, just append - filename_.append(device); + // no %d present, so append device ordinal before final '.' + found = filename_.rfind("."); + if (found != std::string::npos) { + filename_.insert(found, device); + } + else { + // all else fails, just append + filename_.append(device); + } } } } @@ -495,14 +544,15 @@ std::string TuningContext::GetFilename() const { return filename_; } -bool TuningContext::ReadFile(const std::string& filename) { - TUNABLE_LOG("reading tuning results from ", filename); +bool TuningContext::ReadFile(const std::string& filename_) { + std::string filename = filename_.empty() ? GetFilename() : filename_; + TUNABLE_LOG1("reading tuning results from ", filename); ResultsMap results; std::unordered_map validators; std::string line; std::ifstream file(filename); if (!file) { - TUNABLE_LOG("could not open ", filename, " for reading tuning results"); + TUNABLE_LOG1("could not open ", filename, " for reading tuning results"); return false; } while (std::getline(file, line)) { @@ -517,7 +567,7 @@ bool TuningContext::ReadFile(const std::string& filename) { } if (parts[0] == "Validator" && parts.size() >= 3) { validators[parts[1]] = parts[2]; - TUNABLE_LOG("Validator ", parts[1], "=", parts[2]); + TUNABLE_LOG1("Validator ", parts[1], "=", parts[2]); } else if (parts.size() >= 4) { results[parts[0]].emplace(parts[1], ResultEntry(parts[2], atof(parts[3].c_str()))); @@ -527,7 +577,7 @@ bool TuningContext::ReadFile(const std::string& filename) { results[parts[0]].emplace(parts[1], ResultEntry(parts[2], 0)); } else { - TUNABLE_LOG("could not parse line: ", line); + TUNABLE_LOG1("could not parse line: ", line); } } if (GetTuningResultsValidator().ValidateAll(validators) != FAIL) { @@ -535,16 +585,17 @@ bool TuningContext::ReadFile(const std::string& filename) { results_count_from_input_file_ = manager_.GetSize(); } else { - TUNABLE_LOG("results validator check failed"); + TUNABLE_LOG1("results validator check failed"); return false; } return true; } -bool TuningContext::WriteFile(const std::string& filename) { +bool TuningContext::WriteFile(const std::string& filename_) { + std::string filename = filename_.empty() ? GetFilename() : filename_; std::ofstream file(filename, std::ios::out | std::ios::trunc); if (!file.good()) { - TUNABLE_LOG("error opening tuning results file for writing ", filename); + TUNABLE_LOG1("error opening tuning results file for writing ", filename); return false; } auto validators = GetTuningResultsValidator().GetAllValidators(); diff --git a/aten/src/ATen/cuda/tunable/Tunable.h b/aten/src/ATen/cuda/tunable/Tunable.h index eb849a213fe5..243031cf3da2 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.h +++ b/aten/src/ATen/cuda/tunable/Tunable.h @@ -11,6 +11,7 @@ #include +#include #include #include #include @@ -23,27 +24,58 @@ namespace at::cuda::tunable { -static void TunableLog(const std::string& msg) { - static const char *env = getenv("PYTORCH_TUNABLEOP_VERBOSE"); - if (env != nullptr && strcmp(env, "1") == 0) { - std::cerr << msg << std::endl; +namespace detail { + +struct MaybeDelete { + bool owns_pointer; + void operator()(std::ostream* os) const { if (owns_pointer) delete os; } +}; + +using OstreamPtr = std::unique_ptr; + +static OstreamPtr get_stream(std::string filename) { + if (filename.compare("out") == 0) { + return OstreamPtr { &std::cout, MaybeDelete {false} }; } + else if (filename.compare("err") == 0) { + return OstreamPtr { &std::cerr, MaybeDelete {false} }; + } + else { + return OstreamPtr { new std::ofstream {filename.c_str()}, MaybeDelete {true} }; + } +} + } -#define TUNABLE_LOG(...) TunableLog(c10::str(__VA_ARGS__)) -enum TuningStatus { +static void TunableLog(int level, const std::string& msg) { + static const char *env_file = getenv("PYTORCH_TUNABLEOP_VERBOSE_FILENAME"); + static const char *env_verbose = getenv("PYTORCH_TUNABLEOP_VERBOSE"); + static int level_user = env_verbose ? atoi(env_verbose) : 0; + static auto streamptr = detail::get_stream(env_file ? env_file : "err"); + if (level_user >= level) { + (*streamptr) << msg < KernelMap; typedef std::unordered_map ResultsMap; -struct TuningResults { +struct TORCH_CUDA_CPP_API TuningResults { // Validates if these results are compatible with the libraries std::unordered_map validators; @@ -64,7 +96,7 @@ struct TuningResults { ResultsMap results; }; -class TuningResultsManager { +class TORCH_CUDA_CPP_API TuningResultsManager { public: TuningResultsManager() = default; ~TuningResultsManager() = default; @@ -102,7 +134,7 @@ class TuningResultsManager { ResultsMap results_; }; -class TuningResultsValidator { +class TORCH_CUDA_CPP_API TuningResultsValidator { public: using GetFunc = std::function; using ValidateFunc = std::function; @@ -126,7 +158,7 @@ class TuningResultsValidator { GetValidateFuncs validators_; }; -class TuningContext { +class TORCH_CUDA_CPP_API TuningContext { public: TuningContext(); ~TuningContext(); @@ -135,14 +167,15 @@ class TuningContext { TuningContext &operator=(TuningContext &) = delete; TuningContext &operator=(TuningContext &&) = delete; - void EnableTunableOp(); - void DisableTunableOp(); + void EnableTunableOp(bool value); bool IsTunableOpEnabled() const; - void EnableTuning(); - void DisableTuning(); + void EnableTuning(bool value); bool IsTuningEnabled() const; + void EnableNumericsCheck(bool value); + bool IsNumericsCheckEnabled() const; + void SetMaxTuningDurationMs(int max_duration_ms); int GetMaxTuningDurationMs() const; @@ -155,8 +188,11 @@ class TuningContext { void SetMaxWarmupIterations(int max_iter); int GetMaxWarmupIterations() const; - void EnableTunableOpAndTuning(); - void DisableTunableOpAndTuning(); + void EnableICacheFlush(bool value); + bool IsICacheFlushEnabled() const; + + void SetRotatingBufferSize(int size); + int GetRotatingBufferSize() const; TuningResultsManager& GetTuningResultsManager(); @@ -166,21 +202,26 @@ class TuningContext { TuningStatus LoadTuningResults(const TuningResults& tr); - void SetFilename(const std::string& filename); + void SetFilename(const std::string& filename, bool insert_device_ordinal=false); std::string GetFilename() const; - protected: - bool ReadFile(const std::string& filename); - bool WriteFile(const std::string& filename); + void WriteFileOnExit(bool value); + + bool ReadFile(const std::string& filename={}); + bool WriteFile(const std::string& filename={}); private: bool enable_; bool tuning_enable_; bool manager_initialized_; + bool write_file_on_exit_; + bool numerics_check_enable_; int max_tuning_duration_ms_; int max_tuning_iterations_; int max_warmup_duration_ms_; int max_warmup_iterations_; + bool icache_flush_; + int rotating_buffer_size_; mutable TuningResultsManager manager_; mutable c10::once_flag manager_init_once_; TuningResultsValidator validator_; @@ -188,7 +229,7 @@ class TuningContext { size_t results_count_from_input_file_; }; -TuningContext* getTuningContext(); +TORCH_CUDA_CPP_API TuningContext* getTuningContext(); class ITimer { public: diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index 1eaf251caad7..53e6154120c9 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -175,6 +175,56 @@ inline std::string TypeName(c10::complex v) { return "c10::complex"; } +#ifdef USE_ROCM +static void AddRocblasValidator() { + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + if (validators.find("ROCBLAS_VERSION") == validators.end()) { + std::string rocblas_version = c10::str( + XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", + XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".", + XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-", + XSTRINGIFY(ROCBLAS_VERSION_TWEAK)); + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "ROCBLAS_VERSION", + [rocblas_version]() { return rocblas_version; }, + [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); + } +} + +static void AddHipblasltValidator() { + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + if (validators.find("HIPBLASLT_VERSION") == validators.end()) { + std::string hipblaslt_version = c10::str( + XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".", + XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".", + XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-", + XSTRINGIFY(HIPBLASLT_VERSION_TWEAK)); + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "HIPBLASLT_VERSION", + [hipblaslt_version]() { return hipblaslt_version; }, + [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); + } +} + +static void AddRocmValidator() { + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + if (validators.find("ROCM_VERSION") == validators.end()) { + std::string rocm_version = ROCM_BUILD_INFO; + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "ROCM_VERSION", + [rocm_version]() { return rocm_version; }, + [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); + } + + if (validators.find("GCN_ARCH_NAME") == validators.end()) { + std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName; + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "GCN_ARCH_NAME", + [gcn_arch_name]() { return gcn_arch_name; }, + [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); + } +} +#endif template class GemmTunableOp : public TunableOp, StreamTimer> { @@ -182,45 +232,21 @@ class GemmTunableOp : public TunableOp, StreamTimer> { GemmTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); - auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); - #ifdef USE_ROCM - for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) { - this->RegisterOp(std::move(name), std::move(op)); - } - - if (validators.find("ROCM_VERSION") == validators.end()) { - std::string rocm_version = ROCM_BUILD_INFO; - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "ROCM_VERSION", - [rocm_version]() { return rocm_version; }, - [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); - } - - if (validators.find("GCN_ARCH_NAME") == validators.end()) { - std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName; - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "GCN_ARCH_NAME", - [gcn_arch_name]() { return gcn_arch_name; }, - [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); - } + bool rocm_validators = false; - if (validators.find("ROCBLAS_VERSION") == validators.end()) { - std::string rocblas_version = c10::str( - XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", - XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".", - XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-", - XSTRINGIFY(ROCBLAS_VERSION_TWEAK)); - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "ROCBLAS_VERSION", - [rocblas_version]() { return rocblas_version; }, - [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); + static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) { + rocm_validators = true; + for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + AddRocblasValidator(); } -#endif -#if defined(USE_ROCM) - static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); - if (env == nullptr || strcmp(env, "1") == 0) { + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + rocm_validators = true; // disallow tuning of hipblaslt with c10::complex if constexpr ( !std::is_same_v> && @@ -229,18 +255,11 @@ class GemmTunableOp : public TunableOp, StreamTimer> { this->RegisterOp(std::move(name), std::move(op)); } } + AddHipblasltValidator(); + } - if (validators.find("HIPBLASLT_VERSION") == validators.end()) { - std::string hipblaslt_version = c10::str( - XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-", - XSTRINGIFY(HIPBLASLT_VERSION_TWEAK)); - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "HIPBLASLT_VERSION", - [hipblaslt_version]() { return hipblaslt_version; }, - [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); - } + if (rocm_validators) { + AddRocmValidator(); } #endif } @@ -256,45 +275,21 @@ class GemmStridedBatchedTunableOp : public TunableOp GemmStridedBatchedTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); - auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); - #ifdef USE_ROCM - for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) { - this->RegisterOp(std::move(name), std::move(op)); - } - - if (validators.find("ROCM_VERSION") == validators.end()) { - std::string rocm_version = ROCM_BUILD_INFO; - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "ROCM_VERSION", - [rocm_version]() { return rocm_version; }, - [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); - } - - if (validators.find("GCN_ARCH_NAME") == validators.end()) { - std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName; - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "GCN_ARCH_NAME", - [gcn_arch_name]() { return gcn_arch_name; }, - [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); - } + bool rocm_validators = false; - if (validators.find("ROCBLAS_VERSION") == validators.end()) { - std::string rocblas_version = c10::str( - XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", - XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".", - XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-", - XSTRINGIFY(ROCBLAS_VERSION_TWEAK)); - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "ROCBLAS_VERSION", - [rocblas_version]() { return rocblas_version; }, - [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); + static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) { + rocm_validators = true; + for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + AddRocblasValidator(); } -#endif -#if defined(USE_ROCM) - static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); - if (env == nullptr || strcmp(env, "1") == 0) { + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + rocm_validators = true; // disallow tuning of hipblaslt with c10::complex if constexpr ( !std::is_same_v> && @@ -303,18 +298,11 @@ class GemmStridedBatchedTunableOp : public TunableOp this->RegisterOp(std::move(name), std::move(op)); } } + AddHipblasltValidator(); + } - if (validators.find("HIPBLASLT_VERSION") == validators.end()) { - std::string hipblaslt_version = c10::str( - XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-", - XSTRINGIFY(HIPBLASLT_VERSION_TWEAK)); - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "HIPBLASLT_VERSION", - [hipblaslt_version]() { return hipblaslt_version; }, - [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); - } + if (rocm_validators) { + AddRocmValidator(); } #endif } @@ -336,18 +324,8 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } - - if (validators.find("HIPBLASLT_VERSION") == validators.end()) { - std::string hipblaslt_version = c10::str( - XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-", - XSTRINGIFY(HIPBLASLT_VERSION_TWEAK)); - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "HIPBLASLT_VERSION", - [hipblaslt_version]() { return hipblaslt_version; }, - [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); - } + AddHipblasltValidator(); + AddRocmValidator(); #endif } diff --git a/aten/src/ATen/cuda/tunable/TunableOp.h b/aten/src/ATen/cuda/tunable/TunableOp.h index 65257974ab0c..f158e11cef0a 100644 --- a/aten/src/ATen/cuda/tunable/TunableOp.h +++ b/aten/src/ATen/cuda/tunable/TunableOp.h @@ -10,6 +10,7 @@ #pragma once #include +#include #include #ifndef _WIN32 @@ -62,7 +63,7 @@ class TunableOp { result = ResultEntry::Default(); } if (result == ResultEntry::Null()) { - TUNABLE_LOG("no result, using default"); + TUNABLE_LOG2("no result, using default"); result = ResultEntry::Default(); } auto iter = ops_.find(result); @@ -87,88 +88,120 @@ class TunableOp { } private: - static void WarmUp(Callable *op, ParamsT* param, size_t num_iter) { + static void WarmUp(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); for (size_t i = 0; i < num_iter; i++) { - TORCH_CHECK(op->Call(param) == OK); + if (do_flush) { + at::cuda::flush_icache(); + } + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); } } - static double Profile(Callable *op, ParamsT* param, size_t num_iter) { + static double Profile(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); TimerT timer{}; timer.Start(); for (size_t i = 0; i < num_iter; i++) { - TORCH_CHECK(op->Call(param) == OK); + if (do_flush) { + at::cuda::flush_icache(); + } + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); } timer.End(); return timer.Duration() / num_iter; } protected: - bool IsNumericsCheckEnabled() { - static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); - if (env != nullptr && strcmp(env, "0") == 0) { - return false; - } - return true; - } - virtual ResultEntry FindFastest(const ParamsT* params) { TuningContext* ctx = getTuningContext(); auto op_sig = Signature(); auto params_sig = params->Signature(); - TUNABLE_LOG("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates"); + TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates"); auto min_duration_ms = std::numeric_limits::infinity(); std::string id_name = "Default"; + ParamsT* reference_params = nullptr; // calcaulte a reference answer for numerical check - ParamsT* reference_params = params->DeepCopy(); - TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); + if (ctx->IsNumericsCheckEnabled()) { + reference_params = params->DeepCopy(false); + TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); + } + + // need copies of params to reuse + // make as many copies as will fill the requested rotating buffer size, if requested + // rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int + size_t rotating_size = ctx->GetRotatingBufferSize(); + bool use_buffer_rotation = (rotating_size > 0); + size_t param_size = params->GetSize(use_buffer_rotation); + size_t param_count = (rotating_size / param_size) + 1; + constexpr size_t MB = 1024*1024; + if (use_buffer_rotation) { + TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ", + "Needed Size: ", param_size/MB, " MiB. ", + "Needed number of param copies: ", param_count); + } + TORCH_CHECK(param_count > 0); + + std::vector reusable_params(param_count); + for (size_t i = 0; i < param_count; i++) { + reusable_params[i] = params->DeepCopy(use_buffer_rotation); + } - // need a copy of params to reuse - ParamsT* reusable_params = params->DeepCopy(); + // for rotating buffer + size_t offset = 0; for (size_t i = 0; i < op_names_.size(); i++) { auto* candidate = ops_[op_names_[i]].get(); // borrow pointer - auto status = candidate->Call(reusable_params); - if (status != OK) { - TUNABLE_LOG("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); - continue; - } - if (IsNumericsCheckEnabled()) { - ParamsT* numerical_params = params->DeepCopy(); - WarmUp(candidate, numerical_params, 1); + if (ctx->IsNumericsCheckEnabled()) { + ParamsT* numerical_params = params->DeepCopy(false); + auto status = candidate->Call(numerical_params); + if (status != OK) { + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } status = reference_params->NumericalCheck(numerical_params); numerical_params->Delete(); if (status != OK) { - TUNABLE_LOG("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + } + else { + auto status = candidate->Call(reusable_params[0]); + if (status != OK) { + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); continue; } } // collect a small profile constexpr const int approx_num_iter = 3; - auto approx_duration = Profile(candidate, reusable_params, approx_num_iter); + auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset); // bail if too slow if (approx_duration > 2 * min_duration_ms) { - TUNABLE_LOG("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); continue; } // for warmup does user set max duration, max iters, or both? + // warmup is allowed to be skipped by setting either iterations or duration to 0 double max_warmup_duration = ctx->GetMaxWarmupDurationMs(); int max_warmup_iter = ctx->GetMaxWarmupIterations(); int warmup_iter = 1; // default - if (max_warmup_duration > 0) { + if (max_warmup_duration >= 0) { int duration_iters = max_warmup_duration / approx_duration; - if (max_warmup_iter > 0) { + if (max_warmup_iter >= 0) { warmup_iter = std::min(max_warmup_iter, duration_iters); } else { warmup_iter = duration_iters; } } - else if (max_warmup_iter > 0) { + else if (max_warmup_iter >= 0) { warmup_iter = max_warmup_iter; } @@ -188,27 +221,34 @@ class TunableOp { else if (max_tuning_iter > 0) { tuning_iter = max_tuning_iter; } + // tuning must run at least 1 iteration + tuning_iter = std::max(1, tuning_iter); // do the full warmup followed by tuning double warmup_ms = warmup_iter * approx_duration; double tuning_ms = tuning_iter * approx_duration; - TUNABLE_LOG("├──tuning using " + TUNABLE_LOG3("├──tuning using " "warmup iters ", warmup_iter, " [", warmup_ms, " ms] " "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ", "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]); - WarmUp(candidate, reusable_params, warmup_iter); - auto duration_ms = Profile(candidate, reusable_params, tuning_iter); + TUNABLE_LOG3("├──offset at ", offset); + WarmUp(candidate, reusable_params, warmup_iter, offset); + auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset); if (duration_ms < min_duration_ms) { - TUNABLE_LOG("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]); + TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]); min_duration_ms = duration_ms; id_name = op_names_[i]; } } - reusable_params->Delete(); - reference_params->Delete(); + for (size_t i = 0; i < reusable_params.size(); i++) { + reusable_params[i]->Delete(); + } + if (reference_params) { + reference_params->Delete(); + } - TUNABLE_LOG("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name); + TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name); return ResultEntry(id_name, min_duration_ms); } diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index cee1ec6af2e8..7b9bf536c145 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -144,6 +144,26 @@ Jiterator (beta) jiterator._create_jit_fn jiterator._create_multi_output_jit_fn +TunableOp +--------- + +Some operations could be implemented using more than one library or more than +one technique. For example, a GEMM could be implemented for CUDA or ROCm using +either the cublas/cublasLt libraries or hipblas/hipblasLt libraries, +respectively. How does one know which implementation is the fastest and should +be chosen? That's what TunableOp provides. Certain operators have been +implemented using multiple strategies as Tunable Operators. At runtime, all +strategies are profiled and the fastest is selected for all subsequent +operations. + +See the :doc:`documentation ` for information on how to use it. + +.. toctree:: + :hidden: + + cuda.tunable + + Stream Sanitizer (prototype) ---------------------------- diff --git a/docs/source/cuda.tunable.rst b/docs/source/cuda.tunable.rst new file mode 100644 index 000000000000..52482122ec75 --- /dev/null +++ b/docs/source/cuda.tunable.rst @@ -0,0 +1,32 @@ +.. currentmodule:: torch.cuda.tunable + +TunableOp +========= + +.. note:: + This is a prototype feature, which means it is at an early stage + for feedback and testing, and its components are subject to change. + +Overview +-------- + +.. automodule:: torch.cuda.tunable + +API Reference +------------- + +.. autofunction:: enable +.. autofunction:: is_enabled +.. autofunction:: tuning_enable +.. autofunction:: tuning_is_enabled +.. autofunction:: set_max_tuning_duration +.. autofunction:: get_max_tuning_duration +.. autofunction:: set_max_tuning_iterations +.. autofunction:: get_max_tuning_iterations +.. autofunction:: set_filename +.. autofunction:: get_filename +.. autofunction:: get_results +.. autofunction:: get_validators +.. autofunction:: write_file_on_exit +.. autofunction:: write_file +.. autofunction:: read_file diff --git a/test/test_linalg.py b/test/test_linalg.py index 040b86e60d60..207290f5a6a8 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -4479,6 +4479,72 @@ def test_matmul_small_brute_force_3d_Nd(self, device, dtype): y = make_arg(size_y, noncontiguous=nctg_y) self.check_single_matmul(x, y) + @onlyCUDA + @dtypes(*floating_types_and(torch.half)) + def test_matmul_small_brute_force_tunableop(self, device, dtype): + # disable tunableop buffer rotation for all tests everywhere, it can be slow + import os + os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"] = "0" + assert torch.cuda.tunable.is_enabled() is False, "TunableOp should be off by default" + assert torch.cuda.tunable.tuning_is_enabled(), "TunableOp's tuning should be enabled by default" + torch.cuda.tunable.tuning_enable(False) + assert torch.cuda.tunable.tuning_is_enabled() is False + torch.cuda.tunable.tuning_enable(True) + assert torch.cuda.tunable.tuning_is_enabled() + assert torch.cuda.tunable.get_max_tuning_duration() == 30 + assert torch.cuda.tunable.get_max_tuning_iterations() == 100 + + torch.cuda.tunable.enable() + # set these to single iterations to keep it short but still exercise the code + torch.cuda.tunable.set_max_tuning_duration(1) + torch.cuda.tunable.set_max_tuning_iterations(1) + + make_arg = partial(make_tensor, device=device, dtype=dtype) + + for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)): + x = make_arg(size_x, noncontiguous=nctg_x) + y = make_arg(size_y, noncontiguous=nctg_y) + self.check_single_matmul(x, y) + + filename1 = torch.cuda.tunable.get_filename() + filename2 = "tunableop_results_tmp1.csv" + filename3 = "tunableop_results_tmp2.csv" + ordinal = torch.cuda.current_device() + assert filename1 == f"tunableop_results{ordinal}.csv" + assert len(torch.cuda.tunable.get_validators()) > 0 + assert len(torch.cuda.tunable.get_results()) > 0 + + assert torch.cuda.tunable.write_file() # use default filename + assert torch.cuda.tunable.write_file(filename2) # use custom, one-time filename + torch.cuda.tunable.set_filename(filename3) + assert torch.cuda.tunable.write_file() # use previously set filename + assert torch.cuda.tunable.read_file() # use previously set filename, will ignore duplicates and return True + + with open(filename1) as file1: + file1_contents = file1.read() + with open(filename2) as file2: + file2_contents = file2.read() + with open(filename3) as file3: + file3_contents = file3.read() + assert file1_contents == file2_contents + assert file1_contents == file3_contents + + # remove the files created above to avoid error 'Build left local git repository checkout dirty', ignore errors + for filename in [filename1, filename2, filename3]: + try: + import os + os.remove(filename) + finally: + pass + + # disables TunableOp, no file will be written, restore to default values + torch.cuda.tunable.enable(False) + torch.cuda.tunable.set_filename(filename1) # reset back to default filename for next unit test + torch.cuda.tunable.set_max_tuning_duration(30) + torch.cuda.tunable.set_max_tuning_iterations(100) + assert torch.cuda.tunable.is_enabled() is False, "TunableOp should be off after resetting" + assert torch.cuda.tunable.get_max_tuning_iterations() == 100 + @dtypes(torch.float, torch.complex64) def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index d4dbee20466e..5d8a45f86523 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1885,6 +1885,21 @@ def _nccl_reduce_scatter( comms: Optional[Sequence[object]], ) -> None: ... def _rocm_is_backward_pass() -> _bool: ... +def _cuda_tunableop_enable(val: _bool) -> None: ... +def _cuda_tunableop_is_enabled() -> _bool: ... +def _cuda_tunableop_tuning_enable(val: _bool) -> None: ... +def _cuda_tunableop_tuning_is_enabled() -> _bool: ... +def _cuda_tunableop_set_max_tuning_duration(duration: _int) -> None: ... +def _cuda_tunableop_get_max_tuning_duration() -> _int: ... +def _cuda_tunableop_set_max_tuning_iterations(iterations: _int) -> None: ... +def _cuda_tunableop_get_max_tuning_iterations() -> _int: ... +def _cuda_tunableop_set_filename(filename: str, insert_device_ordinal: Optional[_bool]) -> None: ... +def _cuda_tunableop_get_filename() -> str: ... +def _cuda_tunableop_write_file(filename: Optional[str]) -> _bool: ... +def _cuda_tunableop_read_file(filename: Optional[str]) -> _bool: ... +def _cuda_tunableop_write_file_on_exit(val: _bool) -> None: ... +def _cuda_tunableop_get_results() -> Tuple[str, str, str, _float]: ... +def _cuda_tunableop_get_validators() -> Tuple[str, str]: ... class _CudaDeviceProperties: name: str diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 030c5a2b5ccf..4197c2aa5e81 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -1403,6 +1404,275 @@ PyObject* THCPModule_rocm_is_backward_pass( END_HANDLE_TH_ERRORS } +PyObject* THCPModule_cuda_tunableop_enable(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkBool(arg), + "cuda_tunableop_enable expects a bool, but got ", + THPUtils_typename(arg)); + at::cuda::tunable::getTuningContext()->EnableTunableOp( + THPUtils_unpackBool(arg)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_is_enabled( + PyObject* _unused, + PyObject* noarg) { + HANDLE_TH_ERRORS + if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_tuning_enable( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkBool(arg), + "cuda_tunableop_tuning_enable expects a bool, but got ", + THPUtils_typename(arg)); + at::cuda::tunable::getTuningContext()->EnableTuning(THPUtils_unpackBool(arg)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_tuning_is_enabled( + PyObject* _unused, + PyObject* noarg) { + HANDLE_TH_ERRORS + if (at::cuda::tunable::getTuningContext()->IsTuningEnabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_write_file_on_exit( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkBool(arg), + "cuda_tunableop_write_file_on_exit expects a bool, but got ", + THPUtils_typename(arg)); + at::cuda::tunable::getTuningContext()->WriteFileOnExit( + THPUtils_unpackBool(arg)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_set_max_tuning_duration( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "cuda_tunableop_set_max_tuning_duration expects an int, but got ", + THPUtils_typename(arg)); + auto duration = static_cast(THPUtils_unpackLong(arg)); + at::cuda::tunable::getTuningContext()->SetMaxTuningDurationMs(duration); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_get_max_tuning_duration( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packInt32( + at::cuda::tunable::getTuningContext()->GetMaxTuningDurationMs()); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_set_max_tuning_iterations( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "cuda_tunableop_set_max_tuning_iterations expects an int, but got ", + THPUtils_typename(arg)); + auto iterations = static_cast(THPUtils_unpackLong(arg)); + at::cuda::tunable::getTuningContext()->SetMaxTuningIterations(iterations); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_get_max_tuning_iterations( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packInt32( + at::cuda::tunable::getTuningContext()->GetMaxTuningIterations()); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_set_filename( + PyObject* _unused, + PyObject* args) { + HANDLE_TH_ERRORS + PyObject* obj_str = nullptr; + PyObject* obj_ord = nullptr; + if (!PyArg_ParseTuple(args, "O|O", &obj_str, &obj_ord)) { + } + TORCH_CHECK( + THPUtils_checkString(obj_str), + "cuda_tunableop_set_filename expects a string, but got ", + THPUtils_typename(obj_str)); + auto filename = THPUtils_unpackString(obj_str); + bool dev = false; + if (obj_ord) { + TORCH_CHECK( + THPUtils_checkBool(obj_ord), + "cuda_tunableop_set_filename expects a bool, but got ", + THPUtils_typename(obj_ord)); + dev = THPUtils_unpackBool(obj_ord); + } + at::cuda::tunable::getTuningContext()->SetFilename(filename, dev); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_get_filename( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packString( + at::cuda::tunable::getTuningContext()->GetFilename()); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_write_file( + PyObject* _unused, + PyObject* args) { + HANDLE_TH_ERRORS + PyObject* str = nullptr; + bool success = false; + if (!PyArg_ParseTuple(args, "|O", &str)) { + } + if (str) { + TORCH_CHECK( + THPUtils_checkString(str), + "cuda_tunableop_write_file expects a string, but got ", + THPUtils_typename(str)); + auto filename = THPUtils_unpackString(str); + success = at::cuda::tunable::getTuningContext()->WriteFile(filename); + } else { + success = at::cuda::tunable::getTuningContext()->WriteFile(); + } + if (success) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_read_file( + PyObject* _unused, + PyObject* args) { + HANDLE_TH_ERRORS + PyObject* str = nullptr; + bool success = false; + if (!PyArg_ParseTuple(args, "|O", &str)) { + } + if (str) { + TORCH_CHECK( + THPUtils_checkString(str), + "cuda_tunableop_read_file expects a string, but got ", + THPUtils_typename(str)); + auto filename = THPUtils_unpackString(str); + success = at::cuda::tunable::getTuningContext()->ReadFile(filename); + } else { + success = at::cuda::tunable::getTuningContext()->ReadFile(); + } + if (success) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_get_results( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + auto results = + at::cuda::tunable::getTuningContext()->GetTuningResultsManager().Dump(); + size_t result_size = 0; + for (const auto& [op_sig, kernelmap] : results) { + result_size += kernelmap.size(); + } + THPObjectPtr outer_tuple(PyTuple_New(result_size)); + if (!outer_tuple) + throw python_error(); + size_t result_index = 0; + for (const auto& [op_sig, kernelmap] : results) { + for (const auto& [param_sig, result] : kernelmap) { + THPObjectPtr inner_tuple(PyTuple_New(4)); + if (!inner_tuple) + throw python_error(); + PyObject* obj_op_sig = THPUtils_packString(op_sig); + if (!obj_op_sig) + throw python_error(); + PyObject* obj_param_sig = THPUtils_packString(param_sig); + if (!obj_param_sig) + throw python_error(); + PyObject* obj_result_key = THPUtils_packString(result.GetKey()); + if (!obj_result_key) + throw python_error(); + PyObject* obj_result_time = PyFloat_FromDouble(result.GetTime()); + if (!obj_result_time) + throw python_error(); + PyTuple_SET_ITEM(inner_tuple.get(), 0, obj_op_sig); + PyTuple_SET_ITEM(inner_tuple.get(), 1, obj_param_sig); + PyTuple_SET_ITEM(inner_tuple.get(), 2, obj_result_key); + PyTuple_SET_ITEM(inner_tuple.get(), 3, obj_result_time); + PyTuple_SET_ITEM( + outer_tuple.get(), result_index++, inner_tuple.release()); + } + } + return outer_tuple.release(); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_get_validators( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + auto validators = at::cuda::tunable::getTuningContext() + ->GetTuningResultsValidator() + .GetAllValidators(); + THPObjectPtr outer_tuple(PyTuple_New(validators.size())); + if (!outer_tuple) + throw python_error(); + size_t validator_index = 0; + for (const auto& [key, val] : validators) { + THPObjectPtr inner_tuple(PyTuple_New(2)); + if (!inner_tuple) + throw python_error(); + PyObject* obj_key = THPUtils_packString(key); + if (!obj_key) + throw python_error(); + PyObject* obj_val = THPUtils_packString(val); + if (!obj_val) + throw python_error(); + PyTuple_SET_ITEM(inner_tuple.get(), 0, obj_key); + PyTuple_SET_ITEM(inner_tuple.get(), 1, obj_val); + PyTuple_SET_ITEM( + outer_tuple.get(), validator_index++, inner_tuple.release()); + } + return outer_tuple.release(); + END_HANDLE_TH_ERRORS +} + static PyObject* THCPModule_isCurrentStreamCapturing_wrap( PyObject* self, PyObject* noargs) { @@ -1576,6 +1846,66 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_rocm_is_backward_pass, METH_NOARGS, nullptr}, + {"_cuda_tunableop_enable", + THCPModule_cuda_tunableop_enable, + METH_O, + nullptr}, + {"_cuda_tunableop_is_enabled", + THCPModule_cuda_tunableop_is_enabled, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_tuning_enable", + THCPModule_cuda_tunableop_tuning_enable, + METH_O, + nullptr}, + {"_cuda_tunableop_tuning_is_enabled", + THCPModule_cuda_tunableop_tuning_is_enabled, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_write_file_on_exit", + THCPModule_cuda_tunableop_write_file_on_exit, + METH_O, + nullptr}, + {"_cuda_tunableop_set_max_tuning_duration", + THCPModule_cuda_tunableop_set_max_tuning_duration, + METH_O, + nullptr}, + {"_cuda_tunableop_get_max_tuning_duration", + THCPModule_cuda_tunableop_get_max_tuning_duration, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_set_max_tuning_iterations", + THCPModule_cuda_tunableop_set_max_tuning_iterations, + METH_O, + nullptr}, + {"_cuda_tunableop_get_max_tuning_iterations", + THCPModule_cuda_tunableop_get_max_tuning_iterations, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_set_filename", + THCPModule_cuda_tunableop_set_filename, + METH_VARARGS, + nullptr}, + {"_cuda_tunableop_get_filename", + THCPModule_cuda_tunableop_get_filename, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_write_file", + THCPModule_cuda_tunableop_write_file, + METH_VARARGS, + nullptr}, + {"_cuda_tunableop_read_file", + THCPModule_cuda_tunableop_read_file, + METH_VARARGS, + nullptr}, + {"_cuda_tunableop_get_results", + THCPModule_cuda_tunableop_get_results, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_get_validators", + THCPModule_cuda_tunableop_get_validators, + METH_NOARGS, + nullptr}, {nullptr}}; PyMethodDef* THCPModule_methods() { diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 2f2784f26c7a..ec4c0297a4b4 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1462,7 +1462,7 @@ def addmm_kernel_impl(*args, **kwargs): _lazy_call(_register_triton_kernels) -from . import amp, jiterator, nvtx, profiler, sparse +from . import amp, jiterator, nvtx, profiler, sparse, tunable __all__ = [ # Typed storage and tensors @@ -1575,5 +1575,6 @@ def addmm_kernel_impl(*args, **kwargs): "stream", "streams", "synchronize", + "tunable", "utilization", ] diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py new file mode 100644 index 000000000000..0f7e0a1f3725 --- /dev/null +++ b/torch/cuda/tunable.py @@ -0,0 +1,242 @@ +r""" +This module exposes a TunableOp interface. + +Some operations, such as GEMMs, could be implemented using more than one library +or more than one technique. For example, a GEMM could be implemented for CUDA or +ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and +hipblaslt libraries allow the user to query for all possible algorithms and then +choose one. How does one know which implementation is the fastest and should be +chosen? That's what TunableOp provides. + +Enabling TunableOp and Tuning Separately +======================================== + +The TunableOp feature is enabled separately from enabling the tuning phase +itself. Enabling TunableOp means that PyTorch will replace any standard +operators with their Tunable implementations. Any call to a TunableOp first +checks whether it has already been tuned for the given operator inputs. If so, +it will immediately call the tuned operation; no further tuning will take place +even when the tuning setting is enabled. Instead if no tuning result is found, +and tuning is enabled, the TunableOp will benchmark every registered +implementation of that operator for the given set of inputs and select the +fastest. + +File Input and Output +===================== + +The first time any TunableOp is invoked, the internal database of tuned +operations will be prepared by attempting to read the results from the given +file. The default filename is 'tunableop_results.csv'. To support tuning when +multiple GPUs are used across multiple processes, the GPU device ordinal is +automatically inserted into the filename to avoid multiple processes overwriting +the same file. + +If tuning is enabled and new tunings are discovered during the course of your +workload, it will also write out to this same filename with all tunings, both +the ones it read in at startup as well as the new ones found at runtime. This +can be used, for example, to build up a tunings file across many workloads by +reusing the same file. The output file is automatically created when the +application terminates. This behavior can be controlled by the C++ and Python +APIs but not the environment variables. + +Assuming you specified a filename, you'll end up with a CSV file with contents +like so:: + + Validator,PT_VERSION,2.2.0 + Validator,ROCM_VERSION,6.0.0.0-12969-1544e39 + Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7 + Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty + GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262 + GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033 + +Note the "Validator" lines. If you change a library verison, or ROCm version, or +PyTorch version, TunableOp will detect this and reject the tunings file because +the prior tunings are likely affected by other software changes. + +The remaining lines are the tuned solutions for each TunableOp encountered +during your execution. Each line consists of 4 comma-separated fields: operator +name, operator parameters, solution name, and average execution time. The +execution time is an optional field. The CSV file can be edited, but with +caution. For example, the solution name (field 3) can be changed to "Default" +and it will fall back to the original PyTorch untuned implementation. Or, in the +case of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution +index you can override the solution that TunableOp selected by replacing the +value. The operator name and parameters (fields 1 and 2) are internally named +and should not be modified. In the case of GemmTunableOp, field 1 indicates the +datatype and whether the inputs are transposed (T) or not (N) and field 2 +indicates the M, N, K input shapes. + +There is an option to enable verbose output but it is only recommended for +debugging purposes. This will produce a lot of diagnostic messages but may be +useful to see if TunableOp is being used at all. Otherwise, TunableOp is +completely silent, besides file output, unless there is a warning or error +during its use. The verbose option is only available by setting the environment +variable PYTORCH_TUNABLEOP_VEROBSE=1. + +A Note on Tuning Behavior +========================= + +Tuning an operator consists of iterating through the list or registered +implementations and profiling each one. The profile is established by running a +single implementation in a loop multiple times and taking the average execution +time. + +By default, each possible solution for a given operator will be run for either +100 iterations or as many iterations that can be run within 30ms, whichever is +smaller, and its average execution will be calculated. The fastest solution +among all that were successfully profiled will be chosen. A profile might fail +if the given solution doesn't achieve the same accuracy as the default +implementation or if the solution returns an error code. + +Current Tunable Operators +========================= + +TunableGemm for ROCm +-------------------- + +Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of +PyTorch will function correctly when using TunableOp but the only solution +available to CUDA builds is the 'Default' implementation i.e. the original +cuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm() +or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a +given set of input arguments (transa, transb, m, n, k) will attempt to use the +fastest available implementation across both rocblas and hipblaslt. + +Tuning Context +============== + +The behavior of TunableOp is currently manipulated through environment +variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the +torch.cuda.tunable python interfaces that wrap the C++ TuningContext. The +environment variables take precedence over any setting you manipulate using the +C++ or Python APIs. + +""" +from typing import Optional, Tuple + +import torch + + +__all__ = [ + "enable", + "is_enabled", + "tuning_enable", + "tuning_is_enabled", + "set_max_tuning_duration", + "get_max_tuning_duration", + "set_max_tuning_iterations", + "get_max_tuning_iterations", + "set_filename", + "get_filename", + "get_results", + "get_validators", + "write_file_on_exit", + "write_file", + "read_file", +] + + +def enable(val: bool = True) -> None: + r"""This is the big on/off switch for all TunableOp implementations.""" + torch._C._cuda_tunableop_enable(val) + + +def is_enabled() -> bool: + r"""Returns whether the TunableOp feature is enabled.""" + return torch._C._cuda_tunableop_is_enabled() + + +def tuning_enable(val: bool = True) -> None: + r"""Enable tuning of TunableOp implementations. + + When enabled, if a tuned entry isn't found, run the tuning step and record + the entry. + """ + torch._C._cuda_tunableop_tuning_enable(val) + + +def tuning_is_enabled() -> bool: + r"""Returns whether TunableOp implementations can be tuned.""" + return torch._C._cuda_tunableop_tuning_is_enabled() + + +def set_max_tuning_duration(duration: int) -> None: + r"""Set max time in milliseconds to spend tuning a given solution. + + If both max tuning duration and iterations are set, the smaller of the two + will be honored. At minimum 1 tuning iteration will always be run. + """ + torch._C._cuda_tunableop_set_max_tuning_duration(duration) + + +def get_max_tuning_duration() -> int: + r"""Get max time to spend tuning a given solution.""" + return torch._C._cuda_tunableop_get_max_tuning_duration() + + +def set_max_tuning_iterations(iterations: int) -> None: + r"""Set max number of iterations to spend tuning a given solution. + + If both max tuning duration and iterations are set, the smaller of the two + will be honored. At minimum 1 tuning iteration will always be run. + """ + torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) + + +def get_max_tuning_iterations() -> int: + r"""Get max iterations to spend tuning a given solution.""" + return torch._C._cuda_tunableop_get_max_tuning_iterations() + + +def set_filename(filename: str, insert_device_ordinal: bool = False) -> None: + r"""Set the filename to use for input/output of tuning results. + + If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal + will be added to the given filename automatically. This can be used in a + 1-process-per-gpu cenario to ensure all processes write to a separate file. + """ + torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) + + +def get_filename() -> str: + r"""Get the results filename.""" + return torch._C._cuda_tunableop_get_filename() + + +def get_results() -> Tuple[str, str, str, float]: + r"""Return all TunableOp results.""" + return torch._C._cuda_tunableop_get_results() + + +def get_validators() -> Tuple[str, str]: + r"""Return the TunableOp validators.""" + return torch._C._cuda_tunableop_get_validators() + + +def write_file_on_exit(val: bool) -> None: + r"""During Tuning Context destruction, write file to disk. + + This is useful as a final flush of your results to disk if your application + terminates as result of normal operation or an error. Manual flushing of + your results can be achieved by manually calling ``write_file()``.""" + torch._C._cuda_tunableop_write_file_on_exit(val) + + +def write_file(filename: Optional[str] = None) -> bool: + r"""Write results to a CSV file. + + If :attr:`filename` is not given, ``get_filename()`` is called. + """ + if filename is None: + filename = get_filename() + return torch._C._cuda_tunableop_write_file(filename) + + +def read_file(filename: Optional[str] = None) -> bool: + r"""Read results from a TunableOp CSV file. + + If :attr:`filename` is not given, ``get_filename()`` is called. + """ + if filename is None: + filename = get_filename() + return torch._C._cuda_tunableop_read_file(filename) From ac568fc0077748264538fda9ce99c856b7ca944a Mon Sep 17 00:00:00 2001 From: eqy Date: Mon, 3 Jun 2024 22:42:02 +0000 Subject: [PATCH 286/706] [CUDNN] Remove defunct cuDNN V8 API build flag (#120006) The flag basically does nothing following #95722 Let's see if the quantization tests break CC @malfet @atalmanagement Pull Request resolved: https://github.com/pytorch/pytorch/pull/120006 Approved by: https://github.com/malfet --- cmake/Summary.cmake | 1 - defs.bzl | 1 - test/quantization/core/test_quantized_op.py | 49 ++++++++++----------- 3 files changed, 23 insertions(+), 28 deletions(-) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 99b6521328d6..bc15f70ad1f5 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -71,7 +71,6 @@ function(caffe2_print_configuration_summary) message(STATUS " Split CUDA : ${BUILD_SPLIT_CUDA}") message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}") message(STATUS " USE_CUDNN : ${USE_CUDNN}") - message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}") message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}") message(STATUS " CUDA version : ${CUDA_VERSION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") diff --git a/defs.bzl b/defs.bzl index 6ea4b1219325..5e8923556af0 100644 --- a/defs.bzl +++ b/defs.bzl @@ -33,7 +33,6 @@ default_compiler_flags = [ "-DTH_INDEX_BASE=0", "-DMAGMA_V2", "-DNO_CUDNN_DESTROY_HANDLE", - "-DUSE_EXPERIMENTAL_CUDNN_V8_API", # enable cudnn v8 api "-DUSE_FBGEMM", "-DUSE_PYTORCH_QNNPACK", # The dynamically loaded NVRTC trick doesn't work in fbcode, diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index d59f1fffd926..6671b6634e00 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -21,6 +21,7 @@ import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() +from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2, IS_SANDCASTLE from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN @@ -31,10 +32,12 @@ qengine_is_onednn, ) from torch.ao.quantization import PerChannelMinMaxObserver -from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDA +from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDNN_VERSION, TEST_CUDA from torch.testing._internal.optests import opcheck import torch.backends.xnnpack +from torch.utils.cpp_extension import ROCM_HOME + from typing import Optional np_dtype = { @@ -43,6 +46,8 @@ torch.qint32 : np.int32 } +TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None + class PointwisePostOp(NamedTuple): binary_attr : str = "none" alpha : float = 1.0 @@ -905,9 +910,8 @@ def test_qadd_relu_same_qparams(self): """Tests the correctness of the cudnn add and add_relu op (Similar to test_qadd_relu_different_qparams, will probably merge in the future)""" @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the test_qadd_relu_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qadd_relu_cudnn(self): dtype = torch.qint8 add_relu = torch.ops.quantized.add_relu @@ -940,9 +944,8 @@ def test_qadd_relu_cudnn(self): """Tests the correctness of the cudnn add and add_relu op for nhwc format""" @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the test_qadd_relu_cudnn_nhwc op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qadd_relu_cudnn_nhwc(self): dtype = torch.qint8 add_relu = torch.ops.quantized.add_relu @@ -1379,7 +1382,7 @@ def test_max_pool1d(self, X, kernel, stride, dilation, padding, ceil_mode): self.assertEqual(a_ref, a_hat.dequantize(), msg="ops.quantized.max_pool1d results are off") - # TODO: merge this test with test_max_pool2d when USE_EXPERIMENTAL_CUDNN_V8_API flag is enabled in CI + # TODO: merge this test with test_max_pool2d """Tests 2D cudnn max pool operation on quantized tensors.""" @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4, min_side=1, max_side=10), @@ -1394,9 +1397,8 @@ def test_max_pool1d(self, X, kernel, stride, dilation, padding, ceil_mode): padding=st.integers(0, 2), ceil_mode=st.booleans()) @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qconv2d_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(TEST_CUDNN_VERSION <= 90100, "cuDNN maxpool2d mishandles -128 before v90100") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_max_pool2d_cudnn(self, X, kernel, stride, dilation, padding, ceil_mode): X, (scale, zero_point, torch_type) = X assume(kernel // 2 >= padding) # Kernel cannot be overhanging! @@ -4050,9 +4052,8 @@ def test_qlinear_with_input_q_dq_qweight_dq_output_fp32( use_channelwise=st.sampled_from([False])) # channelwise currently not supported for qlinear cudnn @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qlinear_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") # TODO: check with yang regarding CUDNN flags def test_qlinear_cudnn(self, batch_size, input_channels, output_channels, use_bias, use_relu, use_multi_dim_input, use_channelwise): @@ -5427,9 +5428,8 @@ def test_qconv2d_add_relu(self): use_channelwise=st.sampled_from([False])) @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qconv2d_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qconv2d_cudnn( self, batch_size, @@ -5510,9 +5510,8 @@ def test_qconv2d_cudnn( use_channelwise=st.sampled_from([False])) @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qconv2d_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qconv2d_relu_cudnn( self, batch_size, @@ -6245,9 +6244,8 @@ def test_qconv1d_relu( use_channelwise=st.sampled_from([False])) @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qconv1d_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qconv1d_cudnn( self, batch_size, @@ -6319,9 +6317,8 @@ def test_qconv1d_cudnn( use_channelwise=st.sampled_from([False])) @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qconv1d_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qconv1d_relu_cudnn( self, batch_size, From b42cfcabc4e7216c5941ebab19c5000e5ecea4eb Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Mon, 3 Jun 2024 15:03:30 -0400 Subject: [PATCH 287/706] Lift jagged -> padded dense forward / backward kernels from fbgemm_gpu (#125946) PyTorch can't depend on `fbgemm_gpu` as a dependency because `fbgemm_gpu` already has a dependency on PyTorch. So this PR copy / pastes kernels from `fbgemm_gpu`: * `dense_to_jagged_forward()` as CUDA registration for new ATen op `_padded_dense_to_jagged_forward()` * `jagged_to_padded_dense_forward()` as CUDA registration for new ATen op `_jagged_to_padded_dense_forward()` CPU impls for these new ATen ops will be added in a follow-up PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125946 Approved by: https://github.com/davidberard98 --- aten/src/ATen/native/native_functions.yaml | 10 + .../cuda/NestedTensorTransformerFunctions.cu | 1081 +++++++++++++++++ ...asDecompTest.test_has_decomposition.expect | 2 + test/test_nestedtensor.py | 25 + 4 files changed, 1118 insertions(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a051f43e87eb..54b12a9a0b0c 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14644,6 +14644,16 @@ NestedTensorCUDA: NestedTensor_to_padded_tensor_cuda autogen: to_padded_tensor.out +- func: _jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor + variants: function + dispatch: + CUDA: _fbgemm_jagged_to_padded_dense_forward + +- func: _padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor + variants: function + dispatch: + CUDA: _fbgemm_dense_to_jagged_forward_symint + - func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor dispatch: NestedTensorCPU: NestedTensor_softmax_dropout diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu index 56cac2a89803..c425cf504dc9 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu @@ -1,3 +1,4 @@ +#include #include #include @@ -11,6 +12,7 @@ #include #include +#include #include #include @@ -462,5 +464,1084 @@ template void add_padding_kernelLauncher( const int batch_size, const int output_batch_size); +// NB: The following code covers jagged <-> padded dense conversions and was lifted +// from fbgemm_gpu. For more details, see +// https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/src/jagged_tensor_ops + +// Passing lambda exp argument by value instead of by reference to avoid +// "internal compiler error: in maybe_undo_parenthesized_ref" error for specific +// compiler version. +#define JAGGED_TENSOR_DISPATCH_DIMS() \ + AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [=] { \ + switch (num_jagged_dim) { \ + case 1: \ + INVOKE_KERNEL_WITH_DIM(1); \ + break; \ + case 2: \ + INVOKE_KERNEL_WITH_DIM(2); \ + break; \ + case 3: \ + INVOKE_KERNEL_WITH_DIM(3); \ + break; \ + case 4: \ + INVOKE_KERNEL_WITH_DIM(4); \ + break; \ + case 5: \ + INVOKE_KERNEL_WITH_DIM(5); \ + break; \ + default: \ + TORCH_CHECK( \ + false, "unsupported number of jagged dim ", num_jagged_dim); \ + } \ + }); + +inline std::string torch_tensor_device_name(const at::Tensor& ten) { + return c10::DeviceTypeName(ten.device().type()); +} + +inline std::string torch_tensor_device_name( + const c10::optional& ten) { + if (ten.has_value()) { + return torch_tensor_device_name(ten.value()); + } else { + return "N/A"; + } +} + +inline bool torch_tensor_on_cuda_gpu_check(const at::Tensor& ten) { + return ten.is_cuda(); +} + +inline bool torch_tensor_on_cuda_gpu_check( + const c10::optional& ten) { + return !ten.has_value() || torch_tensor_on_cuda_gpu_check(ten.value()); +} + +#define TENSOR_ON_CUDA_GPU(x) \ + TORCH_CHECK( \ + torch_tensor_on_cuda_gpu_check(x), \ + #x " must be a CUDA tensor; it is currently on device ", \ + torch_tensor_device_name(x)) + +// A wrapper class for passing dynamically sized dimension information (e.g. +// tensor.dims()) from the host to device. +constexpr size_t kStackArrayMaxDims = 5; + +template +struct StackArray { + T vals[kStackArrayMaxDims]; + size_t ndim; +}; + +// Warp size +#ifdef USE_ROCM +static constexpr int32_t kWarpSize = 64; +#else +static constexpr int32_t kWarpSize = 32; +#endif +// Max thread num in one thread block +static constexpr int32_t kMaxThreads = 1024; + +#define DEVICE_INLINE __device__ C10_ALWAYS_INLINE + +__host__ DEVICE_INLINE int32_t div_round_up(int32_t a, int32_t b) { + return (a + b - 1) / b; +} + +__host__ DEVICE_INLINE int32_t round_down(int32_t a, int32_t b) { + return a / b * b; +} + +inline std::tuple> check_shape_and_partition_( + const Tensor& values, + const std::vector& offsets, + const Tensor& dense_tensor) { + const int outer_dense_size = dense_tensor.size(0); + TORCH_CHECK( + outer_dense_size == offsets[0].numel() - 1, + "outer_dense_size, ", + outer_dense_size, + " != offsets[0].numel() - 1, ", + offsets[0].numel() - 1); + const int inner_dense_size = dense_tensor.size(-1); + TORCH_CHECK( + inner_dense_size == values.size(-1), + "inner_dense_size, ", + inner_dense_size, + " != values.size(-1), ", + values.size(-1)); + const int jagged_folded_size = + dense_tensor.numel() / (outer_dense_size * inner_dense_size); + + const int threads_x = + inner_dense_size >= kWarpSize / 2 ? kWarpSize : inner_dense_size; + const int threads_y = kMaxThreads / kWarpSize; + const dim3 blocks( + div_round_up(outer_dense_size * jagged_folded_size, threads_y)); + + StackArray jagged_dims_tensor; + const int num_jagged_dim = dense_tensor.dim() - 2; + TORCH_CHECK(num_jagged_dim <= kStackArrayMaxDims); + jagged_dims_tensor.ndim = num_jagged_dim; + std::memcpy( + &(jagged_dims_tensor.vals[0]), + dense_tensor.sizes().data() + 1, + num_jagged_dim * sizeof(int64_t)); + return {dim3(threads_x, threads_y), blocks, jagged_dims_tensor}; +} + +template +DEVICE_INLINE bool walk_down_tensor_storage_tree_( + int& offset, + const int flattened_jagged_idx, + const StackArray& jagged_dims, + const StackArray& x_offsets) { + // compute coorindates + int jagged_coords[NUM_JAGGED_DIM]; + int j_temp = flattened_jagged_idx; +#pragma unroll + for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) { + const int jagged_size = jagged_dims.vals[d]; + jagged_coords[d] = j_temp % jagged_size; + j_temp /= jagged_size; + } + + // walk down the tree + bool is_zero = false; +#pragma unroll + for (int d = 0; d < NUM_JAGGED_DIM; ++d) { + const int begin = x_offsets.vals[d][offset]; + const int end = x_offsets.vals[d][offset + 1]; + if (jagged_coords[d] >= end - begin) { + is_zero = true; + break; + } + offset = begin + jagged_coords[d]; + } + return is_zero; +} + +// output = f(x, y) where x is jagged, y is dense, and output is dense. +// A generic elementwise operation between a jagged tensor and a dense tensor +// This kernel assumes jagged dims are clustered together, preceded by outer +// dense dimensions and followed by inner dense dimensions. +// The outer/inner dense dimensions, and jagged dimensions in between are +// assumed to be folded so physically the dense tensor is 3D and the value of +// jagged tensor is 2D. +// To support arbitrary number of jagged dimensions, we pass a vector of +// pointers to offset tensors (this is ugly and probably we can use nested +// tensor here). +// This kernel parallelizes the (folded) inner dense dimension across +// blockDim.x so the inner dense dimension should be similar to or bigger than +// warp size. +// We rely on compiler unrolling the compiler time constant NUM_JAGGED_DIM. +template +__global__ +__launch_bounds__(kMaxThreads) void jagged_dense_elementwise_dense_output_kernel_( + const at::PackedTensorAccessor32 + x_values, + StackArray x_offsets, + const at::PackedTensorAccessor32 y, + at::PackedTensorAccessor32 output, + StackArray jagged_dims, + F f, + const scalar_t padding_value) { + const int outer_dense_size = y.size(0); + const int jagged_folded_size = y.size(1); + const int inner_dense_size = y.size(2); + + const int outer_begin = blockIdx.x * blockDim.y + threadIdx.y; + const int outer_stride = gridDim.x * blockDim.y; + for (int outer = outer_begin; outer < outer_dense_size * jagged_folded_size; + outer += outer_stride) { + const int oidx = outer / jagged_folded_size; + const int jidx = outer % jagged_folded_size; + + int offset = oidx; + const bool is_zero = walk_down_tensor_storage_tree_( + offset, jidx, jagged_dims, x_offsets); + + if (is_zero) { + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output[oidx][jidx][2 * iidx] = + f(padding_value, y[oidx][jidx][2 * iidx]); + output[oidx][jidx][2 * iidx + 1] = + f(padding_value, y[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output[oidx][jidx][2 * iidx] = + f(padding_value, y[oidx][jidx][2 * iidx]); + } + } else { + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output[oidx][jidx][2 * iidx] = + f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]); + output[oidx][jidx][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], y[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output[oidx][jidx][2 * iidx] = + f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]); + } + } + } +} + +template +void jagged_dense_elementwise_dense_output_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output, + F f, + const scalar_t padding_value = static_cast(0)) { + TENSOR_ON_CUDA_GPU(x_values); + for (auto& x_offset : x_offsets) { + TENSOR_ON_CUDA_GPU(x_offset); + } + + const int num_jagged_dim = y.dim() - 2; + TORCH_CHECK( + x_offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + x_offsets.size(), + " != num_jagged_dim ", + num_jagged_dim); + + if (y.numel() == 0) { + return; + } + + dim3 threads, blocks; + StackArray jagged_dims_tensor; + std::tie(threads, blocks, jagged_dims_tensor) = + check_shape_and_partition_(x_values, x_offsets, y); + + // Canonicalize y and output to 3D, collapsing jagged dimensions. + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + Tensor output_reshaped = output.view(y_reshaped.sizes()); + +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + } \ + jagged_dense_elementwise_dense_output_kernel_ \ + <<>>( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + y_reshaped \ + .packed_accessor32(), \ + output_reshaped \ + .packed_accessor32(), \ + jagged_dims_tensor, \ + f, \ + padding_value); \ + } + + JAGGED_TENSOR_DISPATCH_DIMS(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#undef INVOKE_KERNEL_WITH_DIM +} + +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + dim3 threads, blocks; \ + StackArray jagged_dims_tensor; \ + std::tie(threads, blocks, jagged_dims_tensor) = \ + check_shape_and_partition_(x_values, x_offsets, y); \ + blocks.x = div_round_up(x_values.size(0), threads.y); \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + StackArray x_offset_sizes; \ + x_offset_sizes.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + x_offset_sizes.vals[d] = x_offsets[d].numel(); \ + } \ + jagged_dense_dense_elementwise_jagged_output_kernel_< \ + NUM_JAGGED_DIM, \ + index_t><<>>( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + x_offset_sizes, \ + y_reshaped.packed_accessor32(), \ + y_reshaped.packed_accessor32(), \ + output_values.packed_accessor32(), \ + jagged_dims_tensor, \ + [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \ + -> scalar_t { return f(x, y); }); \ + } + +template +__global__ +__launch_bounds__(kMaxThreads) void jagged_dense_dense_elementwise_jagged_output_kernel_( + const at::PackedTensorAccessor32 + x_values, + StackArray x_offsets, + StackArray x_offsets_sizes, + const at::PackedTensorAccessor32 y_0, + const at::PackedTensorAccessor32 y_1, + at::PackedTensorAccessor32 + output_values, + StackArray jagged_dims, + F f) { + const int outer_dense_size = y_0.size(0); + const int inner_dense_size = y_0.size(2); + const int nnz = x_values.size(0); + + const int offset_begin = blockIdx.x * blockDim.y + threadIdx.y; + const int offset_stride = gridDim.x * blockDim.y; + for (int offset = offset_begin; offset < nnz; offset += offset_stride) { + int offset_temp = offset; + int jidx = 0; + bool truncated = false; + int dim_prod = 1; +#pragma unroll + for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) { + // Binary search the first that is bigger than offset + int count = x_offsets_sizes.vals[d] - 1; + int first = 1; + while (count > 0) { + int idx = first; + int step = count / 2; + idx += step; + if (x_offsets.vals[d][idx] <= offset_temp) { + first = ++idx; + count -= step + 1; + } else { + count = step; + } + } + + --first; + int coord = offset_temp - x_offsets.vals[d][first]; + if (coord >= jagged_dims.vals[d]) { + truncated = true; + break; + } + jidx += coord * dim_prod; + dim_prod *= jagged_dims.vals[d]; + offset_temp = first; + } + + if (offset_temp >= outer_dense_size) { + // This can happen when values have more elements than the last element of + // offset + truncated = true; + } + if (!truncated) { + const int oidx = offset_temp; + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output_values[offset][2 * iidx] = + f(x_values[offset][2 * iidx], + y_0[oidx][jidx][2 * iidx], + y_1[oidx][jidx][2 * iidx]); + output_values[offset][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], + y_0[oidx][jidx][2 * iidx + 1], + y_1[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_values[offset][2 * iidx] = + f(x_values[offset][2 * iidx], + y_0[oidx][jidx][2 * iidx], + y_1[oidx][jidx][2 * iidx]); + } + } else { + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0); + output_values[offset][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], 0, 0); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0); + } + } + } +} + +///@addtogroup jagged-tensor-ops-cuda +template +void jagged_dense_elementwise_jagged_output_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output_values, + F f) { + TENSOR_ON_CUDA_GPU(x_values); + for (auto& x_offset : x_offsets) { + TENSOR_ON_CUDA_GPU(x_offset); + } + + const int num_jagged_dim = y.dim() - 2; + TORCH_CHECK( + x_offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + x_offsets.size(), + " != num_jagged_dim, ", + num_jagged_dim); + + if (y.numel() == 0 || x_values.numel() == 0) { + return; + } + + // Canonicalize y to 3D, collapsing jagged dimensions. + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + + JAGGED_TENSOR_DISPATCH_DIMS(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +#undef INVOKE_KERNEL_WITH_DIM + +template +struct SharedMemory; + +template <> +struct SharedMemory { + __device__ int64_t* getPointer() { + extern __shared__ int64_t s_int64_t[]; + return s_int64_t; + } +}; + +template <> +struct SharedMemory { + __device__ int32_t* getPointer() { + extern __shared__ int32_t s_int32_t[]; + return s_int32_t; + } +}; + +template +__global__ void jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_( + const at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor32 rows, + at::PackedTensorAccessor32 cols, + int nnz, + int B) { + struct SharedMemory smem; + index_t* offsets_sh = smem.getPointer(); + + for (int i = threadIdx.x; i < B + 1; i += blockDim.x) { + offsets_sh[i] = offsets[i]; + } + __syncthreads(); + int row = threadIdx.x + blockIdx.x * blockDim.x; + if (row >= nnz) + return; + int first = -1; + int count = B - 1; + first = 1; + while (count > 0) { + int idx = first; + int step = count / 2; + idx += step; + if (offsets_sh[idx] <= row) { + first = ++idx; + count -= step + 1; + } else { + count = step; + } + } + --first; + + int dense_row = first; + int offset = offsets_sh[dense_row]; + int dense_col = row - offset; + rows[row] = dense_row; + cols[row] = dense_col; +} + +struct VecType128 { + typedef float4 TType; // Transaction Type + typedef struct __align__(16) { + __half a, b, c, d, w, x, y, z; + } + half8; + + union Data { + half8 val; + TType mask; + } data; + + __device__ VecType128() { + data.mask = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + } +}; + +struct VecType64 { + typedef float2 TType; // Transaction Type + typedef struct __align__(8) { + __half a, b, c, d; + } + half4; + + union Data { + half4 val; + TType mask; + } data; + + __device__ VecType64() { + data.mask = make_float2(0.0f, 0.0f); + } +}; + +struct VecType32 { + typedef float TType; // Transaction Type + + union Data { + __half2 val; + TType mask; + } data; + + __device__ VecType32() { + data.mask = 0.0f; + } +}; + +template +__device__ void f128( + VecType128& v_out, + const VecType128& x, + const VecType128& y0, + const VecType128& y1, + F f) { + v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a); + v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b); + v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c); + v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d); + v_out.data.val.w = f(x.data.val.w, y0.data.val.w, y1.data.val.w); + v_out.data.val.x = f(x.data.val.x, y0.data.val.x, y1.data.val.x); + v_out.data.val.y = f(x.data.val.y, y0.data.val.y, y1.data.val.y); + v_out.data.val.z = f(x.data.val.z, y0.data.val.z, y1.data.val.z); +} + +template +__device__ void f64( + VecType64& v_out, + const VecType64& x, + const VecType64& y0, + const VecType64& y1, + F f) { + v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a); + v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b); + v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c); + v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d); +} + +template +__device__ void f32( + VecType32& v_out, + const VecType32& x, + const VecType32& y0, + const VecType32& y1, + F f) { + v_out.data.val = __halves2half2( + f(__low2half(x.data.val), + __low2half(y0.data.val), + __low2half(y1.data.val)), + f(__high2half(x.data.val), + __high2half(y0.data.val), + __high2half(y1.data.val))); +} + +template +__device__ void +fh(__half& v_out, const __half& x, const __half& y0, const __half& y1, F f) { + v_out = f(x, y0, y1); +} + +template +__global__ void jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_( + at::PackedTensorAccessor32 values, + const at::PackedTensorAccessor32 + x_values, + const at::PackedTensorAccessor32 y0, + const at::PackedTensorAccessor32 y1, + const at::PackedTensorAccessor32 rows, + const at::PackedTensorAccessor32 cols, + const int nnz, + const int E, + F f) { + int values_row = threadIdx.y + blockIdx.y * blockDim.y; + if (values_row >= nnz) + return; + for (int real_row = values_row; real_row < nnz; + real_row += blockDim.y * gridDim.y) { + int dense_row = rows[real_row]; + int dense_col = cols[real_row]; + __half* values_ptr = reinterpret_cast<__half*>(&values[real_row][0]); + const __half* x_ptr = + reinterpret_cast(&x_values[real_row][0]); + const __half* y0_ptr = + reinterpret_cast(&y0[dense_row][dense_col][0]); + const __half* y1_ptr = + reinterpret_cast(&y1[dense_row][dense_col][0]); + if ((dense_col < y0.size(1)) && (dense_row < y0.size(0)) && + (dense_col < y1.size(1)) && (dense_row < y1.size(0)) && + (dense_col >= 0) && (dense_row >= 0)) { + for (int tid = threadIdx.x; tid < E / 8; tid += blockDim.x) { + VecType128 v_x, v_out, v_y0, v_y1; + v_x.data.mask = + (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f128(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 8) * 8; tid < E / 4; + tid += blockDim.x) { + VecType64 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f64(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 4) * 4; tid < E / 2; + tid += blockDim.x) { + VecType32 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f32(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 2) * 2; tid < E; tid += blockDim.x) { + __half v_x, v_out, v_y0, v_y1; + v_x = static_cast<__half>(x_ptr[tid]); + v_y0 = static_cast<__half>(y0_ptr[tid]); + v_y1 = static_cast<__half>(y1_ptr[tid]); + fh(v_out, v_x, v_y0, v_y1, f); + values_ptr[tid] = v_out; + } + } else { + for (int tid = threadIdx.x; tid < E / 8; tid += blockDim.x) { + VecType128 v_x, v_out, v_y0, v_y1; + v_x.data.mask = + (reinterpret_cast(x_ptr))[tid]; + f128(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 8) * 8; tid < E / 4; + tid += blockDim.x) { + VecType64 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + f64(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 4) * 4; tid < E / 2; + tid += blockDim.x) { + VecType32 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + f32(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 2) * 2; tid < E; tid += blockDim.x) { + __half v_x, v_out, v_y0, v_y1; + v_x = static_cast<__half>(x_ptr[tid]); + fh(v_out, v_x, v_y0, v_y1, f); + values_ptr[tid] = v_out; + } + } + } +} + +// Check to see if the inputs to the op are amenable to the fast path +inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt( + const int& num_jagged_dim, + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y_0_reshaped, + const Tensor& y_1_reshaped, + const Tensor& output_values) { + bool matches = true; + matches &= (num_jagged_dim == 1); + + // Unit stride embedding dim + matches &= (x_values.stride(-1) == 1); + matches &= (output_values.stride(-1) == 1); + matches &= (y_0_reshaped.stride(-1) == 1); + matches &= (y_1_reshaped.stride(-1) == 1); + + // Each row is aligned to 128-bit + matches &= (x_values.stride(-2) % 8 == 0); + matches &= (output_values.stride(-2) % 8 == 0); + matches &= (y_0_reshaped.stride(-2) % 8 == 0); + matches &= (y_1_reshaped.stride(-2) % 8 == 0); + + // Base addresses aligned to 128-bit + matches &= (reinterpret_cast(x_values.data_ptr()) % 16 == 0); + matches &= (reinterpret_cast(output_values.data_ptr()) % 16 == 0); + matches &= (reinterpret_cast(y_0_reshaped.data_ptr()) % 16 == 0); + matches &= (reinterpret_cast(y_1_reshaped.data_ptr()) % 16 == 0); + + // Rows and col fit into int32_t + matches &= (y_0_reshaped.size(0) < INT_MAX); + matches &= (y_0_reshaped.size(1) < INT_MAX); + + int max_shared_bytes; +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaDeviceGetAttribute( + &max_shared_bytes, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + y_0_reshaped.get_device())); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif + int shared_kb = max_shared_bytes >> 10; +#ifndef USE_ROCM + // Use 2/3 of the available GPU shared mem; leave rooms for L1$. + int used_shared_kb = round_down(shared_kb * 2 / 3, 16); + TORCH_CHECK(used_shared_kb > 0); +#else + // MI100 has independent shared mem and L1 + int used_shared_kb = shared_kb; +#endif + int used_shared_bytes = used_shared_kb << 10; + AT_DISPATCH_INDEX_TYPES( + x_offsets[0].scalar_type(), "check_shared_memory", [&] { + auto B = y_0_reshaped.size(0); + // the default shared memory on V100/A100/H100 is 48 KB from + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-8-x + if ((B + 1) * sizeof(index_t) >= used_shared_bytes) { + matches = false; + } + }); + return matches; +} + +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + dim3 threads, blocks; \ + StackArray jagged_dims_tensor; \ + std::tie(threads, blocks, jagged_dims_tensor) = \ + check_shape_and_partition_(x_values, x_offsets, y); \ + blocks.x = div_round_up(x_values.size(0), threads.y); \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + StackArray x_offset_sizes; \ + x_offset_sizes.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + x_offset_sizes.vals[d] = x_offsets[d].numel(); \ + } \ + jagged_dense_dense_elementwise_jagged_output_kernel_< \ + NUM_JAGGED_DIM, \ + index_t><<>>( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + x_offset_sizes, \ + y_reshaped.packed_accessor32(), \ + y_reshaped.packed_accessor32(), \ + output_values.packed_accessor32(), \ + jagged_dims_tensor, \ + [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \ + -> scalar_t { return f(x, y); }); \ + } + +inline int calc_used_shared_bytes(const int device) { + int max_shared_bytes; +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaDeviceGetAttribute( + &max_shared_bytes, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device)); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif + int shared_kb = max_shared_bytes >> 10; +#ifndef USE_ROCM + // Use 2/3 of the available GPU shared mem; leave rooms for L1$. + int used_shared_kb = round_down(shared_kb * 2 / 3, 16); + TORCH_CHECK(used_shared_kb > 0); +#else + // MI100 has independent shared mem and L1 + int used_shared_kb = shared_kb; +#endif + int used_shared_bytes = used_shared_kb << 10; + return used_shared_bytes; +} + +template +inline void set_max_dynamic_shared_mem_size_for_opt_search_kernel(const int used_shared_bytes) { +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute( + jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< + index_t>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + used_shared_bytes)); // V100: 64 KB; A100: 96 KB; H100: 144 KB +#endif +} + +///@addtogroup jagged-tensor-ops-cuda +template +void jagged_dense_elementwise_jagged_output_opt_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output_values, + F f) { + TENSOR_ON_CUDA_GPU(x_values); + for (auto& x_offset : x_offsets) { + TENSOR_ON_CUDA_GPU(x_offset); + } + + const int num_jagged_dim = y.dim() - 2; + TORCH_CHECK( + x_offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + x_offsets.size(), + " != num_jagged_dim, ", + num_jagged_dim); + + if (y.numel() == 0 || x_values.numel() == 0) { + return; + } + + // Canonicalize y to 3D, collapsing jagged dimensions. + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + if (jagged_dense_dense_elementwise_jagged_output_matches_opt( + num_jagged_dim, + x_values, + x_offsets, + y_reshaped, + y_reshaped, + output_values)) { + AT_DISPATCH_INDEX_TYPES( + x_offsets[0].scalar_type(), "jagged_indices_fast_path", [=] { + auto nnz = output_values.size(0); + auto B = y_reshaped.size(0); + auto E = y_reshaped.size(2); + Tensor t_rows_after_bs = at::empty( + {nnz}, + at::TensorOptions().dtype(at::kInt).device( + at::kCUDA, at::cuda::current_device())); + Tensor t_cols_after_bs = at::empty( + {nnz}, + at::TensorOptions().dtype(at::kInt).device( + at::kCUDA, at::cuda::current_device())); + + // Binary search + size_t dynamic_smem_size = (B + 1) * sizeof(index_t); + auto cur_max_shared_bytes = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; + if (dynamic_smem_size > cur_max_shared_bytes) { + int used_shared_bytes = calc_used_shared_bytes(y_reshaped.get_device()); + set_max_dynamic_shared_mem_size_for_opt_search_kernel(used_shared_bytes); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + TORCH_CHECK(dynamic_smem_size <= used_shared_bytes); + } + dim3 threads_bs = dim3(1024, 1, 1); + dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1); + jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< + index_t> + <<>>( + x_offsets[0] + .packed_accessor32(), + t_rows_after_bs + .packed_accessor32(), + t_cols_after_bs + .packed_accessor32(), + nnz, + B); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Gather kernel + dim3 threads = dim3(16, 16, 1); + dim3 blocks = dim3(1, div_round_up(nnz, threads.y), 1); + if (blocks.y > 65535) { + blocks.y = 65535; + } + jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_< + index_t> + <<>>( + output_values + .packed_accessor32(), + x_values + .packed_accessor32(), + y_reshaped + .packed_accessor32(), + y_reshaped + .packed_accessor32(), + t_rows_after_bs + .packed_accessor32(), + t_cols_after_bs + .packed_accessor32(), + nnz, + E, + [f] __device__(__half x, __half y0, __half) -> __half { + // NB: added the static_casts here + return static_cast<__half>( + f(static_cast(x), static_cast(y0)) + ); + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); // AT_DISPATCH + } else { + JAGGED_TENSOR_DISPATCH_DIMS(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +at::Tensor _fbgemm_jagged_to_padded_dense_forward( + const Tensor& values, + TensorList offsets, + c10::IntArrayRef max_lengths, + const double padding_value) { + const size_t num_jagged_dim = offsets.size(); + TORCH_CHECK( + max_lengths.size() == num_jagged_dim, + "max_lengths.size(), ", + max_lengths.size(), + " != num_jagged_dim, ", + num_jagged_dim); + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(values.get_device()); + + const Tensor values_canonicalized = values.view( + {values.size(0), + std::accumulate( + values.sizes().begin() + 1, + values.sizes().end(), + 1, + std::multiplies())}); + at::SymDimVector padded_values_shape({at::SymInt(offsets[0].size(0) - 1)}); + padded_values_shape.insert( + padded_values_shape.end(), max_lengths.begin(), max_lengths.end()); + + // Canonicalize padded_values by unsqueeze the last dim if the inner dense + // dimension is 1 and folded. + const bool D_folded = values.dim() == 1; + if (!D_folded) { + padded_values_shape.push_back(values.size(-1)); + } + Tensor padded_values = + at::empty_symint(padded_values_shape, values.options()); + Tensor padded_values_view = + D_folded ? padded_values.unsqueeze(-1) : padded_values; + + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + values.scalar_type(), + "jagged_to_padded_dense", + [&] { + jagged_dense_elementwise_dense_output_( + values_canonicalized, + offsets.vec(), + padded_values_view, // dummy not used in the lambda function + padded_values_view, + [] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t { + return x; + }, + static_cast(padding_value)); + }); + + return padded_values; +} + +#define DISPATCH_DENSE_TO_JAGGED_CASE(TYPE) \ + AT_DISPATCH_CASE(TYPE, [&] { \ + jagged_dense_elementwise_jagged_output_opt_( \ + values, \ + offsets.vec(), \ + dense, \ + output, \ + [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { \ + return y; \ + }); \ + }) + +Tensor _fbgemm_dense_to_jagged_forward_symint( + const Tensor& dense, + TensorList offsets, + c10::optional total_L) { + // D is the embedding dimension + auto D = dense.size(-1); + + // If total_L is not given then compute it + at::SymInt total_L_computed; + if (total_L.has_value()) { + total_L_computed = total_L.value(); + } else { + total_L_computed = (int64_t)offsets.back().max().item(); + } + auto values = at::empty_symint({total_L_computed, D}, dense.options()); + auto output = at::empty_like(values); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(dense.get_device()); + + // clang-format off + AT_DISPATCH_SWITCH( + values.scalar_type(), + "dense_to_jagged_gpu_op_forward", + DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Half) + // NB: removed this to build + // DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Int) + AT_DISPATCH_CASE_FLOATING_TYPES_AND2( + at::ScalarType::Long, + at::ScalarType::BFloat16, + [&] { + jagged_dense_elementwise_jagged_output_( + values, + offsets.vec(), + dense, + output, + [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { + return y; + }); // device lambda + } // lambda + ) // CASE_FLOATING_TYPES_AND + ); // SWITCH + // clang-format on + +#undef DISPATCH_DENSE_TO_JAGGED_CASE + + return output; +} + } // namespace native } // namespace at diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index f9bf58c5f474..ad9cf07d7550 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -388,6 +388,7 @@ aten::_int_mm aten::_int_mm.out aten::_is_all_true aten::_is_any_true +aten::_jagged_to_padded_dense_forward aten::_lazy_clone aten::_linalg_check_errors aten::_linalg_det @@ -477,6 +478,7 @@ aten::_nnpack_spatial_convolution.out aten::_nnz aten::_pack_padded_sequence aten::_pack_padded_sequence.out +aten::_padded_dense_to_jagged_forward aten::_pdist_backward aten::_pdist_backward.out aten::_pdist_forward diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 17b2bf5a8393..2382cb40b522 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -4551,6 +4551,31 @@ def forward(self, query, value, offsets): self.assertTrue(torch.allclose(attn_output_eager, attn_output)) self.assertTrue(torch.allclose(value_grad, value.grad)) + @dtypes(torch.float64, torch.float32, torch.half) + @onlyCUDA + def test_fbgemm_jagged_to_padded_dense_kernels(self, device, dtype): + values = torch.randn(10, 5, device=device, dtype=dtype) + offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64) + max_length = offsets.diff().max().item() + padding_value = 1.3 + + # convert jagged -> padded dense + padded = torch.ops.aten._jagged_to_padded_dense_forward( + values, [offsets], [max_length], padding_value + ) + + batch_size = offsets.shape[0] - 1 + expected_padded_shape = (batch_size, max_length, values.shape[-1]) + self.assertEqual(padded.shape, expected_padded_shape) + + # convert padded dense -> jagged + total_L = values.shape[0] + output_jagged = torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets], total_L + ) + + # should be equivalent to the original values + self.assertEqual(values, output_jagged) instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) From dbf39a6e6323419c95d989949cefc82b5c2e027c Mon Sep 17 00:00:00 2001 From: "haozhe.zhu" Date: Sun, 2 Jun 2024 19:25:42 +0800 Subject: [PATCH 288/706] [inductor] fix linear_add_bias path (#127597) Previous the `linear_add_bias` path do not work. This PR is to fix it and add more ut with it. **TestPlan** ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_add_bias ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127597 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_mkldnn_pattern_matcher.py | 33 ++++++++++++++++++++ torch/_inductor/fx_passes/mkldnn_fusion.py | 12 +++++-- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 94fe34c64e53..8932fcfc4afd 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -396,6 +396,39 @@ def forward(self, x): matcher_nodes = 1 self._test_common(mod, (v,), matcher_count, matcher_nodes) + def test_linear_add_bias(self): + class M(torch.nn.Module): + def __init__(self, dtype, unary_fn): + super().__init__() + self.linear = torch.nn.Linear(10, 64, bias=False) + self.bias = torch.randn(64).to(dtype=dtype) + self.unary_fn = unary_fn + + def forward(self, x): + x = self.linear(x) + self.bias + return self.unary_fn(x) + + dtypes = [] + if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + dtypes.append(torch.bfloat16) + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + dtypes.append(torch.float16) + options = itertools.product(unary_list, dtypes) + for unary_fn, dtype in options: + metrics.reset() + mod = M(dtype, unary_fn).eval() + v = torch.randn(2, 10) + matcher_count = 3 + # Add 1 for weight packing pass, add 2 for bias folding pass. + matcher_nodes = unary_list[unary_fn] + 3 + if self._check_unary_is_decomposed(unary_fn): + # Has extra dtype conversion nodes for autocast. + matcher_nodes += 2 + self._test_common( + mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype + ) + self.assertEqual(metrics.generated_kernel_count, 1) + @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 5d1a723fa58a..be73a09ca648 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -788,14 +788,22 @@ def get_val(val): def is_linear_add_bias(match): add_node = match.output_node() linear_node = add_node.args[0] - weight_meta = linear_node.args[1].meta.get("val") + packed_weight_node = linear_node.args[1] + assert packed_weight_node.name == "_reorder_linear_weight" + transpose_weight_node = packed_weight_node.args[0] + assert transpose_weight_node.name == "permute_default" + weight_meta = transpose_weight_node.args[0].meta.get("val") + bias_node = add_node.args[1] + if isinstance(bias_node, int): + # we only folding bias if it is a constant + return False bias_meta = add_node.args[1].meta.get("val") if weight_meta is None or bias_meta is None: return False return ( linear_node.args[2] is None and bias_meta.dim() == 1 - and bias_meta.size(0) == weight_meta.size(0) + and bias_meta.size(0) == weight_meta.size(1) ) # convert linear+bias to a single linear for applying fusion path. From 05fa05cbae55bcae330c01d5e26cf86d05e307e6 Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 4 Jun 2024 00:49:01 +0000 Subject: [PATCH 289/706] [2/N] Change static functions in headers to inline (#127764) Follows #127727 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127764 Approved by: https://github.com/Skylion007 --- c10/util/StringUtil.h | 2 +- c10/util/TypeSafeSignMath.h | 20 ++++++++----------- torch/csrc/Exceptions.h | 2 +- torch/csrc/Storage.h | 4 ++-- torch/csrc/autograd/python_variable.h | 4 ++-- .../csrc/autograd/python_variable_indexing.h | 2 +- torch/csrc/cuda/nccl.h | 2 +- torch/csrc/utils/cuda_enabled.h | 8 +++----- torch/csrc/utils/device_lazy_init.h | 6 +++--- torch/csrc/utils/python_strings.h | 3 +-- 10 files changed, 23 insertions(+), 30 deletions(-) diff --git a/c10/util/StringUtil.h b/c10/util/StringUtil.h index 157a4f4be28d..88a91c84ef0f 100644 --- a/c10/util/StringUtil.h +++ b/c10/util/StringUtil.h @@ -142,7 +142,7 @@ struct C10_API SourceLocation { std::ostream& operator<<(std::ostream& out, const SourceLocation& loc); // unix isprint but insensitive to locale -inline static bool isPrint(char s) { +inline bool isPrint(char s) { return s > 0x1f && s < 0x7f; } diff --git a/c10/util/TypeSafeSignMath.h b/c10/util/TypeSafeSignMath.h index 7eb6d61c122e..2853ff48d183 100644 --- a/c10/util/TypeSafeSignMath.h +++ b/c10/util/TypeSafeSignMath.h @@ -16,7 +16,7 @@ namespace c10 { /// Returns false since we cannot have x < 0 if x is unsigned. template -static inline constexpr bool is_negative( +inline constexpr bool is_negative( const T& /*x*/, std::true_type /*is_unsigned*/) { return false; @@ -24,9 +24,7 @@ static inline constexpr bool is_negative( /// Returns true if a signed variable x < 0 template -static inline constexpr bool is_negative( - const T& x, - std::false_type /*is_unsigned*/) { +inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) { return x < T(0); } @@ -42,15 +40,13 @@ inline constexpr bool is_negative(const T& x) { /// Returns the sign of an unsigned variable x as 0, 1 template -static inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) { +inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) { return T(0) < x; } /// Returns the sign of a signed variable x as -1, 0, 1 template -static inline constexpr int signum( - const T& x, - std::false_type /*is_unsigned*/) { +inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) { return (T(0) < x) - (x < T(0)); } @@ -92,7 +88,7 @@ inline constexpr bool greater_than_max(const T& x) { /// Returns true if x < lowest(Limit). Standard comparison template -static inline constexpr bool less_than_lowest( +inline constexpr bool less_than_lowest( const T& x, std::false_type /*limit_is_unsigned*/, std::false_type /*x_is_unsigned*/) { @@ -102,7 +98,7 @@ static inline constexpr bool less_than_lowest( /// Returns false since all the limit is signed and therefore includes /// negative values but x cannot be negative because it is unsigned template -static inline constexpr bool less_than_lowest( +inline constexpr bool less_than_lowest( const T& /*x*/, std::false_type /*limit_is_unsigned*/, std::true_type /*x_is_unsigned*/) { @@ -112,7 +108,7 @@ static inline constexpr bool less_than_lowest( /// Returns true if x < 0, where 0 is constructed from T. /// Limit is not signed, so its lower value is zero template -static inline constexpr bool less_than_lowest( +inline constexpr bool less_than_lowest( const T& x, std::true_type /*limit_is_unsigned*/, std::false_type /*x_is_unsigned*/) { @@ -121,7 +117,7 @@ static inline constexpr bool less_than_lowest( /// Returns false sign both types are unsigned template -static inline constexpr bool less_than_lowest( +inline constexpr bool less_than_lowest( const T& /*x*/, std::true_type /*limit_is_unsigned*/, std::true_type /*x_is_unsigned*/) { diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 4f8d614e16dc..e4779ff984bc 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -19,7 +19,7 @@ #include #endif -static inline void PyErr_SetString(PyObject* type, const std::string& message) { +inline void PyErr_SetString(PyObject* type, const std::string& message) { PyErr_SetString(type, message.c_str()); } /// NOTE [ Conversion Cpp Python Warning ] diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index 16bf87bbcc2e..55deb18892bb 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -23,11 +23,11 @@ TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage( bool allow_preexisting_pyobj = false); extern PyTypeObject* THPStorageClass; -static inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) { +inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) { return tp == THPStorageClass; } -static inline bool THPStorage_CheckExact(PyObject* obj) { +inline bool THPStorage_CheckExact(PyObject* obj) { return THPStorage_CheckTypeExact(Py_TYPE(obj)); } diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index d0cb13e9f33e..51ade77f03ec 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -39,7 +39,7 @@ TORCH_PYTHON_API extern PyObject* ParameterClass; bool THPVariable_initModule(PyObject* module); TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var); -static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { +inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { // Check that a python object is a `Tensor`, but not a `Tensor` subclass. // (A subclass could have different semantics.) The one exception is // Parameter, which is used for Python bookkeeping but is equivalent to @@ -49,7 +49,7 @@ static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { tp == (PyTypeObject*)ParameterClass); } -static inline bool THPVariable_CheckExact(PyObject* obj) { +inline bool THPVariable_CheckExact(PyObject* obj) { return THPVariable_CheckTypeExact(Py_TYPE(obj)); } diff --git a/torch/csrc/autograd/python_variable_indexing.h b/torch/csrc/autograd/python_variable_indexing.h index a0e35a6e9eff..78c4a546ddbe 100644 --- a/torch/csrc/autograd/python_variable_indexing.h +++ b/torch/csrc/autograd/python_variable_indexing.h @@ -15,7 +15,7 @@ struct UnpackedSlice { }; // This mirrors Cpython's PySlice_Unpack method -static inline UnpackedSlice __PySlice_Unpack(PyObject* _r) { +inline UnpackedSlice __PySlice_Unpack(PyObject* _r) { PySliceObject* r = (PySliceObject*)_r; /* this is harder to get right than you might think */ diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index b118bd4600a5..37d1be15cbd7 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -88,7 +88,7 @@ namespace detail { TORCH_CUDA_CPP_API void throw_nccl_error(ncclResult status); -static inline void NCCL_CHECK(ncclResult status) { +inline void NCCL_CHECK(ncclResult status) { if (status != ncclResult::Success) { throw_nccl_error(status); } diff --git a/torch/csrc/utils/cuda_enabled.h b/torch/csrc/utils/cuda_enabled.h index e27c168a8ef4..0e3c2f30a83e 100644 --- a/torch/csrc/utils/cuda_enabled.h +++ b/torch/csrc/utils/cuda_enabled.h @@ -1,9 +1,8 @@ #pragma once -namespace torch { -namespace utils { +namespace torch::utils { -static inline bool cuda_enabled() { +inline constexpr bool cuda_enabled() { #ifdef USE_CUDA return true; #else @@ -11,5 +10,4 @@ static inline bool cuda_enabled() { #endif } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/device_lazy_init.h b/torch/csrc/utils/device_lazy_init.h index 4d736898e535..79c05f3c9ada 100644 --- a/torch/csrc/utils/device_lazy_init.h +++ b/torch/csrc/utils/device_lazy_init.h @@ -26,21 +26,21 @@ namespace torch::utils { void device_lazy_init(at::DeviceType device_type); void set_requires_device_init(at::DeviceType device_type, bool value); -static inline void maybe_initialize_device(at::Device& device) { +inline void maybe_initialize_device(at::Device& device) { // Add more devices here to enable lazy initialization. if (device.is_cuda() || device.is_xpu() || device.is_privateuseone()) { device_lazy_init(device.type()); } } -static inline void maybe_initialize_device(std::optional& device) { +inline void maybe_initialize_device(std::optional& device) { if (!device.has_value()) { return; } maybe_initialize_device(device.value()); } -static inline void maybe_initialize_device(const at::TensorOptions& options) { +inline void maybe_initialize_device(const at::TensorOptions& options) { auto device = options.device(); maybe_initialize_device(device); } diff --git a/torch/csrc/utils/python_strings.h b/torch/csrc/utils/python_strings.h index a2754ef4610b..cca161399c44 100644 --- a/torch/csrc/utils/python_strings.h +++ b/torch/csrc/utils/python_strings.h @@ -100,8 +100,7 @@ inline void THPUtils_internStringInPlace(PyObject** obj) { * */ -// NOLINTNEXTLINE(clang-diagnostic-unused-function) -static py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) { +inline py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) { PyTypeObject* tp = Py_TYPE(obj); PyObject* res = (PyObject*)nullptr; From 41033a42742a4dee8de8aa66aef345e5e24a571b Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 4 Jun 2024 00:59:54 +0000 Subject: [PATCH 290/706] PyPI: fix link to images to be rendered (#127798) It addresses the long pending issues on PyPI. The [package description](https://pypi.org/project/torch/2.3.0/) is the repo's Readme, but compared to GitHub rendering, PyPI accepts only raw images linked via MarkDown images. ![image](https://github.com/pytorch/pytorch/assets/6035284/1d8e51d5-c8c1-4f92-b323-f7684879adb4) This minor link edit makes the image become raw images and so correctly rendered via PyPI Pull Request resolved: https://github.com/pytorch/pytorch/pull/127798 Approved by: https://github.com/albanD --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2a469af7b166..9a4ba683d769 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -![PyTorch Logo](https://github.com/pytorch/pytorch/blob/main/docs/source/_static/img/pytorch-logo-dark.png) +![PyTorch Logo](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/pytorch-logo-dark.png) -------------------------------------------------------------------------------- @@ -98,7 +98,7 @@ from several research papers on this topic, as well as current and past work suc While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research. -![Dynamic graph](https://github.com/pytorch/pytorch/blob/main/docs/source/_static/img/dynamic_graph.gif) +![Dynamic graph](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif) ### Python First From 1208347d0912d1236ae43257e2914767e35c3b36 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 3 Jun 2024 13:44:22 -0700 Subject: [PATCH 291/706] [inductor][ez] fix loop ordering test (#127807) I didn't realize that the main block is not being run when inductor tests are being run in FBCode via remote GPUs. This is a quick fix. I've tested it in both OSS and FBCode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127807 Approved by: https://github.com/eellison, https://github.com/jansel --- test/inductor/test_loop_ordering.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index 856d849b880f..5261c2325834 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -7,6 +7,9 @@ from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import HAS_CUDA +if HAS_CUDA: + torch.set_default_device("cuda") + @inductor_config.patch( { @@ -53,5 +56,4 @@ def f(x, y): if __name__ == "__main__": if HAS_CUDA: - torch.set_default_device("cuda") run_tests() From ddef7c350f9c63898177ef7923f90d89bd6f41e3 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 4 Jun 2024 02:06:43 +0000 Subject: [PATCH 292/706] Add comments about runner labels (#127827) To distinguish between org-wide and repo-specific runners as well as highlight where they are hosted (by DevInfra, LF or various partners Delete unused `bm-runner` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127827 Approved by: https://github.com/huydhn --- .github/actionlint.yaml | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 679658dafdd8..bb83775a59b2 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -1,9 +1,12 @@ self-hosted-runner: labels: + # GitHub hosted x86 Linux runners - linux.20_04.4x - linux.20_04.16x - - linux.large + # Repo-specific LF hosted ARC runners - linux.large.arc + # Organization-wide AWS Linux Runners + - linux.large - linux.2xlarge - linux.4xlarge - linux.12xlarge @@ -13,16 +16,23 @@ self-hosted-runner: - linux.8xlarge.nvidia.gpu - linux.16xlarge.nvidia.gpu - linux.g5.4xlarge.nvidia.gpu + # Repo-specific IBM hosted S390x runner - linux.s390x + # Organization wide AWS Windows runners - windows.4xlarge.nonephemeral - windows.8xlarge.nvidia.gpu - windows.8xlarge.nvidia.gpu.nonephemeral - windows.g5.4xlarge.nvidia.gpu - - bm-runner + # Organization-wide AMD hosted MI300 runners - linux.rocm.gpu + # Repo-specific Apple hosted runners + - macos-m1-ultra + - macos-m2-14 + # Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors) - macos-m1-stable - macos-m1-13 - macos-m1-14 + # GitHub-hosted MacOS runners - macos-latest-xlarge - macos-13-xlarge - macos-14-xlarge From 4d0386ce1cfe0559f9a11fd782e42f851759fe79 Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Tue, 4 Jun 2024 02:12:14 +0000 Subject: [PATCH 293/706] =?UTF-8?q?[torch/jit-runtime]=20Add=20explicit=20?= =?UTF-8?q?include=20of=20=20to=20torch/jit/run=E2=80=A6=20(#12777?= =?UTF-8?q?9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added an explicit include to `` in `jit/runtime/logging.h` since `std::chrono::time_point` is directly referenced in the header. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127779 Approved by: https://github.com/albanD --- torch/csrc/jit/runtime/logging.h | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/csrc/jit/runtime/logging.h b/torch/csrc/jit/runtime/logging.h index b0b67c680883..fda364e0a923 100644 --- a/torch/csrc/jit/runtime/logging.h +++ b/torch/csrc/jit/runtime/logging.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include From c32fe6b279a6aefcae6d52a9d08b8b4ca104b6ed Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 4 Jun 2024 03:32:22 +0000 Subject: [PATCH 294/706] [FSDP] keep paras in torch.distributed.checkpoint.state_dict.set_optimizer_state_dict (#127644) This addresses Fixes https://github.com/pytorch/pytorch/issues/126948 The previous code under `_load_optim_state_dict `function with condition of `info.broadcast_from_rank0`, `optim_state_dict` holds the parameters based on `optim`. Changes here aim to synchronize the differential parameters. Unit tests are conducted under `test_state_dict.py` in `test_optim_state_dict_para_matching`, Pull Request resolved: https://github.com/pytorch/pytorch/pull/127644 Approved by: https://github.com/fegin --- .../distributed/checkpoint/test_state_dict.py | 44 +++++++++++++++++++ torch/distributed/checkpoint/state_dict.py | 9 ++++ 2 files changed, 53 insertions(+) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index ccd1303c26db..329f8015dc7c 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -19,6 +19,7 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, ) +from torch.distributed.checkpoint import state_dict as ptd_state_dict from torch.distributed.checkpoint.state_dict import ( _patch_model_state_dict, _patch_optimizer_state_dict, @@ -667,6 +668,49 @@ def test_fsdp_root_not_initialized(self) -> None: get_model_state_dict(fsdp_model) get_optimizer_state_dict(fsdp_model, fsdp_optim) + @with_comms + @skip_if_lt_x_gpu(2) + def test_optim_state_dict_param_matching(self) -> None: + # This test verifies parameters between optim and optim_state_dict + # "initial_lr" is added to optim_state_dict, but not to the new optim + # We test whether "initial_lr" appear in optim after + # set_optimizer_state_dict. + device = "cuda" + torch.manual_seed(0) + model = nn.Sequential( + *[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)] + ) + for layer in model: + fully_shard(layer) + fully_shard(model) + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + torch.optim.lr_scheduler.LambdaLR( + optim, lr_lambda=[lambda epoch: 0.95**epoch] + ) + opt_state_dict = ptd_state_dict.get_optimizer_state_dict( + model, + optim, + options=ptd_state_dict.StateDictOptions( + full_state_dict=True, cpu_offload=True + ), + ) + if dist.get_rank() == 0: + self.assertTrue("initial_lr" in opt_state_dict["param_groups"][0]) + + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + self.assertTrue("initial_lr" not in optim.param_groups[0]) + + ptd_state_dict.set_optimizer_state_dict( + model, + optim, + optim_state_dict=opt_state_dict, + options=ptd_state_dict.StateDictOptions( + broadcast_from_rank0=True, full_state_dict=True + ), + ) + if dist.get_rank() == 0: + self.assertTrue("initial_lr" in optim.param_groups[0]) + @with_comms @skip_if_lt_x_gpu(2) def test_flattened_osd(self) -> None: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 0c4cc32c09a1..46701c3493d5 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -878,6 +878,15 @@ def _device(t): flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict) flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict) _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device) + # The modifications listed seek to address the problem where optim might possess + # dissimilar parameters in comparison to optim_state_dict. This is achieved by + # incorporating differential parameters within local, which may result in optim + # having additional parameters ultimately. + for optim_key in flatten_osd.keys(): + if optim_key not in flatten_local_osd: + assert optim_key in osd_mapping + flatten_local_osd[optim_key] = flatten_osd[optim_key] + local_osd_mapping[optim_key] = osd_mapping[optim_key] optim_state_dict = _unflatten_state_dict( flatten_local_osd, local_osd_mapping ) From 7e906ec9e575abf0c7c431e26f124666998be80c Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Tue, 4 Jun 2024 03:41:44 +0000 Subject: [PATCH 295/706] [PT2][Optimus] Improve group batch fusion with same parent/users fusion enablement (#127648) Summary: Currently, we fuse the ops in random place, we here enable the same parent/users fuse to enable follow up potential split cat elimination. Context https://docs.google.com/document/d/1MSZY23wKD2keW2Z-DfAI1DscDERHKjOJAnuB5bxa06I/edit Test Plan: # local reproduce ``` buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode batch-split --model_type "pm_cmf" --flow_id 559694026 ``` P1386889671 Differential Revision: D58037636 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127648 Approved by: https://github.com/jackiexu1992 --- torch/_inductor/fx_passes/group_batch_fusion.py | 17 +++++++++++++++++ torch/_inductor/fx_passes/pre_grad.py | 5 ++++- torch/_inductor/fx_passes/split_cat.py | 4 ++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 92449268bcec..289fe0dbead8 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -600,12 +600,17 @@ def match(self, node: torch.fx.Node): input = get_arg_value(node, 0, "input") weight = get_arg_value(node, 1, "weight") bias = get_arg_value(node, 2, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] group_key = ( "batch_linear", self._getitem_args(input), str(input.meta["example_value"].shape), str(weight.meta["example_value"].shape), bias is None, + str(users), ) else: group_key = None @@ -683,6 +688,10 @@ def match(self, node: torch.fx.Node): input = get_arg_value(node, 0, "input") weight = get_arg_value(node, 2, "weight") bias = get_arg_value(node, 3, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] group_key = ( ( "batch_layernorm", @@ -693,6 +702,7 @@ def match(self, node: torch.fx.Node): str(bias.meta["example_value"].shape) if bias is not None else "", str(get_arg_value(node, 1, "normalized_shape")), str(get_arg_value(node, 4, "eps")), + str(users), ) if "example_value" in input.meta and is_node_meta_valid(weight) @@ -848,11 +858,18 @@ def __init__(self, op, **kwargs): def match(self, node: torch.fx.Node): input = get_arg_value(node, 0, "input") if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # pyre-fixme[16] + parent = node.args[0] + parent = parent.target if parent is not None else "" # type: ignore[union-attr] + else: + parent = "" # for relu op, we also use the inplace to construct the key group_key = ( "batch_" + self.op.__name__.lower().split(".")[0], str(input.meta["example_value"].shape), str(node.kwargs.get("inplace", False)), + str(parent), ) else: group_key = None diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 9af2440eb80b..1cfa104ea995 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -207,7 +207,10 @@ def shape_prop(mod) -> None: inductor_before_change = save_inductor_dict( [pattern_matcher_pass.pass_name] ) - pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + # we support run same pattern multiple times, the default is to run only once + counter = config.pre_grad_fusion_options[pass_name].get("counter", 1) + for _ in range(counter): + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] if not is_same_dict(counters["inductor"], inductor_before_change): optimus_scuba_log[ f"{pattern_matcher_pass.pass_name}_pre_grad" diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 563804f2471a..34757b3b5b1e 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -1708,6 +1708,10 @@ def merge_unbind_stack_aten(match: Match, *args, **kwargs): [get_arg_value(select_node, 2, "index") for select_node in select_nodes] ): return + # check the users of parent of select node only from unsqueeze nodes that go to the cat node + # we simply check the number of users of the parent of select node + if len(parent_of_select_node.users.keys()) != len(node.args[0]): # type: ignore[arg-type] + return node.replace_all_uses_with(parent_of_select_node) graph.erase_node(node) for unsqueeze_node in unsqueeze_nodes: From 6580a18f86076e9d1116e7586aa92ee4844f4b71 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Mon, 3 Jun 2024 22:43:08 +0000 Subject: [PATCH 296/706] [c10d][BE] fix test_init_pg_and_rpc_with_same_socket (#127654) **Summary** fix `test_init_pg_and_rpc_with_same_socket` in `test/distributed/test_store.py` which missed a call to destroy the created ProcessGroup before exiting test function. It lead to "init PG twice" error in the test. **Test Plan** `pytest test/distributed/test_store.py -s -k test_init_pg_and_rpc_with_same_socket` `ciflow/periodic` since this test is included in `.ci/pytorch/multigpu-test.sh` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127654 Approved by: https://github.com/Skylion007, https://github.com/malfet --- test/distributed/test_store.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index 8383101d2093..8de265a30cd8 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -326,6 +326,7 @@ def test_init_pg_and_rpc_with_same_socket(self): ) rpc.shutdown() + dist.destroy_process_group() @skip_if_win32() def test_take_over_listen_socket(self): From 2498ef749001a2f4688929fadb59e7d6a53a8e11 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Tue, 4 Jun 2024 04:19:04 +0000 Subject: [PATCH 297/706] Fix scheduler typehints (#127769) Fixes scheduler typehints Pull Request resolved: https://github.com/pytorch/pytorch/pull/127769 Approved by: https://github.com/jansel --- torch/_inductor/scheduler.py | 250 +++++++++++++++++------------------ 1 file changed, 125 insertions(+), 125 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index a7517575d888..05e05c7d950d 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -63,101 +63,6 @@ fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") -class WhyNoFuse: - # TODO when we drop support for Python < 3.10, we can use - # @dataclass(slots=True) instead of manually specifying __slots__. - __slots__ = ["node1", "node2", "reason", "args"] - reason: str - args: Tuple[Any, ...] - - def __init__(self, node1: "BaseSchedulerNode", node2: "BaseSchedulerNode"): - self.node1 = node1 - self.node2 = node2 - - def __call__(self, reason: str, *args: Any) -> None: - self.reason = reason - self.args = args - fusion_log.debug(self) - - def __str__(self) -> str: - return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( - self.reason % self.args - ) - - -def pformat(obj: Any) -> str: - if isinstance(obj, set): - # pformat has trouble with sets of sympy exprs - obj = sorted(obj, key=str) - result = pprint.pformat(obj, indent=4) - if "\n" in result: - return f"\n{textwrap.indent(result, ' '*4)}" - return result - - -class OutputNode: - def __init__(self, dep: StarDep) -> None: - self.unmet_dependencies = {dep} - self.inverse_users: List[BaseSchedulerNode] = [] - - def is_reduction(self) -> bool: - return False - - def get_inputs_that_alias_output(self) -> Sequence[str]: - return () - - def get_name(self) -> str: - return "OUTPUT" - - __repr__ = get_name - - -def _prune_redundant_deps( - node: "BaseSchedulerNode", name_to_fused_node: Dict[str, "BaseSchedulerNode"] -) -> None: - """ - Prunes weakdeps intended for mutation ordering - on an upstream fused node if after fusion there is another dependency - on the fused upstream node, making the weakdep redundant - - In essence this enforces an ordering on fusions. As fusions occur, weakdeps will - be incrementally removed, enabling other fusions, ensuring they are fused in order. - """ - name_to_dep_count: Counter[str] = collections.Counter() - - for dep in node.unmet_dependencies: - if not isinstance(dep, WeakDep): - name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1 - - def should_prune(dep: Dep) -> bool: - if isinstance(dep, WeakDep): - is_redundant = ( - name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0 - ) - # These can occur because fused nodes always gather deps from their snodes - # If B has a weakdep on A - # B gets fused with C, then any time BC is fused, the weakdep will reappear - is_self_dep = name_to_fused_node[dep.name] == node - return is_redundant or is_self_dep - else: - return False - - deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)} - - if deps_to_prune: - node.unmet_dependencies = node.unmet_dependencies - deps_to_prune - node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) - - -# TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel -kernel_name_to_op = { - "extern_kernels.convolution": torch.ops.aten.convolution, - "extern_kernels.mm": torch.ops.aten.mm, - "extern_kernels.bmm": torch.ops.aten.bmm, - "extern_kernels.addmm": torch.ops.aten.addmm, -} - - class BaseSchedulerNode: group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] read_writes: dependencies.ReadWrites @@ -705,6 +610,101 @@ def get_template_node(self) -> Optional[ir.TemplateBuffer]: return None +class WhyNoFuse: + # TODO when we drop support for Python < 3.10, we can use + # @dataclass(slots=True) instead of manually specifying __slots__. + __slots__ = ["node1", "node2", "reason", "args"] + reason: str + args: Tuple[Any, ...] + + def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + self.node1 = node1 + self.node2 = node2 + + def __call__(self, reason: str, *args: Any) -> None: + self.reason = reason + self.args = args + fusion_log.debug(self) + + def __str__(self) -> str: + return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( + self.reason % self.args + ) + + +def pformat(obj: Any) -> str: + if isinstance(obj, set): + # pformat has trouble with sets of sympy exprs + obj = sorted(obj, key=str) + result = pprint.pformat(obj, indent=4) + if "\n" in result: + return f"\n{textwrap.indent(result, ' '*4)}" + return result + + +class OutputNode: + def __init__(self, dep: StarDep) -> None: + self.unmet_dependencies = {dep} + self.inverse_users: List[BaseSchedulerNode] = [] + + def is_reduction(self) -> bool: + return False + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return () + + def get_name(self) -> str: + return "OUTPUT" + + __repr__ = get_name + + +def _prune_redundant_deps( + node: BaseSchedulerNode, name_to_fused_node: Dict[str, BaseSchedulerNode] +) -> None: + """ + Prunes weakdeps intended for mutation ordering + on an upstream fused node if after fusion there is another dependency + on the fused upstream node, making the weakdep redundant + + In essence this enforces an ordering on fusions. As fusions occur, weakdeps will + be incrementally removed, enabling other fusions, ensuring they are fused in order. + """ + name_to_dep_count: Counter[str] = collections.Counter() + + for dep in node.unmet_dependencies: + if not isinstance(dep, WeakDep): + name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1 + + def should_prune(dep: Dep) -> bool: + if isinstance(dep, WeakDep): + is_redundant = ( + name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0 + ) + # These can occur because fused nodes always gather deps from their snodes + # If B has a weakdep on A + # B gets fused with C, then any time BC is fused, the weakdep will reappear + is_self_dep = name_to_fused_node[dep.name] == node + return is_redundant or is_self_dep + else: + return False + + deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)} + + if deps_to_prune: + node.unmet_dependencies = node.unmet_dependencies - deps_to_prune + node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) + + +# TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel +kernel_name_to_op = { + "extern_kernels.convolution": torch.ops.aten.convolution, + "extern_kernels.mm": torch.ops.aten.mm, + "extern_kernels.bmm": torch.ops.aten.bmm, + "extern_kernels.addmm": torch.ops.aten.addmm, +} + + class ExternKernelSchedulerNode(BaseSchedulerNode): def debug_str_extra(self) -> str: return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" @@ -721,36 +721,6 @@ class NopKernelSchedulerNode(BaseSchedulerNode): pass -def debug_triton_code(node: Union["SchedulerNode", "FusedSchedulerNode"]) -> List[str]: - lines = [] - multi_template = node.get_template_node() - assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) - if multi_template and multi_template.make_kernel_render is None: - lines.append(f"{node.get_name()} Unfinalized multi template buffer") - else: - from torch._inductor.codegen.cuda_combined_scheduling import ( - CUDACombinedScheduling, - ) - from torch._inductor.codegen.triton import TritonScheduling - - snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes - device = snodes[0].get_device() - backend = node.scheduler.get_backend(device) - assert isinstance(backend, (TritonScheduling, CUDACombinedScheduling)) - V.graph.scheduler.current_device = device - - # Don't increment kernel count when generating debug string. - # This will confuse some unit tests that check the number of - # generated kernels. - old_generated_kernel_count = metrics.generated_kernel_count - triton_code = backend.generate_kernel_code_from_nodes(snodes).strip() - metrics.generated_kernel_count = old_generated_kernel_count - - lines.append(f"{node.get_name()} Triton code:") - lines.append(textwrap.indent(triton_code, " ")) - return lines - - class SchedulerNode(BaseSchedulerNode): def __init__( self, @@ -2887,3 +2857,33 @@ def get_fusion_pair_priority( The smaller is with higher priority. """ return 0 + + +def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[str]: + lines = [] + multi_template = node.get_template_node() + assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) + if multi_template and multi_template.make_kernel_render is None: + lines.append(f"{node.get_name()} Unfinalized multi template buffer") + else: + from torch._inductor.codegen.cuda_combined_scheduling import ( + CUDACombinedScheduling, + ) + from torch._inductor.codegen.triton import TritonScheduling + + snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes + device = snodes[0].get_device() + backend = node.scheduler.get_backend(device) + assert isinstance(backend, (TritonScheduling, CUDACombinedScheduling)) + V.graph.scheduler.current_device = device + + # Don't increment kernel count when generating debug string. + # This will confuse some unit tests that check the number of + # generated kernels. + old_generated_kernel_count = metrics.generated_kernel_count + triton_code = backend.generate_kernel_code_from_nodes(snodes).strip() + metrics.generated_kernel_count = old_generated_kernel_count + + lines.append(f"{node.get_name()} Triton code:") + lines.append(textwrap.indent(triton_code, " ")) + return lines From 69f5b66132afa9f52e83cbae1ddeb8842f75964e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 4 Jun 2024 04:22:41 +0000 Subject: [PATCH 298/706] [Inductor] FlexAttention backward kernel optimization (#127208) BWD Speedups (before this PR): ``` | Type | Speedup | shape | score_mod | dtype | |---------|-----------|-------------------|---------------|----------------| | Average | 0.211 | | | | | Max | 0.364 | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | | Min | 0.044 | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | ``` BWD Speedups (after this PR, though not optimizing block size yet): ``` | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|---------------|----------------| | Average | 0.484 | | | | | Max | 0.626 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | | Min | 0.355 | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | ``` There are a few things need to do as follow-ups: * Optimized default block size on A100/H100. * Support different seqlen for Q and K/V. * Support dynamic shapes for backward. * Enhance unit tests to check there is no ```nan``` value in any grad. I think we should make some changes to ```test_padded_dense_causal``` because it has invalid inputs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127208 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 9 +- torch/_inductor/kernel/flex_attention.py | 300 +++++++++++++---------- torch/_inductor/select_algorithm.py | 5 +- 3 files changed, 188 insertions(+), 126 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index bc688ab834cb..21e462f75f81 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -31,7 +31,10 @@ # Skip tests if Triton is not available supported_platform = skipUnless( - torch.cuda.is_available() and has_triton() and torch.version.hip is None, + torch.cuda.is_available() + and has_triton() + and torch.version.hip is None + and torch.cuda.get_device_capability() >= (8, 0), "Requires CUDA and Triton", ) @@ -144,6 +147,8 @@ def _check_equal( ): compiled_error = (golden_out - compiled_out).abs().mean() ref_error = (golden_out - ref_out).abs().mean() + if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any(): + self.assertTrue(False, "Output/Grad with NaN") if compiled_error > ref_error * fudge_factor: name = tensor_name if tensor_name is not None else "" msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." @@ -195,7 +200,7 @@ def run_test( self._check_equal( k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" ) - v_fudge_factor = 8 * fudge_factor + v_fudge_factor = 4 * fudge_factor self._check_equal( v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" ) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 5a1f45e767a7..3e95dd4f65ce 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,7 +1,6 @@ """ Triton Implementation of the flex_attention Kernel""" import logging -import math from enum import auto, Enum from typing import Any, List, Tuple @@ -189,7 +188,7 @@ def build_subgraph_buffer( Z = {{size("Q", 0)}} H = {{size("Q", 1)}} - N_CTX = {{size("Q", 2)}} + Q_LEN = {{size("Q", 2)}} qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty @@ -200,7 +199,7 @@ def build_subgraph_buffer( qkv_offset = off_hz * stride_qh Q_block_ptr = tl.make_block_ptr( base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), + shape=(Q_LEN, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), @@ -208,7 +207,7 @@ def build_subgraph_buffer( ) K_block_ptr = tl.make_block_ptr( base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), + shape=(BLOCK_DMODEL, Q_LEN), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), @@ -216,7 +215,7 @@ def build_subgraph_buffer( ) V_block_ptr = tl.make_block_ptr( base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), + shape=(Q_LEN, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), @@ -236,7 +235,7 @@ def build_subgraph_buffer( q = (q * qk_scale).to(MATMUL_PRECISION) # loop over k, v and update accumulator lo = 0 - hi = N_CTX + hi = Q_LEN for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- load k, v -- @@ -299,7 +298,7 @@ def build_subgraph_buffer( # TODO dont want to write this if we dont require grad if OUTPUT_LOGSUMEXP: - l_ptrs = LSE + off_hz * N_CTX + offs_m + l_ptrs = LSE + off_hz * Q_LEN + offs_m lse = m_i + tl.math.log2(l_i) tl.store(l_ptrs, lse) """, @@ -446,13 +445,22 @@ def flex_attention(*args, **kwargs): # ---------------------------- Backward HOP Implementation ---------------------------- -def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, meta): +def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, meta): """How is this kernel parallelized? Currently this is only parallelizing over batch * num_heads, but we can, and want to parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require atomic updates to some grad values or to have a two pass kernel design. """ - return (batch_size * num_heads, 1, 1) + import triton + + # TODO: support different seqlen for Query and Key/Value. + num_key_value = num_queries + return ( + triton.cdiv(num_queries, meta["BLOCK_M2"]) + + triton.cdiv(num_key_value, meta["BLOCK_N1"]), + 1, + batch_size * num_heads, + ) flex_attention_backward_template = TritonTemplate( @@ -470,95 +478,83 @@ def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, # M: Number of queries, N: Number of keys/values, D: Model dimension # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head # (Modifiable) Config options: - # BLOCK_M - # BLOCK_N + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. # SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the # change of base out of the loop - # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row - # is not masked out? If so, we can skip an extra safety check - # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad # Define Q Strides - stride_qz = {{stride("Q", 0)}} - stride_qh = {{stride("Q", 1)}} - stride_qm = {{stride("Q", 2)}} - stride_qk = {{stride("Q", 3)}} - # Define K Strides - stride_kz = {{stride("K", 0)}} - stride_kh = {{stride("K", 1)}} - stride_kn = {{stride("K", 2)}} - stride_kk = {{stride("K", 3)}} - # Define V Strides - stride_vz = {{stride("V", 0)}} - stride_vh = {{stride("V", 1)}} - stride_vn = {{stride("V", 2)}} - stride_vk = {{stride("V", 3)}} + stride_z = {{stride("Q", 0)}} + stride_h = {{stride("Q", 1)}} + stride_tok = {{stride("Q", 2)}} + stride_d = {{stride("Q", 3)}} Z = {{size("Q", 0)}} H = {{size("Q", 1)}} - N_CTX = {{size("Q", 2)}} + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} - qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty - off_hz = tl.program_id(0) + pid = tl.program_id(0) + NUM_KV_BLOCKS = KV_LEN // BLOCK_N1 + + bhid = tl.program_id(2) + off_chz = (bhid * Q_LEN).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + + off_hz = tl.program_id(2) off_z = off_hz // H # batch idx off_h = off_hz % H # head idx # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh - - # Asserting contiguous for now... - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DV += off_z * stride_vz + off_h * stride_vh - - # TODO I think that this should be N_CTX/BLOCK_N blocks - for start_n in range(0, NUM_Q_BLOCKS): - # We are not doing the causal optimization yet allowing us to start further down the - # kv column - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_m = tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, BLOCK_DMODEL) - - # initialize pointers to value-like data - q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) - do_ptrs = DO + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) - - # pointer to row-wise quantities in value-like data - D_ptrs = DELTA + off_hz * N_CTX - l_ptrs = LSE + off_hz * N_CTX - - # initialize dv and dk - dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - - # Key and Value stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - - for start_m in range(0, NUM_Q_BLOCKS * BLOCK_M, BLOCK_M): - offs_m_curr = start_m + offs_m - - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - - if SCORE_MOD_IS_LINEAR: - qk_scale *= 1.44269504 - q = (q * qk_scale).to(MATMUL_PRECISION) - - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, tl.trans(k.to(MATMUL_PRECISION)), acc=qk) - pre_mod_scores = qk + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DV += adj + LSE += off_chz + DELTA += off_chz + + offs_k = tl.arange(0, BLOCK_DMODEL) + + if pid >= NUM_KV_BLOCKS: + # THIS BLOCK DOES DQ + off_pid = pid - NUM_KV_BLOCKS + start_m2 = off_pid * BLOCK_M2 + + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + do = tl.load(DO + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) + + lse = tl.load(LSE + offs_m2) + lse = lse[:, None] + + start_n2 = 0 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + offs_n2 = start_n2 + tl.arange(0, BLOCK_N2) + kT_ptrs = K + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d + Di = tl.load(DELTA + offs_m2) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + curr_n = start_n2 + num_steps = KV_LEN // BLOCK_N2 + for blk_idx in range(num_steps): + offs_n2= curr_n + tl.arange(0, BLOCK_N2) + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - m = offs_m_curr[:, None] - n = offs_n[None, :] + pre_mod_scores = qk + m = offs_m2[:, None] + n = offs_n2[None, :] {{ modification( subgraph_number=0, output_name="post_mod_scores", @@ -569,25 +565,13 @@ def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, n="n", out="qk" ) | indent_except_first(3) }} - # TODO: In the case that score_mod is linear, this can be LICMed + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if not SCORE_MOD_IS_LINEAR: post_mod_scores *= 1.44269504 - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - l_i = tl.load(l_ptrs + offs_m_curr) - p = tl.math.exp2(post_mod_scores - l_i[:, None]) - - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(tl.trans(p.to(MATMUL_PRECISION)), do) - - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr) # [BLOCKM, 1] - - # compute ds = p * (dp - delta[:, None]) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, tl.trans(v)) - ds = p * dp - + p = tl.math.exp2(post_mod_scores - lse).to(MATMUL_PRECISION) + # Compute dP and dS. + dp = tl.dot(do, vT) + ds = p * (dp - Di[:, None]) # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ {{ modification( subgraph_number=1, @@ -601,32 +585,101 @@ def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, ) | indent_except_first(3) }} ds = grad_scores # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds.to(MATMUL_PRECISION)), q) - # compute dq - dq = tl.load(dq_ptrs) - dq += tl.dot(ds.to(MATMUL_PRECISION), k) - - # Store grad_query - tl.store(dq_ptrs, dq) - - # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - - # write-back - index_n = offs_n[:, None] - index_k = offs_k[None, :] + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += BLOCK_N2 + kT_ptrs += BLOCK_N2 * stride_tok + vT_ptrs += BLOCK_N2 * stride_tok + # Write back dQ. + dq_ptrs = DQ + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dq_ptrs, dq) + else: + # THIS BLOCK DOES DK & DV + start_n1 = pid * BLOCK_N1 + start_m1 = 0 + + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) + + offs_m1 = start_m1 + tl.arange(0, BLOCK_M1) + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + qT_ptrs = Q + offs_m1[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m1[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + curr_m = start_m1 + num_steps = Q_LEN // BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load LSE before computing qk to reduce pipeline stall. + offs_m1 = curr_m + tl.arange(0, BLOCK_M1) + lse = tl.load(LSE + offs_m1) + qkT = tl.dot(k, qT) + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = offs_m1[None, :] + n = offs_n1[:, None] + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_h", + m="m", + n="n", + out="qkT" + ) | indent_except_first(3) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not SCORE_MOD_IS_LINEAR: + post_mod_scores *= 1.44269504 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do) + Di = tl.load(DELTA + offs_m1) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + m = offs_m1[None, :] + n = offs_n1[:, None] + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_h", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(3) }} + dsT = grad_scores + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT)) + # Increment pointers. + curr_m += BLOCK_M1 + qT_ptrs += BLOCK_M1 * stride_tok + do_ptrs += BLOCK_M1 * stride_tok - # Store grad_key and grad_value - dv_ptrs = DV + (index_n * stride_vn + index_k * stride_vk) + dv_ptrs = DV + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d tl.store(dv_ptrs, dv) + # Write back dK. + index_n = offs_n1[:, None] + index_k = offs_k[None, :] # TODO generalize and add proper mask support mask = (index_n != -1) & (index_k != -1) {{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}} - """, ) @@ -722,10 +775,11 @@ def flex_attention_backward(*args, **kwargs): mutated_inputs=[grad_query, grad_value], num_stages=num_stages, num_warps=num_warps, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, + BLOCK_M1=BLOCK_M, + BLOCK_N1=BLOCK_N, + BLOCK_M2=BLOCK_N, + BLOCK_N2=BLOCK_M, BLOCK_DMODEL=query.get_size()[-1], - NUM_Q_BLOCKS=math.ceil(query.get_size()[-2] / BLOCK_M), # For now, we always assume the "sound" option SCORE_MOD_IS_LINEAR=False, ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index d8ca3eefed70..531f3c25a31b 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -310,7 +310,10 @@ def modification( Args: subgraph_number (int): The index of the subgraph in self.subgraphs """ - with self.create_subgraph_body(f"modification_{subgraph_number}"): + num = 0 + while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: + num += 1 + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): assert isinstance(subgraph_number, int) assert isinstance(self.subgraphs, list) assert ( From 22368eac108029e1d513dd957b8af657e3236302 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 3 Jun 2024 12:02:03 -0700 Subject: [PATCH 299/706] [FSDP2] Fix submesh slicing to enable 3D parallelism (#127585) Ensures the submesh used to create sharded parameters are created on a submesh that excludes the Pipeline Parallelism dimension. Also cleans up the logic for storing placements to no longer consider the outer / global dims. Since we store an 'spmd' submesh, we can avoid this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127585 Approved by: https://github.com/wanchaol --- .../fsdp/test_fully_shard_training.py | 6 +++ .../_composable/fsdp/_fsdp_param.py | 40 +++++++++---------- torch/distributed/device_mesh.py | 6 ++- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index a7b97f8f7dd3..6634e142312b 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -1182,6 +1182,12 @@ def _test_2d_mlp_with_nd_mesh( _optim.step() self.assertEqual(losses[0], losses[1]) + for n, p in model.named_parameters(): + self.assertIsInstance(p, DTensor) + self.assertEqual(p.device_mesh.ndim, 2) + self.assertEqual(len(p.placements), 2) + self.assertEqual(p.device_mesh.mesh_dim_names, ("dp", "tp")) + class TestFullyShardHSDPTraining(FSDPTest): @property diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index f0d64aa3e8f1..5fed53f4a11a 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -127,7 +127,7 @@ class FSDPParam: _sharded_post_forward_param: Optional[nn.Parameter] # ND _unsharded_param: nn.Parameter # ND unsharded_accumulated_grad: Optional[torch.Tensor] # ND - _global_placements: Tuple[Placement, ...] + _spmd_placements: Tuple[Placement, ...] _global_size: torch.Size _global_stride: Tuple[int, ...] all_gather_outputs: List[torch.Tensor] # 1D @@ -199,37 +199,33 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n" f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}" ) - self._global_mesh = dp_global_mesh + + name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism" + assert dp_mesh.mesh_dim_names is not None, name_dims_error + assert tp_mesh.mesh_dim_names is not None, name_dims_error + + submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names + self._spmd_mesh = dp_global_mesh[submesh_names] if len(self._tp_spec.placements) != 1: raise NotImplementedError( f"FSDP only supports 1D TP, not {self._tp_spec.placements}" ) - global_placements: List[Placement] = [Replicate(), Replicate()] - global_dp_mesh_dim = _mesh_resources.get_parent_mesh_dim(dp_mesh) - global_tp_mesh_dim = _mesh_resources.get_parent_mesh_dim(tp_mesh) - assert global_dp_mesh_dim is not None # mypy - assert global_tp_mesh_dim is not None # mypy - # for PP, DP, TP case, dp mesh dim would be 1, tp mesh dim would be 2 - # DP/TP would only live in the inner most 2-3 dims (HSDP + TP would be 3) - dp_tp_mesh_ndim = dp_mesh.ndim + tp_mesh.ndim - outer_mesh_ndim = self._global_mesh.ndim - dp_tp_mesh_ndim - if self._global_mesh.ndim > dp_tp_mesh_ndim: - global_dp_mesh_dim = global_dp_mesh_dim - outer_mesh_ndim - global_tp_mesh_dim = global_tp_mesh_dim - outer_mesh_ndim # TODO: Hard code FSDP + TP; need to support HSDP + TP - global_placements[global_dp_mesh_dim] = Shard(0) - global_placements[global_tp_mesh_dim] = self._tp_spec.placements[0] - self._global_placements = tuple(global_placements) + self._spmd_placements = ( + Shard(0), + self._tp_spec.placements[0], + ) + self._global_size = param.size() self._global_stride = param.stride() param_data = cast(DTensor, param)._local_tensor else: - self._global_mesh = self.mesh_info.mesh + self._spmd_mesh = self.mesh_info.mesh if isinstance(self.mesh_info, HSDPMeshInfo): - self._global_placements = (Replicate(), Shard(0)) + self._spmd_placements = (Replicate(), Shard(0)) else: - self._global_placements = (Shard(0),) + self._spmd_placements = (Shard(0),) self._global_size = param.size() self._global_stride = param.stride() param_data = param @@ -443,8 +439,8 @@ def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor: ) return _from_local_no_grad( tensor, - self._global_mesh, - self._global_placements, + self._spmd_mesh, + self._spmd_placements, self._global_size, self._global_stride, ) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index a0e7b7acddeb..2913f28f9e36 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -69,7 +69,7 @@ def get_current_mesh(self) -> "DeviceMesh": return self.mesh_stack[-1] def create_child_mesh( - self, parent_mesh: "DeviceMesh", submesh_dim_names: Tuple[str] + self, parent_mesh: "DeviceMesh", submesh_dim_names: Tuple[str, ...] ) -> "DeviceMesh": # submesh_dims are the mesh dimension of the submesh in the parent mesh. submesh_dims = [ @@ -382,7 +382,9 @@ def __eq__(self, other: object) -> bool: and self._thread_id == other._thread_id ) - def __getitem__(self, mesh_dim_names: Union[str, Tuple[str]]) -> "DeviceMesh": + def __getitem__( + self, mesh_dim_names: Union[str, Tuple[str, ...]] + ) -> "DeviceMesh": """ Slice the current DeviceMesh based on the mesh_dim_name given to create a child DeviceMesh. From dae757c97185bcaf0d3d2a28453ef4f3fd681713 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 4 Jun 2024 04:25:39 +0000 Subject: [PATCH 300/706] Specify supported OS matrix (#127816) Windows-10 or newer manylinux-2014 MacOS-11 or newer (but only on Apple Silicon) Fixes https://github.com/pytorch/pytorch/issues/126679 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127816 Approved by: https://github.com/kit1980, https://github.com/huydhn --- RELEASE.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/RELEASE.md b/RELEASE.md index ff8e99883e4e..3c9d68f9a6cd 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -37,6 +37,7 @@ - [TL;DR](#tldr) - [Accelerator Software](#accelerator-software) - [Special support cases](#special-support-cases) + - [Operating Systems](#operating-systems) - [Submitting Tutorials](#submitting-tutorials) - [Special Topics](#special-topics) - [Updating submodules for a release](#updating-submodules-for-a-release) @@ -426,6 +427,15 @@ the size restrictions for publishing on PyPI so the default version that is publ These special support cases will be handled on a case by case basis and support may be continued if current PyTorch maintainers feel as though there may still be a need to support these particular versions of software. +## Operating Systems +Supported OS flavors are summarized in the table below: +| Operating System family | Architectrue | Notes | +| --- | --- | --- | +| Linux | aarch64, x86_64 | Wheels are manylinux2014 compatible, i.e. they should be runnable on any Linux system with glibc-2.17 or above. | +| MacOS | arm64 | Builds should be compatible with MacOS 11 (Big Sur) or newer, but are actively tested against MacOS 14 (Sonoma). | +| MacOS | x86_64 | Requires MacOS Catalina or above, not supported after 2.2, see https://github.com/pytorch/pytorch/issues/114602 | +| Windows | x86_64 | Buils are compatible with Windows-10 or newer. | + # Submitting Tutorials Tutorials in support of a release feature must be submitted to the [pytorch/tutorials](https://github.com/pytorch/tutorials) repo at least two weeks before the release date to allow for editorial and technical review. There is no cherry-pick process for tutorials. All tutorials will be merged around the release day and published at [pytorch.org/tutorials](https://pytorch.org/tutorials/). From e793ae220ffb56d07582a93f1dd43ec085e72a50 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 4 Jun 2024 04:27:23 +0000 Subject: [PATCH 301/706] [Inductor][Flex-attention] Support different sequence lengths for Query and Key/Value (#127678) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/127678 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 53 ++++++++++---- torch/_inductor/kernel/flex_attention.py | 93 ++++++++++++++---------- torch/_inductor/select_algorithm.py | 9 ++- torch/nn/attention/_flex_attention.py | 5 -- 4 files changed, 98 insertions(+), 62 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 21e462f75f81..7601922a685e 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -158,23 +158,33 @@ def run_test( self, score_mod: Callable, dtype: torch.dtype = torch.float16, - B: int = B, - H: int = H, - S: int = S, - D: int = D, + Q_B: int = B, + Q_H: int = H, + Q_S: int = S, + Q_D: int = D, + KV_B: int = B, + KV_H: int = H, + KV_S: int = S, + KV_D: int = D, ): sdpa_partial = create_attention(score_mod) compiled_sdpa = torch.compile(sdpa_partial) - q = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q = torch.randn( + (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True + ) + k = torch.randn( + (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True + ) + v = torch.randn( + (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True + ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) golden_out = sdpa_partial(q_gold, k_gold, v_gold) ref_out = sdpa_partial(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) - backward_grad = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + backward_grad = torch.randn((Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda") golden_out.backward(backward_grad.to(torch.float64)) ref_out.backward(backward_grad) @@ -348,6 +358,25 @@ def test_builtin_score_mods_automatic_dynamic( ): self.run_automatic_dynamic_test(score_mod, dtype) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes_fast) + @common_utils.parametrize("score_mod", test_score_mods) + def test_builtin_score_mods_different_seqlen( + self, dtype: torch.dtype, score_mod: Callable + ): + self.run_test( + score_mod, + dtype, + B, + H, + S // 2, # Seqlen of Q is different from seqlen of K/V + D, + B, + H, + S, + D, + ) + @supported_platform @common_utils.parametrize("dtype", test_dtypes) def test_skip_odd_keys(self, dtype: torch.dtype): @@ -724,14 +753,6 @@ def test_mixed_dtypes_fails(self): ): _flex_attention(query, key, value, _identity) - @supported_platform - def test_different_sequence_length_fails(self): - query = torch.randn((1, 1, 2048, 64), dtype=torch.float32, device="cuda") - key = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") - value = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") - with self.assertRaisesRegex(ValueError, "NYI: The target sequence length"): - _flex_attention(query, key, value, _identity) - @supported_platform @patch.object(torch._inductor.config, "max_autotune", True) def test_max_autotune(self): diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 3e95dd4f65ce..42fabf65591d 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -189,6 +189,7 @@ def build_subgraph_buffer( Z = {{size("Q", 0)}} H = {{size("Q", 1)}} Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty @@ -196,9 +197,10 @@ def build_subgraph_buffer( start_m = tl.program_id(0) off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh + q_offset = off_hz * stride_qh + kv_offset = off_hz * stride_kh Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, + base=Q + q_offset, shape=(Q_LEN, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), @@ -206,16 +208,16 @@ def build_subgraph_buffer( order=(1, 0) ) K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, Q_LEN), + base=K + kv_offset, + shape=(BLOCK_DMODEL, KV_LEN), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(Q_LEN, BLOCK_DMODEL), + base=V + kv_offset, + shape=(KV_LEN, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), @@ -235,7 +237,7 @@ def build_subgraph_buffer( q = (q * qk_scale).to(MATMUL_PRECISION) # loop over k, v and update accumulator lo = 0 - hi = Q_LEN + hi = KV_LEN for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- load k, v -- @@ -425,6 +427,7 @@ def flex_attention(*args, **kwargs): ], num_stages=num_stages, num_warps=num_warps, + call_sizes=query.get_size(), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=query.get_size()[-1], @@ -445,7 +448,9 @@ def flex_attention(*args, **kwargs): # ---------------------------- Backward HOP Implementation ---------------------------- -def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, meta): +def flex_attention_backward_grid( + batch_size, num_heads, num_queries, d_model, num_key_value, meta +): """How is this kernel parallelized? Currently this is only parallelizing over batch * num_heads, but we can, and want to parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require @@ -453,8 +458,6 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me """ import triton - # TODO: support different seqlen for Query and Key/Value. - num_key_value = num_queries return ( triton.cdiv(num_queries, meta["BLOCK_M2"]) + triton.cdiv(num_key_value, meta["BLOCK_N1"]), @@ -476,7 +479,7 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me # DK: Derivative of Key, is the written to via the store_output call due to some limitations with # inductor codegen # M: Number of queries, N: Number of keys/values, D: Model dimension - # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim # (Modifiable) Config options: # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. @@ -486,10 +489,20 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me # change of base out of the loop # Define Q Strides - stride_z = {{stride("Q", 0)}} - stride_h = {{stride("Q", 1)}} - stride_tok = {{stride("Q", 2)}} - stride_d = {{stride("Q", 3)}} + stride_qz = {{stride("Q", 0)}} + stride_qh = {{stride("Q", 1)}} + stride_qm = {{stride("Q", 2)}} + stride_qd = {{stride("Q", 3)}} + # Define K Strides + stride_kz = {{stride("K", 0)}} + stride_kh = {{stride("K", 1)}} + stride_km = {{stride("K", 2)}} + stride_kd = {{stride("K", 3)}} + # Define V Strides + stride_vz = {{stride("V", 0)}} + stride_vh = {{stride("V", 1)}} + stride_vm = {{stride("V", 2)}} + stride_vd = {{stride("V", 3)}} Z = {{size("Q", 0)}} H = {{size("Q", 1)}} @@ -501,21 +514,22 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me pid = tl.program_id(0) NUM_KV_BLOCKS = KV_LEN // BLOCK_N1 - bhid = tl.program_id(2) - off_chz = (bhid * Q_LEN).to(tl.int64) - adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) - off_hz = tl.program_id(2) off_z = off_hz // H # batch idx off_h = off_hz % H # head idx + off_chz = (off_hz * Q_LEN).to(tl.int64) + q_adj = (stride_qh * (off_hz % H) + stride_qz * (off_hz // H)).to(tl.int64) + k_adj = (stride_kh * (off_hz % H) + stride_kz * (off_hz // H)).to(tl.int64) + v_adj = (stride_vh * (off_hz % H) + stride_vz * (off_hz // H)).to(tl.int64) + # offset pointers for batch/head - Q += adj - K += adj - V += adj - DO += adj - DQ += adj - DV += adj + Q += q_adj + K += k_adj + V += v_adj + DO += q_adj + DQ += q_adj + DV += v_adj LSE += off_chz DELTA += off_chz @@ -528,9 +542,9 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) - q = tl.load(Q + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) + q = tl.load(Q + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd) dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) - do = tl.load(DO + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d) + do = tl.load(DO + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd) lse = tl.load(LSE + offs_m2) lse = lse[:, None] @@ -538,8 +552,8 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me start_n2 = 0 offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) offs_n2 = start_n2 + tl.arange(0, BLOCK_N2) - kT_ptrs = K + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d - vT_ptrs = V + offs_n2[None, :] * stride_tok + offs_k[:, None] * stride_d + kT_ptrs = K + offs_n2[None, :] * stride_km + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vm + offs_k[:, None] * stride_vd Di = tl.load(DELTA + offs_m2) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) @@ -590,10 +604,10 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me dq += tl.dot(ds, tl.trans(kT)) # Increment pointers. curr_n += BLOCK_N2 - kT_ptrs += BLOCK_N2 * stride_tok - vT_ptrs += BLOCK_N2 * stride_tok + kT_ptrs += BLOCK_N2 * stride_km + vT_ptrs += BLOCK_N2 * stride_km # Write back dQ. - dq_ptrs = DQ + offs_m2[:, None] * stride_tok + offs_k[None, :] * stride_d + dq_ptrs = DQ + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd tl.store(dq_ptrs, dq) else: # THIS BLOCK DOES DK & DV @@ -606,13 +620,13 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d) + k = tl.load(K + offs_n1[:, None] * stride_km + offs_k[None, :] * stride_kd) + v = tl.load(V + offs_n1[:, None] * stride_vm + offs_k[None, :] * stride_vd) offs_m1 = start_m1 + tl.arange(0, BLOCK_M1) offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) - qT_ptrs = Q + offs_m1[None, :] * stride_tok + offs_k[:, None] * stride_d - do_ptrs = DO + offs_m1[:, None] * stride_tok + offs_k[None, :] * stride_d + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_qm + offs_k[None, :] * stride_qd # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) @@ -668,10 +682,10 @@ def flex_attention_backward_grid(batch_size, num_heads, num_queries, d_model, me dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT)) # Increment pointers. curr_m += BLOCK_M1 - qT_ptrs += BLOCK_M1 * stride_tok - do_ptrs += BLOCK_M1 * stride_tok + qT_ptrs += BLOCK_M1 * stride_qm + do_ptrs += BLOCK_M1 * stride_qm - dv_ptrs = DV + offs_n1[:, None] * stride_tok + offs_k[None, :] * stride_d + dv_ptrs = DV + offs_n1[:, None] * stride_vm + offs_k[None, :] * stride_vd tl.store(dv_ptrs, dv) # Write back dK. @@ -773,6 +787,7 @@ def flex_attention_backward(*args, **kwargs): layout=layout_k, # We use store_output only for grad_key subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer], mutated_inputs=[grad_query, grad_value], + call_sizes=query.get_size() + [key.get_size()[2]], num_stages=num_stages, num_warps=num_warps, BLOCK_M1=BLOCK_M, diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 531f3c25a31b..bc89441e3bd8 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -580,6 +580,7 @@ def generate( epilogue_fn=identity, subgraphs=None, mutated_inputs=None, + call_sizes=None, **kwargs, ): """This function generates a TritonTemplateCaller @@ -614,6 +615,9 @@ def generate( "64-bit indexing is not yet implemented for triton templates" ) + if call_sizes is None: + call_sizes = layout.size + kernel_options = dict( input_nodes=input_nodes, defines=defines, @@ -621,13 +625,14 @@ def generate( num_warps=num_warps, grid_fn=self.grid, meta=kwargs, - call_sizes=layout.size, + call_sizes=call_sizes, prefix_args=prefix_args, suffix_args=suffix_args, epilogue_fn=epilogue_fn, index_dtype="tl.int32", subgraphs=subgraphs, ) + with patch.object( V.graph, "get_dtype", self._fake_get_dtype(fake_out) ), TritonTemplateKernel( @@ -701,7 +706,7 @@ def make_kernel_render(out_node): assert mod.__file__ is not None grid = self.grid( *V.graph.sizevars.size_hints( - layout.size, + call_sizes, fallback=config.unbacked_symint_fallback, ), kwargs, diff --git a/torch/nn/attention/_flex_attention.py b/torch/nn/attention/_flex_attention.py index bd999ec39118..430d3280442a 100644 --- a/torch/nn/attention/_flex_attention.py +++ b/torch/nn/attention/_flex_attention.py @@ -101,11 +101,6 @@ def score_mod( # Some basic input validation _validate_sdpa_input(query, key, value) - # This will restriction will be removed in newer version of the kernel - if query.size(-2) != key.size(-2): - raise ValueError( - "NYI: The target sequence length (L) of the query tensor must match the source sequence length (S) of the key tensor." - ) if query.size(-2) % 128 != 0: raise ValueError("NYI: S and L must be a multiple of 128") From 8d153e0bab267a41f0f3c89ee998e4fa8d152d0a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 4 Jun 2024 04:32:03 +0000 Subject: [PATCH 302/706] [Inductor] Add FlexAttention backward kernel dynamic shape tests (#127728) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/127728 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 158 +++++++++++++++++++-------- 1 file changed, 110 insertions(+), 48 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 7601922a685e..d7afbe1123e7 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -154,6 +154,47 @@ def _check_equal( msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." self.assertTrue(False, msg) + def _check_out_and_grad( + self, + golden_out: torch.Tensor, + ref_out: torch.Tensor, + compiled_out: torch.Tensor, + q_gold: torch.Tensor, + q_ref: torch.Tensor, + q: torch.Tensor, + k_gold: torch.Tensor, + k_ref: torch.Tensor, + k: torch.Tensor, + v_gold: torch.Tensor, + v_ref: torch.Tensor, + v: torch.Tensor, + ): + dtype = ref_out.dtype + with torch.no_grad(): + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + + # Checkout output + self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") + + # Check gradients + q_fudge_factor = 2.5 * fudge_factor + self._check_equal( + q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" + ) + k_fudge_factor = 4 * fudge_factor + self._check_equal( + k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" + ) + v_fudge_factor = 4 * fudge_factor + self._check_equal( + v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" + ) + def run_test( self, score_mod: Callable, @@ -190,30 +231,20 @@ def run_test( ref_out.backward(backward_grad) compiled_out.backward(backward_grad) - with torch.no_grad(): - # Note, it seems like we really are less accurate than the float32 - # computation, likely due to the online softmax - if dtype == torch.float32: - fudge_factor = 10.0 - else: - fudge_factor = 1.1 - - # Checkout output - self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") - - # Check gradients - q_fudge_factor = 2.5 * fudge_factor - self._check_equal( - q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" - ) - k_fudge_factor = 4 * fudge_factor - self._check_equal( - k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" - ) - v_fudge_factor = 4 * fudge_factor - self._check_equal( - v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" - ) + self._check_out_and_grad( + golden_out, + ref_out, + compiled_out, + q_gold, + q_ref, + q, + k_gold, + k_ref, + k, + v_gold, + v_ref, + v, + ) def run_dynamic_test( self, @@ -226,24 +257,34 @@ def run_dynamic_test( ): sdpa_partial = create_attention(score_mod) # The first eager batch, shape (B, H, S, D) - q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out1 = sdpa_partial( - q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) - ) - ref_out1 = sdpa_partial(q1, k1, v1) + q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1) + q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64) + ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref) + golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold) + + backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + + golden_out1.backward(backward_grad1.to(torch.float64)) + ref_out1.backward(backward_grad1) # The second eager batch, shape (B * 2, H, S / 2, D) B = int(B * 2) S = int(S / 2) - q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out2 = sdpa_partial( - q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) - ) - ref_out2 = sdpa_partial(q2, k2, v2) + q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2) + q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64) + ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref) + golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold) + + backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + + golden_out2.backward(backward_grad2.to(torch.float64)) + ref_out2.backward(backward_grad2) # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. # We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation. @@ -251,20 +292,41 @@ def run_dynamic_test( # Compiling with dynamic shape in the first batch. compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) compiled_out1 = compiled_sdpa(q1, k1, v1) - - # Note, it seems like we really are less accurate than the float32 - # computation, likely due to the online softmax - if dtype == torch.float32: - fudge_factor = 10.0 - else: - fudge_factor = 1.1 - - self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) + compiled_out1.backward(backward_grad1) + + self._check_out_and_grad( + golden_out1, + ref_out1, + compiled_out1, + q1_gold, + q1_ref, + q1, + k1_gold, + k1_ref, + k1, + v1_gold, + v1_ref, + v1, + ) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # No re-compilation, use the compiled dynamic shape version. compiled_out2 = compiled_sdpa(q2, k2, v2) - self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) + compiled_out2.backward(backward_grad2) + self._check_out_and_grad( + golden_out2, + ref_out2, + compiled_out2, + q2_gold, + q2_ref, + q2, + k2_gold, + k2_ref, + k2, + v2_gold, + v2_ref, + v2, + ) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) def run_automatic_dynamic_test( From 2ad0e4197d1c6b2f0f732f2a6ec29de2af031c28 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 4 Jun 2024 04:51:29 +0000 Subject: [PATCH 303/706] [ts-migration] support aten::__is__, aten::__isnot__, aten::__not__, profiler::_record_function_enter_new, profiler::_record_function_exit (#127656) Support more ops in ts converter and add unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127656 Approved by: https://github.com/SherlockNoMad --- test/export/test_converter.py | 61 +++++++++++++++++++++++++++++++++-- torch/_export/converter.py | 47 +++++++++++++++++++++------ 2 files changed, 96 insertions(+), 12 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index cde08b7f7cd3..d59bb0ebf8f7 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: export"] import unittest +from typing import Tuple import torch @@ -9,7 +10,6 @@ from torch._dynamo.test_case import TestCase from torch._export.converter import TS2EPConverter from torch.export import ExportedProgram - from torch.testing._internal.common_utils import run_tests requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") @@ -23,8 +23,11 @@ def _check_equal_ts_ep_converter(self, mod, inp) -> ExportedProgram: orig_out, _ = pytree.tree_flatten(mod(*inp)) self.assertEqual(len(ep_out), len(orig_out)) for ep_t, orig_t in zip(ep_out, orig_out): - self.assertEqual(ep_t.shape, orig_t.shape) - self.assertTrue(torch.allclose(ep_t, orig_t)) + if isinstance(ep_t, torch.Tensor): + self.assertEqual(ep_t.shape, orig_t.shape) + self.assertTrue(torch.allclose(ep_t, orig_t)) + else: + self.assertEqual(ep_t, orig_t) return ep def test_ts2ep_converter_basic(self): @@ -192,6 +195,58 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): M()(torch.tensor(False), torch.tensor(4)), ) + def test_profiler__record_function(self): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + handle = torch.ops.profiler._record_function_enter_new("foo", None) + y = x * 2 + 4 + torch.ops.profiler._record_function_exit(handle) + return y + + x = torch.randn(10, 10) + self._check_equal_ts_ep_converter(Module(), (x,)) + + def test_aten_floordiv(self): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x // 2 + + x = torch.randn(10, 10) + self._check_equal_ts_ep_converter(Module(), (x,)) + + def test_aten___is__(self): + class Module(torch.nn.Module): + def forward( + self, x: torch.Tensor, y: torch.Tensor + ) -> Tuple[bool, torch.Tensor]: + z = x + 1 + return x is y, z + + inp = (torch.randn(10, 10), torch.rand(10, 10)) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_aten___isnot__(self): + class Module(torch.nn.Module): + def forward( + self, x: torch.Tensor, y: torch.Tensor + ) -> Tuple[bool, torch.Tensor]: + z = x + 1 + return x is not y, z + + inp = (torch.randn(10, 10), torch.rand(10, 10)) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_aten___not__(self): + class Module(torch.nn.Module): + def forward( + self, x: torch.Tensor, y: torch.Tensor + ) -> Tuple[bool, torch.Tensor]: + z = x + 1 + return not (x is not y), z + + inp = (torch.randn(10, 10), torch.rand(10, 10)) + self._check_equal_ts_ep_converter(Module(), inp) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 88f43d46cb48..2c021377cb4f 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -41,6 +41,15 @@ def normalize_name(name: str) -> str: return name.replace(".", "_") +# Given a node: torch._C.Node, map from node.kind() to a standard operator +kind_to_standard_operators = { + "prim::TupleIndex": operator.getitem, + "aten::__is__": operator.is_, + "aten::__isnot__": operator.is_not, + "aten::__not__": operator.not_, +} + + def get_op_overload(node: torch._C.Node): schema_str = node.schema() schema = FunctionSchema.parse(schema_str) @@ -285,13 +294,6 @@ def convert_prim_DictConstruct(self, node: torch._C.Node): output_name = node.output().debugName() self.name_to_node[output_name] = output_dict - def convert_prim_TupleIndex(self, node: torch._C.Node): - args = tuple(self.get_fx_value(input) for input in node.inputs()) - getitem_node = self.fx_graph.call_function(operator.getitem, args) - - output_name = node.output().debugName() - self.name_to_node[output_name] = getitem_node - def convert_aten_Int(self, node: torch._C.Node): # converts aten::Int as aten._to_copy + aten::_local_scalar_dense target = torch.ops.aten._to_copy.default @@ -438,6 +440,28 @@ def convert_as_noop(self, node: torch._C.Node): output_name = node.output().debugName() self.name_to_node[output_name] = args[0] + def convert_profiler__record_function_enter_new(self, node: torch._C.Node): + target = torch.ops.profiler._record_function_enter_new + args = tuple(self.get_fx_value(input) for input in node.inputs()) + fx_node = self.fx_graph.call_function(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_profiler__record_function_exit(self, node: torch._C.Node): + # _record_function_exit has side effect so we keep it in fx.graph + # currently, _record_function_enter_new and _record_function_exit are + # discarded during `retrace_as_exported_program`. + target = torch.ops.profiler._record_function_exit + args = tuple(self.get_fx_value(input) for input in node.inputs()) + self.fx_graph.call_function(target, args) + + def convert_standard_operators(self, node: torch._C.Node): + target = kind_to_standard_operators[node.kind()] + args = tuple(self.get_fx_value(input) for input in node.inputs()) + fx_node = self.fx_graph.call_function(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + def convert_node(self, node: torch._C.Node): node_kind = node.kind() if node_kind == "prim::CreateObject": @@ -457,8 +481,6 @@ def convert_node(self, node: torch._C.Node): self.convert_prim_dtype(node) elif node_kind == "prim::DictConstruct": self.convert_prim_DictConstruct(node) - elif node_kind == "prim::TupleIndex": - self.convert_prim_TupleIndex(node) # elif node_kind == "aten::Int": # convert_aten_Int(node) elif node_kind == "aten::_convolution": @@ -471,7 +493,14 @@ def convert_node(self, node: torch._C.Node): self.convert_prim_if(node) elif node_kind == "aten::Bool": self.convert_as_noop(node) + elif node_kind == "profiler::_record_function_enter_new": + self.convert_profiler__record_function_enter_new(node) + elif node_kind == "profiler::_record_function_exit": + self.convert_profiler__record_function_exit(node) + elif node_kind in kind_to_standard_operators: + self.convert_standard_operators(node) elif node_kind.startswith("aten::"): + # order matters! this should be handled after kind_to_standard_operators self.convert_aten_op(node) else: raise ValueError(f"Unsupported node kind: {node_kind}") From e7cb43a2d2bb1740fd6f4bc1a440004664007a3f Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 4 Jun 2024 05:35:25 +0000 Subject: [PATCH 304/706] Check unused variables in tests (#127498) Enables unused variable checks in CMake. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127498 Approved by: https://github.com/ezyang --- caffe2/CMakeLists.txt | 3 --- test/cpp/jit/CMakeLists.txt | 3 --- test/cpp/tensorexpr/CMakeLists.txt | 3 --- test/cpp/tensorexpr/test_llvm.cpp | 4 +--- 4 files changed, 1 insertion(+), 12 deletions(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index df1eecf929a8..f7de195a6272 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1810,9 +1810,6 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE $) target_include_directories(${test_name} PRIVATE $) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) - if(NOT MSVC) - target_compile_options(${test_name} PRIVATE -Wno-unused-variable) - endif() add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) install(TARGETS ${test_name} DESTINATION test) diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index 2d88d3f7172d..f0510d9c81f2 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -129,9 +129,6 @@ endif(MSVC) target_link_libraries(test_jit PRIVATE ${JIT_TEST_DEPENDENCIES}) target_include_directories(test_jit PRIVATE ${ATen_CPU_INCLUDE}) -if(NOT MSVC) - target_compile_options(test_jit PRIVATE $<$:-Wno-unused-variable>) -endif() if(LINUX) #Update to target_link_options when CMake version can be upgraded diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt index 012471d0e584..179270c4a4a1 100644 --- a/test/cpp/tensorexpr/CMakeLists.txt +++ b/test/cpp/tensorexpr/CMakeLists.txt @@ -42,9 +42,6 @@ add_executable(test_tensorexpr target_link_libraries(test_tensorexpr PRIVATE torch gtest) target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST) -if(NOT MSVC) - target_compile_options(test_tensorexpr PRIVATE -Wno-unused-variable) -endif() add_executable(tutorial_tensorexpr ${TENSOREXPR_TEST_ROOT}/tutorial.cpp) target_link_libraries(tutorial_tensorexpr PRIVATE torch) diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index aa578a4956c6..f6ffc84f62c0 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -179,7 +179,7 @@ TEST(LLVM, CharToFloatCastTest) { } TEST(LLVM, BitCast) { - constexpr int16_t ref16 = 1337; + /* constexpr int16_t ref16 = 1337; */ constexpr int32_t ref32 = 1337; constexpr int64_t ref64 = 1337; constexpr float reff32 = 1337.0f; @@ -1395,7 +1395,6 @@ TEST(LLVM, EliminatedStmt) { TEST(LLVM, SimpleReduction) { int M = 128; int N = 64; - const int kTotalSize = M * N; BufHandle a("a", {1, M, N}, kFloat); @@ -1429,7 +1428,6 @@ TEST(LLVM, SimpleReduction) { TEST(LLVM, RFactorReduction) { int M = 128; int N = 64; - const int kTotalSize = M * N; BufHandle a("a", {1, M, N}, kFloat); From f4b77ce8e289e587077cf607eedf9ded6ff0f6e6 Mon Sep 17 00:00:00 2001 From: satheeshhab <153791691+satheeshhab@users.noreply.github.com> Date: Tue, 4 Jun 2024 06:09:17 +0000 Subject: [PATCH 305/706] Masked scale meta function registration #119984 (#127389) Fixes #119984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127389 Approved by: https://github.com/cpuhrsch --- torch/_meta_registrations.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 624801bf9afa..7442ca9157e3 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3622,6 +3622,14 @@ def meta_masked_fill_(self, mask, value): return self +@register_meta(aten._masked_scale.default) +def meta__masked_scale(self, mask, scale): + masked_scale = self.new_empty(self.size()).to( + memory_format=utils.suggest_memory_format(self) + ) + return masked_scale + + @register_meta(aten.masked_scatter_) def meta_masked_scatter_(self, mask, source): torch._check( From ef77f2ca4a4d714cd1501e4b184cba9eccd41ce6 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 31 May 2024 22:03:17 -0700 Subject: [PATCH 306/706] [pipelining] Simple 1F1B schedule (#127673) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ![Screenshot 2024-05-31 at 9 13 18 PM](https://github.com/pytorch/pytorch/assets/6676466/ecf3ca24-33a6-4188-9f7c-df6e96311caa) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127673 Approved by: https://github.com/wconstab --- .../pipelining/PipelineSchedule.py | 215 +++++++++--------- 1 file changed, 105 insertions(+), 110 deletions(-) diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index 16940674b670..f7d6d7c1b372 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -206,6 +206,8 @@ def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None): """ Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. """ + if len(p2p_ops) == 0: + return None desc_str = f"{desc}, " if desc else "" logger.debug(f"batch_p2p {desc_str}{p2p_ops}") # noqa: G004 return dist.batch_isend_irecv(p2p_ops).pop() @@ -399,121 +401,114 @@ def _step_microbatches( """ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - # Example, 4 GPUs, 8 microbatches - # Stage 0: 6 warmup, 2 1f1b, 6 cooldown - # Stage 1: 4 warmup, 4 1f1b, 4 cooldown - # Stage 2: 2 warmup, 6 1f1b, 2 cooldown - # Stage 3: 0 warmup, 8 1f1b, 0 cooldown - # fwd only - warmup_steps = min( + # Last stage has 1 warmup, second-to-last 2 warmups, ... + # first stage `num_stages` warmups + warmup_chunks = min( self._n_microbatches, - 2 * (self._num_stages - self._stage.stage_index - 1), - ) - # fwd + bwd - main_1f1b_steps = self._n_microbatches - warmup_steps - # bwd only - cooldown_steps = (2 * self._n_microbatches) - ( - warmup_steps + (2 * main_1f1b_steps) - ) - total_steps = warmup_steps + main_1f1b_steps + cooldown_steps - logger.debug( - f"Stage {self._stage.stage_index}: " # noqa: G004 - f"Warmup steps: {warmup_steps}, " - f"Main 1F1B steps: {main_1f1b_steps}, " - f"Cooldown steps: {cooldown_steps}, " - f"Total steps: {total_steps}" + self._num_stages - self._stage.stage_index, ) - # Delay send waits - fwd_sends_to_wait: List[dist.Work] = [] - bwd_sends_to_wait: List[dist.Work] = [] - - def step_has_forward(i): - assert i >= 0, i - return i < self._n_microbatches - - def step_has_backward(i): - assert i < total_steps, i - return i >= warmup_steps and self._has_backward - - def is_1f1b_step(i): - return step_has_forward(i) and step_has_backward(i) - - def is_warmup_step(i): - return step_has_forward(i) and not step_has_backward(i) - - def is_cooldown_step(i): - return not step_has_forward(i) and step_has_backward(i) - - def should_coalesce_fwd_send_bwd_recv(step): - return ( - is_1f1b_step(step) - or (is_warmup_step(step) and is_cooldown_step(step + 1)) - or (step >= 1 and is_warmup_step(step - 1) and is_cooldown_step(step)) - ) - - def should_coalesce_bwd_send_fwd_recv(bwd_send_step): - # The backward send to prev stage should be coalesced with the fwd recv from the previous stage - return bwd_send_step >= warmup_steps and is_1f1b_step(bwd_send_step + 1) - - # bwd chunk counter + # Chunk counters + fwd_mb_index = 0 bwd_mb_index = 0 - self._stage._configure_data_parallel_mode(last_backward=False) - for i in range(total_steps): - if step_has_forward(i): - with record_function(f"Forward {i}"): - ops = self._stage.get_fwd_recv_ops() - desc = "fwd_recv" - if should_coalesce_bwd_send_fwd_recv(i - 1): - desc += "_bwd_send" - ops.extend(self._stage.get_bwd_send_ops()) - - works = _sorted_batch_p2p(ops, desc=desc) - for work in works.values(): - work.wait() - - output = self._stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] - - if not should_coalesce_fwd_send_bwd_recv(i): - ops = self._stage.get_fwd_send_ops() - works = _sorted_batch_p2p(ops, desc="fwd_send") - fwd_sends_to_wait.extend(works.values()) - - self._maybe_compute_loss(self._stage, output, target_mbs, i) - - if step_has_backward(i): - self._stage._configure_data_parallel_mode( - last_backward=(i == total_steps - 1) - ) - with record_function(f"Backward {bwd_mb_index}"): - ops = self._stage.get_bwd_recv_ops() - desc = "bwd_recv" - if should_coalesce_fwd_send_bwd_recv(i): - ops.extend(self._stage.get_fwd_send_ops()) - desc += "_fwd_send" - - works = _sorted_batch_p2p(ops, desc=desc) - for work in works.values(): - work.wait() - - loss = self._maybe_get_loss(self._stage, bwd_mb_index) - self._stage.backward_one_chunk(loss=loss) - - if not should_coalesce_bwd_send_fwd_recv(i): - # see Note: coalesced bwd-send/fwd-recv - ops = self._stage.get_bwd_send_ops() - works = _sorted_batch_p2p(ops, desc="bwd_send") - bwd_sends_to_wait.extend(works.values()) - - bwd_mb_index += 1 - # Wait for all forward sends to finish - for work in fwd_sends_to_wait: - work.wait() - - # Wait for all backward sends to finish - for work in bwd_sends_to_wait: - work.wait() + # Warmup phase + send_work = None + fwd_sends = [] + for _ in range(warmup_chunks): + # Receive activations + fwd_recvs = self._stage.get_fwd_recv_ops() + if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"): + recv_work.wait() + + # Compute + output = self._stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + + # Clear previous chunk's forward sends (hopefully they have well + # finished, otherwise, we are heavily communication bound, in which + # case it doesn't create a lot of benefit to compute next chunk + # eagerly either) + if send_work: + send_work.wait() + + # Send activations + fwd_sends = self._stage.get_fwd_send_ops() + if fwd_mb_index != warmup_chunks - 1: + # Safe to fire + send_work = _batch_p2p(fwd_sends, desc="fwd_send") + # otherwise: + # The last foward send is left for fuse with first 1B in 1B1F below + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + fwd_mb_index += 1 + + # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below. + + # 1B1F phase + while True: # Don't worry, we have a break inside + # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops + bwd_recvs = self._stage.get_bwd_recv_ops() + + # Now, we need to fire the fwd_sends and bwd_recvs together + if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"): + fuse_work.wait() + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk(loss=loss) + + # Get the bwd send ops, but don't fire, to be fused with the 1F below + bwd_sends = self._stage.get_bwd_send_ops() + bwd_mb_index += 1 + + if fwd_mb_index == self._n_microbatches: + # We are done with 1B1F, so break with some left-over bwd_sends + break + + # We prepare 1F of the `1B1F` + fwd_recvs = self._stage.get_fwd_recv_ops() + + # Fuse it with bwd_sends above + if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"): + fuse_work.wait() + + # Now do the fwd + output = self._stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + + # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around) + fwd_sends = self._stage.get_fwd_send_ops() + fwd_mb_index += 1 + + # Remember we still have some bwd_sends left over after the break? Now it is time to fire it + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + + # Cooldown + while bwd_mb_index < self._n_microbatches: + # prepare bwd recv ops + bwd_recvs = self._stage.get_bwd_recv_ops() + if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"): + recv_work.wait() + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk(loss=loss) + + # Clear previous chunk's backward sends (hopefully they have well finished) + if send_work: + send_work.wait() + + # Get the bwd send ops, fire it + bwd_sends = self._stage.get_bwd_send_ops() + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + bwd_mb_index += 1 + + # Wait for the last backward send to finish + if send_work: + send_work.wait() # Return losses if there is a container passed in self._update_losses(self._stage, losses) From 2122c9e2a9ce4cb91a0938ceacc0264b2b7547b8 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 3 Jun 2024 17:25:46 -0700 Subject: [PATCH 307/706] [BE] Enabled lintrunner on torch/distributed/utils.py (#127771) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127771 Approved by: https://github.com/wanchaol, https://github.com/Skylion007 --- .lintrunner.toml | 1 - torch/distributed/utils.py | 40 ++++++++++++++++++++++++++------------ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 8f7be27ece84..ce69059d1ead 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1582,7 +1582,6 @@ exclude_patterns = [ 'torch/distributed/tensor/parallel/input_reshard.py', 'torch/distributed/tensor/parallel/multihead_attention_tp.py', 'torch/distributed/tensor/parallel/style.py', - 'torch/distributed/utils.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/functional.py', diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index f47908d96c74..af44fee9d720 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -1,6 +1,17 @@ import dataclasses import traceback -from typing import Any, Callable, Container, Dict, List, Optional, OrderedDict, Tuple, TypeVar, overload +from typing import ( + Any, + Callable, + Container, + Dict, + List, + Optional, + OrderedDict, + overload, + Tuple, + TypeVar, +) import torch import torch.distributed as dist @@ -40,6 +51,7 @@ def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, return tuple(flat_args), tuple(kwarg_keys) + def _cast_forward_inputs( dtype: Optional[torch.dtype], *args: Any, @@ -60,7 +72,10 @@ def cast_fn(x: torch.Tensor) -> torch.Tensor: return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs)) -def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + +def _unpack_kwargs( + flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...] +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """See _pack_kwargs.""" assert len(kwarg_keys) <= len( flat_args @@ -77,12 +92,16 @@ def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> T @overload -def _recursive_to(inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> List[S]: +def _recursive_to( + inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool +) -> List[S]: ... @overload -def _recursive_to(inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> Tuple[T]: +def _recursive_to( + inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool +) -> Tuple[T]: ... @@ -155,9 +174,7 @@ def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None: storage was already allocated. """ with torch.no_grad(): - if ( - not torch.distributed._functional_collectives.is_torchdynamo_compiling() - ): + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): already_allocated = tensor._typed_storage()._size() == size.numel() if not already_allocated: tensor_storage_size = tensor._typed_storage()._size() @@ -177,9 +194,7 @@ def _free_storage(tensor: torch.Tensor): storage was already freed. """ with torch.no_grad(): - if ( - not torch.distributed._functional_collectives.is_torchdynamo_compiling() - ): + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): already_freed = tensor._typed_storage()._size() == 0 if not already_freed: _p_assert( @@ -192,7 +207,6 @@ def _free_storage(tensor: torch.Tensor): tensor._typed_storage()._resize_(0) - Q = TypeVar("Q") R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any) @@ -264,7 +278,9 @@ def _to_kwargs( def _verify_param_shape_across_processes( - process_group: dist.ProcessGroup, tensors: List[torch.Tensor], logger: Optional[dist.Logger] = None + process_group: dist.ProcessGroup, + tensors: List[torch.Tensor], + logger: Optional[dist.Logger] = None, ): return dist._verify_params_across_processes(process_group, tensors, logger) From e216df48c8dc0ec7a4b3d279444259b13f674257 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 4 Jun 2024 06:12:59 +0000 Subject: [PATCH 308/706] [Dynamo][TVM] Fix ignored `trials` argument for MetaSchedule (#127747) Fixes #127746 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127747 Approved by: https://github.com/jansel --- torch/_dynamo/backends/tvm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/backends/tvm.py b/torch/_dynamo/backends/tvm.py index a0a86536c16d..84084616995e 100644 --- a/torch/_dynamo/backends/tvm.py +++ b/torch/_dynamo/backends/tvm.py @@ -97,11 +97,12 @@ def tvm( ) # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch # once USE_PT_TVMDSOOP is updated and turned on by default in TVM. + assert trials > 0 database = ms.relay_integration.tune_relay( mod=mod, target=target, work_dir=work_dir, - max_trials_global=20000, + max_trials_global=trials, num_trials_per_iter=64, params=params, strategy="evolutionary", From 6abca6a5647f3200547d131ae0fef83bdd4073bb Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 3 Jun 2024 23:39:46 -0700 Subject: [PATCH 309/706] [export][unflatten] More strictly respect scope when removing inputs (#127607) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Code snippet from TorchTitan (LLaMa): ``` for layer in self.layers.values(): h = layer(h, self.freqs_cis) ``` `self.freqs_cis` is a buffer of root module (`self`). It is also an explicit arg in the call signature of original `layer` modules. If not respecting scope -- `freqs_cis`'s scope only corresponds to root -- `_sink_param` can remove `freqs_cis` from `layer`'s call signature, resulting in runtime error. There are two fixes in this PR: 1. We filter out the `inputs_to_state` corresponding to the current scope, using existing code that does prefix matching. 2. We delay the removal of param inputs from `call_module` nodes' `args`, till `_sink_param` call on that submodule returns. The return now returns information on which input is actually removed by the submodule, thus more accurate than just doing: ``` for node in call_module_nodes: node.args = tuple(filter(lambda n: n.name not in inputs_to_state, node.args)) ``` Before the PR: ![Screenshot 2024-05-31 at 1 40 24 AM](https://github.com/pytorch/pytorch/assets/6676466/a2e06b18-44d5-40ca-b242-0edab45075b7) After the PR: ![Screenshot 2024-05-31 at 1 43 41 AM](https://github.com/pytorch/pytorch/assets/6676466/b72afb94-cdfa-420d-b88b-29a92bf2a0c0) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127607 Approved by: https://github.com/pianpwk --- test/export/test_unflatten.py | 22 +++++++ torch/export/unflatten.py | 116 ++++++++++++++++++++++++---------- 2 files changed, 106 insertions(+), 32 deletions(-) diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 3ca58e8fff79..3940cde45234 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -747,6 +747,28 @@ def forward(self, x): unep = unflatten(ep) self.assertTrue(torch.allclose(unep(*inps), m(*inps))) + def test_attr_as_submod_input(self): + class layer(torch.nn.Module): + def forward(self, x, const) -> torch.Tensor: + return x + const + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("const", torch.ones(4, 8)) + self.layers = torch.nn.ModuleList([layer() for _ in range(2)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + x = layer(x, self.const) + return x + + mod = M() + x = torch.randn(4, 8) + ep = export(mod, (x,)) + unflattened = unflatten(ep) + torch.testing.assert_close(unflattened(x), mod(x)) + if __name__ == "__main__": run_tests() diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 2bb38fccc378..61685ed0f180 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -337,16 +337,35 @@ def add_to_consts_map(obj_id, node_name, target_name): inputs_to_state[n] = targets _sink_params(self, inputs_to_state, []) - # Check all input nodes has been processed. - for name, module in self.named_modules(): - if not hasattr(module, "graph"): - continue - for node in module.graph.nodes: - if node.op != "placeholder": - continue - assert ( - node.name not in inputs_to_state - ), f"{node.name} was not sunk into the module {name} which has the graph: {module.graph}" + + # Helper function to check input nodes of `module` has been processed. + def check_module_inputs(module, scope): + if hasattr(module, "graph"): + for node in module.graph.nodes: + # sink_params() should turn placeholders into get_attr nodes + # for attributes that are within scope of the current + # module. We allow attributes to remain as placeholders if + # they are inputs in the original module signature, meaning + # they are a parent module's attribute, and therefore out of + # scope of the current module. + if ( + node.op == "placeholder" + and node.name in inputs_to_state + and any( + fqn.split(".")[: len(scope)] == scope + for fqn in inputs_to_state[node.name] + ) # matching scope to avoid wrong assert + ): + raise AssertionError( + f"{node.name} was not sunk into the module {scope} which has the graph: {module.graph}" + ) + # Recursively check the submodules. + for name, submod in module.named_children(): + scope.append(name) + check_module_inputs(submod, scope) + + # Recurively check all input nodes have been processed. + check_module_inputs(self, []) # Cache so we don't have to compute this every time. # NOTE: this needs to be kept in sync with the placeholders in @@ -1010,14 +1029,23 @@ def _sink_params( scope: tracks where we are in the module hierarchy, so that we can emit the right `getattr(self, "foo.bar")` calls, etc. """ + # This dict records inputs removed by child modules. + # Maps the module object id to the list of placeholder node names + # in the child module that were removed. + module_id_to_inputs_removed: Dict[int, List[str]] = defaultdict(list) + # We need to use _modules here instead of named_children(), because we # explicitly want duplicate modules to show up in the traversal. for name, submodule in module._modules.items(): - _sink_params(cast(torch.nn.Module, submodule), inputs_to_state, scope + [name]) + submod_id_to_inputs_removed = _sink_params( + cast(torch.nn.Module, submodule), inputs_to_state, scope + [name] + ) + for k, v in submod_id_to_inputs_removed.items(): + module_id_to_inputs_removed[k].extend(v) if not hasattr(module, "graph"): # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList) - return + return module_id_to_inputs_removed graph = module.graph inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) @@ -1026,32 +1054,49 @@ def _sink_params( # Also remove from call_module nodes call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) for node in call_module_nodes: - node.args = tuple(filter(lambda n: n.name not in inputs_to_state, node.args)) + submodule = _recursive_getattr(module, node.target.split(".")) + # remove placeholder from call_module node arguments, only if we've + # erased the placeholder node in the corresponding _sink_params() call + if submodule is not None and id(submodule) in module_id_to_inputs_removed: + node.args = tuple( + filter( + lambda n: n.name not in module_id_to_inputs_removed[id(submodule)], + node.args, + ) + ) + # Filter out inputs_to_state corresponding to current scope. + inputs_to_state_of_scope: Dict[torch.fx.Node, list[str]] = {} for node in inputs: if node.name not in inputs_to_state: continue - if len(node.users) > 0: - state_name = None - for sn in inputs_to_state[node.name]: - sn_split = sn.split(".") - if sn_split[: len(scope)] == scope: - state_name = sn_split - break - - # If there's a mismatch beteewn scope name and state name, then - # there must be multuple scopes pointing to the same state name, - # meaning some modules are shared. In such case, we can simply skip - # updating the current node because another later iteration will - # take care of this input node when the unique match between scope - # and state name occurs. To make sure this always happen, we should - # enforce the invariant that no placeholder node in the unflattened - # graph appears in inputs_to_state dict, which means all the extra - # input nodes have been handled. - if state_name is None: - continue + state_name = None + for sn in inputs_to_state[node.name]: + sn_split = sn.split(".") + if sn_split[: len(scope)] == scope: + state_name = sn_split + break + + # If there's a mismatch beteewn scope name and state name, then + # there must be multuple scopes pointing to the same state name, + # meaning some modules are shared. In such case, we can simply skip + # updating the current node because another later iteration will + # take care of this input node when the unique match between scope + # and state name occurs. To make sure this always happen, we should + # enforce the invariant that no placeholder node in the unflattened + # graph appears in inputs_to_state dict, which means all the extra + # input nodes have been handled. + if state_name is None: + continue + + inputs_to_state_of_scope[node] = state_name + + # Record name of remove inputs for return purpose. + inputs_removed: List[str] = [] + for node, state_name in inputs_to_state_of_scope.items(): + if len(node.users) > 0: attr_path = state_name[len(scope) :] state_attr = _recursive_getattr(module, attr_path) assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) @@ -1061,13 +1106,20 @@ def _sink_params( new_node = graph.create_node("get_attr", ".".join(attr_path)) node.replace_all_uses_with(new_node, propagate_meta=True) + graph.erase_node(node) + inputs_removed.append(node.name) + if isinstance(module, InterpreterModule): module.finalize() + return {id(module): inputs_removed} + def _recursive_getattr(obj, attr_path): for attr in attr_path: + if not hasattr(obj, attr): + return None obj = getattr(obj, attr) return obj From b9c058c203ee38032594f898f27cd8404f113a63 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 3 Jun 2024 23:58:35 -0700 Subject: [PATCH 310/706] Retire torch.distributed.pipeline (#127354) Actually retiring module after deprecation warning for a while. The new supported module is: torch.distributed.pipelining. Please migrate. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127354 Approved by: https://github.com/wconstab --- .lintrunner.toml | 24 - .../distributed/pipeline/benchmark_dataset.py | 58 -- benchmarks/distributed/pipeline/pipe.py | 296 ------ docs/source/conf.py | 85 -- docs/source/distributed.rst | 19 - docs/source/index.rst | 1 - docs/source/pipeline.rst | 85 -- test/allowlist_for_publicAPI.json | 28 - test/distributed/pipeline/sync/LICENSE | 27 - test/distributed/pipeline/sync/__init__.py | 8 - test/distributed/pipeline/sync/conftest.py | 61 -- .../pipeline/sync/skip/__init__.py | 6 - .../pipeline/sync/skip/test_api.py | 52 -- .../pipeline/sync/skip/test_gpipe.py | 126 --- .../sync/skip/test_inspect_skip_layout.py | 118 --- .../pipeline/sync/skip/test_leak.py | 136 --- .../pipeline/sync/skip/test_portal.py | 163 ---- .../pipeline/sync/skip/test_stash_pop.py | 144 --- .../pipeline/sync/skip/test_tracker.py | 145 --- .../sync/skip/test_verify_skippables.py | 165 ---- .../distributed/pipeline/sync/test_balance.py | 240 ----- test/distributed/pipeline/sync/test_bugs.py | 146 --- .../pipeline/sync/test_checkpoint.py | 178 ---- test/distributed/pipeline/sync/test_copy.py | 85 -- .../pipeline/sync/test_deferred_batch_norm.py | 200 ---- .../pipeline/sync/test_dependency.py | 152 ---- .../distributed/pipeline/sync/test_inplace.py | 79 -- .../pipeline/sync/test_microbatch.py | 148 --- test/distributed/pipeline/sync/test_phony.py | 57 -- test/distributed/pipeline/sync/test_pipe.py | 858 ------------------ .../pipeline/sync/test_pipeline.py | 36 - test/distributed/pipeline/sync/test_stream.py | 198 ---- .../pipeline/sync/test_transparency.py | 55 -- test/distributed/pipeline/sync/test_worker.py | 118 --- test/test_public_bindings.py | 2 - test/test_testing.py | 1 - torch/distributed/pipeline/__init__.py | 13 - torch/distributed/pipeline/sync/LICENSE | 27 - torch/distributed/pipeline/sync/__init__.py | 12 - .../pipeline/sync/_balance/__init__.py | 164 ---- .../pipeline/sync/_balance/blockpartition.py | 95 -- .../pipeline/sync/_balance/profile.py | 116 --- .../pipeline/sync/_balance/py.typed | 6 - torch/distributed/pipeline/sync/batchnorm.py | 159 ---- torch/distributed/pipeline/sync/checkpoint.py | 364 -------- torch/distributed/pipeline/sync/copy.py | 108 --- torch/distributed/pipeline/sync/dependency.py | 54 -- torch/distributed/pipeline/sync/microbatch.py | 234 ----- torch/distributed/pipeline/sync/phony.py | 50 - torch/distributed/pipeline/sync/pipe.py | 490 ---------- torch/distributed/pipeline/sync/pipeline.py | 255 ------ torch/distributed/pipeline/sync/py.typed | 6 - .../pipeline/sync/skip/__init__.py | 11 - .../distributed/pipeline/sync/skip/layout.py | 92 -- .../pipeline/sync/skip/namespace.py | 50 - .../distributed/pipeline/sync/skip/portal.py | 231 ----- .../pipeline/sync/skip/skippable.py | 431 --------- .../distributed/pipeline/sync/skip/tracker.py | 180 ---- torch/distributed/pipeline/sync/stream.py | 120 --- torch/distributed/pipeline/sync/utils.py | 38 - torch/distributed/pipeline/sync/worker.py | 132 --- .../distributed/pipe_with_ddp_test.py | 149 --- .../distributed/pipeline/__init__.py | 0 .../_internal/distributed/rpc_utils.py | 4 - 64 files changed, 7891 deletions(-) delete mode 100644 benchmarks/distributed/pipeline/benchmark_dataset.py delete mode 100644 benchmarks/distributed/pipeline/pipe.py delete mode 100644 docs/source/pipeline.rst delete mode 100644 test/distributed/pipeline/sync/LICENSE delete mode 100644 test/distributed/pipeline/sync/__init__.py delete mode 100644 test/distributed/pipeline/sync/conftest.py delete mode 100644 test/distributed/pipeline/sync/skip/__init__.py delete mode 100644 test/distributed/pipeline/sync/skip/test_api.py delete mode 100644 test/distributed/pipeline/sync/skip/test_gpipe.py delete mode 100644 test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py delete mode 100644 test/distributed/pipeline/sync/skip/test_leak.py delete mode 100644 test/distributed/pipeline/sync/skip/test_portal.py delete mode 100644 test/distributed/pipeline/sync/skip/test_stash_pop.py delete mode 100644 test/distributed/pipeline/sync/skip/test_tracker.py delete mode 100644 test/distributed/pipeline/sync/skip/test_verify_skippables.py delete mode 100644 test/distributed/pipeline/sync/test_balance.py delete mode 100644 test/distributed/pipeline/sync/test_bugs.py delete mode 100644 test/distributed/pipeline/sync/test_checkpoint.py delete mode 100644 test/distributed/pipeline/sync/test_copy.py delete mode 100644 test/distributed/pipeline/sync/test_deferred_batch_norm.py delete mode 100644 test/distributed/pipeline/sync/test_dependency.py delete mode 100644 test/distributed/pipeline/sync/test_inplace.py delete mode 100644 test/distributed/pipeline/sync/test_microbatch.py delete mode 100644 test/distributed/pipeline/sync/test_phony.py delete mode 100644 test/distributed/pipeline/sync/test_pipe.py delete mode 100644 test/distributed/pipeline/sync/test_pipeline.py delete mode 100644 test/distributed/pipeline/sync/test_stream.py delete mode 100644 test/distributed/pipeline/sync/test_transparency.py delete mode 100644 test/distributed/pipeline/sync/test_worker.py delete mode 100644 torch/distributed/pipeline/__init__.py delete mode 100644 torch/distributed/pipeline/sync/LICENSE delete mode 100644 torch/distributed/pipeline/sync/__init__.py delete mode 100644 torch/distributed/pipeline/sync/_balance/__init__.py delete mode 100644 torch/distributed/pipeline/sync/_balance/blockpartition.py delete mode 100644 torch/distributed/pipeline/sync/_balance/profile.py delete mode 100644 torch/distributed/pipeline/sync/_balance/py.typed delete mode 100644 torch/distributed/pipeline/sync/batchnorm.py delete mode 100644 torch/distributed/pipeline/sync/checkpoint.py delete mode 100644 torch/distributed/pipeline/sync/copy.py delete mode 100644 torch/distributed/pipeline/sync/dependency.py delete mode 100644 torch/distributed/pipeline/sync/microbatch.py delete mode 100644 torch/distributed/pipeline/sync/phony.py delete mode 100644 torch/distributed/pipeline/sync/pipe.py delete mode 100644 torch/distributed/pipeline/sync/pipeline.py delete mode 100644 torch/distributed/pipeline/sync/py.typed delete mode 100644 torch/distributed/pipeline/sync/skip/__init__.py delete mode 100644 torch/distributed/pipeline/sync/skip/layout.py delete mode 100644 torch/distributed/pipeline/sync/skip/namespace.py delete mode 100644 torch/distributed/pipeline/sync/skip/portal.py delete mode 100644 torch/distributed/pipeline/sync/skip/skippable.py delete mode 100644 torch/distributed/pipeline/sync/skip/tracker.py delete mode 100644 torch/distributed/pipeline/sync/stream.py delete mode 100644 torch/distributed/pipeline/sync/utils.py delete mode 100644 torch/distributed/pipeline/sync/worker.py delete mode 100644 torch/testing/_internal/distributed/pipe_with_ddp_test.py delete mode 100644 torch/testing/_internal/distributed/pipeline/__init__.py diff --git a/.lintrunner.toml b/.lintrunner.toml index ce69059d1ead..abf8ed9e28dd 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1536,28 +1536,6 @@ exclude_patterns = [ 'torch/distributed/optim/post_localSGD_optimizer.py', 'torch/distributed/optim/utils.py', 'torch/distributed/optim/zero_redundancy_optimizer.py', - 'torch/distributed/pipeline/__init__.py', - 'torch/distributed/pipeline/sync/__init__.py', - 'torch/distributed/pipeline/sync/_balance/__init__.py', - 'torch/distributed/pipeline/sync/_balance/blockpartition.py', - 'torch/distributed/pipeline/sync/_balance/profile.py', - 'torch/distributed/pipeline/sync/batchnorm.py', - 'torch/distributed/pipeline/sync/checkpoint.py', - 'torch/distributed/pipeline/sync/copy.py', - 'torch/distributed/pipeline/sync/dependency.py', - 'torch/distributed/pipeline/sync/microbatch.py', - 'torch/distributed/pipeline/sync/phony.py', - 'torch/distributed/pipeline/sync/pipe.py', - 'torch/distributed/pipeline/sync/pipeline.py', - 'torch/distributed/pipeline/sync/skip/__init__.py', - 'torch/distributed/pipeline/sync/skip/layout.py', - 'torch/distributed/pipeline/sync/skip/namespace.py', - 'torch/distributed/pipeline/sync/skip/portal.py', - 'torch/distributed/pipeline/sync/skip/skippable.py', - 'torch/distributed/pipeline/sync/skip/tracker.py', - 'torch/distributed/pipeline/sync/stream.py', - 'torch/distributed/pipeline/sync/utils.py', - 'torch/distributed/pipeline/sync/worker.py', 'torch/distributed/remote_device.py', 'torch/distributed/rendezvous.py', 'torch/distributed/rpc/__init__.py', @@ -1851,8 +1829,6 @@ exclude_patterns = [ 'torch/testing/_internal/distributed/nn/__init__.py', 'torch/testing/_internal/distributed/nn/api/__init__.py', 'torch/testing/_internal/distributed/nn/api/remote_module_test.py', - 'torch/testing/_internal/distributed/pipe_with_ddp_test.py', - 'torch/testing/_internal/distributed/pipeline/__init__.py', 'torch/testing/_internal/distributed/rpc/__init__.py', 'torch/testing/_internal/distributed/rpc/dist_autograd_test.py', 'torch/testing/_internal/distributed/rpc/dist_optimizer_test.py', diff --git a/benchmarks/distributed/pipeline/benchmark_dataset.py b/benchmarks/distributed/pipeline/benchmark_dataset.py deleted file mode 100644 index 3cd22e9a468d..000000000000 --- a/benchmarks/distributed/pipeline/benchmark_dataset.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -from torch.utils.data import Dataset - - -def collate_sentences_lm(samples): - if len(samples) == 0: - return {} - - id = torch.LongTensor([s["id"] for s in samples]) - src_tokens = torch.stack([s["source"] for s in samples], 0) - tgt_tokens = torch.stack([s["target"] for s in samples], 0) - ntokens = len(samples) * len(samples[0]["target"]) - src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples)) - - batch = { - "id": id, - "nsentences": len(samples), - "ntokens": ntokens, - "input": src_tokens, - "target": tgt_tokens, - } - return batch - - -class BenchmarkLMDataset(Dataset): - """ - Dataset to benchmark a translation like seq2seq task. - Args: - vocab_size (int, optional): size of the vocabulary (default 10000). - max_source_positions (int, optional): max number of tokens in the - source sentence (default: 1024). - total_samples (int, optional): the total number of rows in the - dataset (default: 10000). - """ - - def __init__( - self, - vocab_size=10000, - max_source_positions=1024, - total_samples=10000, - ): - self.vocab_size = vocab_size - self.max_source_positions = max_source_positions - self.total_samples = total_samples - self.sizes = [self.max_source_positions] * self.total_samples - - def __getitem__(self, index): - length = self.sizes[index] - source = torch.randint(1, self.vocab_size, (length,)) - target = source.clone() - return { - "id": index, - "source": source, - "target": target, - } - - def __len__(self): - return self.total_samples diff --git a/benchmarks/distributed/pipeline/pipe.py b/benchmarks/distributed/pipeline/pipe.py deleted file mode 100644 index c465c2488565..000000000000 --- a/benchmarks/distributed/pipeline/pipe.py +++ /dev/null @@ -1,296 +0,0 @@ -import argparse -import math -import os -import time - -from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm - -import torch -import torch.nn as nn -from torch.distributed import rpc - -from torch.distributed.pipeline.sync import Pipe -from torch.distributed.pipeline.sync.utils import partition_model -from torch.optim import Adam -from torch.utils.data import DataLoader - - -def sizeof_fmt(num, suffix="B"): - for unit in ["", "Ki", "Mi", "Gi", "Ti"]: - if abs(num) < 1024.0: - return f"{num:3.2f}{unit}B" - num /= 1024.0 - - -def init_random_seed(seed: int): - import numpy - - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - numpy.random.seed(seed) - - -iteration_count = 0 - - -class EmbeddingLayer(nn.Embedding): - def __init__(self, ntoken, ninp, initrange): - super().__init__(ntoken, ninp) - self.ninp = ninp - nn.init.uniform_(self.weight, -initrange, initrange) - - def forward(self, src): - return super().forward(src) * math.sqrt(self.ninp) - - -class PositionalEncodingLayer(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer("pe", pe) - - def forward(self, x): - x = x + self.pe[: x.size(0), :] - return self.dropout(x) - - -class TransformerDecoderLayer(nn.TransformerEncoderLayer): - """Though this class inherits from torch.nn.TransformerEncoderLayer, - it functions as a decoder in this model""" - - def __init__(self, ninp, nhead, nhid, droupout): - super().__init__(ninp, nhead, nhid, droupout) - self.src_mask = None - - def forward(self, src): - global iteration_count - iteration_count += 1 - - if self.src_mask is None or self.src_mask.size(0) != len(src): - device = src.device - mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device) - self.src_mask = mask - - return super().forward(src, self.src_mask) - - -class LinearLayer(nn.Linear): - def __init__(self, ninp, ntoken, initrange): - super().__init__(ninp, ntoken) - nn.init.zeros_(self.bias) - nn.init.uniform_(self.weight, -initrange, initrange) - - -class TransformerLMSequential(nn.Sequential): - """A small language model based on the design of GPT-2 using nn.Sequential - for compatibility with Pipe""" - - def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder): - layers = [ - EmbeddingLayer(ntokens, ninp, initrange), - PositionalEncodingLayer(ninp, dropout), - ] - for _ in range(ndecoder): - layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout)) - - layers.append(LinearLayer(ninp, ntokens, initrange)) - super().__init__(*layers) - - -def make_model(args, device, ntokens): - ninp = 2048 # embedding dimension - nhid = ( - 2048 # the dimension of the feedforward network model in nn.TransformerEncoder - ) - nhead = 32 # the number of heads in the multiheadattention models - dropout = 0 - initrange = 0.1 - ndecoder = args.num_decoder_layers - - model = TransformerLMSequential( - ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder - ).to(device) - - criterion = nn.CrossEntropyLoss() - lr = 0.01 # learning rate - - def make_adam(model): - return Adam(model.parameters(), lr=lr) - - optimizer = make_adam - - return model, criterion, optimizer - - -def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): - model.train() - - vocab_size = 10000 - total_loss = 0.0 - start_time = time.time() - word_counter = 0 - - optimizer = optimizer(model) - - def get_first_device(model): - if model.devices: - return model.devices[0] - else: - return torch.cuda.current_device() - - def get_last_device(model): - if model.devices: - return model.devices[-1] - else: - return torch.cuda.current_device() - - print( - f"Number of parameters for model: {sum(p.numel() for p in model.parameters())}" - ) - for i, batch in enumerate(lm_dataloader): - bi = batch["input"] - if args.max_batch and i > args.max_batch: - break - optimizer.zero_grad() - try: - tmp = batch["input"].to(get_first_device(model)) - output = model(tmp).local_value() - except Exception as e: - raise RuntimeError( - f"training failed on {torch.distributed.get_rank()}" - ) from e - - target = batch["target"].to(get_last_device(model)) - output = output.to(target.device) - - loss = criterion(output.view(-1, vocab_size), target.view(-1)) - loss.backward() - del target - del output - - torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) - optimizer.step() - - total_loss += loss.item() - log_interval = 1 - word_counter += batch["ntokens"] - if i % log_interval == 0 and i > 0: - cur_loss = total_loss / log_interval - elapsed = time.time() - start_time - print( - f"| batch {i:5d} | wps {word_counter / elapsed:5.2f} | loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}" - ) - word_counter = 0 - total_loss = 0 - start_time = time.time() - - print("Peak memory usage for GPUs: ", end="") - for i in range(len(model.devices)): - print( - f"cuda:{i}: {sizeof_fmt(torch.cuda.memory_stats(i)['allocated_bytes.all.peak'])}, ", - end="", - ) - print() - - -def generate_balance(num_devices, num_layers): - balance = [] - layers_assigned = 0 - for i in range(num_devices): - x = (num_layers - layers_assigned) / (num_devices - i) - if x.is_integer(): - balance.append(int(x)) - layers_assigned += x - else: - balance.append(math.ceil(x)) - layers_assigned += math.ceil(x) - return balance - - -def make_model_and_data(args, device): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - vocab_size = 10000 - model, criterion, optimizer = make_model(args, device, vocab_size) - lm_dataset = BenchmarkLMDataset() - lm_dataloader = DataLoader( - lm_dataset, - batch_size=args.batch_size, - shuffle=True, - num_workers=0, - collate_fn=collate_sentences_lm, - ) - return { - "model": model, - "criterion": criterion, - "optimizer": optimizer, - "data": lm_dataloader, - "vocab_size": vocab_size, - } - - -def bench_single_process(args): - os.environ.update({"MASTER_ADDR": args.host}) - os.environ.update({"MASTER_PORT": "10638"}) - - rpc.init_rpc( - "worker", - rank=0, - world_size=1, - ) - - num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 - num_devices = min(args.num_devices, num_devices) - assert num_devices > 0 - init_random_seed(0) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - blob = make_model_and_data(args, None) - model = blob["model"] - - balance = generate_balance(num_devices, len(model)) - model = partition_model(model, balance) - p = Pipe(model, chunks=args.chunks, checkpoint=args.checkpoint) - del model - del blob["model"] - - train( - blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args - ) - - -parser = argparse.ArgumentParser(description="benchmark") -parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname") -parser.add_argument( - "--chunks", type=int, default=4, help="number of microbatches per batch" -) -parser.add_argument("--batch-size", type=int, default=8, help="size of a batch") -parser.add_argument("--max-batch", type=int, default=10, help="Max number of batches") -parser.add_argument( - "--num-decoder-layers", - type=int, - default=10, - help="Number of decoder layers in the model", -) -parser.add_argument( - "--checkpoint", - default="except_last", - choices=["always", "except_last", "never"], - help="Checkpointing strategy for pipe", -) -parser.add_argument( - "--num-devices", type=int, default=4, help="Number of GPU devices to use" -) - -if __name__ == "__main__": - args = parser.parse_args() - print(f"Running benchmark with args: {args}") - bench_single_process(args) diff --git a/docs/source/conf.py b/docs/source/conf.py index fe548737b313..4f73c111cb23 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -606,45 +606,6 @@ # torch.distributed.optim.utils "as_functional_optim", "register_functional_optim", - # torch.distributed.pipeline.sync.checkpoint - "checkpoint", - "enable_checkpointing", - "enable_recomputing", - "is_checkpointing", - "is_recomputing", - "restore_rng_states", - "save_rng_states", - # torch.distributed.pipeline.sync.dependency - "fork", - "join", - # torch.distributed.pipeline.sync.microbatch - "check", - "gather", - "scatter", - # torch.distributed.pipeline.sync.phony - "get_phony", - # torch.distributed.pipeline.sync.skip.layout - "inspect_skip_layout", - # torch.distributed.pipeline.sync.skip.tracker - "current_skip_tracker", - "use_skip_tracker", - # torch.distributed.pipeline.sync.stream - "as_cuda", - "current_stream", - "default_stream", - "get_device", - "is_cuda", - "new_stream", - "record_stream", - "use_device", - "use_stream", - "wait_stream", - # torch.distributed.pipeline.sync.utils - "partition_model", - # torch.distributed.pipeline.sync.worker - "create_workers", - "spawn_workers", - "worker", # torch.distributed.rendezvous "register_rendezvous_handler", "rendezvous", @@ -2648,52 +2609,6 @@ "PostLocalSGDOptimizer", # torch.distributed.optim.zero_redundancy_optimizer "ZeroRedundancyOptimizer", - # torch.distributed.pipeline.sync.batchnorm - "DeferredBatchNorm", - # torch.distributed.pipeline.sync.checkpoint - "Checkpoint", - "Checkpointing", - "Context", - "Function", - "Recompute", - "ThreadLocal", - # torch.distributed.pipeline.sync.copy - "Context", - "Copy", - "Wait", - # torch.distributed.pipeline.sync.dependency - "Fork", - "Join", - # torch.distributed.pipeline.sync.microbatch - "Batch", - "NoChunk", - # torch.distributed.pipeline.sync.pipe - "BalanceError", - "Pipe", - "PipeSequential", - "WithDevice", - # torch.distributed.pipeline.sync.pipeline - "Pipeline", - # torch.distributed.pipeline.sync.skip.layout - "SkipLayout", - # torch.distributed.pipeline.sync.skip.namespace - "Namespace", - # torch.distributed.pipeline.sync.skip.portal - "Context", - "Portal", - "PortalBlue", - "PortalCopy", - "PortalOrange", - # torch.distributed.pipeline.sync.skip.skippable - "Skippable", - # torch.distributed.pipeline.sync.skip.tracker - "SkipTracker", - "SkipTrackerThroughPotals", - "ThreadLocal", - # torch.distributed.pipeline.sync.stream - "CPUStreamType", - # torch.distributed.pipeline.sync.worker - "Task", # torch.distributed.rpc.api "AllGatherStates", "RRef", diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 0b091d567031..f4c73b9381e5 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -876,9 +876,6 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.nn.api .. py:module:: torch.distributed.nn.jit .. py:module:: torch.distributed.nn.jit.templates -.. py:module:: torch.distributed.pipeline -.. py:module:: torch.distributed.pipeline.sync -.. py:module:: torch.distributed.pipeline.sync.skip .. py:module:: torch.distributed.tensor .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks @@ -964,22 +961,6 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.optim.post_localSGD_optimizer .. py:module:: torch.distributed.optim.utils .. py:module:: torch.distributed.optim.zero_redundancy_optimizer -.. py:module:: torch.distributed.pipeline.sync.batchnorm -.. py:module:: torch.distributed.pipeline.sync.checkpoint -.. py:module:: torch.distributed.pipeline.sync.copy -.. py:module:: torch.distributed.pipeline.sync.dependency -.. py:module:: torch.distributed.pipeline.sync.microbatch -.. py:module:: torch.distributed.pipeline.sync.phony -.. py:module:: torch.distributed.pipeline.sync.pipe -.. py:module:: torch.distributed.pipeline.sync.pipeline -.. py:module:: torch.distributed.pipeline.sync.skip.layout -.. py:module:: torch.distributed.pipeline.sync.skip.namespace -.. py:module:: torch.distributed.pipeline.sync.skip.portal -.. py:module:: torch.distributed.pipeline.sync.skip.skippable -.. py:module:: torch.distributed.pipeline.sync.skip.tracker -.. py:module:: torch.distributed.pipeline.sync.stream -.. py:module:: torch.distributed.pipeline.sync.utils -.. py:module:: torch.distributed.pipeline.sync.worker .. py:module:: torch.distributed.remote_device .. py:module:: torch.distributed.rendezvous .. py:module:: torch.distributed.rpc.api diff --git a/docs/source/index.rst b/docs/source/index.rst index ea704f20c3af..dcaadcbb63ed 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -103,7 +103,6 @@ Features described in this documentation are classified by release status: optim complex_numbers ddp_comm_hooks - pipeline quantization rpc torch.random diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst deleted file mode 100644 index 94d730ee223d..000000000000 --- a/docs/source/pipeline.rst +++ /dev/null @@ -1,85 +0,0 @@ -.. _pipeline-parallelism: - -Pipeline Parallelism -==================== - -Pipeline parallelism was original introduced in the -`Gpipe `__ paper and is an efficient -technique to train large models on multiple GPUs. - -.. warning :: - torch.distributed.pipeline is deprecated, so is this document. For - up-to-date pipeline parallel implementation, please refer to the - `PiPPy `__ library under the PyTorch - organization (Pipeline Parallelism for PyTorch). - -Model Parallelism using multiple GPUs -------------------------------------- - -Typically for large models which don't fit on a single GPU, model parallelism -is employed where certain parts of the model are placed on different GPUs. -Although, if this is done naively for sequential models, the training process -suffers from GPU under utilization since only one GPU is active at one time as -shown in the figure below: - -.. figure:: _static/img/pipeline_parallelism/no_pipe.png - - The figure represents a model with 4 layers placed on 4 different GPUs - (vertical axis). The horizontal axis represents training this model through - time demonstrating that only 1 GPU is utilized at a time - (`image source `__). - -Pipelined Execution -------------------- - -To alleviate this problem, pipeline parallelism splits the input minibatch into -multiple microbatches and pipelines the execution of these microbatches across -multiple GPUs. This is outlined in the figure below: - -.. figure:: _static/img/pipeline_parallelism/pipe.png - - The figure represents a model with 4 layers placed on 4 different GPUs - (vertical axis). The horizontal axis represents training this model through - time demonstrating that the GPUs are utilized much more efficiently. - However, there still exists a bubble (as demonstrated in the figure) where - certain GPUs are not utilized. - (`image source `__). - -Pipe APIs in PyTorch --------------------- -.. autoclass:: torch.distributed.pipeline.sync.Pipe - :members: forward - -Skip connections -^^^^^^^^^^^^^^^^ - -Certain models like `ResNeXt `__ -are not completely sequential and have skip connections between layers. -Naively implementing as part of pipeline parallelism would imply that -we need to copy outputs for certain layers through multiple GPUs till -we eventually reach the GPU where the layer for the skip connection resides. -To avoid this copy overhead, we provide APIs below to stash and pop Tensors -in different layers of the model. - -.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.skippable -.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.stash -.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.pop -.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.verify_skippables - -Tutorials ---------- - -The following tutorials give a good overview of how to use the -:class:`~torch.distributed.pipeline.sync.Pipe` API to train your models with the -rest of the components that PyTorch provides: - -- `Training Transformer models using Pipeline Parallelism `__ -- `Training Transformer models using Distributed Data Parallel and Pipeline Parallelism `__ - -Acknowledgements ----------------- - -The implementation for pipeline parallelism is based on `fairscale's pipe implementation `__ and -`torchgpipe `__. We would like to -thank both teams for their contributions and guidance towards bringing pipeline -parallelism into PyTorch. diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index c3d3fe2f00ec..f7af925adb72 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -211,30 +211,6 @@ "torch.distributed.optim.utils": [ "Type" ], - "torch.distributed.pipeline.sync.pipe": [ - "Pipeline" - ], - "torch.distributed.pipeline.sync.skip.layout": [ - "SkipLayout", - "inspect_skip_layout" - ], - "torch.distributed.pipeline.sync.skip.portal": [ - "Context", - "Portal", - "PortalBlue", - "PortalCopy", - "PortalOrange" - ], - "torch.distributed.pipeline.sync.skip.skippable": [ - "Skippable" - ], - "torch.distributed.pipeline.sync.skip.tracker": [ - "SkipTracker", - "SkipTrackerThroughPotals", - "ThreadLocal", - "current_skip_tracker", - "use_skip_tracker" - ], "torch.distributed.remote_device": [ "Optional", "Union" @@ -1697,10 +1673,6 @@ "get_args_parser", "run" ], - "torch.distributed.pipeline.sync": [ - "NoChunk", - "WithDevice" - ], "torch.distributed.rpc.rref_proxy": [ "Future", "partial", diff --git a/test/distributed/pipeline/sync/LICENSE b/test/distributed/pipeline/sync/LICENSE deleted file mode 100644 index e52be240fdc9..000000000000 --- a/test/distributed/pipeline/sync/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright 2019-2020 Kakao Brain - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from this - software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/test/distributed/pipeline/sync/__init__.py b/test/distributed/pipeline/sync/__init__.py deleted file mode 100644 index 94cd5bcb415e..000000000000 --- a/test/distributed/pipeline/sync/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH. -# See also: https://docs.pytest.org/en/latest/goodpractices.html diff --git a/test/distributed/pipeline/sync/conftest.py b/test/distributed/pipeline/sync/conftest.py deleted file mode 100644 index 4f2479b27b29..000000000000 --- a/test/distributed/pipeline/sync/conftest.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import tempfile - -import pytest - -import torch -import torch.distributed as dist - - -@pytest.fixture(autouse=True) -def manual_seed_zero(): - torch.manual_seed(0) - - -@pytest.fixture(scope="session") -def cuda_sleep(): - # Warm-up CUDA. - torch.empty(1, device="cuda") - - # From test/test_cuda.py in PyTorch. - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - torch.cuda._sleep(1000000) - end.record() - end.synchronize() - cycles_per_ms = 1000000 / start.elapsed_time(end) - - def cuda_sleep(seconds): - torch.cuda._sleep(int(seconds * cycles_per_ms * 1000)) - - return cuda_sleep - - -def pytest_report_header(): - return f"torch: {torch.__version__}" - - -@pytest.fixture -def setup_rpc(scope="session"): - file = tempfile.NamedTemporaryFile() - dist.rpc.init_rpc( - name="worker0", - rank=0, - world_size=1, - rpc_backend_options=dist.rpc.TensorPipeRpcBackendOptions( - init_method=f"file://{file.name}", - ), - ) - yield - dist.rpc.shutdown() - - -def pytest_ignore_collect(path, config): - "Skip this directory if distributed modules are not enabled." - return not dist.is_available() diff --git a/test/distributed/pipeline/sync/skip/__init__.py b/test/distributed/pipeline/sync/skip/__init__.py deleted file mode 100644 index ab03724cafbf..000000000000 --- a/test/distributed/pipeline/sync/skip/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/test/distributed/pipeline/sync/skip/test_api.py b/test/distributed/pipeline/sync/skip/test_api.py deleted file mode 100644 index be38d6d83dac..000000000000 --- a/test/distributed/pipeline/sync/skip/test_api.py +++ /dev/null @@ -1,52 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import copy - -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, skippable, stash -from torch.testing._internal.common_utils import run_tests - - -def test_namespace_difference(): - ns1 = Namespace() - ns2 = Namespace() - assert ns1 != ns2 - - -def test_namespace_copy(): - ns = Namespace() - assert copy.copy(ns) == ns - assert copy.copy(ns) is not ns - - -def test_skippable_repr(): - @skippable(stash=["hello"]) - class Hello(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 1, 1) - - def forward(self, x): - yield stash("hello", x) - return self.conv(x) # noqa: B901 - - m = Hello() - assert ( - repr(m) - == """ -@skippable(Hello( - (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1)) -)) -""".strip() - ) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_gpipe.py b/test/distributed/pipeline/sync/skip/test_gpipe.py deleted file mode 100644 index 4f433ab38941..000000000000 --- a/test/distributed/pipeline/sync/skip/test_gpipe.py +++ /dev/null @@ -1,126 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.portal import ( - PortalBlue, - PortalCopy, - PortalOrange, -) -from torch.distributed.pipeline.sync.utils import partition_model -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -@pytest.mark.parametrize( - "balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"] -) -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_1to3(balance, checkpoint, setup_rpc): - if torch.cuda.device_count() < len(balance): - pytest.skip("at least %d cuda devices required" % len(balance)) - - @skippable(stash=["1to3"]) - class Layer1(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - yield stash("1to3", input) - output = self.conv(input) - return output # noqa: B901 - - class Layer2(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - output = self.conv(input) - return output - - @skippable(pop=["1to3"]) - class Layer3(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - skip_1to3 = yield pop("1to3") - output = self.conv(input) + skip_1to3 - return output - - model = nn.Sequential(Layer1(), Layer2(), Layer3()) - model = partition_model(model, balance) - model = Pipe(model, chunks=3, checkpoint=checkpoint) - - in_device = model.devices[0] - out_device = model.devices[-1] - - input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) - output = model(input) - loss = output.local_value().mean() - loss.backward() - - assert torch.allclose( - output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1 - ) - assert torch.allclose( - input.grad.norm(), torch.tensor(0.0004533053, device=in_device) - ) - - -def test_none_skip(setup_rpc): - @skippable(stash=["none"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("none", None) - return input # noqa: B901 - - @skippable(pop=["none"]) - class Pop(nn.Module): - def forward(self, input): - none = yield pop("none") - assert none is None - return input - - model = nn.Sequential(Stash(), Pop()) - model = Pipe(model, chunks=5) - - input = torch.rand(10, requires_grad=True) - output = model(input) - - def assert_grad_fn_is_not_portal(grad_fn, visited=None): - if visited is None: - visited = set() - if grad_fn in visited or grad_fn is None: - return - - assert not isinstance(grad_fn, PortalBlue._backward_cls) - assert not isinstance(grad_fn, PortalCopy._backward_cls) - assert not isinstance(grad_fn, PortalOrange._backward_cls) - - visited.add(grad_fn) - for next_grad_fn, _ in grad_fn.next_functions: - assert_grad_fn_is_not_portal(next_grad_fn, visited) - - assert_grad_fn_is_not_portal(output.local_value().grad_fn) - - output.local_value().sum().backward() - assert input.grad.mean().item() == 1 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py b/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py deleted file mode 100644 index 4d542285cd5a..000000000000 --- a/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py +++ /dev/null @@ -1,118 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, pop, skippable, stash -from torch.distributed.pipeline.sync.skip.layout import inspect_skip_layout -from torch.testing._internal.common_utils import run_tests - - -class Pass(nn.Module): - def forward(self, input): - return input - - -@skippable(stash=["foo"]) -class StashFoo(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input # noqa: B901 - - -@skippable(pop=["foo"]) -class PopFoo(nn.Module): - def forward(self, input): - foo = yield stash("foo") - return input + foo - - -@skippable(stash=["bar"]) -class StashBar(nn.Module): - def forward(self, input): - yield stash("bar", input) - return input # noqa: B901 - - -@skippable(pop=["bar"]) -class PopBar(nn.Module): - def forward(self, input): - bar = yield pop("bar") - return input + bar - - -def test_no_skippables(): - p1 = nn.Sequential(Pass()) - p2 = nn.Sequential(Pass()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], []] - - -def test_inner_partition(): - p1 = nn.Sequential(StashFoo(), PopFoo()) - p2 = nn.Sequential(Pass()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], []] - - -def test_adjoining_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(PopFoo()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], [(0, None, "foo")]] - - -def test_far_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(Pass()) - p3 = nn.Sequential(PopFoo()) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - assert policy == [[], [], [(0, None, "foo")]] - - -def test_pop_2_from_different_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(StashBar()) - p3 = nn.Sequential(PopBar(), PopFoo()) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. - assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]] - - -def test_namespace(): - ns1 = Namespace() - ns2 = Namespace() - - p1 = nn.Sequential(StashFoo().isolate(ns1)) - p2 = nn.Sequential(StashFoo().isolate(ns2)) - p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1)) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. - assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]] - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_leak.py b/test/distributed/pipeline/sync/skip/test_leak.py deleted file mode 100644 index f4d1043e0549..000000000000 --- a/test/distributed/pipeline/sync/skip/test_leak.py +++ /dev/null @@ -1,136 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import is_checkpointing, is_recomputing, Pipe -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.tracker import current_skip_tracker -from torch.testing._internal.common_utils import run_tests - - -@skippable(stash=["skip"]) -class Stash(nn.Module): - def forward(self, input): - yield stash("skip", input) - return input # noqa: B901 - - -@skippable(pop=["skip"]) -class Pop(nn.Module): - def forward(self, input): - skip = yield pop("skip") - return input + skip - - -@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) -def test_delete_portal_tensor(train, checkpoint, setup_rpc): - # Without checkpointing: - # +- Stash --+ +--- Pop ----+ - - - layers - # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function - # +----------+ +------------+ - # - # With checkpointing: - # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ - # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | - # +----------+ +------------+ +------------+ +----------+ - - def portal_tensor_life_is(tensor_life, skip_tracker=None): - if skip_tracker is None: - skip_tracker = current_skip_tracker() - - # Get the current portal. - portal = next(iter(skip_tracker.portals.values())) - - if tensor_life == 0: - return portal.tensor_life == 0 and portal.tensor is None - else: - return portal.tensor_life == tensor_life and portal.tensor is not None - - # Check the portal tensor after 'Stash'. - stash_ = Stash() - - @stash_.register_forward_hook - def check_portal_tensor_after_stash(*_): - if is_checkpointing(): - assert portal_tensor_life_is(2) - elif is_recomputing(): - assert portal_tensor_life_is(0) - else: - assert portal_tensor_life_is(1) - - pop_ = Pop() - - @pop_.register_forward_hook - def check_portal_tensor_after_pop(*_): - if is_checkpointing(): - assert portal_tensor_life_is(1) - elif is_recomputing(): - assert portal_tensor_life_is(0) - else: - assert portal_tensor_life_is(0) - - class NoPortalTensorAtBackward(nn.Module): - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - ctx.skip_tracker = current_skip_tracker() - return input.detach() - - @staticmethod - def backward(ctx, grad): - assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) - return grad - - def forward(self, input): - return self.F.apply(input) - - model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) - model = Pipe(model, chunks=2, checkpoint=checkpoint) - - input = torch.rand(10, requires_grad=True) - - if train: - model.train() - output = model(input).local_value() - output.norm().backward() - else: - model.eval() - with torch.no_grad(): - model(input) - - -@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -def test_no_portal_without_pipe(train, monkeypatch, setup_rpc): - def deny(*args, **kwargs): - raise AssertionError("tried to create Portal without Pipe") - - monkeypatch.setattr( - "torch.distributed.pipeline.sync.skip.portal.Portal.__init__", deny - ) - - model = nn.Sequential(Stash(), Pop()) - - input = torch.rand(10, requires_grad=True) - - if train: - model.train() - output = model(input) - output.norm().backward() - else: - model.eval() - with torch.no_grad(): - model(input) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_portal.py b/test/distributed/pipeline/sync/skip/test_portal.py deleted file mode 100644 index 5ad180b6f9c8..000000000000 --- a/test/distributed/pipeline/sync/skip/test_portal.py +++ /dev/null @@ -1,163 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.dependency import fork, join -from torch.distributed.pipeline.sync.skip.portal import Portal -from torch.distributed.pipeline.sync.stream import default_stream -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_copy_returns_on_next_device(): - portal = Portal(torch.rand(1), tensor_life=1) - - prev_stream = default_stream(torch.device("cpu")) - next_stream = default_stream(torch.device("cuda")) - - phony = torch.zeros(0, requires_grad=True) - assert phony.device.type == "cpu" - - phony = portal.copy(prev_stream, next_stream, phony) - assert phony.device.type == "cuda" - - -def test_blue_orange(): - tensor1 = torch.rand(1, requires_grad=True) - tensor2 = torch.rand(1, requires_grad=True) - - # Same with: output = tensor1*2 + tensor2 - # - # +----------------------+ - # | | - # tensor2 -- PortalBlue -+ +- PortalOrange -+ - # | | | - # tensor1 ------------ Join -- Fork --- Mul --- Add -- output - # - main = tensor1 - portal = Portal(tensor2, tensor_life=2) - phony = portal.blue() - main = join(main, phony) - main, phony = fork(main) - sub = portal.orange(phony) - output = main * 2 + sub - - output.backward() - - assert torch.allclose(tensor1.grad, torch.tensor([2.0])) - assert torch.allclose(tensor2.grad, torch.tensor([1.0])) - - -def test_blue_orange_not_requires_grad(): - tensor1 = torch.rand(1, requires_grad=True) - tensor2 = torch.rand(1) - - # Same with: output = tensor1*2 + tensor2 - # - # +----------------------+ - # | | - # tensor2 -- PortalBlue -+ +- PortalOrange -+ - # | | | - # tensor1 ------------ Join -- Fork --- Mul --- Add -- output - # - main = tensor1 - portal = Portal(tensor2, tensor_life=2) - phony = portal.blue() - main = join(main, phony) - main, phony = fork(main) - sub = portal.orange(phony) - output = main * 2 + sub - - output.backward() - - assert torch.allclose(tensor1.grad, torch.tensor([2.0])) - assert tensor2.grad is None - - -def test_use_grad(): - tensor = torch.rand(1, requires_grad=True) - portal = Portal(tensor, tensor_life=1) - - portal.put_grad(tensor) - assert portal.use_grad() is tensor - - # Gradient in a portal is ephemeral. - with pytest.raises(RuntimeError): - portal.use_grad() - - -class TestTensorLife: - @pytest.fixture - def new_portal(self): - portal = None - - def new_portal(tensor_life): - nonlocal portal - tensor = torch.rand(1, requires_grad=True) - portal = Portal(tensor, tensor_life) - return portal, tensor - - yield new_portal - - # A test using this fixture must exhaust the tensor in the portal. - with pytest.raises(RuntimeError): - portal.check_tensor_life() - assert portal.tensor is None - - def test_tensor_life_0(self, new_portal): - portal, tensor = new_portal(0) - assert portal.tensor is None - - def test_tensor_life_1(self, new_portal): - portal, tensor = new_portal(1) - assert portal.tensor is tensor - - portal.blue() - - def test_tensor_life_2(self, new_portal): - portal, tensor = new_portal(2) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - def test_tensor_life_3(self, new_portal): - portal, tensor = new_portal(3) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - def test_tensor_life_4(self, new_portal): - portal, tensor = new_portal(4) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - portal.blue() - - def test_tensor_life_3_plus_1(self, new_portal): - portal, tensor = new_portal(3) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - another_tensor = torch.rand(1, requires_grad=True) - portal.put_tensor(another_tensor, tensor_life=1) - portal.blue() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_stash_pop.py b/test/distributed/pipeline/sync/skip/test_stash_pop.py deleted file mode 100644 index 5d273860f6a6..000000000000 --- a/test/distributed/pipeline/sync/skip/test_stash_pop.py +++ /dev/null @@ -1,144 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker -from torch.testing._internal.common_utils import run_tests - - -@pytest.fixture(autouse=True) -def skip_tracker(): - skip_tracker = SkipTracker() - with use_skip_tracker(skip_tracker): - yield skip_tracker - - -def test_stash(skip_tracker): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - l1 = Stash() - - assert len(skip_tracker.tensors) == 0 - - with use_skip_tracker(skip_tracker): - l1(torch.tensor(42)) - - assert len(skip_tracker.tensors) == 1 - - -def test_pop(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - l1 = Stash() - l2 = Pop() - - output = l2(l1(torch.tensor(42))) - - assert output.item() == 42 - - -def test_declare_but_not_use(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - return input * 2 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - return input * 3 - - l1 = Stash() - l2 = Pop() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - with pytest.raises(RuntimeError): - l2(torch.tensor(42)) - - -def test_stash_not_declared(): - @skippable() - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - l1 = Stash() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - -def test_pop_not_declared(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable() - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - l1 = Stash() - l2 = Pop() - - latent = l1(torch.tensor(42)) - - with pytest.raises(RuntimeError): - l2(latent) - - -def test_pop_not_stashed(): - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - yield pop("foo") - - l1 = Pop() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - -def test_stash_none(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", None) - return input * 2 # noqa: B901 - - l1 = Stash() - l1(torch.tensor(42)) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_tracker.py b/test/distributed/pipeline/sync/skip/test_tracker.py deleted file mode 100644 index 9c3a970f7574..000000000000 --- a/test/distributed/pipeline/sync/skip/test_tracker.py +++ /dev/null @@ -1,145 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import threading -from queue import Queue - -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync.checkpoint import ( - enable_checkpointing, - enable_recomputing, -) -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.layout import SkipLayout -from torch.distributed.pipeline.sync.skip.tracker import ( - current_skip_tracker, - SkipTracker, - SkipTrackerThroughPotals, -) -from torch.testing._internal.common_utils import run_tests - - -def test_default_skip_tracker(): - q = Queue() - - def f(): - q.put(current_skip_tracker()) - - t = threading.Thread(target=f) - t.start() - t.join() - - skip_tracker = q.get() - - assert type(skip_tracker) is SkipTracker - assert type(skip_tracker) is not SkipTrackerThroughPotals - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_default_skip_tracker_by_data_parallel(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - model = nn.Sequential(Stash(), Pop()) - model = nn.DataParallel(model, device_ids=[0, 0], output_device=0) - - input = torch.rand(10, device=0) - output = model(input) - - assert torch.allclose(output, input) - - -def test_reuse_portal(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - a = torch.tensor([2.0]) - b = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "test", a) - portal = skip_tracker.portals[(None, "test")] - - skip_tracker.save(batch, None, "test", b) - assert portal is skip_tracker.portals[(None, "test")] - - -def test_no_copy_no_portal(): - skip_layout = SkipLayout( - num_partitions=2, - skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)}, - ) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - a = torch.tensor([2.0]) - b = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "copy", a) - skip_tracker.save(batch, None, "not_copy", b) - - assert (None, "copy") in skip_tracker.portals - assert (None, "copy") not in skip_tracker.tensors - assert (None, "not_copy") in skip_tracker.tensors - assert (None, "not_copy") not in skip_tracker.portals - - -def test_tensor_life_without_checkpointing(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - tensor = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 1 - - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - -def test_tensor_life_with_checkpointing(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - tensor = torch.tensor([2.0]) - - with enable_checkpointing(): - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 2 - - with enable_checkpointing(): - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 1 - - with enable_recomputing(): - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - with enable_recomputing(): - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_verify_skippables.py b/test/distributed/pipeline/sync/skip/test_verify_skippables.py deleted file mode 100644 index 1d5941487da8..000000000000 --- a/test/distributed/pipeline/sync/skip/test_verify_skippables.py +++ /dev/null @@ -1,165 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, skippable, verify_skippables -from torch.testing._internal.common_utils import run_tests - - -def test_matching(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - verify_skippables(nn.Sequential(Layer1(), Layer2())) - - -def test_stash_not_pop(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "no module declared 'foo' as poppable but stashed" in str(e.value) - - -def test_pop_unknown(): - @skippable(pop=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value) - - -def test_stash_again(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer3(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - assert "'1' redeclared 'foo' as stashable" in str(e.value) - - -def test_pop_again(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer3(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - assert "'2' redeclared 'foo' as poppable" in str(e.value) - - -def test_stash_pop_together_different_names(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"], stash=["bar"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["bar"]) - class Layer3(nn.Module): - pass - - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - - -def test_stash_pop_together_same_name(): - @skippable(stash=["foo"], pop=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value) - - -def test_double_stash_pop(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer3(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer4(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4())) - assert "'2' redeclared 'foo' as stashable" in str(e.value) - assert "'3' redeclared 'foo' as poppable" in str(e.value) - - -def test_double_stash_pop_but_isolated(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer3(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer4(nn.Module): - pass - - ns1 = Namespace() - ns2 = Namespace() - - verify_skippables( - nn.Sequential( - Layer1().isolate(ns1), - Layer2().isolate(ns1), - Layer3().isolate(ns2), - Layer4().isolate(ns2), - ) - ) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_balance.py b/test/distributed/pipeline/sync/test_balance.py deleted file mode 100644 index faf09f4581ae..000000000000 --- a/test/distributed/pipeline/sync/test_balance.py +++ /dev/null @@ -1,240 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import time - -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync._balance import ( - balance_by_size, - balance_by_time, - blockpartition, -) -from torch.distributed.pipeline.sync._balance.profile import layerwise_sandbox -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - -devices = ["cpu"] -if torch.cuda.is_available(): - devices.append("cuda") - - -def test_blockpartition(): - assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [ - [1, 2, 3, 4], - [5, 6], - ] - - -def test_blockpartition_zeros(): - assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]] - - -def test_blockpartition_non_positive_partitions(): - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=0) - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=-1) - - -def test_blockpartition_short_sequence(): - with pytest.raises(ValueError): - blockpartition.solve([], partitions=1) - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=2) - - -@pytest.mark.parametrize("device", devices) -@pytest.mark.skip(reason="Flaky due to time.sleep()") -def test_balance_by_time(device): - class Delay(nn.Module): - def __init__(self, seconds): - super().__init__() - self.seconds = seconds - - def forward(self, x): - time.sleep(self.seconds) - return x - - model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]]) - sample = torch.rand(1) - balance = balance_by_time(2, model, sample, device=device) - assert balance == [4, 2] - - -def test_balance_by_time_loop_resets_input(): - # nn.Flatten was introduced at PyTorch 1.2.0. - class Flatten(nn.Module): - def forward(self, x): - return x.flatten(1) - - model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10)) - sample = torch.rand(10, 3, 8, 8) - balance = balance_by_time(2, model, sample, device="cpu") - assert balance == [1, 2] - - -@skip_if_no_cuda -def test_balance_by_size_latent(): - class Expand(nn.Module): - def __init__(self, times): - super().__init__() - self.times = times - - def forward(self, x): - for i in range(self.times): - x = x + torch.rand_like(x, requires_grad=True) - return x - - sample = torch.rand(10, 100, 100) - - model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]]) - balance = balance_by_size(2, model, sample) - assert balance == [4, 2] - - model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]]) - balance = balance_by_size(2, model, sample) - assert balance == [2, 4] - - -@skip_if_no_cuda -def test_balance_by_size_param(): - model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)]) - sample = torch.rand(7, 1) - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [4, 2] - - model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))]) - sample = torch.rand(1, 7) - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [2, 4] - - -@skip_if_no_cuda -def test_balance_by_size_param_scale(): - class Tradeoff(nn.Module): - def __init__(self, param_size, latent_size): - super().__init__() - self.fc = nn.Linear(param_size, param_size) - self.latent_size = latent_size - - def forward(self, x): - for i in range(self.latent_size): - x = x + torch.rand_like(x, requires_grad=True) - return x - - model = nn.Sequential( - Tradeoff(param_size=1, latent_size=6), - Tradeoff(param_size=2, latent_size=5), - Tradeoff(param_size=3, latent_size=4), - Tradeoff(param_size=4, latent_size=3), - Tradeoff(param_size=5, latent_size=2), - Tradeoff(param_size=6, latent_size=1), - ) - - sample = torch.rand(1, requires_grad=True) - - balance = balance_by_size(2, model, sample, param_scale=0) - assert balance == [2, 4] - - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [4, 2] - - -@pytest.mark.parametrize("device", devices) -def test_layerwise_sandbox(device): - model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) - model.eval() - - for layer in layerwise_sandbox(model, torch.device(device)): - assert layer.training - assert all(p.device.type == device for p in layer.parameters()) - - assert all(not l.training for l in model) - assert all(p.device.type == "cpu" for p in model.parameters()) - - -@pytest.mark.parametrize("device", devices) -def test_sandbox_during_profiling(device): - model = nn.Sequential(nn.BatchNorm2d(3)) - - before = {k: v.clone() for k, v in model.state_dict().items()} - - sample = torch.rand(1, 3, 10, 10) - balance_by_time(1, model, sample, device=device) - - after = model.state_dict() - - assert before.keys() == after.keys() - for key, value in before.items(): - assert torch.allclose(after[key], value), key - - -def test_not_training(): - class AssertTraining(nn.Module): - def forward(self, x): - assert self.training - return x - - model = nn.Sequential(AssertTraining()) - - model.eval() - assert not model.training - - sample = torch.rand(1) - balance_by_time(1, model, sample, device="cpu") - - assert not model.training - - -def test_balance_by_time_tuple(): - class Twin(nn.Module): - def forward(self, x): - return x, x.detach() - - class Add(nn.Module): - def forward(self, a, b): - return a + b - - model = nn.Sequential(Twin(), Add()) - sample = torch.rand(1, requires_grad=True) - balance_by_time(1, model, sample, device="cpu") - - -@skip_if_no_cuda -def test_balance_by_size_tuple(): - class Twin(nn.Module): - def forward(self, x): - return x, x.detach() - - class Add(nn.Module): - def forward(self, a, b): - return a + b - - model = nn.Sequential(Twin(), Add()) - sample = torch.rand(1, requires_grad=True) - balance_by_size(1, model, sample) - - -def test_already_has_grad(): - model = nn.Sequential(nn.Conv2d(3, 3, 1)) - sample = torch.rand(1, 3, 32, 32) - model(sample).norm().backward() - - with pytest.raises(ValueError, match="some parameter already has gradient"): - balance_by_time(1, model, sample, device="cpu") - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_bugs.py b/test/distributed/pipeline/sync/test_bugs.py deleted file mode 100644 index 928a78db6e32..000000000000 --- a/test/distributed/pipeline/sync/test_bugs.py +++ /dev/null @@ -1,146 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -import torch.nn.functional as F -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_utils import run_tests - - -def test_python_autograd_function(setup_rpc): - # A Python autograd function might fail with this error: - # - # RuntimeError: Returning Variables sharing storage with other Variables - # that require grad is not supported in Python functions. Please submit a - # feature request if you hit this error. - # - # It doesn't look like an essential restriction. But it happens on the - # current PyTorch version. To avoid it, we should detach the tensor before - # returning by identity autograd functions, such as Wait, Fork, and Join. - # - class Identity(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad): - return grad - - class M(nn.Module): - def forward(self, input): - return Identity.apply(input) - - model = nn.Sequential(M(), M()) - model = Pipe(model, checkpoint="always") - - x = torch.rand(42) - y = model(x) - assert torch.allclose(x, y.local_value()) - - -def test_exception_no_hang(setup_rpc): - # In v0.0.2, once a failed partition receives a normal message - # (non-closing) for the next micro-batch, a hang occurred. The reason was - # that a failed partition didn't call in_queue.task_done() on a normal - # message. So the former partition was blocked at out_queue.join() for the - # next of next micro-batch. - class ExpectedException(Exception): - pass - - class Pass(nn.Module): - def forward(self, x): - return x - - class Raise(nn.Module): - def forward(self, x): - raise ExpectedException - - model = nn.Sequential(Pass(), Pass(), Raise()) - model = Pipe(model, chunks=3) - - with pytest.raises(ExpectedException): - model(torch.rand(3)) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="2 cuda devices required") -def test_tuple_wait(cuda_sleep, setup_rpc): - # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. - # Under this behavior, if checkpointing was disabled, there's a possibility - # that gradient accumulations on other tensors are not synchronized - # properly to the copy stream. - class Sleep(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - return x.detach() - - @staticmethod - def backward(ctx, grad): - with torch.cuda.device(grad.device): - cuda_sleep(0.05) - return grad - - class Layer1(nn.Module): - def __init__(self): - super().__init__() - self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) - - def forward(self, a, b): - a = a * self.ones - return a * 1, b * 2, b * 3 - - class Layer2(nn.Module): - def __init__(self): - super().__init__() - self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) - - def forward(self, a, b, c): - a = a * self.ones - b = Sleep.apply(b) - return a + b + c - - model = nn.Sequential(Layer1().cuda(0), Layer2().cuda(1)) - model = Pipe(model, chunks=32, checkpoint="never") - - a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) - b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) - - y = model(a, b) - y.local_value().norm().backward() - - torch.cuda.synchronize(0) - torch.cuda.synchronize(1) - - assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000)) - - -def test_parallel_randoms(setup_rpc): - class Dropouts(nn.Module): - def forward(self, x): - for _ in range(100): - x = F.dropout(x, p=0.001) - return x - - model = nn.Sequential(Dropouts(), Dropouts()) - - x = torch.rand(10, 10, requires_grad=True) - model = Pipe(model, chunks=10, checkpoint="always") - y = model(x) - y = y.local_value() - y.norm().backward() - - assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_checkpoint.py b/test/distributed/pipeline/sync/test_checkpoint.py deleted file mode 100644 index 7be8ddefafe9..000000000000 --- a/test/distributed/pipeline/sync/test_checkpoint.py +++ /dev/null @@ -1,178 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from functools import partial - -import pytest - -import torch -import torch.cuda -from torch import nn - -from torch.distributed.pipeline.sync.checkpoint import ( - checkpoint, - Checkpointing, - is_checkpointing, - is_recomputing, -) -from torch.distributed.pipeline.sync.dependency import fork, join -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.testing._internal.common_utils import run_tests - -devices = ["cpu"] -if torch.cuda.is_available(): - devices.append("cuda") - - -@pytest.mark.parametrize("device", devices) -def test_serial_checkpoints(device): - # Copied from https://github.com/pytorch/pytorch/pull/18568. - timeline = [] - - class Log(torch.autograd.Function): - @staticmethod - def forward(ctx, name, x): - ctx.name = name - timeline.append(f"{name}:forward") - return x.detach() - - @staticmethod - def backward(ctx, grad_output): - name = ctx.name - timeline.append(f"{name}:backward") - return None, grad_output - - a = torch.rand(1, device=device, requires_grad=True) - b = torch.rand(1, device=device, requires_grad=True) - - # Increase the next function sequence number. - _ = a + 1 + 2 + 3 + 4 + 5 - - a = checkpoint(partial(Log.apply, "a"), a) - - a, phony = fork(a) - b = join(b, phony) - - b = checkpoint(partial(Log.apply, "b"), b) - - c = torch.cat((a, b)) - - out = c.sum() - - # +--> {a} --Checkpoint(Log)--> {a} - # {out} --Sum--> {c} --Cat ^-----------------------------+ - # +--> {b} --Checkpoint(Log)--> {b} --First--> {b} - out.backward() - - assert timeline == [ - "a:forward", - "b:forward", - "b:forward", - "b:backward", - "a:forward", - "a:backward", - ] - # |----------------------| |-----------------------| |-----------------------| - # forward pass Checkpoint(Log[b]) Checkpoint(Log[a]) - - -def test_not_requires_grad(): - x = Batch(torch.rand(1, requires_grad=False)) - assert not x[0].requires_grad - - def f(x): - return x * 2 - - chk = Checkpointing(f, x) - x = chk.checkpoint() - assert x[0].requires_grad - - chk.recompute(x) - assert x[0].requires_grad - - x.tensor.backward() - - -def test_not_requires_grad_with_parameter(): - x = torch.rand(1, requires_grad=False) - a = torch.rand(1, requires_grad=True) - - def f(x): - return x * a - - y = checkpoint(f, x) - y.backward() - - assert a.grad is not None - - -@pytest.mark.parametrize("device", devices) -def test_random_in_checkpoint(device): - dropout = nn.Dropout(p=0.5) - - torch.manual_seed(0) - x = torch.randn(3, 3, device=device, requires_grad=True) - y = dropout(x) - y.norm().backward() - - torch.manual_seed(0) - chk_x = torch.randn(3, 3, device=device, requires_grad=True) - chk_y = checkpoint(dropout, chk_x) - chk_y.norm().backward() - - assert torch.allclose(x.grad, chk_x.grad) - - -def test_detect_checkpointing_recomputing(): - logs = [] - - class Detect(nn.Module): - def forward(self, input): - logs.append((is_checkpointing(), is_recomputing())) - return input - - model = Detect() - input = torch.rand(1, requires_grad=True) - - output = checkpoint(model, input) - output.backward() - - assert logs == [(True, False), (False, True)] - - -def test_detect_checkpointing_recomputing_without_checkpoint(): - logs = [] - - class Detect(nn.Module): - def forward(self, input): - logs.append((is_checkpointing(), is_recomputing())) - return input - - model = Detect() - input = torch.rand(1, requires_grad=True) - - output = model(input) - output.backward() - - assert logs == [(False, False)] - - -def test_non_grad_output(): - class ForkNonGrad(nn.Module): - def forward(self, input): - return (input * 2, torch.rand(1)) - - model = ForkNonGrad() - input = torch.rand(1, requires_grad=True) - - output = checkpoint(model, input) - output[0].backward() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_copy.py b/test/distributed/pipeline/sync/test_copy.py deleted file mode 100644 index 302c3d25d53f..000000000000 --- a/test/distributed/pipeline/sync/test_copy.py +++ /dev/null @@ -1,85 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.copy import Copy, Wait -from torch.distributed.pipeline.sync.stream import ( - CPUStream, - current_stream, - get_device, - is_cuda, - new_stream, - use_stream, -) -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - - -def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None): - device = get_device(prev_stream) - - with use_stream(prev_stream): - if is_cuda(prev_stream): - cuda_sleep(0.5) - x = torch.ones(100, device=device, requires_grad=True) - - (y,) = Copy.apply(prev_stream, next_stream, x) - (y,) = Wait.apply(prev_stream, next_stream, x) - - with use_stream(next_stream): - assert torch.allclose(y.sum(), torch.tensor(100.0, device=device)) - y.norm().backward() - with use_stream(prev_stream): - assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device)) - - -def test_copy_wait_cpu_cpu(): - prev_stream = CPUStream - next_stream = CPUStream - _test_copy_wait(prev_stream, next_stream) - - -@skip_if_no_cuda -def test_copy_wait_cpu_cuda(cuda_sleep): - prev_stream = CPUStream - next_stream = current_stream(torch.device("cuda")) - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -@skip_if_no_cuda -def test_copy_wait_cuda_cpu(cuda_sleep): - prev_stream = current_stream(torch.device("cuda")) - next_stream = CPUStream - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -@skip_if_no_cuda -def test_copy_wait_cuda_cuda(cuda_sleep): - prev_stream = current_stream(torch.device("cuda")) - next_stream = new_stream(torch.device("cuda")) - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -def test_wait_multiple_tensors(): - a = torch.rand(1, requires_grad=True) - b = torch.rand(1, requires_grad=True) - - a, b = Wait.apply(CPUStream, CPUStream, a, b) - - assert a.grad_fn is b.grad_fn - assert a.grad_fn.__class__ is Wait._backward_cls - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_deferred_batch_norm.py b/test/distributed/pipeline/sync/test_deferred_batch_norm.py deleted file mode 100644 index c3807c57d612..000000000000 --- a/test/distributed/pipeline/sync/test_deferred_batch_norm.py +++ /dev/null @@ -1,200 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from copy import deepcopy -from itertools import chain - -import pytest - -import torch -from torch import nn, optim - -from torch.distributed.pipeline.sync.batchnorm import DeferredBatchNorm -from torch.testing._internal.common_utils import run_tests - -CHUNKS = 4 - - -def tilt_dist(input): - # Tilt variance by channel. - rgb = input.transpose(0, 1) - rgb[0] *= 1 - rgb[1] *= 10 - rgb[2] *= 100 - - # Tilt mean by single batch. - for i, single in enumerate(input): - single += 2**i - - return input - - -def chunked_forward(model, input, chunks=CHUNKS): - output_chunks = [] - - for chunk in input.chunk(chunks): - output_chunks.append(model(chunk)) - - return torch.cat(output_chunks) - - -@pytest.mark.parametrize("chunks", [1, 4]) -@pytest.mark.parametrize("input_requires_grad", [True, False]) -def test_transparency(chunks, input_requires_grad): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks) - - input1 = torch.rand(16, 3, 224, 224) - input1 = tilt_dist(input1) - input2 = input1.clone() - input1.requires_grad = input_requires_grad - input2.requires_grad = input_requires_grad - - output1 = chunked_forward(bn, input1, chunks=chunks) - output2 = chunked_forward(dbn, input2, chunks=chunks) - - assert torch.allclose(output1, output2, atol=1e-4) - - output1.mean().backward() - output2.mean().backward() - - assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4) - - if input_requires_grad: - assert input1.grad is not None - assert input2.grad is not None - assert torch.allclose(input1.grad, input2.grad, atol=1e-4) - - -@pytest.mark.parametrize("momentum", [0.1, None]) -def test_running_stats(momentum): - bn = nn.BatchNorm2d(3, momentum=momentum) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - bn(input) - chunked_forward(dbn, input) - - assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4) - assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4) - - -def test_convert_deferred_batch_norm(): - bn = nn.BatchNorm2d(3, track_running_stats=False) - bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS) - assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False - - dbn = DeferredBatchNorm(3, chunks=CHUNKS) - dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS) - assert dbn is dbn_again - - dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1) - assert dbn is not dbn_again # because of different chunks - - -def test_eval(): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - bn(input) - chunked_forward(dbn, input) - - bn.eval() - dbn.eval() - - assert torch.allclose(bn(input), dbn(input), atol=1e-4) - - -def test_optimize(): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0) - - for i in range(5): - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - # train - y = bn(input) - a = y.sum() - a.backward() - - y = chunked_forward(dbn, input) - b = y.sum() - b.backward() - - opt.step() - - # eval - bn.eval() - dbn.eval() - - with torch.no_grad(): - assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10**i)) - - -def test_conv_bn(): - bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1) - - # 1st step - a = bn(input) - b = chunked_forward(dbn, input) - - # Outputs are different. (per-mini-batch vs. per-micro-batch) - assert not torch.allclose(a, b) - - a.sum().backward() - b.sum().backward() - opt.step() - opt.zero_grad() - - # Conv layers are also trained differently because of their different outputs. - assert not torch.allclose(bn[0].weight, dbn[0].weight) - - # But BNs track identical running stats. - assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) - assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) - - # 2nd step - a = bn(input) - b = chunked_forward(dbn, input) - a.sum().backward() - b.sum().backward() - - # BNs can't track identical running stats due to the different conv layers. - assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) - assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) - - -def test_input_requiring_grad(): - dbn = DeferredBatchNorm(3, chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - input.requires_grad = True - - chunked_forward(dbn, input) - - assert not dbn.sum.requires_grad - assert dbn.sum.grad_fn is None - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_dependency.py b/test/distributed/pipeline/sync/test_dependency.py deleted file mode 100644 index e966d6541bf5..000000000000 --- a/test/distributed/pipeline/sync/test_dependency.py +++ /dev/null @@ -1,152 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import weakref - -import pytest - -import torch - -from torch.distributed.pipeline.sync.dependency import Fork, fork, Join, join -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_fork_join(): - logs = [] - - class Log(torch.autograd.Function): - @staticmethod - def forward(ctx, number, tensor): - ctx.number = number - return tensor.detach() - - @staticmethod - def backward(ctx, grad): - logs.append(ctx.number) - return None, grad - - a = torch.rand(1, device="cpu", requires_grad=True) - b = torch.rand(1, device="cuda", requires_grad=True) - - a = Log.apply(1, a) - - a, phony = fork(a) - b = join(a, phony) - - b = Log.apply(2, b) - b = b.to("cpu") - - (a + b).backward() - - assert logs == [2, 1] - - -def test_fork_join_enable_grad(): - x = torch.rand(1, requires_grad=True) - - with torch.enable_grad(): - x2, p = fork(x) - - assert p.requires_grad - assert x2 is not x - x = x2 - - assert x.requires_grad - assert p.requires_grad - assert x.grad_fn.__class__ is Fork._backward_cls - assert p.grad_fn.__class__ is Fork._backward_cls - - with torch.enable_grad(): - x2 = join(x, p) - - assert x2 is not x - x = x2 - - assert x.requires_grad - assert x.grad_fn.__class__ is Join._backward_cls - - -def test_fork_join_no_grad(monkeypatch): - def do_not_apply(*args): - raise AssertionError("Function.apply called") - - monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply) - - x = torch.rand(1, requires_grad=True) - - with torch.no_grad(): - x2, p = fork(x) - - assert not p.requires_grad - assert x2 is x - x = x2 - - with torch.no_grad(): - x2 = join(x, p) - - assert x2 is x - x = x2 - - -def test_fork_leak(): - leak = None - - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad): - nonlocal leak - leak = weakref.ref(ctx) - return grad - - x = torch.rand(1, requires_grad=True) - x = F.apply(x) - x, phony = fork(x) - x = join(x, phony) - - x.backward() - del x, phony - - assert leak() is None - - -def test_join_when_fork_not_requires_grad(): - x = torch.rand(2, 1) - a, b = x.chunk(2) - - assert not a.requires_grad - a, p = fork(a) - assert not a.requires_grad - assert not p.requires_grad - - assert not b.requires_grad - b = join(b, p) - assert not b.requires_grad - - -def test_join_when_fork_requires_grad(): - x = torch.rand(2, 1) - a, b = x.chunk(2) - - a.requires_grad_() - assert a.requires_grad - a, p = fork(a) - assert a.requires_grad - assert p.requires_grad - - assert not b.requires_grad - b = join(b, p) - assert b.requires_grad - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_inplace.py b/test/distributed/pipeline/sync/test_inplace.py deleted file mode 100644 index 33f31b2a52bb..000000000000 --- a/test/distributed/pipeline/sync/test_inplace.py +++ /dev/null @@ -1,79 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_utils import run_tests - - -def test_inplace_on_requires_grad(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) - model = Pipe(model, checkpoint="always") - - x = torch.rand(1) - y = model(x).local_value() - - message = r"a leaf Variable that requires grad .* used in an in-place operation." - with pytest.raises(RuntimeError, match=message): - y.backward() - - -@pytest.mark.xfail(strict=True) -def test_inplace_on_not_requires_grad(setup_rpc): - # In-place operation on a tensor not requiring grad doesn't cause a - # RuntimeError. Currently, we cannot detect this case. - model = nn.Sequential(nn.ReLU(inplace=True)) - model = Pipe(model, [1], devices=["cpu"], checkpoint="always") - - x = torch.rand(1) - y = model(x).local_value() - del model - - message = r"a leaf Variable that requires grad .* used in an in-place operation." - with pytest.raises(RuntimeError, match=message): - y.backward() - - -@pytest.mark.xfail(strict=True) -def test_inplace_incorrect_grad(setup_rpc): - class M(nn.Module): - def forward(self, foo_bar): - # 'foo' requires grad but 'bar' does not. In-place operation on - # 'bar' won't cause a RuntimeError. - foo, bar = foo_bar - - # add_(1) is not idempotent, in contrast to relu_(). If it is - # executed multiple times, it will accumulates each difference onto - # 'bar'. - bar.add_(1) - - # 'bar' is still captured by checkpointing. 'foo' will get - # incorrect grad. - return foo * bar - - model = nn.Sequential(M()) - model = Pipe(model, [1], devices=["cpu"], checkpoint="always") - - foo = torch.tensor([1.0], requires_grad=True) - bar = torch.tensor([1.0]) - - output = model((foo, bar)).local_value() - del model - output.backward() - - # The gradient of 'foo' should be 2, but it is 3 actually because - # bar.add_(1) was executed twice due to checkpointing. - assert foo.grad.item() == 2.0 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_microbatch.py b/test/distributed/pipeline/sync/test_microbatch.py deleted file mode 100644 index b5e44aa73a8d..000000000000 --- a/test/distributed/pipeline/sync/test_microbatch.py +++ /dev/null @@ -1,148 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -import torch.cuda - -from torch.distributed.pipeline.sync.microbatch import Batch, check, gather, scatter -from torch.testing._internal.common_utils import run_tests - - -def test_batch_atomic(): - x = torch.tensor(42) - b = Batch(x) - - assert b.atomic - - assert b.tensor is x - with pytest.raises(AttributeError): - b.tensors - - assert list(b) == [x] - assert len(b) == 1 - assert b[0] is x - - -def test_batch_non_atomic(): - x, y = torch.tensor(42), torch.tensor(21) - b = Batch((x, y)) - - assert not b.atomic - - with pytest.raises(AttributeError): - b.tensor - - assert list(b) == [x, y] - assert len(b) == 2 - assert b[0] is x - assert b[1] is y - - -def test_batch_call(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - def f(x): - return x - - def g(x, y): - return x, y - - assert a.call(f).atomic - assert not b.call(g).atomic - - -def test_batch_setitem_by_index(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - a[0] = torch.tensor(0) - b[0] = torch.tensor(0) - - assert a.atomic - assert a[0].item() == 0 - - assert not b.atomic - assert len(b) == 2 - assert b[0].item() == 0 - assert b[1].item() == 21 - - -def test_batch_setitem_by_slice(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - a[:] = (torch.tensor(0),) - b[:] = (torch.tensor(0),) - - assert a.atomic - assert a[0].item() == 0 - - assert not b.atomic - assert len(b) == 1 - assert b[0].item() == 0 - - -def test_check(): - check(torch.device("cpu"), torch.tensor(42)) - check(torch.device("cpu"), torch.tensor(4), torch.tensor(2)) - - with pytest.raises(TypeError): - check(torch.device("cpu"), 42) - - with pytest.raises(TypeError): - check(torch.device("cpu"), "str") - - with pytest.raises(TypeError): - check(torch.device("cpu"), (torch.tensor(4), 2)) - - -def test_gather_tensors(): - a = torch.zeros(1, 1) - b = torch.zeros(1, 1) - - ab = gather([Batch(a), Batch(b)]) - - assert ab.size() == (2, 1) - - -def test_gather_tuples(): - a = (torch.zeros(1, 1), torch.zeros(2, 2)) - b = (torch.zeros(1, 1), torch.zeros(2, 2)) - - ab = gather([Batch(a), Batch(b)]) - - assert isinstance(ab, tuple) - assert ab[0].size() == (2, 1) - assert ab[1].size() == (4, 2) - - -def test_scatter_tensor(): - ab = torch.zeros(2, 1) - - a, b = scatter(ab, chunks=2) - - assert a.tensor.size() == (1, 1) - assert b.tensor.size() == (1, 1) - - -def test_scatter_multiple_tensors(): - ab = (torch.zeros(2, 1), torch.zeros(4, 2)) - - a, b = scatter(*ab, chunks=2) - - assert next(iter(a)).size() == (1, 1) - assert next(iter(b)).size() == (1, 1) - assert list(a)[1].size() == (2, 2) - assert list(b)[1].size() == (2, 2) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_phony.py b/test/distributed/pipeline/sync/test_phony.py deleted file mode 100644 index 6aeb873b30b2..000000000000 --- a/test/distributed/pipeline/sync/test_phony.py +++ /dev/null @@ -1,57 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import torch - -from torch.distributed.pipeline.sync.phony import get_phony -from torch.testing._internal.common_utils import run_tests - - -def test_phony_size(): - p = get_phony(torch.device("cpu"), requires_grad=False) - assert p.size() == (0,) - - -def test_phony_requires_grad(): - p1 = get_phony(torch.device("cpu"), requires_grad=True) - p2 = get_phony(torch.device("cpu"), requires_grad=False) - assert p1.requires_grad - assert not p2.requires_grad - - -def test_cached_phony(): - p1 = get_phony(torch.device("cpu"), requires_grad=True) - p2 = get_phony(torch.device("cpu"), requires_grad=True) - assert p1 is p2 - - p3 = get_phony(torch.device("cpu"), requires_grad=False) - p4 = get_phony(torch.device("cpu"), requires_grad=False) - assert p3 is p4 - - assert p1 is not p3 - - -def test_phony_in_autograd_function(): - class Phonify(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - phony = get_phony(input.device, requires_grad=False) - return phony.detach() - - x = torch.rand(1, requires_grad=True) - - p1 = Phonify.apply(x) - p2 = get_phony(torch.device("cpu"), requires_grad=True) - - assert p1 is not p2 - assert p1.grad_fn is not None - assert p2.grad_fn is None - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_pipe.py b/test/distributed/pipeline/sync/test_pipe.py deleted file mode 100644 index e493b1d5a03e..000000000000 --- a/test/distributed/pipeline/sync/test_pipe.py +++ /dev/null @@ -1,858 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import random -import time -from collections import OrderedDict -from copy import deepcopy - -import pytest - -import torch -from torch import nn, Tensor - -from torch.distributed.pipeline.sync import NoChunk, Pipe, WithDevice -from torch.distributed.pipeline.sync.pipe import PipeSequential -from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_utils import run_tests, TEST_CUDA - -skip_if_no_cuda = pytest.mark.skipif(not TEST_CUDA, reason="cuda required") - - -def test_pipe_without_rpc(): - model = nn.Sequential(nn.Linear(1, 1)) - with pytest.raises(RuntimeError, match="Please initialize RPC framework"): - pipe = Pipe(model, chunks=1) - - -def test_parameters(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, chunks=1) - assert list(pipe.parameters()) != [] - - -def test_public_attrs(setup_rpc): - class MyString: - def __init__(self, value): - self.value = value - - def __str__(self): - return self.value - - model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always")) - - assert pipe.devices == [torch.device("cpu")] - assert pipe.chunks == 42 - assert isinstance(pipe.chunks, int) - assert pipe.checkpoint == "always" - assert isinstance(pipe.checkpoint, str) - - -def test_sequential_like(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - assert len(model) == 2 - assert list(model) == [a, b] - - assert model[0] is a - assert model[1] is b - with pytest.raises(IndexError): - _ = model[2] - - assert model[-1] is b - assert model[-2] is a - - -def test_chunks_less_than_1(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - with pytest.raises(ValueError): - Pipe(model, chunks=0) - - with pytest.raises(ValueError): - Pipe(model, chunks=-1) - - -def test_batch_size_indivisible(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=4) - - with pytest.warns(None) as record: - model(torch.rand(7, 1)) - - # Indivisible batch size is legal. - assert not record - - -def test_batch_size_small(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=4) - - with pytest.warns(None) as record: - model(torch.rand(2, 1)) - - # Batch size smaller than chunks is legal. - assert not record - - -def test_checkpoint_mode(setup_rpc): - def count_grad_fn(grad_fn, name, visited=None): - if visited is None: - visited = set() - if grad_fn in visited: - return 0 - visited.add(grad_fn) - - if grad_fn is None: - return 0 - if grad_fn.__class__.__name__ == name: - return 1 - - counter = 0 - for next_grad_fn, _ in grad_fn.next_functions: - counter += count_grad_fn(next_grad_fn, name, visited=visited) - return counter - - model = nn.Sequential(nn.Linear(1, 1)) - input = torch.rand(2, 1) - - always = Pipe(model, chunks=2, checkpoint="always") - except_last = Pipe(model, chunks=2, checkpoint="except_last") - never = Pipe(model, chunks=2, checkpoint="never") - - always_output = always(input) - except_last_output = except_last(input) - never_output = never(input) - - assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2 - assert ( - count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") - == 1 - ) - assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0 - - -def test_checkpoint_mode_invalid(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - with pytest.raises( - ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'" - ): - Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT") - - -def test_checkpoint_mode_when_chunks_1(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - # All checkpoint modes are fine. - Pipe(model, chunks=1, checkpoint="except_last") - Pipe(model, chunks=1, checkpoint="always") - Pipe(model, chunks=1, checkpoint="never") - - -def test_checkpoint_eval(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=2) - input = torch.rand(2, 1) - - def find_grad_fn(grad_fn, name): - if grad_fn is None: - return False - if grad_fn.__class__.__name__ == name: - return True - for next_grad_fn, _ in grad_fn.next_functions: - if find_grad_fn(next_grad_fn, name): - return True - return False - - model.train() - train_output = model(input) - assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward") - assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward") - - model.eval() - eval_output = model(input) - assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward") - assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward") - - -def test_checkpoint_non_float_input(setup_rpc): - class ForkNonFloat(nn.Module): - def forward(self, input): - return (input * 2, torch.tensor([False])) - - class JoinNonFloat(nn.Module): - def forward(self, input, non_float): - return input * 2 - - model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) - model = Pipe(model, chunks=1, checkpoint="always") - - input = torch.rand(1, requires_grad=True) - output = model(input) - output.backward() - - -def test_no_grad(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=2) - input = torch.rand(2, 1) - - latent = None - - def hook(module, input, output): - _ = module - _ = input - - nonlocal latent - latent = output - - partition = model.partitions[0] - partition.register_forward_hook(hook) - - with torch.no_grad(): - model(input) - - assert latent.grad_fn is None - - -def test_exception(setup_rpc): - class ExpectedException(Exception): - pass - - class Raise(nn.Module): - def forward(self, *_): - raise ExpectedException - - model = nn.Sequential(Raise()) - model = Pipe(model, chunks=1) - - with pytest.raises(ExpectedException): - model(torch.rand(1)) - - -def test_exception_early_stop_asap(setup_rpc): - """Even the first partitions have finished to process, the partition before - the failed partition should be killed as soon as possible. - """ - - class ExpectedException(Exception): - pass - - class Pass(nn.Module): - def forward(self, x): - return x - - counter = 0 - - class Counter(nn.Module): - def forward(self, x): - time.sleep(0.1) - - nonlocal counter - counter += 1 - - return x - - class Raise(nn.Module): - def forward(self, x): - raise ExpectedException - - model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) - model = Pipe(model, chunks=3) - - with pytest.raises(ExpectedException): - model(torch.rand(3)) - - # If the early stop doesn't work, it would be 3 instead. - assert counter == 2 - - -def test_nested_input(setup_rpc): - class NestedInput(nn.Module): - def __init__(self): - super().__init__() - self.fc_a = nn.Linear(1, 1) - self.fc_b = nn.Linear(1, 1) - - def forward(self, inp): - return inp - - model = nn.Sequential(NestedInput()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - b = torch.rand(10, 1, requires_grad=True) - - # TypeError: expected Tensor, but got tuple - with pytest.raises(TypeError): - model((a, (a, b))).local_value() - - # TypeError: expected Tensor, but got list - with pytest.raises(TypeError): - model((a, [a, b])).local_value() - - -def test_input_pair(setup_rpc): - class Two(nn.Module): - def __init__(self): - super().__init__() - self.fc_a = nn.Linear(1, 1) - self.fc_b = nn.Linear(1, 1) - - def forward(self, a, b): - return (self.fc_a(a), self.fc_b(b)) - - model = nn.Sequential(Two()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - b = torch.rand(10, 1, requires_grad=True) - - a_out, b_out = model(a, b).local_value() - loss = (a_out + b_out).mean() - loss.backward() - - assert a.grad is not None - assert b.grad is not None - - -def test_multi_sequence_input(setup_rpc): - class MultiSeq(nn.Module): - def forward(self, tup1, tup2): - return tup1, tup2 - - model = Pipe(nn.Sequential(MultiSeq())) - with pytest.raises(TypeError): - model([torch.rand(10), torch.rand(10)], [torch.rand(10), torch.rand(10)]) - - -def test_input_singleton(setup_rpc): - class One(nn.Module): - def __init__(self): - super().__init__() - self.fc = nn.Linear(1, 1) - - def forward(self, a): - return (self.fc(a),) - - model = nn.Sequential(One()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - - (a_out,) = model(a).local_value() - loss = a_out.mean() - loss.backward() - - assert all(p.grad is not None for p in model.parameters()) - assert a.grad is not None - - -def test_input_varargs(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model) - - a = torch.rand(1) - b = torch.rand(1) - - # TypeError: forward() takes 2 positional arguments but 3 were given - with pytest.raises(TypeError): - model(a, b) - - -def test_non_tensor(setup_rpc): - class NonTensor(nn.Module): - def forward(self, _): - return "hello" - - model = nn.Sequential(NonTensor()) - model = Pipe(model) - x = torch.rand(1) - - with pytest.raises(TypeError): - model(x) - - with pytest.raises(TypeError): - model("hello") - - -def test_non_tensor_sequence(setup_rpc): - class NonTensorTuple(nn.Module): - def forward(self, x): - return (x, "hello") - - class NonTensorArgs(nn.Module): - def forward(self, x: str, y: bool): - return x, y - - model = nn.Sequential(NonTensorTuple()) - model = Pipe(model) - x = torch.rand(1) - - with pytest.raises(TypeError): - model((x, "hello")) - - with pytest.raises(TypeError): - model([x, "hello"]) - - model = nn.Sequential(NonTensorArgs()) - model = Pipe(model) - - with pytest.raises(TypeError): - # Need atleast one Tensor. - model("hello", True) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_valid_non_tensor(checkpoint, setup_rpc): - class NonTensor1(nn.Module): - def forward(self, a: int, b: Tensor, c: bool, d: Tensor): - res = b + a if c else b * a - if d is not None: - res += d - return res, c, a, b, "hello", d - - class NonTensor2(nn.Module): - def forward(self, a: Tensor, b: bool, c: int, d: Tensor, e: str, f: Tensor): - res = a * c if b else a + c - res += d - return c, res, a, d + f if f is not None else d, b, e, f - - model = Pipe( - nn.Sequential(NonTensor1(), NonTensor2()), chunks=5, checkpoint=checkpoint - ) - a = random.randint(0, 10) - b = torch.rand(10, 10) - c = random.randint(0, 1) == 0 - d = torch.rand(10, 10) - res = model(a, b, c, d).local_value() - assert 7 == len(res) - assert [a] * 5 == res[0] - if c: - assert torch.allclose(((b + a + d) * a) + b, res[1]) - assert torch.allclose(b + a + d, res[2]) - else: - assert torch.allclose(((b * a) + d + a) + b, res[1]) - assert torch.allclose(b * a + d, res[2]) - assert torch.allclose(b + d, res[3]) - assert [c] * 5 == res[4] - assert ["hello"] * 5 == res[5] - assert torch.allclose(d, res[6]) - - # Test one of the tensors can be None - res = model(a, b, c, None).local_value() - assert 7 == len(res) - assert [a] * 5 == res[0] - if c: - assert torch.allclose(((b + a) * a) + b, res[1]) - assert torch.allclose(b + a, res[2]) - else: - assert torch.allclose(((b * a) + a) + b, res[1]) - assert torch.allclose(b * a, res[2]) - assert torch.allclose(b, res[3]) - assert [c] * 5 == res[4] - assert ["hello"] * 5 == res[5] - assert [None] * 5 == res[6] - - # Need atleast one tensor. - with pytest.raises(TypeError): - model(a, None, c, None) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_no_tensor_output(checkpoint, setup_rpc): - class Model1(nn.Module): - def forward(self, a: int, b: Tensor, c: bool): - return a, c, "hello" - - class Model2(nn.Module): - def forward(self, a: int, b: bool, c: str): - return a, c, b - - model = Pipe(nn.Sequential(Model1(), Model2()), chunks=5) - a = random.randint(0, 10) - b = torch.rand(10, 10) - c = random.randint(0, 1) == 0 - - # Need atleast one tensor across partitions too. - with pytest.raises(TypeError): - res = model(a, b, c).local_value() - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_uneven_batch_size(checkpoint, setup_rpc): - class Model(nn.Module): - def forward(self, a: Tensor, b: int, c: Tensor): - return a, b, c - - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(3, 10) - b = random.randint(0, 10) - c = torch.rand(6, 10) - res = model(a, b, c).local_value() - assert torch.allclose(a, res[0]) - assert [b] * 3 == res[1] # 3 chunks - assert torch.allclose(c, res[2]) - - # Two tensors producing uneven chunks would fail. - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(3, 10) - b = random.randint(0, 10) - c = torch.rand(4, 10) - - with pytest.raises(RuntimeError, match="Found different number of chunks"): - model(a, b, c) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_no_chunk(checkpoint, setup_rpc): - class Model(nn.Module): - def forward(self, a: Tensor, b: int, c: Tensor): - return a, b, c - - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(10, 10) - b = random.randint(0, 10) - c = torch.rand(10, 10) - res = model(a, b, NoChunk(c)).local_value() - assert torch.allclose(a, res[0]) - assert [b] * 5 == res[1] - # c gets replicated due to NoChunk and the same tensor gets concatenated 5 - # times in the output. - assert torch.allclose(torch.cat((c, c, c, c, c)), res[2]) - - # Test invalid type for NoChunk - with pytest.raises(TypeError, match="NoChunk only supported for tensors"): - NoChunk(b) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_deferred_batch_norm(checkpoint, setup_rpc): - bn = nn.BatchNorm2d(3) - pipe_bn = deepcopy(bn) - pipe = Pipe( - nn.Sequential(pipe_bn), - chunks=2, - checkpoint=checkpoint, - deferred_batch_norm=True, - ) - - x = torch.rand(4, 3, 10, 10) - pipe(x).local_value().mean().backward() - bn(x).mean().backward() - - assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) - assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4) - - -@pytest.mark.parametrize("checkpoint", ["never", "always"]) -def test_deferred_batch_norm_params(checkpoint, setup_rpc): - bn = nn.BatchNorm2d(3) - pipe_bn = deepcopy(bn) - pipe = Pipe( - nn.Sequential(pipe_bn), - chunks=1, - checkpoint=checkpoint, - deferred_batch_norm=True, - ) - - x = torch.rand(4, 3, 10, 10) - pipe(x).local_value().mean().backward() - bn(x).mean().backward() - - assert pipe[0].weight.grad is not None - assert pipe[0].bias.grad is not None - - assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4) - assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4) - - -def test_devices(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - c = nn.Linear(1, 1) - - # There are extra two devices. - model = nn.Sequential(a, b, c) - model = Pipe(model) - - cpu = torch.device("cpu") - # Extra devices must be discarded. - assert model.devices == [cpu, cpu, cpu] - - -def test_partitions(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - assert isinstance(model.partitions, nn.ModuleList) - assert isinstance(model.partitions[0], nn.Sequential) - assert isinstance(model.partitions[1], nn.Sequential) - - assert "partitions.0.0.weight" in model.state_dict() - - -@skip_if_no_cuda -def test_merged_partitions(setup_rpc): - a = nn.Linear(1, 1).to(0) - b = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 2)).to(0) - c = nn.Linear(1, 1) - d = nn.Linear(1, 2) - - model = nn.Sequential(a, b, c, d) - model = Pipe(model) - - assert isinstance(model.partitions, nn.ModuleList) - assert isinstance(model.partitions[0], PipeSequential) - assert isinstance(model.partitions[1], PipeSequential) - assert list(model.partitions[0]) == [a, b[0], b[1]] - assert list(model.partitions[1]) == [c] - assert list(model.partitions[2]) == [d] - - -def test_deny_moving(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - # Moving is denied. - with pytest.raises(TypeError): - model.cuda() - - with pytest.raises(TypeError): - model.cpu() - - with pytest.raises(TypeError): - model.to(torch.device("cuda")) - - with pytest.raises(TypeError): - model.to(0) - - with pytest.raises(TypeError): - model.to("cuda") - - with pytest.raises(TypeError): - model.to(device=0) - - with pytest.raises(TypeError): - model.to(torch.rand(1)) - - with pytest.raises(TypeError): - model.to(tensor=torch.rand(1)) - - # Casting is allowed. - model.half() - model.to(torch.double) - model.to(dtype=torch.float) - - -def test_empty_module(setup_rpc): - # Empty sequential module is not illegal. - model = nn.Sequential() - model = Pipe(model) - - assert model(torch.tensor(42)).local_value() == torch.tensor(42) - - # But only tensor or tensors is legal in Pipe. - with pytest.raises(TypeError): - model(42) - - -def test_named_children(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) - model = Pipe(model) - - names = {n for n, _ in model.named_modules()} - assert "partitions.0.0" in names - assert "partitions.1.0" in names - - # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires - # several methods in its namespace. - with pytest.raises(AttributeError): - model.a - - -def test_verify_module_non_sequential(setup_rpc): - with pytest.raises( - TypeError, match="module must be nn.Sequential to be partitioned" - ): - Pipe(nn.Module()) - - -def test_verify_module_duplicate_children(setup_rpc): - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(conv, conv) - - with pytest.raises( - ValueError, match="module with duplicate children is not supported" - ): - Pipe(model) - - -@skip_if_no_cuda -def test_verify_module_params_on_same_device(setup_rpc): - class Surrogate(nn.Module): - def __init__(self, param1, param2): - super().__init__() - self.param1 = param1 - self.param2 = param2 - - conv1 = nn.Conv2d(3, 3, 1) - conv2 = nn.Conv2d(3, 3, 1) - model = nn.Sequential(Surrogate(conv1, conv2.cuda())) - - with pytest.raises( - ValueError, - match=r"should have all parameters on a single device, please use .to\(\)" - " to place the module on a single device", - ): - Pipe(model) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_verify_nested_modules(setup_rpc): - model = nn.Sequential( - nn.Sequential(nn.Linear(32, 16).cuda(0), nn.Linear(16, 8).cuda(0)), - nn.Sequential(nn.Linear(8, 4).cuda(1), nn.Linear(4, 2).cuda(1)), - ) - - pipe = Pipe(model) - out = pipe(torch.rand(10, 32).cuda(0)) - assert out.local_value().device == torch.device("cuda:1") - assert out.local_value().size() == torch.Size([10, 2]) - - -def test_verify_module_duplicate_parameters_on_same_device(setup_rpc): - class Surrogate(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(Surrogate(conv), Surrogate(conv)) - - Pipe(model) - - -def test_forward_lockstep(setup_rpc): - timeline = [] - - class DelayedLog(nn.Module): - def __init__(self, j, seconds): - super().__init__() - self.i = 0 - self.j = j - self.seconds = seconds - - def forward(self, x): - time.sleep(self.seconds) - - timeline.append((self.i, self.j)) - self.i += 1 - - return x - - model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1)) - model = Pipe(model, chunks=3) - model(torch.rand(3, 1)) - - # Expected timeline: (Logs are recorded at !) - # - # Partition #0: 0! 1! 2! - # Partition #1: 000! 111! 222! - # - assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)] - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -@skip_if_no_cuda -def test_multiple_inputs(checkpoint, setup_rpc): - class Module1(nn.Module): - def forward(self, a, b, c): - return a + b + c, a * b * c - - class Module2(nn.Module): - def forward(self, a, b): - return a + b - - model = Pipe( - nn.Sequential(Module1().cuda(0), Module2().cuda(0)), - chunks=2, - checkpoint=checkpoint, - ) - t = torch.rand(10) - res = model(t, t, t).local_value() - assert torch.equal(res, (t + t + t) + (t * t * t)) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_inputs_wrong_device(setup_rpc): - class Module1(nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(5)) - - def forward(self, a, b): - return a + b + self.param, b - - # Start inputs on wrong device and ensure Pipe moves them correctly. - a = torch.rand(10).cuda(1) - b = torch.rand(10).cuda(1) - model = Pipe(nn.Sequential(Module1().cuda(0), Module1().cuda(1)), chunks=2) - with pytest.raises( - ValueError, - match="All inputs should be on the same device as the first partition", - ): - model(a, b) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_with_device_wrapper(setup_rpc): - fc1 = nn.Linear(16, 8).cuda(0) - fc2 = nn.Linear(8, 4).cuda(1) - dropout = nn.Dropout() - - model = nn.Sequential(fc1, fc2, WithDevice(dropout, "cuda:1")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:1") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0"), torch.device("cuda:1")] == model.devices - - model = nn.Sequential(fc1, WithDevice(dropout, "cuda:1")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:1") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0"), torch.device("cuda:1")] == model.devices - - model = nn.Sequential(fc1, WithDevice(fc2, "cuda:0")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:0") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0")] == model.devices - assert torch.device("cuda:0") == fc2.weight.device - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_pipeline.py b/test/distributed/pipeline/sync/test_pipeline.py deleted file mode 100644 index 9548cb959db1..000000000000 --- a/test/distributed/pipeline/sync/test_pipeline.py +++ /dev/null @@ -1,36 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from torch.distributed.pipeline.sync.pipeline import _clock_cycles -from torch.testing._internal.common_utils import run_tests - - -def test_clock_cycles(): - assert list(_clock_cycles(1, 1)) == [[(0, 0)]] - assert list(_clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]] - assert list(_clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]] - - assert list(_clock_cycles(3, 3)) == [ - [(0, 0)], - [(1, 0), (0, 1)], - [(2, 0), (1, 1), (0, 2)], - [(2, 1), (1, 2)], - [(2, 2)], - ] - - assert list(_clock_cycles(4, 2)) == [ - [(0, 0)], - [(1, 0), (0, 1)], - [(2, 0), (1, 1)], - [(3, 0), (2, 1)], - [(3, 1)], - ] - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_stream.py b/test/distributed/pipeline/sync/test_stream.py deleted file mode 100644 index f9702c8e4152..000000000000 --- a/test/distributed/pipeline/sync/test_stream.py +++ /dev/null @@ -1,198 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.stream import ( - CPUStream, - current_stream, - default_stream, - get_device, - is_cuda, - new_stream, - record_stream, - use_device, - use_stream, - wait_stream, -) -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - - -class TestNewStream: - def test_new_stream_cpu(self): - stream = new_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_new_stream_cuda(self): - stream = new_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream != torch.cuda.default_stream() - - -class TestCurrentStream: - def test_current_stream_cpu(self): - stream = current_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_current_stream_cuda(self): - stream = current_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream == torch.cuda.current_stream() - - -class TestDefaultStream: - def test_default_stream_cpu(self): - stream = default_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_default_stream_cuda(self): - stream = default_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream == torch.cuda.default_stream() - - -class TestUseDevice: - def test_use_device_cpu(self): - with use_device(torch.device("cpu")): - pass - - @skip_if_no_cuda - def test_use_device_cuda(self): - with use_device(torch.device("cuda")): - pass - - -class TestUseStream: - def test_use_stream_cpu(self): - with use_stream(CPUStream): - pass - - @skip_if_no_cuda - def test_use_stream_cuda(self): - stream = new_stream(torch.device("cuda")) - with use_stream(stream): - assert current_stream(torch.device("cuda")) == stream - - -class TestGetDevice: - def test_get_device_cpu(self): - assert get_device(CPUStream).type == "cpu" - - @skip_if_no_cuda - def test_get_device_cuda(self): - stream = current_stream(torch.device("cuda")) - assert get_device(stream).type == "cuda" - - -class TestWaitStream: - def _test_wait_stream(self, source, target, cuda_sleep=None): - with use_stream(target): - if is_cuda(target): - cuda_sleep(0.5) - x = torch.ones(100, 100, device=get_device(target)) - - wait_stream(source, target) - - with use_stream(source): - assert x.sum().item() == 10000 - - def test_wait_stream_cpu_cpu(self): - source = CPUStream - target = CPUStream - self._test_wait_stream(source, target) - - @skip_if_no_cuda - def test_wait_stream_cpu_cuda(self, cuda_sleep): - source = CPUStream - target = new_stream(torch.device("cuda")) - self._test_wait_stream(source, target, cuda_sleep) - - @skip_if_no_cuda - def test_wait_stream_cuda_cpu(self, cuda_sleep): - source = new_stream(torch.device("cuda")) - target = CPUStream - self._test_wait_stream(source, target, cuda_sleep) - - @skip_if_no_cuda - def test_wait_stream_cuda_cuda(self, cuda_sleep): - source = current_stream(torch.device("cuda")) - target = new_stream(torch.device("cuda")) - self._test_wait_stream(source, target, cuda_sleep) - - -class TestRecordStream: - def test_record_stream_cpu(self): - # It should silently ignore CPU tensors. - x = torch.rand(1, device=torch.device("cpu")) - record_stream(x, CPUStream) - - @skip_if_no_cuda - def test_record_stream_cuda(self, cuda_sleep): - # This test detects unexpected block reallocation. For reliable test, - # the stream to allocate tensors is isolated. The allocator will not - # reuse free blocks which were allocated from another stream. - stream_alloc = new_stream(torch.device("cuda")) - with torch.cuda.stream(stream_alloc): - x = torch.rand(1, device=torch.device("cuda")) - - stream = new_stream(torch.device("cuda")) - record_stream(x, stream) - with use_stream(stream): - cuda_sleep(0.5) - - # 'x' is deleted at Python's perspective. But the block of 'x' is still - # required for 'stream'. 'y' shouldn't be allocated to the block. - data_ptr = x.data_ptr() - del x - stream_alloc.synchronize() - with torch.cuda.stream(stream_alloc): - y = torch.rand(1, device=torch.device("cuda")) - assert y.data_ptr() != data_ptr - - # Pause Python until 'stream' finishes tasks queued. Now the block of - # 'x' is free to be reallocated. - wait_stream(CPUStream, stream) - with torch.cuda.stream(stream_alloc): - z = torch.rand(1, device=torch.device("cuda")) - assert z.data_ptr() == data_ptr - - @skip_if_no_cuda - def test_record_stream_shifted_view(self, cuda_sleep): - # Issue: https://github.com/pytorch/pytorch/issues/27366 - stream_alloc = new_stream(torch.device("cuda")) - with torch.cuda.stream(stream_alloc): - x = torch.rand(2, device=torch.device("cuda")) - - y = x[1:] - assert y.data_ptr() > x.data_ptr() - - stream = new_stream(torch.device("cuda")) - with use_stream(stream): - cuda_sleep(0.5) - record_stream(y, stream) - - data_ptr = x.data_ptr() - del x, y - - stream_alloc.synchronize() - with torch.cuda.stream(stream_alloc): - z = torch.rand(2, device=torch.device("cuda")) - assert z.data_ptr() != data_ptr - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_transparency.py b/test/distributed/pipeline/sync/test_transparency.py deleted file mode 100644 index a87a04150fdc..000000000000 --- a/test/distributed/pipeline/sync/test_transparency.py +++ /dev/null @@ -1,55 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_utils import run_tests - - -def test_simple_linears(setup_rpc): - def sum_grad(parameters): - return sum(p.grad.sum() for p in parameters if p.grad is not None) - - def zero_grad(parameters): - for p in parameters: - p.grad = None - - inputs = torch.rand(8, 1) - model = nn.Sequential( - nn.Linear(1, 2), - nn.Linear(2, 4), - nn.Linear(4, 2), - nn.Linear(2, 1), - ) - - # Without Pipe - outputs = model(inputs) - loss = outputs.mean() - loss.backward() - - grad_without_pipe = sum_grad(model.parameters()) - - zero_grad(model.parameters()) - - # With Pipe - model = Pipe(model, chunks=4) - - outputs = model(inputs).local_value() - loss = outputs.mean() - loss.backward() - - grad_with_pipe = sum_grad(model.parameters()) - - # Both grads should be identical. - assert torch.allclose(grad_with_pipe, grad_without_pipe) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_worker.py b/test/distributed/pipeline/sync/test_worker.py deleted file mode 100644 index f82af2ea0067..000000000000 --- a/test/distributed/pipeline/sync/test_worker.py +++ /dev/null @@ -1,118 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import threading - -import pytest - -import torch - -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.distributed.pipeline.sync.stream import CPUStream -from torch.distributed.pipeline.sync.worker import spawn_workers, Task -from torch.testing._internal.common_utils import run_tests - - -class fake_device: - """A test double for :class:`torch.device`. Every fake device is different - with each other. - """ - - type = "fake" - index = None - - -def test_compute_multithreading(): - """Task.compute should be executed on multiple threads.""" - thread_ids = set() - - def log_thread_id(): - thread_id = threading.current_thread().ident - thread_ids.add(thread_id) - return Batch(()) - - with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues): - for i in range(2): - t = Task(CPUStream, compute=log_thread_id, finalize=None) - in_queues[i].put(t) - for i in range(2): - out_queues[i].get() - - assert len(thread_ids) == 2 - - -def test_compute_success(): - """Task.compute returns (True, (task, batch)) on success.""" - - def _42(): - return Batch(torch.tensor(42)) - - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - t = Task(CPUStream, compute=_42, finalize=None) - in_queues[0].put(t) - ok, (task, batch) = out_queues[0].get() - - assert ok - assert task is t - assert isinstance(batch, Batch) - assert batch[0].item() == 42 - - -def test_compute_exception(): - """Task.compute returns (False, exc_info) on failure.""" - - def zero_div(): - 0 / 0 - - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - t = Task(CPUStream, compute=zero_div, finalize=None) - in_queues[0].put(t) - ok, exc_info = out_queues[0].get() - - assert not ok - assert isinstance(exc_info, tuple) - assert issubclass(exc_info[0], ZeroDivisionError) - - -@pytest.mark.parametrize("grad_mode", [True, False]) -def test_grad_mode(grad_mode): - def detect_grad_enabled(): - x = torch.rand(1, requires_grad=torch.is_grad_enabled()) - return Batch(x) - - with torch.set_grad_enabled(grad_mode): - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - task = Task(CPUStream, compute=detect_grad_enabled, finalize=None) - in_queues[0].put(task) - - ok, (_, batch) = out_queues[0].get() - - assert ok - assert batch[0].requires_grad == grad_mode - - -def test_worker_per_device(): - cpu = torch.device("cpu") - cpu0 = torch.device("cpu", index=0) - fake1 = fake_device() - fake2 = fake_device() - - with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues): - assert len(in_queues) == len(out_queues) == 5 - - # 0: cpu, 1: cpu, 2: cpu0 - assert in_queues[0] is in_queues[1] is in_queues[2] - assert out_queues[0] is out_queues[1] is out_queues[2] - - # 3: fake1, 4: fake2 - assert in_queues[3] is not in_queues[4] - assert out_queues[3] is not out_queues[4] - - -if __name__ == "__main__": - run_tests() diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 1db0e5718ce6..8ab2ac1f511f 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -329,7 +329,6 @@ def test_modules_can_be_imported(self): "torch.testing._internal.distributed.fake_pg", "torch.testing._internal.distributed.multi_threaded_pg", "torch.testing._internal.distributed.nn.api.remote_module_test", - "torch.testing._internal.distributed.pipe_with_ddp_test", "torch.testing._internal.distributed.rpc.dist_autograd_test", "torch.testing._internal.distributed.rpc.dist_optimizer_test", "torch.testing._internal.distributed.rpc.examples.parameter_server_test", @@ -408,7 +407,6 @@ def test_modules_can_be_imported(self): "torch.distributed.nn.api.remote_module", "torch.distributed.optim", "torch.distributed.optim.optimizer", - "torch.distributed.pipeline.sync", "torch.distributed.rendezvous", "torch.distributed.rpc.api", "torch.distributed.rpc.backend_registry", diff --git a/test/test_testing.py b/test/test_testing.py index ba9558a3ddd1..1e1dce59a32e 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2245,7 +2245,6 @@ def test_circular_dependencies(self) -> None: else: ignored_modules.append("torch.distributed.nn.api.") ignored_modules.append("torch.distributed.optim.") - ignored_modules.append("torch.distributed.pipeline.") ignored_modules.append("torch.distributed.rpc.") ignored_modules.append("torch.testing._internal.dist_utils") # And these both end up with transitive dependencies on distributed diff --git a/torch/distributed/pipeline/__init__.py b/torch/distributed/pipeline/__init__.py deleted file mode 100644 index eacd2bc99d04..000000000000 --- a/torch/distributed/pipeline/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -import warnings - - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "`torch.distributed.pipeline` is deprecated. For up-to-date pipeline parallel " - "implementation, please refer to the PiPPy library under the PyTorch " - "organization (Pipeline Parallelism for PyTorch): " - "https://github.com/pytorch/PiPPy", - DeprecationWarning, - stacklevel=2, - ) diff --git a/torch/distributed/pipeline/sync/LICENSE b/torch/distributed/pipeline/sync/LICENSE deleted file mode 100644 index e52be240fdc9..000000000000 --- a/torch/distributed/pipeline/sync/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright 2019-2020 Kakao Brain - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from this - software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/torch/distributed/pipeline/sync/__init__.py b/torch/distributed/pipeline/sync/__init__.py deleted file mode 100644 index 75a80c5db0f9..000000000000 --- a/torch/distributed/pipeline/sync/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""A Pipe implementation in PyTorch.""" -from .checkpoint import is_checkpointing, is_recomputing -from .pipe import Pipe, WithDevice -from .microbatch import NoChunk - -__all__ = ["Pipe", "is_checkpointing", "is_recomputing"] diff --git a/torch/distributed/pipeline/sync/_balance/__init__.py b/torch/distributed/pipeline/sync/_balance/__init__.py deleted file mode 100644 index 8ffc657896d8..000000000000 --- a/torch/distributed/pipeline/sync/_balance/__init__.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""A helper to roughly balance a sequential module. - -Usage:: - - import torch - from torch.distributed.pipeline.sync import Pipe - from torch.distributed.pipeline.sync.balance import balance_by_time - - sample = torch.empty(128, 3, 224, 224) - balance = balance_by_time(torch.cuda.device_count(), model, sample) - - pipe = Pipe(model, balance, chunks=8) - -""" -from typing import Any, List, Union, Sequence - -import torch -from torch import Tensor -import torch.nn as nn - -from . import blockpartition -from .profile import profile_sizes, profile_times - -__all__ = ["balance_by_time", "balance_by_size"] - - -Device = Union[torch.device, int, str] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - - -def balance_cost(cost: List[int], partitions: int) -> List[int]: - partitioned = blockpartition.solve(cost, partitions) - return [len(p) for p in partitioned] - - -def balance_by_time( - partitions: int, - module: nn.Sequential, - sample: Union[List[Any], Tensor], - *, - timeout: float = 1.0, - device: Device = torch.device("cuda"), -) -> List[int]: - """Naive automatic balancing by elapsed time per layer. - :: - - sample = torch.empty(128, 3, 224, 224) - balance = balance_by_time(torch.cuda.device_count(), model, sample) - pipe = Pipe(model, balance, chunks=8) - - Args: - partitions (int): - intended number of partitions - module (torch.nn.Sequential): - sequential module to be partitioned - sample (torch.Tensor): - example input with arbitrary batch size - - Keyword Args: - timeout (float): - profiling iterates again if the timeout (in second) is not exceeded - (default: ``1.0``) - device ('cpu' or 'cuda' device): - CPU or CUDA device where each layer is profiled (default: the - current CUDA device) - - Returns: - A list of number of layers in each partition. Use it for the `balance` - parameter of :class:`~torchpipe.Pipe`. - - .. note:: - `module` and `sample` must be placed on the same device. - - """ - times = profile_times(module, sample, timeout, torch.device(device)) - return balance_cost(times, partitions) - - -def balance_by_size( - partitions: int, - module: nn.Sequential, - input: Union[List[Any], Tensor], - *, - chunks: int = 1, - param_scale: float = 2.0, - device: Device = torch.device("cuda"), -) -> List[int]: - """Naive automatic balancing by CUDA memory usage per layer. - - During training, required memory for parameters depends on which optimizer - is used. Optimizers may use buffers for each parameter to track - optimization statistics internally, such as momentum buffer in SGD. - - To get more reliable size based balance, you should specify `param_scale` - with regard to your optimizer. The default `param_scale` is 2 instead of 1 - due to gradient accumulation which is necessary for every optimizer. - - Follow this guide to choose correct `param_scale` for typical optimizers: - - ========= ============= ========================================= - Optimizer `param_scale` Internal State - ========= ============= ========================================= - SGD 2--3 (momentum_buffer) - Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq) - Adadelta 4 square_avg, acc_delta - Adagrad 3 sum - RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg) - ========= ============= ========================================= - - Here's a simple example with the Adam optimizer:: - - balance = balance_by_size( - torch.cuda.device_count(), - model, - - # Same size with mini-batch to train - torch.empty(1024, 3, 224, 224), - - # Number of micro-batches to train with Pipe - chunks=8, - - # 4 for Adam - param_scale=4.0, - ) - - pipe = Pipe(model, balance, chunks=8) - adam = Adam(pipe.parameters()) - - Args: - partitions (int): - intended number of partitions - module (torch.nn.Sequential): - sequential module to be partitioned - input (torch.Tensor): - example mini-batch with the same size to train - - Keyword Args: - chunks (int): - number of micro-batches will be used to train (default: ``1``) - param_scale (float): - how many copies of parameters would be allocated for training. It - depends on optimizer. See the above guide. (default: ``2.0``) - device ('cuda' device): - CUDA device where each layer is profiled (default: the current CUDA - device) - - Returns: - A list of number of layers in each partition. Use it for the `balance` - parameter of :class:`~torchpipe.Pipe`. - - .. note:: - `module` and `input` must be placed on the same CUDA device. - - """ - sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device)) - return balance_cost(sizes, partitions) diff --git a/torch/distributed/pipeline/sync/_balance/blockpartition.py b/torch/distributed/pipeline/sync/_balance/blockpartition.py deleted file mode 100644 index ccdf5fe4df99..000000000000 --- a/torch/distributed/pipeline/sync/_balance/blockpartition.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Implements "Block Partitions of Sequences" by Imre B\u00e1r\u00e1ny et al. - -Paper: https://arxiv.org/pdf/1308.2452.pdf - -""" -from typing import Iterator, List, Tuple - -__all__ = ["solve"] - - -def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]: - """Splits a sequence into several partitions to minimize variance for each - partition. - - The result might not be optimal. However, it can be done only in O(kn\u00b3), - where k is the number of partitions and n is the length of the sequence. - - """ - if partitions < 1: - raise ValueError(f"partitions must be a positive integer ({partitions} < 1)") - - n = len(sequence) - if n < partitions: - raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})") - - # Normalize the sequence in [0, 1]. - minimum = min(sequence) - maximum = max(sequence) - minimum - - normal_sequence: List[float] - if maximum == 0: - normal_sequence = [0 for _ in sequence] - else: - normal_sequence = [(x - minimum) / maximum for x in sequence] - - splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n] - - def block_size(i: int) -> float: - start = splits[i - 1] if i > 0 else 0 - stop = splits[i] - return sum(normal_sequence[start:stop]) - - def leaderboard() -> Iterator[Tuple[float, int]]: - return ((block_size(i), i) for i in range(partitions)) - - while True: - """ - (1) Fix p element-of [k] with M(P) = bp. So Bp is a maximal block of P. - """ - # max_size: M(P) - max_size, p = max(leaderboard()) - - while True: - """ - (2) If M(P) <= m(P) + 1, then stop. - """ - # min_size: m(P) - min_size, q = min(leaderboard()) - - if max_size <= min_size + 1: - return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)] - - """ - (3) If M(P) > m(P) + 1, then let m(P) = bq for the q element-of [k] which is - closest to p (ties broken arbitrarily). Thus Bq is a minimal block - of P. Let Bh be the block next to Bq between Bp and Bq. (Note that - Bh is a non-empty block: if it were, then m(P) = 0 and we should - have chosen Bh instead of Bq.) - """ - if p < q: - """ - So either p < q and then h = q-1 and we define P * by moving - the last element from Bh = Bq-1 to Bq, - """ - h = q - 1 - splits[h] -= 1 - else: - """ - or q < p, and then h = q + 1 and P * is obtained by moving the - first element of Bh = Bq+1 to Bq. - """ - h = q + 1 - splits[q] += 1 - - """ - Set P = P * . If p = h, then go to (1), else go to (2). - """ - if p == h: - break diff --git a/torch/distributed/pipeline/sync/_balance/profile.py b/torch/distributed/pipeline/sync/_balance/profile.py deleted file mode 100644 index fa1a0c06a8e3..000000000000 --- a/torch/distributed/pipeline/sync/_balance/profile.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Per-layer profilers.""" -import copy -import time -from typing import Any, Generator, List, Union, Sequence - -import torch -from torch import Tensor -import torch.nn as nn - -from ..microbatch import Batch - -__all__: List[str] = [] - - -Device = Union[torch.device, int, str] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - - -def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]: - """Copies layers for ease to profile. It doesn't modify the given - module. - """ - for layer in module: - layer_copy = copy.deepcopy(layer) - layer_copy.to(device) - layer_copy.train() - yield layer_copy - - -def detach(batch: Batch) -> None: - """Detaches from autograd graph.""" - for i, x in enumerate(batch): - batch[i] = x.detach().requires_grad_(x.requires_grad) - - -def profile_times(module: nn.Sequential, sample: Union[List[Any], Tensor], timeout: float, device: torch.device,) -> List[int]: - """Profiles elapsed times per layer.""" - if any(p.grad is not None for p in module.parameters()): - raise ValueError("some parameter already has gradient") - - _batch = Batch(sample) - for i, x in enumerate(_batch): - _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad) - - time_bufs: List[List[float]] = [[] for _ in module] - begun_at = time.time() - - while time.time() - begun_at < timeout: - batch = _batch - - for i, layer in enumerate(layerwise_sandbox(module, device)): - detach(batch) - - if device.type == "cuda": - torch.cuda.synchronize(device) - tick = time.time() - - # Forward - batch = batch.call(layer) - - # Backward - backward_tensors = tuple(y for y in batch if y.requires_grad) - if backward_tensors: - torch.autograd.backward(backward_tensors, backward_tensors) - - if device.type == "cuda": - torch.cuda.synchronize(device) - tock = time.time() - - time_bufs[i].append(tock - tick) - - us = 1_000_000 - return [sum(int(t * us) for t in buf) for buf in time_bufs] - - -def profile_sizes( - module: nn.Sequential, input: Union[List[Any], Tensor], chunks: int, param_scale: float, device: torch.device, -) -> List[int]: - """Profiles CUDA memory usage per layer.""" - if device.type != "cuda": - raise ValueError("size profiler supports only CUDA device") - - batch = Batch(input) - sizes: List[int] = [] - - latent_scale = batch[0].size(0) / chunks - for i, x in enumerate(batch): - batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad) - - for layer in layerwise_sandbox(module, device): - detach(batch) - - # Detect memory usage at forward. - torch._C._cuda_clearCublasWorkspaces() - memory_before = torch.cuda.memory_allocated(device) - batch = batch.call(layer) - torch._C._cuda_clearCublasWorkspaces() - memory_after = torch.cuda.memory_allocated(device) - latent_size = memory_after - memory_before - - # Analyze size of parameters. - param_size = sum(p._typed_storage()._nbytes() for p in layer.parameters()) - - # Combine size of parameters and activations with normalize scales. - size = latent_size * latent_scale + param_size * param_scale - sizes.append(int(size)) - - return sizes diff --git a/torch/distributed/pipeline/sync/_balance/py.typed b/torch/distributed/pipeline/sync/_balance/py.typed deleted file mode 100644 index ab03724cafbf..000000000000 --- a/torch/distributed/pipeline/sync/_balance/py.typed +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/batchnorm.py b/torch/distributed/pipeline/sync/batchnorm.py deleted file mode 100644 index 868ad50cf3fc..000000000000 --- a/torch/distributed/pipeline/sync/batchnorm.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Tracks the running statistics per mini-batch instead of micro-batch.""" -from typing import TypeVar, Optional, cast - -import torch -from torch import Tensor, nn -from torch.nn.functional import batch_norm -from torch.nn.modules.batchnorm import _BatchNorm - -from .checkpoint import is_recomputing - -__all__ = ["DeferredBatchNorm"] - - -TModule = TypeVar("TModule", bound=nn.Module) - - -class DeferredBatchNorm(_BatchNorm): - """A BatchNorm layer tracks multiple micro-batches to update running statistics per mini-batch.""" - - sum: Tensor - sum_squares: Tensor - running_mean: Tensor - running_var: Tensor - num_batches_tracked: Tensor - - def __init__( - self, - num_features: int, - eps: float = 1e-5, - momentum: Optional[float] = 0.1, - affine: bool = True, - chunks: int = 1, - ) -> None: - super().__init__(num_features, eps, momentum, affine, track_running_stats=True) - - self.register_buffer("sum", torch.zeros_like(self.running_mean)) - self.register_buffer("sum_squares", torch.zeros_like(self.running_var)) - - self.counter = 0 - self.tracked = 0 - self.chunks = chunks - - def _check_input_dim(self, input: Tensor) -> None: - # It's the typical _check_input_dim() implementation in PyTorch. - if input.dim() <= 2: - raise ValueError("expected at least 3D input (got %dD input)" % input.dim()) - - def _track(self, input: Tensor) -> bool: - """Tracks statistics of a micro-batch.""" - # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d. - dim = [0] - dim.extend(range(2, input.dim())) - - with torch.no_grad(): - self.sum += input.sum(dim) - self.sum_squares += (input ** 2).sum(dim) - - size = input.size().numel() // input.size(1) - self.counter += size - self.tracked += 1 - - return self.tracked == self.chunks - - def _commit(self) -> None: - """Update the running statistics of a mini-batch.""" - exponential_average_factor = 0.0 - self.num_batches_tracked += 1 - if self.momentum is None: # use cumulative moving average - exponential_average_factor = 1.0 / float(self.num_batches_tracked) - else: # use exponential moving average - exponential_average_factor = self.momentum - - mean = self.sum / self.counter - var = self.sum_squares / self.counter - mean ** 2 - - # Calculate the exponential moving average here. - m = exponential_average_factor - - self.running_mean *= 1 - m - self.running_mean += mean * m - - self.running_var *= 1 - m - self.running_var += var * m - - self.sum.zero_() - self.sum_squares.zero_() - self.counter = 0 - self.tracked = 0 - - def forward(self, input: Tensor) -> Tensor: - if not self.training: - # Don't train parameters on the evaluation mode. - return batch_norm( - input, - running_mean=self.running_mean, - running_var=self.running_var, - weight=self.weight, - bias=self.bias, - training=False, - momentum=0.0, - eps=self.eps, - ) - - if not is_recomputing(): - # Track a micro-batch on the training mode - # but not under a recomputation. - tracked_enough = self._track(input) - - # Update the running statistics for a mini-batch - # if it has tracked enough micro-batches. - if tracked_enough: - self._commit() - - # Normalize a micro-batch and train the parameters. - return batch_norm( - input, - running_mean=None, - running_var=None, - weight=self.weight, - bias=self.bias, - training=True, - momentum=0.0, - eps=self.eps, - ) - - @classmethod - def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule: - """Converts a :class:`nn.BatchNorm` or underlying :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`:: - - from torchvision.models.resnet import resnet101 - from torchpipe.batchnorm import DeferredBatchNorm - model = resnet101() - model = DeferredBatchNorm.convert_deferred_batch_norm(model) - - """ - if isinstance(module, DeferredBatchNorm) and module.chunks is chunks: - return cast(TModule, module) - - module_output: nn.Module = module - - if isinstance(module, _BatchNorm) and module.track_running_stats: - module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks) - if module.affine: - module_output.register_parameter("weight", module.weight) - module_output.register_parameter("bias", module.bias) - module_output.register_buffer("running_mean", module.running_mean) - module_output.register_buffer("running_var", module.running_var) - module_output.register_buffer("num_batches_tracked", module.num_batches_tracked) - - for name, child in module.named_children(): - module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks)) - - return cast(TModule, module_output) diff --git a/torch/distributed/pipeline/sync/checkpoint.py b/torch/distributed/pipeline/sync/checkpoint.py deleted file mode 100644 index e67da2499d57..000000000000 --- a/torch/distributed/pipeline/sync/checkpoint.py +++ /dev/null @@ -1,364 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Checkpointing with preceding recomputation. - -PyTorch already provides the official checkpointing utilities in -:mod:`torch.utils.checkpoint`. The official checkpointing combines -recomputation and recursive backpropagation into one autograd function named -``CheckpointFunction``. Hence, the recomputation can be started only when the -gradients arrive to the function. In Pipe, the recomputation needs to precede -the gradient arrival to minimize the GPU idle time. - -We solve this problem by introducing separate autograd functions named -:class:`Recompute` and :class:`Checkpoint`. Each function represents -recomputation and recursive backpropagation, respectively. We can manipulate -the control flow in aspect of both the autograd engine and CUDA with a pair of -the functions. - -Specifically, we place CUDA stream synchronization between :class:`Recompute` -and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is -copied entirely. - -""" -from collections import deque -from contextlib import contextmanager -import threading -from typing import ( - Any, - Deque, - Generator, - List, - Optional, - Protocol, - Union, - Sequence, - Tuple -) - -import torch -from torch import Tensor -import torch.autograd - -from .dependency import fork, join -from .microbatch import Batch -from .phony import get_phony - -__all__ = ["Function", "checkpoint", "Checkpointing", "ThreadLocal", "enable_checkpointing", - "enable_recomputing", "is_checkpointing", "is_recomputing", "Context", "save_rng_states", - "restore_rng_states", "Checkpoint", "Recompute"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -# Types for shared memory between Checkpoint and Recompute. -Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf) -RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state) - - -# Protocol with __call__ instead of Callable can be used as an attribute type. -# See: https://github.com/python/mypy/issues/708#issuecomment-561735949 -class Function(Protocol): - def __call__(self, input: TensorOrTensors) -> TensorOrTensors: - ... - - -def checkpoint(function: Function, input): - """Make a checkpoint with a simple interface like - :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug - :class:`Checkpoint` and :class:`Recompute` without boilerplate. - """ - batch = Batch(input) - - chk = Checkpointing(function, batch) - batch = chk.checkpoint() - chk.recompute(batch) - - return batch.values - - -class Checkpointing: - """Generates a pair of :class:`Checkpoint` and :class:`Recompute`.""" - - def __init__(self, function: Function, batch: Batch) -> None: - self.function = function - self.batch = batch - - # Shared memory between Checkpoint and Recompute. 1-length deque is - # used for mutability and length limitation. - self.recomputed: Deque[Recomputed] = deque(maxlen=1) - self.rng_states: Deque[RNGStates] = deque(maxlen=1) - - def checkpoint(self) -> Batch: - """Return a batch applied by :class:`Checkpoint`.""" - input_atomic = self.batch.atomic - inputs = tuple(self.batch) - - # Use a phony which requires grad to ensure that Checkpoint can be - # tracked by the autograd engine even when none of the input tensors - # require grad. - phony = get_phony(self.batch.get_device(), requires_grad=True) - - output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) - - # Gradients are only supported for float Tensors. - if isinstance(output, tuple): - output = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in output]) - - return Batch(output) - - def recompute(self, batch: Batch) -> None: - """Apply :class:`Recompute` to the batch in place.""" - input_atomic = self.batch.atomic - inputs = tuple(self.batch) - - # Use a tensor in the batch to tie together fork-join - tensor_idx = batch.find_tensor_idx() - # batch[tensor_idx] is always requiring grad, because it has been passed - # checkpoint with a phony requiring grad. - batch[tensor_idx], phony = fork(batch[tensor_idx]) - phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) - batch[tensor_idx] = join(batch[tensor_idx], phony) - - -class ThreadLocal(threading.local): - def __init__(self) -> None: - self.is_checkpointing = False - self.is_recomputing = False - - -thread_local = ThreadLocal() - - -@contextmanager -def enable_checkpointing() -> Generator[None, None, None]: - """Make :func:`is_checkpointing` return :data:`True` within a context.""" - orig = thread_local.is_checkpointing - thread_local.is_checkpointing = True - try: - yield - finally: - thread_local.is_checkpointing = orig - - -@contextmanager -def enable_recomputing() -> Generator[None, None, None]: - """Makes :func:`is_recomputing` return :data:`True` within a context.""" - orig = thread_local.is_recomputing - thread_local.is_recomputing = True - try: - yield - finally: - thread_local.is_recomputing = orig - - -def is_checkpointing() -> bool: - """Whether the current forward propagation is under checkpointing. - - Returns: - bool: :data:`True` if it's under checkpointing. - - """ - return thread_local.is_checkpointing - - -def is_recomputing() -> bool: - """Whether the current forward propagation is under checkpoint recomputation. - - Use this to prevent duplicated side-effects at forward - propagation:: - - class Counter(nn.Module): - def __init__(self): - super().__init__() - self.counter = 0 - - def forward(self, input): - if not is_recomputing(): - self.counter += 1 - return input - - Returns: - bool: :data:`True` if it's under checkpoint recomputation. - - .. seealso:: :ref:`Detecting Recomputation` - - """ - return thread_local.is_recomputing - - -class Context: - """The common interface between the :class:`Checkpoint` and :class:`Recompute` context.""" - - recomputed: Deque[Recomputed] - rng_states: Deque[RNGStates] - function: Function - input_atomic: bool - inputs: Sequence[Any] - - saved_tensors: Tuple[Tensor, ...] - - def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover - pass - - -def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: - """: - Capture the current random number generator states. - - meth:`Checkpoint.forward` captures the current PyTorch's random number - generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. - - .. seealso:: :ref:`Referential Transparency` - - """ - cpu_rng_state = torch.get_rng_state() - - gpu_rng_state: Optional[Tensor] - if device.type == "cuda": - gpu_rng_state = torch.cuda.get_rng_state(device) - else: - gpu_rng_state = None - - rng_states.append((cpu_rng_state, gpu_rng_state)) - - -@contextmanager -def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: - """: - Restore the random number generator state. - - meth:`Recompute.backward` restores the random number generator states - captured by :func:`save_rng_states` within its context. - - .. seealso:: :ref:`Referential Transparency` - - """ - cpu_rng_state, gpu_rng_state = rng_states.pop() - - gpu_devices: List[torch.device] = [] - if device.type == "cuda": - gpu_devices.append(device) - - with torch.random.fork_rng(gpu_devices): - torch.set_rng_state(cpu_rng_state) - if gpu_rng_state is not None: - torch.cuda.set_rng_state(gpu_rng_state, device) - yield - - -class Checkpoint(torch.autograd.Function): - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - phony: Tensor, - recomputed: Deque[Recomputed], - rng_states: Deque[RNGStates], - function: Function, - input_atomic: bool, - *inputs, - ): - ctx.recomputed = recomputed - ctx.rng_states = rng_states - - save_rng_states(phony.device, ctx.rng_states) - - ctx.function = function - ctx.input_atomic = input_atomic - if input_atomic: - tensors = [inputs[0]] - else: - tensors = [] - for input in inputs: - if torch.is_tensor(input): - tensors.append(input) - - ctx.save_for_backward(*tensors) - - with torch.no_grad(), enable_checkpointing(): - if input_atomic: - assert len(inputs) == 1 - output = function(inputs[0]) - else: - output = function(*inputs) - return output - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover - output, input_leaf = ctx.recomputed.pop() - - if isinstance(output, tuple): - outputs = output - else: - outputs = (output,) - if any(torch.is_tensor(y) and y.requires_grad for y in outputs): - tensors = tuple([x for x in outputs if torch.is_tensor(x) and x.requires_grad]) - torch.autograd.backward(tensors, grad_output) - - grad_input: List[Optional[Tensor]] = [None, None, None, None, None] - grad_input.extend(x.grad if torch.is_tensor(x) else None for x in input_leaf) - return tuple(grad_input) - - -class Recompute(torch.autograd.Function): - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - phony: Tensor, - recomputed: Deque[Recomputed], - rng_states: Deque[RNGStates], - function: Function, - input_atomic: bool, - *inputs, - ) -> Tensor: - ctx.recomputed = recomputed - ctx.rng_states = rng_states - - ctx.function = function - ctx.input_atomic = input_atomic - ctx.inputs = inputs - if input_atomic: - tensors = [inputs[0]] - else: - tensors = [] - for input in inputs: - if torch.is_tensor(input): - tensors.append(input) - ctx.save_for_backward(*tensors) - - return phony - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover - inputs = ctx.inputs - inputs_leaf = tuple(x.detach().requires_grad_(x.requires_grad) if torch.is_tensor(x) else x for x in inputs) - - # Get the device for the inputs from a tensor - device = None - for input in inputs: - if torch.is_tensor(input): - device = input.device - break - - if device is None: - raise RuntimeError(f'No tensors found in {inputs}') - - with restore_rng_states(device, ctx.rng_states): - with torch.enable_grad(), enable_recomputing(): - if ctx.input_atomic: - assert len(inputs_leaf) == 1 - output = ctx.function(inputs_leaf[0]) - else: - output = ctx.function(*inputs_leaf) - - ctx.recomputed.append((output, inputs_leaf)) - - grad_input: List[None] = [None, None, None, None, None] - grad_input.extend(None for _ in ctx.inputs) - return tuple(grad_input) diff --git a/torch/distributed/pipeline/sync/copy.py b/torch/distributed/pipeline/sync/copy.py deleted file mode 100644 index b717f0c2932c..000000000000 --- a/torch/distributed/pipeline/sync/copy.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Autograd functions for stream-aware CUDA copy. - -It is used to overlap copy and computation on the same GPU. -""" -from collections import deque -from typing import Deque, List, Optional, Tuple, Sequence - -import torch -from torch import Tensor - -from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream - -__all__: List[str] = ["Context", "Copy", "Wait"] - - -Tensors = Sequence[Tensor] - - -# Common interface between :class:`Copy` and :class:`Wait`. -class Context: - prev_stream: AbstractStream - next_stream: AbstractStream - - -class Copy(torch.autograd.Function): - """Copies tensors on specific streams.""" - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input,) -> Tensors: - ctx.prev_stream = prev_stream - ctx.next_stream = next_stream - - output = [] - output_stream = current_stream(get_device(next_stream)) - - with use_stream(prev_stream), use_stream(next_stream): - for x in input: - if torch.is_tensor(x): - y = x.to(get_device(next_stream), non_blocking=True) - output.append(y) - - # 'prev_stream' is not where 'x' has been allocated. - record_stream(x, prev_stream) - # 'y' has been allocated on 'next_stream'. - # It might be used on the current stream captured as 'output_stream'. - record_stream(y, output_stream) - else: - output.append(x) - - return tuple(output) - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: - prev_stream = ctx.prev_stream - next_stream = ctx.next_stream - - grad_input: Deque[Tensor] = deque(maxlen=len(grad_output)) - input_stream = current_stream(get_device(prev_stream)) - - with use_stream(prev_stream), use_stream(next_stream): - for x in reversed(grad_output): - y = x.to(get_device(prev_stream), non_blocking=True) - grad_input.appendleft(y) - - # 'next_stream' is not where 'x' has been allocated. - record_stream(x, next_stream) - # 'y' has been allocated on 'prev_stream'. - # It might be used on the current stream captured as 'input_stream'. - record_stream(y, input_stream) - - grad_streams: Tuple[Optional[Tensor], ...] = (None, None) - return grad_streams + tuple(grad_input) - - -class Wait(torch.autograd.Function): - """Synchronizes a stream to another stream. - - Place it just before you want to start an operation on the next stream, - provided that all operations on the previous stream are done. - - """ - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input) -> Tensors: - ctx.prev_stream = prev_stream - ctx.next_stream = next_stream - - wait_stream(next_stream, prev_stream) - - return tuple(x.detach() if torch.is_tensor(x) else x for x in input) - - @staticmethod - def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]: - prev_stream = ctx.prev_stream - next_stream = ctx.next_stream - - wait_stream(prev_stream, next_stream) - - grad_streams: Tuple[Optional[Tensor], ...] = (None, None) - return grad_streams + grad_input diff --git a/torch/distributed/pipeline/sync/dependency.py b/torch/distributed/pipeline/sync/dependency.py deleted file mode 100644 index ca5c69e388fe..000000000000 --- a/torch/distributed/pipeline/sync/dependency.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Arbitrary dependency between two autograd lanes.""" -from typing import List, Tuple - -import torch -from torch import Tensor - -from .phony import get_phony - -__all__: List[str] = ["fork", "Fork", "join", "Join"] - - -def fork(input: Tensor) -> Tuple[Tensor, Tensor]: - """Branches out from an autograd lane of the given tensor.""" - if torch.is_grad_enabled() and input.requires_grad: - input, phony = Fork.apply(input) - else: - phony = get_phony(input.device, requires_grad=False) - - return input, phony - - -class Fork(torch.autograd.Function): - @staticmethod - def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore[override] - phony = get_phony(input.device, requires_grad=False) - return input.detach(), phony.detach() - - @staticmethod - def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore[override] - return grad_input - - -def join(input: Tensor, phony: Tensor) -> Tensor: - """Merge two autograd lanes.""" - if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): - input = Join.apply(input, phony) - - return input - - -class Join(torch.autograd.Function): - @staticmethod - def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore[override] - return input.detach() - - @staticmethod - def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore[override] - return grad_input, None diff --git a/torch/distributed/pipeline/sync/microbatch.py b/torch/distributed/pipeline/sync/microbatch.py deleted file mode 100644 index 5b8aca257548..000000000000 --- a/torch/distributed/pipeline/sync/microbatch.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Manipulation of micro-batches.""" -import typing -from typing import Any, Callable, List, Union, cast, Sequence - -import torch -from torch import Tensor -import torch.cuda.comm - -__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] -Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]] - - -class NoChunk: - """ - Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor - should not be chunked on the batch dimension and instead be replicated - as-is across all micro-batches. This is useful for tensors which might - not have any 'batch' semantics for the model. - """ - def __init__(self, inp: Tensor): - if not torch.is_tensor(inp): - raise TypeError(f'NoChunk only supported for tensors, found: {inp}') - self._tensor = inp - - @property - def tensor(self): - return self._tensor - - -class Batch: - """ - An abstraction representing a microbatch in the pipeline. - """ - - def __init__(self, values: Union[List[Any], Tensor]) -> None: - self._values = values - self.atomic = torch.is_tensor(values) - - # Verify at least on tensor - if not self.atomic: - if not any(torch.is_tensor(value) for value in self._values): - raise TypeError(f'No tensors found in batch: {self._values}') - - @property - def tensor(self) -> Tensor: - """Retrieves the underlying tensor.""" - if not self.atomic: - raise AttributeError("not atomic batch") - return cast(Tensor, self._values) - - @property - def values(self): - """Retrieves the underlying values for the batch""" - return self._values - - def find_tensor_idx(self): - """ - Retrieves the index of first tensor found. - """ - if self.atomic: - return 0 - for i, value in enumerate(self._values): - if torch.is_tensor(value): - return i - - raise TypeError("No tensor found!") - - def get_device(self): - """ - Retrieves the device for this microbatch. - """ - if self.atomic: - return self._values.device # type: ignore[union-attr] - - for value in self._values: - if torch.is_tensor(value): - return value.device - - def call(self, function: Function) -> "Batch": - """Calls a function on the microbatch. It also wraps - the output with :class:`Batch`. - """ - if self.atomic: - return Batch(function(self._values)) - else: - return Batch(function(*self._values)) - - def __repr__(self) -> str: - return f"Batch[atomic={self.atomic!r}]({self._values!r})" - - def __iter__(self): - if self.atomic: - yield self._values - else: - yield from self._values - - def __len__(self) -> int: - return 1 if self.atomic else len(self._values) - - def __getitem__(self, index: int): - if not self.atomic: - return self._values[index] - - if index != 0: - raise IndexError("atomic batch allows index 0 only") - - return self._values - - # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload". - @typing.overload - def __setitem__(self, index: int, value: Tensor) -> None: - ... - - @typing.overload - def __setitem__(self, index: slice, value: Tensors) -> None: - ... - - def __setitem__(self, index: Union[int, slice], value) -> None: - if isinstance(index, int): - self._setitem_by_index(index, value) - else: - self._setitem_by_slice(index, value) - - def _setitem_by_index(self, index: int, value) -> None: - if not self.atomic: - i = index - self._values = self._values[:i] + (value,) + self._values[i + 1 :] # type: ignore[operator] - return - - if index != 0: - raise IndexError("atomic batch allows index 0 only") - - self._values = value - - def _setitem_by_slice(self, index: slice, value) -> None: - if not (index.start is index.stop is index.step is None): # noqa: E714 - raise NotImplementedError("only slice [:] supported") - - if not self.atomic: - self._values = value - return - - if len(value) != 1: - raise IndexError("atomic batch cannot be replaced with multiple tensors") - - self._values = value[0] - - -def check(first_device, *inputs) -> None: - """ - Checks whether the input contains at least one tensor and each tensor is - on the same device as the first partition. - - Raises: - ValueError: input does not contain at least one tensor - - """ - - if not any(torch.is_tensor(input) for input in inputs): - raise TypeError(f'inputs do not have any tensors: {inputs}') - if any(torch.is_tensor(input) and input.device != first_device for input in inputs): - raise ValueError('All inputs should be on the same device as the first partition') - - -def scatter(*inputs, chunks: int) -> List[Batch]: - """Splits an input mini-batch into multiple micro-batches.""" - if len(inputs) == 1 and isinstance(inputs[0], Tensor): - return [Batch(x) for x in inputs[0].chunk(chunks)] - - batches: List[Any] = [[] for _ in range(chunks)] - # Actual number of chunks produced - num_chunks = -1 - for input in inputs: - if torch.is_tensor(input): - # Chunk only tensors. - tensors = input.chunk(chunks) - - # Validate number of chunks equal across all inputs. - if num_chunks != -1 and num_chunks != len(tensors): - raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}') - num_chunks = len(tensors) - - for i, tensor in enumerate(tensors): - batches[i].append(tensor) - else: - # Replicate non-tensors or tensors wrapped with 'NoChunk'. - for i in range(chunks): - if isinstance(input, NoChunk): - # Extract the tensor out. - batches[i].append(input.tensor) - else: - batches[i].append(input) - - # Truncate to actual number of chunks - batches = batches[:num_chunks] - - return [Batch(x) for x in batches] - - -def gather(outputs: List[Batch]): - """Concatenates output micro-batches into a mini-batch.""" - output: Any - - if outputs[0].atomic: - tensors = tuple(b.tensor for b in outputs) - output = torch.cat(tensors) - else: - output_buf: List[Any] = [] - for i in range(len(outputs[0])): - output_type = type(outputs[0][i]) - current_outputs = [] - for batch in outputs: - if output_type != type(batch[i]): - raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}') - current_outputs.append(batch[i]) - - if torch.is_tensor(outputs[0][i]): - output_buf.append(torch.cat(current_outputs)) - else: - output_buf.append(current_outputs) - - output = tuple(output_buf) - - return output diff --git a/torch/distributed/pipeline/sync/phony.py b/torch/distributed/pipeline/sync/phony.py deleted file mode 100644 index 012926699cfb..000000000000 --- a/torch/distributed/pipeline/sync/phony.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Provides phony for arbitrary dependency in a autograd graph.""" -from typing import Dict, List, Tuple - -import torch -from torch import Tensor - -from .stream import default_stream, use_stream - -__all__: List[str] = ["get_phony"] - - -_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} - - -def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: - """Get a phony. Phony is tensor without space. - - It is useful to make arbitrary dependency in a autograd graph because it doesn't require any - gradient accumulation. - - .. note:: - - Phonies for each device are cached. If an autograd function gets a phony - internally, the phony must be detached to be returned. Otherwise, the - autograd engine will mutate the cached phony in-place:: - - class Phonify(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - phony = get_phony(input.device, requires_grad=False) - return phony.detach() # detach() is necessary. - - """ - key = (device, requires_grad) - - try: - phony = _phonies[key] - except KeyError: - with use_stream(default_stream(device)): - phony = torch.empty(0, device=device, requires_grad=requires_grad) - - _phonies[key] = phony - - return phony diff --git a/torch/distributed/pipeline/sync/pipe.py b/torch/distributed/pipeline/sync/pipe.py deleted file mode 100644 index 5e61341d9ad9..000000000000 --- a/torch/distributed/pipeline/sync/pipe.py +++ /dev/null @@ -1,490 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The Pipe interface.""" -from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Union, Sequence, Tuple, cast - -import torch -from torch import Tensor, nn -from torch.distributed.rpc import RRef -import torch.autograd -import torch.cuda - -from . import microbatch -from .batchnorm import DeferredBatchNorm -from .pipeline import Pipeline -from .skip.layout import inspect_skip_layout -from .skip.skippable import verify_skippables -from .stream import AbstractStream, new_stream - -__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"] - - -Device = Union[torch.device, int, str] -Devices = Union[Iterable[Device], List[Device]] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -if TYPE_CHECKING: - # Typechecking: nn.Module is not a Generic - Module = nn.Module[TensorOrTensors] # type: ignore[type-arg] - NamedModules = OrderedDict[str, Module] -else: - Module = nn.Module - NamedModules = OrderedDict - - -def _recommend_auto_balance(message: str) -> str: - """Expands a message with recommendation to :mod:`torchpipe.balance`.""" - return f"""{message} - -If your model is still under development, its optimal balance would change -frequently. In this case, we highly recommend 'torch.distributed.pipeline.sync.balance' for -naive automatic balancing: - - from torch.distributed.pipeline.sync import Pipe - from torch.distributed.pipeline.sync.balance import balance_by_time - - partitions = torch.cuda.device_count() - sample = torch.empty(...) - balance = balance_by_time(partitions, model, sample) - - model = Pipe(model, balance, ...) -""" - - -def _verify_module(module: nn.Sequential) -> None: - if not isinstance(module, nn.Sequential): - raise TypeError("module must be nn.Sequential to be partitioned") - - named_children = list(module.named_children()) - if len(named_children) != len(module): - raise ValueError("module with duplicate children is not supported") - - -def _verify_splitting( - module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device] -) -> None: - num_parameters = len(list(module.parameters())) - num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) - if num_parameters == num_child_parameters: - return - - for i in range(len(partitions)): - for j in range(i + 1, len(partitions)): - parti = partitions[i] - partj = partitions[j] - if devices[i] == devices[j]: - continue - for p in parti.parameters(): - for q in partj.parameters(): - if p is q: - raise ValueError("module with duplicate parameters on distinct devices is not supported") - - -class BalanceError(ValueError): - pass - - -def _retrieve_device(module: nn.Module) -> torch.device: - """Validates all parameters in the Module have the same device and returns - the appropriate device. - - Args: - An ``nn.Module`` to process. - - Returns: - ``torch.Device`` for the entire module. - - Raises: - ValueError: - If devices for ``nn.Module`` parameters are not all same. - """ - - device = None - for parameter in module.parameters(): - if device is None: - device = parameter.device - elif device != parameter.device: - raise ValueError( - f'nn.Module: {module}, should have all parameters on a single device,' - ' please use .to() to place the module on a single device') - - return device if device is not None else torch.device("cpu") - - -class PipeSequential(nn.Sequential): - """ - Pipe variant of ``nn.Sequential`` which supports multiple inputs. - """ - - def forward(self, *inputs): - for module in self: - if isinstance(inputs, Tuple): # type: ignore[arg-type] - inputs = module(*inputs) - else: - # Don't expand single variables (ex: lists/Tensor) - inputs = module(inputs) - return inputs - - -class WithDevice(nn.Module): - """ - Wraps an ``nn.Module`` which is part of ``nn.Sequential`` passed into :class:`Pipe` - that overrides the device for that module. In cases where :class:`Pipe` - can't implicitly determine the device for the module and places it on CPU, - this wrapper can be used to override the implicit behavior and explicitly - specify which device a module should run on. - - The provided module is also moved to the given device via ``.to(device)`` - by :class:`Pipe` - - Args: - module(:class:`torch.nn.Module`): The module to be wrapped. - device(:class:`torch.device`): The device to run the module on. - - Example:: - >>> # xdoctest: +SKIP("distributed") - >>> fc1 = nn.Linear(16, 8).cuda(0) - >>> fc2 = nn.Linear(8, 4).cuda(1) - >>> dropout = nn.Dropout() - >>> - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) - >>> # Dropout does not have any parameters/buffers, but we want to - >>> # run it on cuda:1 to avoid any GPU to CPU transfers. - >>> model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1')) - >>> # xdoctest: +SKIP("Needs RPC framework init") - >>> model = Pipe(model, chunks=8) - """ - def __init__(self, module: nn.Module, device: torch.device): - super().__init__() - self._module = module - self._device = torch.device(device) - - def forward(self, *args, **kwargs): - return self._module(*args, **kwargs) - - @property - def module(self): - return self._module - - @property - def device(self): - return self._device - - -def _assemble_partition(modules: List[nn.Module]): - modules_list: List[nn.Module] = [] - for module in modules: - if isinstance(module, nn.Sequential): - modules_list.extend(module.children()) - else: - modules_list.append(module) - return PipeSequential(*modules_list) - - -def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]: - partitions = [] - devices = [] - - current_partition = [] - current_device = None - for name, module in modules.named_children(): - if isinstance(module, WithDevice): - # Process device override and move module to appropriate device. - device = module.device - module = module.module - module.to(device) - else: - device = _retrieve_device(module) - if current_device is not None and (current_device != device or device.type == 'cpu'): - partitions.append(_assemble_partition(current_partition)) - devices.append(current_device) - current_partition = [] - current_device = device - current_partition.append(module) - - if current_device is not None: - partitions.append(_assemble_partition(current_partition)) - devices.append(current_device) - - partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) - - return partitions, devices - - -MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement") - - -class Pipe(Module): - """Wraps an arbitrary :class:`nn.Sequential ` module - to train on using synchronous pipeline parallelism. If the module requires - lots of memory and doesn't fit on a single GPU, pipeline parallelism is a - useful technique to employ for training. - - The implementation is based on the torchgpipe_ paper. - - .. _torchgpipe: https://arxiv.org/abs/2004.09910 - - Pipe combines pipeline parallelism with checkpointing to reduce peak - memory required to train while minimizing device under-utilization. - - You should place all the modules on the appropriate devices and wrap them - into an :class:`nn.Sequential ` module defining the - desired order of execution. If a module does not contain any - parameters/buffers, it is assumed this module should be executed on CPU - and appropriate input tensors to the module are moved to CPU before - execution. This behavior can be overridden by the :class:`WithDevice` - wrapper which can be used to explicitly specify which device a module - should run on. - - Args: - module (:class:`nn.Sequential `): - sequential module to be parallelized using pipelining. Each module - in the sequence has to have all of its parameters on a single - device. Each module in the sequence has to either be an nn.Module - or :class:`nn.Sequential ` (to combine multiple - sequential modules on a single device) - chunks (int): - number of micro-batches (default: ``1``) - checkpoint (str): - when to enable checkpointing, one of ``'always'``, - ``'except_last'``, or ``'never'`` (default: ``'except_last'``). - ``'never'`` disables checkpointing completely, ``'except_last'`` - enables checkpointing for all micro-batches except the last one - and ``'always'`` enables checkpointing for all micro-batches. - deferred_batch_norm (bool): - whether to use deferred ``BatchNorm`` moving statistics (default: - :data:`False`). If set to :data:`True`, we track statistics across - multiple micro-batches to update the running statistics per - mini-batch. - - Raises: - TypeError: - the module is not a :class:`nn.Sequential `. - ValueError: - invalid arguments - - Example:: - Pipeline of two FC layers across GPUs 0 and 1. - - >>> # Need to initialize RPC framework first. - >>> # xdoctest: +SKIP - >>> os.environ['MASTER_ADDR'] = 'localhost' - >>> os.environ['MASTER_PORT'] = '29500' - >>> torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1) - >>> - >>> # Build pipe. - >>> fc1 = nn.Linear(16, 8).cuda(0) - >>> fc2 = nn.Linear(8, 4).cuda(1) - >>> model = nn.Sequential(fc1, fc2) - >>> model = Pipe(model, chunks=8) - >>> input = torch.rand(16, 16).cuda(0) - >>> output_rref = model(input) - - .. note:: - You can wrap a :class:`Pipe` model with - :class:`torch.nn.parallel.DistributedDataParallel` only when the - checkpoint parameter of :class:`Pipe` is ``'never'``. - - .. note:: - :class:`Pipe` only supports intra-node pipelining currently, but - will be expanded to support inter-node pipelining in the future. - The forward function returns an :class:`~torch.distributed.rpc.RRef` - to allow for inter-node pipelining in the future, where the output - might be on a remote host. For intra-node pipelining you can use - :meth:`~torch.distributed.rpc.RRef.local_value` to retrieve the - output locally. - - .. warning:: - :class:`Pipe` is experimental and subject to change. - """ - - def __init__( - self, - module: nn.Sequential, - chunks: int = 1, - checkpoint: str = "except_last", - deferred_batch_norm: bool = False, - ) -> None: - super().__init__() - - # Check if RPC framework is initialized. - if not torch.distributed.rpc._is_current_rpc_agent_set(): - raise RuntimeError( - 'Please initialize RPC framework for Pipe using ' - 'torch.distributed.rpc.init_rpc') - - chunks = int(chunks) - checkpoint = str(checkpoint) - - if chunks <= 0: - raise ValueError("number of chunks must be positive integer") - if checkpoint not in ["always", "except_last", "never"]: - raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") - - _verify_module(module) - - # Verify if the underlying skippable modules satisfy integrity. The - # integrity can be verified before forward() because it is static. - verify_skippables(module) - - self.chunks = chunks - self.checkpoint = checkpoint - - if deferred_batch_norm: - module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) - - self.partitions, self.devices = _split_module(module) - _verify_splitting(module, self.partitions, self.devices) - - self._copy_streams: List[List[AbstractStream]] = [] - self._skip_layout = inspect_skip_layout(self.partitions) - - # Separate CUDA streams for copy. - copy_streams = self._ensure_copy_streams() - - # The micro-batch index where the checkpointing stops. - checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] - - self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) - - def __len__(self) -> int: - """Counts the length of the underlying sequential module.""" - return sum(len(p) for p in self.partitions) - - def __getitem__(self, index: int) -> nn.Module: - """Gets a layer in the underlying sequential module.""" - partitions = self.partitions - if index < 0: - partitions = partitions[::-1] - - for partition in partitions: - try: - return partition[index] - except IndexError: - pass - - shift = len(partition) - - if index < 0: - index += shift - else: - index -= shift - - raise IndexError - - def __iter__(self) -> Iterator[nn.Module]: - """Iterates over children of the underlying sequential module.""" - for partition in self.partitions: - yield from partition - - # Pipe should manage the device of each partition. - # Deny cuda(), cpu(), and to() with device, by TypeError. - def cuda(self, device: Optional[Device] = None) -> "Pipe": - raise MOVING_DENIED - - def cpu(self) -> "Pipe": - raise MOVING_DENIED - - def to(self, *args: Any, **kwargs: Any) -> "Pipe": - # Deny these usages: - # - # - to(device[, dtype, non_blocking]) - # - to(tensor[, non_blocking]) - # - # But allow this: - # - # - to(dtype[, non_blocking]) - # - if "device" in kwargs or "tensor" in kwargs: - raise MOVING_DENIED - - if args: - if isinstance(args[0], (torch.device, int, str)): - raise MOVING_DENIED - if torch.is_tensor(args[0]): - raise MOVING_DENIED - - return super().to(*args, **kwargs) - - def _ensure_copy_streams(self) -> List[List[AbstractStream]]: - """Ensures that :class:`Pipe` caches CUDA streams for copy. - - It's worth to cache CUDA streams although PyTorch already manages a - pool of pre-allocated CUDA streams, because it may reduce GPU memory - fragmentation when the number of micro-batches is small. - - """ - if not self._copy_streams: - for device in self.devices: - self._copy_streams.append([new_stream(device) for _ in range(self.chunks)]) - - return self._copy_streams - - def forward(self, *inputs) -> RRef: - """ - Processes a single input mini-batch through the pipe and returns an - :class:`~torch.distributed.rpc.RRef` pointing to the output. - :class:`Pipe` is a fairly transparent module wrapper. It doesn't - modify the input and output signature of the underlying module. But - there's type restriction. Input and output have to contain at least one - tensor. This restriction is applied at partition boundaries too. - - The sequence of inputs are fed into the first stage of the pipeline as - ``*inputs``. As a result the positional args for this function should - match the positional args for the first stage of the pipeline. The same - condition applies for output of one stage of the pipeline which is the - input for the next stage. - - The input tensor is split into multiple micro-batches based on the - ``chunks`` parameter used to initialize :class:`Pipe`. The batch size - is assumed to be the first dimension of the tensor and if the batch - size is less than ``chunks``, the number of micro-batches is equal to - the batch size. - - Only tensors are split into multiple micro-batches, non-Tensor inputs - are just replicated as-is in each micro-batch. For non-Tensor outputs - in the last stage of the pipeline, they are aggregated as a ``List`` - and returned the user. For example, if you have 2 micro-batches - returning the integer 5, the user would receive the consolidated - output of `[5, 5]` - - All the input tensors need to be on the same device as the first - partition of the pipeline. - - If a tensor is wrapped with the :class:`NoChunk` wrapper, the tensor - is not split across micro-batches and is replicated as-is similar to - non-tensors. - - Args: - inputs: input mini-batch - - Returns: - :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch - - Raises: - TypeError: input doesn't contain at least one tensor - - """ - first_partition_device = self.devices[0] if len(self.devices) != 0 else torch.device("cpu") - microbatch.check(first_partition_device, *inputs) - - if not self.devices: - # Empty sequential module is not illegal. - return RRef(*inputs) - - # Divide a mini-batch into micro-batches. - batches = microbatch.scatter(*inputs, chunks=self.chunks) - - # Run pipeline parallelism. - self.pipeline.run(batches) - - # Merge the micro-batches into one mini-batch. - output = microbatch.gather(batches) - return RRef(output) diff --git a/torch/distributed/pipeline/sync/pipeline.py b/torch/distributed/pipeline/sync/pipeline.py deleted file mode 100644 index 7cd5e5831169..000000000000 --- a/torch/distributed/pipeline/sync/pipeline.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The pipeline parallelism of Pipe.""" -from queue import Queue -from types import TracebackType -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence - -import torch -from torch import Tensor, nn -from torch.autograd.profiler import record_function - -from .checkpoint import Checkpointing -from .copy import Copy, Wait -from .dependency import fork, join -from .microbatch import Batch -from .skip.layout import SkipLayout -from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker -from .stream import AbstractStream, current_stream, use_device -from .worker import Task, create_workers - -__all__: List[str] = ["Pipeline"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] - -# Queue is generic only in stubs. -# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime -if TYPE_CHECKING: - InQueue = Queue[Optional["Task"]] - OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] -else: - InQueue = Queue - OutQueue = Queue - - -def _depend(fork_from: Batch, join_to: Batch) -> None: - fork_from_idx = fork_from.find_tensor_idx() - join_to_idx = join_to.find_tensor_idx() - - fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx]) - join_to[join_to_idx] = join(join_to[join_to_idx], phony) - - -def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: - batch[:] = Copy.apply(prev_stream, next_stream, *batch) - # Gradients are only supported for float Tensors. - batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) - - -def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: - batch[:] = Wait.apply(prev_stream, next_stream, *batch) - # Gradients are only supported for float Tensors. - batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) - - -def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]: - """Generate schedules for each clock cycle.""" - # m: number of micro-batches - # n: number of partitions - # i: index of micro-batch - # j: index of partition - # k: clock number - # - # k (i,j) (i,j) (i,j) - # - ----- ----- ----- - # 0 (0,0) - # 1 (1,0) (0,1) - # 2 (2,0) (1,1) (0,2) - # 3 (2,1) (1,2) - # 4 (2,2) - for k in range(m + n - 1): - yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] - - -class Pipeline: - """The pipeline parallelism for Pipe.""" - - def __init__( - self, - partitions: List[nn.Sequential], - devices: List[torch.device], - copy_streams: List[List[AbstractStream]], - skip_layout: SkipLayout, - checkpoint_stop: int, - ) -> None: - self.partitions = partitions - self.devices = devices - self.copy_streams = copy_streams - self.skip_layout = skip_layout - self.checkpoint_stop = checkpoint_stop - (self.in_queues, self.out_queues) = create_workers(devices) - - def run(self, batches: List[Batch]) -> None: - """Runs pipeline parallelism. - - It modifies the given batches in place. - - """ - partitions = self.partitions - devices = self.devices - skip_layout = self.skip_layout - - m = len(batches) - n = len(partitions) - - skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] - - for schedule in _clock_cycles(m, n): - self.fence(batches, schedule, skip_trackers) - self.compute(batches, schedule, skip_trackers) - - def fence( - self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], - ) -> None: - """Copy micro-batches after computation for the previous micro-batches.""" - copy_streams = self.copy_streams - skip_layout = self.skip_layout - - for i, j in schedule: - # Ensure that batches[i-1] is executed after batches[i] in - # backpropagation by an explicit dependency. - if i != 0 and j != 0: - _depend(batches[i - 1], batches[i]) - - next_stream = copy_streams[j][i] - - for prev_j, ns, name in skip_layout.copy_policy(j): - prev_stream = copy_streams[prev_j][i] - skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name) - - if j != 0: - prev_stream = copy_streams[j - 1][i] - _copy(batches[i], prev_stream, next_stream) - - def compute( - self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], - ) -> None: - """Run tasks with synchronization to copy streams.""" - partitions = self.partitions - devices = self.devices - copy_streams = self.copy_streams - checkpoint_stop = self.checkpoint_stop - - # Disable checkpointing if in eval mode. - if not self.partitions[0].training: - checkpoint_stop = 0 - - n = len(partitions) - streams = [current_stream(d) for d in devices] - exc_info: Optional[ExcInfo] = None - - # With checkpointing, the autograd graph looks like this diagram: - # +-----+------+ - # | Copy | - # +-----+------+ (fence) - # - - - + - - - - - - - - - - # | (compute) - # +-----+------+ - # | Wait | [1] Synchronize the current stream with the copy stream. - # +-----+------+ - # +-----+------+ - # | Checkpoint | [2] Compute a partition within checkpointing. - # +-----+------+ - # +-----+------+ - # | Wait | [3] Synchronize the copy stream with the current stream. - # +-----+------+ - # + - - - + - # | +-----+-----+ - # | | Recompute | [4] Schedule the recomputation at backpropagation. - # | +-----+-----+ - # + - - - + - # | - # - - - + - - - - - - - - - - # +-----+------+ (fence) - # | Copy | - # +-----+------+ - for i, j in schedule: - batch = batches[i] - partition = partitions[j] - - # Synchronize with the copied input. ([1] in the diagram) - if j != 0: - _wait(batch, copy_streams[j][i], streams[j]) - - # Determine whether checkpointing or not. - checkpoint = i < checkpoint_stop - if checkpoint: - - def function( - *inputs, - partition: nn.Module = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> TensorOrTensors: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return partition(*inputs) - - chk = Checkpointing(function, batch) # type: ignore[arg-type] - task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) - del function, chk - - else: - - def compute( - batch: Batch = batch, - partition: nn.Module = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> Batch: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return batch.call(partition) - - task = Task(streams[j], compute=compute, finalize=None) - del compute - - # Compute tasks in parallel. ([2] in the diagram) - self.in_queues[j].put(task) - - for i, j in schedule: - ok, payload = self.out_queues[j].get() - - # Hold the first exception. - if exc_info is not None: - continue - elif not ok: - exc_info = cast(ExcInfo, payload) - continue - - task, batch = cast(Tuple[Task, Batch], payload) - - # The copy stream synchronizes to copy the output. ([3] in the - # diagram) - if j != n - 1: - _wait(batch, streams[j], copy_streams[j][i]) - - # Finalize tasks. If checkpointing is enabled, here the - # recomputation is scheduled at backpropagation. ([4] in the - # diagram) - with use_device(devices[j]): - task.finalize(batch) - - batches[i] = batch - - # Fail at the first exception. - if exc_info is not None: - raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) diff --git a/torch/distributed/pipeline/sync/py.typed b/torch/distributed/pipeline/sync/py.typed deleted file mode 100644 index ab03724cafbf..000000000000 --- a/torch/distributed/pipeline/sync/py.typed +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/skip/__init__.py b/torch/distributed/pipeline/sync/skip/__init__.py deleted file mode 100644 index bdcb913867a7..000000000000 --- a/torch/distributed/pipeline/sync/skip/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Supports efficiency with skip connections.""" -from .namespace import Namespace -from .skippable import pop, skippable, stash, verify_skippables - -__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"] diff --git a/torch/distributed/pipeline/sync/skip/layout.py b/torch/distributed/pipeline/sync/skip/layout.py deleted file mode 100644 index 04d76d34ea16..000000000000 --- a/torch/distributed/pipeline/sync/skip/layout.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Static skip connection layout of ``@skippable`` modules.""" -from typing import Dict, Iterable, List, Tuple - -from torch import nn - -from .namespace import Namespace - -__all__: List[str] = [] - - -class SkipLayout: - """Represents a skip connection layout across partitions.""" - - # Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...} - by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]] - - # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] - by_partition: List[List[Tuple[int, Namespace, str]]] - - def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: - # The skip routes are already indexed by 'ns, name'. - self.by_ns_name = skip_routes - - # Index skip routes by partition number 'j'. - self.by_partition = [[] for _ in range(num_partitions)] - - for (ns, name), (prev_j, next_j) in skip_routes.items(): - self.by_partition[next_j].append((prev_j, ns, name)) - - for p in self.by_partition: - p.sort() - - def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]: - """Generates skip routes for the given destination partition number. - The skip routes are sorted by source partition number in ascending - order. - - Yields: - Each tuple of (source partition number, namespace, name). - - """ - for prev_j, ns, name in self.by_partition[next_j]: - if prev_j == next_j: - # This skip tensor will be popped at the same partition where - # it is stashed. In this case, copy is not required. - continue - - yield (prev_j, ns, name) - - def requires_copy(self, ns: Namespace, name: str) -> bool: - """Whether the given namespace and name requires partition-to-partition - copy or not. - """ - prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1)) - return prev_j != next_j - - -def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout: - """Inspects the skip connection layout in the given partitions.""" - # NOTE(sublee): Hide circular import inside this subroutine. Circular - # import is not ideal but placing this logic near to SkipLayout may - # increase cohesion of code. - from .skippable import Skippable - - skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {} - stashed_at: Dict[Tuple[Namespace, str], int] = {} - - for j, partition in enumerate(partitions): - def inspect_layer(layer): - if not isinstance(layer, Skippable): - return - - for ns, name in layer.stashable(): - stashed_at[(ns, name)] = j - - for ns, name in layer.poppable(): - prev_j = stashed_at.pop((ns, name)) - skip_routes[(ns, name)] = (prev_j, j) - - if isinstance(partition, nn.Sequential): - for layer in partition: - inspect_layer(layer) - else: - inspect_layer(partition) - - return SkipLayout(len(partitions), skip_routes) diff --git a/torch/distributed/pipeline/sync/skip/namespace.py b/torch/distributed/pipeline/sync/skip/namespace.py deleted file mode 100644 index 7d9c0d9b7d84..000000000000 --- a/torch/distributed/pipeline/sync/skip/namespace.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Provides isolated namespace of skip tensors.""" -import abc -from functools import total_ordering -from typing import Any -import uuid - -__all__ = ["Namespace"] - - -@total_ordering -class Namespace(metaclass=abc.ABCMeta): # noqa: B024 - """Namespace for isolating skip tensors used by :meth:`isolate() - `. - """ - - __slots__ = ("id",) - - def __init__(self) -> None: - self.id = uuid.uuid4() - - def __repr__(self) -> str: - return f"" - - def __hash__(self) -> int: - return hash(self.id) - - # Namespaces should support ordering, since SkipLayout will sort tuples - # including a namespace. But actual order between namespaces is not - # important. That's why they are ordered by version 4 UUID which generates - # random numbers. - def __lt__(self, other: Any) -> bool: - if isinstance(other, Namespace): - return self.id < other.id - return False - - def __eq__(self, other: object) -> bool: - if isinstance(other, Namespace): - return self.id == other.id - return False - - -# 'None' is the default namespace, -# which means that 'isinstance(None, Namespace)' is 'True'. -Namespace.register(type(None)) diff --git a/torch/distributed/pipeline/sync/skip/portal.py b/torch/distributed/pipeline/sync/skip/portal.py deleted file mode 100644 index 335793f4cc13..000000000000 --- a/torch/distributed/pipeline/sync/skip/portal.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the -autograd engine. The shared context of three functions (:class:`PortalBlue`, -:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is -one of the most important feature of :mod:`torchpipe.skip`. - -The metaphor is inspired by Portal(tm) from Valve. - -""" -from typing import List, Optional, Tuple - -import torch -from torch import Tensor - -from ..copy import Context as CopyContext -from ..copy import Copy -from ..phony import get_phony -from ..stream import AbstractStream, get_device - -__all__: List[str] = [] - - -class Portal: - """A portal for a tensor.""" - - def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None: - self.put_tensor(tensor, tensor_life) - self.grad: Optional[Tensor] = None - - def blue(self) -> Tensor: - """Creates a :class:`PortalBlue` which hides the underlying tensor from - the autograd engine. - - Join the returning phony to the main lane of the autograd graph to - assure the correct backpropagation:: - - PortalBlue --+ - | - ---------- Join -- - - """ - tensor = self.use_tensor() - - if tensor is None: - return get_phony(torch.device("cpu"), requires_grad=False) - - return PortalBlue.apply(self, tensor) - - def orange(self, phony: Tensor) -> Optional[Tensor]: - """Creates a :class:`PortalOrange` which retrieves the hidden tensor - without losing ability of backpropagation. - - Give a phony forked from the main lane of an autograd graph:: - - +-- PortalOrange --+ - | | - -- Fork --------- f(a, b) -- - - """ - self.check_tensor_life() - - if self.tensor is None: - return self.use_tensor() - - return PortalOrange.apply(self, phony) - - def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor: - """Copies the hidden tensor by a :class:`PortalCopy`. - - Give a phony and use the returning phony to keep backpropagation:: - - +-- PortalCopy --+ - | | - -- Fork ---------- Join -- - - """ - if self.tensor is None: - return get_phony(torch.device("cpu"), requires_grad=False) - - return PortalCopy.apply(self, prev_stream, next_stream, phony) - - def check_tensor_life(self) -> None: - if self.tensor_life <= 0: - raise RuntimeError("tensor in portal has been removed") - - def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None: - """Stores a tensor into this portal.""" - # [Life of Tensor through Portal] - # - # The tensor can be retrieved by use_tensor() up to 'tensor_life' - # times. When the life becomes 0, the tensor will be deleted for - # deallocation in CUDA memory. - # - # The below events participate in a tensor through a portal. - # Note that [x] denotes the events which call use_tensor(): - # - # 1. [x] blue() - # 2. [ ] PortalBlue.forward - # 3. [ ] copy() - # 4. [ ] PortalCopy.forward - # 5. [ ] orange() - # 6. [x] PortalOrange.forward - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 7. [ ] orange() (recomputed) - # 8. [x] PortalOrange.forward (recomputed) - # 9. [ ] PortalOrange.backward - # 10. [ ] PortalCopy.backward - # 11. [x] blue() (recomputed) - # 12. [ ] PortalBlue.forward (recomputed) - # 13. [ ] PortalBlue.backward - # - self.tensor_life = tensor_life - - if tensor_life > 0: - self.tensor = tensor - else: - self.tensor = None - - def use_tensor(self) -> Optional[Tensor]: - """Retrieves the underlying tensor and decreases the tensor life. When - the life becomes 0, it the tensor will be removed. - """ - self.check_tensor_life() - - tensor = self.tensor - - self.tensor_life -= 1 - - if self.tensor_life <= 0: - self.tensor = None - - return tensor - - def put_grad(self, grad: Tensor) -> None: - """Stores a gradient into this portal.""" - self.grad = grad - - def use_grad(self) -> Tensor: - """Retrieves and removes the underlying gradient. The gradient is - always ephemeral. - """ - if self.grad is None: - raise RuntimeError("grad in portal has been removed or never set") - - grad = self.grad - self.grad = None - return grad - - -# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and -# :class:`PortalCopy`. -class Context(CopyContext): - portal: Portal - - -class PortalBlue(torch.autograd.Function): - """Hides a tensor from the autograd engine by a :class:`Portal`.""" - - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - portal: Portal, - # This tensor must be retrieved by portal.use_tensor(). - tensor: Tensor, - ) -> Tensor: - ctx.portal = portal - - phony = get_phony(tensor.device, requires_grad=False) - return phony.detach() - - @staticmethod - # type: ignore[override] - def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]: - # The paired PortalOrange should keep the gradient. - grad = ctx.portal.use_grad() - return None, grad - - -class PortalOrange(torch.autograd.Function): - """Retrieves the hidden tensor from a :class:`Portal`.""" - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor: - ctx.portal = portal - - tensor = portal.use_tensor() - assert tensor is not None - - return tensor.detach() - - @staticmethod - def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore[override] - # The paired PortalBlue will use the gradient. - ctx.portal.put_grad(grad) - return None, None - - -class PortalCopy(torch.autograd.Function): - """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden - tensor with copied one. - """ - - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor, - ) -> Tensor: - ctx.portal = portal - - assert portal.tensor is not None - (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor) - - phony = get_phony(get_device(next_stream), requires_grad=False) - return phony.detach() - - @staticmethod - # type: ignore[override] - def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]: - portal = ctx.portal - - assert portal.grad is not None - _, _, portal.grad = Copy.backward(ctx, portal.grad) - - return None, None, None, None diff --git a/torch/distributed/pipeline/sync/skip/skippable.py b/torch/distributed/pipeline/sync/skip/skippable.py deleted file mode 100644 index 9d4db76c6b67..000000000000 --- a/torch/distributed/pipeline/sync/skip/skippable.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The user interface to define skip connections.""" -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Dict, - FrozenSet, - Generator, - Iterable, - List, - Optional, - Set, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -from torch import Tensor, nn - -from ..microbatch import Batch -from .namespace import Namespace -from .tracker import current_skip_tracker - -__all__ = ["skippable", "stash", "pop", "verify_skippables"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -StashPop = Union["stash", "pop"] -StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors] -if TYPE_CHECKING: - # Typechecking: nn.Module is not a Generic - SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg] -else: - SkippableModule = nn.Module - -T = TypeVar("T", bound="Skippable") - - -class Skippable(nn.Module): - """The base class for skippable modules. - - Do not use this class directly. Define a subclass by :func:`skippable` - instead. - - """ - - module_cls: ClassVar[Type[SkippableModule]] - stashable_names: ClassVar[FrozenSet[str]] - poppable_names: ClassVar[FrozenSet[str]] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self.module = self.module_cls(*args, **kwargs) # type: ignore[call-arg] - self.namespaces: Dict[str, Namespace] = {} - - def __repr__(self) -> str: - return f"@skippable({self.module})" - - def namespaced(self, name: str) -> Tuple[Namespace, str]: - """Prepend namespace for the given skip name.""" - ns = self.namespaces.get(name) - ns = cast(Namespace, ns) - return (ns, name) - - def stashable(self) -> Iterable[Tuple[Namespace, str]]: - """Iterate over namespaced skip names to be stashed.""" - for name in self.stashable_names: - yield self.namespaced(name) - - def poppable(self) -> Iterable[Tuple[Namespace, str]]: - """Iterate over namespaced skip names to be popped.""" - for name in self.poppable_names: - yield self.namespaced(name) - - def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T: - r"""Isolate a specified subset or the whole set of skip tensors. - - In a single sequential module, skip tensors with the same - name are not allowed unless they are isolated by different namespaces. - - Here's an example using the same name for skip tensors twice. Each pair - of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1`` - and ``ns2``. There is no conflict anymore:: - - ns1 = Namespace() - ns2 = Namespace() - - model = nn.Sequential( - Layer1().isolate(ns1), - Layer1().isolate(ns2), - Layer2(), - Layer3().isolate(ns2), - Layer3().isolate(ns1), - ) - - When `only` parameter is omitted, all skip tensors are isolated. You - can isolate a subset of skip tensors by passing `only` parameter:: - - ns_alice = Namespace() - ns_bob = Namespace() - - model = nn.Sequential( - ... - StashStashPop().isolate(ns_alice, only=['alice']) \ - .isolate(ns_bob, only=['bob']), - ... - ) - - Args: - ns (Namespace): - namespace for isolation - - Keyword Args: - only (iterable of strs): - names of specific skip tensors to be isolated (omit this option - to isolate all skip tensors declared in this module) - - Returns: - this module itself - - """ - names: Iterable[str] - - if only is None: - names = self.stashable_names | self.poppable_names - else: - names = set(only) - - for name in names: - self.namespaces[name] = ns - - return self - - def dispatch( - self, - input, - handle_stash: Callable[[str, Optional[Tensor]], None], - handle_pop: Callable[[str], Optional[Tensor]], - ): - """Dispatch :class:`stash` or :class:`pop` commands. - - The commands are generated by the module's ``forward()``. - """ - generator = self.module(input) - - if not isinstance(generator, Generator): - # The underlying module returned output without any yield. - output = generator - return output - - try: - op = next(generator) - - while True: - if isinstance(op, stash): - handle_stash(op.name, op.tensor) - op = next(generator) - continue - - if isinstance(op, pop): - tensor = handle_pop(op.name) - op = generator.send(tensor) - continue - - raise TypeError(f"{op!r} is not a command from @skippable") - - except StopIteration as stop: - output = stop.args[0] - return output - - def forward(self, input: Union[List[Any], Tensor]) -> TensorOrTensors: - """Perform the forward propagation. - - :class:`stash` or :class:`pop` commands will be handled by portals - silently. The portals won't be exposed to users. - - Raises: - RuntimeError: - illegal 'stash' or 'pop' is found. - - """ - skip_tracker = current_skip_tracker() - stashed_tensors: Dict[str, Optional[Tensor]] = {} - - # Load skip tensors that might be popped. - poppable_tensors = {} - batch = Batch(input) - for ns, name in self.poppable(): - try: - poppable_tensors[name] = skip_tracker.load(batch, ns, name) - except KeyError as e: - raise RuntimeError(f"'{name}' has not been stashed") from e - input = batch.values - - # Handle skip commands. - def handle_stash(name: str, tensor: Optional[Tensor]) -> None: - if name not in self.stashable_names: - raise RuntimeError(f"'{name}' has not been declared as stashable") - stashed_tensors[name] = tensor - - def handle_pop(name: str) -> Optional[Tensor]: - if name not in self.poppable_names: - raise RuntimeError(f"'{name}' has not been declared as poppable") - return poppable_tensors.pop(name) - - output = self.dispatch(input, handle_stash, handle_pop) - - # All declared skips must be stashed or popped. - not_stashed = self.stashable_names - stashed_tensors.keys() - if not_stashed: - comma_names = ", ".join(f"'{n}'" for n in not_stashed) - raise RuntimeError(f"{comma_names} must be stashed but have not") - - not_popped = poppable_tensors.keys() - if not_popped: - comma_names = ", ".join(f"'{n}'" for n in not_popped) - raise RuntimeError(f"{comma_names} must be popped but have not") - - # Save stashed skip tensors. - batch = Batch(output) - for ns, name in self.stashable(): - tensor = stashed_tensors[name] - skip_tracker.save(batch, ns, name, tensor) - output = batch.values - - return output - - -# TODO(sublee): Move to above of Skippable class for better read flow. -def skippable( - stash: Iterable[str] = (), pop: Iterable[str] = (), -) -> Callable[[Type[SkippableModule]], Type[Skippable]]: - """Define a decorator to create :class:`nn.Module ` with skip connections. - - These decorated modules are called "skippable". This functionality works perfectly - fine even when the module is not wrapped by :class:`~torch.distributed.pipeline.sync.Pipe`. - - Each skip tensor is managed by its name. Before manipulating skip tensors, - a skippable module must statically declare the names for skip tensors by - `stash` and/or `pop` parameters. Skip tensors with pre-declared name can be - stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield - pop(name)``. - - Here is an example with three layers. A skip tensor named "1to3" is stashed - and popped at the first and last layer, respectively:: - - @skippable(stash=['1to3']) - class Layer1(nn.Module): - def forward(self, input): - yield stash('1to3', input) - return f1(input) - - class Layer2(nn.Module): - def forward(self, input): - return f2(input) - - @skippable(pop=['1to3']) - class Layer3(nn.Module): - def forward(self, input): - skip_1to3 = yield pop('1to3') - return f3(input) + skip_1to3 - - model = nn.Sequential(Layer1(), Layer2(), Layer3()) - - One skippable module can stash or pop multiple skip tensors:: - - @skippable(stash=['alice', 'bob'], pop=['carol']) - class StashStashPop(nn.Module): - def forward(self, input): - yield stash('alice', f_alice(input)) - yield stash('bob', f_bob(input)) - carol = yield pop('carol') - return input + carol - - Every skip tensor must be associated with exactly one pair of `stash` and - `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this - restriction automatically when wrapping a module. You can also check the - restriction by :func:`verify_skippables` - without :class:`~torch.distributed.pipeline.sync.Pipe`. - - """ - stashable_names = frozenset(stash) - poppable_names = frozenset(pop) - - def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]: - name = module_cls.__name__ - bases = (Skippable,) - attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names} - return type(name, bases, attrs) - - return extend_skippable - - -class stash: - """The command to stash a skip tensor. - - :: - - def forward(self, input): - yield stash('name', input) - return f(input) - - Args: - name (str): name of skip tensor - input (torch.Tensor or None): tensor to pass to the skip connection - - """ - - __slots__ = ("name", "tensor") - - def __init__(self, name: str, tensor: Optional[Tensor]) -> None: - self.name = name - self.tensor = tensor - - -class pop: - """The command to pop a skip tensor. - - :: - - def forward(self, input): - skip = yield pop('name') - return f(input) + skip - - Args: - name (str): name of skip tensor - - Returns: - the skip tensor previously stashed by another layer under the same name - - """ - - __slots__ = ("name",) - - def __init__(self, name: str) -> None: - self.name = name - - -def verify_skippables(module: nn.Sequential) -> None: - """Verify if the underlying skippable modules satisfy integrity. - - Every skip tensor must have only one pair of `stash` and `pop`. If there - are one or more unmatched pairs, it will raise :exc:`TypeError` with the - detailed messages. - - Here are a few failure cases. :func:`verify_skippables` will report failure - for these cases:: - - # Layer1 stashes "1to3". - # Layer3 pops "1to3". - - nn.Sequential(Layer1(), Layer2()) - # +---- ? - - nn.Sequential(Layer2(), Layer3()) - # ? ----+ - - nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3()) - # +-------------------+ ^^^^^^ - - nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3()) - # ^^^^^^ +-------------------+ - - To use the same name for multiple skip tensors, they must be isolated by - different namespaces. See :meth:`isolate() - `. - - Raises: - TypeError: - one or more pairs of `stash` and `pop` are not matched. - - """ - stashed: Set[Tuple[Namespace, str]] = set() - popped: Set[Tuple[Namespace, str]] = set() - msgs: List[str] = [] - - for layer_name, layer in module.named_children(): - if not isinstance(layer, Skippable): - continue - - for name in layer.stashable_names & layer.poppable_names: - msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable" - msgs.append(msg) - - for ns, name in layer.stashable(): - if name in layer.poppable_names: - continue - - if (ns, name) in stashed: - msg = f"'{layer_name}' redeclared '{name}' as stashable but not isolated by namespace" - msgs.append(msg) - continue - - stashed.add((ns, name)) - - for ns, name in layer.poppable(): - if name in layer.stashable_names: - continue - - if (ns, name) in popped: - msg = f"'{layer_name}' redeclared '{name}' as poppable but not isolated by namespace" - msgs.append(msg) - continue - - if (ns, name) not in stashed: - msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed" - msgs.append(msg) - continue - - popped.add((ns, name)) - - for (_, name) in stashed - popped: - msg = f"no module declared '{name}' as poppable but stashed" - msgs.append(msg) - - if msgs: - raise TypeError( - "one or more pairs of stash and pop do not match:\n\n{}" "".format("\n".join(f"* {x}" for x in msgs)) - ) diff --git a/torch/distributed/pipeline/sync/skip/tracker.py b/torch/distributed/pipeline/sync/skip/tracker.py deleted file mode 100644 index 8ac82bc05dc9..000000000000 --- a/torch/distributed/pipeline/sync/skip/tracker.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Tracks skip tensors on a thread.""" -from contextlib import contextmanager -import threading -from typing import Dict, Generator, List, Optional, Tuple - -from torch import Tensor - -from ..checkpoint import is_checkpointing -from ..dependency import fork, join -from ..microbatch import Batch -from ..stream import AbstractStream -from .layout import SkipLayout -from .namespace import Namespace -from .portal import Portal - -__all__: List[str] = [] - - -class SkipTracker: - """Tracks saved skip tensors. - - It will update the given micro-batch in place. This is because when it - manipulates the underlying skip tensors, the current micro-batch also has - to be connected with the skip tensors. - - One thread has one skip tracker. Call :func:`current_skip_tracker` to get - the skip tracker on the current thread. - - """ - - def __init__(self) -> None: - self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {} - - def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: - self.tensors[(ns, name)] = tensor - - def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: - return self.tensors.pop((ns, name)) - - def copy( - self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, - ) -> None: - raise TypeError("copy is not supported for non-portal skip tensors") - - -class SkipTrackerThroughPotals(SkipTracker): - """Tracks saved skip tensors through portals. The skip tensors will be - hidden in portals so that the autograd engine does not need to track them. - - This tracker is only used when the training or evaluating module is wrapped - with :class:`torchpipe.Pipe`. - - """ - - def __init__(self, skip_layout: SkipLayout) -> None: - super().__init__() - self.skip_layout = skip_layout - self.portals: Dict[Tuple[Namespace, str], Portal] = {} - - def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: - """Saves the stashed skip tensor in a portal. The portal is then - connected to the given micro-batch with :class:`Join`. - """ - if not self.skip_layout.requires_copy(ns, name): - super().save(batch, ns, name, tensor) - return - - # See [Tensor Life of Portal] at Portal.put_tensor() to understand the - # below tensor_life values. Here are the selected events which retrieve - # the tensor in portal: - # - # 1. [x] blue() - # ... - # 6. [x] PortalOrange.forward - # ... - # 8. [x] PortalOrange.forward (recomputed) - # ... - # 11. [x] blue() (recomputed) - # - if (ns, name) not in self.portals: - if is_checkpointing(): - # Under checkpointing, the tensor used by the first - # PortalOrange should be alive in the portal. This tensor will - # be used again by the second PortalOrange during the - # recomputation. - tensor_life = 3 # Delete at [8. PortalOrange.forward (recomputed)] - else: - tensor_life = 2 # Delete at [6. PortalOrange.forward] - - portal = Portal(tensor, tensor_life) - self.portals[(ns, name)] = portal - - else: - # Under recomputation, the portal already exists. - portal = self.portals[(ns, name)] - - # The existing tensor life already became 0. It should be reset as - # 1 to delete the tensor after the second PortalBlue immediately. - tensor_life = 1 # Delete at [11. blue() (recomputed)] - - portal.put_tensor(tensor, tensor_life) - - phony = portal.blue() - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx] = join(batch[tensor_idx], phony) - - def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: - """Loads a skip tensor from the corresponding portal to pop. The given - micro-batch is connected to the portal with :class:`Fork`. - """ - if not self.skip_layout.requires_copy(ns, name): - tensor = super().load(batch, ns, name) - return tensor - - portal = self.portals[(ns, name)] - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx], phony = fork(batch[tensor_idx]) - tensor = portal.orange(phony) - return tensor - - def copy( - self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, - ) -> None: - """Copies the skip tensor in the corresponding portal. The given - micro-batch and the portal will be tied with :class:`Fork` and - :class:`Join`. - """ - assert self.skip_layout.requires_copy(ns, name) - - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx], phony = fork(batch[tensor_idx]) - - portal = self.portals[(ns, name)] - phony = portal.copy(prev_stream, next_stream, phony) - - batch[tensor_idx] = join(batch[tensor_idx], phony) - - -class ThreadLocal(threading.local): - def __init__(self) -> None: - self.skip_tracker: Optional[SkipTracker] = None - - -thread_local = ThreadLocal() - - -@contextmanager -def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]: - """Registers the given skip tracker on the current thread within a - context:: - - with use_skip_tracker(my_skip_tracker): - ... - - """ - orig = thread_local.skip_tracker - - thread_local.skip_tracker = skip_tracker - - try: - yield - finally: - thread_local.skip_tracker = orig - - -def current_skip_tracker() -> SkipTracker: - """Gets the skip tracker on the current thread.""" - skip_tracker = thread_local.skip_tracker - - if skip_tracker is None: - skip_tracker = SkipTracker() - thread_local.skip_tracker = skip_tracker - - return skip_tracker diff --git a/torch/distributed/pipeline/sync/stream.py b/torch/distributed/pipeline/sync/stream.py deleted file mode 100644 index 59fedf865a42..000000000000 --- a/torch/distributed/pipeline/sync/stream.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Utilities for eliminating boilerplate code to handle abstract streams with -CPU device. -""" -from contextlib import contextmanager -from typing import Generator, List, Union, cast - -import torch - -__all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream", - "use_device", "use_stream", "get_device", "wait_stream", "record_stream", - "is_cuda", "as_cuda"] - - -class CPUStreamType: - pass - - -# The placeholder on place of streams for the CPU device instead of CUDA. -CPUStream = CPUStreamType() - -# It represents both CUDA streams and the CPU stream. -AbstractStream = Union[torch.cuda.Stream, CPUStreamType] - - -def new_stream(device: torch.device) -> AbstractStream: - """Creates a new stream for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.Stream(device) - - -def current_stream(device: torch.device) -> AbstractStream: - """:func:`torch.cuda.current_stream` for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.current_stream(device) - - -def default_stream(device: torch.device) -> AbstractStream: - """:func:`torch.cuda.default_stream` for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.default_stream(device) - - -@contextmanager -def use_device(device: torch.device) -> Generator[None, None, None]: - """:func:`torch.cuda.device` for either CPU or CUDA device.""" - if device.type != "cuda": - yield - return - - with torch.cuda.device(device): - yield - - -@contextmanager -def use_stream(stream: AbstractStream) -> Generator[None, None, None]: - """:func:`torch.cuda.stream` for either CPU or CUDA stream.""" - if not is_cuda(stream): - yield - return - - with torch.cuda.stream(as_cuda(stream)): - yield - - -def get_device(stream: AbstractStream) -> torch.device: - """Gets the device from CPU or CUDA stream.""" - if is_cuda(stream): - return as_cuda(stream).device - return torch.device("cpu") - - -def wait_stream(source: AbstractStream, target: AbstractStream) -> None: - """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It - makes the source stream wait until the target stream completes work queued. - """ - if is_cuda(target): - if is_cuda(source): - # A CUDA stream waits another CUDA stream. - as_cuda(source).wait_stream(as_cuda(target)) - else: - # CPU waits a CUDA stream. - as_cuda(target).synchronize() - - # If the target is CPU, synchronization is not required. - - -def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: - """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream.""" - if is_cuda(stream): - # NOTE(sublee): record_stream() on a shifted view tensor throws - # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely - # protect the tensor against unexpected reallocation, here we use a - # temporal tensor associated with the same storage without shifting as - # a workaround. - # - # Issue: https://github.com/pytorch/pytorch/issues/27366 - # - tensor = tensor.new_empty([0]).set_(tensor._typed_storage()) - - # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream - tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type] - - -def is_cuda(stream: AbstractStream) -> bool: - """Returns ``True`` if the given stream is a valid CUDA stream.""" - return stream is not CPUStream - - -def as_cuda(stream: AbstractStream) -> torch.cuda.Stream: - """Casts the given stream as :class:`torch.cuda.Stream`.""" - return cast(torch.cuda.Stream, stream) diff --git a/torch/distributed/pipeline/sync/utils.py b/torch/distributed/pipeline/sync/utils.py deleted file mode 100644 index 210c475317e2..000000000000 --- a/torch/distributed/pipeline/sync/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -from torch import nn -from typing import List, Optional - -__all__ = ["partition_model"] - -def partition_model( - module: nn.Sequential, - balance: List[int], - devices: Optional[List[int]] = None): - """ - Partions the model accross multiple GPU devices. - - Given an :class:`nn.Sequential ` module, partitions - the model across multiple GPU devices according the provided ``balance`` - and ``devices``. - - Args: - module (:class:`nn.Sequential `): - Sequential model representing the pipe. - balance (List[int]): - List indicating the number of layers in each partition. - devices (List[int], optional): - List indicating the device to use for each partition. Defaults to - ``range(len(balance))`` - """ - device_idx = 0 - pipe_idx = 0 - balanced_pipe = [] - for num_layers in balance: - layers = [] - for i in range(num_layers): - layers.append(module[pipe_idx]) - pipe_idx += 1 - device = device_idx if devices is None else devices[device_idx] - balanced_pipe.append(nn.Sequential(*layers).to(device)) - device_idx += 1 - - return nn.Sequential(*balanced_pipe) diff --git a/torch/distributed/pipeline/sync/worker.py b/torch/distributed/pipeline/sync/worker.py deleted file mode 100644 index 87b20c4a5551..000000000000 --- a/torch/distributed/pipeline/sync/worker.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Multithreading in pipeline parallelism.""" -from contextlib import contextmanager -from queue import Queue -import sys -from threading import Thread -from types import TracebackType -from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast - -import torch - -from .microbatch import Batch -from .stream import AbstractStream, use_device, use_stream - -__all__: List[str] = ["Task", "worker", "create_workers", "spawn_workers"] - - -ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] - -# Queue is generic only in stubs. -# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime -if TYPE_CHECKING: - InQueue = Queue[Optional["Task"]] - OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] -else: - InQueue = Queue - OutQueue = Queue - - -class Task: - """A task represents how to compute a micro-batch on a partition. - - It consists of two parts: :meth:`compute` and :meth:`finalize`. - :meth:`compute` should be executed in worker threads concurrently. - :meth:`finalize` should be executed after when worker threads complete to - execute :meth:`compute`. - - :meth:`compute` might be boosted by worker threads. Because it produces - several CUDA API calls by user code. In PyTorch, parallel CUDA API calls - are not serialized through GIL. So more than one CUDA API call can be - produced at the same time. - - """ - - def __init__( - self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]], - ) -> None: - self.stream = stream - self._compute = compute - self._finalize = finalize - self._grad_enabled = torch.is_grad_enabled() - - def compute(self) -> Batch: - with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): - return self._compute() - - def finalize(self, batch: Batch) -> None: - if self._finalize is None: - return - with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): - self._finalize(batch) - - -def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None: - """Main loop of a worker thread.""" - with use_device(device): - while True: - task = in_queue.get() - - if task is None: - break - - try: - batch = task.compute() - except Exception: - exc_info = cast(ExcInfo, sys.exc_info()) - out_queue.put((False, exc_info)) - continue - - out_queue.put((True, (task, batch))) - - done = (False, None) - out_queue.put(done) - - -def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]: - """Spawns worker threads. A worker thread is bound to a device.""" - in_queues: List[InQueue] = [] - out_queues: List[OutQueue] = [] - - # Spawn workers. - workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {} - - def normalize_device(device: torch.device) -> torch.device: - if device.type == "cuda" and device.index is None: - return torch.device("cuda", index=torch.cuda.current_device()) - - if device.type == "cpu" and device.index is not None: - return torch.device("cpu") - - return device - - for device in devices: - device = normalize_device(device) - - try: - in_queue, out_queue = workers[device] - except KeyError: - in_queue = Queue() - out_queue = Queue() - workers[device] = (in_queue, out_queue) - - t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,) - t.start() - - in_queues.append(in_queue) - out_queues.append(out_queue) - - return (in_queues, out_queues) - -@contextmanager -def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]: - try: - (in_queues, out_queues) = create_workers(devices) - yield (in_queues, out_queues) - finally: - pass diff --git a/torch/testing/_internal/distributed/pipe_with_ddp_test.py b/torch/testing/_internal/distributed/pipe_with_ddp_test.py deleted file mode 100644 index 1ed9f3cc96df..000000000000 --- a/torch/testing/_internal/distributed/pipe_with_ddp_test.py +++ /dev/null @@ -1,149 +0,0 @@ -# mypy: ignore-errors - -import torch -import torch.distributed as dist - -from torch import nn -from torch.nn.parallel import DistributedDataParallel -from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init -from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( - RpcAgentTestFixture, -) -from torch.testing._internal.common_distributed import ( - requires_gloo, - requires_nccl, - skip_if_lt_x_gpu, - skip_if_rocm, -) -from torch.distributed.pipeline.sync import Pipe - -class PipeWithDDPTest(RpcAgentTestFixture): - @property - def world_size(self) -> int: - return 2 - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_never(self): - self._run_basic_test("nccl", "never") - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_never_find_unused(self): - self._run_basic_test("nccl", "never", find_unused_parameters=True) - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_always(self): - self._run_basic_test("nccl", "always", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_except_last(self): - self._run_basic_test("nccl", "except_last", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_never(self): - self._run_basic_test("gloo", "never") - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_never_find_unused(self): - self._run_basic_test("gloo", "never", find_unused_parameters=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_always(self): - self._run_basic_test("gloo", "always", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_except_last(self): - self._run_basic_test("gloo", "except_last", static_graph=True) - - def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False, static_graph=False): - dist.init_process_group( - backend=backend, - init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), - world_size=self.world_size, - rank=self.rank, - ) - - # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another - # pipe between GPU 2 and 3. Both replicas are replicated via DDP. - fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank) - - class MyModule(nn.Module): - def __init__(self, device): - super().__init__() - self.fc2 = nn.Linear(8, 4, bias=False).cuda(device) - self.fc3 = nn.Linear(4, 2, bias=False).cuda(device) - - def forward(self, inp): - if find_unused_parameters: - return self.fc2(inp) - else: - return self.fc3(self.fc2(inp)) - - layer2 = MyModule(2 * self.rank + 1) - model = nn.Sequential( - fc1, - layer2 - ) - model = Pipe(model, chunks=2, checkpoint=checkpoint) - model = DistributedDataParallel( - model, - find_unused_parameters=find_unused_parameters, - static_graph=static_graph, - ) - - # Ensure inputs are different across ranks to verify that gradient - # sync indeed occurs. - model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - out = model(model_input).local_value() - out.sum().backward() - - # Run forward again for find_unused_parameters to trigger any potential errors. - if find_unused_parameters: - # Ensure inputs are different across ranks to verify that gradient - # sync indeed occurs. - unused_param_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - model(unused_param_input).local_value().sum().backward() - - # Run a few more iterations of fwd + bwd to ensure gradient synchronization - # occurs properly across iterations via delay_all_reduce/bucketized allreduce. - for _ in range(3): - model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - out = model(model_input).local_value() - out.sum().backward() - - # Check grads - output = [torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad)] - dist.all_gather(output, fc1.weight.grad) - self.assertEqual(output[0], output[1]) - - output = [torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad)] - dist.all_gather(output, layer2.fc2.weight.grad) - self.assertEqual(output[0], output[1]) - - if not find_unused_parameters: - output = [torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad)] - dist.all_gather(output, layer2.fc3.weight.grad) - self.assertEqual(output[0], output[1]) diff --git a/torch/testing/_internal/distributed/pipeline/__init__.py b/torch/testing/_internal/distributed/pipeline/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/testing/_internal/distributed/rpc_utils.py b/torch/testing/_internal/distributed/rpc_utils.py index cdbbdcfd0681..5b6e2c90770f 100644 --- a/torch/testing/_internal/distributed/rpc_utils.py +++ b/torch/testing/_internal/distributed/rpc_utils.py @@ -16,9 +16,6 @@ DdpComparisonTest, DdpUnderDistAutogradTest, ) -from torch.testing._internal.distributed.pipe_with_ddp_test import ( - PipeWithDDPTest, -) from torch.testing._internal.distributed.nn.api.remote_module_test import ( CudaRemoteModuleTest, RemoteModuleTest, @@ -121,7 +118,6 @@ def tearDown(self): CudaDistAutogradTest, CudaRemoteModuleTest, CudaDdpComparisonTest, - PipeWithDDPTest, ] From cf77e7dd9770caf65e898ac2ee82045aa0408e30 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Sat, 1 Jun 2024 15:30:32 -0700 Subject: [PATCH 311/706] [inductor] Enable subprocess-based parallel compile as the default (#126817) Differential Revision: [D58056502](https://our.internmc.facebook.com/intern/diff/D58056502) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126817 Approved by: https://github.com/eellison --- torch/_inductor/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 0b4cb41d9b2b..bf1e459251aa 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -387,7 +387,9 @@ def is_fbcode(): # The multiprocessing start method to use for inductor workers in the codecache. # "subprocess", "fork", or "spawn" def decide_worker_start_method(): - start_method = os.environ.get("TORCHINDUCTOR_WORKER_START", "fork") + start_method = os.environ.get( + "TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess" + ) assert start_method in [ "subprocess", "fork", From f325b393038c55c30e5b5ce59709b6da158f03c8 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Mon, 3 Jun 2024 14:03:23 -0700 Subject: [PATCH 312/706] Introduce Inductor passes to micro-pipeline all-gather-matmul and matmul-reduce-scatter in certain cases (#126598) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126598 Approved by: https://github.com/wanchaol --- .../tensor/parallel/test_micro_pipeline_tp.py | 147 ++++++ torch/_inductor/config.py | 2 + .../_inductor/fx_passes/micro_pipeline_tp.py | 468 ++++++++++++++++++ torch/_inductor/fx_passes/post_grad.py | 4 + torch/distributed/_cuda_p2p/__init__.py | 2 + .../distributed/_tensor/common_dtensor.py | 6 +- 6 files changed, 626 insertions(+), 3 deletions(-) create mode 100644 test/distributed/tensor/parallel/test_micro_pipeline_tp.py create mode 100644 torch/_inductor/fx_passes/micro_pipeline_tp.py diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py new file mode 100644 index 000000000000..56ae8a14dcde --- /dev/null +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -0,0 +1,147 @@ +# Owner(s): ["module: c10d"] +import unittest + +import torch +import torch.distributed as dist +from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code +from torch.distributed._cuda_p2p import test_with_non_cuda_p2p_group +from torch.distributed._functional_collectives import ( + all_gather_tensor, + reduce_scatter_tensor, +) +from torch.distributed._tensor import DeviceMesh +from torch.distributed._tensor.placement_types import Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) +from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) +from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule +from torch.testing._internal.distributed.fake_pg import FakeStore +from torch.utils._triton import has_triton + + +@instantiate_parametrized_tests +class MicroPipelineTPTest(TestCase): + def setUp(self): + torch._inductor.config._micro_pipeline_tp = True + + self.rank = 0 + self.world_size = 2 + torch.cuda.set_device("cuda:0") + + store = FakeStore() + dist.init_process_group( + backend="fake", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + + def tearDown(self): + dist.destroy_process_group() + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @parametrize("A_dims", [2, 3]) + @parametrize("gather_dim", [0, 1, 2]) + @fresh_inductor_cache() + def test_fuse_all_gather_matmul(self, A_dims, gather_dim): + if gather_dim >= A_dims: + return + + group = dist.group.WORLD + + def func(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + A = all_gather_tensor(A_shard, gather_dim=gather_dim, group=group) + return A @ B + + if A_dims == 2: + A_shard_shape = [64, 32] + elif A_dims == 3: + A_shard_shape = [2, 64, 32] + else: + raise AssertionError(f"Invalid A_dims: {A_dims}") + + A_shard_shape[gather_dim] //= self.world_size + A_shard = torch.rand(*A_shard_shape, device="cuda") + B = torch.rand(32, 16, device="cuda") + + with test_with_non_cuda_p2p_group(): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, A_shard, B) + + if gather_dim == A_dims - 1: + assert "fused_all_gather_matmul" not in code + assert "all_gather_into_tensor" in code + else: + # Decomposing the matmul on the K dimension is not supported + assert "fused_all_gather_matmul" in code + assert "all_gather_into_tensor" not in code + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @parametrize("A_dims", [2, 3]) + @parametrize("scatter_dim", [0, 1, 2]) + @fresh_inductor_cache() + def test_fuse_matmul_reduce_scatter(self, A_dims, scatter_dim): + if scatter_dim >= A_dims: + return + + group = dist.group.WORLD + + def func(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + return reduce_scatter_tensor(A @ B, "avg", scatter_dim, group) + + if A_dims == 2: + A = torch.rand(64, 32, device="cuda") + elif A_dims == 3: + A = torch.rand(2, 64, 32, device="cuda") + else: + raise AssertionError(f"Invalid A_dims: {A_dims}") + B = torch.rand(32, 16, device="cuda") + + with test_with_non_cuda_p2p_group(): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, A, B) + + assert "fused_matmul_reduce_scatter" in code + assert "reduce_scatter_tensor" not in code + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @parametrize("shard_dim", [0, 1]) + @fresh_inductor_cache() + def test_dtensor_seq_par(self, shard_dim: int): + model = MLPModule(device="cuda", bias=False) + device_mesh = DeviceMesh( + "cuda", + torch.arange(0, self.world_size), + ) + parallelize_plan = { + "net1": ColwiseParallel(input_layouts=Shard(shard_dim)), + "net2": RowwiseParallel(output_layouts=Shard(shard_dim)), + } + model = parallelize_module(model, device_mesh, parallelize_plan) + if shard_dim == 0: + inp = torch.rand(8, 10, device="cuda") + elif shard_dim == 1: + inp = torch.rand(2, 8, 10, device="cuda") + else: + raise AssertionError("Invalid shard_dim") + + with test_with_non_cuda_p2p_group(): + compiled = torch.compile(model) + code = run_and_get_triton_code(compiled, inp) + + assert "fused_all_gather_matmul" in code + assert "all_gather_into_tensor" not in code + assert "fused_matmul_reduce_scatter" in code + assert "reduce_scatter_tensor" not in code + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index bf1e459251aa..dbaa528cd3e5 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -423,6 +423,8 @@ def decide_worker_start_method(): "schedule_comm_wait", ] +_micro_pipeline_tp: bool = False + def decide_compile_threads(): """ diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py new file mode 100644 index 000000000000..20a864377787 --- /dev/null +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -0,0 +1,468 @@ +import operator +from dataclasses import dataclass +from typing import cast, List, Set, Tuple, Union + +import torch + +from ..pattern_matcher import ( + CallFunction, + Ignored, + KeywordArg, + ListOf, + MULTIPLE, + PatternMatcherPass, + register_graph_pattern, +) + +aten = torch.ops.aten +patterns = PatternMatcherPass() + + +def _is_backward(graph: torch.fx.Graph) -> bool: + placeholders = [] + for node in graph.nodes: + if node.op != "placeholder": + break + placeholders.append(node) + return not all(node.name.startswith("primal") for node in placeholders) + + +def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float: + return M * N * K / (M * K + N * K + M * N) + + +def _filter_nodes_by_target(nodes: List[torch.fx.Node], target) -> List[torch.fx.Node]: + return [x for x in nodes if x.target == target] + + +def _find_ancestors(node: torch.fx.Node) -> Set[torch.fx.Node]: + ancestors = set() + ancestors.add(node) + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.all_input_nodes: + if inp not in ancestors: + ancestors.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return {node for node in ancestors if node.op != "placeholder"} + + +def _can_schedule_y_before_x( + x: torch.fx.Node, y: torch.fx.Node +) -> Tuple[bool, Set[torch.fx.Node]]: + """ + Check if y can be reordered before x and return the ancestors of y + (inclusive). + """ + y_ancestors = _find_ancestors(y) + if x in y_ancestors: + return False, y_ancestors + + return True, y_ancestors + + +@dataclass +class _2DMatmul: + node: torch.fx.Node + B_node: torch.fx.Node + B_node_ancestors: Set[torch.fx.Node] + + def replace_with(self, new_node: torch.fx.Node) -> None: + """ + Replace the matmul with the new node. + """ + self.node.replace_all_uses_with(new_node) + + +@dataclass +class _NDMatmul: + nodes: List[torch.fx.Node] + B_node: torch.fx.Node + B_node_ancestors: Set[torch.fx.Node] + + def replace_with(self, new_node: torch.fx.Node) -> None: + """ + Replace the matmul with the new node. + + ND-matmul is a sequence of reshape -> mm -> reshape in the graph. The + second reshape node is replaced with `new_node`. + + In addition, we ensure that the original mm node ends up with zero + users by replacing it with a reverse reshape of `new_node`. + """ + graph = new_node.graph + assert len(self.nodes) == 3 + mm_node = self.nodes[1] + output_reshape_node = self.nodes[2] + + assert mm_node.target == aten.mm.default + assert output_reshape_node.target == aten.reshape.default + + output_reshape_node.replace_all_uses_with(new_node) + if len(mm_node.users) > 1: + with graph.inserting_after(new_node): + new_mm_node = graph.call_function( + aten.reshape.default, + args=(new_node, list(mm_node.meta["val"].shape)), + ) + mm_node.replace_all_uses_with(new_mm_node) + + +def _find_consumer_matmuls(node: torch.fx.Node) -> List[Union[_2DMatmul, _NDMatmul]]: + """ + Find the matmuls that use `node` as the lhs argument. + This function effective normalizes 2D and ND matmuls. + """ + matmuls: List[Union[_2DMatmul, _NDMatmul]] = [] + + for user in node.users: + # ND matmuls + if user.target == aten.reshape.default: + for mm_node in user.users: + if mm_node.target != aten.mm.default: + continue + + B_node = cast(torch.fx.Node, mm_node.args[1]) + can_schedule, B_node_ancestors = _can_schedule_y_before_x(user, B_node) + if not can_schedule: + continue + + for reshape_node in mm_node.users: + if reshape_node.target != aten.reshape.default: + continue + + matmul_out_shape = torch.Size( + [ + *node.meta["val"].shape[:-1], + B_node.meta["val"].shape[-1], + ] + ) + if reshape_node.meta["val"].shape != matmul_out_shape: + continue + + matmuls.append( + _NDMatmul( + nodes=[user, mm_node, reshape_node], + B_node=B_node, + B_node_ancestors=B_node_ancestors, + ) + ) + # 2D matmuls + elif user.target == aten.mm.default: + B_node = cast(torch.fx.Node, user.args[1]) + can_schedule, B_node_ancestors = _can_schedule_y_before_x(user, B_node) + if not can_schedule: + continue + + matmuls.append( + _2DMatmul( + node=user, + B_node=B_node, + B_node_ancestors=B_node_ancestors, + ), + ) + return matmuls + + +def _find_all_gather_node_from_match(match) -> Tuple[torch.fx.Node, torch.fx.Node]: + """ + Processes match for ZeroDimAllGather and NonZeroDimAllGather. Returns the + all-gather node (all_gather_into_tensor.default) and the all-gather result + node (wait_tensor.default for gather_dim == 0 and aten.cat.default for + gather_dim == 1). This function effectively normalizes zero-dim and + non-zero-dim all_gather_tensor. + """ + # gather_dim == 0 + if len(match.nodes) == 2: + return match.nodes[0], match.nodes[1] + # gather_dim == 1 + ag_node = _filter_nodes_by_target( + match.nodes, + torch.ops._c10d_functional.all_gather_into_tensor.default, + )[0] + ag_res_node = _filter_nodes_by_target( + match.nodes, + aten.cat.default, + )[0] + shard_node = ag_node.args[0] + return ag_node, ag_res_node + + +def fuse_all_gather_matmul_zero_dim(match, shard, group_name): + fuse_all_gather_matmul(match, shard, 0, group_name) + + +def fuse_all_gather_matmul(match, shard, gather_dim, group_name): + """ + Fused the pattern + + A = all_gather_tensor(A_shard, gather_dim, group_name) + C_0 = torch.matmul(A, B_0) + C_1 = torch.matmul(A, B_1) + C_2 = torch.matmul(A, B_2) + ... + + into + + A, Cs = torch.ops.cuda_p2p.fused_all_gather_matmul( + A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name, + ) + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + c10d = torch.ops._c10d_functional + from torch.distributed._cuda_p2p import is_cuda_p2p_group + from torch.distributed.distributed_c10d import _resolve_process_group + + if gather_dim >= len(shard.meta["val"].shape) - 1: + # Decomposing the matmul on the K dimension is not supported + return + + if not is_cuda_p2p_group(_resolve_process_group(group_name)): + return + + # Normalize zero-dim and non-zero-dim all_gather_tensor + ag_node, ag_res_node = _find_all_gather_node_from_match(match) + + # Find consumer matmuls for eligible for fusion + matmuls = _find_consumer_matmuls(ag_res_node) + if len(matmuls) == 0: + return + + shard_node = ag_node.args[0] + B_nodes = [matmul.B_node for matmul in matmuls] + + # Fuse the all_gather_tensor with the eligible matmuls + graph = ag_node.graph + with graph.inserting_before(ag_node): + fused_node = graph.call_function( + torch.ops.cuda_p2p.fused_all_gather_matmul.default, + args=(shard_node, B_nodes, gather_dim, group_name), + ) + new_ag_node = graph.call_function( + operator.getitem, + args=(fused_node, 0), + ) + new_out_nodes = graph.call_function( + operator.getitem, + args=(fused_node, 1), + ) + for idx, matmul in enumerate(matmuls): + new_out_node = graph.call_function( + operator.getitem, + args=(new_out_nodes, idx), + ) + matmul.replace_with(new_out_node) + ag_res_node.replace_all_uses_with(new_ag_node) + + # Raise ancestors of B that are topologically ordered between ag_res_node + # and the matmul above fused_node. _find_consumer_matmuls guarantees that + # ag_res_node is not an ancestor of B. + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + {x for matmul in matmuls for x in matmul.B_node_ancestors}, + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + graph.eliminate_dead_code() + return + + +def fuse_matmul_reduce_scatter_zero_dim(match, rs_input, reduce_op, group_name): + fuse_matmul_reduce_scatter(match, rs_input, reduce_op, 0, group_name) + + +def fuse_matmul_reduce_scatter(match, rs_input, reduce_op, scatter_dim, group_name): + """ + Fused the pattern + + reduce_scatter_tensor(A @ B, scatter_dim, group_name) + + into + + torch.ops.cuda_p2p.fused_matmul_reduce_scatter( + A, B, scatter_dim, group_name, + ) + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + c10d = torch.ops._c10d_functional + from torch.distributed._cuda_p2p import is_cuda_p2p_group + from torch.distributed.distributed_c10d import _resolve_process_group + + if not is_cuda_p2p_group(_resolve_process_group(group_name)): + return + + # Currently fused_matmul_reduce_scatter doesn't return the matmul result, + # so we can't apply the fusion if the matmul result is used by multiple + # users. This is not a fundamental limitation of the fused op and can be + # addressed if needed. + if len(rs_input.users) != 1: + return + + # 2D matmul + if rs_input.target == aten.mm.default: + A_node, B_node = rs_input.args[0], rs_input.args[1] + # ND matmul + elif rs_input.target == aten.reshape.default: + mm_node = rs_input.args[0] + if mm_node.target != aten.mm.default or len(mm_node.users) != 1: + return + + A_node, B_node = mm_node.args[0], mm_node.args[1] + if A_node.target != aten.reshape.default: + return + A_node = A_node.args[0] + # Not matmul + else: + return + + rs_res_node = _filter_nodes_by_target(match.nodes, c10d.wait_tensor.default)[0] + if not _can_schedule_y_before_x(rs_res_node, B_node): + return + + graph = rs_res_node.graph + with graph.inserting_before(rs_res_node): + fused_node = graph.call_function( + torch.ops.cuda_p2p.fused_matmul_reduce_scatter.default, + args=(A_node, B_node, reduce_op, scatter_dim, group_name), + ) + rs_res_node.replace_all_uses_with(fused_node) + + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + _find_ancestors(B_node), + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + graph.eliminate_dead_code() + + +def _register_passes(): + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + c10d = torch.ops._c10d_functional + + # Matches funcol.all_gather_tensor with gather_dim == 0 + ZeroDimAllGather = CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.all_gather_into_tensor.default, + KeywordArg("shard"), + Ignored(), + KeywordArg("group_name"), + ), + ) + + # Matches funcol.all_gather_tensor with gather_dim > 0 + # NOTE: this pattern may need to be updated if funcol.all_gather_tensor changes + NonZeroDimAllGather = CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.all_gather_into_tensor.default, + KeywordArg("shard"), + Ignored(), + KeywordArg("group_name"), + ), + ), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ), + ), + KeywordArg("gather_dim"), + _users=MULTIPLE, + ) + + register_graph_pattern( + ZeroDimAllGather, + pass_dict=patterns, + )(fuse_all_gather_matmul_zero_dim) + + register_graph_pattern( + NonZeroDimAllGather, + pass_dict=patterns, + )(fuse_all_gather_matmul) + + # Matches funcol.reduce_scatter_tensor with scatter_dim == 0 + ZeroDimReduceScatter = CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.reduce_scatter_tensor.default, + KeywordArg("rs_input"), + KeywordArg("reduce_op"), + Ignored(), + KeywordArg("group_name"), + ), + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim > 0 + # NOTE: this pattern may need to be updated if funcol.reduce_scatter_tensor + # changes + NonZeroDimReduceScatter = CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.reduce_scatter_tensor.default, + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("rs_input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + KeywordArg("reduce_op"), + Ignored(), + KeywordArg("group_name"), + ), + ) + + register_graph_pattern( + ZeroDimReduceScatter, + pass_dict=patterns, + )(fuse_matmul_reduce_scatter_zero_dim) + + register_graph_pattern( + NonZeroDimReduceScatter, + pass_dict=patterns, + )(fuse_matmul_reduce_scatter) + + +_register_passes() diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index b18577a02ffc..dd1900000f7c 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -43,6 +43,7 @@ from ..virtualized import V from .ddp_fusion import fuse_ddp_communication from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS +from .micro_pipeline_tp import patterns as micro_pipeline_tp_patterns from .pre_grad import is_same_dict, save_inductor_dict from .reinplace import reinplace_inplaceable_ops from .split_cat import POST_GRAD_PATTERNS @@ -103,6 +104,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): f"{pattern_matcher_pass.pass_name}_post_grad" ] = upload_graph(gm.graph) + if config._micro_pipeline_tp: + micro_pipeline_tp_patterns.apply(gm) + if config._fuse_ddp_communication: fuse_ddp_communication( gm.graph, diff --git a/torch/distributed/_cuda_p2p/__init__.py b/torch/distributed/_cuda_p2p/__init__.py index 4d07bfcbf067..2c77bd375f34 100644 --- a/torch/distributed/_cuda_p2p/__init__.py +++ b/torch/distributed/_cuda_p2p/__init__.py @@ -91,6 +91,8 @@ def _create_cuda_p2p_group( def is_cuda_p2p_group(group: c10d.ProcessGroup) -> bool: + if _test_with_non_cuda_p2p_group: + return True if not c10d.is_nccl_available(): return False try: diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 1012d065e7ad..2e4a183ea338 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -62,12 +62,12 @@ def forward(self, x): class MLPModule(nn.Module): - def __init__(self, device): + def __init__(self, device, bias: bool = True): super().__init__() torch.manual_seed(5) - self.net1 = nn.Linear(10, 16, device=device) + self.net1 = nn.Linear(10, 16, bias=bias, device=device) self.relu = nn.ReLU() - self.net2 = nn.Linear(16, 10, device=device) + self.net2 = nn.Linear(16, 10, bias=bias, device=device) def forward(self, x): return self.net2(self.relu(self.net1(x))) From 49048e7f26e3149ed6e4c003dcc9b6ee3b269291 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 3 Jun 2024 22:00:11 -0700 Subject: [PATCH 313/706] [FSDP2] Fixed variable shadowing of `module` (#127776) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127776 Approved by: https://github.com/wanchaol ghstack dependencies: #127771 --- torch/distributed/_composable/fsdp/fully_shard.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 3efb8f7afd85..337c9a7e40b8 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -128,10 +128,10 @@ def fully_shard( offload_policy, ) - # for dynamo - for module in managed_modules: - module._is_fsdp_managed_module = True # type: ignore[assignment] - module._fsdp_use_orig_params = True # type: ignore[assignment] + # For Dynamo + for managed_module in managed_modules: + managed_module._is_fsdp_managed_module = True # type: ignore[assignment] + managed_module._fsdp_use_orig_params = True # type: ignore[assignment] # Place FSDP leftmost for highest priority in the method resolution order cls = module.__class__ From db515b6ac7c875c39e9068ba81e72574e8b79e55 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Tue, 4 Jun 2024 11:16:02 +0000 Subject: [PATCH 314/706] [ROCm] Fix error in torch.cuda initialisation if amdsmi is not available (#127528) Reported in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/15874 When nvml_count is set via https://github.com/pytorch/pytorch/blob/9f73c65b8f644d599ff3ff53927b738cfbb7d191/torch/cuda/__init__.py#L834 If amdsmi is not available this will throw an error ``` File "python3.10/site-packages/torch/cuda/__init__.py", line 634, in _raw_device_count_amdsmi except amdsmi.AmdSmiException as e: NameError: name 'amdsmi' is not defined ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127528 Approved by: https://github.com/jeffdaily, https://github.com/eqy, https://github.com/pruthvistony, https://github.com/atalman --- torch/cuda/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index ec4c0297a4b4..2b2fe32154b2 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -629,6 +629,8 @@ def parse_list_with_prefix(lst: str, prefix: str) -> List[str]: def _raw_device_count_amdsmi() -> int: + if not _HAS_PYNVML: # If amdsmi is not available + return -1 try: amdsmi.amdsmi_init() except amdsmi.AmdSmiException as e: @@ -659,6 +661,8 @@ def _raw_device_count_nvml() -> int: def _raw_device_uuid_amdsmi() -> Optional[List[str]]: from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer + if not _HAS_PYNVML: # If amdsmi is not available + return None try: amdsmi.amdsmi_init() except amdsmi.AmdSmiException: From fb696ef3aa34e20c0fef1c0210a397abd3ea5885 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 4 Jun 2024 04:39:27 -0700 Subject: [PATCH 315/706] Complete revamp of float/promotion sympy handling (#126905) At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano --- c10/core/SymNodeImpl.h | 18 + test/dynamo/test_dynamic_shapes.py | 7 - test/dynamo/test_export.py | 3 +- test/dynamo/test_misc.py | 17 +- test/inductor/test_indexing.py | 72 +--- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 8 +- test/test_dynamic_shapes.py | 208 +++------- test/test_proxy_tensor.py | 3 +- test/test_sympy_utils.py | 122 +++--- torch/__init__.py | 162 +++++++- torch/_export/serde/serialize.py | 9 +- torch/_inductor/bounds.py | 5 + torch/_inductor/codegen/common.py | 168 ++++++-- torch/_inductor/codegen/cpp.py | 4 +- torch/_inductor/codegen/cpp_utils.py | 45 +- torch/_inductor/codegen/triton.py | 58 ++- torch/_inductor/graph.py | 5 +- torch/_inductor/ir.py | 16 +- torch/_inductor/kernel/flex_attention.py | 5 +- torch/_inductor/lowering.py | 6 +- torch/_inductor/ops_handler.py | 60 ++- torch/_inductor/select_algorithm.py | 4 +- torch/_inductor/sizevars.py | 20 +- torch/_inductor/utils.py | 2 +- torch/_subclasses/fake_tensor.py | 2 +- torch/csrc/jit/python/init.cpp | 5 + torch/csrc/utils/python_symnode.h | 20 + torch/export/dynamic_shapes.py | 9 +- torch/fx/experimental/recording.py | 8 +- torch/fx/experimental/sym_node.py | 204 +++++++-- torch/fx/experimental/symbolic_shapes.py | 80 ++-- torch/fx/experimental/validator.py | 32 +- torch/utils/_sympy/functions.py | 389 ++++++++++++++---- torch/utils/_sympy/interp.py | 67 ++- torch/utils/_sympy/reference.py | 151 ++++--- torch/utils/_sympy/solve.py | 1 + torch/utils/_sympy/value_ranges.py | 275 +++++++++---- 37 files changed, 1603 insertions(+), 667 deletions(-) diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 9ffab5065109..bb92b09775b7 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -49,15 +49,33 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode mul(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + // NB: legacy, prefer float_truediv or int_truediv virtual SymNode truediv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode float_truediv(const SymNode& other) { + return truediv(other); + } + virtual SymNode int_truediv(const SymNode& other) { + return truediv(other); + } + // NB: legacy, prefer float_pow or pow_by_natural virtual SymNode pow(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode float_pow(const SymNode& other) { + return pow(other); + } + virtual SymNode pow_by_natural(const SymNode& other) { + return pow(other); + } + // NB: legacy, prefer int_floordiv virtual SymNode floordiv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode int_floordiv(const SymNode& other) { + return floordiv(other); + } virtual SymNode mod(const SymNode& other) { TORCH_CHECK(false, "NYI"); } diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 175ed573391b..57671e620e56 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -78,13 +78,6 @@ def make_dynamic_cls(cls): del test if TEST_Z3: - # this only fails when z3 is available - unittest.expectedFailure( - # SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'. - # Ref: https://github.com/sympy/sympy/issues/25146 - DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821 - ) - if not config.inline_inbuilt_nn_modules: # TODO model is somehow not being freed when z3 is available unittest.expectedFailure( diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 9f1417e23247..7ae0f839f6ff 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2385,8 +2385,7 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, "Constraints violated .*!(.*\n)*.*" - "by dim0 = 2\\*dim1(.*\n)*.*" - "Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*", + "Not all values of dim0 .* satisfy the generated guard 4 <= .* and .* <= 10(.*\n)*.*", ): torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index bcb0fd18818e..dc2b9530f0dd 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9309,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9343,7 +9343,7 @@ def test_shape_env_equal_unbacked(self): > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), u1: ValueRanges(lower=0, upper=1, is_bool=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False)} + > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False, is_int=True, is_float=False), u1: ValueRanges(lower=0, upper=1, is_bool=False, is_int=True, is_float=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False, is_int=False, is_float=True)} > Right: {} """, ) @@ -9420,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self): > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} """, ) self._replay_and_check(main) @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self): > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} """, ) self._replay_and_check(main) @@ -9484,10 +9484,7 @@ def test_shape_env_equal_runtime_assert(self): ShapeEnv not equal: field values don't match: ==> deferred_runtime_asserts: values don't match. - > Left: {u0: [Eq(Mod(u0, 3), 0)]} - > Right: {} -==> divisible: values don't match. - > Left: {Mod(u0, 3)} + > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} > Right: {} ==> name_to_node: values don't match. > Left: {_assert, eq, mod, u0} diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 299a619f9cd6..da527cfbb1d8 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -11,7 +11,12 @@ instantiate_parametrized_tests, parametrize, ) -from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Round, RoundDecimal +from torch.utils._sympy.functions import ( + FloorDiv, + ModularIndexing, + RoundDecimal, + RoundToInt, +) class TestIndexingSimplification(InductorTestCase): @@ -168,21 +173,11 @@ def test_print_pow(self): common_cases = [ # expr, result - # Test exprs. - ( - s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1), - lambda c, L: f"((-1{L})*({c}/((-1{L}) + (2{L}*foo)))) + (foo*({c}/((-1{L}) + (2{L}*foo))))", - ), - (s1 / (s2 - s3), lambda c, L: f"foo*({c}/(bar + ((-1{L})*baz)))"), # Test Pow directly. ( sympy.Pow(s1 + s2, 0), lambda _, L: f"1{L}", ), # note: simplified before _print_Pow - ( - sympy.Pow(s1 + s2, -3), - lambda c, _: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))", - ), ] gpu_cases = common_cases + [ @@ -231,12 +226,10 @@ def test_print_ceil(self): self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""") def test_print_round(self): - expr = Round(sympy.Symbol("x", integer=True) / 2) + expr = RoundToInt(sympy.Symbol("x", integer=True) / 2) self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""") self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""") - self.assertExpectedInline( - texpr(expr), """libdevice.llrint((1/2)*x).to(tl.int64)""" - ) + self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""") @parametrize("ndigits", [-1, 0, 1]) def test_print_round_decimal(self, ndigits): @@ -251,45 +244,18 @@ def test_print_round_decimal(self, ndigits): f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}", ) - expr = RoundDecimal(sympy.Symbol("x", integer=True), ndigits) - if ndigits >= 0: - for do_print in [pexpr, cexpr, texpr]: - self.assertEqual(do_print(expr), "x") - else: - self.assertEqual(pexpr(expr), f"round(x, {ndigits})") - for do_print in [cexpr, texpr]: - with self.assertRaisesRegex( - ValueError, "only non-negative ndigits are currently supported" - ): - do_print(expr) - def test_print_floor_div(self): - for integer in [True, False]: - s1 = sympy.Symbol("s1", integer=integer) - s2 = sympy.Symbol("s2", integer=integer) - expr = FloorDiv(s1, s2) - self.assertEqual(pexpr(expr), "(s1 // s2)") - if integer: - self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") - else: - self.assertEqual( - cexpr(expr), - "c10::div_floor_floating(static_cast(s1), static_cast(s2))", - ) - - for integer in [True, False]: - s1 = sympy.Symbol("s1", integer=integer) - s2 = sympy.S(-1) - expr = FloorDiv(s1, s2) - if integer: - self.assertEqual(pexpr(expr), "(-1)*s1") - self.assertEqual(cexpr(expr), "(-1L)*s1") - else: - self.assertEqual(pexpr(expr), "(s1 // (-1))") - self.assertEqual( - cexpr(expr), - "c10::div_floor_floating(static_cast(s1), static_cast((-1L)))", - ) + s1 = sympy.Symbol("s1", integer=True) + s2 = sympy.Symbol("s2", integer=True) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(s1 // s2)") + self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") + + s1 = sympy.Symbol("s1", integer=True) + s2 = sympy.S(-1) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(-1)*s1") + self.assertEqual(cexpr(expr), "(-1L)*s1") def test_print_Min_Max(self): cases = ( diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index b70bfbf9c4a7..0f0e01bc0dc2 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -158,8 +158,12 @@ def forward(self, x, y): torch.tensor([operator.sub(x.item(), y.item())]), torch.tensor([operator.mul(x.item(), y.item())]), torch.tensor([operator.truediv(x.item(), y.item())]), - torch.tensor([operator.floordiv(x.item(), y.item())]), - torch.tensor([operator.pow(x.item(), y.item())]), + # This requires torch.sym_float, probably easy to lower to + # ONNX but I don't know where to put it + # torch.tensor([operator.floordiv(x.item(), y.item())]), + # NB: abs so that the base and exponent are provably + # non-negative, so we don't generate runtime asserts + torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]), torch.tensor([operator.abs(x.item())]), torch.tensor([operator.neg(x.item())]), torch.tensor([math.ceil(x.item())]), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d548e9df0707..82503b5866b5 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -205,15 +205,15 @@ def create_symtype(cls, pytype, shape_env, val, duck=True): # TODO: default duck to False -def create_symint(shape_env, i: int, duck=True): +def create_symint(shape_env, i: int, duck=True) -> SymInt: return create_symtype(SymInt, int, shape_env, i, duck=duck) -def create_symbool(shape_env, b: bool): +def create_symbool(shape_env, b: bool) -> SymBool: return create_symtype(SymBool, bool, shape_env, b) -def create_symfloat(shape_env, f: float): +def create_symfloat(shape_env, f: float) -> SymFloat: return create_symtype(SymFloat, float, shape_env, f) @@ -457,14 +457,16 @@ def test_sym_int(self): r = sym_int(a1 / 2) self.assertEqual(guard_int(r), 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" + ) a3 = create_symint(shape_env, 3) r = sym_int(2.0 * torch.sym_float(a3)) self.assertEqual(guard_int(r), 6) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)""" + str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" ) def test_sym_sqrt(self): @@ -474,7 +476,7 @@ def test_sym_sqrt(self): self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)""" + str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)""" ) def test_sym_floor(self): @@ -483,11 +485,17 @@ def test_sym_floor(self): r = math.floor(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""") + self.assertExpectedInline( + str(shape_env.guards[0][0]), + """Eq(floor(IntTrueDiv(s0, 2)), 2)""", + ) r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), + """Eq(floor(3.0*ToFloat(s0)), 15)""", + ) def test_sym_trunc(self): shape_env = ShapeEnv() @@ -495,12 +503,14 @@ def test_sym_trunc(self): r = math.trunc(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""") + self.assertExpectedInline( + str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" + ) r = torch.sym_int(torch.sym_sqrt(a0)) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)""" + str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)""" ) def test_sym_ceil(self): @@ -510,12 +520,17 @@ def test_sym_ceil(self): self.assertEqual(r, 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""" + str(shape_env.guards[0][0]), + """Eq(ceiling(IntTrueDiv(s0, 2)), 3)""", ) - r = math.floor(3.0 * a0) + r1 = 3.0 * a0 + r = math.floor(r1) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), + """Eq(floor(3.0*ToFloat(s0)), 15)""", + ) def test_sym_ite(self): shape_env = ShapeEnv() @@ -962,8 +977,14 @@ def test_ephemeral_source_unified_with_non_ephemeral_source(self): ) class TestSymNumberMagicMethods(TestCase): def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): + with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn): + return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn) + + def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn): # Helper function # NB: don't use one as that will get specialized + # TODO: We don't have to circuitously create the float, can just + # create a symfloat directly seed_node = (create_symint(shape_env, 2) / 2.0).node bool_seed_node = (create_symint(shape_env, 2) == 2).node @@ -976,27 +997,42 @@ def get_sym_inp(inp): else: return torch.SymFloat(to_node(seed_node, inp)) + if fn == "float_pow": + if inp1 < 0: + return + + if fn == "pow_by_natural": + if isinstance(inp1, float) or isinstance(inp2, float): + return + if inp2 < 0: + return + def maybe_xfail(inp1, inp2): if fn == "sym_sqrt" and inp1 < 0: # ValueError: math domain error return self.assertRaises((ValueError,)) - elif fn in ("truediv", "floordiv", "mod") and inp2 == 0: + elif ( + fn in ("float_truediv", "int_truediv", "int_floordiv", "mod") + and inp2 == 0 + ): # ZeroDivisionError: division by zero return self.assertRaises((ZeroDivisionError,)) - elif fn == "pow" and inp1 == 0 and inp2 < 0: + elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0: # ZeroDivisionError: 0.0 cannot be raised to a negative power return self.assertRaises((ZeroDivisionError,)) elif ( - fn == "pow" + # TODO: dear catastrophe waitress, + # this doesn't work + fn in ["float_pow", "pow_by_natural"] and inp1 < 0 - and inp2 in (2.5, -2.5) and ( - type(inp1) in (SymFloat, SymInt) or type(inp2) in (SymFloat, SymInt) + type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat) ) + and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float)) ): # Complex result, which we do not support: # TypeError: Cannot convert complex to float - return self.assertRaises((TypeError,)) + return self.assertRaises((RuntimeError,)) elif fn in ("lshift", "rshift") and not ( isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int)) ): @@ -1080,6 +1116,9 @@ def test_method(self, fn, first_type, second_type): ) and fn in sym_node.only_float_magic_methods: self.skipTest(f"{fn} is not an int method") + if second_type == "float" and fn in ["mod"]: + self.skipTest(f"{fn} only handles int") + is_unary_fn = fn in sym_node.unary_methods or fn == "round" # Second argument is ignored for unary function. So only run for one type if is_unary_fn and second_type == "float": @@ -1251,112 +1290,15 @@ def yield_test_cases(values, negate=True): yield (-x, -y) def test_floordiv_float_int(self): - values = ( - (2.5, 2.1), - (2.1, 2.5), - (2.0, 2.1), - (7, 2.5), - (2.1, 7), - (7, 2), - ) + values = ((7, 2),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) ) - def test_floordiv_bool(self): - values = ( - (False, True), - (True, 2.5), - (2.5, True), - (False, 7), - (7, True), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - # Compares to int since our FloorDiv has no bool support - self.assertEqual( - TestFloorDiv.python_floordiv(x, y), - TestFloorDiv.torch_floordiv(int(x), int(y)), - ) - # Tests that our impl throws - self.assertRaisesRegex( - TypeError, - ( - rf"unsupported operand type\(s\) for //: " - rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" - rf", expected integer or real" - ), - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_complex(self): - values = ( - (1.5 + 2.5j, 1.3 + 3.5j), - (1.5 + 2.5j, 2.5), - (2.5, 1.5 + 2.5j), - (1.5 + 2.5j, 7), - (7, 1.5 + 2.5j), - ) - - for x, y in TestFloorDiv.yield_test_cases(values): - # We don't test error messages to avoid depending on Python - # interpreter version - self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y)) - self.assertRaisesRegex( - TypeError, - ( - rf"unsupported operand type\(s\) for //: " - rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" - rf", expected integer or real" - ), - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_div_by_zero(self): - values = ( - (2.5, 0), - (2.1, 0.0), - (2.3, sympy.Symbol("s", zero=True)), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - # We don't test error messages to avoid depending on Python - # interpreter version - if type(y) is not sympy.Symbol: - self.assertRaises( - ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y) - ) - self.assertRaisesRegex( - ZeroDivisionError, - "division by zero", - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_zero_base(self): - values = ( - (0, 2.5), - (0.0, 2.1), - (sympy.Symbol("s", zero=True), 2.3), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - if type(x) is not sympy.Symbol: - self.assertEqual( - TestFloorDiv.python_floordiv(x, y), - TestFloorDiv.torch_floordiv(x, y), - ) - else: - self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y)) - def test_floordiv_div_by_one(self): - values = ( - (2.5, 1), - (2.1, 1.0), - (2, 1.0), - (2, 1), - ) + values = ((2, 1),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( @@ -1367,12 +1309,7 @@ def test_floordiv_simplify(self): # Tests how we simplify or evaluate FloorDiv without free variables shape_env = ShapeEnv() result = 21 - exprs = ( - 7 * FloorDiv(6, 2), - 7 * FloorDiv(6.28, 2), - 7 * FloorDiv(6.28, 2.0), - 7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))), - ) + exprs = (7 * FloorDiv(6, 2),) for expr in exprs: self.assertEqual(expr, result) @@ -1382,33 +1319,10 @@ def test_floordiv_simplify(self): self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result) - def test_floordiv_simplify_rational(self): - result = 21 - - a = sympy.Symbol("a", integer=True) - b = sympy.Symbol("b") - - cases = [ - (FloorDiv(a, sympy.Rational(1, 8)), 8 * a), - (FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)), - ] - - for expr, expected in cases: - self.assertEqual(expr, expected) - def test_floordiv_assumptions(self): - # We define two Symbols (with different names) for each type to make - # sure the behavior is consistent regardless of whether both arguments - # are the same object or not. cases = ( sympy.Symbol("i1", integer=True), sympy.Symbol("i2", integer=True), - sympy.Symbol("r1", real=True), - sympy.Symbol("r2", real=True), - sympy.Symbol("c1", complex=True, real=False, integer=False), - sympy.Symbol("c2", complex=True, real=False, integer=False), - sympy.Symbol("s1"), - sympy.Symbol("s2"), ) for base, divisor in itertools.product(cases, repeat=2): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index c7b2e51ced20..04483ffba0fc 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1618,7 +1618,8 @@ def f(a): self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) - pow_1 = sym_size_int ** 0.5; sym_size_int = None + sym_float = torch.sym_float(sym_size_int); sym_size_int = None + pow_1 = sym_float ** 0.5; sym_float = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None return div""") diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index c5da8f7fc0da..8b16b2c620fd 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -36,7 +36,12 @@ "floor", "ceil", ] -BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"] +BINARY_OPS = [ + "truediv", "floordiv", + # "truncdiv", # TODO + # NB: pow is float_pow + "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" +] UNARY_BOOL_OPS = ["not_"] BINARY_BOOL_OPS = ["or_", "and_"] @@ -81,16 +86,24 @@ def valid_unary(fn, v): def valid_binary(fn, a, b): if fn == "pow" and ( + # sympy will expand to x*x*... for integral b; don't do it if it's big b > 4 - or ( # sympy will expand to x*x*... for integral b; don't do it if it's big - a <= 0 and b == -1 - ) - or (a == b == 0) # no imaginary numbers # 0**0 is undefined + # no imaginary numbers + or a <= 0 + # 0**0 is undefined + or (a == b == 0) ): return False - elif fn == "mod" and b == 0: + elif fn == "pow_by_natural" and ( + # sympy will expand to x*x*... for integral b; don't do it if it's big + b > 4 + or b < 0 + or (a == b == 0) + ): return False - elif (fn == "div" or fn == "truediv") and b == 0: + elif fn == "mod" and (a < 0 or b <= 0): + return False + elif (fn in ["div", "truediv", "floordiv"]) and b == 0: return False return True @@ -130,27 +143,26 @@ def test_pow_half(self): ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) @parametrize("fn", BINARY_OPS) - @parametrize("dtype_a", ("int", "float")) - @parametrize("dtype_b", ("int", "float")) - def test_binary_ref(self, fn, dtype_a, dtype_b): + @parametrize("dtype", ("int", "float")) + def test_binary_ref(self, fn, dtype): to_dtype = {"int": sympy.Integer, "float": sympy.Float} - dtype_a = to_dtype[dtype_a] - dtype_b = to_dtype[dtype_b] + # Don't test float on int only methods + if dtype == "float" and fn in ["pow_by_natural", "mod"]: + return + dtype = to_dtype[dtype] for a, b in itertools.product(CONSTANTS, repeat=2): if not valid_binary(fn, a, b): continue - a = dtype_a(a) - b = dtype_b(b) + a = dtype(a) + b = dtype(b) with self.subTest(a=a, b=b): r = getattr(ValueRangeAnalysis, fn)(a, b) if r == ValueRanges.unknown(): continue ref_r = getattr(ReferenceAnalysis, fn)(a, b) - # sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf - if fn != "floordiv": - self.assertEqual(r.lower.is_integer, r.upper.is_integer) - self.assertEqual(ref_r.is_integer, r.upper.is_integer) + self.assertEqual(r.lower.is_integer, r.upper.is_integer) + self.assertEqual(ref_r.is_integer, r.upper.is_integer) self.assertEqual(r.lower, r.upper) self.assertEqual(ref_r, r.lower) @@ -200,7 +212,8 @@ def test_binary_bool_ref_range(self, fn): @parametrize("fn", UNARY_OPS) def test_unary_ref_range(self, fn): - vals = [-sympy.oo, *CONSTANTS, sympy.oo] + # TODO: bring back sympy.oo testing for float unary fns + vals = CONSTANTS for a in generate_range(vals): with self.subTest(a=a): ref_r = getattr(ValueRangeAnalysis, fn)(a) @@ -216,40 +229,26 @@ def test_unary_ref_range(self, fn): # This takes about 4s for all the variants @parametrize("fn", BINARY_OPS + COMPARE_OPS) def test_binary_ref_range(self, fn): - vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo] + # TODO: bring back sympy.oo testing for float unary fns + vals = LESS_CONSTANTS for a, b in itertools.product(generate_range(vals), repeat=2): # don't attempt pow on exponents that are too large (but oo is OK) if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: continue with self.subTest(a=a, b=b): - ref_r = getattr(ValueRangeAnalysis, fn)(a, b) for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): if a0 not in a or b0 not in b: continue if not valid_binary(fn, a0, b0): continue with self.subTest(a0=a0, b0=b0): + ref_r = getattr(ValueRangeAnalysis, fn)(a, b) r = getattr(ReferenceAnalysis, fn)( sympy.Integer(a0), sympy.Integer(b0) ) if r.is_finite: self.assertIn(r, ref_r) - def test_rational_bounds(self): - # Repro from https://github.com/pytorch/pytorch/issues/105097 - from sympy import floor, Eq - shape_0 = sympy.Symbol('shape_0', positive=True, integer=True) - new_expr = ( - Eq(30 * floor(4 * ((shape_0 + 1) // 96) * - ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 + - 2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647), - 2880 * floor(((shape_0 + 1) // 96) * - ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 + - 323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764))) - new_range_env = {shape_0: ValueRanges(lower=1, upper=190)} - self.assertTrue(new_expr.subs({shape_0: 95})) - self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)) - class TestSympyInterp(TestCase): @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) @@ -258,7 +257,13 @@ def test_interp(self, fn): if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): return - from sympy.abc import x, y + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Dummy('x', integer=is_integer) + y = sympy.Dummy('y', integer=is_integer) + vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] @@ -300,29 +305,17 @@ def test_python_interp_fx(self, fn): if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: arity = 2 - from sympy.abc import x, y + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Dummy('x', integer=is_integer) + y = sympy.Dummy('y', integer=is_integer) symbols = [x] if arity == 2: symbols = [x, y] - # Workaround mpf from symbol error - if fn == "minimum": - sympy_expr = sympy.Min(x, y) - elif fn == "maximum": - sympy_expr = sympy.Max(x, y) - else: - sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) - - if arity == 1: - def trace_f(px): - return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) - else: - def trace_f(px, py): - return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) - - gm = fx.symbolic_trace(trace_f) - for args in itertools.product(vals, repeat=arity): if arity == 1 and not valid_unary(fn, *args): continue @@ -330,11 +323,28 @@ def trace_f(px, py): continue if fn == "truncdiv" and args[1] == 0: continue - elif fn == "pow" and (args[0] == 0 and args[1] <= 0): + elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0): continue elif fn == "floordiv" and args[1] == 0: continue with self.subTest(args=args): + # Workaround mpf from symbol error + if fn == "minimum": + sympy_expr = sympy.Min(x, y) + elif fn == "maximum": + sympy_expr = sympy.Max(x, y) + else: + sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) + + if arity == 1: + def trace_f(px): + return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) + else: + def trace_f(px, py): + return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) + + gm = fx.symbolic_trace(trace_f) + self.assertEqual( sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), gm(*args) diff --git a/torch/__init__.py b/torch/__init__.py index 18f1752019ec..dfb1da76739d 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -316,6 +316,75 @@ def __index__(self): # Magic methods installed by torch.fx.experimental.sym_node + def __round__(self, ndigits=None): + return self + + def __truediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__float_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_truediv__(other) + + def __rtruediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rfloat_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_truediv__(other) + + def __floordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return torch.sym_float(math.floor(sym_float(self) / other)) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_floordiv__(other) + + def __rfloordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return torch.sym_float(math.floor(other / sym_float(self))) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_floordiv__(other) + + # nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + def __pow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__pow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + # Guards! This guard is necessary because we need to know it to + # determine the output type of this operation + if other >= 0: + return self.__pow_by_natural__(other) + else: + # Mercifully, when the exponent is negative, Python just promotes + # to doubles and does a float pow: + # + # if (Py_SIZE(b) < 0 && c == NULL) { + # /* if exponent is negative and there's no modulus: + # return a float. This works because we know + # that this calls float_pow() which converts its + # arguments to double. */ + # Py_DECREF(a); + # Py_DECREF(b); + # return PyFloat_Type.tp_as_number->nb_power(v, w, x); + # } + return sym_float(self).__pow__(sym_float(other)) + + def __rpow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rpow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + if self >= 0: # self is exponent + return self.__rpow_by_natural__(other) + else: + return sym_float(self).__rpow__(sym_float(other)) + def __eq__(self, other: object) -> builtins.bool: raise AssertionError("type stub not overridden") @@ -337,6 +406,24 @@ def __add__(self, other) -> "SymInt": def __mul__(self, other) -> "SymInt": raise AssertionError("type stub not overridden") + def __pow_by_natural__(self, other) -> "SymInt": + raise AssertionError("type stub not overridden") + + def __rpow_by_natural__(self, other) -> "SymInt": + raise AssertionError("type stub not overridden") + + def __int_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rint_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __int_floordiv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rint_floordiv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + def __sym_max__(self, other): raise AssertionError("type stub not overridden") @@ -371,9 +458,43 @@ def __init__(self, node): # class has a field named node that stores SymNode self.node = node + def __truediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__float_truediv__(sym_float(other)) + + def __rtruediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__rfloat_truediv__(sym_float(other)) + + def __floordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return torch.sym_float(math.floor(self / sym_float(other))) + + def __rfloordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return torch.sym_float(math.floor(sym_float(other) / self)) + def __bool__(self): return self.node.bool_() + # Symbolic power does NOT work with negative base, this is to avoid + # potential complex outputs + def __pow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(self >= 0) + return self.__float_pow__(other) + + def __rpow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(other >= 0) + return self.__rfloat_pow__(other) + # Magic methods installed by torch.fx.experimental.sym_node def __eq__(self, other: object) -> builtins.bool: @@ -391,6 +512,18 @@ def __le__(self, other) -> builtins.bool: def __ge__(self, other) -> builtins.bool: raise AssertionError("type stub not overridden") + def __float_pow__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rfloat_pow__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __float_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rfloat_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + def __trunc__(self): raise AssertionError("type stub not overridden") @@ -524,7 +657,12 @@ def sym_int(a): return py_int(a) # type: ignore[operator] def sym_max(a, b): - """ SymInt-aware utility for max().""" + """ + SymInt-aware utility for max which avoids branching on a < b. + Unlike builtins.max(), this only works for int/float, and it always + promotes to float if any argument is float (unlike builtins.max, which + will faithfully preserve the type of the input argument). + """ from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -532,14 +670,19 @@ def sym_max(a, b): if isinstance(a, (SymInt, SymFloat)): return a.__sym_max__(b) elif isinstance(b, (SymInt, SymFloat)): - # NB: If you actually care about preserving output type exactly - # if you do something like max(0, 0.0), it is NOT sound to treat - # min/max as commutative + # Due to promotion semantics, this is operator is commutative: + # max(1, 1.0) === max(1.0, 1) === 1.0 return b.__sym_max__(a) - return builtins.max(a, b) # type: ignore[operator] + # TODO: Probably can make bool work too, just lazy + assert isinstance(a, (builtins.int, builtins.float)), type(a) + assert isinstance(b, (builtins.int, builtins.float)), type(b) + if isinstance(a, builtins.float) or isinstance(b, builtins.float): + return builtins.float(builtins.max(a, b)) + else: + return builtins.max(a, b) def sym_min(a, b): - """ SymInt-aware utility for max().""" + """ SymInt-aware utility for min().""" from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -548,7 +691,12 @@ def sym_min(a, b): return a.__sym_min__(b) elif isinstance(b, (SymInt, SymFloat)): return b.__sym_min__(a) - return builtins.min(a, b) # type: ignore[operator] + assert isinstance(a, (builtins.int, builtins.float)), type(a) + assert isinstance(b, (builtins.int, builtins.float)), type(b) + if isinstance(a, builtins.float) or isinstance(b, builtins.float): + return builtins.float(builtins.min(a, b)) + else: + return builtins.min(a, b) # Drop in replacement for math.sqrt, math.sin, math.cos etc current_module = sys.modules[__name__] diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 8d6dc939fb5c..9a92c238f950 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1474,10 +1474,15 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: # Here we force symbols corresponding to SymInts to be at least integers. # Otherwise some expressions that the shape env would otherwise evaluate to False, # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. + # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024) sym = sym.subs( {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} ) - if isinstance(sym, sympy.Symbol): + # We need to check if the symbol has already been allocated, + # self.symbol_name_to_symbol is not enough because the + # integer-ification of symbols can induce simplification; + # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral + if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val: self.symbol_name_to_symbol[val.expr_str] = sym if hint is not None: self.shape_env.add_var_to_val(sym, hint) @@ -1496,7 +1501,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: free_symbols = sym.free_symbols for s in free_symbols: if s.name not in self.symbol_name_to_symbol: - self.symbol_name_to_symbol[s.name] = s + self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment] if vr := self.symbol_name_to_range.get(s.name): self.shape_env.constrain_symbol_range( s, diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 4640ec4dce6b..212b79e35bf9 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,3 +1,4 @@ +import logging import operator from functools import partial from typing import Any, Callable, Dict @@ -11,6 +12,9 @@ from .virtualized import V +log = logging.getLogger(__name__) + + class BoundVars: """ Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() @@ -55,6 +59,7 @@ def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: with V.set_ops_handler(ValueRangeAnalysis()): interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) interpreter.run(V.get_ops_handler(), initial_env=self._bounds) return self._bounds diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index f7b3e7a45d6e..9f4783a8fc59 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -340,6 +340,8 @@ def propagate_scheduler_node(cls, node): DataTypePropagation.propagate_loopbody(node._body) +# This printer contains rules that are supposed to be generic for both C/C++ and +# Python class ExprPrinter(Printer): @staticmethod def paren(string): @@ -369,12 +371,6 @@ def all_in_parens(string): return string return f"({string})" - def _print_Infinity(self, expr): - return "math.inf" - - def _print_NegativeInfinity(self, expr): - return "-math.inf" - def _print_Relational(self, expr): return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) @@ -384,11 +380,14 @@ def _print_Mul(self, expr): def _print_Add(self, expr): return " + ".join(map(self.paren, map(self._print, expr.args))) + # NB: this is OK to put here, because Mod is only defined for positive + # numbers, and so across C/Python its behavior is consistent def _print_Mod(self, expr): return " % ".join(map(self.paren, map(self._print, expr.args))) - def _print_FloorDiv(self, expr): - raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) @@ -399,10 +398,84 @@ def _print_GreaterThan(self, expr): # Go figure... return " >= ".join(map(self.paren, map(self._print, expr.args))) + # NB: The C implementation is injected into codegen at + # torch/_inductor/codegen/wrapper.py def _print_align(self, expr): assert len(expr.args) == 1 return f"align({self._print(expr.args[0])})" + # This must be implemented because sympy will collect x * x into Pow(x, 2), without + # any explicit intervention. We print it just like x * x, notably, we + # never generate sympy.Pow with floats. + # + # NB: this pow by natural, you should never have used builtin sympy.pow + # for FloatPow, and a symbolic exponent should be PowByNatural. These + # means exp is guaranteed to be integer. + def _print_Pow(self, expr): + base, exp = expr.args + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + assert exp >= 0 + if exp > 0: + return "*".join([self.paren(base)] * exp) + else: # exp == 0 + return "1" + + # Explicit NotImplemented functions are to prevent default sympy printing + # behavior, which will just barf out ToFloat(...) to your IR. The error + # message is better here because it tells you which printer class it needs + # to go in. + + def _print_ToFloat(self, expr): + raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") + + def _print_Infinity(self, expr): + raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") + + def _print_NegativeInfinity(self, expr): + raise NotImplementedError( + f"_print_NegativeInfinity not implemented for {type(self)}" + ) + + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_PythonMod(self, expr): + raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") + + def _print_IntTrueDiv(self, expr): + raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") + + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr): + raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") + + def _print_TruncToInt(self, expr): + raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") + + def _print_RoundToInt(self, expr): + raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") + + def _print_RoundDecimal(self, expr): + raise NotImplementedError( + f"_print_RoundDecimal not implemented for {type(self)}" + ) + + # NB: Some float operations are INTENTIONALLY not implemented for + # printers. You can implement them as a quick unblock, but it is better + # to ask yourself why we haven't done this computation in the Tensor + # universe instead + + def _print_TruncToFloat(self, expr): + raise NotImplementedError( + f"_print_TruncToFloat not implemented for {type(self)}" + ) + def doprint(self, expr, *, simplify: bool = True): # TODO: why are people passing strings to the printer here :think: if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): @@ -411,6 +484,10 @@ def doprint(self, expr, *, simplify: bool = True): class PythonPrinter(ExprPrinter): + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"float({self._print(expr.args[0])})" + def _print_ModularIndexing(self, expr): x, div, mod = expr.args x = self.paren(self.doprint(x)) @@ -420,46 +497,51 @@ def _print_ModularIndexing(self, expr): x = f"({x} // {div})" return f"{x} % {mod}" + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # WARNING: this is dangerous for Triton, which has C-style modulus def _print_FloorDiv(self, expr): x, div = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) return f"({x} // {div})" + # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python + # does a special algorithm + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + def _helper_sqrt(self, expr): return f"math.sqrt({self._print(expr)})" def _print_OpaqueUnaryFn_sqrt(self, expr): return self._helper_sqrt(expr.args[0]) - def _print_Pow(self, expr): - # Pow() confuses triton + def _print_FloatPow(self, expr): base, exp = expr.args - # NB: Remember this is sizevar computation! You don't typically - # expect to have to do floating point computation including exponents - # in sizevar compute. Instead of adding support for floating - # point pow, you should make upstream retranslate the Sympy expression - # into Tensor expressions earlier and do that instead. - if exp == 0.5: - return self._helper_sqrt(base) - elif exp == -0.5: - return "1/" + self._helper_sqrt(base) - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - if exp > 0: - return "*".join([self.paren(base)] * exp) - elif exp < 0: - return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - return "1" + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + + # TODO: Not sure this works with Triton, even when base/exp are integral + def _print_PowByNatural(self, expr): + base, exp = expr.args + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" def _print_floor(self, expr): assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" - def _print_Trunc(self, expr): + def _print_TruncToInt(self, expr): assert len(expr.args) == 1 + # This also could have been int(), they'll do the same thing for float return f"math.trunc({self._print(expr.args[0])})" def _print_ceiling(self, expr): @@ -470,6 +552,9 @@ def _print_Abs(self, expr): assert len(expr.args) == 1 return f"abs({self._print(expr.args[0])})" + # NB: It's expected that we've made explicit any promotion in the sympy + # expression, so it doesn't matter that Python max/min doesn't perform + # promotion def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" @@ -514,7 +599,7 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" @@ -653,6 +738,29 @@ def remainder(a, b): ) return ops.where(cond, ops.add(r, b), r) + @staticmethod + def trunc_to_int(a, dtype): + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def floor_to_int(a, dtype): + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a, dtype): + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def round_to_int(a, dtype): + return ops.to_dtype(ops.round(a), dtype) + + @staticmethod + def int_truediv(a, b): + # TODO: this is wrong + # TODO: an easy bandaid is to generate runtime asserts that it's + # <= 2**53, which is when this equation is correct + return ops.truediv(a, b) + @staticmethod def load_seed(name, offset): return ops.load(name, sympy.Integer(offset)) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index eabb5bbef470..311781102c3f 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -275,11 +275,11 @@ def visit_modular_indexing(divisor, modulus): original_index = index - div = sympy.Wild("divisor") + div = sympy.Wild("divisor", integer=True) if index.has(FloorDiv): index = index.replace(FloorDiv(var, div), visit_indexing_div) - mod = sympy.Wild("modulus") + mod = sympy.Wild("modulus", integer=True) if index.has(ModularIndexing): index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 4ab33a5e26dc..79884364420a 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -100,10 +100,48 @@ def _print_floor(self, expr): r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - def _print_Trunc(self, expr): + def _print_TruncToInt(self, expr): assert len(expr.args) == 1 r = f"std::trunc({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + return f"static_cast<{INDEX_TYPE}>({r})" + + def _print_TruncToFloat(self, expr): + assert len(expr.args) == 1 + return f"std::trunc({self._print(expr.args[0])})" + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"static_cast({self._print(expr.args[0])})" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_CMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**53 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT + # use std::pow, that operates on floats + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + def _print_FloatPow(self, expr): + base, exp = expr.args + return f"std::pow({self._print(base)}, {self._print(exp)})" def _print_Pow(self, expr): # Uses float constants to perform FP div @@ -200,8 +238,9 @@ def _print_OpaqueUnaryFn_atan(self, expr): def _print_OpaqueUnaryFn_sqrt(self, expr): return f"std::sqrt({self._print(expr.args[0])})" - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 + # TODO: dispatch to llrint depending on index type return f"std::lrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 4b0ea92f3bf4..066b6545a0a2 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -272,17 +272,52 @@ def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]): return f"{value}[{', '.join(expand)}]" +# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a +# number of operators which Triton "implements", but in a way that is +# inconsistent with Python semantics (and consistent with C semantics). We +# must override all of these, or it is potential silent correctness problem class TritonPrinter(PythonPrinter): - def _print_floor(self, expr): + def _print_TruncToInt(self, expr): assert len(expr.args) == 1 return ( - f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) - def _print_Trunc(self, expr): + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). If you are trying to hit this, maybe try something like + # torch.arange(n, device="cuda") - 1 and then do a modulus on it + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # TODO: This is wrong, see + # https://github.com/triton-lang/triton/issues/955 + # But for Sympy expressions, things will /mostly/ work out because we + # don't usually deal with negative numbers in the division + def _print_FloorDiv(self, expr): + assert expr.is_integer + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"({x} // {div})" + + # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher + # precision algorithm, which we would need to replicate here + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + # NB: sympy.floor/ceiling produce integers, so we have to do the + # conversion to index dtype + def _print_floor(self, expr): assert len(expr.args) == 1 return ( - f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) def _print_ceiling(self, expr): @@ -359,20 +394,9 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" - def _print_FloorDiv(self, expr): - if expr.is_integer: - return super()._print_FloorDiv(expr) - - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"libdevice.floor({x} / {div}).to({V.kernel.index_dtype})" - - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 - return ( - f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" - ) + return f"libdevice.llrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): assert len(expr.args) == 2 diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 411ac0b45ebb..a22e31baf752 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1211,8 +1211,11 @@ def debug(msg): elif is_magic_method(n.target): # TODO: this is sus, it probably should be handled in the # lowerings themselves similarly to sym_size/sym-stride + # https://github.com/pytorch/pytorch/issues/127789 debug("is_magic_method") - if isinstance(n.meta["val"], torch.SymInt): + if isinstance( + n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) + ): result = n.meta["val"].node.expr else: result = super().run_node(n) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index c46cad5e41e2..e9adfcd19a2d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -44,7 +44,6 @@ is_boolean_dtype, is_float_dtype, make_channels_last_strides_for, - make_contiguous_strides_for, StrideType, ) from torch._subclasses.fake_tensor import get_schema_info @@ -236,7 +235,7 @@ def ir_node_to_tensor(x, guard_shape=True): if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] else: - stride = make_contiguous_strides_for(size) # type: ignore[arg-type] + stride = FlexibleLayout.contiguous_strides(size) # type: ignore[arg-type] dtype = x.get_dtype() device = x.get_device() size = convert_shape_to_symint(size) @@ -2766,6 +2765,7 @@ class FlexibleLayout(Layout): allow_indexing = False + # WARNING! This doesn't handle zero size tensors correctly @staticmethod def contiguous_strides(sizes): if len(sizes) == 0: @@ -5915,7 +5915,7 @@ def _original_deconv_weight_size( # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) if dynamic_shapes and is_contiguous_storage_and_layout(x): - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) else: output_stride = make_channels_last_strides_for(output_size) @@ -5967,7 +5967,7 @@ def _prepare_linear_fusion_create( assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" inputs = [x, weight] - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) kernel_layout = FixedLayout( x.get_device(), x.get_dtype(), @@ -6283,7 +6283,7 @@ def create(cls, x, packed_w, orig_w, B, batch_size): *m, _ = x.get_size() oc, _ = orig_w.get_size() output_size = list(m) + [oc] - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) inputs = [x, packed_w, orig_w] constant_args = [batch_size] if B is not None: @@ -6601,13 +6601,13 @@ def create( def get_strides_of_lstm_output(output_shape, batch_first): assert len(output_shape) == 3, "Expect output_shape to be 3D" - return make_contiguous_strides_for(output_shape) + return FlexibleLayout.contiguous_strides(output_shape) output_sizes = [output_shape, hy_shape, cy_shape] output_strides = [ get_strides_of_lstm_output(output_shape, batch_first), - make_contiguous_strides_for(hy_shape), - make_contiguous_strides_for(cy_shape), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), ] output_ir = [ MultiOutput( diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 42fabf65591d..f3492949a84d 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -5,7 +5,6 @@ from typing import Any, List, Tuple import torch -from torch._prims_common import make_contiguous_strides_for from .. import config from ..ir import ( ComputedBuffer, @@ -389,7 +388,7 @@ def flex_attention(*args, **kwargs): query.get_device(), query.get_dtype(), query.get_size(), - make_contiguous_strides_for(query.get_size()), + FlexibleLayout.contiguous_strides(query.get_size()), ) # see NOTE:[TritonTemplates with multiple outputs] logsumexp_shape = query.get_size()[:-1] # [B, H, M] @@ -745,7 +744,7 @@ def flex_attention_backward(*args, **kwargs): key.get_device(), key.get_dtype(), key.get_size(), - make_contiguous_strides_for(key.get_size()), + FlexibleLayout.contiguous_strides(key.get_size()), ) # Create delta which will is needed for the bwd's kernel diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 20b0082eb1d9..300cf71c2934 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -34,7 +34,7 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 @@ -4262,7 +4262,7 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): out_sz = out_sz[dim] in_sz = in_sz[dim] kernel_sz = kernel_sz[dim] - alpha = (in_sz - kernel_sz) / (out_sz - 1) + alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) samples_loader = samples.make_loader() def load(prefix, i): @@ -4372,7 +4372,7 @@ def upsample_nearest2d_backward( w_kernel_max = ceildiv(inp_w, out_w) def start_index(index, out_dim, inp_dim): - return CeilDiv(index * inp_dim, out_dim) + return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) def end_index(index, out_dim, inp_dim): return start_index((index + 1), out_dim, inp_dim) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 5630061b4426..f88cd948ca4d 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -138,6 +138,38 @@ def to_dtype( """ ... + def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with truncation semantics (similar to how the int + constructor works in Python). In Inductor codegen, this just decays + to trunc and then to_dtype, but this composite operation helps + roundtrips for Sympy evaluation. + + dtype is taken as an explicit parameter because the desired output + dtype is typically the index dtype, which may vary between int32 and + int64 depending on if we've shown that all the indexing operations can + be done in int32. + """ + ... + + def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def floor_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def round_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with round-to-even semantics. See also trunc_to_int. + """ + ... + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: """ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) @@ -398,21 +430,23 @@ def isinf(self, x0: T) -> T: def isnan(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation + # This rounds half to even to break ties def round(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation def floor(self, x0: T) -> T: ... def sign(self, x0: T) -> T: ... - def to_int(self, x0: T) -> T: - ... - + # NB: this returns a float, like the torch operation def trunc(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation def ceil(self, x0: T) -> T: ... @@ -449,6 +483,7 @@ def sub(self, x0: T, x1: T) -> T: def mul(self, x0: T, x1: T) -> T: ... + # NB: this returns a float, like the torch operation def pow(self, x0: T, x1: T) -> T: ... @@ -617,14 +652,21 @@ def truncdiv(self, x0: T, x1: T) -> T: def floordiv(self, x0: T, x1: T) -> T: """Python-style floor division between integers only. Computes the - true division of two numbers and floors the result. + true division of two numbers and floors the result. If you want + floor division for floats, do regular truediv and floor the result. """ ... def truediv(self, x0: T, x1: T) -> T: - """True division between floats. Integer inputs are NOT valid: to do - Python style (int, int) -> float division, promote the inputs to float - first.""" + """True division between floats. Integer inputs are NOT valid. To + do Python-style (int, int) -> float division, use int_truediv""" + ... + + def int_truediv(self, x0: T, x1: T) -> T: + """True division between integers. This is NOT the same as promoting + to float and doing integer division, there is a bespoke algorithm for + doing the division in higher precision than the above. + """ ... def div(self, x0: T, x1: T) -> T: @@ -640,6 +682,10 @@ def remainder(self, x0: T, x1: T) -> T: """Python-style modulus, take sign from RHS (x1).""" ... + def round_decimal(self, x0: T, x1: T) -> T: + """Python-style round with decimal argument""" + ... + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # In CUDA, optimized implementations of other mathematical operations are # offered separately via libdevice for double precision computation (in diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bc89441e3bd8..85d0d0f1954c 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -385,7 +385,7 @@ def store_output( assert isinstance(mask, (str, type(None))) assert self.template_mask is None indices = list(map(TritonPrinter.paren, indices)) - index_symbols = [sympy.Symbol(x) for x in indices] + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] lengths = [ V.graph.sizevars.simplify(s) for s in self.output_node.get_size() ] @@ -409,7 +409,7 @@ def store_output( output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) if output_index == contiguous_index: - output_index = sympy.Symbol("xindex") + output_index = sympy.Symbol("xindex", integer=True) epilogue_args = [val] for input_node in itertools.chain( diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index bc8803a5e715..fba9a66f9237 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -161,9 +161,9 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(ModularIndexing): expr = expr.replace( ModularIndexing( - sympy.Wild("base"), - sympy.Wild("divisor"), - sympy.Wild("modulus"), + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + sympy.Wild("modulus", integer=True), ), visit_modular_indexing, ) @@ -171,8 +171,8 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(FloorDiv): expr = expr.replace( FloorDiv( - sympy.Wild("base"), - sympy.Wild("divisor"), + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), ), visit_indexing_div, ) @@ -604,11 +604,11 @@ def _join_dimensions_cached(expr: Expr) -> Expr: """ assert isinstance(expr, sympy.Add) - scale = sympy.Wild("scale", exclude=[0]) - base = sympy.Wild("base") - divisor = sympy.Wild("divisor") - mod1 = sympy.Wild("modulus") - mod2 = sympy.Wild("modulus2") + scale = sympy.Wild("scale", exclude=[0], integer=True) + base = sympy.Wild("base", integer=True) + divisor = sympy.Wild("divisor", integer=True) + mod1 = sympy.Wild("modulus", integer=True) + mod2 = sympy.Wild("modulus2", integer=True) for term1 in expr.args: m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) if m1: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 0915a8330c34..a635c2f509c1 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -192,7 +192,7 @@ def ceildiv( numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] ) -> Union[int, sympy.Expr]: if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): - return CeilDiv(numer, denom) + return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) # TODO: There is a bug in a call to this function, to repro: # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy # --amp --only YituTechConvBert --dynamic-shapes diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 47d4abcf77b9..9343490de3e8 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1727,7 +1727,7 @@ def go(t, real_t): for run_impl_check, op_impl in op_implementations_checks: if run_impl_check(func): op_impl_out = op_impl(self, func, *args, **kwargs) - if op_impl_out != NotImplemented: + if op_impl_out is not NotImplemented: return maybe_propagate_real_tensors(op_impl_out) def maybe_run_unsafe_fallback(error=None): diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index a7ce337f9ac8..2a3cb62c56d7 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1200,8 +1200,13 @@ void initJITBindings(PyObject* module) { SYMNODE_BINARY(sub) SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) + SYMNODE_BINARY(int_truediv) + SYMNODE_BINARY(float_truediv) SYMNODE_BINARY(pow) + SYMNODE_BINARY(float_pow) + SYMNODE_BINARY(pow_by_natural) SYMNODE_BINARY(floordiv) + SYMNODE_BINARY(int_floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(eq) SYMNODE_BINARY(ne) diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index f8c710cf6579..15738b1a67e1 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -198,14 +198,34 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__func__, other); } + c10::SymNode float_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode int_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode pow(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } + c10::SymNode float_pow(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode pow_by_natural(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode floordiv(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } + c10::SymNode int_floordiv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode mod(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index a4ed16e975b8..ac2bdd60a550 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,7 +1,6 @@ import builtins import dataclasses import inspect -import math import sys import weakref from collections import defaultdict @@ -254,11 +253,14 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): shared: Optional[_ConstraintTarget] = None debug_name: Optional[str] = None - def _clone_with_range(self, lower=0, upper=math.inf): + def _clone_with_range(self, lower=0, upper=None): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges + if upper is None: + upper = sys.maxsize - 1 + constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), warn_only=False, @@ -486,7 +488,6 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): ) # Import sympy locally - import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges @@ -496,7 +497,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): id(t), index, StrictMinMaxConstraint( - vr=ValueRanges(lower=0, upper=sympy.oo), warn_only=False + vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False ), debug_name=debug_name, ) diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 4bf9ebab17b3..28df3fddab0e 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -277,7 +277,13 @@ def wrapper(*args, **kwargs): raise except Exception: - log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs) + log.error( # noqa: G201 + "failed while running %s(*%s, **%s)", + name, + args[1:], + kwargs, + exc_info=log.isEnabledFor(logging.INFO), + ) raise return wrapper diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 98cba67a73a1..4a88d24ce3d5 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -267,8 +267,11 @@ def mul(self, other) -> "SymNode": def mod(self, other) -> "SymNode": return self._mod(other) # type: ignore[attr-defined] - def pow(self, other) -> "SymNode": - return self._pow(other) # type: ignore[attr-defined] + def float_pow(self, other) -> "SymNode": + return self._float_pow(other) # type: ignore[attr-defined] + + def pow_by_natural(self, other) -> "SymNode": + return self._pow_by_natural(other) # type: ignore[attr-defined] def and_(self, other) -> "SymNode": return self._and_(other) # type: ignore[attr-defined] @@ -276,11 +279,14 @@ def and_(self, other) -> "SymNode": def or_(self, other) -> "SymNode": return self._or_(other) # type: ignore[attr-defined] - def truediv(self, other) -> "SymNode": - return self._truediv(other) # type: ignore[attr-defined] + def float_truediv(self, other) -> "SymNode": + return self._float_truediv(other) # type: ignore[attr-defined] - def floordiv(self, other) -> "SymNode": - return self._floordiv(other) # type: ignore[attr-defined] + def int_truediv(self, other) -> "SymNode": + return self._int_truediv(other) # type: ignore[attr-defined] + + def int_floordiv(self, other) -> "SymNode": + return self._int_floordiv(other) # type: ignore[attr-defined] def lshift(self, other) -> "SymNode": return self._lshift(other) # type: ignore[attr-defined] @@ -361,6 +367,17 @@ def sym_or(self, other): def sym_and(self, other): return self.and_(other) + # There is no int_truediv available from C++ + def truediv(self, other): + return self.float_truediv(other) + + def floordiv(self, other) -> "SymNode": + return self.int_floordiv(other) + + # We didn't bind integer pow in C++ + def pow(self, other): + return self.float_pow(other) + def is_non_overlapping_and_dense(self, sizes, strides): return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] @@ -477,7 +494,7 @@ def is_constant(self): "eq": operator.eq, "floor": math.floor, "trunc": math.trunc, - "floordiv": operator.floordiv, + "int_floordiv": operator.floordiv, "ge": operator.ge, "gt": operator.gt, "is_integer": lambda x: x.is_integer(), @@ -489,7 +506,8 @@ def is_constant(self): "ne": operator.ne, "neg": operator.neg, "or": operator.or_, - "pow": operator.pow, + "float_pow": operator.pow, + "pow_by_natural": operator.pow, "round": builtins.round, "rshift": operator.rshift, "sub": operator.sub, @@ -498,12 +516,14 @@ def is_constant(self): "sym_max": sym_max, "sym_min": sym_min, "sym_not": sym_not, - "truediv": operator.truediv, + "float_truediv": operator.truediv, + "int_truediv": operator.truediv, } unary_magic_methods = { "abs", "sym_float", + "sym_int", "ceil", "floor", "neg", @@ -559,20 +579,20 @@ def fn(self): bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods # Methods that are only for float -only_float_magic_methods = {"is_integer"} +only_float_magic_methods = {"is_integer", "round", "sym_int"} magic_methods_on_operator_with_trailing_underscore = {"and", "or"} -always_float_magic_methods = {"truediv", "sym_float", "pow"} +always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} for name in math_op_names: sym_name = f"sym_{name}" always_float_magic_methods.add(sym_name) -always_int_magic_methods = {"ceil", "floor", "trunc"} +always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} always_bool_magic_methods = { "eq", "ne", @@ -590,10 +610,16 @@ def fn(self): # Methods that have a `__foo__` as well as `__rfoo__` -def _sympy_truediv(a, b): - from torch.utils._sympy.functions import TrueDiv +def _sympy_float_truediv(a, b): + from torch.utils._sympy.functions import FloatTrueDiv + + return FloatTrueDiv(a, b) + - return TrueDiv(a, b) +def _sympy_int_truediv(a, b): + from torch.utils._sympy.functions import IntTrueDiv + + return IntTrueDiv(a, b) def _sympy_floordiv(a, b): @@ -603,15 +629,24 @@ def _sympy_floordiv(a, b): def _sympy_mod(a, b): - from torch.utils._sympy.functions import Mod + from torch.utils._sympy.functions import Mod, PythonMod + + if a.is_nonnegative and b.is_nonnegative: + return Mod(a, b) + else: + return PythonMod(a, b) - return Mod(a, b) +def _sympy_pow_by_natural(a, b): + from torch.utils._sympy.functions import PowByNatural -def _sympy_pow(a, b): - from torch.utils._sympy.functions import Pow + return PowByNatural(a, b) - return Pow(a, b) + +def _sympy_float_pow(a, b): + from torch.utils._sympy.functions import FloatPow + + return FloatPow(a, b) def _sympy_and(a, b): @@ -643,11 +678,13 @@ def _sympy_rshift(a, b): "sub": operator.sub, "mul": operator.mul, "mod": _sympy_mod, - "pow": _sympy_pow, + "pow_by_natural": _sympy_pow_by_natural, + "float_pow": _sympy_float_pow, "and": _sympy_and, "or": _sympy_or, - "truediv": _sympy_truediv, - "floordiv": _sympy_floordiv, + "float_truediv": _sympy_float_truediv, + "int_truediv": _sympy_int_truediv, + "int_floordiv": _sympy_floordiv, "lshift": _sympy_lshift, "rshift": _sympy_rshift, } @@ -671,18 +708,22 @@ def _floor_ceil_helper(a, fn): return fn(a) +# NB: this is Python semantics so it returns an int def _sympy_floor(a): import sympy return _floor_ceil_helper(a, sympy.floor) +# NB: this is Python trunc semantics which returns an int. Do NOT use this to +# represent torch.trunc (which is float to float) def _sympy_trunc(a): - from torch.utils._sympy.functions import Trunc + from torch.utils._sympy.functions import TruncToInt - return Trunc(a) + return TruncToInt(a) +# NB: this is Python semantics so it returns an int def _sympy_ceil(a): import sympy @@ -771,26 +812,28 @@ def _sympy_abs(a): def _sympy_round(number, ndigits=None): - from torch.utils._sympy.functions import Round, RoundDecimal + from torch.utils._sympy.functions import RoundDecimal, RoundToInt if ndigits is None: - return Round(number) + return RoundToInt(number) else: return RoundDecimal(number, ndigits) def _sympy_sym_float(a): - # Cannot use sympy.Float(a) here, coz it expects python literals - # Multiply by 1.0 to cast to float. This is needed when the input - # is a SymInt which has the assumption that it is integer and - # SymPy will otherwise assume that return value cannot be a float. - return a * 1.0 + from torch.utils._sympy.functions import ToFloat + + # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly + # reports that it is an integer + return ToFloat(a) def _sympy_is_integer(a): import sympy - return sympy.Eq(sympy.floor(a), a) + from torch.utils._sympy.functions import ToFloat + + return sympy.Eq(ToFloat(sympy.floor(a)), a) magic_methods = { @@ -989,9 +1032,26 @@ def binary_magic_impl(self, other): self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) ) assert isinstance(other, SymNode) - # TODO: consider constant prop here try: - out = func(self.expr, other.expr) + if method == "mod": + from torch.utils._sympy.functions import Mod, PythonMod + + # Special handling for mod that requires access to the value + # ranges + shape_env = self.shape_env + if ( + self.expr.is_nonnegative + or shape_env.bound_sympy(self.expr).lower >= 0 + ) and ( + other.expr.is_nonnegative + or shape_env.bound_sympy(other.expr).lower >= 0 + ): + out = Mod(self.expr, other.expr) + else: + out = PythonMod(self.expr, other.expr) + else: + # TODO: consider constant prop here + out = func(self.expr, other.expr) except Exception: log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) raise @@ -1122,9 +1182,13 @@ def round_impl(self, ndigits=None): except Exception: log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) raise + out = safe_expand(out) - pytype = int if ndigits is None else self.pytype + if ndigits is None: + pytype = int + else: + pytype = self.pytype out_hint = None if self.hint is not None: @@ -1136,6 +1200,7 @@ def round_impl(self, ndigits=None): # hack down below works, because all round function down the line all take ndigits=None as default in their # signature. # TODO: Remove the args construction below if a different sentinel is used by FX. + # ezyang(May 2024): LOL args = [self.fx_node] if ndigits is not None: args.append(ndigits) @@ -1259,6 +1324,32 @@ def is_constant(x): return x.node.is_constant() return False + # Promotion rules for binary operations. NB: we preserve PYTHON semantics + # - if args are same type, do nothing + # - if one arg is float, promote other arg to float + # - nb: this applies to floordiv, even though output is integral + # (it's still float) + # - pow is funny business + # - if both ints + # - trigger a guard on exponent >= 0 + # - if non-negative, output is int + # - otherwise, output is float + # - otherwise, promote other arg to float + # - nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + # - equality is pain: Python does the fancy thing where it unpacks the + # mantissa from the float and then compares that against the int. + # Which means it is able to tell that + # 9007199254740993 != 9007199254740992. (rather than if the LHS was + # promoted to float, in which case it would have truncated to the RHS + # and subsequently been equal). We'll model this exactly by having + # special mixed type equality operations. Unfortunately, we need to + # do this for all comparison operations (maybe I'll only implement + # compare) + # - sym_ite mumble mumble really shouldn't allow mixed but whatever + if method in bool_becomes_int_magic_methods: def promote(x): @@ -1272,6 +1363,41 @@ def promote(x): def promote(x): return x + def promote2(self, other): + # TODO: Remove eq and other relations from this list. + # CPython has fancy implementations for these to get as much precision + # as possible instead of just promoting to float64 and praying, so we + # need to handle them specially too. + # Also, note that int_truediv doesn't go through this path: both + # arguments are "int" so there isn't any promotion + if method not in [ + "add", + "sub", + "mul", + "mod", + "float_pow", + "float_truediv", + "int_floordiv", + "sym_min", + "sym_max", + # TODO: remove these + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + ]: + return self, other + f_self = isinstance(self, (float, torch.SymFloat)) + f_other = isinstance(other, (float, torch.SymFloat)) + if f_self or f_other: + if not f_self: + self = torch.sym_float(self) + if not f_other: + other = torch.sym_float(other) + return self, other + # Before and after performing the operation, check if any operands are constant. # If so, extract out the constant values first. If `self` itself is a # constant, then "redispatch" by calling back into the operator. Sometimes @@ -1286,9 +1412,12 @@ def unary_magic_impl(self): return wrap_node(getattr(self.node, method_attr)()) def binary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented sym_node_log.debug("MAGIC %s %s %s", method, self, other) self = promote(self) other = promote(other) + self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): @@ -1300,8 +1429,11 @@ def binary_magic_impl(self, other): return get_constant(ret) if is_constant(ret) else ret def rbinary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented self = promote(self) other = promote(other) + self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index a2abde3a861e..d7321f071865 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -61,7 +61,7 @@ from torch import SymBool, SymFloat, SymInt from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator +from torch.utils._sympy.functions import FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.singleton_int import SingletonInt @@ -869,9 +869,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): for N=1. """ if min is None: - min = -sympy.oo + min = -sys.maxsize - 1 if max is None: - max = sympy.oo + max = sys.maxsize - 1 if max < min: raise ValueError( @@ -979,16 +979,6 @@ def eval_guards(gm, *args, ignore_static=True): def bind_symbols(gm, *args): return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) -def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges): - """ - We assert that the bounds are either Boolean, or not finite, or can be computed - in exact prevision via rational arithmetic. - The only exception to this is the rare case when the user calls `sqrt(s0)` - sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still) - """ - assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr) - assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr) - class DimDynamic(Enum): """ Controls how to perform symbol allocation for a dimension. It is always @@ -1387,14 +1377,17 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 'Min': min, 'Max': max, 'Mod': operator.mod, + 'PythonMod': operator.mod, 'FloorDiv': operator.floordiv, 'TrueDiv': operator.truediv, 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 'floor': math.floor, 'ceiling': math.ceil, 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, - 'Round': builtins.round, + 'RoundToInt': builtins.round, 'RoundDecimal': builtins.round, + 'TruncToInt': math.trunc, + 'IntTrueDiv': operator.truediv, } @@ -1642,10 +1635,17 @@ def floor_div_handler(*args): congruence = (base - mod_reduced) % divisor if congruence != 0: self._congruences[s].add(congruence) + # NB: Must not be CleanDiv, it needs to be regular sympy division + # so inequality solver works. This is sort of problematic for + # is_integer tests though haha return (base - mod_reduced) / divisor if expr.has(Mod): expr = expr.replace(Mod, mod_handler) + # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative + # arguments should be OK. + if expr.has(PythonMod): + expr = expr.replace(PythonMod, mod_handler) if expr.has(FloorDiv): expr = expr.replace(FloorDiv, floor_div_handler) return expr @@ -3330,6 +3330,7 @@ def create_unbacked_symfloat(self): self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() + assert vr.is_float # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, float) @@ -3348,6 +3349,7 @@ def create_unbacked_symint(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = self._default_unspecified_value_range() + assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, int) @@ -3371,6 +3373,7 @@ def create_unbacked_symbool(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges(0, 1) + assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) @@ -3516,6 +3519,7 @@ def create_symbol( self.var_to_range[sympy_expr] &= constraint_dim.vr vr = self.var_to_range[sympy_expr] + assert vr.is_int if val not in vr: raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") @@ -3524,6 +3528,7 @@ def create_symbol( elif isinstance(val, float): self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) range_str = f"[{vr.lower}, {vr.upper}]" + assert vr.is_float else: # Skip var_range logic for SingletonInt # Only used for jagged layout nested tensors @@ -3573,6 +3578,7 @@ def create_symbol( def add_var_to_val(self, expr: sympy.Symbol, val: int): """ Adds a new symbol to the symbolic environment. """ + log.debug("add_var_to_val %s %s", expr, val, stack_info=True) assert expr not in self.var_to_val, f"{expr} already exists" self.var_to_val[expr] = sympy.Integer(val) @@ -4301,7 +4307,8 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa # Clamp values of size-like variables for x in self.size_like & var_to_range.keys(): if var_to_range[x] is not None: - var_to_range[x] = ValueRanges(2, sympy.oo) + var_to_range[x] = ValueRanges(2, sys.maxsize - 1) + assert var_to_range[x].is_int return bound_sympy(expr, var_to_range) @_lru_cache @@ -4418,6 +4425,11 @@ def _maybe_evaluate_static( vr = self._default_unspecified_value_range() if size_oblivious and k in self.size_like: lower = max(2, vr.lower) + # This is a bit dodgy: what this means is that there was a + # size-like unbacked symbol whose upper bound < 2. This + # causes... problems. + if lower <= vr.upper: + vr = ValueRanges(lower, vr.upper) else: lower = vr.lower # Don't do anything if we don't have a nontrivial lower bound @@ -4425,10 +4437,17 @@ def _maybe_evaluate_static( # SymInt if ( lower < (-sys.maxsize - 1) // 2 or - (unbacked_only and k in self.var_to_val) + (unbacked_only and k in self.var_to_val) or + not vr.is_int ): new_range_env[k] = vr continue + # The goal is to take our symbols which have various lower bounds + # and reallocate them into new symbols which are exactly positive; + # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in + # [1, inf], where s0 = ess0 + 1. This gives the most information + # to sympy for subsequent simplifications. + # # Positive means >= 1 # Positive - 1 means >= 0 # Positive + lower - 1 means >= lower @@ -4460,6 +4479,14 @@ def replace(expr, repl): self.counter["sympy_recursion_error"] += 1 return None + new_expr = safe_expand(new_expr) + if new_expr.is_number: + return new_expr + + # This is bad to do, the replacement with division leaves us with + # rationals when atom.args[0] is addition, e.g., sympy will happily + # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication! + """ floor_div_replace = {} for atom in new_expr.atoms(FloorDiv): floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) @@ -4468,13 +4495,12 @@ def replace(expr, repl): # are still free symbols if new_expr.is_number: return new_expr + """ # Check if the range can solve it statically out = bound_sympy(new_expr, new_range_env) - if expect_rational: - _assert_bound_is_rational(new_expr, out) - if out.is_singleton(): - return out.lower + if out.is_singleton(): + return out.lower return new_expr if unbacked_only else None @@ -4526,7 +4552,7 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": for fd in expr.atoms(FloorDiv): base, divisor = fd.args if self.replace(Mod(base, divisor)) in self.divisible: - div_replacements[fd] = base / divisor + div_replacements[fd] = CleanDiv(base, divisor) new_expr = expr.xreplace(div_replacements) new_expr = safe_expand(new_expr) new_pows = new_expr.atoms(sympy.Pow) @@ -4670,7 +4696,10 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1) def issubset(x, y): - return (x & int_range).issubset(y & int_range) + if x.is_int and y.is_int: + return (x & int_range).issubset(y & int_range) + else: + return x.issubset(y) # First, refine the value range of a based on the computed value range # of tgt. This is always OK to do, even if we decide not to do the @@ -4688,7 +4717,7 @@ def issubset(x, y): b = next(iter(tgt.free_symbols)) # Try to invert the equality r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) - if r is not None: + if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])): b_bound = self.bound_sympy(r[1]) self.var_to_range[b] = b_bound & self.var_to_range[b] tgt_bound = self.bound_sympy(tgt) @@ -4899,12 +4928,12 @@ def trivial_solve(lhs, rhs): ): # We have Mod(i0, q / c) == 0, which means we can # rewrite i0 as (q / gcd(q, c)) * i1 - d = q / sympy.gcd(q, c) + d = q / sympy.gcd(q, c) # TODO: CleanDiv? i1 = self.create_unbacked_symint().node.expr # Propagate the value ranges. It doesn't really # matter if we use truediv or floordiv, because we # have established divisibility. - self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv( + self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv( self.var_to_range[i0], ValueRanges.wrap(d) )) # Propagate size-like-ness @@ -5341,7 +5370,6 @@ def _refine_ranges(self, expr: sympy.Expr) -> None: lower, upper = vr.lower, vr.upper rhs_vr = bound_sympy(rhs, self.var_to_range) - _assert_bound_is_rational(rhs, rhs_vr) # Let's suppose that we have a preexisting range for x [0, 100]. # Now, we issue a guard x > y, where the range for y is [50, 150]. diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 6dcb59db7979..d06b38d60c80 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -216,10 +216,7 @@ def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: def abs(self, number: z3.ArithRef) -> z3.ArithRef: return z3.Abs(number) - def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: - if ndigits is not None: - raise ValueError("round(..., ndigits=) is currently not supported by shape validations.") - + def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: # Pythons builtin 'round' implements the 'round half to even' strategy # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to @@ -284,7 +281,7 @@ def wrapper(*args): operator.truediv: lift(ops.div), operator.mod: lift(ops.mod), operator.abs: lift(ops.abs), - builtins.round: lift(ops.round), + builtins.round: lift(ops.round_to_int), # Math module. math.ceil: lift(ops.ceil), @@ -350,6 +347,7 @@ def __init__( self._ops = _Z3Ops(self._validator) def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + # TODO: Probably OK to relax this and allow lower precision if dtype is torch.int64: return z3.IntVal(int(value)) if dtype is torch.double: @@ -358,6 +356,20 @@ def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: return z3.BoolVal(bool(value)) raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") + def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + if dtype == torch.float64: + return z3.ToReal(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return z3.ToInt(x) + + def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.round_to_int(x) + + def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: return self._ops.div(numerator, denominator) @@ -370,11 +382,17 @@ def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: return self._ops.pow(base, exp) + def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: return self._ops.mod(p, q) - def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: - return self._ops.round(number, ndigits) + def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.ceil(x) + + def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.floor(x) def __getattr__(self, name: str) -> Any: REPLACEMENT = { diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 1384261b4512..9b1599288949 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,43 +1,78 @@ +import functools import math +import sys import sympy from sympy import S -from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or __all__ = [ "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", - "Pow", - "TrueDiv", + "IntTrueDiv", + "FloatTrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", - "Round", + "RoundToInt", "RoundDecimal", + "ToFloat", + "FloatPow", + "PowByNatural", ] +def _keep_float(f): + @functools.wraps(f) + def inner(*args): + r = f(*args) + if any(isinstance(a, sympy.Float) for a in args) and not isinstance( + r, sympy.Float + ): + r = sympy.Float(float(r)) + return r + + return inner + + def fuzzy_eq(x, y): if None in (x, y): return None return x == y +# It would be nice to have assertions on whether or not inputs is_integer +# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy +# sometimes inconsistently reports floats an integers. +# +# What we can assume from sympy is that if something is an int, it +# definitely is is_integer, but if it is a float it may or may not +# be is_integer. So we are unable to do strong asserts that things +# are NOT integers. + + +# TODO: In Triton, // rounds to zero, but in Python, it is floor division. +# When we can prove both arguments are non-negative, we should just have a +# GenericFloorDiv (name pending) which can codegen efficiently in Python/C, +# and then PythonFloorDiv and CIntDiv which have the appropriate rounding +# semantics. +# +# Right now, FloorDiv de facto changes behavior if arguments are negative or +# not, this can potentially cause correctness issues. class FloorDiv(sympy.Function): """ We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) + + NB: This is Python-style floor division, round to -Inf """ nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 - # Default return type for SymPy assumptions. - # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers - is_real = True + is_integer = True @property def base(self): @@ -52,29 +87,14 @@ def _sympystr(self, printer): divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" - # SymPy assumptions based on argument types. - def _eval_is_real(self): - return fuzzy_or([self.base.is_real, self.divisor.is_real]) - - def _eval_is_integer(self): - return fuzzy_and([self.base.is_integer, self.divisor.is_integer]) - # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod def eval(cls, base, divisor): - def check_supported_type(x): - if ( - x.is_integer is False and x.is_real is False and x.is_complex - ) or x.is_Boolean: - raise TypeError( - f"unsupported operand type(s) for //: " - f"'{type(base).__name__}' and '{type(divisor).__name__}'" - f", expected integer or real" - ) - - check_supported_type(base) - check_supported_type(divisor) + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # Assert triggered by inequality solver + # assert base.is_integer, base + # assert divisor.is_integer, divisor # We don't provide the same error message as in Python because SymPy # makes it difficult to check the types. @@ -85,26 +105,22 @@ def check_supported_type(x): return sympy.S.Zero if base.is_integer and divisor == 1: return base - if base.is_real and divisor == 1: - return sympy.floor(base) if base.is_integer and divisor == -1: return sympy.Mul(base, -1) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): - return base // divisor - if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance( - divisor, (sympy.Integer, sympy.Float) - ): - return sympy.floor(base / divisor) + return sympy.Integer(int(base) // int(divisor)) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) - if isinstance(divisor, sympy.Rational) and divisor.p == 1: - return sympy.floor(base * divisor.q) + # gcd in sympy is over polynomials, so you'll end up with rationals if + # you do this. Don't. + """ if isinstance(base, sympy.Add): for a in base.args: gcd = sympy.gcd(a, divisor) if gcd == divisor: return FloorDiv(base - a, divisor) + a / gcd + """ try: gcd = sympy.gcd(base, divisor) @@ -126,6 +142,10 @@ class ModularIndexing(sympy.Function): @classmethod def eval(cls, base, divisor, modulus): + assert isinstance(base, int) or base.is_integer, base + assert isinstance(divisor, int) or divisor.is_integer, divisor + assert isinstance(modulus, int) or modulus.is_integer, modulus + if base == 0 or modulus == 1: return sympy.Integer(0) @@ -189,6 +209,19 @@ class Where(sympy.Function): nargs = (3,) + def _eval_is_integer(self): + return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] + + def _eval_is_nonnegative(self): + return ( + True + if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] + else None + ) + + def _eval_is_positive(self): + return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] + @classmethod def eval(cls, c, p, q): if c == sympy.true: @@ -197,28 +230,27 @@ def eval(cls, c, p, q): return q -class Mod(sympy.Function): - """ - We maintain this so that we avoid SymPy correctness issues, such as: - https://github.com/sympy/sympy/issues/25146 - """ - +# Python-style modulus: take sign from RHS +class PythonMod(sympy.Function): nargs = (2,) + is_integer = True + @classmethod def eval(cls, p, q): - # This was adapted from: sympy/core/mod.py + # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint + # Triggered by sympy.solvers.inequalities.reduce_inequalities + # assert p.is_integer, p + # assert q.is_integer, q if q.is_zero: raise ZeroDivisionError("Modulo by zero") - # If either of them is NaN or infinite. - if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False: - return S.NaN + # Three cases: # 1. p == 0 # 2. p is either q or -q # 3. p is integer and q == 1 - if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1): + if p is S.Zero or p in (q, -q) or q == 1: return S.Zero # Evaluate if they are both literals. @@ -247,10 +279,7 @@ def eval(cls, p, q): if sympy.Mod(p, q) == 0: return S.Zero - def _eval_is_integer(self): - p, q = self.args - return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined] - + # NB: args[1] for PythonMod def _eval_is_nonnegative(self): return True if self.args[1].is_positive else None # type: ignore[attr-defined] @@ -258,6 +287,58 @@ def _eval_is_nonpositive(self): return True if self.args[1].is_negative else None # type: ignore[attr-defined] +# Generic modulus: only defined on non-negative arguments +class Mod(sympy.Function): + nargs = (2,) + + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, p, q): + # This was adapted from: sympy/core/mod.py + + # Triggered by + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # assert p.is_integer, p + # assert q.is_integer, q + + if q.is_zero: + raise ZeroDivisionError("Modulo by zero") + + # Three cases: + # 1. p == 0 + # 2. p is either q or -q + # 3. p is integer and q == 1 + if p is S.Zero or p in (q, -q) or q == 1: + return S.Zero + + # Evaluate if they are both literals. + if q.is_Number and p.is_Number: + assert p >= 0, p + assert q >= 1, q + return p % q + + # If q == 2, it's a matter of whether p is odd or even. + if q.is_Number and q == 2: + if p.is_even: + return S.Zero + if p.is_odd: + return S.One + + # If p is a multiple of q. + r = p / q + if r.is_integer: + return S.Zero + + # If p < q and its ratio is positive, then: + # - floor(p / q) = 0 + # - p % q = p - floor(p / q) * q = p + less = p < q + if less.is_Boolean and bool(less) and r.is_positive: + return p + + class CleanDiv(FloorDiv): """ Div where we can assume no rounding. @@ -275,6 +356,10 @@ class CeilDiv(sympy.Function): is_integer = True def __new__(cls, base, divisor): + base = sympy.sympify(base) + divisor = sympy.sympify(divisor) + assert base.is_integer, base + assert divisor.is_integer, divisor if sympy.gcd(base, divisor) == divisor: return CleanDiv(base, divisor) else: @@ -282,43 +367,139 @@ def __new__(cls, base, divisor): class LShift(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, shift): + assert base.is_integer, base + assert shift.is_integer, shift + if shift < 0: raise ValueError("negative shift count") return base * 2**shift class RShift(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, shift): + assert base.is_integer, base + assert shift.is_integer, shift + if shift < 0: raise ValueError("negative shift count") return base // 2**shift -# Overloaded to be compatible with regular Python. -# https://github.com/pytorch/pytorch/issues/90900 -class Pow(sympy.Function): +def safe_pow(base, exp): + sign = 1 + if base < 0: + base = -base + sign = 1 if exp % 2 == 0 else -1 + return sign * _safe_pow(base, exp) + + +def _safe_pow(base, exponent): + if exponent < 0: + raise ValueError("Exponent must be non-negative.") + + if exponent == 0: + return 1 + + half_exp = safe_pow(base, exponent // 2) + if half_exp > sys.maxsize - 1: + return sys.maxsize - 1 + + result = half_exp * half_exp + if result > sys.maxsize - 1: + return sys.maxsize - 1 + + if exponent % 2 == 1: + result *= base + if result > sys.maxsize - 1: + return sys.maxsize - 1 + + return result + + +class PowByNatural(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, exp): - if exp.is_zero: - return sympy.Integer(1) - elif base.is_zero and exp < 0: - raise ZeroDivisionError(f"{base} cannot be raised to a negative power") - else: - return base**exp + # exp can be assumed to be is_integer and is_nonnegative, but we may + # have concluded this externally from Sympy assumptions, so we can't + # assert the nonnegative + assert exp.is_integer, exp + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Integer(safe_pow(base, exp)) + if isinstance(exp, sympy.Integer): + # Translate power into iterated multiplication + r = sympy.Integer(1) + for _ in range(int(exp)): + r *= base + return r + # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp + # is a natural number if we do + + +# base is assumed to be nonnegative, thereby prevent complex numbers from +# occuring +class FloatPow(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, base, exp): + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Float(float(base) ** float(exp)) + # NB: do not do any nontrivial reasoning # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 -class TrueDiv(sympy.Function): +# +# In particular, sympy division is willing to simplify x/x == 1 +# where 1 is an integer, but this must be a float if x was float. +class FloatTrueDiv(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, base, divisor): + # assert base.is_integer is not True, base + # assert divisor.is_integer is not True, divisor + + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(float(base) / float(divisor)) + + +# Overloaded to be compatible with regular Python. We distinguish this from +# FloatTrueDiv, because the code generation has to be different for this case: +# Python has a fancy algorithm for integer true division that isn't just +# "promote both arguments to float and use float division", so you need to +# codegen it differently. While technically you can work it out from the +# types of the input, this is often inconvenient to do in Inductor codegen, +# so just have a different operator +# NB: Right now, Inductor codegen doesn't implement this correctly lol +class IntTrueDiv(sympy.Function): + is_integer = False + is_real = True + @classmethod def eval(cls, base, divisor): + assert base.is_integer, base + assert divisor.is_integer, divisor + if divisor.is_zero: raise ZeroDivisionError("division by zero") - else: - return base / divisor + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(int(base) / int(divisor)) # TODO: As an indicator, this != 0 implies == 1 (and vice versa). @@ -353,45 +534,87 @@ def eval(cls, *args): return None -class Trunc(sympy.Function): +# NB: this is inconsistent with math.trunc in Python +class TruncToFloat(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if isinstance(number, sympy.Number): + # NB: It is safe to use truncation to integer, which is what + # math.trunc does, as Python integers are arbitrary precision and + # so we are guaranteed not to lose precision when we do this + return sympy.Float(math.trunc(float(number))) + + +class TruncToInt(sympy.Function): is_integer = True @classmethod def eval(cls, number): - if number.is_integer: - return number - elif isinstance(number, sympy.Number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): return sympy.Integer(math.trunc(float(number))) -class Round(sympy.Function): +# This is float -> int +class RoundToInt(sympy.Function): is_integer = True @classmethod def eval(cls, number): - if number.is_integer: - return number - elif isinstance(number, sympy.Number): - return sympy.Integer(round(float(number))) + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Float): + return sympy.Integer(round(float(number), 0)) + - def __int__(self): - # This will only ever be called when computing size hints. At that point, self.args[0] should be a number and - # no longer an expression. If it were, the float call would fail and the caller would handle this further. - return round(float(self.args[0])) # type: ignore[arg-type] +# To get float -> int, Python style round semantics. +# +# x = PyFloat_AsDouble(self); +# if (o_ndigits == Py_None) { +# /* single-argument round or with None ndigits: +# * round to nearest integer */ +# rounded = round(x); +# if (fabs(x-rounded) == 0.5) +# /* halfway case: round to even */ +# rounded = 2.0*round(x/2.0); +# return PyLong_FromDouble(rounded); +# } +# NB: Like Round, this only ever returns floats. ndigits cannot be None class RoundDecimal(sympy.Function): + is_integer = False + is_real = True + @classmethod def eval(cls, number, ndigits): - if number.is_integer and ndigits >= 0: + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): + return sympy.Float(round(float(number), int(ndigits))) + + +class ToFloat(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, number): + if number in [sympy.oo, -sympy.oo]: return number - elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): - value_type, output_type = ( - (int, sympy.Integer) - if isinstance(number, sympy.Integer) - else (float, sympy.Float) - ) - return output_type(round(value_type(number), int(ndigits))) + + assert number.is_integer, number + + if isinstance(number, sympy.Integer): + return sympy.Float(int(number)) def make_opaque_unary_fn(name): diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 806e91cfe281..c2d9ae464125 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -16,15 +16,20 @@ import torch from .functions import ( CleanDiv, + FloatPow, + FloatTrueDiv, FloorDiv, + IntTrueDiv, IsNonOverlappingAndDenseIndicator, Mod, ModularIndexing, - Pow, - Round, + PowByNatural, + PythonMod, RoundDecimal, - TrueDiv, - Trunc, + RoundToInt, + ToFloat, + TruncToFloat, + TruncToInt, Where, ) @@ -49,30 +54,39 @@ def handlers(): sympy.Le: "le", sympy.Ge: "ge", sympy.Not: "not_", - TrueDiv: "truediv", + IntTrueDiv: "int_truediv", + FloatTrueDiv: "truediv", FloorDiv: "floordiv", - CleanDiv: "div", - Trunc: "trunc", + CleanDiv: "floordiv", # TODO: hmm? + TruncToFloat: "trunc", Where: "where", sympy.Add: "add", sympy.Mul: "mul", - Pow: "pow", - sympy.Pow: "pow", + FloatPow: "pow", + PowByNatural: "pow_by_natural", + # sympy simplifies x * x into Pow(x, 2), so we need to handle this. + # Do NOT use builtin Pow for floats + # TODO: There is a hazard here, if we have float * float it will + # also get turned into Pow(float, 2) but we don't want this because + # pow_by_natural is assumed to only be integers. Probably the fix is + # to add a FloatMul to impede this optimization + sympy.Pow: "pow_by_natural", Mod: "mod", + PythonMod: "mod", # TODO: this is wrong + # TODO: Inductor can generate these, but it's ill-specified which + # semantics were intended here. Needs to be cleaned up along with + # FloorDiv in a bigger cleanup sympy.Mod: "mod", sympy.Abs: "abs", sympy.log: "log", sympy.exp: "exp", - sympy.floor: "floor", - sympy.ceiling: "ceil", sympy.Min: "minimum", sympy.Max: "maximum", ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", - Round: "round", - RoundDecimal: "round", + RoundDecimal: "round_decimal", } for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: HANDLERS[getattr(sympy, name)] = name @@ -84,7 +98,11 @@ def handlers(): def sympy_interp( - analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean] + analysis, + env: Dict[sympy.Symbol, Any], + expr: Union[sympy.Expr, SympyBoolean], + *, + index_dtype=torch.int64, ): # Handle base cases dtype = None @@ -105,9 +123,30 @@ def sympy_interp( expr.args[1], sympy.core.numbers.Half ): return analysis.sqrt(sympy_interp(analysis, env, expr.args[0])) + if isinstance(expr, ToFloat): + return analysis.to_dtype( + sympy_interp(analysis, env, expr.args[0]), torch.float64 + ) # Recursive case args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type] + + # These handlers are special because they take an extra dtype argument + # specifying what they should convert to, and we need to appropriately set + # this up when we convert from Sympy. A reasonable default when you + # are translating is to conservatively do int64, and then narrow these + # arguments later when you discover you can narrow the index range. But + # if you already know that 32-bit indexing is OK, you can directly do the + # sympy translation with index_dtype=torch.int32 + INDEX_DTYPE_HANDLERS = { + TruncToInt: "trunc_to_int", + sympy.floor: "floor_to_int", + sympy.ceiling: "ceil_to_int", + RoundToInt: "round_to_int", + } + if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: + return getattr(analysis, handler_name)(*args, index_dtype) + if hasattr(expr.func, "_torch_handler_name"): handler_name = expr.func._torch_handler_name else: diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 881b9d616eb5..b54a0d0503a1 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,12 +1,25 @@ import math +import operator + import sympy import torch from torch.utils._sympy.functions import ( + _keep_float, + FloatPow, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, + Mod, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, OpaqueUnaryFn_sqrt, + PowByNatural, + RoundDecimal, + RoundToInt, + ToFloat, + TruncToInt, ) @@ -62,20 +75,41 @@ def not_(a): @staticmethod def reciprocal(x): - return 1 / x + return FloatTrueDiv(1.0, x) @staticmethod def square(x): - return x * x + return PowByNatural(x, 2) + + @staticmethod + def trunc_to_int(x, dtype): + return TruncToInt(x) + + @staticmethod + def ceil_to_int(x, dtype): + return sympy.ceiling(x) + + @staticmethod + def floor_to_int(x, dtype): + return sympy.floor(x) + + @staticmethod + def floor(x): + return _keep_float(sympy.floor)(x) + + @staticmethod + def ceil(x): + return _keep_float(sympy.ceiling)(x) + + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return ToFloat(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") @staticmethod def mod(x, y): - ret = abs(x) % abs(y) - # without check: - # tracing will fail trying to go through control-flow if x is Proxy() - if isinstance(x, (int, sympy.Number)) and x < 0: - ret *= -1 - return ret + return Mod(x, y) @staticmethod def abs(x): @@ -87,37 +121,31 @@ def neg(x): @staticmethod def truediv(a, b): - return a / b + return FloatTrueDiv(a, b) @staticmethod - def div(a, b): - return ReferenceAnalysis.truediv(a, b) + def int_truediv(a, b): + return IntTrueDiv(a, b) @staticmethod def floordiv(a, b): - if b == 0: - return sympy.nan if a == 0 else sympy.zoo - return a // b + return FloorDiv(a, b) @staticmethod def truncdiv(a, b): - result = a / b - if result.is_finite: - result = sympy.Integer(result) - - return result + raise NotImplementedError("TODO: truncdiv") @staticmethod def add(a, b): - return a + b + return _keep_float(operator.add)(a, b) @staticmethod def mul(a, b): - return a * b + return _keep_float(operator.mul)(a, b) @staticmethod def sub(a, b): - return a - b + return _keep_float(operator.sub)(a, b) @staticmethod def exp(x): @@ -133,39 +161,27 @@ def sqrt(x): @staticmethod def pow(a, b): - return a**b + return _keep_float(FloatPow)(a, b) + + @staticmethod + def pow_by_natural(a, b): + return PowByNatural(a, b) @staticmethod def minimum(a, b): - # Poorman's version of upcasting in Sympy - # This won't do for sympy.Expr as the casting does nothing for those - if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: - result_type = sympy.Float - else: - assert a.is_Integer - assert b.is_Integer - result_type = sympy.Integer - return sympy.Min(result_type(a), result_type(b)) + return sympy.Min(a, b) @staticmethod def maximum(a, b): - # Poorman's version of upcasting in Sympy - # This won't do for sympy.Expr as the casting does nothing for those - if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: - result_type = sympy.Float - else: - assert a.is_Integer - assert b.is_Integer - result_type = sympy.Integer - return sympy.Max(result_type(a), result_type(b)) + return sympy.Max(a, b) @staticmethod - def floor(x): - return sympy.floor(x) + def round_to_int(a, dtype): + return RoundToInt(a) @staticmethod - def ceil(x): - return sympy.ceiling(x) + def round_decimal(a, b): + return RoundDecimal(a, b) # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain @@ -191,10 +207,20 @@ def not_(a): def floordiv(a, b): return a // b + @staticmethod + def mod(x, y): + return x % y + @staticmethod def truncdiv(a, b): return a / b + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return float(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + @staticmethod def exp(x): raise AssertionError("exp is not valid shape sympy expr") @@ -216,9 +242,40 @@ def maximum(a, b): return torch.sym_max(a, b) @staticmethod - def floor(x): + def floor_to_int(x, dtype): return math.floor(x) @staticmethod - def ceil(x): + def ceil_to_int(x, dtype): return math.ceil(x) + + @staticmethod + def floor(x): + return float(math.floor(x)) + + @staticmethod + def ceil(x): + return float(math.ceil(x)) + + @staticmethod + def truediv(a, b): + return a / b + + @staticmethod + def pow(a, b): + return a**b + + @staticmethod + def pow_by_natural(a, b): + # Pray that safe_pow is not needed here lol. In particular, this + # never participates in VR low/high ranges, so overflow should be + # unlikely + return a**b + + @staticmethod + def round_to_int(a, dtype): + return round(a) + + @staticmethod + def round_decimal(a, b): + return round(a, ndigits=b) diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index 6276c696293c..02ddf7c34219 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -88,6 +88,7 @@ def try_solve( # Return if we were able to isolate 'thing' on the left-hand side. if isinstance(e, sympy.Rel) and e.lhs == thing: + log.debug("solved: %s ---> %s", expr, e) return e, e.rhs return None diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index c7cc96beb980..4d364d4981b5 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -5,6 +5,7 @@ import logging import math import operator +import sys from typing import ( Callable, Dict, @@ -25,17 +26,20 @@ from torch._prims_common import dtype_to_type from .functions import ( - OpaqueUnaryFn_acos, - OpaqueUnaryFn_asinh, - OpaqueUnaryFn_atan, - OpaqueUnaryFn_cosh, + _keep_float, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, - OpaqueUnaryFn_sinh, OpaqueUnaryFn_sqrt, - OpaqueUnaryFn_tanh, - Round, + PowByNatural, RoundDecimal, + RoundToInt, + safe_pow, + ToFloat, + TruncToFloat, + TruncToInt, ) from .interp import sympy_interp @@ -120,6 +124,8 @@ class ValueRanges(Generic[_T]): lower: _T upper: _T is_bool: bool + is_int: bool + is_float: bool @overload def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: @@ -142,8 +148,39 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: # Because this is a frozen class object.__setattr__(self, "lower", lower) object.__setattr__(self, "upper", upper) + # Unlike bool/int in Python, we don't report bools are ints object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean)) - assert isinstance(upper, SympyBoolean) == self.is_bool + if self.is_bool: + assert isinstance(upper, SympyBoolean), (lower, upper) + + # Warning: is_int/is_float is best effort. We do pretty well in + # Dynamo, but in Inductor these attributes are often wrong because we + # are not very rigorous in dtype analysis. This is also why we need + # the flexible analysis for is_int: sometimes a sympy.oo pops in for + # an integer bound. I would /like/ for us not to do this, but it's + # too hard to push the invariant through right now. + + object.__setattr__( + self, + "is_int", + not self.is_bool + and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)), + ) + """ + # This assert is just impossible right now, too many sympy bugs + if self.is_int: + # NB: sympy will sometimes randomly lose the float-ness of zero, + # so we also need to account for that in the assertion here. + # See also https://github.com/sympy/sympy/issues/26620 + assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], ( + lower, + upper, + ) + assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper) + """ + # NB: [-oo, oo] always advertises as float! + object.__setattr__(self, "is_float", not self.is_bool and not self.is_int) + assert self.is_bool or self.is_int or self.is_float, (lower, upper) def boolify(self) -> ValueRanges[SympyBoolean]: if vr_is_bool(self): @@ -184,6 +221,8 @@ def __and__(self: AllVR, other: AllVR) -> AllVR: if self == ValueRanges.unknown(): return other assert self.is_bool == other.is_bool, (self, other) + assert self.is_int == other.is_int, (self, other) + assert self.is_float == other.is_float, (self, other) if self.is_bool: return ValueRanges( sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper) @@ -353,7 +392,12 @@ def constant(value, dtype): # using nan makes subsequent computation throw, and for the purposes of optimization # returning -math.inf - math.inf is equivalent to giving up if isinstance(value, SupportsFloat) and math.isnan(value): - return ValueRanges.unknown() + if dtype == torch.bool: + return ValueRanges.unknown_bool() + elif dtype.is_floating_point: + return ValueRanges.unknown() + else: + return ValueRanges(-sys.maxsize - 1, sys.maxsize) if is_python: type_ = dtype_to_type(dtype) @@ -369,7 +413,18 @@ def constant(value, dtype): # dtype is intXX assert value.is_integer - return ValueRanges.wrap(value) + r = ValueRanges.wrap(value) + return r + + @staticmethod + def to_dtype(a, dtype, src_dtype=None): + if dtype == torch.float64: + return ValueRanges.increasing_map(a, ToFloat) + return ValueRanges.unknown() + + @staticmethod + def trunc_to_int(a, dtype): + return ValueRanges.increasing_map(a, TruncToInt) @staticmethod def not_(a): @@ -428,7 +483,9 @@ def ge(cls, a, b): @staticmethod def add(a, b): - return ValueRanges.coordinatewise_increasing_map(a, b, operator.add) + return ValueRanges.coordinatewise_increasing_map( + a, b, _keep_float(operator.add) + ) @classmethod def mul(cls, a, b): @@ -448,11 +505,20 @@ def safe_mul(a, b): else: return a * b - return ValueRanges.coordinatewise_monotone_map(a, b, safe_mul) + return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul)) - @classmethod - def div(cls, a, b): - return cls.truediv(a, b) + @staticmethod + def int_truediv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if 0 in b or ( + (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + ): + return ValueRanges.unknown() + else: + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(IntTrueDiv) + ) @staticmethod def truediv(a, b): @@ -463,18 +529,22 @@ def truediv(a, b): ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, operator.truediv) + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(FloatTrueDiv) + ) @staticmethod def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) if 0 in b or ( - (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + # TODO: make this more precise + (-sympy.oo in a or sympy.oo in a) + or (-sympy.oo in b or sympy.oo in b) ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, operator.floordiv) + return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv) @classmethod def mod(cls, x, y): @@ -523,17 +593,51 @@ def modular_indexing(cls, a, b, c): @classmethod def is_non_overlapping_and_dense_indicator(cls, *args): - return ValueRanges.unknown() + return ValueRanges.unknown() # TODO: type here is wrong @classmethod - def pow(cls, a, b): - def is_integer(val): - return isinstance(val, int) or ( - hasattr(val, "is_integer") and val.is_integer + def pow_by_natural(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if a.is_singleton() and b.is_singleton(): + return ValueRanges.wrap(safe_pow(a.lower, b.lower)) + # NB: Exclude zero, because zero is special + elif a.lower >= 1: + # We should know that b >= 0 but we may have forgotten this fact due + # to replacements, so don't assert it, but DO clamp it to prevent + # degenerate problems + return ValueRanges.coordinatewise_increasing_map( + a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural + ) + elif b.is_singleton(): + if b.lower % 2 == 0: + # x^n where n is even + return ValueRanges.convex_min_zero_map( + a, lambda x: safe_pow(x, b.lower) + ) + else: + # x^n where n is odd + return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower)) + else: + # a is potentially negative, and we don't know if the exponent is + # even or odd. So just conservatively set the upper and lower + # bound based on what the maximum absolute value could be, in both + # directions + max_base = max(a.upper, -a.lower) + return ValueRanges( + -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper) ) + @classmethod + def pow(cls, a, b): + return ValueRanges.unknown() + + # We could implement all this, but for floating point pow, is there + # really a point? + """ a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) + # Not implemented yet. It's a bit tricky # If you want to implement it, compute the partial derivatives of a ** b # and check the ranges where the function is increasing / decreasing @@ -553,8 +657,7 @@ def is_integer(val): if b == 0: if not a.lower.is_finite: return ValueRanges.unknown() - type_ = sympy.Float if a.lower.is_real else sympy.Integer - return ValueRanges.wrap(type_(1)) + return ValueRanges.wrap(1.0) if b < 0: a = cls.reciprocal(a) @@ -563,21 +666,12 @@ def is_integer(val): if a == ValueRanges.unknown(): return ValueRanges.unknown() - # Here b > 0 - if not is_integer(b): - # If the base is positive, then we're good, otherwise nothing's defined - if a.lower >= 0: - return ValueRanges.increasing_map(a, lambda x: x**b) - else: - return ValueRanges.unknown() + # If the base is positive, then we're good, otherwise nothing's defined + if a.lower >= 0: + return ValueRanges.increasing_map(a, lambda x: x**b) else: - # b > 0 integer - if b % 2 == 0: - # x^n where n is even - return ValueRanges.convex_min_zero_map(a, lambda x: x**b) - else: - # x^n where n is odd - return ValueRanges.increasing_map(a, lambda x: x**b) + return ValueRanges.unknown() + """ @staticmethod def reciprocal(x): @@ -586,7 +680,7 @@ def reciprocal(x): if 0 in x: return ValueRanges.unknown() else: - return ValueRanges.decreasing_map(x, lambda y: 1 / y) + return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y)) @staticmethod def abs(x): @@ -615,45 +709,64 @@ def maximum(cls, a, b): def min_or_max(a, b, fn): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) + return ValueRanges.coordinatewise_increasing_map(a, b, fn) - # Performs upcasting first - def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr: - # Poorman's version of upcasting in Sympy - # Inf is not a float... - if x.is_Integer and y.is_Integer: - result_type = sympy.Integer - elif x.is_rational and y.is_rational: - result_type = sympy.Rational - else: - assert x.is_real or not x.is_finite or y.is_real or not y.is_finite - result_type = sympy.Float - return fn(result_type(x), result_type(y)) + @classmethod + def floor_to_int(cls, x, dtype): + return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) - return ValueRanges.coordinatewise_increasing_map(a, b, fn_) + @classmethod + def ceil_to_int(cls, x, dtype): + return ValueRanges.increasing_map( + x, sympy.functions.elementary.integers.ceiling + ) + + # I think these implementations are sound. The hazard here is that sympy + # will carry out the floor/ceil at too high precision and then something + # bad will happen when we convert it to float. + # + # For truncation, the implementation is clearly sound, because the desired + # target float is always exactly representable, since you're just chopping + # off bits the mantissa. But what about ceil/floor? + # + # The important constraint here is that we're not defining floor on + # arbitrary real numbers, only representable float numbers. So we can + # take advantage of the fact that before we reach the first + # unrepresentable integer in floating point space, we have the range of + # numbers corresponding to exponent zero: all integers, with no fractional + # amounts. floor/ceil is an identity operation in this case. In the + # range below here, representable floating point numbers are spaced + # exactly 1/2 apart, and notably, both the floor/ceil are defined floating + # point numbers. There is no "gap" as you step up to the next exponent. @classmethod def floor(cls, x): - return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) + return ValueRanges.increasing_map( + x, _keep_float(sympy.functions.elementary.integers.floor) + ) @classmethod def ceil(cls, x): return ValueRanges.increasing_map( - x, sympy.functions.elementary.integers.ceiling + x, _keep_float(sympy.functions.elementary.integers.ceiling) ) @classmethod - def round(cls, number, ndigits=None): - if ndigits is None: - fn = Round - else: - assert ndigits.is_singleton() - ndigits = ndigits.lower - # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind - # the second parameter. - fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 + def round_decimal(cls, number, ndigits): + if not ndigits.is_singleton(): + return ValueRanges.unknown() + + ndigits = ndigits.lower + # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind + # the second parameter. + fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 return ValueRanges.increasing_map(number, fn) + @classmethod + def round_to_int(cls, number, dtype): + return ValueRanges.increasing_map(number, RoundToInt) + # It's used in some models on symints @staticmethod def sqrt(x): @@ -708,12 +821,15 @@ def cos(x): @staticmethod def cosh(x): + return ValueRanges(0.0, sympy.oo) + """ x = ValueRanges.wrap(x) if x.lower > 0: return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh) elif x.upper < 0: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh) return ValueRanges(0.0, sympy.oo) + """ @staticmethod def sin(x): @@ -723,7 +839,8 @@ def sin(x): @staticmethod def sinh(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + return ValueRanges(-sympy.oo, sympy.oo) @staticmethod def tan(x): @@ -731,32 +848,37 @@ def tan(x): @staticmethod def tanh(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + return ValueRanges(-sympy.oo, sympy.oo) @staticmethod def asin(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh) return ValueRanges.unknown() + """ @staticmethod def acos(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos) return ValueRanges.unknown() + """ @staticmethod def atan(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) + return ValueRanges(-sympy.oo, sympy.oo) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) @staticmethod def trunc(x): - def trunc(x): - return sympy.Integer(x) if x.is_finite else x - - return ValueRanges.increasing_map(x, trunc) + return ValueRanges.increasing_map(x, TruncToFloat) class ValueRangeAnalysis(SymPyValueRangeAnalysis): @@ -791,9 +913,10 @@ def store(self, name, index, value, mode=None): def reduction(self, name, dtype, src_dtype, reduction_type, index, value): return ValueRanges.unknown() - def index_expr(self, index, dtype): + @classmethod + def index_expr(cls, index, dtype): assert isinstance(index, ValueRanges) - return index + return cls.to_dtype(index, dtype) @staticmethod def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): @@ -830,12 +953,15 @@ def cast(x, dtype): @staticmethod def square(x): - return ValueRanges.convex_min_zero_map(x, lambda y: y * y) + return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) @staticmethod def neg(x): return ValueRanges.decreasing_map(x, operator.neg) + # TODO: this is slightly inaccurate because truncdiv operates at integer + # precision, but we're going through float truediv which means we can + # potentially lose precision on the bounds @classmethod def truncdiv(cls, a, b): x = cls.truediv(a, b) @@ -856,6 +982,7 @@ def __getattr__(self, name): def bound_sympy( expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: + log.debug("bound_sympy(%s, %s)", expr, ranges) if isinstance(expr, sympy.Number): return ValueRanges.wrap(expr) From 4c074a9b8bd2e6d8940b40a41ce399e6c4a463a9 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 4 Jun 2024 12:53:19 +0000 Subject: [PATCH 316/706] Revert "[torchbind] always fakify script object by default in non-strict export (#127116)" This reverts commit c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad. Reverted https://github.com/pytorch/pytorch/pull/127116 on behalf of https://github.com/atalman due to Failing internal tests ([comment](https://github.com/pytorch/pytorch/pull/127116#issuecomment-2147459339)) --- test/inductor/test_torchbind.py | 23 ++++++++++++++++++++--- torch/_export/non_strict_utils.py | 2 ++ torch/_functorch/aot_autograd.py | 6 +++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index fd8f15d8212c..e1bb0ad36d0b 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -1,4 +1,5 @@ # Owner(s): ["module: functorch"] +import unittest import torch import torch._dynamo @@ -7,14 +8,30 @@ import torch._inductor.decomposition from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch._inductor.test_case import run_tests, TestCase - -from torch.testing._internal.torchbind_impls import init_torchbind_implementations +from torch.testing._internal.common_utils import ( + find_library_location, + IS_FBCODE, + IS_MACOS, + IS_SANDCASTLE, + IS_WINDOWS, +) class TestTorchbind(TestCase): def setUp(self): super().setUp() - init_torchbind_implementations() + if IS_MACOS: + raise unittest.SkipTest("non-portable load_library call used in test") + elif IS_SANDCASTLE or IS_FBCODE: + torch.ops.load_library( + "//caffe2/test/cpp/jit:test_custom_class_registrations" + ) + elif IS_WINDOWS: + lib_file_path = find_library_location("torchbind_test.dll") + torch.ops.load_library(str(lib_file_path)) + else: + lib_file_path = find_library_location("libtorchbind_test.so") + torch.ops.load_library(str(lib_file_path)) def get_exported_model(self): """ diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 638a7db7e537..d15cb29f28df 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -444,6 +444,8 @@ def _fakify_script_objects( fake_to_real = {} def _maybe_fakify_obj(obj): + if not torch._library.fake_class_registry.has_fake_class(obj._type().qualified_name()): # type: ignore[attr-defined] + return obj fake_obj = torch._library.fake_class_registry.to_fake_obj(fake_mode, obj) fake_to_real[fake_obj] = obj return fake_obj diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index a339b63a3ea8..4dc854781e40 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -493,7 +493,11 @@ def convert(idx, x): return shape_env.create_symintnode( shape_env.create_symbol(x, source), hint=x, source=source ) - if isinstance(x, torch.ScriptObject): + if isinstance( + x, torch.ScriptObject + ) and torch._library.fake_class_registry.has_fake_class( + x._type().qualified_name() + ): return torch._library.fake_class_registry.to_fake_obj(fake_mode, x) if not isinstance(x, torch.Tensor): return x From 059cae617697836e605ede9903a31f8286793bfd Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 4 Jun 2024 14:22:21 +0000 Subject: [PATCH 317/706] [Caffe2] Remove Caffe2 proto and other files (#127655) Remove Caffe2 proto files altogether. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127655 Approved by: https://github.com/ezyang --- BUILD.bazel | 8 - c2_defs.bzl | 509 ------------ c2_test_defs.bzl | 20 - caffe2/CMakeLists.txt | 56 -- caffe2/core/logging.h | 3 - caffe2/core/scope_guard.h | 158 ---- caffe2/core/types.cc | 80 -- caffe2/core/types.h | 83 -- caffe2/perfkernels/CMakeLists.txt | 3 - caffe2/perfkernels/embedding_lookup.cc | 3 +- caffe2/perfkernels/embedding_lookup_idx.cc | 3 +- .../fused_8bit_rowwise_embedding_lookup.cc | 3 +- ...fused_8bit_rowwise_embedding_lookup_idx.cc | 3 +- caffe2/perfkernels/typed_axpy.cc | 3 +- caffe2/proto/BUILD.bazel | 37 - caffe2/proto/CMakeLists.txt | 19 - caffe2/proto/__init__.py | 23 - caffe2/proto/caffe2.proto | 528 ------------ caffe2/proto/caffe2_pb.h | 135 --- caffe2/proto/caffe2_pb2.pyi | 767 ------------------ caffe2/proto/gen_proto_typestubs.sh | 52 -- caffe2/proto/gen_proto_typestubs_helper.py | 15 - caffe2/proto/torch.proto | 114 --- caffe2/proto/torch_pb2.pyi | 218 ----- caffe2/utils/cpuid.cc | 83 -- caffe2/utils/cpuid.h | 146 ---- caffe2/utils/threadpool/ThreadPool.cc | 1 - caffe2/utils/threadpool/WorkersPool.h | 3 +- caffe2/utils/threadpool/pthreadpool.cc | 3 +- torch/CMakeLists.txt | 1 - torch/csrc/jit/runtime/static/impl.cpp | 1 - 31 files changed, 9 insertions(+), 3072 deletions(-) delete mode 100644 c2_defs.bzl delete mode 100644 c2_test_defs.bzl delete mode 100644 caffe2/core/logging.h delete mode 100644 caffe2/core/scope_guard.h delete mode 100644 caffe2/core/types.cc delete mode 100644 caffe2/core/types.h delete mode 100644 caffe2/proto/BUILD.bazel delete mode 100644 caffe2/proto/CMakeLists.txt delete mode 100644 caffe2/proto/__init__.py delete mode 100644 caffe2/proto/caffe2.proto delete mode 100644 caffe2/proto/caffe2_pb.h delete mode 100644 caffe2/proto/caffe2_pb2.pyi delete mode 100755 caffe2/proto/gen_proto_typestubs.sh delete mode 100644 caffe2/proto/gen_proto_typestubs_helper.py delete mode 100644 caffe2/proto/torch.proto delete mode 100644 caffe2/proto/torch_pb2.pyi delete mode 100644 caffe2/utils/cpuid.cc delete mode 100644 caffe2/utils/cpuid.h diff --git a/BUILD.bazel b/BUILD.bazel index 7a2c3a523dfc..71ebc296598c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -455,7 +455,6 @@ filegroup( name = "caffe2_core_srcs", srcs = [ "caffe2/core/common.cc", - "caffe2/core/types.cc", ], ) @@ -488,7 +487,6 @@ filegroup( filegroup( name = "caffe2_utils_srcs", srcs = [ - "caffe2/utils/cpuid.cc", "caffe2/utils/proto_wrap.cc", "caffe2/utils/string_utils.cc", "caffe2/utils/threadpool/ThreadPool.cc", @@ -507,12 +505,9 @@ cc_library( name = "caffe2_for_aten_headers", hdrs = [ "caffe2/core/common.h", - "caffe2/core/logging.h", - "caffe2/core/types.h", "caffe2/perfkernels/common.h", "caffe2/perfkernels/embedding_lookup.h", "caffe2/perfkernels/embedding_lookup_idx.h", - "caffe2/utils/cpuid.h", "caffe2/utils/fixed_divisor.h", ] + glob([ "caffe2/utils/threadpool/*.h", @@ -522,7 +517,6 @@ cc_library( deps = [ ":caffe2_core_macros", "//c10", - "//caffe2/proto:caffe2_pb", ], ) @@ -547,7 +541,6 @@ cc_library( deps = [ ":caffe2_core_macros", ":caffe2_for_aten_headers", - "//caffe2/proto:caffe2_pb", ], ) @@ -568,7 +561,6 @@ cc_library( ":caffe2_perfkernels_avx", ":caffe2_perfkernels_avx2", ":caffe2_perfkernels_avx512", - "//caffe2/proto:caffe2_pb", "//third_party/miniz-2.1.0:miniz", "@com_google_protobuf//:protobuf", "@eigen", diff --git a/c2_defs.bzl b/c2_defs.bzl deleted file mode 100644 index 3cca448b394c..000000000000 --- a/c2_defs.bzl +++ /dev/null @@ -1,509 +0,0 @@ -load("@bazel_skylib//lib:collections.bzl", "collections") -load("@bazel_skylib//lib:paths.bzl", "paths") -load("@fbcode_macros//build_defs:native_rules.bzl", "buck_genrule") -load("@fbsource//tools/build_defs:default_platform_defs.bzl", "compose_platform_setting_list") -load("@fbsource//tools/build_defs:dict_defs.bzl", "dict_defs") -load("@fbsource//tools/build_defs:expect.bzl", "expect") -load("@fbsource//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") -load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode", "is_fbcode_mode_mac") -load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX", "WINDOWS") -load("@fbsource//tools/build_defs/apple:build_mode_defs.bzl", "is_production_build") -load("@fbsource//xplat/caffe2:buckbuild.bzl", "read_bool") -load("@fbsource//xplat/pfh/Msgr/Mobile/ProductInfra:DEFS.bzl", "Msgr_Mobile_ProductInfra") - -def get_c2_expose_op_to_c10(): - c2_op_to_c10 = native.read_config("caffe2", "expose_op_to_c10", "0") - - expect( - c2_op_to_c10 in ("0", "1"), - c2_op_to_c10, - ) - - return bool(int(c2_op_to_c10)) - -def get_c2_mpscnn(): - c2_mpscnn = native.read_config("caffe2", "enable_mpscnn", "1") - - expect( - c2_mpscnn in ("0", "1"), - c2_mpscnn, - ) - - return bool(int(c2_mpscnn)) - -def get_c2_mpscnn_test(): - c2_mpscnn_test = native.read_config("caffe2", "enable_mpscnn_test", "0") - - expect( - c2_mpscnn_test in ("0", "1"), - c2_mpscnn_test, - ) - - return bool(int(c2_mpscnn_test)) - -def get_c2_qpl(): - c2_qpl = native.read_config("caffe2", "enable_qpl", "1") - - expect( - c2_qpl in ("0", "1"), - c2_qpl, - ) - - return bool(int(c2_qpl)) - -def get_c2_strip_debug_info(): - c2_strip_debug_info = native.read_config("caffe2", "strip_debug_info", "0") - - expect( - c2_strip_debug_info in ("0", "1"), - c2_strip_debug_info, - ) - - return bool(int(c2_strip_debug_info)) - -def get_c2_strip_glog(): - c2_strip_glog = native.read_config("caffe2", "strip_glog", "1") - - expect( - c2_strip_glog in ("0", "1"), - c2_strip_glog, - ) - - return bool(int(c2_strip_glog)) - -def get_c2_tvm(): - c2_tvm = native.read_config("caffe2", "enable_tvm", "1") - - expect( - c2_tvm in ("0", "1"), - c2_tvm, - ) - - return bool(int(c2_tvm)) - -_C2_XPLAT_NO_HPTT_PREPROCESSOR_FLAGS = [ - "-Icaffe2", - "-Imodules", - "-DEIGEN_NO_DEBUG", - "-DCAFFE2_USE_LITE_PROTO", - "-DCAFFE2_USE_GOOGLE_GLOG", - "-DCAFFE2_RNN_NO_TEXT_FORMAT", - "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK=1", - "-DCAFFE2_IS_XPLAT_BUILD", - "-DSTRIP_ERROR_MESSAGES", - "-DUSE_INTERNAL_PTHREADPOOL_IMPL", -] - -def get_c2_xplat_no_hptt_preprocessor_flags(): - flags = [] - flags += _C2_XPLAT_NO_HPTT_PREPROCESSOR_FLAGS - if is_arvr_mode() and get_c2_strip_glog(): - flags += ["-UGOOGLE_STRIP_LOG", "-DGOOGLE_STRIP_LOG=1"] - if get_c2_expose_op_to_c10(): - flags += ["-DEXPOSE_C2_OPS", "-frtti"] - return flags - -C2_XPLAT_SERVER_PREPROCESSOR_FLAGS = [ - "-DCAFFE2_USE_EIGEN_FOR_BLAS", - "-DC10_DISABLE_SIGNAL_HANDLERS", - "-DCAFFE2_DISABLE_NUMA", -] - -C2_XPLAT_HPTT_PREPROCESSOR_FLAGS = [ - "-DCAFFE2_USE_HPTT", -] - -def get_c2_xplat_preprocessor_flags(): - flags = get_c2_xplat_no_hptt_preprocessor_flags() + C2_XPLAT_HPTT_PREPROCESSOR_FLAGS - return flags - -def get_c2_xplat_no_hptt_compiler_flags(): - return [ - "-Os", - "-fexceptions", - "-frtti", - "-Wno-shadow", - "-Wno-unknown-pragmas", - "-Wno-unused-variable", - "-Wno-sign-compare", - ] - -def get_c2_xplat_compiler_flags(): - return get_c2_xplat_no_hptt_compiler_flags() + C2_XPLAT_HPTT_PREPROCESSOR_FLAGS - -def get_c2_fbobjc_xplat_compiler_flags(): - flags = [] - - if is_production_build(): - flags.append("-DCAFFE2_NO_OPERATOR_SCHEMA") - - flags.append("-DCAFFE2_NO_GRADIENT_OPS") - - # For iOS production builds (and all Android builds), strip GLOG logging to - # save size. We can disable by setting caffe2.strip_glog=0 in .buckconfig.local. - if is_production_build() or get_c2_strip_glog(): - flags += ["-UGOOGLE_STRIP_LOG", "-DGOOGLE_STRIP_LOG=3"] - else: - flags.append("-UGOOGLE_STRIP_LOG") - - return flags - -def get_c2_fbandroid_xplat_compiler_flags(): - flags = [ - "-Wno-unused-but-set-variable", - "-DHAVE_MMAP", - ] - - if get_c2_strip_glog(): - flags += ["-UGOOGLE_STRIP_LOG", "-DGOOGLE_STRIP_LOG=1"] - - if get_c2_strip_debug_info(): - flags.append("-g0") - - return flags - -_C2_FBOBJC_COMPILER_FLAGS = [ - "-Wno-missing-prototypes", - "-Wno-global-constructors", - "-Wno-unknown-pragmas", - "-Wno-invalid-partial-specialization", - "-Wno-missing-braces", - "-Wno-range-loop-analysis", -] - -def get_c2_fbobjc_compiler_flags(): - flags = list(_C2_FBOBJC_COMPILER_FLAGS) - - # Avoid linking Accelerate on MacOS because we have - # inconsistent LAPACK headers (see problems in D19257077). - flags.append("-DCAFFE2_USE_ACCELERATE" if not is_arvr_mode() else "-DCAFFE2_USE_EIGEN_FOR_BLAS") - if get_c2_mpscnn(): - flags.append( - # TODO(t19120552) - fix this. MPSCNNConvolutionDescriptor.strideInPixelsX - # is marked as iOS 11+, but it's been available since iOS 10. - "-Wno-unguarded-availability", - ) - return flags - -C2_FBOBJC_MACOSX_COMPILER_FLAGS = [ - "-msse4.2", -] - -C2_FBOBJC_IPHONE_COMPILER_FLAGS = [ - "-mfpu=neon-fp16", -] - -def get_c2_fbobjc_frameworks(): - frameworks = [] - if not is_arvr_mode(): - frameworks.append( - # On iOS, presumably Accelerate is a faster BLAS - "$SDKROOT/System/Library/Frameworks/Accelerate.framework", - ) - return frameworks - -def get_c2_fbobjc_ios_frameworks(): - frameworks = [] - - if get_c2_mpscnn(): - frameworks.extend([ - "$SDKROOT/System/Library/Frameworks/Metal.framework", - "$SDKROOT/System/Library/Frameworks/MetalPerformanceShaders.framework", - ]) - - return frameworks - -def get_c2_fbobjc_exported_preprocessor_flags(): - flags = [] - - if get_c2_mpscnn(): - flags.append("-DCAFFE2_USE_MPSCNN") - - if get_c2_mpscnn_test(): - flags.append("-DCAFFE2_USE_MPSCNN_TEST") - - return flags - -def get_c2_fbandroid_exported_preprocessor_flags(): - flags = [] - - BUILD_MODE_DO_NOT_USE_WITHOUT_ASKING_SERIOUSLY = native.read_config( - "fbandroid", - "build_mode", - "dev", - ) - if BUILD_MODE_DO_NOT_USE_WITHOUT_ASKING_SERIOUSLY == "opt": - flags.append("-DCAFFE2_NO_OPERATOR_SCHEMA") - - flags.append("-DCAFFE2_NO_GRADIENT_OPS") - - return flags - -C2_FBANDROID_COMPILER_FLAGS = [ - "-DCAFFE2_USE_EIGEN_FOR_BLAS", - "-Wno-unknown-pragmas", - "-Wno-deprecated-declarations", - "-Wno-invalid-partial-specialization", - "-Wno-missing-braces", -] - -C2_FBANDROID_ARMV7_COMPILER_FLAGS = [ - "-mfpu=neon-fp16", -] - -C2_FBANDROID_X86_COMPILER_FLAGS = [ - "-mssse3", -] - -C2_FBANDROID_LINKER_FLAGS = [] - -C2_FBOBJC_EXTRA_TARGET_CONFIG = { - "MTL_LANGUAGE_REVISION": "Metal12", -} - -def get_c2_torch_vulkan_compiler_flags(): - return ["-Wno-missing-prototypes"] - -def get_c2_default_cxx_args(): - return dict( - header_namespace = "", - apple_sdks = (IOS, MACOSX), - compiler_flags = get_c2_xplat_compiler_flags(), - fbandroid_compiler_flags = C2_FBANDROID_COMPILER_FLAGS + get_c2_fbandroid_xplat_compiler_flags(), - fbandroid_exported_platform_preprocessor_flags = [ - ( - "android-armv7", - get_c2_fbandroid_exported_preprocessor_flags(), - ), - ], - fbandroid_linker_flags = C2_FBANDROID_LINKER_FLAGS, - fbandroid_platform_compiler_flags = [ - ("android-armv7", C2_FBANDROID_ARMV7_COMPILER_FLAGS), - (".*x86.*", C2_FBANDROID_X86_COMPILER_FLAGS), - ], - fbobjc_compiler_flags = get_c2_fbobjc_compiler_flags() + get_c2_fbobjc_xplat_compiler_flags(), - fbobjc_exported_platform_preprocessor_flags = [ - ( - "iphoneos", - get_c2_fbobjc_exported_preprocessor_flags(), - ), - ], - fbobjc_frameworks = get_c2_fbobjc_frameworks() + get_c2_fbobjc_ios_frameworks(), - fbobjc_platform_compiler_flags = [ - ("iphoneos", C2_FBOBJC_IPHONE_COMPILER_FLAGS), - ], - macosx_compiler_flags = C2_FBOBJC_MACOSX_COMPILER_FLAGS, - macosx_frameworks_override = get_c2_fbobjc_frameworks(), - preprocessor_flags = [ - # Use the internal pthreadpool impl for all Caffe2 targets on all - # platforms but do not export the preprocessor flag downstream. - "-DUSE_INTERNAL_PTHREADPOOL_IMPL", - ], - visibility = ["PUBLIC"], - windows_preferred_linkage = "static" if is_arvr_mode() else None, - ) - -def get_c2_aten_cpu_fbobjc_macosx_deps(): - return select({ - "DEFAULT": [], - "ovr_config//os:macos-x86_64": ["fbsource//xplat/deeplearning/fbgemm:fbgemm"], - }) if is_arvr_mode() else [] - -def build_cpukernel_avx2(): - return read_bool("caffe2", "build_cpukernel_avx2", not is_arvr_mode()) - -def get_c2_aten_cpu_fbobjc_macosx_platform_deps(): - return compose_platform_setting_list([ - { - "cpu": "x86_64", - "flags": [ - "fbsource//xplat/deeplearning/fbgemm:fbgemmAppleMac", - ] + ([ - "fbsource//xplat/caffe2:cpukernel_avx2AppleMac", - ] if build_cpukernel_avx2() else []), - "os": "macosx", - }, - { - "cpu": "arm64", - "flags": ["fbsource//xplat/third-party/XNNPACK:XNNPACKAppleMac"], - "os": "macosx", - }, - ]) - -def using_protobuf_v3(): - # Consider migrating this to `read_config("protobuf", "use_v3")` - # The `is_fbcode_mode_mac()` clause was added rather than changing to `read_config` to minimize changes in behavior - return is_arvr_mode() or is_fbcode_mode_mac() - -def get_c2_protobuf_dep(): - return "fbsource//third-party/protobuf:libprotobuf" if using_protobuf_v3() else "fbsource//xplat/third-party/protobuf:fb-protobuf-lite" - -def c2_cxx_library(fbobjc_compiler_flags = [], **kwargs): - args = get_c2_default_cxx_args() - args.update(kwargs) - args.setdefault("platforms", (ANDROID, APPLE, CXX, WINDOWS)) - - # Make sure we don't overwrite custom `fbobjc_compiler_flags` - args["fbobjc_compiler_flags"] = args.pop("fbobjc_compiler_flags", []) + fbobjc_compiler_flags - - fb_xplat_cxx_library( - labels = [ - "supermodule:android/default/caffe2", - "supermodule:ios/default/public.caffe2", - ], - feature = Msgr_Mobile_ProductInfra, - **args - ) - -def c2_protobuf_rule(protos): - cpps = [] - headers = {} - raw_headers = {} - for p in protos: - proto = paths.basename(p) - protocexe = "$(exe fbsource//third-party/protobuf:protoc-host)" if is_arvr_mode() else "$(location fbsource//xplat/third-party/protobuf:protoc.Windows)" - protocmd_exe = "powershell.exe -file $(location fbsource//xplat/caffe2/scripts:proto)\\proto.ps1 -Protoc {} -Unprocessed $SRCDIR/{} -Processed $SRCDIR/{} -out $OUT -srcdir $SRCDIR".format(protocexe, p, proto) - protocmd = ("cp $SRCDIR/{} $SRCDIR/{} && chmod +w $SRCDIR/{} && echo \"option optimize_for = LITE_RUNTIME;\" >> $SRCDIR/{} && ".format(p, proto, proto, proto) + - "cp $SRCDIR/caffe2/proto/caffe2.proto $SRCDIR/caffe2.proto && chmod +w $SRCDIR/caffe2.proto && echo \"option optimize_for = LITE_RUNTIME;\" >> $SRCDIR/caffe2.proto && " + - "sed -i -e 's/caffe2\\/proto\\/caffe2.proto/caffe2.proto/g' $SRCDIR/{} && ".format(proto) + - ("$(exe fbsource//third-party/protobuf:protoc-host) " if using_protobuf_v3() else "$(exe fbsource//xplat/third-party/protobuf:protoc) --osx $(location fbsource//xplat/third-party/protobuf:protoc.Darwin) --linux $(location fbsource//xplat/third-party/protobuf:protoc.Linux) ") + - "-I $SRCDIR --cpp_out=$OUT $SRCDIR/{}".format(proto)) - buck_genrule( - name = proto, - srcs = sorted(collections.uniq([p, "caffe2/proto/caffe2.proto"])), - cmd_exe = protocmd_exe, - bash = protocmd, - out = ".", - ) - (name, _) = paths.split_extension(proto) - cpp = name + ".pb.cc" - h = name + ".pb.h" - buck_genrule( - name = h, - cmd_exe = "@powershell -Command \" & { " + "(Get-Content $(location :{})\\{}".format(proto, h) + ") -replace \\\"caffe2.pb.h\\\", \\\"caffe2/proto/caffe2.pb.h\\\" | Set-Content $OUT } \"", - bash = "cp -f $(location :{})/{} $OUT && ".format(proto, h) + - "sed -i -e 's/caffe2.pb.h/caffe2\\/proto\\/caffe2.pb.h/g' $OUT", - out = h, - ) - headers["caffe2/proto/" + h] = ":{}".format(h) - raw_headers[h] = ":{}".format(h) - buck_genrule( - name = cpp, - cmd_exe = "@powershell -Command copy $(location :{})/{} $OUT".format(proto, cpp), - bash = "cp -f $(location :{})/{} $OUT".format(proto, cpp), - out = cpp, - ) - cpps.append(":{}".format(cpp)) - return (cpps, headers, raw_headers) - -# C2 uses lite version of protobuf while torch/jit uses some method only exists -# in full protobuf. This is a temporary workaround to enable experiment build. -# DO NOT USE IT IN PRODUCTION BUILD! -def c2_full_protobuf_rule(protos): - prefix = "full_" - cpps = [] - headers = {} - raw_headers = {} - for p in protos: - proto = paths.basename(p) - protocexe = "$(exe fbsource//third-party/protobuf:protoc-host)" if is_arvr_mode() else "$(location fbsource//xplat/third-party/protobuf:protoc.Windows)" - protocmd_exe = "powershell.exe -file $(location fbsource//xplat/caffe2/scripts:proto)\\proto.ps1 -Protoc {} -Unprocessed $SRCDIR/{} -Processed $SRCDIR/{} -out $OUT -srcdir $SRCDIR".format(protocexe, p, proto) - protocmd = ("cp $SRCDIR/{} $SRCDIR/{} && ".format(p, proto) + - "cp $SRCDIR/caffe2/proto/caffe2.proto $SRCDIR/caffe2.proto && " + - "sed -i -e 's/caffe2\\/proto\\/caffe2.proto/caffe2.proto/g' $SRCDIR/{} && ".format(proto) + - ("$(exe fbsource//third-party/protobuf:protoc-host) " if using_protobuf_v3() else "$(exe fbsource//xplat/third-party/protobuf:protoc) --osx $(location fbsource//xplat/third-party/protobuf:protoc.Darwin) --linux $(location fbsource//xplat/third-party/protobuf:protoc.Linux) ") + - "-I $SRCDIR --cpp_out=$OUT $SRCDIR/{}".format(proto)) - buck_genrule( - name = prefix + proto, - srcs = sorted(collections.uniq([p, "caffe2/proto/caffe2.proto"])), - cmd = protocmd, - cmd_exe = protocmd_exe, - out = ".", - ) - (name, _) = paths.split_extension(proto) - cpp = name + ".pb.cc" - h = name + ".pb.h" - buck_genrule( - name = prefix + h, - cmd_exe = "@powershell -Command \" & { " + "(Get-Content $(location :{})\\{}".format(prefix + proto, h) + ") -replace \\\"caffe2.pb.h\\\", \\\"caffe2/proto/caffe2.pb.h\\\" | Set-Content $OUT } \"", - bash = "cp -f $(location :{})/{} $OUT && ".format(prefix + proto, h) + - "sed -i -e 's/caffe2.pb.h/caffe2\\/proto\\/caffe2.pb.h/g' $OUT", - out = h, - ) - headers["caffe2/proto/" + h] = ":{}".format(prefix + h) - raw_headers[h] = ":{}".format(prefix + h) - buck_genrule( - name = prefix + cpp, - cmd_exe = "@powershell -Command copy $(location :{})/{} $OUT".format(prefix + proto, cpp), - bash = "cp -f $(location :{})/{} $OUT".format(prefix + proto, cpp), - out = cpp, - ) - cpps.append(":{}".format(prefix + cpp)) - return (cpps, headers, raw_headers) - -def libcaffe2_cxx_library(name, use_hptt, **kwargs): - c2_cxx_library( - name = name, - exported_deps = [ - "fbsource//xplat/caffe2/c10:c10", - get_c2_protobuf_dep(), - ":caffe2_protobuf_headers", - ":pthreadpool", - ":common_core", - ":caffe2_proto_types", - ], - compiler_flags = get_c2_xplat_compiler_flags() if use_hptt else get_c2_xplat_no_hptt_compiler_flags(), - exported_preprocessor_flags = get_c2_xplat_preprocessor_flags() if use_hptt else get_c2_xplat_no_hptt_preprocessor_flags(), - cxx_preprocessor_flags = C2_XPLAT_SERVER_PREPROCESSOR_FLAGS, - fbandroid_exported_preprocessor_flags = get_c2_fbandroid_xplat_compiler_flags(), - fbobjc_exported_preprocessor_flags = get_c2_fbobjc_xplat_compiler_flags(), - # Hack to work around lack of platform_srcs support in Xcode project generation. - macosx_extra_xcode_sources_override = [], - link_whole = True, - **kwargs - ) - -def c2_operator_library(name, **kwargs): - dict_defs.key_extend( - kwargs, - "deps", - [ - "fbsource//xplat/folly:molly", - "fbsource//third-party/glog:glog", - ":caffe2", - ] + ([":aten_cpu"] if get_c2_expose_op_to_c10() else []), - ) - - # NOTE: Currently operators can "depend" on other operators, which is used - # so that loading one will implicitly load the dependencies. So, make sure - # that no `--as-needed` flags pulled in from dependencies cause these - # operator deps to get dropped. - linker_flags = [] if (read_config("caffe2", "link_as_needed", "0") == "1") else ["-Wl,--no-as-needed"] - c2_cxx_library( - name = name, - soname = "lib" + name + ".$(ext)", - fbandroid_compiler_flags = get_c2_default_cxx_args()["fbandroid_compiler_flags"] + ["-Os"], - fbobjc_compiler_flags = get_c2_default_cxx_args()["fbobjc_compiler_flags"] + ["-Oz", "-DCOMPILING_FOR_MIN_SIZE=1"], - link_whole = True, - cxx_exported_linker_flags = linker_flags, - fbandroid_exported_linker_flags = linker_flags, - exported_deps = [ - ":caffe2", - ], - **kwargs - ) - -def c2_genrule(genrule, genfiles, prefix = "", src_path = "", header_namespace = ""): - headers = {} - srcs = [] - for generated_filename in genfiles: - buck_genrule( - name = prefix + generated_filename, - bash = "cp -f $(location :{})/{} $OUT".format(genrule, src_path + generated_filename), - cmd_exe = "@powershell -Command copy $(location :{})/{} $OUT".format(genrule, src_path + generated_filename), - out = generated_filename, - ) - rule = ":{}{}".format(prefix, generated_filename) - headers[header_namespace + generated_filename] = rule - srcs.append(rule) - return {"headers": headers, "srcs": srcs} diff --git a/c2_test_defs.bzl b/c2_test_defs.bzl deleted file mode 100644 index 8ef83073d6fa..000000000000 --- a/c2_test_defs.bzl +++ /dev/null @@ -1,20 +0,0 @@ -load("@fbsource//tools/build_defs:fb_xplat_cxx_test.bzl", "fb_xplat_cxx_test") -load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX") -load("@fbsource//xplat/caffe2:c2_defs.bzl", "get_c2_default_cxx_args") - -def c2_cxx_test(**kwargs): - args = get_c2_default_cxx_args() - args.update(kwargs) - args["fbandroid_use_instrumentation_test"] = True - for flag in [ - "macosx_compiler_flags", - "fbobjc_macosx_configs_override", - "macosx_frameworks_override", - "xcode_public_headers_symlinks", - "macosx_inherited_buck_flags_override", - ]: - args.pop(flag, None) - args["apple_sdks"] = (IOS, MACOSX) - args["platforms"] = (CXX, APPLE, ANDROID) - args["contacts"] = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"] - fb_xplat_cxx_test(**args) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index f7de195a6272..e9b2b20ce6ad 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -104,10 +104,6 @@ if(NOT USE_FBGEMM) add_subdirectory(perfkernels) endif() -if(NOT INTERN_BUILD_MOBILE) - add_subdirectory(proto) -endif() - # Advanced: if we have allow list specified, we will do intersections for all # main lib srcs. if(CAFFE2_ALLOWLISTED_FILES) @@ -218,44 +214,6 @@ if(PRINT_CMAKE_DEBUG_INFO) endif() -if(NOT INTERN_BUILD_MOBILE) - # ---[ List of libraries to link with - add_library(caffe2_protos STATIC $) - add_dependencies(caffe2_protos Caffe2_PROTO) - # If we are going to link protobuf locally inside caffe2 libraries, what we will do is - # to create a helper static library that always contains libprotobuf source files, and - # link the caffe2 related dependent libraries to it. - target_include_directories(caffe2_protos INTERFACE $) - # Reason for this public dependency is as follows: - # (1) Strictly speaking, we should not expose any Protobuf related functions. We should - # only use function interfaces wrapped with our own public API, and link protobuf - # locally. - # (2) However, currently across the Caffe2 codebase, we have extensive use of protobuf - # functionalities. For example, not only libcaffe2.so uses it, but also other - # binaries such as python extensions etc. As a result, we will have to have a - # transitive dependency to libprotobuf. - # - # Good thing is that, if we specify CAFFE2_LINK_LOCAL_PROTOBUF, then we do not need to - # separately deploy protobuf binaries - libcaffe2.so will contain all functionalities - # one needs. One can verify this via ldd. - # - # TODO item in the future includes: - # (1) Enable using lite protobuf - # (2) Properly define public API that do not directly depend on protobuf itself. - # (3) Expose the libprotobuf.a file for dependent libraries to link to. - # - # What it means for users/developers? - # (1) Users: nothing affecting the users, other than the fact that CAFFE2_LINK_LOCAL_PROTOBUF - # avoids the need to deploy protobuf. - # (2) Developers: if one simply uses core caffe2 functionality without using protobuf, - # nothing changes. If one has a dependent library that uses protobuf, then one needs to - # have the right protobuf version as well as linking to libprotobuf.a. - target_link_libraries(caffe2_protos PUBLIC protobuf::libprotobuf) - if(NOT BUILD_SHARED_LIBS) - install(TARGETS caffe2_protos ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}") - endif() -endif() - # ========================================================== # formerly-libtorch # ========================================================== @@ -1117,13 +1075,6 @@ endif() # formerly-libtorch flags # ========================================================== -if(NOT INTERN_BUILD_MOBILE) - # Forces caffe2.pb.h to be generated before its dependents are compiled. - # Adding the generated header file to the ${TORCH_SRCS} list is not sufficient - # to establish the dependency, since the generation procedure is declared in a different CMake file. - # See https://samthursfield.wordpress.com/2015/11/21/cmake-dependencies-between-targets-and-files-and-custom-commands/#custom-commands-in-different-directories - add_dependencies(torch_cpu Caffe2_PROTO) -endif() # Build model tracer for tracing-based selective build if(TRACING_BASED AND NOT BUILD_LITE_INTERPRETER AND NOT INTERN_BUILD_MOBILE) @@ -1413,8 +1364,6 @@ if(USE_DISTRIBUTED) endif() if(NOT INTERN_BUILD_MOBILE) - caffe2_interface_library(caffe2_protos caffe2_protos_whole) - target_link_libraries(torch_cpu PRIVATE caffe2_protos_whole) if(${CAFFE2_LINK_LOCAL_PROTOBUF}) target_link_libraries(torch_cpu INTERFACE protobuf::libprotobuf) else() @@ -1981,11 +1930,6 @@ if(BUILD_PYTHON) set_source_files_properties(${TORCH_SRC_DIR}/../caffe2/operators/box_with_nms_limit_op.cc PROPERTIES COMPILE_FLAGS -Wno-attributes) endif() - # generated pb files are copied from build/caffe2 to caffe2 - # if we copied them back to build this would create a build cycle - # consider removing the need for globs - filter_list_exclude(PYTHON_SRCS PYTHON_SRCS "proto/.*_pb") - set(build_files) foreach(python_src ${PYTHON_SRCS}) add_custom_command(OUTPUT ${CMAKE_BINARY_DIR}/${python_src} diff --git a/caffe2/core/logging.h b/caffe2/core/logging.h deleted file mode 100644 index f47c0581b855..000000000000 --- a/caffe2/core/logging.h +++ /dev/null @@ -1,3 +0,0 @@ -#pragma once -#include "c10/util/Logging.h" -#include "caffe2/core/common.h" diff --git a/caffe2/core/scope_guard.h b/caffe2/core/scope_guard.h deleted file mode 100644 index ee412a424de4..000000000000 --- a/caffe2/core/scope_guard.h +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2016 Facebook - * @author Tudor Bosman (tudorb@fb.com) - */ - -#pragma once - -#include -#include -#include -#include -#include - -namespace caffe2 { - -// Copied from folly/ScopeGuard.h - -namespace detail { - -class ScopeGuardImplBase { - public: - void dismiss() noexcept { - dismissed_ = true; - } - - protected: - ScopeGuardImplBase() noexcept : dismissed_(false) {} - - static ScopeGuardImplBase makeEmptyScopeGuard() noexcept { - return ScopeGuardImplBase{}; - } - - template - static const T& asConst(const T& t) noexcept { - return t; - } - - bool dismissed_; -}; - -template -class ScopeGuardImpl : public ScopeGuardImplBase { - public: - explicit ScopeGuardImpl(FunctionType& fn) noexcept( - std::is_nothrow_copy_constructible::value) - : ScopeGuardImpl( - asConst(fn), - makeFailsafe(std::is_nothrow_copy_constructible{}, - &fn)) {} - - explicit ScopeGuardImpl(const FunctionType& fn) noexcept( - std::is_nothrow_copy_constructible::value) - : ScopeGuardImpl( - fn, - makeFailsafe(std::is_nothrow_copy_constructible{}, - &fn)) {} - - explicit ScopeGuardImpl(FunctionType&& fn) noexcept( - std::is_nothrow_move_constructible::value) - : ScopeGuardImpl( - std::move_if_noexcept(fn), - makeFailsafe(std::is_nothrow_move_constructible{}, - &fn)) {} - - ScopeGuardImpl(ScopeGuardImpl&& other) noexcept( - std::is_nothrow_move_constructible::value) - : function_(std::move_if_noexcept(other.function_)) { - // If the above line attempts a copy and the copy throws, other is - // left owning the cleanup action and will execute it (or not) depending - // on the value of other.dismissed_. The following lines only execute - // if the move/copy succeeded, in which case *this assumes ownership of - // the cleanup action and dismisses other. - dismissed_ = other.dismissed_; - other.dismissed_ = true; - } - - ~ScopeGuardImpl() noexcept { - if (!dismissed_) { - execute(); - } - } - - private: - static ScopeGuardImplBase makeFailsafe(std::true_type, const void*) noexcept { - return makeEmptyScopeGuard(); - } - - template - static auto makeFailsafe(std::false_type, Fn* fn) noexcept - -> ScopeGuardImpl { - return ScopeGuardImpl{std::ref(*fn)}; - } - - template - explicit ScopeGuardImpl(Fn&& fn, ScopeGuardImplBase&& failsafe) - : ScopeGuardImplBase{}, function_(std::forward(fn)) { - failsafe.dismiss(); - } - - void* operator new(std::size_t) = delete; - - void execute() noexcept { function_(); } - - FunctionType function_; -}; - -template -using ScopeGuardImplDecay = ScopeGuardImpl::type>; - -} // namespace detail - -/** - * ScopeGuard is a general implementation of the "Initialization is - * Resource Acquisition" idiom. Basically, it guarantees that a function - * is executed upon leaving the current scope unless otherwise told. - * - * The MakeGuard() function is used to create a new ScopeGuard object. - * It can be instantiated with a lambda function, a std::function, - * a functor, or a void(*)() function pointer. - * - * - * Usage example: Add a friend to memory iff it is also added to the db. - * - * void User::addFriend(User& newFriend) { - * // add the friend to memory - * friends_.push_back(&newFriend); - * - * // If the db insertion that follows fails, we should - * // remove it from memory. - * auto guard = MakeGuard([&] { friends_.pop_back(); }); - * - * // this will throw an exception upon error, which - * // makes the ScopeGuard execute UserCont::pop_back() - * // once the Guard's destructor is called. - * db_->addFriend(GetName(), newFriend.GetName()); - * - * // an exception was not thrown, so don't execute - * // the Guard. - * guard.dismiss(); - * } - * - * Examine ScopeGuardTest.cpp for some more sample usage. - * - * Stolen from: - * Andrei's and Petru Marginean's CUJ article: - * http://drdobbs.com/184403758 - * and the loki library: - * http://loki-lib.sourceforge.net/index.php?n=Idioms.ScopeGuardPointer - * and triendl.kj article: - * http://www.codeproject.com/KB/cpp/scope_guard.aspx - */ -template -detail::ScopeGuardImplDecay MakeGuard(F&& f) noexcept( - noexcept(detail::ScopeGuardImplDecay(static_cast(f)))) { - return detail::ScopeGuardImplDecay(static_cast(f)); -} - -} // namespaces diff --git a/caffe2/core/types.cc b/caffe2/core/types.cc deleted file mode 100644 index dfba94ad06ae..000000000000 --- a/caffe2/core/types.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "caffe2/core/types.h" -#include - -#include -#include -#include - -namespace caffe2 { - -TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) { - static_assert( - sizeof(int) == 4, "int in this compiler does not equal to 4 bytes."); - - // Can't use a switch because `meta_id` is not an integer type - const auto meta_id = meta.id(); - if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_FLOAT; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT32; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_STRING; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_BOOL; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_UINT8; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT8; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_UINT16; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT16; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT64; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_FLOAT16; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_DOUBLE; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT8; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_UINT8; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT32; - } else { - return TensorProto_DataType_UNDEFINED; - } -} - -const TypeMeta DataTypeToTypeMeta(const TensorProto_DataType& dt) { - switch (dt) { - case TensorProto_DataType_FLOAT: - return TypeMeta::Make(); - case TensorProto_DataType_INT32: - return TypeMeta::Make(); - case TensorProto_DataType_BYTE: - return TypeMeta::Make(); - case TensorProto_DataType_STRING: - return TypeMeta::Make(); - case TensorProto_DataType_BOOL: - return TypeMeta::Make(); - case TensorProto_DataType_UINT8: - return TypeMeta::Make(); - case TensorProto_DataType_INT8: - return TypeMeta::Make(); - case TensorProto_DataType_UINT16: - return TypeMeta::Make(); - case TensorProto_DataType_INT16: - return TypeMeta::Make(); - case TensorProto_DataType_INT64: - return TypeMeta::Make(); - case TensorProto_DataType_FLOAT16: - return TypeMeta::Make(); - case TensorProto_DataType_DOUBLE: - return TypeMeta::Make(); - default: - throw std::runtime_error("Unknown data type."); - }; -} - -} // namespace caffe2 diff --git a/caffe2/core/types.h b/caffe2/core/types.h deleted file mode 100644 index f83a58910e66..000000000000 --- a/caffe2/core/types.h +++ /dev/null @@ -1,83 +0,0 @@ -#ifndef CAFFE2_CORE_TYPES_H_ -#define CAFFE2_CORE_TYPES_H_ - -#include -#include -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" -#include -#include "caffe2/proto/caffe2_pb.h" -#include - -namespace caffe2 { - -// Storage orders that are often used in the image applications. -enum StorageOrder { - UNKNOWN = 0, - NHWC = 1, - NCHW = 2, -}; - -inline StorageOrder StringToStorageOrder(const string& str) { - if (str == "NHWC" || str == "nhwc") { - return StorageOrder::NHWC; - } else if (str == "NCHW" || str == "nchw") { - return StorageOrder::NCHW; - } else { - LOG(ERROR) << "Unknown storage order string: " << str; - return StorageOrder::UNKNOWN; - } -} - -inline int32_t GetDimFromOrderString(const std::string& str) { - auto order = StringToStorageOrder(str); - switch (order) { - case StorageOrder::NHWC: - return 3; - case StorageOrder::NCHW: - return 1; - default: - CAFFE_THROW("Unsupported storage order: ", str); - return -1; - } -} - -inline constexpr char NameScopeSeparator() { return '/'; } - -// From TypeMeta to caffe2::DataType protobuffer enum. -TORCH_API TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta); - -// From caffe2::DataType protobuffer enum to TypeMeta -TORCH_API const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt); - -} // namespace caffe2 - -/////////////////////////////////////////////////////////////////////////////// -// at::Half is defined in c10/util/Half.h. Currently half float operators are -// mainly on CUDA gpus. -// The reason we do not directly use the cuda __half data type is because that -// requires compilation with nvcc. The float16 data type should be compatible -// with the cuda __half data type, but will allow us to refer to the data type -// without the need of cuda. -static_assert(sizeof(unsigned short) == 2, - "Short on this platform is not 16 bit."); -namespace caffe2 { -// Helpers to avoid using typeinfo with -rtti -template -inline bool fp16_type(); - -template <> -inline bool fp16_type() { - return true; -} - -template -inline bool fp16_type() { - return false; -} - -} // namespace caffe2 - -#endif // CAFFE2_CORE_TYPES_H_ diff --git a/caffe2/perfkernels/CMakeLists.txt b/caffe2/perfkernels/CMakeLists.txt index 9510ec60dfef..3d08e5c0a7bb 100644 --- a/caffe2/perfkernels/CMakeLists.txt +++ b/caffe2/perfkernels/CMakeLists.txt @@ -24,8 +24,6 @@ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs}) if(CXX_AVX2_FOUND) add_library(Caffe2_perfkernels_avx STATIC ${avx_srcs}) add_library(Caffe2_perfkernels_avx2 STATIC ${avx2_srcs}) - add_dependencies(Caffe2_perfkernels_avx Caffe2_PROTO) - add_dependencies(Caffe2_perfkernels_avx2 Caffe2_PROTO) target_link_libraries(Caffe2_perfkernels_avx PRIVATE c10) target_link_libraries(Caffe2_perfkernels_avx2 PRIVATE c10) install(TARGETS Caffe2_perfkernels_avx Caffe2_perfkernels_avx2 @@ -62,7 +60,6 @@ if(CXX_AVX2_FOUND) if(CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) add_library(Caffe2_perfkernels_avx512 STATIC ${avx512_srcs}) - add_dependencies(Caffe2_perfkernels_avx512 Caffe2_PROTO) target_link_libraries(Caffe2_perfkernels_avx512 PRIVATE c10) install(TARGETS Caffe2_perfkernels_avx512 ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}") diff --git a/caffe2/perfkernels/embedding_lookup.cc b/caffe2/perfkernels/embedding_lookup.cc index 687d081301e4..96ae253b32c6 100644 --- a/caffe2/perfkernels/embedding_lookup.cc +++ b/caffe2/perfkernels/embedding_lookup.cc @@ -1,8 +1,9 @@ #include "caffe2/perfkernels/embedding_lookup.h" -#include "caffe2/core/types.h" #include "caffe2/perfkernels/common.h" +#include +#include #include namespace caffe2 { diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index 48c869ee7038..c9b91dc31b88 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -2,9 +2,8 @@ #include #include +#include #include -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" #include "caffe2/perfkernels/common.h" namespace caffe2 { diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc index b1522ecda7e2..d919f22c5795 100644 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc @@ -1,9 +1,8 @@ #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h" -#include "caffe2/core/types.h" #include "caffe2/perfkernels/common.h" -#include "caffe2/utils/cpuid.h" +#include #include namespace caffe2 { diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc index 866298226af0..8f7e926c0e9c 100644 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc @@ -1,9 +1,8 @@ #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h" -#include "caffe2/core/types.h" #include "caffe2/perfkernels/common.h" -#include "caffe2/utils/cpuid.h" +#include #include namespace caffe2 { diff --git a/caffe2/perfkernels/typed_axpy.cc b/caffe2/perfkernels/typed_axpy.cc index b8128ab951a4..400041766e61 100644 --- a/caffe2/perfkernels/typed_axpy.cc +++ b/caffe2/perfkernels/typed_axpy.cc @@ -1,7 +1,6 @@ +#include #include "caffe2/perfkernels/typed_axpy.h" -#include "caffe2/core/types.h" #include "caffe2/perfkernels/common.h" -#include "caffe2/utils/cpuid.h" namespace caffe2 { diff --git a/caffe2/proto/BUILD.bazel b/caffe2/proto/BUILD.bazel deleted file mode 100644 index 58766661ac67..000000000000 --- a/caffe2/proto/BUILD.bazel +++ /dev/null @@ -1,37 +0,0 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library") - -cc_library( - name = "caffe2_pb", - hdrs = ["caffe2_pb.h"], - visibility = [ - "//:__pkg__", - ], - deps = [ - ":caffe2_cc_proto", - "//c10/core:base", - "//c10/util:base", - ], -) - -cc_proto_library( - name = "caffe2_cc_proto", - deps = [":caffe2_proto"], -) - -proto_library( - name = "caffe2_proto", - srcs = ["caffe2.proto"], -) - -cc_proto_library( - name = "torch_cc_proto", - visibility = ["//:__pkg__"], # used in torch - deps = [":torch_proto"], -) - -proto_library( - name = "torch_proto", - srcs = ["torch.proto"], - deps = [":caffe2_proto"], -) diff --git a/caffe2/proto/CMakeLists.txt b/caffe2/proto/CMakeLists.txt deleted file mode 100644 index bdbc045afb3d..000000000000 --- a/caffe2/proto/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -set(Caffe2_PROTOBUF_FILES "${CMAKE_CURRENT_SOURCE_DIR}/torch.proto;${CMAKE_CURRENT_SOURCE_DIR}/caffe2.proto") - -caffe2_protobuf_generate_cpp_py(Caffe2_PROTO_SRCS Caffe2_PROTO_HEADERS Caffe2_PROTO_PY ${Caffe2_PROTOBUF_FILES}) - -add_library(Caffe2_PROTO OBJECT ${Caffe2_PROTO_HEADERS} ${Caffe2_PROTO_SRCS}) - -if(MSVC) - if(BUILD_SHARED_LIBS) - set(TORCH_API_DEFINE "-DTORCH_API=__declspec(dllexport)") - else() - set(TORCH_API_DEFINE "-DTORCH_API=") - endif() -else() - set(TORCH_API_DEFINE "-DTORCH_API=") -endif() -target_compile_definitions( - Caffe2_PROTO PRIVATE ${TORCH_API_DEFINE}) - -install(FILES ${Caffe2_PROTO_HEADERS} DESTINATION include/caffe2/proto) diff --git a/caffe2/proto/__init__.py b/caffe2/proto/__init__.py deleted file mode 100644 index c40ca97189d1..000000000000 --- a/caffe2/proto/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -import warnings - - -# NOTE: we have to import python protobuf here **before** we load cpp extension. -# Otherwise it breaks under certain build conditions if cpp implementation of -# protobuf is used. Presumably there's some registry in protobuf library and -# python side has to initialize the dictionary first, before static -# initialization in python extension does so. Otherwise, duplicated protobuf -# descriptors will be created and it can lead to obscure errors like -# "Parameter to MergeFrom() must be instance of same class: -# expected caffe2.NetDef got caffe2.NetDef." -# -# This has to be done for all python targets, so listing them here -try: - from caffe2.proto import caffe2_pb2, metanet_pb2, torch_pb2 -except ImportError: - warnings.warn('Caffe2 support is no longer present in PyTorch.') - raise - -try: - from caffe2.caffe2.fb.session.proto import session_pb2 -except ImportError: - pass diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto deleted file mode 100644 index 077e7b0ed544..000000000000 --- a/caffe2/proto/caffe2.proto +++ /dev/null @@ -1,528 +0,0 @@ -syntax = "proto2"; - -package caffe2; - -// A few notes about the Caffe2's protobuffer convention: -// (1) Most objects are registered by their types, such as operators and nets. -// For these, we have a string-type field "type" for registration purposes. -// (2) We do not use extension because that used to create quite some conflicts -// in Caffe's protobuf design. -// (3) We have not used any proto3 specific features, such as Any or Map. This -// is mainly for backward compatibility purposes but we may consider using -// those in the future. - -// TensorProto stores serialized Tensor objects. -message TensorProto { - // The dimensions in the tensor. - repeated int64 dims = 1; - - // Data type - enum DataType { - UNDEFINED = 0; - - // Basic types - FLOAT = 1; // float - INT32 = 2; // int - BYTE = 3; // byte, when deserialized, is going to be restored as uint8 - STRING = 4; // string - - // Less-commonly used data types - BOOL = 5; // bool - UINT8 = 6; // uint8_t - INT8 = 7; // int8_t - UINT16 = 8; // uint16_t - INT16 = 9; // int16_t - INT64 = 10; // int64_t - FLOAT16 = 12; // at::Half - DOUBLE = 13; // double - - ZERO_COLLISION_HASH = 14; // zero-collision hash state - REBATCHING_BUFFER = 15; // rebatching buffer - } - // The type of the deserialized tensor data - optional DataType data_type = 2 [ default = FLOAT ]; - - // The format of the serialized data. - enum SerializationFormat { - // FMT_PROTOBUF is the existing serialization format from before the - // data_format field was introduced. Most data types are serialized using - // the protobuf typed fields, although in some cases raw little endian data - // is stored in the byte_data field instead. - FMT_PROTOBUF = 0; - // bfloat16 data stored in the raw_data field. - FMT_BFLOAT16 = 1; - } - // data_format is a SerializationFormat enum value. - // However, we intentionally store it as an integer value so we can - // distinguish between old messages that do not have a data_format value vs - // new messages that have a SerializationFormat value that we don't - // understand. If we stored this as an enum then protobuf would deserialize - // both of these cases the same way. - optional uint32 data_format = 15 [ default = 0 ]; - - // For float - repeated float float_data = 3 [ packed = true ]; - // For int32, uint8, int8, uint16, int16, bool, and float16 - // Note about float16: in storage we will basically convert float16 byte-wise - // to unsigned short and then store them in the int32_data field. - // Note: storing int8 and uint8 values in this field unfortunately results in - // larger serialized data than necessary, as protobuf's varint encoding - // scheme requires 2 bytes to represent int8 and uint8 values that have the - // MSB set. - repeated int32 int32_data = 4 [ packed = true ]; - // For bytes - optional bytes byte_data = 5; - // For strings - repeated bytes string_data = 6; - // For double - repeated double double_data = 9 [ packed = true ]; - // For int64 - repeated int64 int64_data = 10 [ packed = true ]; - // store the raw data, contents are serialized as little-endian - optional bytes raw_data = 13; - - // Optionally, a name for the tensor. - optional string name = 7; - - // Optionally, a TensorProto can contain the details about the device that - // it was serialized from. This is useful in cases like snapshotting a whole - // workspace in a multi-GPU environment. - optional DeviceOption device_detail = 8; - - // When loading from chunks this is going to indicate where to put data in the - // full array. When not used full data have to be present - message Segment { - required int64 begin = 1; - required int64 end = 2; - } - optional Segment segment = 11; - - // Field numbers 12 and 14 were previously used for now-deprecated fields. - // reserved 12, 14; -} - -message QTensorProto { - repeated int64 dims = 1; - required int32 precision = 2; - required double scale = 3; - required double bias = 4; - required bool is_signed = 5; - repeated int32 data = 6 [ packed = true ]; - optional string name = 7; - optional TensorProto.DataType data_type = 8 [ default = INT32 ]; - - // Multi-group quantization params - repeated double scales = 9; - repeated double biases = 10; - - // Multi-group quantization needed, indicates in which dimension - // we do the "group wise quantization" - optional int32 axis = 11; - - // It should be true if it is a multi-group quantization proto - optional bool is_multiparam = 12 [ default = false ]; -} - -// TensorProtos stores multiple TensorProto objects in one single proto. This -// is useful for small tensors; For anything big, consider using a DB for -// storage. -message TensorProtos { - repeated TensorProto protos = 1; -} - -message TensorShape { - repeated int64 dims = 1; - optional TensorProto.DataType data_type = 2 [ default = FLOAT ]; - repeated int32 unknown_dims = 3; - optional bool unknown_shape = 4 [ default = false ]; - optional string name = 5; -} - -message TensorShapes { - repeated TensorShape shapes = 1; -} - -// TensorBoundShape is used to save bound shape inference result for a tensor. -// TensorBoundShape.shape is inferred shape for this tensor. -// TensorBoundShape.dimType contains dim_type for every dimension. -// eg: for dimension i, shape.dims[i] is the inferred shape and -// dim_type[i] is corresponding dim_type. -message TensorBoundShape { - optional TensorShape shape = 1; - enum DimType { - UNKNOWN = 0; // unknown - CONSTANT = 1; // constant - // batch, corresponding dimension is batch_size - BATCH = 2; - // batch_of_feature_max, - // corresponding shape is inferred_feature_length * batch_size - BATCH_OF_FEATURE_MAX = 3; - // batch_of_feature_max_default - // corresponding shape is default_feature_length * batch_size - BATCH_OF_FEATURE_MAX_DEFAULT = 4; - // feature_max, corresponding shape is inferred_feature_length - FEATURE_MAX = 5; - // feature_max_default, corresponding shape is default_feature_length - FEATURE_MAX_DEFAULT = 6; - } - repeated DimType dim_type = 2; // dim_type.size() == shape.dims.size() - optional string name = 3; - // a flag to indicate whether the shape is final and cannot be changed - // eg: input/output of in-place ops - optional bool shape_is_final = 4; -} - -message TensorBoundShapes { - repeated TensorBoundShape shapes = 1; - optional int64 max_batch_size = 2; - optional int64 max_feature_len = 3; -} - -message AOTConfig { - required int64 max_batch_size = 1; - required int64 max_seq_size = 2; - required bool in_batch_broadcast = 3; - optional string onnxifi_blacklist_ops = 4; - optional int32 onnxifi_min_ops = 5; -} - -// A named argument containing either singular float, integer and string -// values, or repeated float, int and string arrays. -message Argument { - optional string name = 1; - - optional float f = 2; - optional int64 i = 3; - optional bytes s = 4; - optional TensorProto t = 10; - optional NetDef n = 8; - - repeated float floats = 5; - repeated int64 ints = 6; - repeated bytes strings = 7; - repeated TensorProto tensors = 11; - repeated NetDef nets = 9; - repeated QTensorProto qtensors = 12; -} - -// DeviceType that Caffe2 currently supports. -// Note: if you add a device type, make sure you add the corresponding device -// line in the DeviceTypeName() function in caffe2/utils/proto_utils.cc -// and update c10/core/DeviceType.h -enum DeviceTypeProto { - PROTO_CPU = 0; // In default, we will use CPU. - PROTO_CUDA = 1; // CUDA. - PROTO_MKLDNN = 2; // Reserved for explicit MKLDNN - PROTO_OPENGL = 3; // OpenGL - PROTO_OPENCL = 4; // OpenCL - PROTO_IDEEP = 5; // IDEEP. - PROTO_HIP = 6; // AMD HIP - PROTO_FPGA = 7; // FPGA - PROTO_MAIA = 8; // MAIA - PROTO_XLA = 9; // XLA / TPU - PROTO_MPS = 10; // MPS - // Change the following number if you add more devices in the code. - PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 11; -} - -// Device-specific options. We do not distinguish DeviceOption protos for -// different DeviceTypes, so currently all devices share the same DeviceOption -// proto. Fields that are specific to a device type is ignored if the type does -// not match. -// Note: if you add fields to the DeviceOption, make sure you add the -// corresponding changes to IsSameDevice() function in utils/proto_utils.{h,cc}. -message DeviceOption { - // [general] Options that need to be carried out before running the execution. - // optional DeviceType device_type = 1 [ default = CPU ]; - optional int32 device_type = 1 [ default = 0 ]; // 0 is CPU. - // [general] Used together with device_type to identify the exact device - optional int32 device_id = 2; - // [general] The random seed to start the device random number generator with. - optional uint32 random_seed = 3; - // [general] What node this op should execute on. - // Used for net transformation purposes. Must be empty at execution time. - optional string node_name = 4; - // [CPU and Linux specific] NUMA node id - optional int32 numa_node_id = 5; - // [general] Extra information passed, not used at execution time currently. - repeated string extra_info = 6; -} - -// Operator Definition. -message OperatorDef { - repeated string input = 1; // the name of the input blobs - repeated string output = 2; // the name of output top blobs - optional string name = 3; // the operator name. This is optional. - // the operator type. This is needed to create the object from the operator - // registry. - optional string type = 4; - // arg is for the argument defined in operator schema - repeated Argument arg = 5; - - // The device option that the operator should run under. - optional DeviceOption device_option = 6; - - // Optionally, one can specify an engine when there are multiple - // implementations available simultaneously for one device type. - // If one specifies an engine but that engine does not exist in the compiled - // Caffe2 binary, Caffe2 will fall back to the default engine of that device - // type. - optional string engine = 7; - - // Additional 'fake' inputs used for expressing control dependencies - // in the operator graph. This can be used to ensure that an - // operator does not run until another operator is ready, for e.g. - // scheduling control. These are not passed as actual inputs to the - // Operator implementation, and are only used by the Net class for - // scheduling purposes. - repeated string control_input = 8; - - // is_gradient_op argument is only used as a hint in shape inference - // and has no runtime significance - optional bool is_gradient_op = 9 [ default = false ]; - - // debug information associated with the construction of the operator. - // This is an optional string with no assumed characteristics as - // operators can be constructed in any language. - optional string debug_info = 10; - - // the domain of the operator to help runtime distinguish which operator - // library this OperatorDef refers to. For example, both caffe2 and aten - // has `Add` operator, with domain, we can easily decide which operator - // to execute. to support multiple operator libs, we use domain to - // distinguish which operator lib we refer to: - // - "caffe2" means this uses Caffe2 operator library - // - "aten" means this uses ATen operator library - // - "c10" is for the fused library - // - if the domain is missing or empty, we use "caffe2", this is for - // legacy models, new serializer should always export an OperatorDef - // with domain and op_version - optional string domain = 11; - // each operator is has its own version number. - // operator version information - // each time, we change the API or semantics of the operator, - // we bump the version for the operator. - // the runtime system should check the op_version of each OperatorDef - // and decide it should reject or accept the model - optional int64 op_version = 12; -} - -// MapFieldEntry follows the pattern for cross-proto-version maps. -// See https://developers.google.com/protocol-buffers/docs/proto3#maps -message MapFieldEntry { - required string key = 1; - required string val = 2; -}; - -// Used to hold backend-specific options. -message BackendOptions { - // Name of the backend that the specified options apply to. - required string backend_name = 1; - // Flexible map for passing in the options. - repeated MapFieldEntry option = 2; -}; - -// Partition definition. -message PartitionInfo { - // Name of the partition. - required string name = 1; - - // A list of logic device ID, indicating which devices this partition - // can be executed on. If empty, it means the partition won't run on - // device but on host CPU instead. - repeated int32 device_id = 2; - - // Extra debug info. - optional string extra_info = 3; - - // Flexible map for passing options specific to a backend. - repeated BackendOptions backend_options = 4; -} - -// Network definition. -message NetDef { - optional string name = 1; // the network's name - // Operators that the network contains. - // Note: this is not named "operator" because that is a reserved word in C++. - repeated OperatorDef op = 2; - - // The type of network that the net should be run with. This routes the - // network instantiation to different execution modes. The default mode, - // "simple", runs the operators in a sequential way as the original Caffe - // implementation does. - optional string type = 3; - - // the number of workers, if the operators in the network is to be carried out - // in parallel. - // Note: This is to be deprecated. Using the arg field with "num_workers" as - // key. - // Note 2: The old uses of this were never actually cleaned up - optional int32 num_workers = 4; - - // The device option for the network. If a network has a specific device - // option and one of its operators does not have it set, we will copy over the - // device option to the operator. This allows us to basically avoid putting - // device options at every operator. - optional DeviceOption device_option = 5; - - repeated Argument arg = 6; - - // Two optional fields to declare external input and output of a net. - // If these two are set, when a net is created, we will sanity check for - // every op whether its input is declared (either as an external input, - // or as an intermediate blob created by one of the ops), and sanity check - // if all blobs in external_output are produced. - // - // In cases of memory optimization, declaring external_input and - // external_output also ensures that storage of these blobs are persistent: - // for any blob in external_input and external_output, after a network run - // finishes, their content are actually the right content. Any intermediate - // blobs' contents may be overwritten. - repeated string external_input = 7; - repeated string external_output = 8; - - // Partitioning info, indexed by partition names. - repeated PartitionInfo partition_info = 9; -} - -// ExecutionStep is actually a sort-of-hacky way we simulate iteration right -// now. -message ExecutionStep { - // ExecutionStep should either contain a set of substeps, or a set of - // network names to run in this execution step. They should NOT both be set - // at the same time. - optional string name = 1; - // An execution step could be recursive, in which it involves a set of - // substeps. - repeated ExecutionStep substep = 2; - // Alternatively, an execution step could involve one or more networks. - // Note that you cannot have both substeps and networks. Choose one. - // Note that an execution step refers networks by their name. The actual - // network definition of the same name should be included in the network field - // of the plan. The reason is that a network object might hold internal states - // (think of a data layer), so we want to have the same network object that - // multiple steps could ask to run. - repeated string network = 3; - // Number of iterations to run this step. The substeps or the networks - // specified will be run sequentially, and one sequential run is considered - // one iteration. If this is not set, the number of iterations is assumed to - // be 1. - optional int64 num_iter = 4; - - // Criteria network specifies a single output (TensorCPU) of - // size (1), is run on every iteration by the executor, and - // execution terminates when the output[0] is `false`. - optional string criteria_network = 5 [ deprecated = true ]; - - // DEPRECATED. Use `run_every_ms`. - optional string report_net = 7; - optional int32 report_interval = 8; - - // If provided, execute this step at every time interval (in millisecs) - // while its sibiling execution steps execute in parallel. This step is - // guaranteed to run at least once after all non-interval siblings finished. - optional int64 run_every_ms = 11; - - // If false or not set, execute sub-steps serially. - // If true, execute all substeps concurrently, each one in a separate thread. - optional bool concurrent_substeps = 6; - - // Name of a scalar boolean tensor. - // ES checks this blob AFTER every substeps/subnets. - // If specified, and the value is true, then ES will skip the rest and return - // immediately. - // This means that the report_net and the first step will always be called. - // Use cases: - // 1) the first substep stops the rest if data condition not met - // 2) the first substep decide which of the rest of the steps should be run. - // 3) external control - // - // ** It is the user's responsibility to not to put this blob in race - // conditions. - // ** For example when setting this blob in concurrent substeps - optional string should_stop_blob = 9; - - // if only_once is true, this step will only be executed once. this ONLY takes - // effect when using should_stop_blob - optional bool only_once = 10; - - // Whether to create a child workspace for this step. - // If yes, the workflow and nets are re-created every time this step is run. - optional bool create_workspace = 12; - - // How many copies of the children execution steps to run concurrently. - optional int32 num_concurrent_instances = 13; -} - -message PlanDef { - // All the networks that are used in this execution. Note that networks should - // be ordered in the way they are executed, i.e. for a layer in a network, all - // its input blobs should already have been initialized by the layers or - // networks defined before it. - optional string name = 1; - // The networks that are going to be used in this plan. - repeated NetDef network = 2; - repeated ExecutionStep execution_step = 3; -} - -// Protobuf format for blobs that are not Tensors. We use a key to store the -// type of the blob. For example for a serialized DBProto, the type should -// be "DBReader" and the content should be a serialized DBProto object. -message BlobProto { - optional string name = 1; - optional string type = 2; - optional TensorProto tensor = 3; - optional bytes content = 4; - optional QTensorProto qtensor = 5; - // If blob is not Tensor and is divided into chunks, content_num_chunks - // contains number of chunks, into which blob was divided. - optional int32 content_num_chunks = 6; - optional int32 content_chunk_id = 7; -} - -// Protobuf format to serialize DBReader. -message DBReaderProto { - // The name for the DB object in the workspace. - optional string name = 1; - // The source of the DB - optional string source = 2; - // The type of the DB - optional string db_type = 3; - // The current key of the DB if the DB supports seeking. - optional string key = 4; -} - -message BlobSerializationOptions { - // This set of options will only apply to blobs whose name matches this - // pattern. If the blob_name_pattern is empty then it will be treated as - // matching all blobs. - optional string blob_name_regex = 1; - - // Note: - // - a chunk_size of 0 means "use the default chunk size". The default chunk - // size is controlled by the --caffe2_tensor_chunk_size command line flag. - // - a chunk size of -1 means to disable chunking, and serialize the blob in - // a single chunk. - optional int64 chunk_size = 2; - - enum FloatFormat { - // Use the current default serialization format, as chosen by the - // current version of the code. (At the time of writing this is PROTOBUF) - FLOAT_DEFAULT = 0; - // Store the data in the TensorProto's float_data field - FLOAT_PROTOBUF = 1; - // Serialize float values as bfloat16. Note that this conversion is lossy. - FLOAT_BFLOAT16 = 2; - } - - // Settings for how to serialize tensors containing float values - optional FloatFormat float_format = 3; -} - -message SerializationOptions { - // A set of options to use when serialializing blobs. - // This is a list, sorted from highest to lowest precedence. When - // serializing a blob, the first entry whose blob_name_pattern matches the - // blob name will be used. - repeated BlobSerializationOptions options = 1; -} diff --git a/caffe2/proto/caffe2_pb.h b/caffe2/proto/caffe2_pb.h deleted file mode 100644 index fc82659dc51d..000000000000 --- a/caffe2/proto/caffe2_pb.h +++ /dev/null @@ -1,135 +0,0 @@ -#pragma once -#include -#include -#include - -namespace caffe2 { - -using DeviceType = at::DeviceType; -constexpr DeviceType CPU = DeviceType::CPU; -constexpr DeviceType CUDA = DeviceType::CUDA; -constexpr DeviceType OPENGL = DeviceType::OPENGL; -constexpr DeviceType OPENCL = DeviceType::OPENCL; -constexpr DeviceType MKLDNN = DeviceType::MKLDNN; -constexpr DeviceType IDEEP = DeviceType::IDEEP; -constexpr DeviceType HIP = DeviceType::HIP; -constexpr DeviceType COMPILE_TIME_MAX_DEVICE_TYPES = - DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES; - -inline TORCH_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) { - switch (p) { - case caffe2::PROTO_CPU: - return DeviceType::CPU; - case caffe2::PROTO_CUDA: - return DeviceType::CUDA; - case caffe2::PROTO_OPENGL: - return DeviceType::OPENGL; - case caffe2::PROTO_OPENCL: - return DeviceType::OPENCL; - case caffe2::PROTO_MKLDNN: - return DeviceType::MKLDNN; - case caffe2::PROTO_IDEEP: - return DeviceType::IDEEP; - case caffe2::PROTO_HIP: - return DeviceType::HIP; - case caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES: - return DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES; - default: - AT_ERROR( - "Unknown device:", - static_cast(p), - ". If you have recently updated the caffe2.proto file to add a new " - "device type, did you forget to update the ProtoToType() and TypeToProto" - "function to reflect such recent changes?"); - } -} - -inline TORCH_API DeviceType ProtoToType(int p) { - return ProtoToType(static_cast(p)); -} - -inline TORCH_API DeviceTypeProto TypeToProto(const DeviceType& t) { - switch (t) { - case DeviceType::CPU: - return caffe2::PROTO_CPU; - case DeviceType::CUDA: - return caffe2::PROTO_CUDA; - case DeviceType::OPENGL: - return caffe2::PROTO_OPENGL; - case DeviceType::OPENCL: - return caffe2::PROTO_OPENCL; - case DeviceType::MKLDNN: - return caffe2::PROTO_MKLDNN; - case DeviceType::IDEEP: - return caffe2::PROTO_IDEEP; - case DeviceType::HIP: - return caffe2::PROTO_HIP; - case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES: - return caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES; - default: - AT_ERROR( - "Unknown device:", - static_cast(t), - ". If you have recently updated the caffe2.proto file to add a new " - "device type, did you forget to update the ProtoToType() and TypeToProto" - "function to reflect such recent changes?"); - } -} - -inline TORCH_API caffe2::DeviceOption DeviceToOption(const at::Device& device) { - caffe2::DeviceOption option; - auto type = device.type(); - option.set_device_type(TypeToProto(type)); - - switch (type) { - case DeviceType::CPU: - if (device.index() != -1) { - option.set_numa_node_id(device.index()); - } - break; - case DeviceType::CUDA: - case DeviceType::HIP: - option.set_device_id(device.index()); - break; - case DeviceType::OPENGL: - case DeviceType::OPENCL: - case DeviceType::MKLDNN: - case DeviceType::IDEEP: - case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES: - break; - default: - AT_ERROR( - "Unknown device:", - static_cast(type), - ". If you have recently updated the caffe2.proto file to add a new " - "device type, did you forget to update the ProtoToType() and TypeToProto" - "function to reflect such recent changes?"); - } - return option; -} - -inline TORCH_API at::Device OptionToDevice(const caffe2::DeviceOption& option) { - auto type = option.device_type(); - c10::DeviceIndex id = -1; - switch (type) { - case caffe2::PROTO_CPU: - if (option.has_numa_node_id()) { - id = static_cast(option.numa_node_id()); - } - break; - case caffe2::PROTO_CUDA: - case caffe2::PROTO_HIP: - id = static_cast(option.device_id()); - break; - } - return at::Device(ProtoToType(type), id); -} - -inline void ExtractDeviceOption( - DeviceOption* device_option, - const at::Device& device) { - AT_ASSERT(device_option); - device_option->CopyFrom(DeviceToOption(device)); -} - -} // namespace caffe2 diff --git a/caffe2/proto/caffe2_pb2.pyi b/caffe2/proto/caffe2_pb2.pyi deleted file mode 100644 index 43249ebf75db..000000000000 --- a/caffe2/proto/caffe2_pb2.pyi +++ /dev/null @@ -1,767 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -global___DeviceTypeProto = DeviceTypeProto -class _DeviceTypeProto(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[DeviceTypeProto], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - PROTO_CPU = DeviceTypeProto.V(0) - PROTO_CUDA = DeviceTypeProto.V(1) - PROTO_MKLDNN = DeviceTypeProto.V(2) - PROTO_OPENGL = DeviceTypeProto.V(3) - PROTO_OPENCL = DeviceTypeProto.V(4) - PROTO_IDEEP = DeviceTypeProto.V(5) - PROTO_HIP = DeviceTypeProto.V(6) - PROTO_FPGA = DeviceTypeProto.V(7) - PROTO_MAIA = DeviceTypeProto.V(8) - PROTO_XLA = DeviceTypeProto.V(9) - PROTO_MPS = DeviceTypeProto.V(10) - PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11) -class DeviceTypeProto(metaclass=_DeviceTypeProto): - V = typing.NewType('V', int) -PROTO_CPU = DeviceTypeProto.V(0) -PROTO_CUDA = DeviceTypeProto.V(1) -PROTO_MKLDNN = DeviceTypeProto.V(2) -PROTO_OPENGL = DeviceTypeProto.V(3) -PROTO_OPENCL = DeviceTypeProto.V(4) -PROTO_IDEEP = DeviceTypeProto.V(5) -PROTO_HIP = DeviceTypeProto.V(6) -PROTO_FPGA = DeviceTypeProto.V(7) -PROTO_MAIA = DeviceTypeProto.V(8) -PROTO_XLA = DeviceTypeProto.V(9) -PROTO_MPS = DeviceTypeProto.V(10) -PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11) - -class TensorProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - class _DataType(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[DataType], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - UNDEFINED = TensorProto.DataType.V(0) - FLOAT = TensorProto.DataType.V(1) - INT32 = TensorProto.DataType.V(2) - BYTE = TensorProto.DataType.V(3) - STRING = TensorProto.DataType.V(4) - BOOL = TensorProto.DataType.V(5) - UINT8 = TensorProto.DataType.V(6) - INT8 = TensorProto.DataType.V(7) - UINT16 = TensorProto.DataType.V(8) - INT16 = TensorProto.DataType.V(9) - INT64 = TensorProto.DataType.V(10) - FLOAT16 = TensorProto.DataType.V(12) - DOUBLE = TensorProto.DataType.V(13) - ZERO_COLLISION_HASH = TensorProto.DataType.V(14) - REBATCHING_BUFFER = TensorProto.DataType.V(15) - class DataType(metaclass=_DataType): - V = typing.NewType('V', int) - UNDEFINED = TensorProto.DataType.V(0) - FLOAT = TensorProto.DataType.V(1) - INT32 = TensorProto.DataType.V(2) - BYTE = TensorProto.DataType.V(3) - STRING = TensorProto.DataType.V(4) - BOOL = TensorProto.DataType.V(5) - UINT8 = TensorProto.DataType.V(6) - INT8 = TensorProto.DataType.V(7) - UINT16 = TensorProto.DataType.V(8) - INT16 = TensorProto.DataType.V(9) - INT64 = TensorProto.DataType.V(10) - FLOAT16 = TensorProto.DataType.V(12) - DOUBLE = TensorProto.DataType.V(13) - ZERO_COLLISION_HASH = TensorProto.DataType.V(14) - REBATCHING_BUFFER = TensorProto.DataType.V(15) - - class _SerializationFormat(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[SerializationFormat], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - FMT_PROTOBUF = TensorProto.SerializationFormat.V(0) - FMT_BFLOAT16 = TensorProto.SerializationFormat.V(1) - class SerializationFormat(metaclass=_SerializationFormat): - V = typing.NewType('V', int) - FMT_PROTOBUF = TensorProto.SerializationFormat.V(0) - FMT_BFLOAT16 = TensorProto.SerializationFormat.V(1) - - class Segment(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - BEGIN_FIELD_NUMBER: int - END_FIELD_NUMBER: int - begin: int = ... - end: int = ... - - def __init__(self, - *, - begin : typing.Optional[int] = ..., - end : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"begin",b"begin",u"end",b"end"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"begin",b"begin",u"end",b"end"]) -> None: ... - - DIMS_FIELD_NUMBER: int - DATA_TYPE_FIELD_NUMBER: int - DATA_FORMAT_FIELD_NUMBER: int - FLOAT_DATA_FIELD_NUMBER: int - INT32_DATA_FIELD_NUMBER: int - BYTE_DATA_FIELD_NUMBER: int - STRING_DATA_FIELD_NUMBER: int - DOUBLE_DATA_FIELD_NUMBER: int - INT64_DATA_FIELD_NUMBER: int - RAW_DATA_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - DEVICE_DETAIL_FIELD_NUMBER: int - SEGMENT_FIELD_NUMBER: int - dims: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - data_type: global___TensorProto.DataType = ... - data_format: int = ... - float_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - int32_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - byte_data: bytes = ... - string_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[bytes] = ... - double_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - int64_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - raw_data: bytes = ... - name: typing.Text = ... - - @property - def device_detail(self) -> global___DeviceOption: ... - - @property - def segment(self) -> global___TensorProto.Segment: ... - - def __init__(self, - *, - dims : typing.Optional[typing.Iterable[int]] = ..., - data_type : typing.Optional[global___TensorProto.DataType] = ..., - data_format : typing.Optional[int] = ..., - float_data : typing.Optional[typing.Iterable[float]] = ..., - int32_data : typing.Optional[typing.Iterable[int]] = ..., - byte_data : typing.Optional[bytes] = ..., - string_data : typing.Optional[typing.Iterable[bytes]] = ..., - double_data : typing.Optional[typing.Iterable[float]] = ..., - int64_data : typing.Optional[typing.Iterable[int]] = ..., - raw_data : typing.Optional[bytes] = ..., - name : typing.Optional[typing.Text] = ..., - device_detail : typing.Optional[global___DeviceOption] = ..., - segment : typing.Optional[global___TensorProto.Segment] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"byte_data",b"byte_data",u"data_format",b"data_format",u"data_type",b"data_type",u"device_detail",b"device_detail",u"name",b"name",u"raw_data",b"raw_data",u"segment",b"segment"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"byte_data",b"byte_data",u"data_format",b"data_format",u"data_type",b"data_type",u"device_detail",b"device_detail",u"dims",b"dims",u"double_data",b"double_data",u"float_data",b"float_data",u"int32_data",b"int32_data",u"int64_data",b"int64_data",u"name",b"name",u"raw_data",b"raw_data",u"segment",b"segment",u"string_data",b"string_data"]) -> None: ... -global___TensorProto = TensorProto - -class QTensorProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - DIMS_FIELD_NUMBER: int - PRECISION_FIELD_NUMBER: int - SCALE_FIELD_NUMBER: int - BIAS_FIELD_NUMBER: int - IS_SIGNED_FIELD_NUMBER: int - DATA_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - DATA_TYPE_FIELD_NUMBER: int - SCALES_FIELD_NUMBER: int - BIASES_FIELD_NUMBER: int - AXIS_FIELD_NUMBER: int - IS_MULTIPARAM_FIELD_NUMBER: int - dims: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - precision: int = ... - scale: float = ... - bias: float = ... - is_signed: bool = ... - data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - name: typing.Text = ... - data_type: global___TensorProto.DataType = ... - scales: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - biases: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - axis: int = ... - is_multiparam: bool = ... - - def __init__(self, - *, - dims : typing.Optional[typing.Iterable[int]] = ..., - precision : typing.Optional[int] = ..., - scale : typing.Optional[float] = ..., - bias : typing.Optional[float] = ..., - is_signed : typing.Optional[bool] = ..., - data : typing.Optional[typing.Iterable[int]] = ..., - name : typing.Optional[typing.Text] = ..., - data_type : typing.Optional[global___TensorProto.DataType] = ..., - scales : typing.Optional[typing.Iterable[float]] = ..., - biases : typing.Optional[typing.Iterable[float]] = ..., - axis : typing.Optional[int] = ..., - is_multiparam : typing.Optional[bool] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"axis",b"axis",u"bias",b"bias",u"data_type",b"data_type",u"is_multiparam",b"is_multiparam",u"is_signed",b"is_signed",u"name",b"name",u"precision",b"precision",u"scale",b"scale"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"axis",b"axis",u"bias",b"bias",u"biases",b"biases",u"data",b"data",u"data_type",b"data_type",u"dims",b"dims",u"is_multiparam",b"is_multiparam",u"is_signed",b"is_signed",u"name",b"name",u"precision",b"precision",u"scale",b"scale",u"scales",b"scales"]) -> None: ... -global___QTensorProto = QTensorProto - -class TensorProtos(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - PROTOS_FIELD_NUMBER: int - - @property - def protos(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TensorProto]: ... - - def __init__(self, - *, - protos : typing.Optional[typing.Iterable[global___TensorProto]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal[u"protos",b"protos"]) -> None: ... -global___TensorProtos = TensorProtos - -class TensorShape(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - DIMS_FIELD_NUMBER: int - DATA_TYPE_FIELD_NUMBER: int - UNKNOWN_DIMS_FIELD_NUMBER: int - UNKNOWN_SHAPE_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - dims: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - data_type: global___TensorProto.DataType = ... - unknown_dims: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - unknown_shape: bool = ... - name: typing.Text = ... - - def __init__(self, - *, - dims : typing.Optional[typing.Iterable[int]] = ..., - data_type : typing.Optional[global___TensorProto.DataType] = ..., - unknown_dims : typing.Optional[typing.Iterable[int]] = ..., - unknown_shape : typing.Optional[bool] = ..., - name : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"data_type",b"data_type",u"name",b"name",u"unknown_shape",b"unknown_shape"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"data_type",b"data_type",u"dims",b"dims",u"name",b"name",u"unknown_dims",b"unknown_dims",u"unknown_shape",b"unknown_shape"]) -> None: ... -global___TensorShape = TensorShape - -class TensorShapes(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - SHAPES_FIELD_NUMBER: int - - @property - def shapes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TensorShape]: ... - - def __init__(self, - *, - shapes : typing.Optional[typing.Iterable[global___TensorShape]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal[u"shapes",b"shapes"]) -> None: ... -global___TensorShapes = TensorShapes - -class TensorBoundShape(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - class _DimType(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[DimType], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - UNKNOWN = TensorBoundShape.DimType.V(0) - CONSTANT = TensorBoundShape.DimType.V(1) - BATCH = TensorBoundShape.DimType.V(2) - BATCH_OF_FEATURE_MAX = TensorBoundShape.DimType.V(3) - BATCH_OF_FEATURE_MAX_DEFAULT = TensorBoundShape.DimType.V(4) - FEATURE_MAX = TensorBoundShape.DimType.V(5) - FEATURE_MAX_DEFAULT = TensorBoundShape.DimType.V(6) - class DimType(metaclass=_DimType): - V = typing.NewType('V', int) - UNKNOWN = TensorBoundShape.DimType.V(0) - CONSTANT = TensorBoundShape.DimType.V(1) - BATCH = TensorBoundShape.DimType.V(2) - BATCH_OF_FEATURE_MAX = TensorBoundShape.DimType.V(3) - BATCH_OF_FEATURE_MAX_DEFAULT = TensorBoundShape.DimType.V(4) - FEATURE_MAX = TensorBoundShape.DimType.V(5) - FEATURE_MAX_DEFAULT = TensorBoundShape.DimType.V(6) - - SHAPE_FIELD_NUMBER: int - DIM_TYPE_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - SHAPE_IS_FINAL_FIELD_NUMBER: int - dim_type: google.protobuf.internal.containers.RepeatedScalarFieldContainer[global___TensorBoundShape.DimType] = ... - name: typing.Text = ... - shape_is_final: bool = ... - - @property - def shape(self) -> global___TensorShape: ... - - def __init__(self, - *, - shape : typing.Optional[global___TensorShape] = ..., - dim_type : typing.Optional[typing.Iterable[global___TensorBoundShape.DimType]] = ..., - name : typing.Optional[typing.Text] = ..., - shape_is_final : typing.Optional[bool] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"name",b"name",u"shape",b"shape",u"shape_is_final",b"shape_is_final"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"dim_type",b"dim_type",u"name",b"name",u"shape",b"shape",u"shape_is_final",b"shape_is_final"]) -> None: ... -global___TensorBoundShape = TensorBoundShape - -class TensorBoundShapes(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - SHAPES_FIELD_NUMBER: int - MAX_BATCH_SIZE_FIELD_NUMBER: int - MAX_FEATURE_LEN_FIELD_NUMBER: int - max_batch_size: int = ... - max_feature_len: int = ... - - @property - def shapes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TensorBoundShape]: ... - - def __init__(self, - *, - shapes : typing.Optional[typing.Iterable[global___TensorBoundShape]] = ..., - max_batch_size : typing.Optional[int] = ..., - max_feature_len : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"max_batch_size",b"max_batch_size",u"max_feature_len",b"max_feature_len"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"max_batch_size",b"max_batch_size",u"max_feature_len",b"max_feature_len",u"shapes",b"shapes"]) -> None: ... -global___TensorBoundShapes = TensorBoundShapes - -class AOTConfig(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - MAX_BATCH_SIZE_FIELD_NUMBER: int - MAX_SEQ_SIZE_FIELD_NUMBER: int - IN_BATCH_BROADCAST_FIELD_NUMBER: int - ONNXIFI_BLACKLIST_OPS_FIELD_NUMBER: int - ONNXIFI_MIN_OPS_FIELD_NUMBER: int - max_batch_size: int = ... - max_seq_size: int = ... - in_batch_broadcast: bool = ... - onnxifi_blacklist_ops: typing.Text = ... - onnxifi_min_ops: int = ... - - def __init__(self, - *, - max_batch_size : typing.Optional[int] = ..., - max_seq_size : typing.Optional[int] = ..., - in_batch_broadcast : typing.Optional[bool] = ..., - onnxifi_blacklist_ops : typing.Optional[typing.Text] = ..., - onnxifi_min_ops : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"in_batch_broadcast",b"in_batch_broadcast",u"max_batch_size",b"max_batch_size",u"max_seq_size",b"max_seq_size",u"onnxifi_blacklist_ops",b"onnxifi_blacklist_ops",u"onnxifi_min_ops",b"onnxifi_min_ops"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"in_batch_broadcast",b"in_batch_broadcast",u"max_batch_size",b"max_batch_size",u"max_seq_size",b"max_seq_size",u"onnxifi_blacklist_ops",b"onnxifi_blacklist_ops",u"onnxifi_min_ops",b"onnxifi_min_ops"]) -> None: ... -global___AOTConfig = AOTConfig - -class Argument(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - F_FIELD_NUMBER: int - I_FIELD_NUMBER: int - S_FIELD_NUMBER: int - T_FIELD_NUMBER: int - N_FIELD_NUMBER: int - FLOATS_FIELD_NUMBER: int - INTS_FIELD_NUMBER: int - STRINGS_FIELD_NUMBER: int - TENSORS_FIELD_NUMBER: int - NETS_FIELD_NUMBER: int - QTENSORS_FIELD_NUMBER: int - name: typing.Text = ... - f: float = ... - i: int = ... - s: bytes = ... - floats: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - ints: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - strings: google.protobuf.internal.containers.RepeatedScalarFieldContainer[bytes] = ... - - @property - def t(self) -> global___TensorProto: ... - - @property - def n(self) -> global___NetDef: ... - - @property - def tensors(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TensorProto]: ... - - @property - def nets(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___NetDef]: ... - - @property - def qtensors(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___QTensorProto]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - f : typing.Optional[float] = ..., - i : typing.Optional[int] = ..., - s : typing.Optional[bytes] = ..., - t : typing.Optional[global___TensorProto] = ..., - n : typing.Optional[global___NetDef] = ..., - floats : typing.Optional[typing.Iterable[float]] = ..., - ints : typing.Optional[typing.Iterable[int]] = ..., - strings : typing.Optional[typing.Iterable[bytes]] = ..., - tensors : typing.Optional[typing.Iterable[global___TensorProto]] = ..., - nets : typing.Optional[typing.Iterable[global___NetDef]] = ..., - qtensors : typing.Optional[typing.Iterable[global___QTensorProto]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"f",b"f",u"i",b"i",u"n",b"n",u"name",b"name",u"s",b"s",u"t",b"t"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"f",b"f",u"floats",b"floats",u"i",b"i",u"ints",b"ints",u"n",b"n",u"name",b"name",u"nets",b"nets",u"qtensors",b"qtensors",u"s",b"s",u"strings",b"strings",u"t",b"t",u"tensors",b"tensors"]) -> None: ... -global___Argument = Argument - -class DeviceOption(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - DEVICE_TYPE_FIELD_NUMBER: int - DEVICE_ID_FIELD_NUMBER: int - RANDOM_SEED_FIELD_NUMBER: int - NODE_NAME_FIELD_NUMBER: int - NUMA_NODE_ID_FIELD_NUMBER: int - EXTRA_INFO_FIELD_NUMBER: int - device_type: int = ... - device_id: int = ... - random_seed: int = ... - node_name: typing.Text = ... - numa_node_id: int = ... - extra_info: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - - def __init__(self, - *, - device_type : typing.Optional[int] = ..., - device_id : typing.Optional[int] = ..., - random_seed : typing.Optional[int] = ..., - node_name : typing.Optional[typing.Text] = ..., - numa_node_id : typing.Optional[int] = ..., - extra_info : typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"device_id",b"device_id",u"device_type",b"device_type",u"node_name",b"node_name",u"numa_node_id",b"numa_node_id",u"random_seed",b"random_seed"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"device_id",b"device_id",u"device_type",b"device_type",u"extra_info",b"extra_info",u"node_name",b"node_name",u"numa_node_id",b"numa_node_id",u"random_seed",b"random_seed"]) -> None: ... -global___DeviceOption = DeviceOption - -class OperatorDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - INPUT_FIELD_NUMBER: int - OUTPUT_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - TYPE_FIELD_NUMBER: int - ARG_FIELD_NUMBER: int - DEVICE_OPTION_FIELD_NUMBER: int - ENGINE_FIELD_NUMBER: int - CONTROL_INPUT_FIELD_NUMBER: int - IS_GRADIENT_OP_FIELD_NUMBER: int - DEBUG_INFO_FIELD_NUMBER: int - DOMAIN_FIELD_NUMBER: int - OP_VERSION_FIELD_NUMBER: int - input: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - output: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - name: typing.Text = ... - type: typing.Text = ... - engine: typing.Text = ... - control_input: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - is_gradient_op: bool = ... - debug_info: typing.Text = ... - domain: typing.Text = ... - op_version: int = ... - - @property - def arg(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Argument]: ... - - @property - def device_option(self) -> global___DeviceOption: ... - - def __init__(self, - *, - input : typing.Optional[typing.Iterable[typing.Text]] = ..., - output : typing.Optional[typing.Iterable[typing.Text]] = ..., - name : typing.Optional[typing.Text] = ..., - type : typing.Optional[typing.Text] = ..., - arg : typing.Optional[typing.Iterable[global___Argument]] = ..., - device_option : typing.Optional[global___DeviceOption] = ..., - engine : typing.Optional[typing.Text] = ..., - control_input : typing.Optional[typing.Iterable[typing.Text]] = ..., - is_gradient_op : typing.Optional[bool] = ..., - debug_info : typing.Optional[typing.Text] = ..., - domain : typing.Optional[typing.Text] = ..., - op_version : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"debug_info",b"debug_info",u"device_option",b"device_option",u"domain",b"domain",u"engine",b"engine",u"is_gradient_op",b"is_gradient_op",u"name",b"name",u"op_version",b"op_version",u"type",b"type"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"arg",b"arg",u"control_input",b"control_input",u"debug_info",b"debug_info",u"device_option",b"device_option",u"domain",b"domain",u"engine",b"engine",u"input",b"input",u"is_gradient_op",b"is_gradient_op",u"name",b"name",u"op_version",b"op_version",u"output",b"output",u"type",b"type"]) -> None: ... -global___OperatorDef = OperatorDef - -class MapFieldEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - VAL_FIELD_NUMBER: int - key: typing.Text = ... - val: typing.Text = ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - val : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key",u"val",b"val"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"val",b"val"]) -> None: ... -global___MapFieldEntry = MapFieldEntry - -class BackendOptions(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - BACKEND_NAME_FIELD_NUMBER: int - OPTION_FIELD_NUMBER: int - backend_name: typing.Text = ... - - @property - def option(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___MapFieldEntry]: ... - - def __init__(self, - *, - backend_name : typing.Optional[typing.Text] = ..., - option : typing.Optional[typing.Iterable[global___MapFieldEntry]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"backend_name",b"backend_name"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"backend_name",b"backend_name",u"option",b"option"]) -> None: ... -global___BackendOptions = BackendOptions - -class PartitionInfo(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - DEVICE_ID_FIELD_NUMBER: int - EXTRA_INFO_FIELD_NUMBER: int - BACKEND_OPTIONS_FIELD_NUMBER: int - name: typing.Text = ... - device_id: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - extra_info: typing.Text = ... - - @property - def backend_options(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___BackendOptions]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - device_id : typing.Optional[typing.Iterable[int]] = ..., - extra_info : typing.Optional[typing.Text] = ..., - backend_options : typing.Optional[typing.Iterable[global___BackendOptions]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"extra_info",b"extra_info",u"name",b"name"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"backend_options",b"backend_options",u"device_id",b"device_id",u"extra_info",b"extra_info",u"name",b"name"]) -> None: ... -global___PartitionInfo = PartitionInfo - -class NetDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - OP_FIELD_NUMBER: int - TYPE_FIELD_NUMBER: int - NUM_WORKERS_FIELD_NUMBER: int - DEVICE_OPTION_FIELD_NUMBER: int - ARG_FIELD_NUMBER: int - EXTERNAL_INPUT_FIELD_NUMBER: int - EXTERNAL_OUTPUT_FIELD_NUMBER: int - PARTITION_INFO_FIELD_NUMBER: int - name: typing.Text = ... - type: typing.Text = ... - num_workers: int = ... - external_input: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - external_output: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - - @property - def op(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___OperatorDef]: ... - - @property - def device_option(self) -> global___DeviceOption: ... - - @property - def arg(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Argument]: ... - - @property - def partition_info(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PartitionInfo]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - op : typing.Optional[typing.Iterable[global___OperatorDef]] = ..., - type : typing.Optional[typing.Text] = ..., - num_workers : typing.Optional[int] = ..., - device_option : typing.Optional[global___DeviceOption] = ..., - arg : typing.Optional[typing.Iterable[global___Argument]] = ..., - external_input : typing.Optional[typing.Iterable[typing.Text]] = ..., - external_output : typing.Optional[typing.Iterable[typing.Text]] = ..., - partition_info : typing.Optional[typing.Iterable[global___PartitionInfo]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"device_option",b"device_option",u"name",b"name",u"num_workers",b"num_workers",u"type",b"type"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"arg",b"arg",u"device_option",b"device_option",u"external_input",b"external_input",u"external_output",b"external_output",u"name",b"name",u"num_workers",b"num_workers",u"op",b"op",u"partition_info",b"partition_info",u"type",b"type"]) -> None: ... -global___NetDef = NetDef - -class ExecutionStep(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - SUBSTEP_FIELD_NUMBER: int - NETWORK_FIELD_NUMBER: int - NUM_ITER_FIELD_NUMBER: int - CRITERIA_NETWORK_FIELD_NUMBER: int - REPORT_NET_FIELD_NUMBER: int - REPORT_INTERVAL_FIELD_NUMBER: int - RUN_EVERY_MS_FIELD_NUMBER: int - CONCURRENT_SUBSTEPS_FIELD_NUMBER: int - SHOULD_STOP_BLOB_FIELD_NUMBER: int - ONLY_ONCE_FIELD_NUMBER: int - CREATE_WORKSPACE_FIELD_NUMBER: int - NUM_CONCURRENT_INSTANCES_FIELD_NUMBER: int - name: typing.Text = ... - network: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - num_iter: int = ... - criteria_network: typing.Text = ... - report_net: typing.Text = ... - report_interval: int = ... - run_every_ms: int = ... - concurrent_substeps: bool = ... - should_stop_blob: typing.Text = ... - only_once: bool = ... - create_workspace: bool = ... - num_concurrent_instances: int = ... - - @property - def substep(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ExecutionStep]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - substep : typing.Optional[typing.Iterable[global___ExecutionStep]] = ..., - network : typing.Optional[typing.Iterable[typing.Text]] = ..., - num_iter : typing.Optional[int] = ..., - criteria_network : typing.Optional[typing.Text] = ..., - report_net : typing.Optional[typing.Text] = ..., - report_interval : typing.Optional[int] = ..., - run_every_ms : typing.Optional[int] = ..., - concurrent_substeps : typing.Optional[bool] = ..., - should_stop_blob : typing.Optional[typing.Text] = ..., - only_once : typing.Optional[bool] = ..., - create_workspace : typing.Optional[bool] = ..., - num_concurrent_instances : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"concurrent_substeps",b"concurrent_substeps",u"create_workspace",b"create_workspace",u"criteria_network",b"criteria_network",u"name",b"name",u"num_concurrent_instances",b"num_concurrent_instances",u"num_iter",b"num_iter",u"only_once",b"only_once",u"report_interval",b"report_interval",u"report_net",b"report_net",u"run_every_ms",b"run_every_ms",u"should_stop_blob",b"should_stop_blob"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"concurrent_substeps",b"concurrent_substeps",u"create_workspace",b"create_workspace",u"criteria_network",b"criteria_network",u"name",b"name",u"network",b"network",u"num_concurrent_instances",b"num_concurrent_instances",u"num_iter",b"num_iter",u"only_once",b"only_once",u"report_interval",b"report_interval",u"report_net",b"report_net",u"run_every_ms",b"run_every_ms",u"should_stop_blob",b"should_stop_blob",u"substep",b"substep"]) -> None: ... -global___ExecutionStep = ExecutionStep - -class PlanDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - NETWORK_FIELD_NUMBER: int - EXECUTION_STEP_FIELD_NUMBER: int - name: typing.Text = ... - - @property - def network(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___NetDef]: ... - - @property - def execution_step(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ExecutionStep]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - network : typing.Optional[typing.Iterable[global___NetDef]] = ..., - execution_step : typing.Optional[typing.Iterable[global___ExecutionStep]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"name",b"name"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"execution_step",b"execution_step",u"name",b"name",u"network",b"network"]) -> None: ... -global___PlanDef = PlanDef - -class BlobProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - TYPE_FIELD_NUMBER: int - TENSOR_FIELD_NUMBER: int - CONTENT_FIELD_NUMBER: int - QTENSOR_FIELD_NUMBER: int - CONTENT_NUM_CHUNKS_FIELD_NUMBER: int - CONTENT_CHUNK_ID_FIELD_NUMBER: int - name: typing.Text = ... - type: typing.Text = ... - content: bytes = ... - content_num_chunks: int = ... - content_chunk_id: int = ... - - @property - def tensor(self) -> global___TensorProto: ... - - @property - def qtensor(self) -> global___QTensorProto: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - type : typing.Optional[typing.Text] = ..., - tensor : typing.Optional[global___TensorProto] = ..., - content : typing.Optional[bytes] = ..., - qtensor : typing.Optional[global___QTensorProto] = ..., - content_num_chunks : typing.Optional[int] = ..., - content_chunk_id : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"content",b"content",u"content_chunk_id",b"content_chunk_id",u"content_num_chunks",b"content_num_chunks",u"name",b"name",u"qtensor",b"qtensor",u"tensor",b"tensor",u"type",b"type"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"content",b"content",u"content_chunk_id",b"content_chunk_id",u"content_num_chunks",b"content_num_chunks",u"name",b"name",u"qtensor",b"qtensor",u"tensor",b"tensor",u"type",b"type"]) -> None: ... -global___BlobProto = BlobProto - -class DBReaderProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - SOURCE_FIELD_NUMBER: int - DB_TYPE_FIELD_NUMBER: int - KEY_FIELD_NUMBER: int - name: typing.Text = ... - source: typing.Text = ... - db_type: typing.Text = ... - key: typing.Text = ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - source : typing.Optional[typing.Text] = ..., - db_type : typing.Optional[typing.Text] = ..., - key : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"db_type",b"db_type",u"key",b"key",u"name",b"name",u"source",b"source"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"db_type",b"db_type",u"key",b"key",u"name",b"name",u"source",b"source"]) -> None: ... -global___DBReaderProto = DBReaderProto - -class BlobSerializationOptions(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - class _FloatFormat(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[FloatFormat], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - FLOAT_DEFAULT = BlobSerializationOptions.FloatFormat.V(0) - FLOAT_PROTOBUF = BlobSerializationOptions.FloatFormat.V(1) - FLOAT_BFLOAT16 = BlobSerializationOptions.FloatFormat.V(2) - class FloatFormat(metaclass=_FloatFormat): - V = typing.NewType('V', int) - FLOAT_DEFAULT = BlobSerializationOptions.FloatFormat.V(0) - FLOAT_PROTOBUF = BlobSerializationOptions.FloatFormat.V(1) - FLOAT_BFLOAT16 = BlobSerializationOptions.FloatFormat.V(2) - - BLOB_NAME_REGEX_FIELD_NUMBER: int - CHUNK_SIZE_FIELD_NUMBER: int - FLOAT_FORMAT_FIELD_NUMBER: int - blob_name_regex: typing.Text = ... - chunk_size: int = ... - float_format: global___BlobSerializationOptions.FloatFormat = ... - - def __init__(self, - *, - blob_name_regex : typing.Optional[typing.Text] = ..., - chunk_size : typing.Optional[int] = ..., - float_format : typing.Optional[global___BlobSerializationOptions.FloatFormat] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"blob_name_regex",b"blob_name_regex",u"chunk_size",b"chunk_size",u"float_format",b"float_format"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"blob_name_regex",b"blob_name_regex",u"chunk_size",b"chunk_size",u"float_format",b"float_format"]) -> None: ... -global___BlobSerializationOptions = BlobSerializationOptions - -class SerializationOptions(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - OPTIONS_FIELD_NUMBER: int - - @property - def options(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___BlobSerializationOptions]: ... - - def __init__(self, - *, - options : typing.Optional[typing.Iterable[global___BlobSerializationOptions]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal[u"options",b"options"]) -> None: ... -global___SerializationOptions = SerializationOptions - -DeviceType = int - -# These are freedom-patched into caffe2_pb2 in caffe2/proto/__init__.py -CPU: int = DeviceType.PROTO_CPU -CUDA: int = DeviceType.PROTO_CUDA -MKLDNN: int = DeviceType.PROTO_MKLDNN -OPENGL: int = DeviceType.PROTO_OPENGL -OPENCL: int = DeviceType.PROTO_OPENCL -IDEEP: int = DeviceType.PROTO_IDEEP -HIP: int = DeviceType.PROTO_HIP -COMPILE_TIME_MAX_DEVICE_TYPES: int = DeviceType.PROTO_COMPILE_TIME_MAX_DEVICE_TYPES diff --git a/caffe2/proto/gen_proto_typestubs.sh b/caffe2/proto/gen_proto_typestubs.sh deleted file mode 100755 index 85503936ea1a..000000000000 --- a/caffe2/proto/gen_proto_typestubs.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env bash - -# Generate type stubs for .proto definition files. - -# This should be run from as -# ./gen_proto_typestubs.sh -# (i.e., from inside the proto/ directory) - -# assumes mypy-protobuf installed to ~/.local; i.e. via -# pip3 install mypy-protobuf --user - -set -euxo pipefail - -MYPY_PROTOBUF_HOME="${1:-${HOME}/.local/bin}" - -pushd ../../ -buck run fbsource//third-party/protobuf:protoc -- --plugin=protoc-gen-mypy="${MYPY_PROTOBUF_HOME}"/protoc-gen-mypy --mypy_out=./ caffe2/proto/*.proto -popd - -# get rid of 'builtins.' prefix, which pyre does not like -sed -E -i 's/builtins\.//g' ./*.pyi - -# mypy-protobuf references types from other mypy-protobuf-generated stubs as -# 'type.V', but it should just be 'type', so we get rid of the '.V' suffix -# when it's not followed by parens to indicate a particular enum value. -sed -E -i 's/\.V([^(_[:alnum:]])/\1/g' ./*.pyi - -# --------------------------- -# Freedom-patched DeviceTypes -# --------------------------- -# -# In order to make DeviceTypes like CPU, CUDA, etc. directly accessible from -# the caffe2_pb2 module, they are currently freedom-patched into it in -# caffe2/python/__init__.py. This is not ideal: it would be better if these -# were autogenerated when the protobuf definitions were created by using -# allow_alias = true in the DeviceTypeProto definition in caffe2.proto. -# -# However, it is impossible to do this currently without significant effort. -# The issue is that the generated proto constants would conflict with various -# constants defined in the C++ caffe2 codebase (`caffe2_pb2.h`). We cannot -# simply remove these constants and replace them with the caffe2 -# DeviceTypeProto constants, because a huge portion of code expects -# at::DeviceType constants defined in `core/DeviceType.h` (apparently -# duplicated to avoid having to figure out how to autogenerate the protobuf -# definitions using cmake for ATen). -# -# Instead, we make a best-effort to add additional definitions in -# `caffe2_pb2.py` by looking for any freedom-patched constants in -# `caffe2/python/__init__.py` and making sure they have corresponding stubs in -# the pyi (see `gen_proto_typestubs_helper.py`). - -python3 ./gen_proto_typestubs_helper.py >> caffe2_pb2.pyi diff --git a/caffe2/proto/gen_proto_typestubs_helper.py b/caffe2/proto/gen_proto_typestubs_helper.py deleted file mode 100644 index 4ed83f55998f..000000000000 --- a/caffe2/proto/gen_proto_typestubs_helper.py +++ /dev/null @@ -1,15 +0,0 @@ -import ast - -with open("../python/__init__.py", "r") as f: - tree = ast.parse(f.read()) - -print("\nDeviceType = int\n") -print("# These are freedom-patched into caffe2_pb2 in caffe2/proto/__init__.py") -for stmt in tree.body: - if not isinstance(stmt, ast.Assign): - continue - target = stmt.targets[0] - if not isinstance(target, ast.Attribute): - continue - if isinstance(target.value, ast.Name) and target.value.id == "caffe2_pb2": - print(f"{target.attr}: int = DeviceType.PROTO_{target.attr}") diff --git a/caffe2/proto/torch.proto b/caffe2/proto/torch.proto deleted file mode 100644 index 1ac4f5443579..000000000000 --- a/caffe2/proto/torch.proto +++ /dev/null @@ -1,114 +0,0 @@ -syntax = "proto2"; - -import "caffe2/proto/caffe2.proto"; - -package torch; - -message RecordRef { - optional string key = 1; -} - -message TensorDef { - repeated int64 dims = 1; - optional int64 offset = 2; - repeated int64 strides = 3; - // whether we compute the gradient for the parameter - optional bool requires_grad = 4; - optional caffe2.TensorProto.DataType data_type = 5; - - optional RecordRef data = 6; - - // device field stores the canonical device string, and it follows the - // format below: `(cpu|cuda)[:]`, e.g., 'cuda:0' - optional string device = 7; - - optional bool is_quantized = 8; - optional double scale = 9; - optional int64 zero_point = 10; -} - -message AttributeDef { - // The mypy type of this attribute - required string type = 1; - required string name = 2; - - // Offset into attribute table - required int64 id = 3; -} - -message ParameterDef { - // whether this parameter is registered as buffer or not - optional bool is_buffer = 1; - - // the offset into the tensor table where this parameter is stored - optional int64 tensor_id = 2; - - optional string name = 3; -} - -message ModuleDef { - repeated ModuleDef submodules = 1; - - optional RecordRef torchscript_arena = 2; - - repeated caffe2.NetDef caffe2_nets = 3; - - // because the old pickle modules may not be supported by torch_script, - // have to stored as pickle_arena at this moment. - optional RecordRef pickle_arena = 4; - // should be exposed by the Class Archive, so user can save - // module specific data which cannot be store in the graph or torch_script - optional RecordRef cpp_arena = 5; - - // the parameters of this module - repeated ParameterDef parameters = 6; - - // the names of inputs and outputs of the module are inferred - // from the main method. - - optional string name = 7; - - // whether apply the optimizations to this module, only applicable to - // script modules - optional bool optimize = 8; - - repeated AttributeDef attributes = 9; - - // Used for retrieving module state from the pickled IValues table - optional int64 get_state_attribute_id = 10; - - optional RecordRef torchscript_debug_arena = 11; -} - -// Represents all non-module code that the model depends on. -// Right now it's just a straight list of classes, defined in dependency order -// (i.e. dependencies appear before their dependers) -message LibDef { - optional RecordRef torchscript_arena = 1; -} - -enum ProtoVersion { PROTO_VERSION_NEWEST = 0x0000000000000006; } - -message ModelDef { - // numbers of fields that have been removed. Do not reuse them! - reserved 9; - reserved "libs"; - // for the proto version, to keep both backward and forward - // compatibility, please bump the proto_version when we add any - // change in the proto. runtime decides whether accept the - // model based on the ir_version. - optional int64 proto_version = 1; - - // main module of the model - optional ModuleDef main_module = 2; - - // to distinguish whether exported from c2 or torch - optional string producer_name = 3; - - // put build version here - optional string producer_version = 4; - - // the table contains all the tensor information - // the tensor id is defined as TensorProto.name - repeated TensorDef tensors = 5; -} diff --git a/caffe2/proto/torch_pb2.pyi b/caffe2/proto/torch_pb2.pyi deleted file mode 100644 index 33826e2aff5d..000000000000 --- a/caffe2/proto/torch_pb2.pyi +++ /dev/null @@ -1,218 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import caffe2.proto.caffe2_pb2 -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -global___ProtoVersion = ProtoVersion -class _ProtoVersion(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ProtoVersion], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - PROTO_VERSION_NEWEST = ProtoVersion.V(6) -class ProtoVersion(metaclass=_ProtoVersion): - V = typing.NewType('V', int) -PROTO_VERSION_NEWEST = ProtoVersion.V(6) - -class RecordRef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - key: typing.Text = ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key"]) -> None: ... -global___RecordRef = RecordRef - -class TensorDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - DIMS_FIELD_NUMBER: int - OFFSET_FIELD_NUMBER: int - STRIDES_FIELD_NUMBER: int - REQUIRES_GRAD_FIELD_NUMBER: int - DATA_TYPE_FIELD_NUMBER: int - DATA_FIELD_NUMBER: int - DEVICE_FIELD_NUMBER: int - IS_QUANTIZED_FIELD_NUMBER: int - SCALE_FIELD_NUMBER: int - ZERO_POINT_FIELD_NUMBER: int - dims: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - offset: int = ... - strides: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - requires_grad: bool = ... - data_type: caffe2.proto.caffe2_pb2.TensorProto.DataType = ... - device: typing.Text = ... - is_quantized: bool = ... - scale: float = ... - zero_point: int = ... - - @property - def data(self) -> global___RecordRef: ... - - def __init__(self, - *, - dims : typing.Optional[typing.Iterable[int]] = ..., - offset : typing.Optional[int] = ..., - strides : typing.Optional[typing.Iterable[int]] = ..., - requires_grad : typing.Optional[bool] = ..., - data_type : typing.Optional[caffe2.proto.caffe2_pb2.TensorProto.DataType] = ..., - data : typing.Optional[global___RecordRef] = ..., - device : typing.Optional[typing.Text] = ..., - is_quantized : typing.Optional[bool] = ..., - scale : typing.Optional[float] = ..., - zero_point : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"data",b"data",u"data_type",b"data_type",u"device",b"device",u"is_quantized",b"is_quantized",u"offset",b"offset",u"requires_grad",b"requires_grad",u"scale",b"scale",u"zero_point",b"zero_point"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"data",b"data",u"data_type",b"data_type",u"device",b"device",u"dims",b"dims",u"is_quantized",b"is_quantized",u"offset",b"offset",u"requires_grad",b"requires_grad",u"scale",b"scale",u"strides",b"strides",u"zero_point",b"zero_point"]) -> None: ... -global___TensorDef = TensorDef - -class AttributeDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - TYPE_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - ID_FIELD_NUMBER: int - type: typing.Text = ... - name: typing.Text = ... - id: int = ... - - def __init__(self, - *, - type : typing.Optional[typing.Text] = ..., - name : typing.Optional[typing.Text] = ..., - id : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"id",b"id",u"name",b"name",u"type",b"type"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"id",b"id",u"name",b"name",u"type",b"type"]) -> None: ... -global___AttributeDef = AttributeDef - -class ParameterDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - IS_BUFFER_FIELD_NUMBER: int - TENSOR_ID_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - is_buffer: bool = ... - tensor_id: int = ... - name: typing.Text = ... - - def __init__(self, - *, - is_buffer : typing.Optional[bool] = ..., - tensor_id : typing.Optional[int] = ..., - name : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"is_buffer",b"is_buffer",u"name",b"name",u"tensor_id",b"tensor_id"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"is_buffer",b"is_buffer",u"name",b"name",u"tensor_id",b"tensor_id"]) -> None: ... -global___ParameterDef = ParameterDef - -class ModuleDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - SUBMODULES_FIELD_NUMBER: int - TORCHSCRIPT_ARENA_FIELD_NUMBER: int - CAFFE2_NETS_FIELD_NUMBER: int - PICKLE_ARENA_FIELD_NUMBER: int - CPP_ARENA_FIELD_NUMBER: int - PARAMETERS_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - OPTIMIZE_FIELD_NUMBER: int - ATTRIBUTES_FIELD_NUMBER: int - GET_STATE_ATTRIBUTE_ID_FIELD_NUMBER: int - TORCHSCRIPT_DEBUG_ARENA_FIELD_NUMBER: int - name: typing.Text = ... - optimize: bool = ... - get_state_attribute_id: int = ... - - @property - def submodules(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ModuleDef]: ... - - @property - def torchscript_arena(self) -> global___RecordRef: ... - - @property - def caffe2_nets(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[caffe2.proto.caffe2_pb2.NetDef]: ... - - @property - def pickle_arena(self) -> global___RecordRef: ... - - @property - def cpp_arena(self) -> global___RecordRef: ... - - @property - def parameters(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ParameterDef]: ... - - @property - def attributes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AttributeDef]: ... - - @property - def torchscript_debug_arena(self) -> global___RecordRef: ... - - def __init__(self, - *, - submodules : typing.Optional[typing.Iterable[global___ModuleDef]] = ..., - torchscript_arena : typing.Optional[global___RecordRef] = ..., - caffe2_nets : typing.Optional[typing.Iterable[caffe2.proto.caffe2_pb2.NetDef]] = ..., - pickle_arena : typing.Optional[global___RecordRef] = ..., - cpp_arena : typing.Optional[global___RecordRef] = ..., - parameters : typing.Optional[typing.Iterable[global___ParameterDef]] = ..., - name : typing.Optional[typing.Text] = ..., - optimize : typing.Optional[bool] = ..., - attributes : typing.Optional[typing.Iterable[global___AttributeDef]] = ..., - get_state_attribute_id : typing.Optional[int] = ..., - torchscript_debug_arena : typing.Optional[global___RecordRef] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"cpp_arena",b"cpp_arena",u"get_state_attribute_id",b"get_state_attribute_id",u"name",b"name",u"optimize",b"optimize",u"pickle_arena",b"pickle_arena",u"torchscript_arena",b"torchscript_arena",u"torchscript_debug_arena",b"torchscript_debug_arena"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"attributes",b"attributes",u"caffe2_nets",b"caffe2_nets",u"cpp_arena",b"cpp_arena",u"get_state_attribute_id",b"get_state_attribute_id",u"name",b"name",u"optimize",b"optimize",u"parameters",b"parameters",u"pickle_arena",b"pickle_arena",u"submodules",b"submodules",u"torchscript_arena",b"torchscript_arena",u"torchscript_debug_arena",b"torchscript_debug_arena"]) -> None: ... -global___ModuleDef = ModuleDef - -class LibDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - TORCHSCRIPT_ARENA_FIELD_NUMBER: int - - @property - def torchscript_arena(self) -> global___RecordRef: ... - - def __init__(self, - *, - torchscript_arena : typing.Optional[global___RecordRef] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"torchscript_arena",b"torchscript_arena"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"torchscript_arena",b"torchscript_arena"]) -> None: ... -global___LibDef = LibDef - -class ModelDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - PROTO_VERSION_FIELD_NUMBER: int - MAIN_MODULE_FIELD_NUMBER: int - PRODUCER_NAME_FIELD_NUMBER: int - PRODUCER_VERSION_FIELD_NUMBER: int - TENSORS_FIELD_NUMBER: int - proto_version: int = ... - producer_name: typing.Text = ... - producer_version: typing.Text = ... - - @property - def main_module(self) -> global___ModuleDef: ... - - @property - def tensors(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TensorDef]: ... - - def __init__(self, - *, - proto_version : typing.Optional[int] = ..., - main_module : typing.Optional[global___ModuleDef] = ..., - producer_name : typing.Optional[typing.Text] = ..., - producer_version : typing.Optional[typing.Text] = ..., - tensors : typing.Optional[typing.Iterable[global___TensorDef]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"main_module",b"main_module",u"producer_name",b"producer_name",u"producer_version",b"producer_version",u"proto_version",b"proto_version"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"main_module",b"main_module",u"producer_name",b"producer_name",u"producer_version",b"producer_version",u"proto_version",b"proto_version",u"tensors",b"tensors"]) -> None: ... -global___ModelDef = ModelDef diff --git a/caffe2/utils/cpuid.cc b/caffe2/utils/cpuid.cc deleted file mode 100644 index 2ba1d2dd8840..000000000000 --- a/caffe2/utils/cpuid.cc +++ /dev/null @@ -1,83 +0,0 @@ -#include "caffe2/utils/cpuid.h" - -namespace caffe2 { - -const CpuId& GetCpuId() { - static CpuId cpuid_singleton; - return cpuid_singleton; -} - -TORCH_API uint32_t CpuId::f1c_ = 0; -TORCH_API uint32_t CpuId::f1d_ = 0; -TORCH_API uint32_t CpuId::f7b_ = 0; -TORCH_API uint32_t CpuId::f7c_ = 0; - -CpuId::CpuId() { -#ifdef _MSC_VER - int reg[4]; - __cpuid(static_cast(reg), 0); - const int n = reg[0]; - if (n >= 1) { - __cpuid(static_cast(reg), 1); - f1c_ = uint32_t(reg[2]); - f1d_ = uint32_t(reg[3]); - } - if (n >= 7) { - __cpuidex(static_cast(reg), 7, 0); - f7b_ = uint32_t(reg[1]); - f7c_ = uint32_t(reg[2]); - } -#elif defined(__i386__) && defined(__PIC__) && !defined(__clang__) && \ - defined(__GNUC__) - // The following block like the normal cpuid branch below, but gcc - // reserves ebx for use of its pic register so we must specially - // handle the save and restore to avoid clobbering the register - uint32_t n; - __asm__( - "pushl %%ebx\n\t" - "cpuid\n\t" - "popl %%ebx\n\t" - : "=a"(n) - : "a"(0) - : "ecx", "edx"); - if (n >= 1) { - uint32_t f1a; - __asm__( - "pushl %%ebx\n\t" - "cpuid\n\t" - "popl %%ebx\n\t" - : "=a"(f1a), "=c"(f1c_), "=d"(f1d_) - : "a"(1) - :); - } - if (n >= 7) { - __asm__( - "pushl %%ebx\n\t" - "cpuid\n\t" - "movl %%ebx, %%eax\n\r" - "popl %%ebx" - : "=a"(f7b_), "=c"(f7c_) - : "a"(7), "c"(0) - : "edx"); - } -#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386__) - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t n; - __asm__("cpuid" : "=a"(n) : "a"(0) : "ebx", "ecx", "edx"); - if (n >= 1) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t f1a; - __asm__("cpuid" : "=a"(f1a), "=c"(f1c_), "=d"(f1d_) : "a"(1) : "ebx"); - } - if (n >= 7) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t f7a; - __asm__("cpuid" - : "=a"(f7a), "=b"(f7b_), "=c"(f7c_) - : "a"(7), "c"(0) - : "edx"); - } -#endif -} - -} // namespace caffe2 diff --git a/caffe2/utils/cpuid.h b/caffe2/utils/cpuid.h deleted file mode 100644 index 2cac7637ba32..000000000000 --- a/caffe2/utils/cpuid.h +++ /dev/null @@ -1,146 +0,0 @@ -#pragma once - -#include - -#ifdef _MSC_VER -#include -#endif - -#include - -namespace caffe2 { - -class CpuId; - -TORCH_API const CpuId& GetCpuId(); - -/////////////////////////////////////////////////////////////////////////////// -// Implementation of CpuId that is borrowed from folly. -/////////////////////////////////////////////////////////////////////////////// - -// TODO: It might be good to use cpuinfo third-party dependency instead for -// consistency sake. - -/** - * Identification of an Intel CPU. - * Supports CPUID feature flags (EAX=1) and extended features (EAX=7, ECX=0). - * Values from - * http://www.intel.com/content/www/us/en/processors/processor-identification-cpuid-instruction-note.html - */ -class CpuId { - public: - CpuId(); - -#define X(name, r, bit) \ - inline bool name() const { \ - return ((r) & (1U << bit)) != 0; \ - } - -// cpuid(1): Processor Info and Feature Bits. -#define C(name, bit) X(name, f1c_, bit) - C(sse3, 0) - C(pclmuldq, 1) - C(dtes64, 2) - C(monitor, 3) - C(dscpl, 4) - C(vmx, 5) - C(smx, 6) - C(eist, 7) - C(tm2, 8) - C(ssse3, 9) - C(cnxtid, 10) - C(fma, 12) - C(cx16, 13) - C(xtpr, 14) - C(pdcm, 15) - C(pcid, 17) - C(dca, 18) - C(sse41, 19) - C(sse42, 20) - C(x2apic, 21) - C(movbe, 22) - C(popcnt, 23) - C(tscdeadline, 24) - C(aes, 25) - C(xsave, 26) - C(osxsave, 27) - C(avx, 28) - C(f16c, 29) - C(rdrand, 30) -#undef C - -#define D(name, bit) X(name, f1d_, bit) - D(fpu, 0) - D(vme, 1) - D(de, 2) - D(pse, 3) - D(tsc, 4) - D(msr, 5) - D(pae, 6) - D(mce, 7) - D(cx8, 8) - D(apic, 9) - D(sep, 11) - D(mtrr, 12) - D(pge, 13) - D(mca, 14) - D(cmov, 15) - D(pat, 16) - D(pse36, 17) - D(psn, 18) - D(clfsh, 19) - D(ds, 21) - D(acpi, 22) - D(mmx, 23) - D(fxsr, 24) - D(sse, 25) - D(sse2, 26) - D(ss, 27) - D(htt, 28) - D(tm, 29) - D(pbe, 31) -#undef D - -// cpuid(7): Extended Features. -#define B(name, bit) X(name, f7b_, bit) - B(bmi1, 3) - B(hle, 4) - B(avx2, 5) - B(smep, 7) - B(bmi2, 8) - B(erms, 9) - B(invpcid, 10) - B(rtm, 11) - B(mpx, 14) - B(avx512f, 16) - B(avx512dq, 17) - B(rdseed, 18) - B(adx, 19) - B(smap, 20) - B(avx512ifma, 21) - B(pcommit, 22) - B(clflushopt, 23) - B(clwb, 24) - B(avx512pf, 26) - B(avx512er, 27) - B(avx512cd, 28) - B(sha, 29) - B(avx512bw, 30) - B(avx512vl, 31) -#undef B - -#define E(name, bit) X(name, f7c_, bit) - E(prefetchwt1, 0) - E(avx512vbmi, 1) -#undef E - -#undef X - - private: - TORCH_API static uint32_t f1c_; - TORCH_API static uint32_t f1d_; - TORCH_API static uint32_t f7b_; - TORCH_API static uint32_t f7c_; -}; - -} // namespace caffe2 diff --git a/caffe2/utils/threadpool/ThreadPool.cc b/caffe2/utils/threadpool/ThreadPool.cc index 27ade275672d..298fbe9ef4fa 100644 --- a/caffe2/utils/threadpool/ThreadPool.cc +++ b/caffe2/utils/threadpool/ThreadPool.cc @@ -1,6 +1,5 @@ #include "caffe2/utils/threadpool/ThreadPool.h" #include "WorkersPool.h" -#include "caffe2/core/logging.h" #if !defined(__s390x__) && !defined(__powerpc__) #include diff --git a/caffe2/utils/threadpool/WorkersPool.h b/caffe2/utils/threadpool/WorkersPool.h index b6bbc60f2099..23a72b02465e 100644 --- a/caffe2/utils/threadpool/WorkersPool.h +++ b/caffe2/utils/threadpool/WorkersPool.h @@ -5,8 +5,7 @@ #include #include "c10/util/thread_name.h" #include -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" +#include #if defined(_MSC_VER) #include diff --git a/caffe2/utils/threadpool/pthreadpool.cc b/caffe2/utils/threadpool/pthreadpool.cc index 44c758db5cb1..b8c6c7cebb8e 100644 --- a/caffe2/utils/threadpool/pthreadpool.cc +++ b/caffe2/utils/threadpool/pthreadpool.cc @@ -4,6 +4,7 @@ #include #include #include +#include #ifdef _MSC_VER #include @@ -14,10 +15,10 @@ #endif /* Library header */ -#include "caffe2/core/logging.h" #include "caffe2/utils/fixed_divisor.h" #include "caffe2/utils/threadpool/pthreadpool.h" +#include static inline size_t divide_round_up(size_t dividend, size_t divisor) { if (dividend % divisor == 0) { diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index c854baf286e8..fa62688c7e86 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -296,7 +296,6 @@ endif() add_library(torch_python SHARED ${TORCH_PYTHON_SRCS}) -add_dependencies(torch_python Caffe2_PROTO) add_dependencies(torch_python onnx_proto) # Avoid numpy for the DEPLOY build if(USE_NUMPY) diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 193675672f6b..a8ff315c9173 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include From c209fbdc5390940d7da6e07f551a279b6e7d0e71 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Mon, 3 Jun 2024 09:34:39 -0700 Subject: [PATCH 318/706] [inductor] Fix missing unbacked def for unbacked in input expr (#127770) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127770 Approved by: https://github.com/ezyang --- torch/_inductor/scheduler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 05e05c7d950d..f17fb1f12daa 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1591,8 +1591,9 @@ def add_user( # generate a dependency because if we do, Inductor will start trying # to free the unbacked int but that's pointless for name, val in V.graph.graph_inputs.items(): - if isinstance(val, sympy.Symbol): - unbacked_symbol_to_origin_node[val] = None + if isinstance(val, sympy.Expr): + for fs in val.free_symbols: + unbacked_symbol_to_origin_node[fs] = None for node in self.nodes: log.debug("scheduling %s", node.node) From 91461601b6f262a60141061430b2066f42dd92ec Mon Sep 17 00:00:00 2001 From: Valeriu Date: Tue, 4 Jun 2024 14:44:43 +0000 Subject: [PATCH 319/706] [TORCH_FA2_flash_api] Update total_q to the reshaped query 0th dimension (#127524) There is a difference (&bug) between the TORCH_FA2_flash_api:**mha_varlen_fwd** and FA2_flash_api:**mha_varlen_fwd** at the query transposition (GQA) step. ``` at::Tensor temp_q = q; if (seqlenq_ngroups_swapped) { temp_q = q.reshape( ... ... } const int total_q = q.sizes()[0]; CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og); ``` When doing query transposition we need to update total_q to the reshaped query 0th dimension, i.e: ``` const int total_q = temp_q.sizes()[0]; ``` In the original FA2_flash_api:**mha_varlen_fwd** they dont introduce a new variable temp_q but overwrite the q value directly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127524 Approved by: https://github.com/drisspg --- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index 9eb3958bf569..24ba7e1343b1 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -602,7 +602,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q cu_seqlens_q_d = nullptr; } - const int total_q = q.sizes()[0]; + const int total_q = temp_q.sizes()[0]; TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); From f8c6d43524c01d968c92e094f7566f34d93c31fb Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 4 Jun 2024 15:12:45 +0000 Subject: [PATCH 320/706] Concat namespaces and other fixes in torch/csrc/utils (#127833) It contains formatting and other minor fixes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127833 Approved by: https://github.com/ezyang --- torch/csrc/utils/byte_order.h | 6 ++---- torch/csrc/utils/init.h | 6 ++---- torch/csrc/utils/nested.h | 6 ++---- torch/csrc/utils/out_types.h | 4 +--- torch/csrc/utils/pybind.h | 2 +- torch/csrc/utils/python_arg_parser.h | 5 ++--- torch/csrc/utils/python_dispatch.cpp | 4 ++-- torch/csrc/utils/python_dispatch.h | 8 ++------ torch/csrc/utils/python_numbers.h | 6 ++---- torch/csrc/utils/python_raii.h | 18 ++++++++---------- torch/csrc/utils/python_scalars.h | 6 ++---- torch/csrc/utils/python_torch_function_mode.h | 6 ++---- torch/csrc/utils/schema_info.h | 6 ++---- torch/csrc/utils/structseq.h | 4 +--- torch/csrc/utils/tensor_apply.cpp | 6 ++---- torch/csrc/utils/tensor_apply.h | 6 ++---- torch/csrc/utils/tensor_flatten.cpp | 6 ++---- torch/csrc/utils/tensor_flatten.h | 6 ++---- torch/csrc/utils/tensor_layouts.cpp | 6 ++---- torch/csrc/utils/tensor_layouts.h | 4 +--- torch/csrc/utils/tensor_list.cpp | 6 ++---- torch/csrc/utils/tensor_list.h | 4 +--- torch/csrc/utils/tensor_memoryformats.cpp | 6 ++---- torch/csrc/utils/tensor_new.h | 6 ++---- torch/csrc/utils/tensor_numpy.cpp | 8 ++++---- torch/csrc/utils/tensor_qschemes.cpp | 11 ++++------- torch/csrc/utils/tensor_qschemes.h | 6 ++---- torch/csrc/utils/tensor_types.h | 6 ++---- torch/csrc/utils/throughput_benchmark-inl.h | 8 ++------ torch/csrc/utils/throughput_benchmark.h | 6 ++---- torch/csrc/utils/torch_dispatch_mode.h | 6 ++---- torch/csrc/utils/variadic.h | 2 -- 32 files changed, 67 insertions(+), 129 deletions(-) diff --git a/torch/csrc/utils/byte_order.h b/torch/csrc/utils/byte_order.h index d960b287e20f..87c1f8837239 100644 --- a/torch/csrc/utils/byte_order.h +++ b/torch/csrc/utils/byte_order.h @@ -62,8 +62,7 @@ #error Unexpected or undefined __BYTE_ORDER__ #endif -namespace torch { -namespace utils { +namespace torch::utils { enum THPByteOrder { THP_LITTLE_ENDIAN = 0, THP_BIG_ENDIAN = 1 }; @@ -223,5 +222,4 @@ TORCH_API void THP_encodeComplexDoubleBuffer( THPByteOrder order, size_t len); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/init.h b/torch/csrc/utils/init.h index bf6dd216bbcc..31b65470c18e 100644 --- a/torch/csrc/utils/init.h +++ b/torch/csrc/utils/init.h @@ -2,10 +2,8 @@ #include -namespace torch { -namespace throughput_benchmark { +namespace torch::throughput_benchmark { void initThroughputBenchmarkBindings(PyObject* module); -} // namespace throughput_benchmark -} // namespace torch +} // namespace torch::throughput_benchmark diff --git a/torch/csrc/utils/nested.h b/torch/csrc/utils/nested.h index f3a1061e4712..7683a2412418 100644 --- a/torch/csrc/utils/nested.h +++ b/torch/csrc/utils/nested.h @@ -5,13 +5,11 @@ #include -namespace torch { -namespace utils { +namespace torch::utils { at::Tensor nested_tensor_ctor( c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/out_types.h b/torch/csrc/utils/out_types.h index 68bf759f3003..63d85dc8f5a9 100644 --- a/torch/csrc/utils/out_types.h +++ b/torch/csrc/utils/out_types.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace utils { +namespace torch::utils { TORCH_API void check_out_type_matches( const at::Tensor& result, @@ -14,4 +13,3 @@ TORCH_API void check_out_type_matches( bool device_is_none); } -} // namespace torch diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index 19874d2e29b2..a222feeaa22d 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -255,7 +255,7 @@ template <> struct type_caster : public type_caster_base { using base = type_caster_base; - c10::DispatchKey tmp; + c10::DispatchKey tmp{}; public: bool load(handle src, bool convert) { diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index a7c53bfb0ad1..8966131f9825 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -77,8 +77,6 @@ #include #include #include -#include -#include #include #include @@ -224,6 +222,7 @@ struct PythonArgs { int idx; bool traceable; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const FunctionSignature& signature; PyObject** args; std::vector overloaded_args; // NOTE: borrowed references @@ -504,7 +503,7 @@ inline std::vector PythonArgs::intlist(int i) { return intlistWithDefault(i, signature.params[i].default_intlist); } -inline PyObject* toPyObject(c10::SymInt symint) { +inline PyObject* toPyObject(const c10::SymInt& symint) { if (symint.is_symbolic()) { auto r = py::cast(symint).release().ptr(); TORCH_INTERNAL_ASSERT(r); diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index e370923b398d..ec0af99842d2 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -255,7 +255,7 @@ void initDispatchBindings(PyObject* module) { .def("debug", &c10::OperatorHandle::debug) .def( "redispatch_boxed", - [](py::object self, + [](const py::object& self, c10::DispatchKeySet keyset, py::args args, const py::kwargs& kwargs) { @@ -819,7 +819,7 @@ void initDispatchBindings(PyObject* module) { auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); for (auto& op : op_names) { - std::cout << op << std::endl; + std::cout << op << '\n'; } }, py::arg("dispatch_key") = static_cast("")); diff --git a/torch/csrc/utils/python_dispatch.h b/torch/csrc/utils/python_dispatch.h index 9549b817ba6a..32d436d8347e 100644 --- a/torch/csrc/utils/python_dispatch.h +++ b/torch/csrc/utils/python_dispatch.h @@ -1,9 +1,7 @@ #include #include -namespace torch { -namespace impl { -namespace dispatch { +namespace torch::impl::dispatch { void initDispatchBindings(PyObject* module); @@ -14,6 +12,4 @@ void python_op_registration_trampoline_impl( torch::jit::Stack* stack, bool with_keyset); -} // namespace dispatch -} // namespace impl -} // namespace torch +} // namespace torch::impl::dispatch diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index 2a17afdf0e18..d5b772b768e2 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -57,8 +57,7 @@ inline bool THPUtils_checkLong(PyObject* obj) { } inline int32_t THPUtils_unpackInt(PyObject* obj) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int overflow; + int overflow = 0; long value = PyLong_AsLongAndOverflow(obj, &overflow); if (value == -1 && PyErr_Occurred()) { throw python_error(); @@ -74,8 +73,7 @@ inline int32_t THPUtils_unpackInt(PyObject* obj) { } inline int64_t THPUtils_unpackLong(PyObject* obj) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int overflow; + int overflow = 0; long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); if (value == -1 && PyErr_Occurred()) { throw python_error(); diff --git a/torch/csrc/utils/python_raii.h b/torch/csrc/utils/python_raii.h index 411e558715e8..bc7b5c263e0d 100644 --- a/torch/csrc/utils/python_raii.h +++ b/torch/csrc/utils/python_raii.h @@ -2,8 +2,7 @@ #include #include -namespace torch { -namespace impl { +namespace torch::impl { template struct RAIIContextManager { @@ -37,9 +36,9 @@ void py_context_manager(const py::module& m, const char* name) { .def( "__exit__", [](ContextManagerT& guard, - py::object exc_type, - py::object exc_value, - py::object traceback) { guard.exit(); }); + const py::object& exc_type, + const py::object& exc_value, + const py::object& traceback) { guard.exit(); }); } template @@ -77,10 +76,9 @@ void py_context_manager_DEPRECATED(const py::module& m, const char* name) { .def( "__exit__", [](ContextManagerT& guard, - py::object exc_type, - py::object exc_value, - py::object traceback) { guard.exit(); }); + const py::object& exc_type, + const py::object& exc_value, + const py::object& traceback) { guard.exit(); }); } -} // namespace impl -} // namespace torch +} // namespace torch::impl diff --git a/torch/csrc/utils/python_scalars.h b/torch/csrc/utils/python_scalars.h index 2819f56b6bab..997425ac7de2 100644 --- a/torch/csrc/utils/python_scalars.h +++ b/torch/csrc/utils/python_scalars.h @@ -7,8 +7,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { template inline T unpackIntegral(PyObject* obj, const char* type) { @@ -159,5 +158,4 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) { } } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/python_torch_function_mode.h b/torch/csrc/utils/python_torch_function_mode.h index f6652dfd9308..f0e6bb9acbe9 100644 --- a/torch/csrc/utils/python_torch_function_mode.h +++ b/torch/csrc/utils/python_torch_function_mode.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace overrides { +namespace torch::overrides { struct StashTorchFunctionModeGuard { StashTorchFunctionModeGuard() { @@ -21,5 +20,4 @@ struct StashTorchFunctionModeGuard { std::shared_ptr cur_mode_; }; -} // namespace overrides -} // namespace torch +} // namespace torch::overrides diff --git a/torch/csrc/utils/schema_info.h b/torch/csrc/utils/schema_info.h index acda1bffc153..18aaa9bc7d35 100644 --- a/torch/csrc/utils/schema_info.h +++ b/torch/csrc/utils/schema_info.h @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { using SchemaSpecialCasePair = std::pair>; @@ -113,5 +112,4 @@ struct TORCH_API SchemaInfo { bool has_init_; }; -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/structseq.h b/torch/csrc/utils/structseq.h index 0d91d39d34be..60e3429b50cd 100644 --- a/torch/csrc/utils/structseq.h +++ b/torch/csrc/utils/structseq.h @@ -2,10 +2,8 @@ #include -namespace torch { -namespace utils { +namespace torch::utils { PyObject* returned_structseq_repr(PyStructSequence* obj); } -} // namespace torch diff --git a/torch/csrc/utils/tensor_apply.cpp b/torch/csrc/utils/tensor_apply.cpp index ffb2c5801751..906b5422b373 100644 --- a/torch/csrc/utils/tensor_apply.cpp +++ b/torch/csrc/utils/tensor_apply.cpp @@ -10,8 +10,7 @@ using namespace at; -namespace torch { -namespace utils { +namespace torch::utils { struct StridedData { StridedData(const Tensor& tensor) @@ -129,5 +128,4 @@ const Tensor& map2_( return self; } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_apply.h b/torch/csrc/utils/tensor_apply.h index bd06e0f3e30b..0e721542fe69 100644 --- a/torch/csrc/utils/tensor_apply.h +++ b/torch/csrc/utils/tensor_apply.h @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { const at::Tensor& apply_(const at::Tensor& self, PyObject* fn); const at::Tensor& map_( @@ -17,5 +16,4 @@ const at::Tensor& map2_( const at::Tensor& y_, PyObject* fn); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_flatten.cpp b/torch/csrc/utils/tensor_flatten.cpp index 396a6e8a3a8e..fb06ad884d7e 100644 --- a/torch/csrc/utils/tensor_flatten.cpp +++ b/torch/csrc/utils/tensor_flatten.cpp @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { using namespace at; @@ -123,5 +122,4 @@ std::vector unflatten_sparse_tensors( return outputs; } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_flatten.h b/torch/csrc/utils/tensor_flatten.h index 04a55ec7960e..2b65403fb0de 100644 --- a/torch/csrc/utils/tensor_flatten.h +++ b/torch/csrc/utils/tensor_flatten.h @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { /// Generate an ID for a combination of tensor backend + scalar type to be used /// when ordering tensors ('like' tensors are grouped by pulling out their @@ -82,5 +81,4 @@ TORCH_API std::vector unflatten_sparse_tensors( const at::Tensor& flat_values, at::TensorList tensors); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_layouts.cpp b/torch/csrc/utils/tensor_layouts.cpp index b403f9130bd9..be8816c8a9ab 100644 --- a/torch/csrc/utils/tensor_layouts.cpp +++ b/torch/csrc/utils/tensor_layouts.cpp @@ -7,8 +7,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { #define REGISTER_LAYOUT(layout, LAYOUT) \ PyObject* layout##_layout = \ @@ -55,5 +54,4 @@ void initializeLayouts() { REGISTER_LAYOUT(jagged, Jagged); } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_layouts.h b/torch/csrc/utils/tensor_layouts.h index 33e32b516b12..7ee7b848cadb 100644 --- a/torch/csrc/utils/tensor_layouts.h +++ b/torch/csrc/utils/tensor_layouts.h @@ -1,9 +1,7 @@ #pragma once -namespace torch { -namespace utils { +namespace torch::utils { void initializeLayouts(); } -} // namespace torch diff --git a/torch/csrc/utils/tensor_list.cpp b/torch/csrc/utils/tensor_list.cpp index c72de0b5e9e0..84f4688e0ecc 100644 --- a/torch/csrc/utils/tensor_list.cpp +++ b/torch/csrc/utils/tensor_list.cpp @@ -9,8 +9,7 @@ using namespace at; -namespace torch { -namespace utils { +namespace torch::utils { static PyObject* recursive_to_list( const char* data, @@ -66,5 +65,4 @@ PyObject* tensor_to_list(const Tensor& tensor) { tensor.numel() == 0 ? 0 : data.dtype().itemsize()); } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_list.h b/torch/csrc/utils/tensor_list.h index 8ae77df4700a..8580631921b7 100644 --- a/torch/csrc/utils/tensor_list.h +++ b/torch/csrc/utils/tensor_list.h @@ -6,10 +6,8 @@ namespace at { class Tensor; } -namespace torch { -namespace utils { +namespace torch::utils { PyObject* tensor_to_list(const at::Tensor& tensor); } -} // namespace torch diff --git a/torch/csrc/utils/tensor_memoryformats.cpp b/torch/csrc/utils/tensor_memoryformats.cpp index 63dafaf5f5ff..28d56291bc94 100644 --- a/torch/csrc/utils/tensor_memoryformats.cpp +++ b/torch/csrc/utils/tensor_memoryformats.cpp @@ -8,8 +8,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { namespace { // Intentionally leaked @@ -50,5 +49,4 @@ void initializeMemoryFormats() { add_memory_format(at::MemoryFormat::ChannelsLast3d, "channels_last_3d"); } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_new.h b/torch/csrc/utils/tensor_new.h index 70a4fbca0bac..088f8d1927c4 100644 --- a/torch/csrc/utils/tensor_new.h +++ b/torch/csrc/utils/tensor_new.h @@ -5,8 +5,7 @@ #include -namespace torch { -namespace utils { +namespace torch::utils { // NOTE: [torch.tensor, lift_fresh, and device movement] // @@ -134,5 +133,4 @@ at::Tensor asarray( std::optional device, std::optional copy, bool requires_grad); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 9b07b9d32f1c..6014281061bc 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -5,8 +5,8 @@ #include #ifndef USE_NUMPY -namespace torch { -namespace utils { + +namespace torch::utils { PyObject* tensor_to_numpy(const at::Tensor&, bool) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } @@ -40,8 +40,8 @@ void validate_numpy_for_dlpack_deleter_bug() {} bool is_numpy_dlpack_deleter_bugged() { return false; } -} // namespace utils -} // namespace torch +} // namespace torch::utils + #else #include diff --git a/torch/csrc/utils/tensor_qschemes.cpp b/torch/csrc/utils/tensor_qschemes.cpp index 9e9d6dbdcfce..4c2e6f20557e 100644 --- a/torch/csrc/utils/tensor_qschemes.cpp +++ b/torch/csrc/utils/tensor_qschemes.cpp @@ -9,11 +9,10 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -static PyObject* thp_qscheme_array[at::COMPILE_TIME_NUM_QSCHEMES]; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static std::array thp_qscheme_array; void initializeQSchemes() { auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); @@ -40,6 +39,4 @@ PyObject* getTHPQScheme(at::QScheme qscheme) { } return qscheme_; } - -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_qschemes.h b/torch/csrc/utils/tensor_qschemes.h index 71e65479047b..dc982efd1ff9 100644 --- a/torch/csrc/utils/tensor_qschemes.h +++ b/torch/csrc/utils/tensor_qschemes.h @@ -1,11 +1,9 @@ #pragma once #include -namespace torch { -namespace utils { +namespace torch::utils { PyObject* getTHPQScheme(at::QScheme qscheme); void initializeQSchemes(); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_types.h b/torch/csrc/utils/tensor_types.h index 601cc920a2e7..a4b905604da6 100644 --- a/torch/csrc/utils/tensor_types.h +++ b/torch/csrc/utils/tensor_types.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { std::string options_to_string(const at::TensorOptions& options); std::string type_to_string(const at::DeprecatedTypeProperties& type); @@ -18,5 +17,4 @@ std::vector> all_declared_types(); // return python module name of backend, like torch.cuda, torch.foo const char* backend_to_string(const at::Backend& backend); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/throughput_benchmark-inl.h b/torch/csrc/utils/throughput_benchmark-inl.h index 4334a58683bb..ead63d585a05 100644 --- a/torch/csrc/utils/throughput_benchmark-inl.h +++ b/torch/csrc/utils/throughput_benchmark-inl.h @@ -12,9 +12,7 @@ #include #include -namespace torch { -namespace throughput_benchmark { -namespace detail { +namespace torch::throughput_benchmark::detail { template BenchmarkExecutionStats BenchmarkHelper::benchmark( @@ -156,6 +154,4 @@ BenchmarkExecutionStats BenchmarkHelper::benchmark( return stats; } -} // namespace detail -} // namespace throughput_benchmark -} // namespace torch +} // namespace torch::throughput_benchmark::detail diff --git a/torch/csrc/utils/throughput_benchmark.h b/torch/csrc/utils/throughput_benchmark.h index 2fca95ca16bf..5ec44e012631 100644 --- a/torch/csrc/utils/throughput_benchmark.h +++ b/torch/csrc/utils/throughput_benchmark.h @@ -14,8 +14,7 @@ namespace py = pybind11; -namespace torch { -namespace throughput_benchmark { +namespace torch::throughput_benchmark { /** * The struct is used to provide results of a benchmark to the caller @@ -193,7 +192,6 @@ class C10_HIDDEN ThroughputBenchmark { detail::ScriptModuleBenchmark script_module_; detail::ModuleBenchmark module_; }; -} // namespace throughput_benchmark -} // namespace torch +} // namespace torch::throughput_benchmark #include diff --git a/torch/csrc/utils/torch_dispatch_mode.h b/torch/csrc/utils/torch_dispatch_mode.h index d1c1392e37d6..8ca451143573 100644 --- a/torch/csrc/utils/torch_dispatch_mode.h +++ b/torch/csrc/utils/torch_dispatch_mode.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace torch_dispatch_mode { +namespace torch::torch_dispatch_mode { struct StashTorchDispatchModeGuard { public: @@ -54,5 +53,4 @@ struct StashTorchDispatchStackGuard { c10::impl::TorchDispatchModeTLS saved_state_; }; -} // namespace torch_dispatch_mode -} // namespace torch +} // namespace torch::torch_dispatch_mode diff --git a/torch/csrc/utils/variadic.h b/torch/csrc/utils/variadic.h index 78ffe2997142..0f3dc992c61d 100644 --- a/torch/csrc/utils/variadic.h +++ b/torch/csrc/utils/variadic.h @@ -4,8 +4,6 @@ #include #include -#include -#include #include #include From 9adfa143d705495cf829dbf9d0f1bc8710b5d3a0 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Mon, 3 Jun 2024 11:53:06 -0700 Subject: [PATCH 321/706] fix post_grad pattern (#127457) The lowering pattern built by cuda_and_enabled_mixed_mm_and_not_int8() was using ListOf() incorrectly - ListOf() is meant to represent a single repeating pattern - but cuda_and_enabled_mixed_mm_and_not_int8() was passing two patterns - I think based on the comment it's trying to build a sequence which would be represented by an actual list, not ListOf(). The behavior of the existing pattern would be to pass the second pattern as the `partial` parameter of `ListOf` which is meant to be a boolean - so it's almost certainly not what was intended. I tried changing it to be what I thought was the intended behavior but then the resnet152 test failed accuracy - so I'm just preserving the existing behavior with the correct parameter types. Found when adding annotations to pattern_matcher.py (#127458) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127457 Approved by: https://github.com/oulgen --- torch/_inductor/fx_passes/post_grad.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index dd1900000f7c..3677f27e1d20 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -272,11 +272,12 @@ def cuda_and_enabled_mixed_mm_and_not_int8(match): KeywordArg("mat2"), 0xF, ), - CallFunction( - aten.__rshift__.Scalar, - KeywordArg("mat2"), - 4, - ), + # CallFunction( + # aten.__rshift__.Scalar, + # KeywordArg("mat2"), + # 4, + # ), + True, ), 1, ), From 7a60a75256a88066257e16240668ed037a5a29d9 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Mon, 3 Jun 2024 11:53:10 -0700 Subject: [PATCH 322/706] Add typing annotations to pattern_matcher.py (#127458) Turn on `mypy: disallow-untyped-defs` in pattern_matcher.py and fix the fallout. There are still a bunch of `type: ignore` annotations which should eventually be ironed out. In the processs found a bug: #127457 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127458 Approved by: https://github.com/Skylion007 ghstack dependencies: #127457 --- torch/_inductor/fx_passes/pad_mm.py | 22 +- torch/_inductor/pattern_matcher.py | 518 +++++++++++++++++----------- torch/fx/node.py | 7 +- torch/storage.py | 8 +- 4 files changed, 349 insertions(+), 206 deletions(-) diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 626897950746..b2a64df57d36 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -1,6 +1,7 @@ import functools import itertools import operator +import typing from typing import List, Optional, Union import torch @@ -11,7 +12,14 @@ from torch.utils._mode_utils import no_dispatch from ...utils._triton import has_triton -from ..pattern_matcher import fwd_only, gen_register_replacement, joint_fwd_bwd, Match +from ..pattern_matcher import ( + fwd_only, + gen_register_replacement, + joint_fwd_bwd, + Match, + ReplaceFn, + SearchFn, +) aten = torch.ops.aten @@ -636,22 +644,22 @@ def _pad_mm_init(): for pattern, replacement, args, workaround, extra_check in [ ( - mm_pattern, - mm_replace, + typing.cast(SearchFn, mm_pattern), + typing.cast(ReplaceFn, mm_replace), [dim2a(), dim2b()], {}, should_pad_mm, ), ( - bmm_pattern, - bmm_replace, + typing.cast(SearchFn, bmm_pattern), + typing.cast(ReplaceFn, bmm_replace), [dim3a(), dim3b()], {}, should_pad_bmm, ), ( - addmm_pattern, - addmm_replace, + typing.cast(SearchFn, addmm_pattern), + typing.cast(ReplaceFn, addmm_replace), [dim1a(), dim2a(), dim2b()], rep, should_pad_addmm, diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index e91873ea933a..f3ea99f26664 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1,3 +1,4 @@ +# mypy: disallow-untyped-defs from __future__ import annotations import contextlib @@ -13,6 +14,7 @@ import re import textwrap import typing +from abc import ABC, abstractmethod from collections import defaultdict from pathlib import Path from typing import ( @@ -20,12 +22,18 @@ Callable, DefaultDict, Dict, + Generator, Iterable, List, + Mapping, NoReturn, Optional, + Protocol, + Sequence, Set, Tuple, + Type, + TypeVar, Union, ) from typing_extensions import Self, TypeGuard @@ -50,9 +58,6 @@ from .decomposition import select_decomp_table from .lowering import fallback_node_due_to_unsupported_type -if typing.TYPE_CHECKING: - from torch.fx import Node - log = logging.getLogger(__name__) aten = torch.ops.aten prims = torch.ops.prims @@ -61,8 +66,33 @@ NodeOrConstant = Union[Constant, torch.fx.Node] +class SearchFn(Protocol): + __name__: str + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + ... + + +class ReplaceFn(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: + ... + + +class TraceFn(Protocol): + def __call__( + self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any + ) -> torch.fx.GraphModule: + ... + + +T = TypeVar("T") + +# What's a better name for this? +FnsType = Union[torch.fx.node.Target, str] + + class Multiple: - def __init__(self): + def __init__(self) -> None: # Ensure we're really a singleton. assert "MULTIPLE" not in globals() or self is MULTIPLE @@ -76,25 +106,38 @@ class Match: Represents a successfully matched pattern. """ - def __init__(self, pattern: PatternExpr, args=None, kwargs=None): + pattern: PatternExpr + args: List[Any] + kwargs: Dict[str, Any] + nodes: List[torch.fx.Node] + targets: Dict[_TargetExpr, torch.fx.node.Target] + ctx: MatchContext + replacement_graph: Optional[torch.fx.Graph] + + def __init__( + self, + ctx: MatchContext, + pattern: PatternExpr, + args: Optional[Sequence[Any]] = None, + kwargs: Optional[Dict[str, Any]] = None, + ) -> None: super().__init__() self.pattern = pattern # The input nodes that must be passed in to the result - self.args = args or [] + self.args = list(args or []) self.kwargs = kwargs or {} # The nodes matched in this expression - self.nodes: List[torch.fx.Node] = [] + self.nodes = [] # Mapping CallFunction to the node.target - self.targets: Dict[_TargetExpr, torch.fx.node.Target] = {} - self.ctx: Optional[MatchContext] = None - self.replacement_graph: Optional[torch.fx.Graph] = None + self.targets = {} + self.ctx = ctx + self.replacement_graph = None @property def graph(self) -> torch.fx.Graph: - assert self.ctx return self.ctx.graph - def extend(self, other: Match): + def extend(self, other: Match) -> None: if self.kwargs: for key in set(self.kwargs.keys()) & set(other.kwargs.keys()): if self.kwargs[key] != other.kwargs[key]: @@ -109,16 +152,15 @@ def bundle(self) -> Match: self.args = [tuple(self.args)] if self.args else [] return self - def __repr__(self): + def __repr__(self) -> str: return f"Match(..., {self.args}, {self.kwargs})" - def erase_nodes(self, graph: torch.fx.Graph): + def erase_nodes(self, graph: torch.fx.Graph) -> None: for n in reversed(self.nodes): if not n._erased: graph.erase_node(n) def output_nodes(self) -> List[Optional[torch.fx.Node]]: - assert self.ctx return [ (self.ctx.pattern_to_node[p] if p is not None else None) for p in self.ctx.outputs @@ -127,15 +169,20 @@ def output_nodes(self) -> List[Optional[torch.fx.Node]]: def output_node(self) -> torch.fx.Node: return next(p for p in self.output_nodes() if p) - def replace_with_graph(self, replacement_graph, args): - assert self.ctx + def replace_with_graph( + self, replacement_graph: torch.fx.Graph, args: Sequence[Any] + ) -> None: ReplacementPatternEntry.replace_with_graph( self, self.ctx.graph, replacement_graph, args ) - def replace_by_example(self, replacement_fn, args, trace_fn=None, run_dce=True): - assert self.ctx - + def replace_by_example( + self, + replacement_fn: ReplaceFn, + args: Sequence[Any], + trace_fn: Optional[TraceFn] = None, + run_dce: bool = True, + ) -> None: from torch._inductor.virtualized import V context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext @@ -155,7 +202,9 @@ def replace_by_example(self, replacement_fn, args, trace_fn=None, run_dce=True): class FailedMatch(RuntimeError): - def __init__(self, format_string, *args, **kwargs): + format_string: str + + def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None: self.format_string = format_string # We want to construct error messages lazily instead of eagerly, as # constructing them eagerly can significantly worsen compile times. @@ -166,14 +215,17 @@ def __init__(self, format_string, *args, **kwargs): self.args = args self.kwargs = kwargs - def __str__(self): + def __str__(self) -> str: return self.format_string.format(*self.args, **self.kwargs) - def __bool__(self): + def __bool__(self) -> bool: return False -def is_match(m: Union[Match, FailedMatch]) -> TypeGuard[Match]: +MatchResult = Union[Match, FailedMatch] + + +def is_match(m: MatchResult) -> TypeGuard[Match]: """ TypeGuards cannot act on `self`. Thus this function exists to let mypy recognize FailedMatch.__bool__ as a TypeGuard. @@ -186,32 +238,36 @@ class MatchContext: State needed while running PatternExpr._match(). """ + outputs: List[Optional[PatternExpr]] + pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]] + graph: torch.fx.Graph + exclusive_node_set: List[NodeOrConstant] + def __init__( self, outputs: List[Optional[PatternExpr]], - pattern_to_node: Optional[Dict[PatternExpr, Node]] = None, + pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None, *, graph: torch.fx.Graph, - ): + ) -> None: self.outputs = outputs - self.pattern_to_node = {} if pattern_to_node is None else pattern_to_node + self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node) self.graph = graph - self.exclusive_node_set: List[NodeOrConstant] = [] + self.exclusive_node_set = [] - def match(self, pattern, node): + def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult: """wrapper to check reused nodes in patterns""" if pattern in self.pattern_to_node: if self.pattern_to_node[pattern] == node: - return Match(pattern) # already checked this node + return Match(self, pattern) # already checked this node else: return FailedMatch("repeated pattern differs") m = pattern._match(node, self) assert pattern not in self.pattern_to_node self.pattern_to_node[pattern] = node if m else None - m.ctx = self return m - def filter_multi_user_patterns(self): + def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]: return { pattern: node for pattern, node in self.pattern_to_node.items() @@ -219,17 +275,16 @@ def filter_multi_user_patterns(self): } -class PatternExpr: +class PatternExpr(ABC): """ Base class for types of patterns """ - def _match( - self, node: torch.fx.Node, ctx: MatchContext - ) -> Union[Match, FailedMatch]: - raise NotImplementedError + @abstractmethod + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + ... - def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: + def match(self, node: torch.fx.Node) -> MatchResult: try: return MatchContext([self], graph=node.graph).match(self, node) except FailedMatch as e: @@ -238,10 +293,12 @@ def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: def has_multiple_users(self) -> bool: return False - def __repr__(self): + def __repr__(self) -> str: return self.__class__.__name__ + "()" - def find_anchor_nodes(self, ctx: MatchContext, searched): + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: if self in ctx.pattern_to_node: yield ctx.pattern_to_node[self] @@ -260,8 +317,8 @@ class Arg(PatternExpr): passed in depth first order. """ - def _match(self, node: NodeOrConstant, ctx: MatchContext): - return Match(self, args=[node]) # matches anything + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, args=[node]) # matches anything class Ignored(PatternExpr): @@ -269,13 +326,13 @@ class Ignored(PatternExpr): Match an arg, but don't pass it to handler """ - def _match(self, node: NodeOrConstant, ctx: MatchContext): - return Match(self) # matches anything + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self) # matches anything - def __repr__(self): + def __repr__(self) -> str: return "*" - def pretty_print(self, pp: PatternPrettyPrinter): + def pretty_print(self, pp: PatternPrettyPrinter) -> str: return "Ignored()" @@ -284,15 +341,15 @@ class KeywordArg(PatternExpr): Capture a kwarg which will become an input to the handler. """ - def __init__(self, name: str): + def __init__(self, name: str) -> None: super().__init__() self.name = name - def __repr__(self): + def __repr__(self) -> str: return f"KeywordArg({self.name!r})" - def _match(self, node: NodeOrConstant, ctx: MatchContext): - return Match(self, kwargs={self.name: node}) # matches anything + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, kwargs={self.name: node}) # matches anything def pattern_eq(self, other: Any) -> bool: other = typing.cast(Self, other) # super makes sure this is true @@ -304,19 +361,21 @@ class ExclusiveKeywordArg(PatternExpr): Capture a kwarg which will become an input to the handler. """ - def __init__(self, name): + name: str + + def __init__(self, name: str) -> None: super().__init__() self.name = name - def __repr__(self): + def __repr__(self) -> str: return f"ExclusiveKeywordArg({self.name!r})" - def _match(self, node: NodeOrConstant, ctx: MatchContext): + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: if node in ctx.exclusive_node_set: return FailedMatch("exclusive arg appears twice") ctx.exclusive_node_set.append(node) - return Match(self, kwargs={self.name: node}) # matches anything + return Match(ctx, self, kwargs={self.name: node}) # matches anything def pattern_eq(self, other: Any) -> bool: other = typing.cast(Self, other) # super makes sure this is true @@ -328,21 +387,27 @@ class _TargetExpr(PatternExpr): Base class for filtering match by node.target """ - op: Optional[str] = None + fns: List[FnsType] + fns_set: Set[FnsType] - def __init__(self, fns, users: Union[Multiple, int] = 1): - if not self.op: - raise NotImplementedError("Shouldn't directly use _BaseNodeMatch") + def __init__( + self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1 + ) -> None: super().__init__() fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns) - for fn in list(fns): + for fn in fns: if isinstance(fn, torch._ops.OpOverloadPacket): - fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + fns.extend(getattr(fn, overload) for overload in fn.overloads()) - self.fns: List[Union[Callable[..., Any], str]] = fns - self.fns_set: Set[Union[Callable[..., Any], str]] = set(fns) + self.fns = fns + self.fns_set = set(fns) self.users = users + @property + @abstractmethod + def op(self) -> str: + ... + def fns_repr(self) -> str: first_repr = self.fns[0] if not isinstance(first_repr, str): @@ -357,7 +422,7 @@ def fns_repr(self) -> str: else: return first_repr - def __repr__(self): + def __repr__(self) -> str: if self.users is MULTIPLE: comma_users = ", MULTIPLE" elif self.users != 1: @@ -369,17 +434,19 @@ def __repr__(self): def has_multiple_users(self) -> bool: return isinstance(self.users, Multiple) or self.users > 1 - def find_anchor_nodes(self, ctx: MatchContext, searched): + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: raise NotImplementedError - def _match_fns(self, node: torch.fx.Node): + def _match_fns(self, node: torch.fx.Node) -> bool: return ( isinstance(node, torch.fx.Node) and node.op == self.op and extract_target(node) in self.fns_set ) - def _match_users(self, node: torch.fx.Node, ctx: MatchContext): + def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool: return ( self in ctx.outputs or self.users is MULTIPLE @@ -396,12 +463,21 @@ def pattern_eq(self, other: Any) -> bool: ) +_SimpleSpec = Tuple[Any, ...] + + class _TargetArgsExpr(_TargetExpr): """ Base class for filtering match by node.{target,args,kwargs} """ - def __init__(self, fns, *args, _users=1, **kwargs): + def __init__( + self, + fns: Union[torch.fx.node.Target, str, Sequence[Any]], + *args: Any, + _users: Union[int, Multiple] = 1, + **kwargs: Any, + ) -> None: super().__init__(fns, _users) self.args = tuple(args) self.kwargs = dict(kwargs) @@ -415,12 +491,18 @@ def __init__(self, fns, *args, _users=1, **kwargs): self.flat_args_kwargs = self.flatten(self.args, self.kwargs) @staticmethod - def simple_flatten(args, kwargs: Dict[Any, Any]): - return (*args, *kwargs.values()), (len(args), *kwargs.keys()) + def simple_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + values = (*args, *kwargs.values()) + spec = (len(args), *kwargs.keys()) + return values, spec @staticmethod - def pytree_flatten(args, kwargs: Dict[Any, Any]): - def norm_spec(s: pytree.TreeSpec): + def pytree_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec: if s.type is None: return s mapping = {immutable_list: list, tuple: list, immutable_dict: dict} @@ -434,7 +516,7 @@ def norm_spec(s: pytree.TreeSpec): spec = norm_spec(spec) return flat, spec - def __repr__(self): + def __repr__(self) -> str: args = [ self.fns_repr(), *map(repr, self.args), @@ -446,7 +528,7 @@ def __repr__(self): args.append(f"_users={self.users}") return f"{self.__class__.__name__}({', '.join(args)})" - def pretty_print(self, pp: PatternPrettyPrinter): + def pretty_print(self, pp: PatternPrettyPrinter) -> str: args = [ self.fns_repr(), *(pp.pretty_print(x) for x in self.args), @@ -460,7 +542,7 @@ def pretty_print(self, pp: PatternPrettyPrinter): joiner_str = ", " return f"{self.__class__.__name__}({joiner_str.join(args)})" - def _match(self, node: torch.fx.Node, ctx: MatchContext): + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: if not self._match_fns(node) or len(node.args) != len(self.args): return FailedMatch("function_mismatch: node={}, pattern={}", node, self) @@ -495,11 +577,11 @@ def _match(self, node: torch.fx.Node, ctx: MatchContext): return FailedMatch("args_structure {} {}", node_spec, self_spec) assert len(node_items) == len(self_items) - m = Match(self) + m = Match(ctx, self) for i, pattern, child_node in zip(itertools.count(), self_items, node_items): if isinstance(pattern, PatternExpr): child_match = ctx.match(pattern, child_node) - if not child_match: + if not is_match(child_match): return child_match m.extend(child_match) elif isinstance(child_node, torch.fx.Node) or child_node != pattern: @@ -510,7 +592,9 @@ def _match(self, node: torch.fx.Node, ctx: MatchContext): m.targets[self] = node.target return m - def find_anchor_nodes(self, ctx: MatchContext, searched): + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: """ This is used when we are matching a pattern with multiple outputs. There is a partial match (stored in ctx) and we want to walk @@ -574,14 +658,14 @@ class _TargetExprVarArgs(_TargetExpr): Matches a call_function node with any arguments which are passed into the pattern """ - def _match(self, node: torch.fx.Node, ctx: MatchContext): + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: if not self._match_fns(node): return FailedMatch("function_mismatch") if not self._match_users(node, ctx): return FailedMatch("multiple_users") - m = Match(self) + m = Match(ctx, self) m.nodes.append(node) m.targets[self] = node.target m.args.extend(node.args) @@ -606,19 +690,19 @@ class ListOf(PatternExpr): Matches a repeated pattern """ - def __init__(self, pattern: PatternExpr, partial=False): + def __init__(self, pattern: PatternExpr, partial: bool = False) -> None: super().__init__() assert isinstance(pattern, PatternExpr) self.pattern = pattern self.partial = partial - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.pattern})" - def _match(self, node: List[torch.fx.Node], ctx: MatchContext): # type: ignore[override] + def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override] if not isinstance(node, (list, tuple)) or len(node) == 0: return FailedMatch("non_list") - m = Match(self) + m = Match(ctx, self) # Propagating patterns with multiple users will ensure we don't revisit # the same nodes pattern_to_node = ctx.filter_multi_user_patterns() @@ -629,7 +713,7 @@ def _match(self, node: List[torch.fx.Node], ctx: MatchContext): # type: ignore[ ) child_match = child_ctx.match(self.pattern, child_node) pattern_to_node = child_ctx.filter_multi_user_patterns() - if not child_match: + if not is_match(child_match): if not self.partial: return FailedMatch("list[{}]: {}", i, child_match) continue @@ -649,54 +733,61 @@ def pattern_eq(self, other: Any) -> bool: class MultiOutputPattern(PatternExpr): - def __init__(self, outputs): + outputs: List[Optional[PatternExpr]] + + def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None: super().__init__() - assert all(isinstance(x, (PatternExpr, type(None))) for x in outputs), outputs - self.outputs: List[Optional[PatternExpr]] = outputs + assert isinstance(outputs[0], _TargetExpr) + assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs + self.outputs = list(outputs) self.op = outputs[0].op @property - def fns(self): - assert self.outputs[0] and hasattr(self.outputs[0], "fns") - return self.outputs[0].fns + def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]: + # This cast is checked above in __init__() + output = typing.cast(_TargetExpr, self.outputs[0]) + return output.fns - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.outputs})" - def pretty_print(self, pp: PatternPrettyPrinter): + def pretty_print(self, pp: PatternPrettyPrinter) -> str: args = [pp.pretty_print(x) for x in self.outputs] joiner_str = f",\n{' '}" str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}" str_out = f"{str_out}\n])" return str_out - def _match(self, node: torch.fx.Node, ctx: MatchContext): - m = ctx.match(self.outputs[0], node) - if not m: + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + output = typing.cast(_TargetExpr, self.outputs[0]) + m = ctx.match(output, node) + if not is_match(m): return m for pattern in self.outputs[1:]: if pattern is None: continue child_match = self._match_from_anchors(pattern, ctx) - if not child_match: + if not is_match(child_match): return child_match m.extend(child_match) return m - def _match_from_anchors(self, pattern, ctx): + def _match_from_anchors( + self, pattern: PatternExpr, ctx: MatchContext + ) -> MatchResult: prior = dict(ctx.pattern_to_node) - m = FailedMatch("no anchor found") + m: MatchResult = FailedMatch("no anchor found") for node in pattern.find_anchor_nodes(ctx, set()): m = ctx.match(pattern, node) - if m: + if is_match(m): return m # revert any partial matches ctx.pattern_to_node = dict(prior) return m - def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: + def match(self, node: torch.fx.Node) -> MatchResult: try: return MatchContext(self.outputs, graph=node.graph).match(self, node) except FailedMatch as e: @@ -719,19 +810,18 @@ class RepeatedExpr(PatternExpr): Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind` """ - def __init__(self, inner_pattern: PatternExpr): + def __init__(self, inner_pattern: _TargetExpr) -> None: super().__init__() - assert hasattr(inner_pattern, "fns") self.inner_pattern = inner_pattern - self.op = inner_pattern.op # type: ignore[attr-defined] + self.op = inner_pattern.op @property - def fns(self): + def fns(self) -> Sequence[FnsType]: return self.inner_pattern.fns - def _match(self, node: torch.fx.Node, ctx: MatchContext): + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: m = ctx.match(self.inner_pattern, node) - if not m: + if not is_match(m): return m ctx.pattern_to_node.pop( self.inner_pattern, @@ -741,7 +831,7 @@ def _match(self, node: torch.fx.Node, ctx: MatchContext): anchor_m = MatchContext([self], graph=node.graph).match( self.inner_pattern, anchor_node ) - if not anchor_m: + if not is_match(anchor_m): return anchor_m m.extend(anchor_m) return m @@ -760,13 +850,13 @@ class PatternPrettyPrinter: all patterns. """ - def __init__(self): + def __init__(self) -> None: self.namespace = torch.fx.graph._Namespace() self.memoized_objs_names: Dict[PatternExpr, str] = {} self.memoized_objs_pp: Dict[PatternExpr, str] = {} @staticmethod - def run(obj: PatternExpr, output_name="output"): + def run(obj: PatternExpr, output_name: str = "output") -> str: """ Serializes obj to python code with obj written out to `output_name` """ @@ -783,7 +873,7 @@ def run(obj: PatternExpr, output_name="output"): return "\n".join(output) - def pretty_print(self, obj): + def pretty_print(self, obj: Any) -> str: if isinstance(obj, _TargetArgsExpr): if memoized_name := self.memoized_objs_names.get(obj): return memoized_name @@ -794,7 +884,7 @@ def pretty_print(self, obj): return repr(obj) - def memoize(self, obj): + def memoize(self, obj: _TargetArgsExpr) -> str: obj_str = obj.pretty_print(self) obj_name = obj.fns_repr() for prefix in ("aten.", "torch.", "prims."): @@ -806,15 +896,25 @@ def memoize(self, obj): return tmp_name +class _PassDictsType(Protocol): + def __getitem__(self, k: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: + ... + + @dataclasses.dataclass class PatternEntry: pattern: PatternExpr extra_check: Callable[[Match], bool] - def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: raise NotImplementedError - def register(self, pass_dicts, target=None, prepend=False): + def register( + self, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + target: Union[torch.fx.node.Target, None] = None, + prepend: bool = False, + ) -> None: if target is None: assert hasattr(self.pattern, "fns") for fn in self.pattern.fns: @@ -826,6 +926,7 @@ def register(self, pass_dicts, target=None, prepend=False): else: pass_dicts[(self.pattern.op, target)].append(self) else: + pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts) for x in pass_dicts: self.register(x, target, prepend=prepend) @@ -834,7 +935,7 @@ def register(self, pass_dicts, target=None, prepend=False): class LoweringPatternEntry(PatternEntry): handler: Callable[..., Any] - def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) with graph.inserting_before(node): replacement = graph.call_function(handler, tuple(match.args), match.kwargs) @@ -852,7 +953,7 @@ class GraphPatternEntry(PatternEntry): handler: Callable[..., Any] - def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: with graph.inserting_before(node): self.handler(match, *match.args, **match.kwargs) @@ -865,9 +966,9 @@ class ReplacementPatternEntry(PatternEntry): def replace_with_graph( match: Match, graph: torch.fx.Graph, - replacement_graph: torch.fx.Graph, - args: List[Any], - ): + replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule], + args: Sequence[torch.fx.Node], + ) -> None: output_nodes = match.output_nodes() first_node = output_nodes[0] @@ -876,7 +977,7 @@ class Replacer(torch.fx.Interpreter): call_module = None # type: ignore[assignment] get_attr = None # type: ignore[assignment] - def run_node(self, node) -> Any: + def run_node(self, node: torch.fx.Node) -> Any: if node.op in ("placeholder", "output"): return super().run_node(node) if node.op == "call_function": @@ -905,7 +1006,9 @@ def run_node(self, node) -> Any: ] last_node = min(indices, key=operator.itemgetter(0))[1] - def percolate_tags(node, recompute_tag, input_stops): + def percolate_tags( + node: torch.fx.Node, recompute_tag: str, input_stops: Set[torch.fx.Node] + ) -> None: queue = [node] visited = set() @@ -925,7 +1028,7 @@ def percolate_tags(node, recompute_tag, input_stops): if isinstance(replacement, torch.fx.Node): replacement = [replacement] - def maybe_getitem(node): + def maybe_getitem(node: torch.fx.Node) -> Any: if node.op != "call_function": return None if node.target != operator.getitem: @@ -933,7 +1036,10 @@ def maybe_getitem(node): assert len(node.args) == 2 return node.args[1] - def replace(old, new): + def replace( + old: Union[torch.fx.Node, None], + new: Union[torch.fx.Node, Sequence[torch.fx.Node], None], + ) -> None: if old is None: assert new is None return @@ -955,12 +1061,13 @@ def replace(old, new): # recomputable tags. It is possible in some scenarios that we # incorrectly tag some nodes as recomputables. if "recompute" in old.meta: - percolate_tags(new, old.meta["recompute"], args) + percolate_tags(new, old.meta["recompute"], set(args)) old.replace_all_uses_with(new) graph.erase_node(old) return + new = typing.cast(Sequence[torch.fx.Node], new) # `new` is not a node: it's a list of nodes. # # This happens when we want to replace a node that has a single @@ -1005,20 +1112,21 @@ def replace(old, new): match.erase_nodes(graph) - def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + assert match.replacement_graph is not None self.replace_with_graph( match, graph, - match.replacement_graph, # type: ignore[arg-type] + match.replacement_graph, self.normalize_args(*match.args, **match.kwargs), ) -def _return_true(match): +def _return_true(match: Match) -> bool: return True -def log_trace_failure(search_fn, e): +def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None: log.info( "Replacement pattern %s failed to apply due to shape mismatch: %s", search_fn.__name__, @@ -1027,16 +1135,16 @@ def log_trace_failure(search_fn, e): def register_replacement( - search_fn, - replace_fn, + search_fn: SearchFn, + replace_fn: ReplaceFn, example_inputs: Iterable[Any], - trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], - pass_dicts, - extra_check=_return_true, - scalar_workaround=(), - exclusive_arg_names=(), - search_fn_pattern=None, -): + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + search_fn_pattern: Union[PatternExpr, None] = None, +) -> bool: """ Create a replacement rule based on example functions that get traced to create patterns. This supports both training and inference when @@ -1052,7 +1160,7 @@ def register_replacement( """ argnames_static = [*inspect.signature(search_fn).parameters.keys()] - def check_fn(match: Match): + def check_fn(match: Match) -> bool: """ Often shapes get burned into the pattern, so our initial match ran with `ignore_types=(int, ...)`. @@ -1106,7 +1214,7 @@ def check_fn(match: Match): # Later, when we actually do the replacement, the symbolic shape # sizes will get re-traced and added to the graph. - def search_fn_new(*args_new): + def search_fn_new(*args_new: Any) -> Any: return search_fn(*args_new[len(args_new) - len(args) :]) try: @@ -1148,15 +1256,17 @@ def search_fn_new(*args_new): scalar_workaround=scalar_workaround, ) - specific_pattern_match = specific_pattern.match(match.output_nodes()[0]) # type: ignore[arg-type] + node = match.output_nodes()[0] + assert node is not None + specific_pattern_match = specific_pattern.match(node) - if specific_pattern_match and extra_check(specific_pattern_match): + if is_match(specific_pattern_match) and extra_check(specific_pattern_match): # trace the pattern using the shapes from the user program match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment] return True return False - def normalize_args(**kwargs): + def normalize_args(**kwargs: Any) -> List[Any]: args = [] for name in argnames_static: args.append(kwargs.pop(name)) @@ -1206,11 +1316,11 @@ def normalize_args(**kwargs): def _serialize_pattern( unique_name: str, - search_fn, + search_fn: SearchFn, example_inputs: Iterable[Any], - trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], - scalar_workaround, -): + trace_fn: TraceFn, + scalar_workaround: Union[Dict[str, Union[float, int]], None], +) -> PatternExpr: def get_file_template() -> str: auto_generated_msg = textwrap.dedent( """\ @@ -1274,6 +1384,8 @@ def get_file_template() -> str: f.write(serialized_pattern) f.write("\n") + return pattern + SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns" @@ -1286,22 +1398,22 @@ def get_file_template() -> str: Iterable[Any], Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], Any, - str, + PatternExpr, ] ] = [] def gen_register_replacement( unique_name: str, - search_fn, - replace_fn, + search_fn: SearchFn, + replace_fn: ReplaceFn, example_inputs: Iterable[Any], - trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], - pass_dicts, - extra_check=_return_true, - scalar_workaround=(), - exclusive_arg_names=(), -): + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> None: # Make sure the example_inputs is materialized. example_inputs = tuple(example_inputs) @@ -1316,7 +1428,7 @@ def gen_register_replacement( ) if not m or not hasattr(m, unique_name): log.warning( - "Precompiled pattern %r not found. Run torchen/fuse/gen_patterns.py.", + "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.", unique_name, ) pat = getattr(m, unique_name) @@ -1347,11 +1459,15 @@ def gen_register_replacement( @functorch_config.patch(functionalize_rng_ops=False) def gen_pattern( - search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=() + search_fn: SearchFn, + example_inputs: Sequence[Any], + trace_fn: TraceFn, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), ) -> PatternExpr: argnames = [*inspect.signature(search_fn).parameters.keys()] - if scalar_workaround == (): + if scalar_workaround is None: scalar_workaround = {} flat_inputs = [] input_idx = 0 # Positional arguments index @@ -1374,34 +1490,42 @@ def gen_pattern( def register_lowering_pattern( - pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False -): + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ Register an aten to inductor IR replacement pattern. The decorated function is saved and then called a lowering time allowing direct pattern to inductor IR conversion. """ - def decorator(handler): + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: assert callable(handler) LoweringPatternEntry( pattern=pattern, extra_check=extra_check, handler=handler ).register(pass_dict, prepend=prepend) - handler._inductor_lowering_function = True + handler._inductor_lowering_function = True # type: ignore[attr-defined] return handler return decorator def register_graph_pattern( - pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False -): + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ Register a pattern that runs a function on the FX graph, allowing custom transformation code. """ - def decorator(handler): + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: assert callable(handler) GraphPatternEntry( pattern=pattern, extra_check=extra_check, handler=handler @@ -1447,7 +1571,7 @@ def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool: return "mutation_region_id" not in next(iter(graph.nodes)).meta -def compute_mutation_region_ids(graph: torch.fx.GraphModule): +def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None: mutation_region_id = 0 for nd in graph.nodes: if is_mutation_op(nd): @@ -1457,8 +1581,10 @@ def compute_mutation_region_ids(graph: torch.fx.GraphModule): class PatternMatcherPass: def __init__( - self, prevent_match_across_mutations=False, pass_name: Optional[str] = None - ): + self, + prevent_match_across_mutations: bool = False, + pass_name: Optional[str] = None, + ) -> None: super().__init__() self.patterns: DefaultDict[ Tuple[str, torch.fx.node.Target], List[PatternEntry] @@ -1522,20 +1648,20 @@ def apply(self, graph: torch.fx.GraphModule) -> int: counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) return count - def clear(self): + def clear(self) -> None: self.patterns.clear() -def _not_implemented(*args, **kwargs) -> NoReturn: +def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn: raise NotImplementedError def fx_to_pattern( - gm, - ignore_types=(), - argnames=(), - scalar_workaround=(), - exclusive_arg_names=(), + gm: Union[torch.fx.GraphModule, torch.fx.Graph], + ignore_types: Sequence[Type[Any]] = (), + argnames: Sequence[str] = (), + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), ) -> PatternExpr: """ Convert an FX graph into a PatternExpr. This is useful for simple @@ -1547,7 +1673,7 @@ def fx_to_pattern( inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()} assert len(inv_scalar_workaround) == len(scalar_workaround) - def process_arg(x): + def process_arg(x: T) -> Union[T, KeywordArg, Ignored]: if isinstance(x, (float, int)) and x in inv_scalar_workaround: return KeywordArg(inv_scalar_workaround[x]) if type(x) in ignore_types: @@ -1563,7 +1689,9 @@ class Converter(torch.fx.Interpreter): call_module = _not_implemented get_attr = _not_implemented - def placeholder(self, target, args, kwargs): + def placeholder( + self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] + ) -> Union[ExclusiveKeywordArg, KeywordArg]: n = next(argnum) if n < len(argnames): name = argnames[n] @@ -1578,7 +1706,9 @@ def placeholder(self, target, args, kwargs): else: return KeywordArg(name) - def call_function(self, target, args, kwargs): + def call_function( + self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] + ) -> PatternExpr: args, kwargs = pytree.tree_map(process_arg, (args, kwargs)) if list in ignore_types: # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...] @@ -1586,11 +1716,11 @@ def call_function(self, target, args, kwargs): kwargs = {k: process_arg(a) for k, a in kwargs.items()} return CallFunction(target, *args, **kwargs) - def run_node(self, n): + def run_node(self, n: torch.fx.Node) -> Any: rv = super().run_node(n) if n.op == "output" and isinstance(rv, tuple): - assert len(rv) == len(n.args[0]) - for r, arg in zip(rv, n.args[0]): + assert len(rv) == len(n.args[0]) # type: ignore[arg-type] + for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type] r.users = len(arg.users) else: rv.users = len(n.users) @@ -1603,7 +1733,9 @@ def run_node(self, n): @torch.no_grad() -def fwd_only(fn, args, *, run_dce=True) -> torch.fx.GraphModule: +def fwd_only( + fn: Callable[..., Any], args: Sequence[Any], *, run_dce: bool = True +) -> torch.fx.GraphModule: """Build a normalized inference graph, for use with fx_to_pattern""" # TODO - look into using aot autograd, asserting no mutating ops here with enable_python_dispatcher(): @@ -1615,11 +1747,13 @@ def fwd_only(fn, args, *, run_dce=True) -> torch.fx.GraphModule: @torch.enable_grad() -def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule: +def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule: """Build a normalized training graph, for use with fx_to_pattern""" gm: Optional[torch.fx.GraphModule] = None - def record_joint_graph(joint_graph, inputs, **kwargs): + def record_joint_graph( + joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any + ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]: nonlocal gm assert not gm gm = clone_graph(joint_graph) @@ -1661,7 +1795,7 @@ def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: return args -def stable_topological_sort(graph: torch.fx.Graph): +def stable_topological_sort(graph: torch.fx.Graph) -> None: # Nodes are in exactly one of these three collections: # - Nodes in `pending` are waiting to be processed (in reverse order): @@ -1697,12 +1831,12 @@ def stable_topological_sort(graph: torch.fx.Graph): assert not waiting and len(ready) == len(graph.nodes) -def init_once_fakemode(fn: Callable[..., Any]): +def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: """Wrapper around lazy init functions in fx_passes/""" @functools.lru_cache(None) @functools.wraps(fn) - def lazy_init(): + def lazy_init() -> Any: counters_ref = counters["inductor"].copy() with torch._guards.tracing( @@ -1718,10 +1852,10 @@ def lazy_init(): return lazy_init -def config_flag(name): +def config_flag(name: str) -> Callable[[Match], Any]: """Function for extra_check to put pass behind a flag""" - def flag_check(match): + def flag_check(match: Match) -> Any: return getattr(config, name) return flag_check @@ -1729,7 +1863,7 @@ def flag_check(match): def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: class CopyGraph(Transformer): - def run_node(self, old_node): + def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node: new_node = super().run_node(old_node) if isinstance(new_node, torch.fx.Proxy): new_node.node.meta.update(old_node.meta) @@ -1746,7 +1880,7 @@ def run_node(self, old_node): def get_arg_value( node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None -): +) -> Any: return ( node.args[arg_number] if len(node.args) > arg_number @@ -1754,7 +1888,7 @@ def get_arg_value( ) -def filter_nodes(nodes: Iterable[torch.fx.Node], fn) -> List[torch.fx.Node]: +def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]: fns = [fn] if isinstance(fn, torch._ops.OpOverloadPacket): fns.extend([getattr(fn, overload) for overload in fn.overloads()]) @@ -1762,7 +1896,7 @@ def filter_nodes(nodes: Iterable[torch.fx.Node], fn) -> List[torch.fx.Node]: return [node for node in nodes if node.target in fns] -def extract_target(node: Node): +def extract_target(node: torch.fx.Node) -> torch.fx.node.Target: """For call_function and call_method, we directly use the target function; For call_module, the target is string, and we treat the module class as a function. diff --git a/torch/fx/node.py b/torch/fx/node.py index d9af26c9207f..8b4768aa497a 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -60,7 +60,7 @@ @compatibility(is_backward_compatible=False) -def has_side_effect(fn: Callable) -> None: +def has_side_effect(fn: Callable) -> Callable: _side_effectful_functions.add(fn) return fn @@ -238,7 +238,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', self._prev = self self._next = self self._erased = False - self._sort_key = () + self._sort_key: Any = () # If set, use this fn to print this node self._repr_fn : Optional[Callable[[Node], str]] = None @@ -295,6 +295,7 @@ def prepend(self, x: 'Node') -> None: psk = x._prev._sort_key nsk = x._next._sort_key if len(psk) > len(nsk): + idx: int *prefix, idx = psk[:len(nsk) + 1] x._sort_key = (*prefix, idx + 1) elif len(psk) < len(nsk): @@ -421,7 +422,7 @@ def insert_arg(self, idx : int, arg : Argument) -> None: self._args = args_left + (arg,) + args_right - _new_input_nodes = {} + _new_input_nodes: Dict[Node, None] = {} map_arg(arg, _new_input_nodes.setdefault) for new_use in _new_input_nodes.keys(): diff --git a/torch/storage.py b/torch/storage.py index 32070783f494..dd268cab0d2e 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -39,7 +39,7 @@ def size(self) -> int: def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 - def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var] # noqa: E704 + def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var, misc] # noqa: E704 """Returns a copy of this object in CUDA memory. If this object is already in CUDA memory and on the correct device, then @@ -54,7 +54,7 @@ def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var] device2 = torch.device('cuda', device) if device else torch.device('cuda') return self.to(device=device2, non_blocking=non_blocking) - def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[type-var] # noqa: E704 + def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[type-var, misc] # noqa: E704 """Returns a copy of this object in HPU memory. If this object is already in HPU memory and on the correct device, then @@ -182,7 +182,7 @@ def _to(self, dtype): storage = storage.clone() return storage - def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var] # noqa: E704 + def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var, misc] # noqa: E704 return _to(self, device, non_blocking) def double(self): @@ -856,7 +856,7 @@ def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[misc, type- hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu(device, non_blocking) return self._new_wrapped_storage(hpu_storage) - def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var] + def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var, misc] _warn_typed_storage_removal() if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError(f"Cannot create {device.type.upper()} storage with quantized dtype") From 97ea2b5d8342a34c57200fc997467d1086674eee Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Mon, 3 Jun 2024 11:53:13 -0700 Subject: [PATCH 323/706] documentation for pattern_matcher.py (#127459) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127459 Approved by: https://github.com/oulgen ghstack dependencies: #127457, #127458 --- torch/_inductor/pattern_matcher.py | 54 ++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index f3ea99f26664..e3e0ddcfd547 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1,4 +1,40 @@ +""" +# Inductor Pattern Matcher + +The pattern matcher enables search/replace within an FX graph. + +The main entrypoint to the pattern matcher is register_replacement(). Given a +search function and a replacement function this will register a replacement with +a pass (such as torch._inductor.fx_passes.joint_graph.patterns). + +Internally the pattern matcher represents patterns as a graph (a DAG). Creating +new patterns manually as a graph is cumbersome and error-prone so the standard +way to create patterns (using register_replacement()) is to provide a search +function and a replacement function which is traced and converted into a graph. + +Because the search functions are built somewhat generic (they tend to ignore +tensor sizes, for example) register_replacement() allows you to specify an +`extra_check` function which performs additional checks to verify that the +matched pattern fully matches before returning it. + +## Precompiled Patterns + +New patterns are added using register_replacement(). Patterns added in this way +can have a compile-time overhead because they need to be traced before +use. Patterns can be precompiled and added using gen_register_replacement() +instead. To do this you call gen_register_replacement() instead of +register_replacement(). The arguments are the same except for an additional +unique name which is used as a lookup key. + +## Internals + +The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr +implements a `_match` method which returns either a `Match` object for a +successful match or a `FailedMatch` object for a failure to match. +""" + # mypy: disallow-untyped-defs + from __future__ import annotations import contextlib @@ -104,6 +140,13 @@ def __init__(self) -> None: class Match: """ Represents a successfully matched pattern. + + The `Match` object is returned to represent a successfully matched + pattern. Included in the Match are the pattern that was matched, the graph + nodes matched, and any args that were used during the matching. + + The args and kwargs are specific to the type of pattern that was matched and + provide hints about what was matched. """ pattern: PatternExpr @@ -202,6 +245,13 @@ def replace_by_example( class FailedMatch(RuntimeError): + """ + Represents a unsuccessful match. + + The `FailedMatch` object is returned to represent a failure to match a + pattern. + """ + format_string: str def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None: @@ -235,7 +285,7 @@ def is_match(m: MatchResult) -> TypeGuard[Match]: class MatchContext: """ - State needed while running PatternExpr._match(). + Internal state needed while running PatternExpr._match(). """ outputs: List[Optional[PatternExpr]] @@ -277,7 +327,7 @@ def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]: class PatternExpr(ABC): """ - Base class for types of patterns + Base class for types of patterns. """ @abstractmethod From c490046693e77e254664e19d940e9b05a1da18ef Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 4 Jun 2024 16:33:06 +0000 Subject: [PATCH 324/706] [BE]: Update cudnn to 9.1.0.70 (#123475) cuDNN has managed to upload cu11 and cu12 wheels for ~~9.0.0.312~~ 9.1.0.70, so trying this out... CC @Skylion007 @malfet Co-authored-by: Wei Wang Co-authored-by: atalman Pull Request resolved: https://github.com/pytorch/pytorch/pull/123475 Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/nWEIdia --- .ci/docker/build.sh | 50 +++++++++---------- .ci/docker/common/install_base.sh | 2 +- .ci/docker/common/install_cudnn.sh | 17 +++---- .ci/docker/ubuntu-cuda/Dockerfile | 2 +- .../scripts/generate_binary_build_matrix.py | 8 +-- .github/workflows/docker-builds.yml | 18 +++---- ...linux-aarch64-binary-manywheel-nightly.yml | 10 ++-- .../generated-linux-binary-manywheel-main.yml | 6 +-- ...nerated-linux-binary-manywheel-nightly.yml | 30 +++++------ ...d-linux-s390x-binary-manywheel-nightly.yml | 10 ++-- ...rated-macos-arm64-binary-wheel-nightly.yml | 10 ++-- ...generated-windows-binary-wheel-nightly.yml | 40 +++++++-------- .../workflows/inductor-micro-benchmark.yml | 2 +- .github/workflows/inductor-perf-compare.yml | 2 +- .../workflows/inductor-perf-test-nightly.yml | 2 +- .github/workflows/inductor-periodic.yml | 2 +- .github/workflows/inductor.yml | 12 ++--- .github/workflows/lint.yml | 4 +- .github/workflows/periodic.yml | 4 +- .github/workflows/pull.yml | 24 ++++----- .github/workflows/slow.yml | 4 +- .../target-determination-indexer.yml | 2 +- .github/workflows/torchbench.yml | 2 +- .github/workflows/trunk.yml | 12 ++--- .../aot_eager_timm_training.csv | 2 +- .../dynamic_inductor_torchbench_training.csv | 2 +- .../cu124/inductor_torchbench_training.csv | 2 +- .../dynamic_aot_eager_timm_training.csv | 2 +- docker.Makefile | 2 +- 29 files changed, 140 insertions(+), 145 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index fa4dbf2b0165..537b0b9d2ba7 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -91,9 +91,9 @@ _UCC_COMMIT=20eae37090a4ce1b32bcce6144ccad0b49943e0b # configuration, so we hardcode everything here rather than do it # from scratch case "$image" in - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -105,9 +105,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -119,9 +119,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -134,9 +134,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -149,9 +149,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=9 PROTOBUF=yes @@ -164,9 +164,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=9 PROTOBUF=yes @@ -179,9 +179,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9) CUDA_VERSION=11.8.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -193,9 +193,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -207,9 +207,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -221,9 +221,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -330,10 +330,10 @@ case "$image" in DOCS=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12) + pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12) ANACONDA_PYTHON_VERSION=3.8 CUDA_VERSION=11.8 - CUDNN_VERSION=8 + CUDNN_VERSION=9 CLANG_VERSION=12 PROTOBUF=yes DB=yes @@ -380,7 +380,7 @@ case "$image" in ANACONDA_PYTHON_VERSION=3.9 CONDA_CMAKE=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter) + pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter) ANACONDA_PYTHON_VERSION=3.9 CUDA_VERSION=11.8 CONDA_CMAKE=yes @@ -447,7 +447,7 @@ tmp_tag=$(basename "$(mktemp -u)" | tr '[:upper:]' '[:lower:]') #when using cudnn version 8 install it separately from cuda if [[ "$image" == *cuda* && ${OS} == "ubuntu" ]]; then IMAGE_NAME="nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu${UBUNTU_VERSION}" - if [[ ${CUDNN_VERSION} == 8 ]]; then + if [[ ${CUDNN_VERSION} == 9 ]]; then IMAGE_NAME="nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}" fi fi @@ -499,7 +499,7 @@ docker build \ "$@" \ . -# NVIDIA dockers for RC releases use tag names like `11.0-cudnn8-devel-ubuntu18.04-rc`, +# NVIDIA dockers for RC releases use tag names like `11.0-cudnn9-devel-ubuntu18.04-rc`, # for this case we will set UBUNTU_VERSION to `18.04-rc` so that the Dockerfile could # find the correct image. As a result, here we have to replace the # "$UBUNTU_VERSION" == "18.04-rc" diff --git a/.ci/docker/common/install_base.sh b/.ci/docker/common/install_base.sh index ebaa17878ade..fd58ad8a60b8 100755 --- a/.ci/docker/common/install_base.sh +++ b/.ci/docker/common/install_base.sh @@ -3,7 +3,7 @@ set -ex install_ubuntu() { - # NVIDIA dockers for RC releases use tag names like `11.0-cudnn8-devel-ubuntu18.04-rc`, + # NVIDIA dockers for RC releases use tag names like `11.0-cudnn9-devel-ubuntu18.04-rc`, # for this case we will set UBUNTU_VERSION to `18.04-rc` so that the Dockerfile could # find the correct image. As a result, here we have to check for # "$UBUNTU_VERSION" == "18.04"* diff --git a/.ci/docker/common/install_cudnn.sh b/.ci/docker/common/install_cudnn.sh index 3afd2f28841f..60f4561d420c 100644 --- a/.ci/docker/common/install_cudnn.sh +++ b/.ci/docker/common/install_cudnn.sh @@ -1,23 +1,18 @@ #!/bin/bash -if [[ ${CUDNN_VERSION} == 8 ]]; then +if [[ -n "${CUDNN_VERSION}" ]]; then # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement mkdir tmp_cudnn pushd tmp_cudnn - if [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-8.9.7.29_cuda12-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz - elif [[ ${CUDA_VERSION:0:4} == "12.1" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-8.9.2.26_cuda12-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz - elif [[ ${CUDA_VERSION:0:4} == "11.8" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-8.7.0.84_cuda11-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/redist/cudnn/v8.7.0/local_installers/11.8/${CUDNN_NAME}.tar.xz + if [[ ${CUDA_VERSION:0:2} == "12" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda12-archive" + elif [[ ${CUDA_VERSION:0:2} == "11" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda11-archive" else print "Unsupported CUDA version ${CUDA_VERSION}" exit 1 fi - + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz tar xf ${CUDNN_NAME}.tar.xz cp -a ${CUDNN_NAME}/include/* /usr/local/cuda/include/ cp -a ${CUDNN_NAME}/lib/* /usr/local/cuda/lib64/ diff --git a/.ci/docker/ubuntu-cuda/Dockerfile b/.ci/docker/ubuntu-cuda/Dockerfile index cb3ea502d231..3b2bbea0097a 100644 --- a/.ci/docker/ubuntu-cuda/Dockerfile +++ b/.ci/docker/ubuntu-cuda/Dockerfile @@ -139,7 +139,7 @@ COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm ARG CUDNN_VERSION ARG CUDA_VERSION COPY ./common/install_cudnn.sh install_cudnn.sh -RUN if [ "${CUDNN_VERSION}" -eq 8 ]; then bash install_cudnn.sh; fi +RUN if [ -n "${CUDNN_VERSION}" ]; then bash install_cudnn.sh; fi RUN rm install_cudnn.sh # Install CUSPARSELT diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index b192475f72b1..3e50cb7930fa 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -19,7 +19,7 @@ CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.1": "12.1.1", "12.4": "12.4.0"} -CUDA_ARCHES_CUDNN_VERSION = {"11.8": "8", "12.1": "8", "12.4": "8"} +CUDA_ARCHES_CUDNN_VERSION = {"11.8": "9", "12.1": "9", "12.4": "9"} ROCM_ARCHES = ["6.0", "6.1"] @@ -42,7 +42,7 @@ "nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950 "nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " @@ -55,7 +55,7 @@ "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950 "nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | " @@ -68,7 +68,7 @@ "nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | " diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 0eec1556bb96..f732dab42050 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -38,19 +38,19 @@ jobs: matrix: runner: [linux.12xlarge] docker-image-name: [ - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9, - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9, - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9, + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9, + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9, + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9, pytorch-linux-focal-py3.8-clang10, pytorch-linux-focal-py3.11-clang10, pytorch-linux-focal-py3.12-clang10, pytorch-linux-focal-rocm-n-1-py3, pytorch-linux-focal-rocm-n-py3, - pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12, + pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12, pytorch-linux-focal-py3-clang9-android-ndk-r21e, pytorch-linux-jammy-py3.8-gcc11, pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks, @@ -58,7 +58,7 @@ jobs: pytorch-linux-jammy-py3-clang15-asan, pytorch-linux-focal-py3-clang10-onnx, pytorch-linux-focal-linter, - pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter, + pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter, pytorch-linux-jammy-py3-clang12-executorch ] include: diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index 726dbf40f985..a1a7e6fd9537 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -54,7 +54,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_8-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cpu-aarch64-test: # Testing @@ -162,7 +162,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-aarch64-test: # Testing @@ -270,7 +270,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-aarch64-test: # Testing @@ -378,7 +378,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-aarch64-test: # Testing @@ -486,7 +486,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-aarch64-test: # Testing diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 6e7edae7b613..053877b1c90e 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -48,7 +48,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda11_8-test: # Testing @@ -88,7 +88,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_1-test: # Testing @@ -128,7 +128,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_4-test: # Testing diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 8ad43b4c3660..03e3e3f4db20 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -174,7 +174,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda11_8-test: # Testing @@ -237,7 +237,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_1-test: # Testing @@ -300,7 +300,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_4-test: # Testing @@ -690,7 +690,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda11_8-test: # Testing @@ -753,7 +753,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda12_1-test: # Testing @@ -816,7 +816,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda12_4-test: # Testing @@ -1206,7 +1206,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda11_8-test: # Testing @@ -1269,7 +1269,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_1-test: # Testing @@ -1332,7 +1332,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_4-test: # Testing @@ -1722,7 +1722,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda11_8-test: # Testing @@ -1785,7 +1785,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_1-test: # Testing @@ -1848,7 +1848,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_4-test: # Testing @@ -2238,7 +2238,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda11_8-test: # Testing @@ -2301,7 +2301,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_1-test: # Testing @@ -2364,7 +2364,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_4-test: # Testing diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index 4f0569c253f2..db0748463da5 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -54,7 +54,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_8-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cpu-s390x-test: # Testing @@ -117,7 +117,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_9-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-s390x-test: # Testing @@ -180,7 +180,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-s390x-test: # Testing @@ -243,7 +243,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-s390x-test: # Testing @@ -306,7 +306,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-s390x-test: # Testing diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index 94a8fd9cd3de..b4910d46ed5e 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -46,7 +46,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -165,7 +165,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -284,7 +284,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -403,7 +403,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -522,7 +522,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index d64c221e7895..d06f99bd9a5a 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -46,7 +46,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -290,7 +290,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -536,7 +536,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -782,7 +782,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1027,7 +1027,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1271,7 +1271,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1517,7 +1517,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1763,7 +1763,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2008,7 +2008,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2252,7 +2252,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2498,7 +2498,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2744,7 +2744,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2989,7 +2989,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3233,7 +3233,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3479,7 +3479,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3725,7 +3725,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3970,7 +3970,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4214,7 +4214,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4460,7 +4460,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4706,7 +4706,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index 4fe0ddf50ef2..431545ea6d0d 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -21,7 +21,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index e485a8bfce1b..a5e4ad1781aa 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -18,7 +18,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index e77c915749f3..2f129c52fe13 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -71,7 +71,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 6f8c06ed030b..2fe649cebb5e 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -23,7 +23,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 0f9c81104f9f..7ce641761f2e 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -44,7 +44,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -86,7 +86,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -112,7 +112,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -133,7 +133,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -175,7 +175,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -189,7 +189,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f1b6611d00e0..e0e4d3c20cd8 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,7 +20,7 @@ jobs: with: timeout: 120 runner: linux.2xlarge - docker-image: pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter + docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 @@ -36,7 +36,7 @@ jobs: with: timeout: 120 runner: linux.2xlarge - docker-image: pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter + docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 925bca54c074..b2f404da6d65 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -67,7 +67,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda11.8-py3.9-gcc9 - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -89,7 +89,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda11.8-py3.10-gcc9-debug - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 build-with-debug: true test-matrix: | { include: [ diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 2b81e998bde5..71f1e11094e2 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -237,7 +237,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda11.8-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "distributed", shard: 1, num_shards: 3, runner: "linux.8xlarge.nvidia.gpu" }, @@ -262,7 +262,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, @@ -290,7 +290,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, @@ -325,12 +325,12 @@ jobs: { config: "default", shard: 1, num_shards: 1 }, ]} - linux-jammy-cuda-11_8-cudnn8-py3_8-clang12-build: - name: linux-jammy-cuda11.8-cudnn8-py3.8-clang12 + linux-jammy-cuda-11_8-cudnn9-py3_8-clang12-build: + name: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 uses: ./.github/workflows/_linux-build-label.yml with: - build-environment: linux-jammy-cuda11.8-cudnn8-py3.8-clang12 - docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12 + build-environment: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 + docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -389,7 +389,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-version: cpu test-matrix: | { include: [ @@ -401,7 +401,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-version: "12.1" test-matrix: | { include: [ @@ -413,7 +413,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 cuda-version: "12.4" test-matrix: | { include: [ @@ -475,7 +475,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -502,7 +502,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 31db7af8fc55..50f74b01f08c 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -41,7 +41,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3-gcc9-slow-gradcheck - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -70,7 +70,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index 0ce1bae6a413..e8bf91c8d9ee 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -26,7 +26,7 @@ jobs: id: calculate-docker-image uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 working-directory: pytorch - name: Use following to pull public copy of the image diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index 73befe34c078..ac5814966899 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -16,7 +16,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index a91238fa2c9b..e727445d3ada 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -39,7 +39,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, @@ -63,7 +63,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: libtorch-linux-focal-cuda12.1-py3.7-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 build-generates-artifacts: false runner: linux.4xlarge test-matrix: | @@ -77,7 +77,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-no-ops - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -88,7 +88,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, @@ -112,7 +112,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: libtorch-linux-focal-cuda12.4-py3.7-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 build-generates-artifacts: false runner: linux.4xlarge test-matrix: | @@ -126,7 +126,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-no-ops - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv index 1def1d99bd53..fe7efa082cea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv @@ -218,7 +218,7 @@ tf_mixnet_l,pass,6 -tinynet_a,pass,6 +tinynet_a,fail_accuracy,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv index a3c9c3915fc5..ee58808c0bb0 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv @@ -182,7 +182,7 @@ phlippe_densenet,pass,6 -phlippe_resnet,fail_accuracy,6 +phlippe_resnet,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv index 02411bef6cc5..cfc524426644 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv @@ -182,7 +182,7 @@ phlippe_densenet,pass,6 -phlippe_resnet,fail_accuracy,6 +phlippe_resnet,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv index 1def1d99bd53..fe7efa082cea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv @@ -218,7 +218,7 @@ tf_mixnet_l,pass,6 -tinynet_a,pass,6 +tinynet_a,fail_accuracy,6 diff --git a/docker.Makefile b/docker.Makefile index a33c411907bc..7f131707e7ab 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -10,7 +10,7 @@ endif CUDA_VERSION_SHORT ?= 12.1 CUDA_VERSION ?= 12.1.1 -CUDNN_VERSION ?= 8 +CUDNN_VERSION ?= 9 BASE_RUNTIME = ubuntu:22.04 BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-devel-ubuntu22.04 CMAKE_VARS ?= From ff32f6c93b76b1e161876023948c40e241c23730 Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Tue, 4 Jun 2024 16:54:23 +0000 Subject: [PATCH 325/706] Use freshly traced jit-traced module to be used in export analysis (#127577) Summary: When we export already traced module, it seems to be modifying some global state causing the traced modules to fail to run. For now, we are only logging for test cases, so it is probs ok to trace fresh copy to be used in export for now. Test Plan: CI Differential Revision: D57983518 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127577 Approved by: https://github.com/pianpwk --- torch/export/_trace.py | 17 +++++++++-------- torch/jit/_trace.py | 30 +++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 976fddf0c874..0ed664f43fc9 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -999,17 +999,17 @@ def _temp_disable_texpr_fuser(): torch._C._jit_set_texpr_fuser_enabled(original_state) -def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): - with _temp_disable_texpr_fuser(): +class _WrapperModule(torch.nn.Module): + def __init__(self, f): + super().__init__() + self.f = f - class _WrapperModule(torch.nn.Module): - def __init__(self, f): - super().__init__() - self.f = f + def forward(self, *args, **kwargs): + return self.f(*args, **kwargs) - def forward(self, *args, **kwargs): - return self.f(*args, **kwargs) +def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): + with _temp_disable_texpr_fuser(): from torch.jit._trace import TopLevelTracedModule export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs) @@ -1034,6 +1034,7 @@ def forward(self, *args, **kwargs): strict=False, _is_torch_jit_trace=True, ).module() + else: return _export( _WrapperModule(traced_callable), diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 17914a5a444d..8be700ee7711 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -646,6 +646,11 @@ def analyze_ts_result_with_export_result(export, trace): flat_trace = pytree.tree_leaves(trace) for orig, loaded in zip(flat_export, flat_trace): + if orig.layout != loaded.layout: + return False + # mkldnn is not supported for torch.allclose + if orig.layout == torch._mkldnn: # type: ignore[attr-defined] + return True if type(orig) != type(loaded): return False @@ -1013,6 +1018,21 @@ def forward(self, x): _process_jit_trace_inputs_for_export, ) + traced_func_for_export = _trace_impl( + func, + example_inputs=example_inputs, + optimize=optimize, + check_trace=False, + check_inputs=check_inputs, + check_tolerance=check_tolerance, + strict=strict, + _force_outplace=_force_outplace, + _module_class=_module_class, + _compilation_unit=_compilation_unit, + example_kwarg_inputs=example_kwarg_inputs, + _store_inputs=_store_inputs, + ) + export_args, _ = _process_jit_trace_inputs_for_export( example_inputs, example_kwarg_inputs ) @@ -1038,7 +1058,7 @@ def _log_exportability(func_to_export, export_func, export_args, export_type): return try: - traced_result = traced_func(*export_args) + traced_result = func_to_export(*export_args) except Exception as e: _ = e log_torch_jit_trace_exportability( @@ -1066,22 +1086,22 @@ def _convert_ts_to_export_source_to_source(func, export_args): return TS2EPConverter(func, export_args).convert().module() # torch.jit.trace is noop when the original module is torch.jit.ScriptModule - if not isinstance(traced_func, torch.jit.ScriptModule): + if not isinstance(traced_func_for_export, torch.jit.ScriptModule): _log_exportability( - traced_func, + traced_func_for_export, _direct_export_and_lower, export_args, _ExportType.DIRECT_EXPORT, ) _log_exportability( - traced_func, + traced_func_for_export, _convert_ts_to_export_experimental, export_args, _ExportType.TRACE_AND_EXPORT, ) _log_exportability( - traced_func, + traced_func_for_export, _convert_ts_to_export_source_to_source, export_args, _ExportType.SOURCE_TO_SOURCE, From 36e9b7161362cc89fd69eb55924b4e48efc58181 Mon Sep 17 00:00:00 2001 From: Hu Niu Date: Tue, 4 Jun 2024 16:56:01 +0000 Subject: [PATCH 326/706] Enable UFMT on test/test_jit_fuser_te.py (#127759) Part of #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127759 Approved by: https://github.com/ezyang --- .lintrunner.toml | 1 - test/test_jit_fuser_te.py | 898 ++++++++++++++++++++++++-------------- 2 files changed, 572 insertions(+), 327 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index abf8ed9e28dd..033414d8bbc8 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1072,7 +1072,6 @@ exclude_patterns = [ 'test/test_jit_disabled.py', 'test/test_jit_fuser.py', 'test/test_jit_fuser_legacy.py', - 'test/test_jit_fuser_te.py', 'test/test_jit_legacy.py', 'test/test_jit_llga_fuser.py', 'test/test_jit_profiling.py', diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 071249192ec6..7b087d361d8b 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1,15 +1,16 @@ # Owner(s): ["NNC"] +import contextlib +import math import operator import os import unittest -import contextlib -import math +import warnings +from typing import List + import torch import torch.nn.functional as F from torch.testing import FileCheck -from typing import List -import warnings # these needs to be set before `common_utils` # infers `GRAPH_EXECUTOR`. @@ -20,42 +21,79 @@ torch._C._jit_set_profiling_executor(True) torch._C._get_graph_executor_optimize(True) -from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \ - enable_profiling_mode_for_profiling_tests, slowTest, skipIfTorchDynamo, TEST_WITH_ASAN, \ - TEST_WITH_ROCM, IS_FBCODE -from torch.testing._internal.jit_utils import JitTestCase, \ - RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \ - clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager - -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \ - OpDTypes -from torch.testing._internal.common_jit import JitCommonTestCase -from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn +from itertools import combinations, permutations, product from textwrap import dedent -from itertools import product, permutations, combinations - -from test_jit import backward_graph, get_lstm_inputs, get_milstm_inputs, \ - LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell from jit.test_fuser_common import TestFuserCommon # noqa: F401 -FUSION_GROUP = 'prim::TensorExprGroup' +from test_jit import ( + backward_graph, + get_lstm_inputs, + get_milstm_inputs, + LSTMCellC, + LSTMCellF, + LSTMCellS, + MiLSTMCell, +) + +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + onlyCPU, + OpDTypes, + ops, +) +from torch.testing._internal.common_jit import JitCommonTestCase + +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_utils import ( + enable_profiling_mode_for_profiling_tests, + GRAPH_EXECUTOR, + IS_FBCODE, + ProfilingMode, + run_tests, + skipIfTorchDynamo, + slowTest, + TEST_WITH_ASAN, + TEST_WITH_ROCM, +) +from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn +from torch.testing._internal.jit_utils import ( + clone_inputs, + get_traced_sample_variant_pairs, + JitTestCase, + NoTracerWarnContextManager, + RUN_CUDA, + RUN_CUDA_HALF, + RUN_CUDA_MULTI_GPU, + set_fusion_group_inlining, + TensorExprTestOptions, + warmup_backward, +) + +FUSION_GROUP = "prim::TensorExprGroup" LLVM_ENABLED = torch._C._llvm_enabled() -autograd_check_set = {'aten::__is__', 'prim::AutogradAllNonZero', 'prim::AutogradAllZero', 'prim::ListConstruct'} +autograd_check_set = { + "aten::__is__", + "prim::AutogradAllNonZero", + "prim::AutogradAllZero", + "prim::ListConstruct", +} + def strip_profiling_nodes(nodes): - profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'} + profiling_opcodes = {"prim::BailoutTemplate", "prim::BailOut"} return [n for n in nodes if n.kind() not in profiling_opcodes] + def warmup_forward(f, *args, profiling_count=2): for i in range(profiling_count): results = f(*args) return results + @contextlib.contextmanager def texpr_reductions_enabled(): old = torch._C._jit_set_texpr_reductions_enabled(True) @@ -64,6 +102,7 @@ def texpr_reductions_enabled(): finally: torch._C._jit_set_texpr_reductions_enabled(old) + @contextlib.contextmanager def texpr_enable_strategy(strategy): old = torch._C._jit_set_fusion_strategy(strategy) @@ -72,6 +111,7 @@ def texpr_enable_strategy(strategy): finally: torch._C._jit_set_fusion_strategy(old) + @contextlib.contextmanager def inline_fusion_groups(): old_inlining = torch._C._debug_get_fusion_group_inlining() @@ -93,7 +133,7 @@ def setUp(self): fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)] self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy) - self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + self.devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] self.int_dtypes = [ torch.int8, torch.int16, @@ -117,7 +157,11 @@ def tearDown(self): def assertAllFused(self, graph, except_for=None): except_for = except_for if except_for is not None else set() # TODO - upstream - guards = "prim::TypeCheck", "prim::RequiresGradCheck", "prim::TensorExprDynamicGuard" + guards = ( + "prim::TypeCheck", + "prim::RequiresGradCheck", + "prim::TensorExprDynamicGuard", + ) guard_found = False def autodiff_guard(node): @@ -128,7 +172,10 @@ def autodiff_guard(node): return False li_inps = list(inps[0].node().inputs()) for li_inp in li_inps: - if li_inp.node().kind() in ("prim::AutogradAllNonZero", "prim::AutogradAllZero"): + if li_inp.node().kind() in ( + "prim::AutogradAllNonZero", + "prim::AutogradAllZero", + ): return True return False @@ -151,7 +198,6 @@ def is_guard(node): self.assertTrue(guard_found) - def assertLastGraphAllFused(self): self.assertAllFused(torch.jit.last_executed_optimized_graph()) @@ -159,7 +205,7 @@ def findFusionGroups(self, graph): result = [] for n in graph.nodes(): if n.kind() == FUSION_GROUP: - result.append(n.g('Subgraph')) + result.append(n.g("Subgraph")) continue for block in n.blocks(): result += self.findFusionGroups(block) @@ -169,7 +215,7 @@ def test_typecheck(self): a = torch.ones(1) def fused_kernel(a, b): - return (a + b) * 2. + return (a + b) * 2.0 scripted = self.checkScript(fused_kernel, (a, a)) graph = scripted.graph_for(a, a) @@ -191,7 +237,7 @@ def func(x): return x2.sum() with texpr_reductions_enabled(): - a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu') + a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) self.assertLastGraphAllFused() @@ -201,13 +247,13 @@ def test_nop(self): def test_sum_dim(self): def func(x): - return x.sum((0, )) * 2 + return x.sum((0,)) * 2 def func_neg(x): - return x.sum((-2, )) * 2 + return x.sum((-2,)) * 2 with texpr_reductions_enabled(): - a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu') + a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) self.assertLastGraphAllFused() @@ -216,10 +262,10 @@ def func_neg(x): def test_sum_keepdim_cast(self): def func(x): - return x.sum((0, ), keepdim=True, dtype=torch.double) * 2 + return x.sum((0,), keepdim=True, dtype=torch.double) * 2 with texpr_reductions_enabled(): - a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu') + a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") a = a.reshape(5, 3) self.checkScript(func, (a,)) @@ -227,6 +273,7 @@ def func(x): def test_abs(self): for device in self.devices: + def func(x): return x.abs() * 2 @@ -236,19 +283,24 @@ def func(x): def test_unsqueeze_size_calculation(self): for device in self.devices: + def foo(b, d): x = d.unsqueeze(1) - y = x * 42. + y = x * 42.0 z = b + y - r = z / 42. + r = z / 42.0 return r - inputs = (torch.rand(20, 28, device=device, requires_grad=True), torch.rand(20, device=device)) + inputs = ( + torch.rand(20, 28, device=device, requires_grad=True), + torch.rand(20, device=device), + ) scripted = self.checkScript(foo, inputs) self.assertAllFused(scripted.graph_for(*inputs)) def test_zero_element_tensors(self): for device in self.devices: + def decode(sin_t, cos_t): theta = torch.atan2(sin_t.float(), cos_t.float()) return theta @@ -267,17 +319,25 @@ def test_arg_configurations_smoke(self): # TODO: add optionally enabled debug counters to the fuser to verify # that we really can tell the difference between configurations for device in self.devices: + def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) - traced_f = torch.jit.trace(f, (x, y,)) + traced_f = torch.jit.trace( + f, + ( + x, + y, + ), + ) self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) def test_broadcast(self): for device in self.devices: + def scaleshift(x, scale, shift): return x * scale + shift @@ -290,16 +350,14 @@ def scaleshift(x, scale, shift): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") + @unittest.skipIf( + GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on" + ) def test_cuda_half(self): - x = torch.randn(4, 4, dtype=torch.half, device='cuda') - y = torch.randn(4, 4, dtype=torch.half, device='cuda') + x = torch.randn(4, 4, dtype=torch.half, device="cuda") + y = torch.randn(4, 4, dtype=torch.half, device="cuda") - funcs = [ - self.fn_test_comparison_gt_lt, - self.fn_test_relu, - self.fn_test_exp - ] + funcs = [self.fn_test_comparison_gt_lt, self.fn_test_relu, self.fn_test_exp] # Note: Non fused inputs must be float to prevent loss of precision inputs = (x.float(), y.float()) @@ -318,9 +376,17 @@ def test_cuda_half(self): # Verifies gradients for output, fusion_output in zip(outputs_half, fusion_outputs): grads = torch.autograd.grad( - output.float().sum(), local_inputs, allow_unused=True, retain_graph=True) + output.float().sum(), + local_inputs, + allow_unused=True, + retain_graph=True, + ) fusion_grads = torch.autograd.grad( - fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True) + fusion_output.sum(), + local_fusion_inputs, + allow_unused=True, + retain_graph=True, + ) grads_half = [t.half() for t in grads] self.assertEqual(grads_half, fusion_grads) @@ -332,7 +398,7 @@ def test_checks_cat_inputs(self): # need to be checked for having the same map size, before we can # run the kernel. def f(x, y): - return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) + return torch.cat([x + 2 * x + x**2, y + 4 * y + y**3], dim=0) # NOTE: y is broadcastable to x, but output of f(x, y) should have # shape 3x4, and not 4x4. @@ -348,6 +414,7 @@ def test_chunk(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def fn(x): a, b, c = x.chunk(3, 1) return a * b + c @@ -362,6 +429,7 @@ def test_chunk_correctness(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def chunk_4_0(x): x0, x1, x2, x3 = x.chunk(4, 0) return x0 + x1 + x2 + x3 @@ -378,12 +446,12 @@ def chunk_4_last(x): tensors = [ # splitSize = 1 torch.randn(4, 4, 4, dtype=torch.float, device=device), - # contiguous case torch.randn(12, 8, 16, dtype=torch.float, device=device), - # non-contiguous case - torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2), + torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose( + 1, 2 + ), ] for tensor in tensors: @@ -399,6 +467,7 @@ def test_chunk_distributes(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 @@ -420,6 +489,7 @@ def test_chunk_motion_deduplicates_inputs(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def func1(x): z = x * x z0, z1 = z.chunk(2) @@ -462,6 +532,7 @@ def fn(s, x, y, z): def test_minmax(self): for device in self.devices: + def tmax(a, b): return torch.max(2 * a, b) @@ -470,26 +541,26 @@ def tmin(a, b): a = torch.randn(4, 4, dtype=torch.float) b = torch.randn(4, 4, dtype=torch.float) - nan = torch.tensor(float('nan'), dtype=torch.float) + nan = torch.tensor(float("nan"), dtype=torch.float) for f, inputs, device in product( - (tmax, tmin), - ([a, b], [a, nan], [b, nan]), - self.devices): + (tmax, tmin), ([a, b], [a, nan], [b, nan]), self.devices + ): inputs = [t.to(device) for t in inputs] s = self.checkScript(f, inputs) self.assertAllFused(s.graph_for(*inputs)) def test_clamp(self): for device in self.devices: + def func2(a, b): return torch.clamp(a + b, min=0, max=2) def funcInf(a, b): - return torch.clamp(a + b, min=0, max=float('inf')) + return torch.clamp(a + b, min=0, max=float("inf")) def funcNegInf(a, b): - return torch.clamp(a + b, min=float('-inf'), max=0) + return torch.clamp(a + b, min=float("-inf"), max=0) def funcOptMin(a, b): return torch.clamp(a + b, max=2) @@ -499,31 +570,47 @@ def funcOptMax(a, b): a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True) b = torch.randn(4, 4, dtype=torch.float, device=device) - nan = torch.tensor(float('nan'), dtype=torch.float, device=device) + nan = torch.tensor(float("nan"), dtype=torch.float, device=device) funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) for f, inputs in product(funcs, [[a, b], [a, nan]]): inp1, inp2 = inputs s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) - self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) + self.assertAllFused( + s.graph_for(inp1, inp2), + except_for={"aten::size", "aten::_size_if_not_equal"}, + ) c = s(inp1, inp2) with enable_profiling_mode_for_profiling_tests(): warmup_backward(c.sum()) graph = backward_graph(s) - self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'}.union(autograd_check_set)) + self.assertAllFused( + graph, + except_for={"aten::Float", "aten::_grad_sum_to_size"}.union( + autograd_check_set + ), + ) def test_clamp_double(self): for device in self.devices: + def clamp_double(x, eta: float): return 1 - x.clamp(eta, 1 - eta) x = torch.tensor([1.0, 1.0], dtype=torch.double, device=device) eta = 1e-9 - s = self.checkScript(clamp_double, (x, eta), profiling=ProfilingMode.PROFILING, atol=1e-10, rtol=1e-5) - self.assertAllFused(s.graph_for(x, eta), except_for={'aten::sub'}) + s = self.checkScript( + clamp_double, + (x, eta), + profiling=ProfilingMode.PROFILING, + atol=1e-10, + rtol=1e-5, + ) + self.assertAllFused(s.graph_for(x, eta), except_for={"aten::sub"}) def test_clamp_int(self): for device in self.devices: + def clamp_int(x, eta: int): return x.clamp(0, eta) @@ -535,6 +622,7 @@ def clamp_int(x, eta: int): def test_add_bool(self): sizes = [(1,), (2,), (4, 4)] for device, size in product(self.devices, sizes): + def f(x, y, z): return x + y + z @@ -546,6 +634,7 @@ def f(x, y, z): def test_mul_bool(self): for device in self.devices: + def f(x, y, z): return x * y * z @@ -558,6 +647,7 @@ def f(x, y, z): def test_div_bool(self): for device in self.devices: + def f(x, y, z): return (x + y) / z @@ -605,10 +695,7 @@ def test_minmax_int_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) - binary_ops = [ - torch.min, - torch.max - ] + binary_ops = [torch.min, torch.max] devices = self.devices for dtype, op, device in product(self.int_dtypes, binary_ops, devices): try: @@ -633,6 +720,7 @@ def apply(fn): def test_comparison_eq_ne(self): for device in self.devices: + def f(x, y): mask = (x == 0).type_as(x) z = x * mask + y @@ -664,6 +752,7 @@ def test_comparison_gt_lt(self): def test_comparison_ge_le(self): for device in self.devices: + def f(x, y): mask = (x >= 0).type_as(x) z = x * mask + y @@ -678,8 +767,14 @@ def f(x, y): self.assertAllFused(ge.graph_for(x, y)) x.requires_grad_(True) y.requires_grad_(True) - self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + self.assertAllFused( + ge.graph_for(x, y), + except_for=( + "aten::size", + "prim::BroadcastSizes", + "aten::_size_if_not_equal", + ), + ) def test_addcmul(self): for device in self.devices: @@ -694,7 +789,9 @@ def foo(t, t1, t2): graph = ge.graph_for(t, t1, t2) fusion_groups = self.findFusionGroups(graph) self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::add(").check("aten::addcmul(").run(str(fusion_groups[0])) + FileCheck().check("aten::add(").check("aten::addcmul(").run( + str(fusion_groups[0]) + ) # TODO: We leak CUDA memory here because the traced graph holds onto a # constant-ified tensor. Since the Python-global CompilationUnit is alive @@ -743,6 +840,7 @@ def foo(hx, cx): def test_remove_output_used_only_in_size(self): for device in self.devices: + def test_fuse(a, b): c = a + b d = c + b @@ -753,10 +851,10 @@ def test_fuse(a, b): y = torch.ones(1, requires_grad=True, device=device) warmup_forward(scripted_f, x, y, profiling_count=3) g = scripted_f.graph_for(x, y) - diff_nodes = g.findAllNodes('prim::DifferentiableGraph') + diff_nodes = g.findAllNodes("prim::DifferentiableGraph") self.assertEqual(len(diff_nodes), 1) - g = diff_nodes[0].g('Subgraph') - if_nodes = [n for n in g.nodes() if n.kind() == 'prim::If'] + g = diff_nodes[0].g("Subgraph") + if_nodes = [n for n in g.nodes() if n.kind() == "prim::If"] self.assertEqual(len(if_nodes), 1) # the if node and the fusion group inside it should only have one output @@ -777,13 +875,13 @@ def fn(x, y, z): z = torch.randn(4, 2, dtype=torch.float, device=device) ge = self.checkTrace(fn, (x, y, z)) graph = ge.graph_for(x, y, z) - self.assertAllFused(graph, except_for={'aten::add'}) + self.assertAllFused(graph, except_for={"aten::add"}) # XXX: TE fuser can handle concats inside a fusion group. # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @staticmethod def fn_test_exp(x, y): - return (x + .5 * y).exp() + return (x + 0.5 * y).exp() def test_exp(self): for device in self.devices: @@ -795,6 +893,7 @@ def test_exp(self): def test_threshold(self): for device in self.devices: + def f(x): return torch.threshold(x, 0, -10) + x + x + x @@ -804,6 +903,7 @@ def f(x): def test_scalar_arg(self): for device in self.devices: + def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: return p * (x * x + x) @@ -816,15 +916,23 @@ def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: # use another function otherwise we will bailout # and won't be able to do fused checks - def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor: + def fn_test_scalar_arg_requires_grad( + x: torch.Tensor, p: float + ) -> torch.Tensor: return p * (x * x + x) scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) out = scripted(x, p) out = scripted(x, p) out = scripted(x, p) - self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + self.assertAllFused( + scripted.graph_for(x, p), + except_for=( + "aten::size", + "prim::BroadcastSizes", + "aten::_size_if_not_equal", + ), + ) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") @@ -861,8 +969,8 @@ def fn(x, y, z): inputs = [ torch.randn(4, 4, dtype=torch.float), - torch.randn(4, 4, dtype=torch.float, device='cuda:0'), - torch.randn(4, 4, dtype=torch.float, device='cuda:1'), + torch.randn(4, 4, dtype=torch.float, device="cuda:0"), + torch.randn(4, 4, dtype=torch.float, device="cuda:1"), ] prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() @@ -870,8 +978,7 @@ def fn(x, y, z): # There are 3 FusionGroups. Because they have the same graph, they # should reuse the same KernelSpec in the KernelSpec cache. ge = self.checkScript(fn, inputs) - self.assertGraphContainsExactly( - ge.graph_for(*inputs), FUSION_GROUP, 3, True) + self.assertGraphContainsExactly(ge.graph_for(*inputs), FUSION_GROUP, 3, True) new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() # XXX: This assumes that the same kernel isn't already used by another test # FIXME: Use the TE fuser's way of querying the cache. @@ -879,7 +986,7 @@ def fn(x, y, z): @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") def test_nonzero_device_cuda(self): - device = 'cuda:' + str(1) + device = "cuda:" + str(1) x = torch.tensor([0.4], dtype=torch.float, device=device) y = torch.tensor([0.7], dtype=torch.float, device=device) @@ -893,7 +1000,9 @@ def test_lstm(self): for device in self.devices: inputs = get_lstm_inputs(device, training=True) module = self.checkScript(LSTMCellS, inputs) - self.assertAllFused(module.graph_for(inputs), except_for={"prim::TupleConstruct"}) + self.assertAllFused( + module.graph_for(inputs), except_for={"prim::TupleConstruct"} + ) def test_lstm_concat(self): # single fusion node causes error @@ -905,7 +1014,9 @@ def test_lstm_concat(self): except_nodes = {"prim::TupleConstruct", "aten::linear"} # TODO... Chunk if self.dynamic_shapes: - except_nodes = except_nodes.union({"aten::add", "prim::ConstantChunk"}) + except_nodes = except_nodes.union( + {"aten::add", "prim::ConstantChunk"} + ) self.assertAllFused(ge.graph_for(*inputs), except_for=except_nodes) # XXX: TE fuser can handle concats inside a fusion group. # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @@ -914,13 +1025,15 @@ def test_lstm_gates_permutations(self): for device in self.devices: # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. # Test that any permutation of this will still result in one FusionGroup. - choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh'] - template = dedent(''' + choices = ["x.mm(w_ih.t())", "hx.mm(w_hh.t())", "b_ih", "b_hh"] + template = dedent( + """ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): gates = {} + {} + {} + {} ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) return ingate * forgetgate * cellgate * outgate - ''') + """ + ) for permutation in permutations(choices, len(choices)): code = template.format(*permutation) scope = {} @@ -928,9 +1041,11 @@ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): cu = torch.jit.CompilationUnit(code) fusion_group_len = 2 if self.dynamic_shapes else 1 inputs = get_lstm_inputs(device, training=False) - self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs)) + self.assertEqual(cu.cell(*inputs), scope["cell"](*inputs)) forward_graph = cu.cell.graph_for(*inputs) - self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, fusion_group_len) + self.assertGraphContainsExactly( + forward_graph, FUSION_GROUP, fusion_group_len + ) # TODO: Fuser doesn't work at all when inputs require grad. Fix that def test_lstm_traced(self): @@ -945,7 +1060,9 @@ def test_lstm_traced(self): f = FileCheck() if not self.dynamic_shapes: f.check("Chunk") - f.check("aten::sigmoid").check("aten::tanh").run(str(fusion_groups[0 if not self.dynamic_shapes else 1])) + f.check("aten::sigmoid").check("aten::tanh").run( + str(fusion_groups[0 if not self.dynamic_shapes else 1]) + ) def test_milstm(self): if self.dynamic_shapes: @@ -958,9 +1075,11 @@ def test_milstm(self): # TODO: chunk fusion_group_len = 2 if self.dynamic_shapes else 1 self.assertGraphContainsExactly( - forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True) - FileCheck().check("DifferentiableGraph").check("TupleConstruct") \ - .check_next("return").check(FUSION_GROUP).run(str(forward_graph)) + forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True + ) + FileCheck().check("DifferentiableGraph").check("TupleConstruct").check_next( + "return" + ).check(FUSION_GROUP).run(str(forward_graph)) hy, cy = module(*inputs) warmup_backward((hy + cy).sum()) @@ -968,17 +1087,17 @@ def test_milstm(self): @unittest.skip("rand_like is not supported yet") def test_rand_cuda(self): class M(torch.jit.ScriptModule): - __constants__ = ['d'] + __constants__ = ["d"] def __init__(self): super().__init__() - self.d = torch.device('cuda') + self.d = torch.device("cuda") @torch.jit.script_method def create(self, x): return x * x + x + torch.rand_like(x) - x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda') + x = torch.zeros([3, 4, 5], dtype=torch.float, device="cuda") m = M() out1 = m.create(x) out2 = m.create(x) @@ -991,7 +1110,7 @@ def create(self, x): @staticmethod def fn_test_relu(x, y): - return F.relu(x + .5 * y) + return F.relu(x + 0.5 * y) def test_relu(self): for device in self.devices: @@ -1004,7 +1123,7 @@ def test_relu(self): def test_erf(self): for device in self.devices: # only enabled on gpu - if device == 'cpu': + if device == "cpu": continue def fn_test_erf(x): @@ -1015,8 +1134,14 @@ def fn_test_erf(x): self.assertAllFused(ge.graph_for(x)) x.requires_grad_(True) ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING) - self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + self.assertAllFused( + ge.graph_for(x), + except_for=( + "aten::size", + "prim::BroadcastSizes", + "aten::_size_if_not_equal", + ), + ) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") @@ -1031,24 +1156,30 @@ def fn_test_rand2(x, y): r = torch.rand_like(y) return r * x * x - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4, 4, dtype=torch.float, device="cuda") + y = torch.randn(4, 4, dtype=torch.float, device="cuda") script_f = torch.jit.script(fn_test_rand) warmup_forward(script_f, x, y) out = script_f(x, y) self.assertAllFused(script_f.graph_for(x, y)) x.requires_grad_(True) out = script_f(x, y) - self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + self.assertAllFused( + script_f.graph_for(x, y), + except_for=( + "aten::size", + "prim::BroadcastSizes", + "aten::_size_if_not_equal", + ), + ) # test that broadcasting random produces correct results - x = torch.ones(4, 4, dtype=torch.float, device='cuda') - y = torch.ones(4, dtype=torch.float, device='cuda') + x = torch.ones(4, 4, dtype=torch.float, device="cuda") + y = torch.ones(4, dtype=torch.float, device="cuda") script_f = torch.jit.script(fn_test_rand2) warmup_forward(script_f, x, y) out = script_f(x, y) - self.assertEqual(out[0, :] + torch.zeros(4, 4, device='cuda'), out) + self.assertEqual(out[0, :] + torch.zeros(4, 4, device="cuda"), out) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") @@ -1059,8 +1190,8 @@ def fn_test_diamond(x, y): b = y - r return a + b - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4, 4, dtype=torch.float, device="cuda") + y = torch.randn(4, 4, dtype=torch.float, device="cuda") script_f = torch.jit.script(fn_test_diamond) warmup_forward(script_f, x, y) out = script_f(x, y) @@ -1070,8 +1201,8 @@ def test_scalar(self): def fn(x, y): return 2 * x + y - x = torch.tensor(0.1, dtype=torch.float, device='cpu') - y = torch.tensor(1, dtype=torch.float, device='cpu') + x = torch.tensor(0.1, dtype=torch.float, device="cpu") + y = torch.tensor(1, dtype=torch.float, device="cpu") ge = self.checkScript(fn, (x, y)) self.assertAllFused(ge.graph_for(x, y)) @@ -1091,7 +1222,9 @@ def foo(x): g = torch.jit.last_executed_optimized_graph() - FileCheck().check_count("prim::If", 1, exactly=True).check("prim::TensorExpr").run(g) + FileCheck().check_count("prim::If", 1, exactly=True).check( + "prim::TensorExpr" + ).run(g) torch._C._jit_pass_inline(g) f = FileCheck() for _ in range(3): @@ -1100,8 +1233,10 @@ def foo(x): def test_small_constant(self): for device in self.devices: + def fn_test_small_constant(x, y): return (1e-8 * x + 5e-9 * y) * 1e8 + x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) @@ -1116,8 +1251,9 @@ def fn_test_small_constant(x, y): # TODO: fix that and reenable the test. def test_tensor_scalar_ops(self): for device in self.devices: + def should_fuse(x): - z = 3. + z = 3.0 y = x + z return x * y @@ -1134,22 +1270,24 @@ def should_fuse_scalar(x, z): inputs = [ torch.randn(2, 2, dtype=torch.float, device=device), - torch.tensor(3., dtype=torch.float, device=device), + torch.tensor(3.0, dtype=torch.float, device=device), ] ge = self.checkScript(should_fuse_scalar, inputs) # Check that the fused graph computes correct results when the scalar # input changes. inputs = [ torch.randn(2, 2, dtype=torch.float, device=device), - torch.tensor(7., dtype=torch.float, device=device), + torch.tensor(7.0, dtype=torch.float, device=device), ] self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs)) # The TE fuser supports fusion of non-constant scalars self.assertGraphContainsExactly( - ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True) + ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True + ) def test_where_and_typing(self): for device in self.devices: + def f(x, y): mask = x > y res = torch.where(mask, x, y) @@ -1159,14 +1297,16 @@ def f(x, y): y = torch.randn(4, 4, dtype=torch.double, device=device) script_f = self.checkScript(f, (x, y)) - self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) + self.assertAllFused( + script_f.graph_for(x, y), except_for={"prim::TupleConstruct"} + ) def test_disabled(self): old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() torch._C._jit_override_can_fuse_on_cpu(False) def fn(a): - return a ** 2 + a + return a**2 + a x = torch.randn(4, dtype=torch.float, device="cpu") s = self.checkScript(fn, (x,)) @@ -1193,38 +1333,46 @@ def test_torch_to(self): def foo(x): return x.to(torch.float) - foo(torch.tensor([3.], dtype=torch.float)) - foo(torch.tensor([3.], dtype=torch.float)) - FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + foo(torch.tensor([3.0], dtype=torch.float)) + foo(torch.tensor([3.0], dtype=torch.float)) + FileCheck().check_not("TensorExpr").run( + torch.jit.last_executed_optimized_graph() + ) # test not fusing non-const inputs @torch.jit.script def foo(x, dtype: int): return x.to(dtype) - foo(torch.tensor([3.], dtype=torch.float), torch.int) - foo(torch.tensor([3.], dtype=torch.float), torch.int) - FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + foo(torch.tensor([3.0], dtype=torch.float), torch.int) + foo(torch.tensor([3.0], dtype=torch.float), torch.int) + FileCheck().check_not("TensorExpr").run( + torch.jit.last_executed_optimized_graph() + ) # test not fusing to_pinned inputs @torch.jit.script def foo(x, dtype: int): return x.to(pin_memory=True) - foo(torch.tensor([3.], dtype=torch.float), torch.int) - foo(torch.tensor([3.], dtype=torch.float), torch.int) - FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) - + foo(torch.tensor([3.0], dtype=torch.float), torch.int) + foo(torch.tensor([3.0], dtype=torch.float), torch.int) + FileCheck().check_not("TensorExpr").run( + torch.jit.last_executed_optimized_graph() + ) # test across-device not supported if torch.cuda.is_available(): + @torch.jit.script def foo(x): return x.to(device="cuda") - foo(torch.tensor([3.], dtype=torch.float)) - foo(torch.tensor([3.], dtype=torch.float)) - FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + foo(torch.tensor([3.0], dtype=torch.float)) + foo(torch.tensor([3.0], dtype=torch.float)) + FileCheck().check_not("TensorExpr").run( + torch.jit.last_executed_optimized_graph() + ) sizes = [(1, 4), (4, 4)] # reuses cast impl, smaller dtype set for faster test @@ -1245,7 +1393,9 @@ def forward(self, x): return x.to(self.dtype) bad_dtypes = [] - for dtype, output_dtype, device, size in product(dtypes, dtypes, self.devices, sizes): + for dtype, output_dtype, device, size in product( + dtypes, dtypes, self.devices, sizes + ): # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue @@ -1275,12 +1425,15 @@ def test_masked_fill(self): torch.bool, ] sizes = [(2,), (4, 4)] - for self_dtype, device, scalar_val, size in product(dtypes, self.devices, [0.4, 3], sizes): + for self_dtype, device, scalar_val, size in product( + dtypes, self.devices, [0.4, 3], sizes + ): input_v = self.data_for(self_dtype, device, size=size) mask = self.data_for(torch.bool, device, size=size) def fn(input_v, mask): return torch.masked_fill(input_v, mask, scalar_val) + ref = fn(input_v, mask) try: t = torch.jit.trace(fn, (input_v, mask)) @@ -1288,16 +1441,21 @@ def fn(input_v, mask): self.assertLastGraphAllFused() except Exception as e: raise RuntimeError( - " ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)]) # noqa: F821 + " ".join( + [ + "Failed:", + str(self_dtype), + op.__name__, # noqa: F821 + device, + str(size), + ] + ) ) from e def test_isnan(self): x = torch.rand([4]) - x[0] = float('nan') - inputs = [ - x, - torch.tensor([float('nan'), .5]) - ] + x[0] = float("nan") + inputs = [x, torch.tensor([float("nan"), 0.5])] dtypes = [ torch.int8, torch.int16, @@ -1321,7 +1479,7 @@ def test_isnan(self): self.assertLastGraphAllFused() except Exception as e: raise RuntimeError( - " ".join(["Failed:", str(dtype), 'isnan', device]) + " ".join(["Failed:", str(dtype), "isnan", device]) ) from e def test_gelu(self): @@ -1332,7 +1490,9 @@ def apply(fn): F.gelu, ] sizes = [(1,), (2,), (4, 4)] - for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes): + for dtype, op, device, size in product( + self.dtypes, unary_ops, self.devices, sizes + ): # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue @@ -1357,6 +1517,7 @@ def apply(fn): def test_unary_ops(self): with torch._jit_internal._disable_emit_hooks(): + def apply(fn): return lambda x: fn(x) @@ -1411,7 +1572,9 @@ def apply(fn): ] gpu_only = {torch.erf, torch.erfc} sizes = [(1,), (2,), (4, 4)] - for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes): + for dtype, op, device, size in product( + self.dtypes, unary_ops, self.devices, sizes + ): # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue @@ -1435,7 +1598,9 @@ def apply(fn): self.assertAllFused(t.graph_for(x)) except Exception as e: raise RuntimeError( - " ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) + " ".join( + ["Failed:", str(dtype), op.__name__, device, str(size)] + ) ) from e def test_binary_ops(self): @@ -1494,6 +1659,7 @@ def apply(fn): def test_binary_scalar_ops(self): def apply(fn): return lambda x, y: fn(x, y) + ir_template = """ graph(%x : {dtype_x}, %y : {dtype_y}): %z = {op}(%x, %y) @@ -1516,10 +1682,12 @@ def apply(fn): "aten::__lshift__", "aten::__rshift__", ] - dtypes = ['int', 'float', 'bool'] - values = {'int' : [10, 3], 'float' : [12.34, 2.78], 'bool' : [True, False]} + dtypes = ["int", "float", "bool"] + values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]} devices = self.devices - for dtype_x, dtype_y, op, device in product(dtypes, dtypes, binary_ops, devices): + for dtype_x, dtype_y, op, device in product( + dtypes, dtypes, binary_ops, devices + ): code = ir_template.format(**locals()) # Interpret the graph @@ -1535,7 +1703,9 @@ def apply(fn): try: k = torch._C._te.TensorExprKernel(graph) except Exception as e: - raise RuntimeError(" ".join(["Compilation failed:", device, str(code)])) from e + raise RuntimeError( + " ".join(["Compilation failed:", device, str(code)]) + ) from e # Run the graph for x, y in product(values[dtype_x], values[dtype_y]): @@ -1544,7 +1714,11 @@ def apply(fn): res = k.run((x, y)) self.assertEqual(ref, res) except Exception as e: - raise RuntimeError(" ".join(["Failed at runtime:", device, str(x), str(y), str(code)])) from e + raise RuntimeError( + " ".join( + ["Failed at runtime:", device, str(x), str(y), str(code)] + ) + ) from e def test_matmul(self): if self.dynamic_shapes: @@ -1553,31 +1727,33 @@ def test_matmul(self): def fn(x, y): return torch.matmul(x, y) - devices = ['cpu'] # No cuda support for ext calls yet - sizes = [[[128, 128], [128, 128]], - [[10, 10], [10, 10]], - [[1, 16], [16, 128]], - [[128], [128]], - [[128], [128, 128]], - [[3], [3]], - [[3, 4], [4]], - [[10, 3, 4], [4]], - [[10, 3, 4], [10, 4, 5]], - [[10, 3, 4], [4, 5]], - ] + devices = ["cpu"] # No cuda support for ext calls yet + sizes = [ + [[128, 128], [128, 128]], + [[10, 10], [10, 10]], + [[1, 16], [16, 128]], + [[128], [128]], + [[128], [128, 128]], + [[3], [3]], + [[3, 4], [4]], + [[10, 3, 4], [4]], + [[10, 3, 4], [10, 4, 5]], + [[10, 3, 4], [4, 5]], + ] # Only 2D x 2D matrix multiply is supported. For non-supported sizes we # still want to run results verification to test that we didn't # accidentally fuse it, but we skip the 'is-fused' check. # TODO: add support for other shape combinations and make this set empty: - skip_is_fused_check_sizes = ["[[128], [128]]", - "[[128], [128, 128]]", - "[[3], [3]]", - "[[3, 4], [4]]", - "[[10, 3, 4], [4]]", - "[[10, 3, 4], [10, 4, 5]]", - "[[10, 3, 4], [4, 5]]", - ] + skip_is_fused_check_sizes = [ + "[[128], [128]]", + "[[128], [128, 128]]", + "[[3], [3]]", + "[[3, 4], [4]]", + "[[10, 3, 4], [4]]", + "[[10, 3, 4], [10, 4, 5]]", + "[[10, 3, 4], [4, 5]]", + ] for dtype, size, device in product(self.dtypes, sizes, devices): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue @@ -1598,12 +1774,11 @@ def fn(x, y): if str(size) not in skip_is_fused_check_sizes: self.assertAllFused(t.graph_for(x, y)) except Exception as e: - raise RuntimeError( - " ".join(["Failed:", str(dtype), device]) - ) from e + raise RuntimeError(" ".join(["Failed:", str(dtype), device])) from e def test_binary_tensor_scalar_ops(self): with torch._jit_internal._disable_emit_hooks(): + def apply_with_scalar(fn, scalar): return lambda x: fn(x, scalar) @@ -1625,7 +1800,9 @@ def apply_with_scalar(fn, scalar): # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] - for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): + for dtype, op, device, scalar in product( + self.dtypes, binary_ops, devices, scalars + ): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: @@ -1659,7 +1836,9 @@ def apply_with_scalar(fn, scalar): # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, -2.0, -1] # skip 0 - for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): + for dtype, op, device, scalar in product( + self.dtypes, binary_ops, devices, scalars + ): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: @@ -1696,7 +1875,9 @@ def apply_with_scalar(fn, scalar): # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] - for dtype, op, device, scalar in product(dtypes, binary_ops, self.devices, scalars): + for dtype, op, device, scalar in product( + dtypes, binary_ops, self.devices, scalars + ): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: @@ -1780,8 +1961,9 @@ def apply(fn): " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e - - @unittest.skip("FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure") + @unittest.skip( + "FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure" + ) def test_list_ops(self): def apply(fn): return lambda x, y, z: fn([x * x, y * y, z * z]) @@ -1848,6 +2030,7 @@ def apply(fn): def test_unsupported_dtypes(self): for device in self.devices: + def fn(x): return x * x + x @@ -1904,10 +2087,13 @@ def eager(t0, t1, t2, t3, t4): for pair in zip(script(*inputs), eager(*inputs)): test, ref = pair torch.testing.assert_close(test, ref) - self.assertAllFused(script.graph_for(*inputs), except_for={"prim::TupleConstruct"}) + self.assertAllFused( + script.graph_for(*inputs), except_for={"prim::TupleConstruct"} + ) def test_sub_gt_and(self): for device in self.devices: + def eager(t1, t2, t3, t4, t: float): w = t1 - t2 h = t3 - t4 @@ -1920,6 +2106,7 @@ def eager(t1, t2, t3, t4, t: float): # careful not to create a fusion group containing it. return k + 1 return w + t = torch.rand(8, dtype=torch.float, device=device) scripted = self.checkScript(eager, (t, t, t, t, 0.1)) @@ -1929,20 +2116,24 @@ def test_chunk_mul_one(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def eager(x): z, y, w = torch.chunk(x, 3, -1) return z * 3, y, w + x = torch.rand(64, 1, 3072, dtype=torch.float, device=device) z, y, w = eager(x) script = self.checkScript(eager, (x,)) def test_eq_unsqueeze_type_as(self): for device in self.devices: + def eager(a, b): mask = b == 1 mask = torch.unsqueeze(mask, -1) x = mask.type_as(a) return x, mask + a = torch.rand(1, 64, 1024, device=device, dtype=torch.float) b = torch.randint(-2, 2, (1, 64), device=device, dtype=torch.long) script = self.checkScript(eager, (a, b)) @@ -1995,33 +2186,40 @@ def eager(input, weight, bias): bias = torch.rand((64), dtype=torch.float) script = self.checkScript(eager, (input, weight, bias)) - FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + FileCheck().check_not("TensorExpr").run( + torch.jit.last_executed_optimized_graph() + ) def test_type_as_cat(self): with inline_fusion_groups(): + def eager(x, y): return torch.cat((x, y.type_as(x)), dim=1) + dtypes = self.dtypes.copy() # CPU fuser doesn't support float16. dtypes.remove(torch.float16) dtypes.remove(torch.bfloat16) for dtype1, dtype2 in product(dtypes, dtypes): - x = torch.randint(2, (1, 13,)).to(dtype1) + x = torch.randint( + 2, + ( + 1, + 13, + ), + ).to(dtype1) zero = torch.tensor([[0]]).to(dtype2) one = torch.tensor([[1]]).to(dtype2) script = torch.jit.trace(eager, (x, zero)) for _ in range(3): - torch.testing.assert_close( - script(x, zero), - eager(x, zero)) - torch.testing.assert_close( - script(x, one), - eager(x, one)) + torch.testing.assert_close(script(x, zero), eager(x, zero)) + torch.testing.assert_close(script(x, one), eager(x, one)) self.assertAllFused(script.graph_for(x, one)) def test_to_device(self): def eager(x): return x.to(device="cpu").relu() + x = torch.rand(8) script = self.checkScript(eager, (x,)) self.assertAllFused(script.graph_for(x)) @@ -2029,7 +2227,10 @@ def eager(x): def test_dims(self): def eager(x, y): return x / (y + 0.0001) - x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided((1, 1, 768), (768, 1, 1)) + + x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided( + (1, 1, 768), (768, 1, 1) + ) y = torch.tensor([[[2.0]]], dtype=torch.float32) script = self.checkScript(eager, (x, y)) self.assertAllFused(script.graph_for(x, y)) @@ -2062,6 +2263,7 @@ def eager(x, y): def test_exhaust_specializations(self): with texpr_enable_strategy([("STATIC", 1)]): + @torch.jit.script def foo(x): return x + x + x @@ -2080,6 +2282,7 @@ def foo(x): def test_unsqueeze_var_dim(self): def eager(x, y, z: int): return x * torch.unsqueeze(y, dim=z) + x = torch.rand(4, 4, 64).permute(1, 0, 2) y = torch.rand(4, 4) z = 2 @@ -2107,34 +2310,43 @@ def _test_fwd_bwd(self, fn): def test_relu_fwd_bwd(self): def eager(x): return torch.relu(x * 1.01) + self._test_fwd_bwd(eager) def test_hardswish_fwd_bwd(self): def eager(x): return F.hardswish(x) * 1.01 + self._test_fwd_bwd(eager) def test_hardsigmoid_fwd_bwd(self): def eager(x): return F.hardsigmoid(x) * 1.01 + self._test_fwd_bwd(eager) def test_cat_graph_opt(self): def foo(x, y, z): return torch.log(torch.cat([x, y, z])) - self.checkScript(foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5]))) + self.checkScript( + foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5])) + ) # TODO: not sure why not updated graph isn't reflected in last_optimized_graph self.assertLastGraphAllFused() def test_dynamic_cat(self): with inline_fusion_groups(): + @torch.jit.script - def repro(xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor]): + def repro( + xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor] + ): return [ torch.cat([x, torch.cat([y, z], dim=-1)], dim=-1) for x, y, z in zip(xs, ys, zs) ] + for _ in range(3): N = 3 xs = [torch.ones(21) for _ in range(N)] @@ -2153,8 +2365,10 @@ def eager(b: float): def test_cat_2k_args(self): with inline_fusion_groups(): + def eager(x): return torch.relu(torch.cat([x for _ in range(2000)])) + x = torch.randn(1) trace = self.checkTrace(eager, (x,)) fusion_groups = self.findFusionGroups(trace.graph_for(x)) @@ -2164,6 +2378,7 @@ def test_adaptive_avg_pool2d(self): # TODO: once the adaptive_avg_pool2d is available in OpInfo DB, this # test should be moved there with inline_fusion_groups(): + def foo1(x): return torch.nn.functional.adaptive_avg_pool2d(x, (2, 2)) @@ -2179,11 +2394,13 @@ def foo2(x): def test_unrolled_cat(self): with inline_fusion_groups(): + def eager(x): ret = torch.empty(0) for i in range(x.shape[0]): ret = torch.cat([ret, x[i].relu()]) return ret + script = torch.jit.script(eager) # Warm up with size=1 tensor; since the loop iterates once the @@ -2260,6 +2477,7 @@ def foo(x): def test_dynamic_shapes(self): from functools import partial + n = 10 gen_tensor = ( @@ -2272,6 +2490,7 @@ def test_dynamic_shapes(self): ) with texpr_enable_strategy([("DYNAMIC", 20)]): + def foo(x, y, z): return torch.sigmoid(torch.tanh(x)) @@ -2311,7 +2530,9 @@ def fum(x, y, z): torch._C._jit_pass_dce(g) # We should see only one optimized kernel - FileCheck().check_count("TensorExprDynamicGuard", 1, exactly=True).run(g) + FileCheck().check_count( + "TensorExprDynamicGuard", 1, exactly=True + ).run(g) self.assertEqual(func(*inps), func_s(*inps)) gen = gen_tensor[0] @@ -2327,7 +2548,9 @@ def fum(x, y, z): g = torch.jit.last_executed_optimized_graph() torch._C._jit_pass_inline(g) torch._C._jit_pass_dce(g) - FileCheck().check_count("TensorExprDynamicGuard", len(gen_tensor), exactly=True).run(g) + FileCheck().check_count( + "TensorExprDynamicGuard", len(gen_tensor), exactly=True + ).run(g) @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA") def test_autocast_up(self): @@ -2382,7 +2605,6 @@ def f(x): self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3) def test_with_strict_fusion(self): - def success(x): with torch.jit.strict_fusion(): return x + x + x @@ -2445,6 +2667,7 @@ def test_constant_chunk_shapes(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def f(x, y): r = torch.tensor(4) z1, z2 = (x + y + r).chunk(2, dim=1) @@ -2474,10 +2697,10 @@ def test_pow_multiple_dtype(self): # https://github.com/pytorch/pytorch/issues/75476 def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: p = torch.sigmoid(p) - result = p ** gamma + result = p**gamma return result - x = torch.rand((2, 2), dtype=torch.half, device='cuda') + x = torch.rand((2, 2), dtype=torch.half, device="cuda") ref = fn(x) @@ -2491,138 +2714,140 @@ def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: class TestTEFuserStatic(TestTEFuser): dynamic_shapes = False + class TestTEFuserDynamic(TestTEFuser): dynamic_shapes = True + del TestTEFuser works_list = [ - '__radd__', - '__rdiv__', - '__rmul__', - '__rmod__', - 'abs', - 'acos', - 'add', - 'addcmul', - 'addmm.decomposed', - 'asin', - 'atan', - 'atan2', - 'ceil', - 'clamp', - 'clamp.scalar', - 'contiguous', - 'cos', - 'cosh', - 'div.no_rounding_mode', - 'div.true_rounding', - 'div.floor_rounding', - 'div.trunc_rounding', - 'eq', - 'erf', - 'erfc', - 'exp', - 'expand', - 'expand_as', - 'expm1', - 'floor', - 'fmod', - 'fmod.autodiffed', - 'ge', - 'gt', - 'isnan', - 'le', - 'lerp', - 'lgamma', - 'log', - 'log10', - 'log1p', - 'log2', - 'lt', - 'masked_fill', - 'max.binary', - 'mean', - 'min.binary', - 'mm', - 'mul', - 'ne', - 'neg', - 'nn.functional.hardshrink', - 'nn.functional.hardsigmoid', - 'nn.functional.hardswish', - 'nn.functional.softplus', - 'nn.functional.hardtanh', - 'nn.functional.leaky_relu', - 'nn.functional.relu', - 'nn.functional.relu6', - 'nn.functional.softsign', - 'nn.functional.tanhshrink', - 'nn.functional.threshold', - 'permute', - 'pow', - 'reciprocal', - 'remainder', - 'remainder.autodiffed', - 'reshape', - 'reshape_as', - 'round', - 'rsub', - 'rsub.rsub_tensor', - 'rsqrt', - 'sigmoid', - 'sign', - 'sin', - 'sinh', - 'sqrt', - 'sub', - 'sum', - 't', - 'tan', - 'tanh', - 'transpose', - 'true_divide', - 'trunc', - 'unsqueeze', - 'view', - 'view_as', - 'where', - 'bool', - 'byte', - 'char', - 'double', - 'float', - 'half', - 'int', - 'long', - 'short', - 'bool.channels_last', - 'byte.channels_last', - 'char.channels_last', - 'double.channels_last', - 'float.channels_last', - 'half.channels_last', - 'int.channels_last', - 'long.channels_last', - 'short.channels_last', + "__radd__", + "__rdiv__", + "__rmul__", + "__rmod__", + "abs", + "acos", + "add", + "addcmul", + "addmm.decomposed", + "asin", + "atan", + "atan2", + "ceil", + "clamp", + "clamp.scalar", + "contiguous", + "cos", + "cosh", + "div.no_rounding_mode", + "div.true_rounding", + "div.floor_rounding", + "div.trunc_rounding", + "eq", + "erf", + "erfc", + "exp", + "expand", + "expand_as", + "expm1", + "floor", + "fmod", + "fmod.autodiffed", + "ge", + "gt", + "isnan", + "le", + "lerp", + "lgamma", + "log", + "log10", + "log1p", + "log2", + "lt", + "masked_fill", + "max.binary", + "mean", + "min.binary", + "mm", + "mul", + "ne", + "neg", + "nn.functional.hardshrink", + "nn.functional.hardsigmoid", + "nn.functional.hardswish", + "nn.functional.softplus", + "nn.functional.hardtanh", + "nn.functional.leaky_relu", + "nn.functional.relu", + "nn.functional.relu6", + "nn.functional.softsign", + "nn.functional.tanhshrink", + "nn.functional.threshold", + "permute", + "pow", + "reciprocal", + "remainder", + "remainder.autodiffed", + "reshape", + "reshape_as", + "round", + "rsub", + "rsub.rsub_tensor", + "rsqrt", + "sigmoid", + "sign", + "sin", + "sinh", + "sqrt", + "sub", + "sum", + "t", + "tan", + "tanh", + "transpose", + "true_divide", + "trunc", + "unsqueeze", + "view", + "view_as", + "where", + "bool", + "byte", + "char", + "double", + "float", + "half", + "int", + "long", + "short", + "bool.channels_last", + "byte.channels_last", + "char.channels_last", + "double.channels_last", + "float.channels_last", + "half.channels_last", + "int.channels_last", + "long.channels_last", + "short.channels_last", ] known_failures = [ - '__rmatmul__', - 'frac', - 'matmul', + "__rmatmul__", + "frac", + "matmul", ] # If your OpInfo test causes this test to fail, add it here -skip_ops = [ - 'conj' -] +skip_ops = ["conj"] + def get_name(op): l = [op.name] - if op.variant_test_name != '': + if op.variant_test_name != "": l.append(op.variant_test_name) - return '.'.join(l) + return ".".join(l) + # Purpose of this class is to allow super() calls. # super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works. @@ -2631,6 +2856,7 @@ def get_name(op): class TestNNCOpInfoParent(JitCommonTestCase): pass + class TestNNCOpInfo(TestNNCOpInfoParent): def setUp(self): super(TestNNCOpInfoParent, self).setUp() @@ -2656,23 +2882,23 @@ def te_compile(self, device, dtype, op): param_values.append(v) fx_args.append(param_names[-1]) else: - fx_args.append(f'{repr(v)}') + fx_args.append(f"{repr(v)}") for k, v in kwarg_values.items(): if isinstance(v, torch.Tensor): param_names.append(k) param_values.append(v) - fx_args.append(f'{k} = {k}') + fx_args.append(f"{k} = {k}") else: - fx_args.append(f'{k} = {repr(v)}') + fx_args.append(f"{k} = {repr(v)}") code = f""" def f({', '.join(param_names)}): return op.op({', '.join(fx_args)})""" - g = {'torch': torch, 'inf' : math.inf, 'op': op} + g = {"torch": torch, "inf": math.inf, "op": op} exec(code, g) - f = g['f'] - f.__module__ = 'test' + f = g["f"] + f.__module__ = "test" out = f(*param_values) ts_g = torch.jit.trace(f, param_values) @@ -2683,35 +2909,48 @@ def f({', '.join(param_names)}): @onlyCPU @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") - @ops([op for op in op_db if get_name(op) in works_list], allowed_dtypes=(torch.float,)) + @ops( + [op for op in op_db if get_name(op) in works_list], + allowed_dtypes=(torch.float,), + ) def test_working(self, device, dtype, op): self.te_compile(device, dtype, op) @onlyCPU @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") - @ops([op for op in op_db if get_name(op) in known_failures], allowed_dtypes=(torch.float,)) + @ops( + [op for op in op_db if get_name(op) in known_failures], + allowed_dtypes=(torch.float,), + ) def test_failures(self, device, dtype, op): try: self.te_compile(device, dtype, op) except Exception as e: pass else: - raise RuntimeError("Expected test to fail. If it now works, move op into works_list") + raise RuntimeError( + "Expected test to fail. If it now works, move op into works_list" + ) @onlyCPU @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") - @ops([op for op in op_db if get_name(op) not in works_list + known_failures], allowed_dtypes=(torch.float,)) + @ops( + [op for op in op_db if get_name(op) not in works_list + known_failures], + allowed_dtypes=(torch.float,), + ) def test_unsupported(self, device, dtype, op): if get_name(op) in skip_ops: return try: with warnings.catch_warnings(): - warnings.simplefilter('ignore', TracerWarning) # noqa: F821 + warnings.simplefilter("ignore", TracerWarning) # noqa: F821 self.te_compile(device, dtype, op) except Exception as e: pass else: - raise RuntimeError("Expected test to fail. If it now works, move op into works_list") + raise RuntimeError( + "Expected test to fail. If it now works, move op into works_list" + ) @slowTest @onlyCPU @@ -2725,10 +2964,14 @@ def test_nnc_correctness(self, device, dtype, op): for variant, sample in variant_sample_pairs: trace = create_traced_fn(self, variant, cache_traced_fn=True) - ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) + ref = variant( + *clone_inputs((sample.input, *sample.args)), **sample.kwargs + ) trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) - val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) + val = trace( + *clone_inputs((sample.input, *sample.args)), **sample.kwargs + ) atol = 2e-1 if dtype == torch.bfloat16 else 1e-5 rtol = 2e-1 if dtype == torch.bfloat16 else 1e-5 @@ -2740,14 +2983,17 @@ def test_nnc_correctness(self, device, dtype, op): # if the CU is not cleared. torch.jit._state._python_cu.drop_all_functions() + # CPU fuser not currently used in fbcode only_for = ("cuda") if IS_FBCODE else ("cpu", "cuda") instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for) + # Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent) class TestLoopnestRandomizationParent(JitTestCase): pass + class TestLoopnestRandomization(TestLoopnestRandomizationParent): def setUp(self): super(TestLoopnestRandomizationParent, self).setUp() @@ -2812,5 +3058,5 @@ def fn_test_relu(x, y): instantiate_device_type_tests(TestLoopnestRandomization, globals(), only_for=("cpu")) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() From 627d2cd87d7d916e5d9acbf95006a5fd51a52762 Mon Sep 17 00:00:00 2001 From: chuanqiw Date: Tue, 4 Jun 2024 17:15:03 +0000 Subject: [PATCH 327/706] [CI] disable td for xpu ci test by default (#127611) Due to the xpu ci test has been enabled td by default, a lot of test cases (75%) have been skipped in CI tests. It caused some ci failures escaped from the ci tests, for example issue #127539. This PR depends on PR #127595 landed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127611 Approved by: https://github.com/etaf, https://github.com/atalman --- test/run_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/run_test.py b/test/run_test.py index 23160d01281c..d9ef52a42af3 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1200,6 +1200,7 @@ def parse_args(): and not IS_SLOW and not TEST_WITH_ROCM and not IS_MACOS + and "xpu" not in BUILD_ENVIRONMENT and "onnx" not in BUILD_ENVIRONMENT and "debug" not in BUILD_ENVIRONMENT and "parallelnative" not in BUILD_ENVIRONMENT, From 0ff60236abfa6d60c4c9caf2f812f82f23530a49 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 4 Jun 2024 18:19:31 +0000 Subject: [PATCH 328/706] Revert "Retire torch.distributed.pipeline (#127354)" This reverts commit b9c058c203ee38032594f898f27cd8404f113a63. Reverted https://github.com/pytorch/pytorch/pull/127354 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the doc build failure looks legit https://hud.pytorch.org/pytorch/pytorch/commit/b9c058c203ee38032594f898f27cd8404f113a63 ([comment](https://github.com/pytorch/pytorch/pull/127354#issuecomment-2148133982)) --- .lintrunner.toml | 24 + .../distributed/pipeline/benchmark_dataset.py | 58 ++ benchmarks/distributed/pipeline/pipe.py | 296 ++++++ docs/source/conf.py | 85 ++ docs/source/distributed.rst | 19 + docs/source/index.rst | 1 + docs/source/pipeline.rst | 85 ++ test/allowlist_for_publicAPI.json | 28 + test/distributed/pipeline/sync/LICENSE | 27 + test/distributed/pipeline/sync/__init__.py | 8 + test/distributed/pipeline/sync/conftest.py | 61 ++ .../pipeline/sync/skip/__init__.py | 6 + .../pipeline/sync/skip/test_api.py | 52 ++ .../pipeline/sync/skip/test_gpipe.py | 126 +++ .../sync/skip/test_inspect_skip_layout.py | 118 +++ .../pipeline/sync/skip/test_leak.py | 136 +++ .../pipeline/sync/skip/test_portal.py | 163 ++++ .../pipeline/sync/skip/test_stash_pop.py | 144 +++ .../pipeline/sync/skip/test_tracker.py | 145 +++ .../sync/skip/test_verify_skippables.py | 165 ++++ .../distributed/pipeline/sync/test_balance.py | 240 +++++ test/distributed/pipeline/sync/test_bugs.py | 146 +++ .../pipeline/sync/test_checkpoint.py | 178 ++++ test/distributed/pipeline/sync/test_copy.py | 85 ++ .../pipeline/sync/test_deferred_batch_norm.py | 200 ++++ .../pipeline/sync/test_dependency.py | 152 ++++ .../distributed/pipeline/sync/test_inplace.py | 79 ++ .../pipeline/sync/test_microbatch.py | 148 +++ test/distributed/pipeline/sync/test_phony.py | 57 ++ test/distributed/pipeline/sync/test_pipe.py | 858 ++++++++++++++++++ .../pipeline/sync/test_pipeline.py | 36 + test/distributed/pipeline/sync/test_stream.py | 198 ++++ .../pipeline/sync/test_transparency.py | 55 ++ test/distributed/pipeline/sync/test_worker.py | 118 +++ test/test_public_bindings.py | 2 + test/test_testing.py | 1 + torch/distributed/pipeline/__init__.py | 13 + torch/distributed/pipeline/sync/LICENSE | 27 + torch/distributed/pipeline/sync/__init__.py | 12 + .../pipeline/sync/_balance/__init__.py | 164 ++++ .../pipeline/sync/_balance/blockpartition.py | 95 ++ .../pipeline/sync/_balance/profile.py | 116 +++ .../pipeline/sync/_balance/py.typed | 6 + torch/distributed/pipeline/sync/batchnorm.py | 159 ++++ torch/distributed/pipeline/sync/checkpoint.py | 364 ++++++++ torch/distributed/pipeline/sync/copy.py | 108 +++ torch/distributed/pipeline/sync/dependency.py | 54 ++ torch/distributed/pipeline/sync/microbatch.py | 234 +++++ torch/distributed/pipeline/sync/phony.py | 50 + torch/distributed/pipeline/sync/pipe.py | 490 ++++++++++ torch/distributed/pipeline/sync/pipeline.py | 255 ++++++ torch/distributed/pipeline/sync/py.typed | 6 + .../pipeline/sync/skip/__init__.py | 11 + .../distributed/pipeline/sync/skip/layout.py | 92 ++ .../pipeline/sync/skip/namespace.py | 50 + .../distributed/pipeline/sync/skip/portal.py | 231 +++++ .../pipeline/sync/skip/skippable.py | 431 +++++++++ .../distributed/pipeline/sync/skip/tracker.py | 180 ++++ torch/distributed/pipeline/sync/stream.py | 120 +++ torch/distributed/pipeline/sync/utils.py | 38 + torch/distributed/pipeline/sync/worker.py | 132 +++ .../distributed/pipe_with_ddp_test.py | 149 +++ .../distributed/pipeline/__init__.py | 0 .../_internal/distributed/rpc_utils.py | 4 + 64 files changed, 7891 insertions(+) create mode 100644 benchmarks/distributed/pipeline/benchmark_dataset.py create mode 100644 benchmarks/distributed/pipeline/pipe.py create mode 100644 docs/source/pipeline.rst create mode 100644 test/distributed/pipeline/sync/LICENSE create mode 100644 test/distributed/pipeline/sync/__init__.py create mode 100644 test/distributed/pipeline/sync/conftest.py create mode 100644 test/distributed/pipeline/sync/skip/__init__.py create mode 100644 test/distributed/pipeline/sync/skip/test_api.py create mode 100644 test/distributed/pipeline/sync/skip/test_gpipe.py create mode 100644 test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py create mode 100644 test/distributed/pipeline/sync/skip/test_leak.py create mode 100644 test/distributed/pipeline/sync/skip/test_portal.py create mode 100644 test/distributed/pipeline/sync/skip/test_stash_pop.py create mode 100644 test/distributed/pipeline/sync/skip/test_tracker.py create mode 100644 test/distributed/pipeline/sync/skip/test_verify_skippables.py create mode 100644 test/distributed/pipeline/sync/test_balance.py create mode 100644 test/distributed/pipeline/sync/test_bugs.py create mode 100644 test/distributed/pipeline/sync/test_checkpoint.py create mode 100644 test/distributed/pipeline/sync/test_copy.py create mode 100644 test/distributed/pipeline/sync/test_deferred_batch_norm.py create mode 100644 test/distributed/pipeline/sync/test_dependency.py create mode 100644 test/distributed/pipeline/sync/test_inplace.py create mode 100644 test/distributed/pipeline/sync/test_microbatch.py create mode 100644 test/distributed/pipeline/sync/test_phony.py create mode 100644 test/distributed/pipeline/sync/test_pipe.py create mode 100644 test/distributed/pipeline/sync/test_pipeline.py create mode 100644 test/distributed/pipeline/sync/test_stream.py create mode 100644 test/distributed/pipeline/sync/test_transparency.py create mode 100644 test/distributed/pipeline/sync/test_worker.py create mode 100644 torch/distributed/pipeline/__init__.py create mode 100644 torch/distributed/pipeline/sync/LICENSE create mode 100644 torch/distributed/pipeline/sync/__init__.py create mode 100644 torch/distributed/pipeline/sync/_balance/__init__.py create mode 100644 torch/distributed/pipeline/sync/_balance/blockpartition.py create mode 100644 torch/distributed/pipeline/sync/_balance/profile.py create mode 100644 torch/distributed/pipeline/sync/_balance/py.typed create mode 100644 torch/distributed/pipeline/sync/batchnorm.py create mode 100644 torch/distributed/pipeline/sync/checkpoint.py create mode 100644 torch/distributed/pipeline/sync/copy.py create mode 100644 torch/distributed/pipeline/sync/dependency.py create mode 100644 torch/distributed/pipeline/sync/microbatch.py create mode 100644 torch/distributed/pipeline/sync/phony.py create mode 100644 torch/distributed/pipeline/sync/pipe.py create mode 100644 torch/distributed/pipeline/sync/pipeline.py create mode 100644 torch/distributed/pipeline/sync/py.typed create mode 100644 torch/distributed/pipeline/sync/skip/__init__.py create mode 100644 torch/distributed/pipeline/sync/skip/layout.py create mode 100644 torch/distributed/pipeline/sync/skip/namespace.py create mode 100644 torch/distributed/pipeline/sync/skip/portal.py create mode 100644 torch/distributed/pipeline/sync/skip/skippable.py create mode 100644 torch/distributed/pipeline/sync/skip/tracker.py create mode 100644 torch/distributed/pipeline/sync/stream.py create mode 100644 torch/distributed/pipeline/sync/utils.py create mode 100644 torch/distributed/pipeline/sync/worker.py create mode 100644 torch/testing/_internal/distributed/pipe_with_ddp_test.py create mode 100644 torch/testing/_internal/distributed/pipeline/__init__.py diff --git a/.lintrunner.toml b/.lintrunner.toml index 033414d8bbc8..e4f2507de8cc 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1535,6 +1535,28 @@ exclude_patterns = [ 'torch/distributed/optim/post_localSGD_optimizer.py', 'torch/distributed/optim/utils.py', 'torch/distributed/optim/zero_redundancy_optimizer.py', + 'torch/distributed/pipeline/__init__.py', + 'torch/distributed/pipeline/sync/__init__.py', + 'torch/distributed/pipeline/sync/_balance/__init__.py', + 'torch/distributed/pipeline/sync/_balance/blockpartition.py', + 'torch/distributed/pipeline/sync/_balance/profile.py', + 'torch/distributed/pipeline/sync/batchnorm.py', + 'torch/distributed/pipeline/sync/checkpoint.py', + 'torch/distributed/pipeline/sync/copy.py', + 'torch/distributed/pipeline/sync/dependency.py', + 'torch/distributed/pipeline/sync/microbatch.py', + 'torch/distributed/pipeline/sync/phony.py', + 'torch/distributed/pipeline/sync/pipe.py', + 'torch/distributed/pipeline/sync/pipeline.py', + 'torch/distributed/pipeline/sync/skip/__init__.py', + 'torch/distributed/pipeline/sync/skip/layout.py', + 'torch/distributed/pipeline/sync/skip/namespace.py', + 'torch/distributed/pipeline/sync/skip/portal.py', + 'torch/distributed/pipeline/sync/skip/skippable.py', + 'torch/distributed/pipeline/sync/skip/tracker.py', + 'torch/distributed/pipeline/sync/stream.py', + 'torch/distributed/pipeline/sync/utils.py', + 'torch/distributed/pipeline/sync/worker.py', 'torch/distributed/remote_device.py', 'torch/distributed/rendezvous.py', 'torch/distributed/rpc/__init__.py', @@ -1828,6 +1850,8 @@ exclude_patterns = [ 'torch/testing/_internal/distributed/nn/__init__.py', 'torch/testing/_internal/distributed/nn/api/__init__.py', 'torch/testing/_internal/distributed/nn/api/remote_module_test.py', + 'torch/testing/_internal/distributed/pipe_with_ddp_test.py', + 'torch/testing/_internal/distributed/pipeline/__init__.py', 'torch/testing/_internal/distributed/rpc/__init__.py', 'torch/testing/_internal/distributed/rpc/dist_autograd_test.py', 'torch/testing/_internal/distributed/rpc/dist_optimizer_test.py', diff --git a/benchmarks/distributed/pipeline/benchmark_dataset.py b/benchmarks/distributed/pipeline/benchmark_dataset.py new file mode 100644 index 000000000000..3cd22e9a468d --- /dev/null +++ b/benchmarks/distributed/pipeline/benchmark_dataset.py @@ -0,0 +1,58 @@ +import torch +from torch.utils.data import Dataset + + +def collate_sentences_lm(samples): + if len(samples) == 0: + return {} + + id = torch.LongTensor([s["id"] for s in samples]) + src_tokens = torch.stack([s["source"] for s in samples], 0) + tgt_tokens = torch.stack([s["target"] for s in samples], 0) + ntokens = len(samples) * len(samples[0]["target"]) + src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples)) + + batch = { + "id": id, + "nsentences": len(samples), + "ntokens": ntokens, + "input": src_tokens, + "target": tgt_tokens, + } + return batch + + +class BenchmarkLMDataset(Dataset): + """ + Dataset to benchmark a translation like seq2seq task. + Args: + vocab_size (int, optional): size of the vocabulary (default 10000). + max_source_positions (int, optional): max number of tokens in the + source sentence (default: 1024). + total_samples (int, optional): the total number of rows in the + dataset (default: 10000). + """ + + def __init__( + self, + vocab_size=10000, + max_source_positions=1024, + total_samples=10000, + ): + self.vocab_size = vocab_size + self.max_source_positions = max_source_positions + self.total_samples = total_samples + self.sizes = [self.max_source_positions] * self.total_samples + + def __getitem__(self, index): + length = self.sizes[index] + source = torch.randint(1, self.vocab_size, (length,)) + target = source.clone() + return { + "id": index, + "source": source, + "target": target, + } + + def __len__(self): + return self.total_samples diff --git a/benchmarks/distributed/pipeline/pipe.py b/benchmarks/distributed/pipeline/pipe.py new file mode 100644 index 000000000000..c465c2488565 --- /dev/null +++ b/benchmarks/distributed/pipeline/pipe.py @@ -0,0 +1,296 @@ +import argparse +import math +import os +import time + +from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm + +import torch +import torch.nn as nn +from torch.distributed import rpc + +from torch.distributed.pipeline.sync import Pipe +from torch.distributed.pipeline.sync.utils import partition_model +from torch.optim import Adam +from torch.utils.data import DataLoader + + +def sizeof_fmt(num, suffix="B"): + for unit in ["", "Ki", "Mi", "Gi", "Ti"]: + if abs(num) < 1024.0: + return f"{num:3.2f}{unit}B" + num /= 1024.0 + + +def init_random_seed(seed: int): + import numpy + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + numpy.random.seed(seed) + + +iteration_count = 0 + + +class EmbeddingLayer(nn.Embedding): + def __init__(self, ntoken, ninp, initrange): + super().__init__(ntoken, ninp) + self.ninp = ninp + nn.init.uniform_(self.weight, -initrange, initrange) + + def forward(self, src): + return super().forward(src) * math.sqrt(self.ninp) + + +class PositionalEncodingLayer(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[: x.size(0), :] + return self.dropout(x) + + +class TransformerDecoderLayer(nn.TransformerEncoderLayer): + """Though this class inherits from torch.nn.TransformerEncoderLayer, + it functions as a decoder in this model""" + + def __init__(self, ninp, nhead, nhid, droupout): + super().__init__(ninp, nhead, nhid, droupout) + self.src_mask = None + + def forward(self, src): + global iteration_count + iteration_count += 1 + + if self.src_mask is None or self.src_mask.size(0) != len(src): + device = src.device + mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device) + self.src_mask = mask + + return super().forward(src, self.src_mask) + + +class LinearLayer(nn.Linear): + def __init__(self, ninp, ntoken, initrange): + super().__init__(ninp, ntoken) + nn.init.zeros_(self.bias) + nn.init.uniform_(self.weight, -initrange, initrange) + + +class TransformerLMSequential(nn.Sequential): + """A small language model based on the design of GPT-2 using nn.Sequential + for compatibility with Pipe""" + + def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder): + layers = [ + EmbeddingLayer(ntokens, ninp, initrange), + PositionalEncodingLayer(ninp, dropout), + ] + for _ in range(ndecoder): + layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout)) + + layers.append(LinearLayer(ninp, ntokens, initrange)) + super().__init__(*layers) + + +def make_model(args, device, ntokens): + ninp = 2048 # embedding dimension + nhid = ( + 2048 # the dimension of the feedforward network model in nn.TransformerEncoder + ) + nhead = 32 # the number of heads in the multiheadattention models + dropout = 0 + initrange = 0.1 + ndecoder = args.num_decoder_layers + + model = TransformerLMSequential( + ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder + ).to(device) + + criterion = nn.CrossEntropyLoss() + lr = 0.01 # learning rate + + def make_adam(model): + return Adam(model.parameters(), lr=lr) + + optimizer = make_adam + + return model, criterion, optimizer + + +def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): + model.train() + + vocab_size = 10000 + total_loss = 0.0 + start_time = time.time() + word_counter = 0 + + optimizer = optimizer(model) + + def get_first_device(model): + if model.devices: + return model.devices[0] + else: + return torch.cuda.current_device() + + def get_last_device(model): + if model.devices: + return model.devices[-1] + else: + return torch.cuda.current_device() + + print( + f"Number of parameters for model: {sum(p.numel() for p in model.parameters())}" + ) + for i, batch in enumerate(lm_dataloader): + bi = batch["input"] + if args.max_batch and i > args.max_batch: + break + optimizer.zero_grad() + try: + tmp = batch["input"].to(get_first_device(model)) + output = model(tmp).local_value() + except Exception as e: + raise RuntimeError( + f"training failed on {torch.distributed.get_rank()}" + ) from e + + target = batch["target"].to(get_last_device(model)) + output = output.to(target.device) + + loss = criterion(output.view(-1, vocab_size), target.view(-1)) + loss.backward() + del target + del output + + torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) + optimizer.step() + + total_loss += loss.item() + log_interval = 1 + word_counter += batch["ntokens"] + if i % log_interval == 0 and i > 0: + cur_loss = total_loss / log_interval + elapsed = time.time() - start_time + print( + f"| batch {i:5d} | wps {word_counter / elapsed:5.2f} | loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}" + ) + word_counter = 0 + total_loss = 0 + start_time = time.time() + + print("Peak memory usage for GPUs: ", end="") + for i in range(len(model.devices)): + print( + f"cuda:{i}: {sizeof_fmt(torch.cuda.memory_stats(i)['allocated_bytes.all.peak'])}, ", + end="", + ) + print() + + +def generate_balance(num_devices, num_layers): + balance = [] + layers_assigned = 0 + for i in range(num_devices): + x = (num_layers - layers_assigned) / (num_devices - i) + if x.is_integer(): + balance.append(int(x)) + layers_assigned += x + else: + balance.append(math.ceil(x)) + layers_assigned += math.ceil(x) + return balance + + +def make_model_and_data(args, device): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + vocab_size = 10000 + model, criterion, optimizer = make_model(args, device, vocab_size) + lm_dataset = BenchmarkLMDataset() + lm_dataloader = DataLoader( + lm_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=0, + collate_fn=collate_sentences_lm, + ) + return { + "model": model, + "criterion": criterion, + "optimizer": optimizer, + "data": lm_dataloader, + "vocab_size": vocab_size, + } + + +def bench_single_process(args): + os.environ.update({"MASTER_ADDR": args.host}) + os.environ.update({"MASTER_PORT": "10638"}) + + rpc.init_rpc( + "worker", + rank=0, + world_size=1, + ) + + num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 + num_devices = min(args.num_devices, num_devices) + assert num_devices > 0 + init_random_seed(0) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + blob = make_model_and_data(args, None) + model = blob["model"] + + balance = generate_balance(num_devices, len(model)) + model = partition_model(model, balance) + p = Pipe(model, chunks=args.chunks, checkpoint=args.checkpoint) + del model + del blob["model"] + + train( + blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args + ) + + +parser = argparse.ArgumentParser(description="benchmark") +parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname") +parser.add_argument( + "--chunks", type=int, default=4, help="number of microbatches per batch" +) +parser.add_argument("--batch-size", type=int, default=8, help="size of a batch") +parser.add_argument("--max-batch", type=int, default=10, help="Max number of batches") +parser.add_argument( + "--num-decoder-layers", + type=int, + default=10, + help="Number of decoder layers in the model", +) +parser.add_argument( + "--checkpoint", + default="except_last", + choices=["always", "except_last", "never"], + help="Checkpointing strategy for pipe", +) +parser.add_argument( + "--num-devices", type=int, default=4, help="Number of GPU devices to use" +) + +if __name__ == "__main__": + args = parser.parse_args() + print(f"Running benchmark with args: {args}") + bench_single_process(args) diff --git a/docs/source/conf.py b/docs/source/conf.py index 4f73c111cb23..fe548737b313 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -606,6 +606,45 @@ # torch.distributed.optim.utils "as_functional_optim", "register_functional_optim", + # torch.distributed.pipeline.sync.checkpoint + "checkpoint", + "enable_checkpointing", + "enable_recomputing", + "is_checkpointing", + "is_recomputing", + "restore_rng_states", + "save_rng_states", + # torch.distributed.pipeline.sync.dependency + "fork", + "join", + # torch.distributed.pipeline.sync.microbatch + "check", + "gather", + "scatter", + # torch.distributed.pipeline.sync.phony + "get_phony", + # torch.distributed.pipeline.sync.skip.layout + "inspect_skip_layout", + # torch.distributed.pipeline.sync.skip.tracker + "current_skip_tracker", + "use_skip_tracker", + # torch.distributed.pipeline.sync.stream + "as_cuda", + "current_stream", + "default_stream", + "get_device", + "is_cuda", + "new_stream", + "record_stream", + "use_device", + "use_stream", + "wait_stream", + # torch.distributed.pipeline.sync.utils + "partition_model", + # torch.distributed.pipeline.sync.worker + "create_workers", + "spawn_workers", + "worker", # torch.distributed.rendezvous "register_rendezvous_handler", "rendezvous", @@ -2609,6 +2648,52 @@ "PostLocalSGDOptimizer", # torch.distributed.optim.zero_redundancy_optimizer "ZeroRedundancyOptimizer", + # torch.distributed.pipeline.sync.batchnorm + "DeferredBatchNorm", + # torch.distributed.pipeline.sync.checkpoint + "Checkpoint", + "Checkpointing", + "Context", + "Function", + "Recompute", + "ThreadLocal", + # torch.distributed.pipeline.sync.copy + "Context", + "Copy", + "Wait", + # torch.distributed.pipeline.sync.dependency + "Fork", + "Join", + # torch.distributed.pipeline.sync.microbatch + "Batch", + "NoChunk", + # torch.distributed.pipeline.sync.pipe + "BalanceError", + "Pipe", + "PipeSequential", + "WithDevice", + # torch.distributed.pipeline.sync.pipeline + "Pipeline", + # torch.distributed.pipeline.sync.skip.layout + "SkipLayout", + # torch.distributed.pipeline.sync.skip.namespace + "Namespace", + # torch.distributed.pipeline.sync.skip.portal + "Context", + "Portal", + "PortalBlue", + "PortalCopy", + "PortalOrange", + # torch.distributed.pipeline.sync.skip.skippable + "Skippable", + # torch.distributed.pipeline.sync.skip.tracker + "SkipTracker", + "SkipTrackerThroughPotals", + "ThreadLocal", + # torch.distributed.pipeline.sync.stream + "CPUStreamType", + # torch.distributed.pipeline.sync.worker + "Task", # torch.distributed.rpc.api "AllGatherStates", "RRef", diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index f4c73b9381e5..0b091d567031 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -876,6 +876,9 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.nn.api .. py:module:: torch.distributed.nn.jit .. py:module:: torch.distributed.nn.jit.templates +.. py:module:: torch.distributed.pipeline +.. py:module:: torch.distributed.pipeline.sync +.. py:module:: torch.distributed.pipeline.sync.skip .. py:module:: torch.distributed.tensor .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks @@ -961,6 +964,22 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.optim.post_localSGD_optimizer .. py:module:: torch.distributed.optim.utils .. py:module:: torch.distributed.optim.zero_redundancy_optimizer +.. py:module:: torch.distributed.pipeline.sync.batchnorm +.. py:module:: torch.distributed.pipeline.sync.checkpoint +.. py:module:: torch.distributed.pipeline.sync.copy +.. py:module:: torch.distributed.pipeline.sync.dependency +.. py:module:: torch.distributed.pipeline.sync.microbatch +.. py:module:: torch.distributed.pipeline.sync.phony +.. py:module:: torch.distributed.pipeline.sync.pipe +.. py:module:: torch.distributed.pipeline.sync.pipeline +.. py:module:: torch.distributed.pipeline.sync.skip.layout +.. py:module:: torch.distributed.pipeline.sync.skip.namespace +.. py:module:: torch.distributed.pipeline.sync.skip.portal +.. py:module:: torch.distributed.pipeline.sync.skip.skippable +.. py:module:: torch.distributed.pipeline.sync.skip.tracker +.. py:module:: torch.distributed.pipeline.sync.stream +.. py:module:: torch.distributed.pipeline.sync.utils +.. py:module:: torch.distributed.pipeline.sync.worker .. py:module:: torch.distributed.remote_device .. py:module:: torch.distributed.rendezvous .. py:module:: torch.distributed.rpc.api diff --git a/docs/source/index.rst b/docs/source/index.rst index dcaadcbb63ed..ea704f20c3af 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -103,6 +103,7 @@ Features described in this documentation are classified by release status: optim complex_numbers ddp_comm_hooks + pipeline quantization rpc torch.random diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst new file mode 100644 index 000000000000..94d730ee223d --- /dev/null +++ b/docs/source/pipeline.rst @@ -0,0 +1,85 @@ +.. _pipeline-parallelism: + +Pipeline Parallelism +==================== + +Pipeline parallelism was original introduced in the +`Gpipe `__ paper and is an efficient +technique to train large models on multiple GPUs. + +.. warning :: + torch.distributed.pipeline is deprecated, so is this document. For + up-to-date pipeline parallel implementation, please refer to the + `PiPPy `__ library under the PyTorch + organization (Pipeline Parallelism for PyTorch). + +Model Parallelism using multiple GPUs +------------------------------------- + +Typically for large models which don't fit on a single GPU, model parallelism +is employed where certain parts of the model are placed on different GPUs. +Although, if this is done naively for sequential models, the training process +suffers from GPU under utilization since only one GPU is active at one time as +shown in the figure below: + +.. figure:: _static/img/pipeline_parallelism/no_pipe.png + + The figure represents a model with 4 layers placed on 4 different GPUs + (vertical axis). The horizontal axis represents training this model through + time demonstrating that only 1 GPU is utilized at a time + (`image source `__). + +Pipelined Execution +------------------- + +To alleviate this problem, pipeline parallelism splits the input minibatch into +multiple microbatches and pipelines the execution of these microbatches across +multiple GPUs. This is outlined in the figure below: + +.. figure:: _static/img/pipeline_parallelism/pipe.png + + The figure represents a model with 4 layers placed on 4 different GPUs + (vertical axis). The horizontal axis represents training this model through + time demonstrating that the GPUs are utilized much more efficiently. + However, there still exists a bubble (as demonstrated in the figure) where + certain GPUs are not utilized. + (`image source `__). + +Pipe APIs in PyTorch +-------------------- +.. autoclass:: torch.distributed.pipeline.sync.Pipe + :members: forward + +Skip connections +^^^^^^^^^^^^^^^^ + +Certain models like `ResNeXt `__ +are not completely sequential and have skip connections between layers. +Naively implementing as part of pipeline parallelism would imply that +we need to copy outputs for certain layers through multiple GPUs till +we eventually reach the GPU where the layer for the skip connection resides. +To avoid this copy overhead, we provide APIs below to stash and pop Tensors +in different layers of the model. + +.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.skippable +.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.stash +.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.pop +.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.verify_skippables + +Tutorials +--------- + +The following tutorials give a good overview of how to use the +:class:`~torch.distributed.pipeline.sync.Pipe` API to train your models with the +rest of the components that PyTorch provides: + +- `Training Transformer models using Pipeline Parallelism `__ +- `Training Transformer models using Distributed Data Parallel and Pipeline Parallelism `__ + +Acknowledgements +---------------- + +The implementation for pipeline parallelism is based on `fairscale's pipe implementation `__ and +`torchgpipe `__. We would like to +thank both teams for their contributions and guidance towards bringing pipeline +parallelism into PyTorch. diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index f7af925adb72..c3d3fe2f00ec 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -211,6 +211,30 @@ "torch.distributed.optim.utils": [ "Type" ], + "torch.distributed.pipeline.sync.pipe": [ + "Pipeline" + ], + "torch.distributed.pipeline.sync.skip.layout": [ + "SkipLayout", + "inspect_skip_layout" + ], + "torch.distributed.pipeline.sync.skip.portal": [ + "Context", + "Portal", + "PortalBlue", + "PortalCopy", + "PortalOrange" + ], + "torch.distributed.pipeline.sync.skip.skippable": [ + "Skippable" + ], + "torch.distributed.pipeline.sync.skip.tracker": [ + "SkipTracker", + "SkipTrackerThroughPotals", + "ThreadLocal", + "current_skip_tracker", + "use_skip_tracker" + ], "torch.distributed.remote_device": [ "Optional", "Union" @@ -1673,6 +1697,10 @@ "get_args_parser", "run" ], + "torch.distributed.pipeline.sync": [ + "NoChunk", + "WithDevice" + ], "torch.distributed.rpc.rref_proxy": [ "Future", "partial", diff --git a/test/distributed/pipeline/sync/LICENSE b/test/distributed/pipeline/sync/LICENSE new file mode 100644 index 000000000000..e52be240fdc9 --- /dev/null +++ b/test/distributed/pipeline/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright 2019-2020 Kakao Brain + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/test/distributed/pipeline/sync/__init__.py b/test/distributed/pipeline/sync/__init__.py new file mode 100644 index 000000000000..94cd5bcb415e --- /dev/null +++ b/test/distributed/pipeline/sync/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH. +# See also: https://docs.pytest.org/en/latest/goodpractices.html diff --git a/test/distributed/pipeline/sync/conftest.py b/test/distributed/pipeline/sync/conftest.py new file mode 100644 index 000000000000..4f2479b27b29 --- /dev/null +++ b/test/distributed/pipeline/sync/conftest.py @@ -0,0 +1,61 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import tempfile + +import pytest + +import torch +import torch.distributed as dist + + +@pytest.fixture(autouse=True) +def manual_seed_zero(): + torch.manual_seed(0) + + +@pytest.fixture(scope="session") +def cuda_sleep(): + # Warm-up CUDA. + torch.empty(1, device="cuda") + + # From test/test_cuda.py in PyTorch. + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.cuda._sleep(1000000) + end.record() + end.synchronize() + cycles_per_ms = 1000000 / start.elapsed_time(end) + + def cuda_sleep(seconds): + torch.cuda._sleep(int(seconds * cycles_per_ms * 1000)) + + return cuda_sleep + + +def pytest_report_header(): + return f"torch: {torch.__version__}" + + +@pytest.fixture +def setup_rpc(scope="session"): + file = tempfile.NamedTemporaryFile() + dist.rpc.init_rpc( + name="worker0", + rank=0, + world_size=1, + rpc_backend_options=dist.rpc.TensorPipeRpcBackendOptions( + init_method=f"file://{file.name}", + ), + ) + yield + dist.rpc.shutdown() + + +def pytest_ignore_collect(path, config): + "Skip this directory if distributed modules are not enabled." + return not dist.is_available() diff --git a/test/distributed/pipeline/sync/skip/__init__.py b/test/distributed/pipeline/sync/skip/__init__.py new file mode 100644 index 000000000000..ab03724cafbf --- /dev/null +++ b/test/distributed/pipeline/sync/skip/__init__.py @@ -0,0 +1,6 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/distributed/pipeline/sync/skip/test_api.py b/test/distributed/pipeline/sync/skip/test_api.py new file mode 100644 index 000000000000..be38d6d83dac --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_api.py @@ -0,0 +1,52 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import copy + +from torch import nn + +from torch.distributed.pipeline.sync.skip import Namespace, skippable, stash +from torch.testing._internal.common_utils import run_tests + + +def test_namespace_difference(): + ns1 = Namespace() + ns2 = Namespace() + assert ns1 != ns2 + + +def test_namespace_copy(): + ns = Namespace() + assert copy.copy(ns) == ns + assert copy.copy(ns) is not ns + + +def test_skippable_repr(): + @skippable(stash=["hello"]) + class Hello(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 1, 1) + + def forward(self, x): + yield stash("hello", x) + return self.conv(x) # noqa: B901 + + m = Hello() + assert ( + repr(m) + == """ +@skippable(Hello( + (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1)) +)) +""".strip() + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_gpipe.py b/test/distributed/pipeline/sync/skip/test_gpipe.py new file mode 100644 index 000000000000..4f433ab38941 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_gpipe.py @@ -0,0 +1,126 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +import torch +from torch import nn + +from torch.distributed.pipeline.sync import Pipe +from torch.distributed.pipeline.sync.skip import pop, skippable, stash +from torch.distributed.pipeline.sync.skip.portal import ( + PortalBlue, + PortalCopy, + PortalOrange, +) +from torch.distributed.pipeline.sync.utils import partition_model +from torch.testing._internal.common_utils import run_tests + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +@pytest.mark.parametrize( + "balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"] +) +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +def test_1to3(balance, checkpoint, setup_rpc): + if torch.cuda.device_count() < len(balance): + pytest.skip("at least %d cuda devices required" % len(balance)) + + @skippable(stash=["1to3"]) + class Layer1(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def forward(self, input): + yield stash("1to3", input) + output = self.conv(input) + return output # noqa: B901 + + class Layer2(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def forward(self, input): + output = self.conv(input) + return output + + @skippable(pop=["1to3"]) + class Layer3(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def forward(self, input): + skip_1to3 = yield pop("1to3") + output = self.conv(input) + skip_1to3 + return output + + model = nn.Sequential(Layer1(), Layer2(), Layer3()) + model = partition_model(model, balance) + model = Pipe(model, chunks=3, checkpoint=checkpoint) + + in_device = model.devices[0] + out_device = model.devices[-1] + + input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) + output = model(input) + loss = output.local_value().mean() + loss.backward() + + assert torch.allclose( + output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1 + ) + assert torch.allclose( + input.grad.norm(), torch.tensor(0.0004533053, device=in_device) + ) + + +def test_none_skip(setup_rpc): + @skippable(stash=["none"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("none", None) + return input # noqa: B901 + + @skippable(pop=["none"]) + class Pop(nn.Module): + def forward(self, input): + none = yield pop("none") + assert none is None + return input + + model = nn.Sequential(Stash(), Pop()) + model = Pipe(model, chunks=5) + + input = torch.rand(10, requires_grad=True) + output = model(input) + + def assert_grad_fn_is_not_portal(grad_fn, visited=None): + if visited is None: + visited = set() + if grad_fn in visited or grad_fn is None: + return + + assert not isinstance(grad_fn, PortalBlue._backward_cls) + assert not isinstance(grad_fn, PortalCopy._backward_cls) + assert not isinstance(grad_fn, PortalOrange._backward_cls) + + visited.add(grad_fn) + for next_grad_fn, _ in grad_fn.next_functions: + assert_grad_fn_is_not_portal(next_grad_fn, visited) + + assert_grad_fn_is_not_portal(output.local_value().grad_fn) + + output.local_value().sum().backward() + assert input.grad.mean().item() == 1 + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py b/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py new file mode 100644 index 000000000000..4d542285cd5a --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py @@ -0,0 +1,118 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from torch import nn + +from torch.distributed.pipeline.sync.skip import Namespace, pop, skippable, stash +from torch.distributed.pipeline.sync.skip.layout import inspect_skip_layout +from torch.testing._internal.common_utils import run_tests + + +class Pass(nn.Module): + def forward(self, input): + return input + + +@skippable(stash=["foo"]) +class StashFoo(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input # noqa: B901 + + +@skippable(pop=["foo"]) +class PopFoo(nn.Module): + def forward(self, input): + foo = yield stash("foo") + return input + foo + + +@skippable(stash=["bar"]) +class StashBar(nn.Module): + def forward(self, input): + yield stash("bar", input) + return input # noqa: B901 + + +@skippable(pop=["bar"]) +class PopBar(nn.Module): + def forward(self, input): + bar = yield pop("bar") + return input + bar + + +def test_no_skippables(): + p1 = nn.Sequential(Pass()) + p2 = nn.Sequential(Pass()) + + layout = inspect_skip_layout([p1, p2]) + policy = [list(layout.copy_policy(i)) for i in range(2)] + + assert policy == [[], []] + + +def test_inner_partition(): + p1 = nn.Sequential(StashFoo(), PopFoo()) + p2 = nn.Sequential(Pass()) + + layout = inspect_skip_layout([p1, p2]) + policy = [list(layout.copy_policy(i)) for i in range(2)] + + assert policy == [[], []] + + +def test_adjoining_partitions(): + p1 = nn.Sequential(StashFoo()) + p2 = nn.Sequential(PopFoo()) + + layout = inspect_skip_layout([p1, p2]) + policy = [list(layout.copy_policy(i)) for i in range(2)] + + assert policy == [[], [(0, None, "foo")]] + + +def test_far_partitions(): + p1 = nn.Sequential(StashFoo()) + p2 = nn.Sequential(Pass()) + p3 = nn.Sequential(PopFoo()) + + layout = inspect_skip_layout([p1, p2, p3]) + policy = [list(layout.copy_policy(i)) for i in range(3)] + + assert policy == [[], [], [(0, None, "foo")]] + + +def test_pop_2_from_different_partitions(): + p1 = nn.Sequential(StashFoo()) + p2 = nn.Sequential(StashBar()) + p3 = nn.Sequential(PopBar(), PopFoo()) + + layout = inspect_skip_layout([p1, p2, p3]) + policy = [list(layout.copy_policy(i)) for i in range(3)] + + # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. + assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]] + + +def test_namespace(): + ns1 = Namespace() + ns2 = Namespace() + + p1 = nn.Sequential(StashFoo().isolate(ns1)) + p2 = nn.Sequential(StashFoo().isolate(ns2)) + p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1)) + + layout = inspect_skip_layout([p1, p2, p3]) + policy = [list(layout.copy_policy(i)) for i in range(3)] + + # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. + assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]] + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_leak.py b/test/distributed/pipeline/sync/skip/test_leak.py new file mode 100644 index 000000000000..f4d1043e0549 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_leak.py @@ -0,0 +1,136 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +import torch +from torch import nn + +from torch.distributed.pipeline.sync import is_checkpointing, is_recomputing, Pipe +from torch.distributed.pipeline.sync.skip import pop, skippable, stash +from torch.distributed.pipeline.sync.skip.tracker import current_skip_tracker +from torch.testing._internal.common_utils import run_tests + + +@skippable(stash=["skip"]) +class Stash(nn.Module): + def forward(self, input): + yield stash("skip", input) + return input # noqa: B901 + + +@skippable(pop=["skip"]) +class Pop(nn.Module): + def forward(self, input): + skip = yield pop("skip") + return input + skip + + +@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) +@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) +def test_delete_portal_tensor(train, checkpoint, setup_rpc): + # Without checkpointing: + # +- Stash --+ +--- Pop ----+ - - - layers + # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function + # +----------+ +------------+ + # + # With checkpointing: + # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ + # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | + # +----------+ +------------+ +------------+ +----------+ + + def portal_tensor_life_is(tensor_life, skip_tracker=None): + if skip_tracker is None: + skip_tracker = current_skip_tracker() + + # Get the current portal. + portal = next(iter(skip_tracker.portals.values())) + + if tensor_life == 0: + return portal.tensor_life == 0 and portal.tensor is None + else: + return portal.tensor_life == tensor_life and portal.tensor is not None + + # Check the portal tensor after 'Stash'. + stash_ = Stash() + + @stash_.register_forward_hook + def check_portal_tensor_after_stash(*_): + if is_checkpointing(): + assert portal_tensor_life_is(2) + elif is_recomputing(): + assert portal_tensor_life_is(0) + else: + assert portal_tensor_life_is(1) + + pop_ = Pop() + + @pop_.register_forward_hook + def check_portal_tensor_after_pop(*_): + if is_checkpointing(): + assert portal_tensor_life_is(1) + elif is_recomputing(): + assert portal_tensor_life_is(0) + else: + assert portal_tensor_life_is(0) + + class NoPortalTensorAtBackward(nn.Module): + class F(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.skip_tracker = current_skip_tracker() + return input.detach() + + @staticmethod + def backward(ctx, grad): + assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) + return grad + + def forward(self, input): + return self.F.apply(input) + + model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) + model = Pipe(model, chunks=2, checkpoint=checkpoint) + + input = torch.rand(10, requires_grad=True) + + if train: + model.train() + output = model(input).local_value() + output.norm().backward() + else: + model.eval() + with torch.no_grad(): + model(input) + + +@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) +def test_no_portal_without_pipe(train, monkeypatch, setup_rpc): + def deny(*args, **kwargs): + raise AssertionError("tried to create Portal without Pipe") + + monkeypatch.setattr( + "torch.distributed.pipeline.sync.skip.portal.Portal.__init__", deny + ) + + model = nn.Sequential(Stash(), Pop()) + + input = torch.rand(10, requires_grad=True) + + if train: + model.train() + output = model(input) + output.norm().backward() + else: + model.eval() + with torch.no_grad(): + model(input) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_portal.py b/test/distributed/pipeline/sync/skip/test_portal.py new file mode 100644 index 000000000000..5ad180b6f9c8 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_portal.py @@ -0,0 +1,163 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +import torch + +from torch.distributed.pipeline.sync.dependency import fork, join +from torch.distributed.pipeline.sync.skip.portal import Portal +from torch.distributed.pipeline.sync.stream import default_stream +from torch.testing._internal.common_utils import run_tests + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def test_copy_returns_on_next_device(): + portal = Portal(torch.rand(1), tensor_life=1) + + prev_stream = default_stream(torch.device("cpu")) + next_stream = default_stream(torch.device("cuda")) + + phony = torch.zeros(0, requires_grad=True) + assert phony.device.type == "cpu" + + phony = portal.copy(prev_stream, next_stream, phony) + assert phony.device.type == "cuda" + + +def test_blue_orange(): + tensor1 = torch.rand(1, requires_grad=True) + tensor2 = torch.rand(1, requires_grad=True) + + # Same with: output = tensor1*2 + tensor2 + # + # +----------------------+ + # | | + # tensor2 -- PortalBlue -+ +- PortalOrange -+ + # | | | + # tensor1 ------------ Join -- Fork --- Mul --- Add -- output + # + main = tensor1 + portal = Portal(tensor2, tensor_life=2) + phony = portal.blue() + main = join(main, phony) + main, phony = fork(main) + sub = portal.orange(phony) + output = main * 2 + sub + + output.backward() + + assert torch.allclose(tensor1.grad, torch.tensor([2.0])) + assert torch.allclose(tensor2.grad, torch.tensor([1.0])) + + +def test_blue_orange_not_requires_grad(): + tensor1 = torch.rand(1, requires_grad=True) + tensor2 = torch.rand(1) + + # Same with: output = tensor1*2 + tensor2 + # + # +----------------------+ + # | | + # tensor2 -- PortalBlue -+ +- PortalOrange -+ + # | | | + # tensor1 ------------ Join -- Fork --- Mul --- Add -- output + # + main = tensor1 + portal = Portal(tensor2, tensor_life=2) + phony = portal.blue() + main = join(main, phony) + main, phony = fork(main) + sub = portal.orange(phony) + output = main * 2 + sub + + output.backward() + + assert torch.allclose(tensor1.grad, torch.tensor([2.0])) + assert tensor2.grad is None + + +def test_use_grad(): + tensor = torch.rand(1, requires_grad=True) + portal = Portal(tensor, tensor_life=1) + + portal.put_grad(tensor) + assert portal.use_grad() is tensor + + # Gradient in a portal is ephemeral. + with pytest.raises(RuntimeError): + portal.use_grad() + + +class TestTensorLife: + @pytest.fixture + def new_portal(self): + portal = None + + def new_portal(tensor_life): + nonlocal portal + tensor = torch.rand(1, requires_grad=True) + portal = Portal(tensor, tensor_life) + return portal, tensor + + yield new_portal + + # A test using this fixture must exhaust the tensor in the portal. + with pytest.raises(RuntimeError): + portal.check_tensor_life() + assert portal.tensor is None + + def test_tensor_life_0(self, new_portal): + portal, tensor = new_portal(0) + assert portal.tensor is None + + def test_tensor_life_1(self, new_portal): + portal, tensor = new_portal(1) + assert portal.tensor is tensor + + portal.blue() + + def test_tensor_life_2(self, new_portal): + portal, tensor = new_portal(2) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + + def test_tensor_life_3(self, new_portal): + portal, tensor = new_portal(3) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + + def test_tensor_life_4(self, new_portal): + portal, tensor = new_portal(4) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + portal.blue() + + def test_tensor_life_3_plus_1(self, new_portal): + portal, tensor = new_portal(3) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + + another_tensor = torch.rand(1, requires_grad=True) + portal.put_tensor(another_tensor, tensor_life=1) + portal.blue() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_stash_pop.py b/test/distributed/pipeline/sync/skip/test_stash_pop.py new file mode 100644 index 000000000000..5d273860f6a6 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_stash_pop.py @@ -0,0 +1,144 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +import torch +from torch import nn + +from torch.distributed.pipeline.sync.skip import pop, skippable, stash +from torch.distributed.pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker +from torch.testing._internal.common_utils import run_tests + + +@pytest.fixture(autouse=True) +def skip_tracker(): + skip_tracker = SkipTracker() + with use_skip_tracker(skip_tracker): + yield skip_tracker + + +def test_stash(skip_tracker): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa: B901 + + l1 = Stash() + + assert len(skip_tracker.tensors) == 0 + + with use_skip_tracker(skip_tracker): + l1(torch.tensor(42)) + + assert len(skip_tracker.tensors) == 1 + + +def test_pop(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa: B901 + + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + foo = yield pop("foo") + return foo + + l1 = Stash() + l2 = Pop() + + output = l2(l1(torch.tensor(42))) + + assert output.item() == 42 + + +def test_declare_but_not_use(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + return input * 2 + + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + return input * 3 + + l1 = Stash() + l2 = Pop() + + with pytest.raises(RuntimeError): + l1(torch.tensor(42)) + + with pytest.raises(RuntimeError): + l2(torch.tensor(42)) + + +def test_stash_not_declared(): + @skippable() + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa: B901 + + l1 = Stash() + + with pytest.raises(RuntimeError): + l1(torch.tensor(42)) + + +def test_pop_not_declared(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa: B901 + + @skippable() + class Pop(nn.Module): + def forward(self, input): + foo = yield pop("foo") + return foo + + l1 = Stash() + l2 = Pop() + + latent = l1(torch.tensor(42)) + + with pytest.raises(RuntimeError): + l2(latent) + + +def test_pop_not_stashed(): + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + yield pop("foo") + + l1 = Pop() + + with pytest.raises(RuntimeError): + l1(torch.tensor(42)) + + +def test_stash_none(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", None) + return input * 2 # noqa: B901 + + l1 = Stash() + l1(torch.tensor(42)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_tracker.py b/test/distributed/pipeline/sync/skip/test_tracker.py new file mode 100644 index 000000000000..9c3a970f7574 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_tracker.py @@ -0,0 +1,145 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import threading +from queue import Queue + +import pytest + +import torch +from torch import nn + +from torch.distributed.pipeline.sync.checkpoint import ( + enable_checkpointing, + enable_recomputing, +) +from torch.distributed.pipeline.sync.microbatch import Batch +from torch.distributed.pipeline.sync.skip import pop, skippable, stash +from torch.distributed.pipeline.sync.skip.layout import SkipLayout +from torch.distributed.pipeline.sync.skip.tracker import ( + current_skip_tracker, + SkipTracker, + SkipTrackerThroughPotals, +) +from torch.testing._internal.common_utils import run_tests + + +def test_default_skip_tracker(): + q = Queue() + + def f(): + q.put(current_skip_tracker()) + + t = threading.Thread(target=f) + t.start() + t.join() + + skip_tracker = q.get() + + assert type(skip_tracker) is SkipTracker + assert type(skip_tracker) is not SkipTrackerThroughPotals + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def test_default_skip_tracker_by_data_parallel(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa: B901 + + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + foo = yield pop("foo") + return foo + + model = nn.Sequential(Stash(), Pop()) + model = nn.DataParallel(model, device_ids=[0, 0], output_device=0) + + input = torch.rand(10, device=0) + output = model(input) + + assert torch.allclose(output, input) + + +def test_reuse_portal(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + a = torch.tensor([2.0]) + b = torch.tensor([2.0]) + + skip_tracker.save(batch, None, "test", a) + portal = skip_tracker.portals[(None, "test")] + + skip_tracker.save(batch, None, "test", b) + assert portal is skip_tracker.portals[(None, "test")] + + +def test_no_copy_no_portal(): + skip_layout = SkipLayout( + num_partitions=2, + skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)}, + ) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + a = torch.tensor([2.0]) + b = torch.tensor([2.0]) + + skip_tracker.save(batch, None, "copy", a) + skip_tracker.save(batch, None, "not_copy", b) + + assert (None, "copy") in skip_tracker.portals + assert (None, "copy") not in skip_tracker.tensors + assert (None, "not_copy") in skip_tracker.tensors + assert (None, "not_copy") not in skip_tracker.portals + + +def test_tensor_life_without_checkpointing(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + tensor = torch.tensor([2.0]) + + skip_tracker.save(batch, None, "test", tensor) + assert skip_tracker.portals[(None, "test")].tensor_life == 1 + + skip_tracker.load(batch, None, "test") + assert skip_tracker.portals[(None, "test")].tensor_life == 0 + + +def test_tensor_life_with_checkpointing(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + tensor = torch.tensor([2.0]) + + with enable_checkpointing(): + skip_tracker.save(batch, None, "test", tensor) + assert skip_tracker.portals[(None, "test")].tensor_life == 2 + + with enable_checkpointing(): + skip_tracker.load(batch, None, "test") + assert skip_tracker.portals[(None, "test")].tensor_life == 1 + + with enable_recomputing(): + skip_tracker.load(batch, None, "test") + assert skip_tracker.portals[(None, "test")].tensor_life == 0 + + with enable_recomputing(): + skip_tracker.save(batch, None, "test", tensor) + assert skip_tracker.portals[(None, "test")].tensor_life == 0 + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_verify_skippables.py b/test/distributed/pipeline/sync/skip/test_verify_skippables.py new file mode 100644 index 000000000000..1d5941487da8 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_verify_skippables.py @@ -0,0 +1,165 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +from torch import nn + +from torch.distributed.pipeline.sync.skip import Namespace, skippable, verify_skippables +from torch.testing._internal.common_utils import run_tests + + +def test_matching(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + verify_skippables(nn.Sequential(Layer1(), Layer2())) + + +def test_stash_not_pop(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1())) + assert "no module declared 'foo' as poppable but stashed" in str(e.value) + + +def test_pop_unknown(): + @skippable(pop=["foo"]) + class Layer1(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1())) + assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value) + + +def test_stash_again(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(stash=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer3(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) + assert "'1' redeclared 'foo' as stashable" in str(e.value) + + +def test_pop_again(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer3(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) + assert "'2' redeclared 'foo' as poppable" in str(e.value) + + +def test_stash_pop_together_different_names(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"], stash=["bar"]) + class Layer2(nn.Module): + pass + + @skippable(pop=["bar"]) + class Layer3(nn.Module): + pass + + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) + + +def test_stash_pop_together_same_name(): + @skippable(stash=["foo"], pop=["foo"]) + class Layer1(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1())) + assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value) + + +def test_double_stash_pop(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(stash=["foo"]) + class Layer3(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer4(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4())) + assert "'2' redeclared 'foo' as stashable" in str(e.value) + assert "'3' redeclared 'foo' as poppable" in str(e.value) + + +def test_double_stash_pop_but_isolated(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(stash=["foo"]) + class Layer3(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer4(nn.Module): + pass + + ns1 = Namespace() + ns2 = Namespace() + + verify_skippables( + nn.Sequential( + Layer1().isolate(ns1), + Layer2().isolate(ns1), + Layer3().isolate(ns2), + Layer4().isolate(ns2), + ) + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_balance.py b/test/distributed/pipeline/sync/test_balance.py new file mode 100644 index 000000000000..faf09f4581ae --- /dev/null +++ b/test/distributed/pipeline/sync/test_balance.py @@ -0,0 +1,240 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import time + +import pytest + +import torch +from torch import nn + +from torch.distributed.pipeline.sync._balance import ( + balance_by_size, + balance_by_time, + blockpartition, +) +from torch.distributed.pipeline.sync._balance.profile import layerwise_sandbox +from torch.testing._internal.common_utils import run_tests + +skip_if_no_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="cuda required" +) + +devices = ["cpu"] +if torch.cuda.is_available(): + devices.append("cuda") + + +def test_blockpartition(): + assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [ + [1, 2, 3, 4], + [5, 6], + ] + + +def test_blockpartition_zeros(): + assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]] + + +def test_blockpartition_non_positive_partitions(): + with pytest.raises(ValueError): + blockpartition.solve([42], partitions=0) + with pytest.raises(ValueError): + blockpartition.solve([42], partitions=-1) + + +def test_blockpartition_short_sequence(): + with pytest.raises(ValueError): + blockpartition.solve([], partitions=1) + with pytest.raises(ValueError): + blockpartition.solve([42], partitions=2) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.skip(reason="Flaky due to time.sleep()") +def test_balance_by_time(device): + class Delay(nn.Module): + def __init__(self, seconds): + super().__init__() + self.seconds = seconds + + def forward(self, x): + time.sleep(self.seconds) + return x + + model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]]) + sample = torch.rand(1) + balance = balance_by_time(2, model, sample, device=device) + assert balance == [4, 2] + + +def test_balance_by_time_loop_resets_input(): + # nn.Flatten was introduced at PyTorch 1.2.0. + class Flatten(nn.Module): + def forward(self, x): + return x.flatten(1) + + model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10)) + sample = torch.rand(10, 3, 8, 8) + balance = balance_by_time(2, model, sample, device="cpu") + assert balance == [1, 2] + + +@skip_if_no_cuda +def test_balance_by_size_latent(): + class Expand(nn.Module): + def __init__(self, times): + super().__init__() + self.times = times + + def forward(self, x): + for i in range(self.times): + x = x + torch.rand_like(x, requires_grad=True) + return x + + sample = torch.rand(10, 100, 100) + + model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]]) + balance = balance_by_size(2, model, sample) + assert balance == [4, 2] + + model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]]) + balance = balance_by_size(2, model, sample) + assert balance == [2, 4] + + +@skip_if_no_cuda +def test_balance_by_size_param(): + model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)]) + sample = torch.rand(7, 1) + balance = balance_by_size(2, model, sample, param_scale=100) + assert balance == [4, 2] + + model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))]) + sample = torch.rand(1, 7) + balance = balance_by_size(2, model, sample, param_scale=100) + assert balance == [2, 4] + + +@skip_if_no_cuda +def test_balance_by_size_param_scale(): + class Tradeoff(nn.Module): + def __init__(self, param_size, latent_size): + super().__init__() + self.fc = nn.Linear(param_size, param_size) + self.latent_size = latent_size + + def forward(self, x): + for i in range(self.latent_size): + x = x + torch.rand_like(x, requires_grad=True) + return x + + model = nn.Sequential( + Tradeoff(param_size=1, latent_size=6), + Tradeoff(param_size=2, latent_size=5), + Tradeoff(param_size=3, latent_size=4), + Tradeoff(param_size=4, latent_size=3), + Tradeoff(param_size=5, latent_size=2), + Tradeoff(param_size=6, latent_size=1), + ) + + sample = torch.rand(1, requires_grad=True) + + balance = balance_by_size(2, model, sample, param_scale=0) + assert balance == [2, 4] + + balance = balance_by_size(2, model, sample, param_scale=100) + assert balance == [4, 2] + + +@pytest.mark.parametrize("device", devices) +def test_layerwise_sandbox(device): + model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) + model.eval() + + for layer in layerwise_sandbox(model, torch.device(device)): + assert layer.training + assert all(p.device.type == device for p in layer.parameters()) + + assert all(not l.training for l in model) + assert all(p.device.type == "cpu" for p in model.parameters()) + + +@pytest.mark.parametrize("device", devices) +def test_sandbox_during_profiling(device): + model = nn.Sequential(nn.BatchNorm2d(3)) + + before = {k: v.clone() for k, v in model.state_dict().items()} + + sample = torch.rand(1, 3, 10, 10) + balance_by_time(1, model, sample, device=device) + + after = model.state_dict() + + assert before.keys() == after.keys() + for key, value in before.items(): + assert torch.allclose(after[key], value), key + + +def test_not_training(): + class AssertTraining(nn.Module): + def forward(self, x): + assert self.training + return x + + model = nn.Sequential(AssertTraining()) + + model.eval() + assert not model.training + + sample = torch.rand(1) + balance_by_time(1, model, sample, device="cpu") + + assert not model.training + + +def test_balance_by_time_tuple(): + class Twin(nn.Module): + def forward(self, x): + return x, x.detach() + + class Add(nn.Module): + def forward(self, a, b): + return a + b + + model = nn.Sequential(Twin(), Add()) + sample = torch.rand(1, requires_grad=True) + balance_by_time(1, model, sample, device="cpu") + + +@skip_if_no_cuda +def test_balance_by_size_tuple(): + class Twin(nn.Module): + def forward(self, x): + return x, x.detach() + + class Add(nn.Module): + def forward(self, a, b): + return a + b + + model = nn.Sequential(Twin(), Add()) + sample = torch.rand(1, requires_grad=True) + balance_by_size(1, model, sample) + + +def test_already_has_grad(): + model = nn.Sequential(nn.Conv2d(3, 3, 1)) + sample = torch.rand(1, 3, 32, 32) + model(sample).norm().backward() + + with pytest.raises(ValueError, match="some parameter already has gradient"): + balance_by_time(1, model, sample, device="cpu") + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_bugs.py b/test/distributed/pipeline/sync/test_bugs.py new file mode 100644 index 000000000000..928a78db6e32 --- /dev/null +++ b/test/distributed/pipeline/sync/test_bugs.py @@ -0,0 +1,146 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +import torch +import torch.nn.functional as F +from torch import nn + +from torch.distributed.pipeline.sync import Pipe +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_utils import run_tests + + +def test_python_autograd_function(setup_rpc): + # A Python autograd function might fail with this error: + # + # RuntimeError: Returning Variables sharing storage with other Variables + # that require grad is not supported in Python functions. Please submit a + # feature request if you hit this error. + # + # It doesn't look like an essential restriction. But it happens on the + # current PyTorch version. To avoid it, we should detach the tensor before + # returning by identity autograd functions, such as Wait, Fork, and Join. + # + class Identity(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad): + return grad + + class M(nn.Module): + def forward(self, input): + return Identity.apply(input) + + model = nn.Sequential(M(), M()) + model = Pipe(model, checkpoint="always") + + x = torch.rand(42) + y = model(x) + assert torch.allclose(x, y.local_value()) + + +def test_exception_no_hang(setup_rpc): + # In v0.0.2, once a failed partition receives a normal message + # (non-closing) for the next micro-batch, a hang occurred. The reason was + # that a failed partition didn't call in_queue.task_done() on a normal + # message. So the former partition was blocked at out_queue.join() for the + # next of next micro-batch. + class ExpectedException(Exception): + pass + + class Pass(nn.Module): + def forward(self, x): + return x + + class Raise(nn.Module): + def forward(self, x): + raise ExpectedException + + model = nn.Sequential(Pass(), Pass(), Raise()) + model = Pipe(model, chunks=3) + + with pytest.raises(ExpectedException): + model(torch.rand(3)) + + +@pytest.mark.skipif(not TEST_MULTIGPU, reason="2 cuda devices required") +def test_tuple_wait(cuda_sleep, setup_rpc): + # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. + # Under this behavior, if checkpointing was disabled, there's a possibility + # that gradient accumulations on other tensors are not synchronized + # properly to the copy stream. + class Sleep(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.detach() + + @staticmethod + def backward(ctx, grad): + with torch.cuda.device(grad.device): + cuda_sleep(0.05) + return grad + + class Layer1(nn.Module): + def __init__(self): + super().__init__() + self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) + + def forward(self, a, b): + a = a * self.ones + return a * 1, b * 2, b * 3 + + class Layer2(nn.Module): + def __init__(self): + super().__init__() + self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) + + def forward(self, a, b, c): + a = a * self.ones + b = Sleep.apply(b) + return a + b + c + + model = nn.Sequential(Layer1().cuda(0), Layer2().cuda(1)) + model = Pipe(model, chunks=32, checkpoint="never") + + a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) + b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) + + y = model(a, b) + y.local_value().norm().backward() + + torch.cuda.synchronize(0) + torch.cuda.synchronize(1) + + assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000)) + + +def test_parallel_randoms(setup_rpc): + class Dropouts(nn.Module): + def forward(self, x): + for _ in range(100): + x = F.dropout(x, p=0.001) + return x + + model = nn.Sequential(Dropouts(), Dropouts()) + + x = torch.rand(10, 10, requires_grad=True) + model = Pipe(model, chunks=10, checkpoint="always") + y = model(x) + y = y.local_value() + y.norm().backward() + + assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_checkpoint.py b/test/distributed/pipeline/sync/test_checkpoint.py new file mode 100644 index 000000000000..7be8ddefafe9 --- /dev/null +++ b/test/distributed/pipeline/sync/test_checkpoint.py @@ -0,0 +1,178 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from functools import partial + +import pytest + +import torch +import torch.cuda +from torch import nn + +from torch.distributed.pipeline.sync.checkpoint import ( + checkpoint, + Checkpointing, + is_checkpointing, + is_recomputing, +) +from torch.distributed.pipeline.sync.dependency import fork, join +from torch.distributed.pipeline.sync.microbatch import Batch +from torch.testing._internal.common_utils import run_tests + +devices = ["cpu"] +if torch.cuda.is_available(): + devices.append("cuda") + + +@pytest.mark.parametrize("device", devices) +def test_serial_checkpoints(device): + # Copied from https://github.com/pytorch/pytorch/pull/18568. + timeline = [] + + class Log(torch.autograd.Function): + @staticmethod + def forward(ctx, name, x): + ctx.name = name + timeline.append(f"{name}:forward") + return x.detach() + + @staticmethod + def backward(ctx, grad_output): + name = ctx.name + timeline.append(f"{name}:backward") + return None, grad_output + + a = torch.rand(1, device=device, requires_grad=True) + b = torch.rand(1, device=device, requires_grad=True) + + # Increase the next function sequence number. + _ = a + 1 + 2 + 3 + 4 + 5 + + a = checkpoint(partial(Log.apply, "a"), a) + + a, phony = fork(a) + b = join(b, phony) + + b = checkpoint(partial(Log.apply, "b"), b) + + c = torch.cat((a, b)) + + out = c.sum() + + # +--> {a} --Checkpoint(Log)--> {a} + # {out} --Sum--> {c} --Cat ^-----------------------------+ + # +--> {b} --Checkpoint(Log)--> {b} --First--> {b} + out.backward() + + assert timeline == [ + "a:forward", + "b:forward", + "b:forward", + "b:backward", + "a:forward", + "a:backward", + ] + # |----------------------| |-----------------------| |-----------------------| + # forward pass Checkpoint(Log[b]) Checkpoint(Log[a]) + + +def test_not_requires_grad(): + x = Batch(torch.rand(1, requires_grad=False)) + assert not x[0].requires_grad + + def f(x): + return x * 2 + + chk = Checkpointing(f, x) + x = chk.checkpoint() + assert x[0].requires_grad + + chk.recompute(x) + assert x[0].requires_grad + + x.tensor.backward() + + +def test_not_requires_grad_with_parameter(): + x = torch.rand(1, requires_grad=False) + a = torch.rand(1, requires_grad=True) + + def f(x): + return x * a + + y = checkpoint(f, x) + y.backward() + + assert a.grad is not None + + +@pytest.mark.parametrize("device", devices) +def test_random_in_checkpoint(device): + dropout = nn.Dropout(p=0.5) + + torch.manual_seed(0) + x = torch.randn(3, 3, device=device, requires_grad=True) + y = dropout(x) + y.norm().backward() + + torch.manual_seed(0) + chk_x = torch.randn(3, 3, device=device, requires_grad=True) + chk_y = checkpoint(dropout, chk_x) + chk_y.norm().backward() + + assert torch.allclose(x.grad, chk_x.grad) + + +def test_detect_checkpointing_recomputing(): + logs = [] + + class Detect(nn.Module): + def forward(self, input): + logs.append((is_checkpointing(), is_recomputing())) + return input + + model = Detect() + input = torch.rand(1, requires_grad=True) + + output = checkpoint(model, input) + output.backward() + + assert logs == [(True, False), (False, True)] + + +def test_detect_checkpointing_recomputing_without_checkpoint(): + logs = [] + + class Detect(nn.Module): + def forward(self, input): + logs.append((is_checkpointing(), is_recomputing())) + return input + + model = Detect() + input = torch.rand(1, requires_grad=True) + + output = model(input) + output.backward() + + assert logs == [(False, False)] + + +def test_non_grad_output(): + class ForkNonGrad(nn.Module): + def forward(self, input): + return (input * 2, torch.rand(1)) + + model = ForkNonGrad() + input = torch.rand(1, requires_grad=True) + + output = checkpoint(model, input) + output[0].backward() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_copy.py b/test/distributed/pipeline/sync/test_copy.py new file mode 100644 index 000000000000..302c3d25d53f --- /dev/null +++ b/test/distributed/pipeline/sync/test_copy.py @@ -0,0 +1,85 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +import torch + +from torch.distributed.pipeline.sync.copy import Copy, Wait +from torch.distributed.pipeline.sync.stream import ( + CPUStream, + current_stream, + get_device, + is_cuda, + new_stream, + use_stream, +) +from torch.testing._internal.common_utils import run_tests + +skip_if_no_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="cuda required" +) + + +def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None): + device = get_device(prev_stream) + + with use_stream(prev_stream): + if is_cuda(prev_stream): + cuda_sleep(0.5) + x = torch.ones(100, device=device, requires_grad=True) + + (y,) = Copy.apply(prev_stream, next_stream, x) + (y,) = Wait.apply(prev_stream, next_stream, x) + + with use_stream(next_stream): + assert torch.allclose(y.sum(), torch.tensor(100.0, device=device)) + y.norm().backward() + with use_stream(prev_stream): + assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device)) + + +def test_copy_wait_cpu_cpu(): + prev_stream = CPUStream + next_stream = CPUStream + _test_copy_wait(prev_stream, next_stream) + + +@skip_if_no_cuda +def test_copy_wait_cpu_cuda(cuda_sleep): + prev_stream = CPUStream + next_stream = current_stream(torch.device("cuda")) + _test_copy_wait(prev_stream, next_stream, cuda_sleep) + + +@skip_if_no_cuda +def test_copy_wait_cuda_cpu(cuda_sleep): + prev_stream = current_stream(torch.device("cuda")) + next_stream = CPUStream + _test_copy_wait(prev_stream, next_stream, cuda_sleep) + + +@skip_if_no_cuda +def test_copy_wait_cuda_cuda(cuda_sleep): + prev_stream = current_stream(torch.device("cuda")) + next_stream = new_stream(torch.device("cuda")) + _test_copy_wait(prev_stream, next_stream, cuda_sleep) + + +def test_wait_multiple_tensors(): + a = torch.rand(1, requires_grad=True) + b = torch.rand(1, requires_grad=True) + + a, b = Wait.apply(CPUStream, CPUStream, a, b) + + assert a.grad_fn is b.grad_fn + assert a.grad_fn.__class__ is Wait._backward_cls + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_deferred_batch_norm.py b/test/distributed/pipeline/sync/test_deferred_batch_norm.py new file mode 100644 index 000000000000..c3807c57d612 --- /dev/null +++ b/test/distributed/pipeline/sync/test_deferred_batch_norm.py @@ -0,0 +1,200 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from copy import deepcopy +from itertools import chain + +import pytest + +import torch +from torch import nn, optim + +from torch.distributed.pipeline.sync.batchnorm import DeferredBatchNorm +from torch.testing._internal.common_utils import run_tests + +CHUNKS = 4 + + +def tilt_dist(input): + # Tilt variance by channel. + rgb = input.transpose(0, 1) + rgb[0] *= 1 + rgb[1] *= 10 + rgb[2] *= 100 + + # Tilt mean by single batch. + for i, single in enumerate(input): + single += 2**i + + return input + + +def chunked_forward(model, input, chunks=CHUNKS): + output_chunks = [] + + for chunk in input.chunk(chunks): + output_chunks.append(model(chunk)) + + return torch.cat(output_chunks) + + +@pytest.mark.parametrize("chunks", [1, 4]) +@pytest.mark.parametrize("input_requires_grad", [True, False]) +def test_transparency(chunks, input_requires_grad): + bn = nn.BatchNorm2d(3) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks) + + input1 = torch.rand(16, 3, 224, 224) + input1 = tilt_dist(input1) + input2 = input1.clone() + input1.requires_grad = input_requires_grad + input2.requires_grad = input_requires_grad + + output1 = chunked_forward(bn, input1, chunks=chunks) + output2 = chunked_forward(dbn, input2, chunks=chunks) + + assert torch.allclose(output1, output2, atol=1e-4) + + output1.mean().backward() + output2.mean().backward() + + assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4) + + if input_requires_grad: + assert input1.grad is not None + assert input2.grad is not None + assert torch.allclose(input1.grad, input2.grad, atol=1e-4) + + +@pytest.mark.parametrize("momentum", [0.1, None]) +def test_running_stats(momentum): + bn = nn.BatchNorm2d(3, momentum=momentum) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + bn(input) + chunked_forward(dbn, input) + + assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4) + assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4) + + +def test_convert_deferred_batch_norm(): + bn = nn.BatchNorm2d(3, track_running_stats=False) + bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS) + assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False + + dbn = DeferredBatchNorm(3, chunks=CHUNKS) + dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS) + assert dbn is dbn_again + + dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1) + assert dbn is not dbn_again # because of different chunks + + +def test_eval(): + bn = nn.BatchNorm2d(3) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + bn(input) + chunked_forward(dbn, input) + + bn.eval() + dbn.eval() + + assert torch.allclose(bn(input), dbn(input), atol=1e-4) + + +def test_optimize(): + bn = nn.BatchNorm2d(3) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0) + + for i in range(5): + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + # train + y = bn(input) + a = y.sum() + a.backward() + + y = chunked_forward(dbn, input) + b = y.sum() + b.backward() + + opt.step() + + # eval + bn.eval() + dbn.eval() + + with torch.no_grad(): + assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10**i)) + + +def test_conv_bn(): + bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1) + + # 1st step + a = bn(input) + b = chunked_forward(dbn, input) + + # Outputs are different. (per-mini-batch vs. per-micro-batch) + assert not torch.allclose(a, b) + + a.sum().backward() + b.sum().backward() + opt.step() + opt.zero_grad() + + # Conv layers are also trained differently because of their different outputs. + assert not torch.allclose(bn[0].weight, dbn[0].weight) + + # But BNs track identical running stats. + assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) + assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) + + # 2nd step + a = bn(input) + b = chunked_forward(dbn, input) + a.sum().backward() + b.sum().backward() + + # BNs can't track identical running stats due to the different conv layers. + assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) + assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) + + +def test_input_requiring_grad(): + dbn = DeferredBatchNorm(3, chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + input.requires_grad = True + + chunked_forward(dbn, input) + + assert not dbn.sum.requires_grad + assert dbn.sum.grad_fn is None + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_dependency.py b/test/distributed/pipeline/sync/test_dependency.py new file mode 100644 index 000000000000..e966d6541bf5 --- /dev/null +++ b/test/distributed/pipeline/sync/test_dependency.py @@ -0,0 +1,152 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import weakref + +import pytest + +import torch + +from torch.distributed.pipeline.sync.dependency import Fork, fork, Join, join +from torch.testing._internal.common_utils import run_tests + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def test_fork_join(): + logs = [] + + class Log(torch.autograd.Function): + @staticmethod + def forward(ctx, number, tensor): + ctx.number = number + return tensor.detach() + + @staticmethod + def backward(ctx, grad): + logs.append(ctx.number) + return None, grad + + a = torch.rand(1, device="cpu", requires_grad=True) + b = torch.rand(1, device="cuda", requires_grad=True) + + a = Log.apply(1, a) + + a, phony = fork(a) + b = join(a, phony) + + b = Log.apply(2, b) + b = b.to("cpu") + + (a + b).backward() + + assert logs == [2, 1] + + +def test_fork_join_enable_grad(): + x = torch.rand(1, requires_grad=True) + + with torch.enable_grad(): + x2, p = fork(x) + + assert p.requires_grad + assert x2 is not x + x = x2 + + assert x.requires_grad + assert p.requires_grad + assert x.grad_fn.__class__ is Fork._backward_cls + assert p.grad_fn.__class__ is Fork._backward_cls + + with torch.enable_grad(): + x2 = join(x, p) + + assert x2 is not x + x = x2 + + assert x.requires_grad + assert x.grad_fn.__class__ is Join._backward_cls + + +def test_fork_join_no_grad(monkeypatch): + def do_not_apply(*args): + raise AssertionError("Function.apply called") + + monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply) + + x = torch.rand(1, requires_grad=True) + + with torch.no_grad(): + x2, p = fork(x) + + assert not p.requires_grad + assert x2 is x + x = x2 + + with torch.no_grad(): + x2 = join(x, p) + + assert x2 is x + x = x2 + + +def test_fork_leak(): + leak = None + + class F(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad): + nonlocal leak + leak = weakref.ref(ctx) + return grad + + x = torch.rand(1, requires_grad=True) + x = F.apply(x) + x, phony = fork(x) + x = join(x, phony) + + x.backward() + del x, phony + + assert leak() is None + + +def test_join_when_fork_not_requires_grad(): + x = torch.rand(2, 1) + a, b = x.chunk(2) + + assert not a.requires_grad + a, p = fork(a) + assert not a.requires_grad + assert not p.requires_grad + + assert not b.requires_grad + b = join(b, p) + assert not b.requires_grad + + +def test_join_when_fork_requires_grad(): + x = torch.rand(2, 1) + a, b = x.chunk(2) + + a.requires_grad_() + assert a.requires_grad + a, p = fork(a) + assert a.requires_grad + assert p.requires_grad + + assert not b.requires_grad + b = join(b, p) + assert b.requires_grad + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_inplace.py b/test/distributed/pipeline/sync/test_inplace.py new file mode 100644 index 000000000000..33f31b2a52bb --- /dev/null +++ b/test/distributed/pipeline/sync/test_inplace.py @@ -0,0 +1,79 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +import torch +from torch import nn + +from torch.distributed.pipeline.sync import Pipe +from torch.testing._internal.common_utils import run_tests + + +def test_inplace_on_requires_grad(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) + model = Pipe(model, checkpoint="always") + + x = torch.rand(1) + y = model(x).local_value() + + message = r"a leaf Variable that requires grad .* used in an in-place operation." + with pytest.raises(RuntimeError, match=message): + y.backward() + + +@pytest.mark.xfail(strict=True) +def test_inplace_on_not_requires_grad(setup_rpc): + # In-place operation on a tensor not requiring grad doesn't cause a + # RuntimeError. Currently, we cannot detect this case. + model = nn.Sequential(nn.ReLU(inplace=True)) + model = Pipe(model, [1], devices=["cpu"], checkpoint="always") + + x = torch.rand(1) + y = model(x).local_value() + del model + + message = r"a leaf Variable that requires grad .* used in an in-place operation." + with pytest.raises(RuntimeError, match=message): + y.backward() + + +@pytest.mark.xfail(strict=True) +def test_inplace_incorrect_grad(setup_rpc): + class M(nn.Module): + def forward(self, foo_bar): + # 'foo' requires grad but 'bar' does not. In-place operation on + # 'bar' won't cause a RuntimeError. + foo, bar = foo_bar + + # add_(1) is not idempotent, in contrast to relu_(). If it is + # executed multiple times, it will accumulates each difference onto + # 'bar'. + bar.add_(1) + + # 'bar' is still captured by checkpointing. 'foo' will get + # incorrect grad. + return foo * bar + + model = nn.Sequential(M()) + model = Pipe(model, [1], devices=["cpu"], checkpoint="always") + + foo = torch.tensor([1.0], requires_grad=True) + bar = torch.tensor([1.0]) + + output = model((foo, bar)).local_value() + del model + output.backward() + + # The gradient of 'foo' should be 2, but it is 3 actually because + # bar.add_(1) was executed twice due to checkpointing. + assert foo.grad.item() == 2.0 + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_microbatch.py b/test/distributed/pipeline/sync/test_microbatch.py new file mode 100644 index 000000000000..b5e44aa73a8d --- /dev/null +++ b/test/distributed/pipeline/sync/test_microbatch.py @@ -0,0 +1,148 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +import torch +import torch.cuda + +from torch.distributed.pipeline.sync.microbatch import Batch, check, gather, scatter +from torch.testing._internal.common_utils import run_tests + + +def test_batch_atomic(): + x = torch.tensor(42) + b = Batch(x) + + assert b.atomic + + assert b.tensor is x + with pytest.raises(AttributeError): + b.tensors + + assert list(b) == [x] + assert len(b) == 1 + assert b[0] is x + + +def test_batch_non_atomic(): + x, y = torch.tensor(42), torch.tensor(21) + b = Batch((x, y)) + + assert not b.atomic + + with pytest.raises(AttributeError): + b.tensor + + assert list(b) == [x, y] + assert len(b) == 2 + assert b[0] is x + assert b[1] is y + + +def test_batch_call(): + a = Batch(torch.tensor(42)) + b = Batch((torch.tensor(42), torch.tensor(21))) + + def f(x): + return x + + def g(x, y): + return x, y + + assert a.call(f).atomic + assert not b.call(g).atomic + + +def test_batch_setitem_by_index(): + a = Batch(torch.tensor(42)) + b = Batch((torch.tensor(42), torch.tensor(21))) + + a[0] = torch.tensor(0) + b[0] = torch.tensor(0) + + assert a.atomic + assert a[0].item() == 0 + + assert not b.atomic + assert len(b) == 2 + assert b[0].item() == 0 + assert b[1].item() == 21 + + +def test_batch_setitem_by_slice(): + a = Batch(torch.tensor(42)) + b = Batch((torch.tensor(42), torch.tensor(21))) + + a[:] = (torch.tensor(0),) + b[:] = (torch.tensor(0),) + + assert a.atomic + assert a[0].item() == 0 + + assert not b.atomic + assert len(b) == 1 + assert b[0].item() == 0 + + +def test_check(): + check(torch.device("cpu"), torch.tensor(42)) + check(torch.device("cpu"), torch.tensor(4), torch.tensor(2)) + + with pytest.raises(TypeError): + check(torch.device("cpu"), 42) + + with pytest.raises(TypeError): + check(torch.device("cpu"), "str") + + with pytest.raises(TypeError): + check(torch.device("cpu"), (torch.tensor(4), 2)) + + +def test_gather_tensors(): + a = torch.zeros(1, 1) + b = torch.zeros(1, 1) + + ab = gather([Batch(a), Batch(b)]) + + assert ab.size() == (2, 1) + + +def test_gather_tuples(): + a = (torch.zeros(1, 1), torch.zeros(2, 2)) + b = (torch.zeros(1, 1), torch.zeros(2, 2)) + + ab = gather([Batch(a), Batch(b)]) + + assert isinstance(ab, tuple) + assert ab[0].size() == (2, 1) + assert ab[1].size() == (4, 2) + + +def test_scatter_tensor(): + ab = torch.zeros(2, 1) + + a, b = scatter(ab, chunks=2) + + assert a.tensor.size() == (1, 1) + assert b.tensor.size() == (1, 1) + + +def test_scatter_multiple_tensors(): + ab = (torch.zeros(2, 1), torch.zeros(4, 2)) + + a, b = scatter(*ab, chunks=2) + + assert next(iter(a)).size() == (1, 1) + assert next(iter(b)).size() == (1, 1) + assert list(a)[1].size() == (2, 2) + assert list(b)[1].size() == (2, 2) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_phony.py b/test/distributed/pipeline/sync/test_phony.py new file mode 100644 index 000000000000..6aeb873b30b2 --- /dev/null +++ b/test/distributed/pipeline/sync/test_phony.py @@ -0,0 +1,57 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from torch.distributed.pipeline.sync.phony import get_phony +from torch.testing._internal.common_utils import run_tests + + +def test_phony_size(): + p = get_phony(torch.device("cpu"), requires_grad=False) + assert p.size() == (0,) + + +def test_phony_requires_grad(): + p1 = get_phony(torch.device("cpu"), requires_grad=True) + p2 = get_phony(torch.device("cpu"), requires_grad=False) + assert p1.requires_grad + assert not p2.requires_grad + + +def test_cached_phony(): + p1 = get_phony(torch.device("cpu"), requires_grad=True) + p2 = get_phony(torch.device("cpu"), requires_grad=True) + assert p1 is p2 + + p3 = get_phony(torch.device("cpu"), requires_grad=False) + p4 = get_phony(torch.device("cpu"), requires_grad=False) + assert p3 is p4 + + assert p1 is not p3 + + +def test_phony_in_autograd_function(): + class Phonify(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + phony = get_phony(input.device, requires_grad=False) + return phony.detach() + + x = torch.rand(1, requires_grad=True) + + p1 = Phonify.apply(x) + p2 = get_phony(torch.device("cpu"), requires_grad=True) + + assert p1 is not p2 + assert p1.grad_fn is not None + assert p2.grad_fn is None + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_pipe.py b/test/distributed/pipeline/sync/test_pipe.py new file mode 100644 index 000000000000..e493b1d5a03e --- /dev/null +++ b/test/distributed/pipeline/sync/test_pipe.py @@ -0,0 +1,858 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import random +import time +from collections import OrderedDict +from copy import deepcopy + +import pytest + +import torch +from torch import nn, Tensor + +from torch.distributed.pipeline.sync import NoChunk, Pipe, WithDevice +from torch.distributed.pipeline.sync.pipe import PipeSequential +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_utils import run_tests, TEST_CUDA + +skip_if_no_cuda = pytest.mark.skipif(not TEST_CUDA, reason="cuda required") + + +def test_pipe_without_rpc(): + model = nn.Sequential(nn.Linear(1, 1)) + with pytest.raises(RuntimeError, match="Please initialize RPC framework"): + pipe = Pipe(model, chunks=1) + + +def test_parameters(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + pipe = Pipe(model, chunks=1) + assert list(pipe.parameters()) != [] + + +def test_public_attrs(setup_rpc): + class MyString: + def __init__(self, value): + self.value = value + + def __str__(self): + return self.value + + model = nn.Sequential(nn.Linear(1, 1)) + pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always")) + + assert pipe.devices == [torch.device("cpu")] + assert pipe.chunks == 42 + assert isinstance(pipe.chunks, int) + assert pipe.checkpoint == "always" + assert isinstance(pipe.checkpoint, str) + + +def test_sequential_like(setup_rpc): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + model = Pipe(model) + + assert len(model) == 2 + assert list(model) == [a, b] + + assert model[0] is a + assert model[1] is b + with pytest.raises(IndexError): + _ = model[2] + + assert model[-1] is b + assert model[-2] is a + + +def test_chunks_less_than_1(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + + with pytest.raises(ValueError): + Pipe(model, chunks=0) + + with pytest.raises(ValueError): + Pipe(model, chunks=-1) + + +def test_batch_size_indivisible(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, chunks=4) + + with pytest.warns(None) as record: + model(torch.rand(7, 1)) + + # Indivisible batch size is legal. + assert not record + + +def test_batch_size_small(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, chunks=4) + + with pytest.warns(None) as record: + model(torch.rand(2, 1)) + + # Batch size smaller than chunks is legal. + assert not record + + +def test_checkpoint_mode(setup_rpc): + def count_grad_fn(grad_fn, name, visited=None): + if visited is None: + visited = set() + if grad_fn in visited: + return 0 + visited.add(grad_fn) + + if grad_fn is None: + return 0 + if grad_fn.__class__.__name__ == name: + return 1 + + counter = 0 + for next_grad_fn, _ in grad_fn.next_functions: + counter += count_grad_fn(next_grad_fn, name, visited=visited) + return counter + + model = nn.Sequential(nn.Linear(1, 1)) + input = torch.rand(2, 1) + + always = Pipe(model, chunks=2, checkpoint="always") + except_last = Pipe(model, chunks=2, checkpoint="except_last") + never = Pipe(model, chunks=2, checkpoint="never") + + always_output = always(input) + except_last_output = except_last(input) + never_output = never(input) + + assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2 + assert ( + count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") + == 1 + ) + assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0 + + +def test_checkpoint_mode_invalid(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + + with pytest.raises( + ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'" + ): + Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT") + + +def test_checkpoint_mode_when_chunks_1(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + + # All checkpoint modes are fine. + Pipe(model, chunks=1, checkpoint="except_last") + Pipe(model, chunks=1, checkpoint="always") + Pipe(model, chunks=1, checkpoint="never") + + +def test_checkpoint_eval(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, chunks=2) + input = torch.rand(2, 1) + + def find_grad_fn(grad_fn, name): + if grad_fn is None: + return False + if grad_fn.__class__.__name__ == name: + return True + for next_grad_fn, _ in grad_fn.next_functions: + if find_grad_fn(next_grad_fn, name): + return True + return False + + model.train() + train_output = model(input) + assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward") + assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward") + + model.eval() + eval_output = model(input) + assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward") + assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward") + + +def test_checkpoint_non_float_input(setup_rpc): + class ForkNonFloat(nn.Module): + def forward(self, input): + return (input * 2, torch.tensor([False])) + + class JoinNonFloat(nn.Module): + def forward(self, input, non_float): + return input * 2 + + model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) + model = Pipe(model, chunks=1, checkpoint="always") + + input = torch.rand(1, requires_grad=True) + output = model(input) + output.backward() + + +def test_no_grad(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, chunks=2) + input = torch.rand(2, 1) + + latent = None + + def hook(module, input, output): + _ = module + _ = input + + nonlocal latent + latent = output + + partition = model.partitions[0] + partition.register_forward_hook(hook) + + with torch.no_grad(): + model(input) + + assert latent.grad_fn is None + + +def test_exception(setup_rpc): + class ExpectedException(Exception): + pass + + class Raise(nn.Module): + def forward(self, *_): + raise ExpectedException + + model = nn.Sequential(Raise()) + model = Pipe(model, chunks=1) + + with pytest.raises(ExpectedException): + model(torch.rand(1)) + + +def test_exception_early_stop_asap(setup_rpc): + """Even the first partitions have finished to process, the partition before + the failed partition should be killed as soon as possible. + """ + + class ExpectedException(Exception): + pass + + class Pass(nn.Module): + def forward(self, x): + return x + + counter = 0 + + class Counter(nn.Module): + def forward(self, x): + time.sleep(0.1) + + nonlocal counter + counter += 1 + + return x + + class Raise(nn.Module): + def forward(self, x): + raise ExpectedException + + model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) + model = Pipe(model, chunks=3) + + with pytest.raises(ExpectedException): + model(torch.rand(3)) + + # If the early stop doesn't work, it would be 3 instead. + assert counter == 2 + + +def test_nested_input(setup_rpc): + class NestedInput(nn.Module): + def __init__(self): + super().__init__() + self.fc_a = nn.Linear(1, 1) + self.fc_b = nn.Linear(1, 1) + + def forward(self, inp): + return inp + + model = nn.Sequential(NestedInput()) + model = Pipe(model, chunks=2) + + a = torch.rand(10, 1, requires_grad=True) + b = torch.rand(10, 1, requires_grad=True) + + # TypeError: expected Tensor, but got tuple + with pytest.raises(TypeError): + model((a, (a, b))).local_value() + + # TypeError: expected Tensor, but got list + with pytest.raises(TypeError): + model((a, [a, b])).local_value() + + +def test_input_pair(setup_rpc): + class Two(nn.Module): + def __init__(self): + super().__init__() + self.fc_a = nn.Linear(1, 1) + self.fc_b = nn.Linear(1, 1) + + def forward(self, a, b): + return (self.fc_a(a), self.fc_b(b)) + + model = nn.Sequential(Two()) + model = Pipe(model, chunks=2) + + a = torch.rand(10, 1, requires_grad=True) + b = torch.rand(10, 1, requires_grad=True) + + a_out, b_out = model(a, b).local_value() + loss = (a_out + b_out).mean() + loss.backward() + + assert a.grad is not None + assert b.grad is not None + + +def test_multi_sequence_input(setup_rpc): + class MultiSeq(nn.Module): + def forward(self, tup1, tup2): + return tup1, tup2 + + model = Pipe(nn.Sequential(MultiSeq())) + with pytest.raises(TypeError): + model([torch.rand(10), torch.rand(10)], [torch.rand(10), torch.rand(10)]) + + +def test_input_singleton(setup_rpc): + class One(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(1, 1) + + def forward(self, a): + return (self.fc(a),) + + model = nn.Sequential(One()) + model = Pipe(model, chunks=2) + + a = torch.rand(10, 1, requires_grad=True) + + (a_out,) = model(a).local_value() + loss = a_out.mean() + loss.backward() + + assert all(p.grad is not None for p in model.parameters()) + assert a.grad is not None + + +def test_input_varargs(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model) + + a = torch.rand(1) + b = torch.rand(1) + + # TypeError: forward() takes 2 positional arguments but 3 were given + with pytest.raises(TypeError): + model(a, b) + + +def test_non_tensor(setup_rpc): + class NonTensor(nn.Module): + def forward(self, _): + return "hello" + + model = nn.Sequential(NonTensor()) + model = Pipe(model) + x = torch.rand(1) + + with pytest.raises(TypeError): + model(x) + + with pytest.raises(TypeError): + model("hello") + + +def test_non_tensor_sequence(setup_rpc): + class NonTensorTuple(nn.Module): + def forward(self, x): + return (x, "hello") + + class NonTensorArgs(nn.Module): + def forward(self, x: str, y: bool): + return x, y + + model = nn.Sequential(NonTensorTuple()) + model = Pipe(model) + x = torch.rand(1) + + with pytest.raises(TypeError): + model((x, "hello")) + + with pytest.raises(TypeError): + model([x, "hello"]) + + model = nn.Sequential(NonTensorArgs()) + model = Pipe(model) + + with pytest.raises(TypeError): + # Need atleast one Tensor. + model("hello", True) + + +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +def test_valid_non_tensor(checkpoint, setup_rpc): + class NonTensor1(nn.Module): + def forward(self, a: int, b: Tensor, c: bool, d: Tensor): + res = b + a if c else b * a + if d is not None: + res += d + return res, c, a, b, "hello", d + + class NonTensor2(nn.Module): + def forward(self, a: Tensor, b: bool, c: int, d: Tensor, e: str, f: Tensor): + res = a * c if b else a + c + res += d + return c, res, a, d + f if f is not None else d, b, e, f + + model = Pipe( + nn.Sequential(NonTensor1(), NonTensor2()), chunks=5, checkpoint=checkpoint + ) + a = random.randint(0, 10) + b = torch.rand(10, 10) + c = random.randint(0, 1) == 0 + d = torch.rand(10, 10) + res = model(a, b, c, d).local_value() + assert 7 == len(res) + assert [a] * 5 == res[0] + if c: + assert torch.allclose(((b + a + d) * a) + b, res[1]) + assert torch.allclose(b + a + d, res[2]) + else: + assert torch.allclose(((b * a) + d + a) + b, res[1]) + assert torch.allclose(b * a + d, res[2]) + assert torch.allclose(b + d, res[3]) + assert [c] * 5 == res[4] + assert ["hello"] * 5 == res[5] + assert torch.allclose(d, res[6]) + + # Test one of the tensors can be None + res = model(a, b, c, None).local_value() + assert 7 == len(res) + assert [a] * 5 == res[0] + if c: + assert torch.allclose(((b + a) * a) + b, res[1]) + assert torch.allclose(b + a, res[2]) + else: + assert torch.allclose(((b * a) + a) + b, res[1]) + assert torch.allclose(b * a, res[2]) + assert torch.allclose(b, res[3]) + assert [c] * 5 == res[4] + assert ["hello"] * 5 == res[5] + assert [None] * 5 == res[6] + + # Need atleast one tensor. + with pytest.raises(TypeError): + model(a, None, c, None) + + +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +def test_no_tensor_output(checkpoint, setup_rpc): + class Model1(nn.Module): + def forward(self, a: int, b: Tensor, c: bool): + return a, c, "hello" + + class Model2(nn.Module): + def forward(self, a: int, b: bool, c: str): + return a, c, b + + model = Pipe(nn.Sequential(Model1(), Model2()), chunks=5) + a = random.randint(0, 10) + b = torch.rand(10, 10) + c = random.randint(0, 1) == 0 + + # Need atleast one tensor across partitions too. + with pytest.raises(TypeError): + res = model(a, b, c).local_value() + + +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +def test_uneven_batch_size(checkpoint, setup_rpc): + class Model(nn.Module): + def forward(self, a: Tensor, b: int, c: Tensor): + return a, b, c + + model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) + a = torch.rand(3, 10) + b = random.randint(0, 10) + c = torch.rand(6, 10) + res = model(a, b, c).local_value() + assert torch.allclose(a, res[0]) + assert [b] * 3 == res[1] # 3 chunks + assert torch.allclose(c, res[2]) + + # Two tensors producing uneven chunks would fail. + model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) + a = torch.rand(3, 10) + b = random.randint(0, 10) + c = torch.rand(4, 10) + + with pytest.raises(RuntimeError, match="Found different number of chunks"): + model(a, b, c) + + +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +def test_no_chunk(checkpoint, setup_rpc): + class Model(nn.Module): + def forward(self, a: Tensor, b: int, c: Tensor): + return a, b, c + + model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) + a = torch.rand(10, 10) + b = random.randint(0, 10) + c = torch.rand(10, 10) + res = model(a, b, NoChunk(c)).local_value() + assert torch.allclose(a, res[0]) + assert [b] * 5 == res[1] + # c gets replicated due to NoChunk and the same tensor gets concatenated 5 + # times in the output. + assert torch.allclose(torch.cat((c, c, c, c, c)), res[2]) + + # Test invalid type for NoChunk + with pytest.raises(TypeError, match="NoChunk only supported for tensors"): + NoChunk(b) + + +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +def test_deferred_batch_norm(checkpoint, setup_rpc): + bn = nn.BatchNorm2d(3) + pipe_bn = deepcopy(bn) + pipe = Pipe( + nn.Sequential(pipe_bn), + chunks=2, + checkpoint=checkpoint, + deferred_batch_norm=True, + ) + + x = torch.rand(4, 3, 10, 10) + pipe(x).local_value().mean().backward() + bn(x).mean().backward() + + assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) + assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4) + + +@pytest.mark.parametrize("checkpoint", ["never", "always"]) +def test_deferred_batch_norm_params(checkpoint, setup_rpc): + bn = nn.BatchNorm2d(3) + pipe_bn = deepcopy(bn) + pipe = Pipe( + nn.Sequential(pipe_bn), + chunks=1, + checkpoint=checkpoint, + deferred_batch_norm=True, + ) + + x = torch.rand(4, 3, 10, 10) + pipe(x).local_value().mean().backward() + bn(x).mean().backward() + + assert pipe[0].weight.grad is not None + assert pipe[0].bias.grad is not None + + assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4) + assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4) + + +def test_devices(setup_rpc): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + c = nn.Linear(1, 1) + + # There are extra two devices. + model = nn.Sequential(a, b, c) + model = Pipe(model) + + cpu = torch.device("cpu") + # Extra devices must be discarded. + assert model.devices == [cpu, cpu, cpu] + + +def test_partitions(setup_rpc): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + model = Pipe(model) + + assert isinstance(model.partitions, nn.ModuleList) + assert isinstance(model.partitions[0], nn.Sequential) + assert isinstance(model.partitions[1], nn.Sequential) + + assert "partitions.0.0.weight" in model.state_dict() + + +@skip_if_no_cuda +def test_merged_partitions(setup_rpc): + a = nn.Linear(1, 1).to(0) + b = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 2)).to(0) + c = nn.Linear(1, 1) + d = nn.Linear(1, 2) + + model = nn.Sequential(a, b, c, d) + model = Pipe(model) + + assert isinstance(model.partitions, nn.ModuleList) + assert isinstance(model.partitions[0], PipeSequential) + assert isinstance(model.partitions[1], PipeSequential) + assert list(model.partitions[0]) == [a, b[0], b[1]] + assert list(model.partitions[1]) == [c] + assert list(model.partitions[2]) == [d] + + +def test_deny_moving(setup_rpc): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + model = Pipe(model) + + # Moving is denied. + with pytest.raises(TypeError): + model.cuda() + + with pytest.raises(TypeError): + model.cpu() + + with pytest.raises(TypeError): + model.to(torch.device("cuda")) + + with pytest.raises(TypeError): + model.to(0) + + with pytest.raises(TypeError): + model.to("cuda") + + with pytest.raises(TypeError): + model.to(device=0) + + with pytest.raises(TypeError): + model.to(torch.rand(1)) + + with pytest.raises(TypeError): + model.to(tensor=torch.rand(1)) + + # Casting is allowed. + model.half() + model.to(torch.double) + model.to(dtype=torch.float) + + +def test_empty_module(setup_rpc): + # Empty sequential module is not illegal. + model = nn.Sequential() + model = Pipe(model) + + assert model(torch.tensor(42)).local_value() == torch.tensor(42) + + # But only tensor or tensors is legal in Pipe. + with pytest.raises(TypeError): + model(42) + + +def test_named_children(setup_rpc): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) + model = Pipe(model) + + names = {n for n, _ in model.named_modules()} + assert "partitions.0.0" in names + assert "partitions.1.0" in names + + # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires + # several methods in its namespace. + with pytest.raises(AttributeError): + model.a + + +def test_verify_module_non_sequential(setup_rpc): + with pytest.raises( + TypeError, match="module must be nn.Sequential to be partitioned" + ): + Pipe(nn.Module()) + + +def test_verify_module_duplicate_children(setup_rpc): + conv = nn.Conv2d(3, 3, 1) + model = nn.Sequential(conv, conv) + + with pytest.raises( + ValueError, match="module with duplicate children is not supported" + ): + Pipe(model) + + +@skip_if_no_cuda +def test_verify_module_params_on_same_device(setup_rpc): + class Surrogate(nn.Module): + def __init__(self, param1, param2): + super().__init__() + self.param1 = param1 + self.param2 = param2 + + conv1 = nn.Conv2d(3, 3, 1) + conv2 = nn.Conv2d(3, 3, 1) + model = nn.Sequential(Surrogate(conv1, conv2.cuda())) + + with pytest.raises( + ValueError, + match=r"should have all parameters on a single device, please use .to\(\)" + " to place the module on a single device", + ): + Pipe(model) + + +@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") +def test_verify_nested_modules(setup_rpc): + model = nn.Sequential( + nn.Sequential(nn.Linear(32, 16).cuda(0), nn.Linear(16, 8).cuda(0)), + nn.Sequential(nn.Linear(8, 4).cuda(1), nn.Linear(4, 2).cuda(1)), + ) + + pipe = Pipe(model) + out = pipe(torch.rand(10, 32).cuda(0)) + assert out.local_value().device == torch.device("cuda:1") + assert out.local_value().size() == torch.Size([10, 2]) + + +def test_verify_module_duplicate_parameters_on_same_device(setup_rpc): + class Surrogate(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + conv = nn.Conv2d(3, 3, 1) + model = nn.Sequential(Surrogate(conv), Surrogate(conv)) + + Pipe(model) + + +def test_forward_lockstep(setup_rpc): + timeline = [] + + class DelayedLog(nn.Module): + def __init__(self, j, seconds): + super().__init__() + self.i = 0 + self.j = j + self.seconds = seconds + + def forward(self, x): + time.sleep(self.seconds) + + timeline.append((self.i, self.j)) + self.i += 1 + + return x + + model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1)) + model = Pipe(model, chunks=3) + model(torch.rand(3, 1)) + + # Expected timeline: (Logs are recorded at !) + # + # Partition #0: 0! 1! 2! + # Partition #1: 000! 111! 222! + # + assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)] + + +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +@skip_if_no_cuda +def test_multiple_inputs(checkpoint, setup_rpc): + class Module1(nn.Module): + def forward(self, a, b, c): + return a + b + c, a * b * c + + class Module2(nn.Module): + def forward(self, a, b): + return a + b + + model = Pipe( + nn.Sequential(Module1().cuda(0), Module2().cuda(0)), + chunks=2, + checkpoint=checkpoint, + ) + t = torch.rand(10) + res = model(t, t, t).local_value() + assert torch.equal(res, (t + t + t) + (t * t * t)) + + +@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") +def test_inputs_wrong_device(setup_rpc): + class Module1(nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(5)) + + def forward(self, a, b): + return a + b + self.param, b + + # Start inputs on wrong device and ensure Pipe moves them correctly. + a = torch.rand(10).cuda(1) + b = torch.rand(10).cuda(1) + model = Pipe(nn.Sequential(Module1().cuda(0), Module1().cuda(1)), chunks=2) + with pytest.raises( + ValueError, + match="All inputs should be on the same device as the first partition", + ): + model(a, b) + + +@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") +def test_with_device_wrapper(setup_rpc): + fc1 = nn.Linear(16, 8).cuda(0) + fc2 = nn.Linear(8, 4).cuda(1) + dropout = nn.Dropout() + + model = nn.Sequential(fc1, fc2, WithDevice(dropout, "cuda:1")) + model = Pipe(model, chunks=8) + assert ( + torch.device("cuda:1") == model(torch.rand(16, 16).cuda(0)).local_value().device + ) + assert [torch.device("cuda:0"), torch.device("cuda:1")] == model.devices + + model = nn.Sequential(fc1, WithDevice(dropout, "cuda:1")) + model = Pipe(model, chunks=8) + assert ( + torch.device("cuda:1") == model(torch.rand(16, 16).cuda(0)).local_value().device + ) + assert [torch.device("cuda:0"), torch.device("cuda:1")] == model.devices + + model = nn.Sequential(fc1, WithDevice(fc2, "cuda:0")) + model = Pipe(model, chunks=8) + assert ( + torch.device("cuda:0") == model(torch.rand(16, 16).cuda(0)).local_value().device + ) + assert [torch.device("cuda:0")] == model.devices + assert torch.device("cuda:0") == fc2.weight.device + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_pipeline.py b/test/distributed/pipeline/sync/test_pipeline.py new file mode 100644 index 000000000000..9548cb959db1 --- /dev/null +++ b/test/distributed/pipeline/sync/test_pipeline.py @@ -0,0 +1,36 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from torch.distributed.pipeline.sync.pipeline import _clock_cycles +from torch.testing._internal.common_utils import run_tests + + +def test_clock_cycles(): + assert list(_clock_cycles(1, 1)) == [[(0, 0)]] + assert list(_clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]] + assert list(_clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]] + + assert list(_clock_cycles(3, 3)) == [ + [(0, 0)], + [(1, 0), (0, 1)], + [(2, 0), (1, 1), (0, 2)], + [(2, 1), (1, 2)], + [(2, 2)], + ] + + assert list(_clock_cycles(4, 2)) == [ + [(0, 0)], + [(1, 0), (0, 1)], + [(2, 0), (1, 1)], + [(3, 0), (2, 1)], + [(3, 1)], + ] + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_stream.py b/test/distributed/pipeline/sync/test_stream.py new file mode 100644 index 000000000000..f9702c8e4152 --- /dev/null +++ b/test/distributed/pipeline/sync/test_stream.py @@ -0,0 +1,198 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +import torch + +from torch.distributed.pipeline.sync.stream import ( + CPUStream, + current_stream, + default_stream, + get_device, + is_cuda, + new_stream, + record_stream, + use_device, + use_stream, + wait_stream, +) +from torch.testing._internal.common_utils import run_tests + +skip_if_no_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="cuda required" +) + + +class TestNewStream: + def test_new_stream_cpu(self): + stream = new_stream(torch.device("cpu")) + assert stream is CPUStream + + @skip_if_no_cuda + def test_new_stream_cuda(self): + stream = new_stream(torch.device("cuda")) + assert isinstance(stream, torch.cuda.Stream) + assert stream != torch.cuda.default_stream() + + +class TestCurrentStream: + def test_current_stream_cpu(self): + stream = current_stream(torch.device("cpu")) + assert stream is CPUStream + + @skip_if_no_cuda + def test_current_stream_cuda(self): + stream = current_stream(torch.device("cuda")) + assert isinstance(stream, torch.cuda.Stream) + assert stream == torch.cuda.current_stream() + + +class TestDefaultStream: + def test_default_stream_cpu(self): + stream = default_stream(torch.device("cpu")) + assert stream is CPUStream + + @skip_if_no_cuda + def test_default_stream_cuda(self): + stream = default_stream(torch.device("cuda")) + assert isinstance(stream, torch.cuda.Stream) + assert stream == torch.cuda.default_stream() + + +class TestUseDevice: + def test_use_device_cpu(self): + with use_device(torch.device("cpu")): + pass + + @skip_if_no_cuda + def test_use_device_cuda(self): + with use_device(torch.device("cuda")): + pass + + +class TestUseStream: + def test_use_stream_cpu(self): + with use_stream(CPUStream): + pass + + @skip_if_no_cuda + def test_use_stream_cuda(self): + stream = new_stream(torch.device("cuda")) + with use_stream(stream): + assert current_stream(torch.device("cuda")) == stream + + +class TestGetDevice: + def test_get_device_cpu(self): + assert get_device(CPUStream).type == "cpu" + + @skip_if_no_cuda + def test_get_device_cuda(self): + stream = current_stream(torch.device("cuda")) + assert get_device(stream).type == "cuda" + + +class TestWaitStream: + def _test_wait_stream(self, source, target, cuda_sleep=None): + with use_stream(target): + if is_cuda(target): + cuda_sleep(0.5) + x = torch.ones(100, 100, device=get_device(target)) + + wait_stream(source, target) + + with use_stream(source): + assert x.sum().item() == 10000 + + def test_wait_stream_cpu_cpu(self): + source = CPUStream + target = CPUStream + self._test_wait_stream(source, target) + + @skip_if_no_cuda + def test_wait_stream_cpu_cuda(self, cuda_sleep): + source = CPUStream + target = new_stream(torch.device("cuda")) + self._test_wait_stream(source, target, cuda_sleep) + + @skip_if_no_cuda + def test_wait_stream_cuda_cpu(self, cuda_sleep): + source = new_stream(torch.device("cuda")) + target = CPUStream + self._test_wait_stream(source, target, cuda_sleep) + + @skip_if_no_cuda + def test_wait_stream_cuda_cuda(self, cuda_sleep): + source = current_stream(torch.device("cuda")) + target = new_stream(torch.device("cuda")) + self._test_wait_stream(source, target, cuda_sleep) + + +class TestRecordStream: + def test_record_stream_cpu(self): + # It should silently ignore CPU tensors. + x = torch.rand(1, device=torch.device("cpu")) + record_stream(x, CPUStream) + + @skip_if_no_cuda + def test_record_stream_cuda(self, cuda_sleep): + # This test detects unexpected block reallocation. For reliable test, + # the stream to allocate tensors is isolated. The allocator will not + # reuse free blocks which were allocated from another stream. + stream_alloc = new_stream(torch.device("cuda")) + with torch.cuda.stream(stream_alloc): + x = torch.rand(1, device=torch.device("cuda")) + + stream = new_stream(torch.device("cuda")) + record_stream(x, stream) + with use_stream(stream): + cuda_sleep(0.5) + + # 'x' is deleted at Python's perspective. But the block of 'x' is still + # required for 'stream'. 'y' shouldn't be allocated to the block. + data_ptr = x.data_ptr() + del x + stream_alloc.synchronize() + with torch.cuda.stream(stream_alloc): + y = torch.rand(1, device=torch.device("cuda")) + assert y.data_ptr() != data_ptr + + # Pause Python until 'stream' finishes tasks queued. Now the block of + # 'x' is free to be reallocated. + wait_stream(CPUStream, stream) + with torch.cuda.stream(stream_alloc): + z = torch.rand(1, device=torch.device("cuda")) + assert z.data_ptr() == data_ptr + + @skip_if_no_cuda + def test_record_stream_shifted_view(self, cuda_sleep): + # Issue: https://github.com/pytorch/pytorch/issues/27366 + stream_alloc = new_stream(torch.device("cuda")) + with torch.cuda.stream(stream_alloc): + x = torch.rand(2, device=torch.device("cuda")) + + y = x[1:] + assert y.data_ptr() > x.data_ptr() + + stream = new_stream(torch.device("cuda")) + with use_stream(stream): + cuda_sleep(0.5) + record_stream(y, stream) + + data_ptr = x.data_ptr() + del x, y + + stream_alloc.synchronize() + with torch.cuda.stream(stream_alloc): + z = torch.rand(2, device=torch.device("cuda")) + assert z.data_ptr() != data_ptr + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_transparency.py b/test/distributed/pipeline/sync/test_transparency.py new file mode 100644 index 000000000000..a87a04150fdc --- /dev/null +++ b/test/distributed/pipeline/sync/test_transparency.py @@ -0,0 +1,55 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import torch +from torch import nn + +from torch.distributed.pipeline.sync import Pipe +from torch.testing._internal.common_utils import run_tests + + +def test_simple_linears(setup_rpc): + def sum_grad(parameters): + return sum(p.grad.sum() for p in parameters if p.grad is not None) + + def zero_grad(parameters): + for p in parameters: + p.grad = None + + inputs = torch.rand(8, 1) + model = nn.Sequential( + nn.Linear(1, 2), + nn.Linear(2, 4), + nn.Linear(4, 2), + nn.Linear(2, 1), + ) + + # Without Pipe + outputs = model(inputs) + loss = outputs.mean() + loss.backward() + + grad_without_pipe = sum_grad(model.parameters()) + + zero_grad(model.parameters()) + + # With Pipe + model = Pipe(model, chunks=4) + + outputs = model(inputs).local_value() + loss = outputs.mean() + loss.backward() + + grad_with_pipe = sum_grad(model.parameters()) + + # Both grads should be identical. + assert torch.allclose(grad_with_pipe, grad_without_pipe) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/test_worker.py b/test/distributed/pipeline/sync/test_worker.py new file mode 100644 index 000000000000..f82af2ea0067 --- /dev/null +++ b/test/distributed/pipeline/sync/test_worker.py @@ -0,0 +1,118 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import threading + +import pytest + +import torch + +from torch.distributed.pipeline.sync.microbatch import Batch +from torch.distributed.pipeline.sync.stream import CPUStream +from torch.distributed.pipeline.sync.worker import spawn_workers, Task +from torch.testing._internal.common_utils import run_tests + + +class fake_device: + """A test double for :class:`torch.device`. Every fake device is different + with each other. + """ + + type = "fake" + index = None + + +def test_compute_multithreading(): + """Task.compute should be executed on multiple threads.""" + thread_ids = set() + + def log_thread_id(): + thread_id = threading.current_thread().ident + thread_ids.add(thread_id) + return Batch(()) + + with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues): + for i in range(2): + t = Task(CPUStream, compute=log_thread_id, finalize=None) + in_queues[i].put(t) + for i in range(2): + out_queues[i].get() + + assert len(thread_ids) == 2 + + +def test_compute_success(): + """Task.compute returns (True, (task, batch)) on success.""" + + def _42(): + return Batch(torch.tensor(42)) + + with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): + t = Task(CPUStream, compute=_42, finalize=None) + in_queues[0].put(t) + ok, (task, batch) = out_queues[0].get() + + assert ok + assert task is t + assert isinstance(batch, Batch) + assert batch[0].item() == 42 + + +def test_compute_exception(): + """Task.compute returns (False, exc_info) on failure.""" + + def zero_div(): + 0 / 0 + + with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): + t = Task(CPUStream, compute=zero_div, finalize=None) + in_queues[0].put(t) + ok, exc_info = out_queues[0].get() + + assert not ok + assert isinstance(exc_info, tuple) + assert issubclass(exc_info[0], ZeroDivisionError) + + +@pytest.mark.parametrize("grad_mode", [True, False]) +def test_grad_mode(grad_mode): + def detect_grad_enabled(): + x = torch.rand(1, requires_grad=torch.is_grad_enabled()) + return Batch(x) + + with torch.set_grad_enabled(grad_mode): + with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): + task = Task(CPUStream, compute=detect_grad_enabled, finalize=None) + in_queues[0].put(task) + + ok, (_, batch) = out_queues[0].get() + + assert ok + assert batch[0].requires_grad == grad_mode + + +def test_worker_per_device(): + cpu = torch.device("cpu") + cpu0 = torch.device("cpu", index=0) + fake1 = fake_device() + fake2 = fake_device() + + with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues): + assert len(in_queues) == len(out_queues) == 5 + + # 0: cpu, 1: cpu, 2: cpu0 + assert in_queues[0] is in_queues[1] is in_queues[2] + assert out_queues[0] is out_queues[1] is out_queues[2] + + # 3: fake1, 4: fake2 + assert in_queues[3] is not in_queues[4] + assert out_queues[3] is not out_queues[4] + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 8ab2ac1f511f..1db0e5718ce6 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -329,6 +329,7 @@ def test_modules_can_be_imported(self): "torch.testing._internal.distributed.fake_pg", "torch.testing._internal.distributed.multi_threaded_pg", "torch.testing._internal.distributed.nn.api.remote_module_test", + "torch.testing._internal.distributed.pipe_with_ddp_test", "torch.testing._internal.distributed.rpc.dist_autograd_test", "torch.testing._internal.distributed.rpc.dist_optimizer_test", "torch.testing._internal.distributed.rpc.examples.parameter_server_test", @@ -407,6 +408,7 @@ def test_modules_can_be_imported(self): "torch.distributed.nn.api.remote_module", "torch.distributed.optim", "torch.distributed.optim.optimizer", + "torch.distributed.pipeline.sync", "torch.distributed.rendezvous", "torch.distributed.rpc.api", "torch.distributed.rpc.backend_registry", diff --git a/test/test_testing.py b/test/test_testing.py index 1e1dce59a32e..ba9558a3ddd1 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2245,6 +2245,7 @@ def test_circular_dependencies(self) -> None: else: ignored_modules.append("torch.distributed.nn.api.") ignored_modules.append("torch.distributed.optim.") + ignored_modules.append("torch.distributed.pipeline.") ignored_modules.append("torch.distributed.rpc.") ignored_modules.append("torch.testing._internal.dist_utils") # And these both end up with transitive dependencies on distributed diff --git a/torch/distributed/pipeline/__init__.py b/torch/distributed/pipeline/__init__.py new file mode 100644 index 000000000000..eacd2bc99d04 --- /dev/null +++ b/torch/distributed/pipeline/__init__.py @@ -0,0 +1,13 @@ +import warnings + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed.pipeline` is deprecated. For up-to-date pipeline parallel " + "implementation, please refer to the PiPPy library under the PyTorch " + "organization (Pipeline Parallelism for PyTorch): " + "https://github.com/pytorch/PiPPy", + DeprecationWarning, + stacklevel=2, + ) diff --git a/torch/distributed/pipeline/sync/LICENSE b/torch/distributed/pipeline/sync/LICENSE new file mode 100644 index 000000000000..e52be240fdc9 --- /dev/null +++ b/torch/distributed/pipeline/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright 2019-2020 Kakao Brain + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/torch/distributed/pipeline/sync/__init__.py b/torch/distributed/pipeline/sync/__init__.py new file mode 100644 index 000000000000..75a80c5db0f9 --- /dev/null +++ b/torch/distributed/pipeline/sync/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""A Pipe implementation in PyTorch.""" +from .checkpoint import is_checkpointing, is_recomputing +from .pipe import Pipe, WithDevice +from .microbatch import NoChunk + +__all__ = ["Pipe", "is_checkpointing", "is_recomputing"] diff --git a/torch/distributed/pipeline/sync/_balance/__init__.py b/torch/distributed/pipeline/sync/_balance/__init__.py new file mode 100644 index 000000000000..8ffc657896d8 --- /dev/null +++ b/torch/distributed/pipeline/sync/_balance/__init__.py @@ -0,0 +1,164 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""A helper to roughly balance a sequential module. + +Usage:: + + import torch + from torch.distributed.pipeline.sync import Pipe + from torch.distributed.pipeline.sync.balance import balance_by_time + + sample = torch.empty(128, 3, 224, 224) + balance = balance_by_time(torch.cuda.device_count(), model, sample) + + pipe = Pipe(model, balance, chunks=8) + +""" +from typing import Any, List, Union, Sequence + +import torch +from torch import Tensor +import torch.nn as nn + +from . import blockpartition +from .profile import profile_sizes, profile_times + +__all__ = ["balance_by_time", "balance_by_size"] + + +Device = Union[torch.device, int, str] + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + + +def balance_cost(cost: List[int], partitions: int) -> List[int]: + partitioned = blockpartition.solve(cost, partitions) + return [len(p) for p in partitioned] + + +def balance_by_time( + partitions: int, + module: nn.Sequential, + sample: Union[List[Any], Tensor], + *, + timeout: float = 1.0, + device: Device = torch.device("cuda"), +) -> List[int]: + """Naive automatic balancing by elapsed time per layer. + :: + + sample = torch.empty(128, 3, 224, 224) + balance = balance_by_time(torch.cuda.device_count(), model, sample) + pipe = Pipe(model, balance, chunks=8) + + Args: + partitions (int): + intended number of partitions + module (torch.nn.Sequential): + sequential module to be partitioned + sample (torch.Tensor): + example input with arbitrary batch size + + Keyword Args: + timeout (float): + profiling iterates again if the timeout (in second) is not exceeded + (default: ``1.0``) + device ('cpu' or 'cuda' device): + CPU or CUDA device where each layer is profiled (default: the + current CUDA device) + + Returns: + A list of number of layers in each partition. Use it for the `balance` + parameter of :class:`~torchpipe.Pipe`. + + .. note:: + `module` and `sample` must be placed on the same device. + + """ + times = profile_times(module, sample, timeout, torch.device(device)) + return balance_cost(times, partitions) + + +def balance_by_size( + partitions: int, + module: nn.Sequential, + input: Union[List[Any], Tensor], + *, + chunks: int = 1, + param_scale: float = 2.0, + device: Device = torch.device("cuda"), +) -> List[int]: + """Naive automatic balancing by CUDA memory usage per layer. + + During training, required memory for parameters depends on which optimizer + is used. Optimizers may use buffers for each parameter to track + optimization statistics internally, such as momentum buffer in SGD. + + To get more reliable size based balance, you should specify `param_scale` + with regard to your optimizer. The default `param_scale` is 2 instead of 1 + due to gradient accumulation which is necessary for every optimizer. + + Follow this guide to choose correct `param_scale` for typical optimizers: + + ========= ============= ========================================= + Optimizer `param_scale` Internal State + ========= ============= ========================================= + SGD 2--3 (momentum_buffer) + Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq) + Adadelta 4 square_avg, acc_delta + Adagrad 3 sum + RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg) + ========= ============= ========================================= + + Here's a simple example with the Adam optimizer:: + + balance = balance_by_size( + torch.cuda.device_count(), + model, + + # Same size with mini-batch to train + torch.empty(1024, 3, 224, 224), + + # Number of micro-batches to train with Pipe + chunks=8, + + # 4 for Adam + param_scale=4.0, + ) + + pipe = Pipe(model, balance, chunks=8) + adam = Adam(pipe.parameters()) + + Args: + partitions (int): + intended number of partitions + module (torch.nn.Sequential): + sequential module to be partitioned + input (torch.Tensor): + example mini-batch with the same size to train + + Keyword Args: + chunks (int): + number of micro-batches will be used to train (default: ``1``) + param_scale (float): + how many copies of parameters would be allocated for training. It + depends on optimizer. See the above guide. (default: ``2.0``) + device ('cuda' device): + CUDA device where each layer is profiled (default: the current CUDA + device) + + Returns: + A list of number of layers in each partition. Use it for the `balance` + parameter of :class:`~torchpipe.Pipe`. + + .. note:: + `module` and `input` must be placed on the same CUDA device. + + """ + sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device)) + return balance_cost(sizes, partitions) diff --git a/torch/distributed/pipeline/sync/_balance/blockpartition.py b/torch/distributed/pipeline/sync/_balance/blockpartition.py new file mode 100644 index 000000000000..ccdf5fe4df99 --- /dev/null +++ b/torch/distributed/pipeline/sync/_balance/blockpartition.py @@ -0,0 +1,95 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Implements "Block Partitions of Sequences" by Imre B\u00e1r\u00e1ny et al. + +Paper: https://arxiv.org/pdf/1308.2452.pdf + +""" +from typing import Iterator, List, Tuple + +__all__ = ["solve"] + + +def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]: + """Splits a sequence into several partitions to minimize variance for each + partition. + + The result might not be optimal. However, it can be done only in O(kn\u00b3), + where k is the number of partitions and n is the length of the sequence. + + """ + if partitions < 1: + raise ValueError(f"partitions must be a positive integer ({partitions} < 1)") + + n = len(sequence) + if n < partitions: + raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})") + + # Normalize the sequence in [0, 1]. + minimum = min(sequence) + maximum = max(sequence) - minimum + + normal_sequence: List[float] + if maximum == 0: + normal_sequence = [0 for _ in sequence] + else: + normal_sequence = [(x - minimum) / maximum for x in sequence] + + splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n] + + def block_size(i: int) -> float: + start = splits[i - 1] if i > 0 else 0 + stop = splits[i] + return sum(normal_sequence[start:stop]) + + def leaderboard() -> Iterator[Tuple[float, int]]: + return ((block_size(i), i) for i in range(partitions)) + + while True: + """ + (1) Fix p element-of [k] with M(P) = bp. So Bp is a maximal block of P. + """ + # max_size: M(P) + max_size, p = max(leaderboard()) + + while True: + """ + (2) If M(P) <= m(P) + 1, then stop. + """ + # min_size: m(P) + min_size, q = min(leaderboard()) + + if max_size <= min_size + 1: + return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)] + + """ + (3) If M(P) > m(P) + 1, then let m(P) = bq for the q element-of [k] which is + closest to p (ties broken arbitrarily). Thus Bq is a minimal block + of P. Let Bh be the block next to Bq between Bp and Bq. (Note that + Bh is a non-empty block: if it were, then m(P) = 0 and we should + have chosen Bh instead of Bq.) + """ + if p < q: + """ + So either p < q and then h = q-1 and we define P * by moving + the last element from Bh = Bq-1 to Bq, + """ + h = q - 1 + splits[h] -= 1 + else: + """ + or q < p, and then h = q + 1 and P * is obtained by moving the + first element of Bh = Bq+1 to Bq. + """ + h = q + 1 + splits[q] += 1 + + """ + Set P = P * . If p = h, then go to (1), else go to (2). + """ + if p == h: + break diff --git a/torch/distributed/pipeline/sync/_balance/profile.py b/torch/distributed/pipeline/sync/_balance/profile.py new file mode 100644 index 000000000000..fa1a0c06a8e3 --- /dev/null +++ b/torch/distributed/pipeline/sync/_balance/profile.py @@ -0,0 +1,116 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Per-layer profilers.""" +import copy +import time +from typing import Any, Generator, List, Union, Sequence + +import torch +from torch import Tensor +import torch.nn as nn + +from ..microbatch import Batch + +__all__: List[str] = [] + + +Device = Union[torch.device, int, str] + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + + +def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]: + """Copies layers for ease to profile. It doesn't modify the given + module. + """ + for layer in module: + layer_copy = copy.deepcopy(layer) + layer_copy.to(device) + layer_copy.train() + yield layer_copy + + +def detach(batch: Batch) -> None: + """Detaches from autograd graph.""" + for i, x in enumerate(batch): + batch[i] = x.detach().requires_grad_(x.requires_grad) + + +def profile_times(module: nn.Sequential, sample: Union[List[Any], Tensor], timeout: float, device: torch.device,) -> List[int]: + """Profiles elapsed times per layer.""" + if any(p.grad is not None for p in module.parameters()): + raise ValueError("some parameter already has gradient") + + _batch = Batch(sample) + for i, x in enumerate(_batch): + _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad) + + time_bufs: List[List[float]] = [[] for _ in module] + begun_at = time.time() + + while time.time() - begun_at < timeout: + batch = _batch + + for i, layer in enumerate(layerwise_sandbox(module, device)): + detach(batch) + + if device.type == "cuda": + torch.cuda.synchronize(device) + tick = time.time() + + # Forward + batch = batch.call(layer) + + # Backward + backward_tensors = tuple(y for y in batch if y.requires_grad) + if backward_tensors: + torch.autograd.backward(backward_tensors, backward_tensors) + + if device.type == "cuda": + torch.cuda.synchronize(device) + tock = time.time() + + time_bufs[i].append(tock - tick) + + us = 1_000_000 + return [sum(int(t * us) for t in buf) for buf in time_bufs] + + +def profile_sizes( + module: nn.Sequential, input: Union[List[Any], Tensor], chunks: int, param_scale: float, device: torch.device, +) -> List[int]: + """Profiles CUDA memory usage per layer.""" + if device.type != "cuda": + raise ValueError("size profiler supports only CUDA device") + + batch = Batch(input) + sizes: List[int] = [] + + latent_scale = batch[0].size(0) / chunks + for i, x in enumerate(batch): + batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad) + + for layer in layerwise_sandbox(module, device): + detach(batch) + + # Detect memory usage at forward. + torch._C._cuda_clearCublasWorkspaces() + memory_before = torch.cuda.memory_allocated(device) + batch = batch.call(layer) + torch._C._cuda_clearCublasWorkspaces() + memory_after = torch.cuda.memory_allocated(device) + latent_size = memory_after - memory_before + + # Analyze size of parameters. + param_size = sum(p._typed_storage()._nbytes() for p in layer.parameters()) + + # Combine size of parameters and activations with normalize scales. + size = latent_size * latent_scale + param_size * param_scale + sizes.append(int(size)) + + return sizes diff --git a/torch/distributed/pipeline/sync/_balance/py.typed b/torch/distributed/pipeline/sync/_balance/py.typed new file mode 100644 index 000000000000..ab03724cafbf --- /dev/null +++ b/torch/distributed/pipeline/sync/_balance/py.typed @@ -0,0 +1,6 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/batchnorm.py b/torch/distributed/pipeline/sync/batchnorm.py new file mode 100644 index 000000000000..868ad50cf3fc --- /dev/null +++ b/torch/distributed/pipeline/sync/batchnorm.py @@ -0,0 +1,159 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Tracks the running statistics per mini-batch instead of micro-batch.""" +from typing import TypeVar, Optional, cast + +import torch +from torch import Tensor, nn +from torch.nn.functional import batch_norm +from torch.nn.modules.batchnorm import _BatchNorm + +from .checkpoint import is_recomputing + +__all__ = ["DeferredBatchNorm"] + + +TModule = TypeVar("TModule", bound=nn.Module) + + +class DeferredBatchNorm(_BatchNorm): + """A BatchNorm layer tracks multiple micro-batches to update running statistics per mini-batch.""" + + sum: Tensor + sum_squares: Tensor + running_mean: Tensor + running_var: Tensor + num_batches_tracked: Tensor + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + chunks: int = 1, + ) -> None: + super().__init__(num_features, eps, momentum, affine, track_running_stats=True) + + self.register_buffer("sum", torch.zeros_like(self.running_mean)) + self.register_buffer("sum_squares", torch.zeros_like(self.running_var)) + + self.counter = 0 + self.tracked = 0 + self.chunks = chunks + + def _check_input_dim(self, input: Tensor) -> None: + # It's the typical _check_input_dim() implementation in PyTorch. + if input.dim() <= 2: + raise ValueError("expected at least 3D input (got %dD input)" % input.dim()) + + def _track(self, input: Tensor) -> bool: + """Tracks statistics of a micro-batch.""" + # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d. + dim = [0] + dim.extend(range(2, input.dim())) + + with torch.no_grad(): + self.sum += input.sum(dim) + self.sum_squares += (input ** 2).sum(dim) + + size = input.size().numel() // input.size(1) + self.counter += size + self.tracked += 1 + + return self.tracked == self.chunks + + def _commit(self) -> None: + """Update the running statistics of a mini-batch.""" + exponential_average_factor = 0.0 + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + mean = self.sum / self.counter + var = self.sum_squares / self.counter - mean ** 2 + + # Calculate the exponential moving average here. + m = exponential_average_factor + + self.running_mean *= 1 - m + self.running_mean += mean * m + + self.running_var *= 1 - m + self.running_var += var * m + + self.sum.zero_() + self.sum_squares.zero_() + self.counter = 0 + self.tracked = 0 + + def forward(self, input: Tensor) -> Tensor: + if not self.training: + # Don't train parameters on the evaluation mode. + return batch_norm( + input, + running_mean=self.running_mean, + running_var=self.running_var, + weight=self.weight, + bias=self.bias, + training=False, + momentum=0.0, + eps=self.eps, + ) + + if not is_recomputing(): + # Track a micro-batch on the training mode + # but not under a recomputation. + tracked_enough = self._track(input) + + # Update the running statistics for a mini-batch + # if it has tracked enough micro-batches. + if tracked_enough: + self._commit() + + # Normalize a micro-batch and train the parameters. + return batch_norm( + input, + running_mean=None, + running_var=None, + weight=self.weight, + bias=self.bias, + training=True, + momentum=0.0, + eps=self.eps, + ) + + @classmethod + def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule: + """Converts a :class:`nn.BatchNorm` or underlying :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`:: + + from torchvision.models.resnet import resnet101 + from torchpipe.batchnorm import DeferredBatchNorm + model = resnet101() + model = DeferredBatchNorm.convert_deferred_batch_norm(model) + + """ + if isinstance(module, DeferredBatchNorm) and module.chunks is chunks: + return cast(TModule, module) + + module_output: nn.Module = module + + if isinstance(module, _BatchNorm) and module.track_running_stats: + module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks) + if module.affine: + module_output.register_parameter("weight", module.weight) + module_output.register_parameter("bias", module.bias) + module_output.register_buffer("running_mean", module.running_mean) + module_output.register_buffer("running_var", module.running_var) + module_output.register_buffer("num_batches_tracked", module.num_batches_tracked) + + for name, child in module.named_children(): + module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks)) + + return cast(TModule, module_output) diff --git a/torch/distributed/pipeline/sync/checkpoint.py b/torch/distributed/pipeline/sync/checkpoint.py new file mode 100644 index 000000000000..e67da2499d57 --- /dev/null +++ b/torch/distributed/pipeline/sync/checkpoint.py @@ -0,0 +1,364 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Checkpointing with preceding recomputation. + +PyTorch already provides the official checkpointing utilities in +:mod:`torch.utils.checkpoint`. The official checkpointing combines +recomputation and recursive backpropagation into one autograd function named +``CheckpointFunction``. Hence, the recomputation can be started only when the +gradients arrive to the function. In Pipe, the recomputation needs to precede +the gradient arrival to minimize the GPU idle time. + +We solve this problem by introducing separate autograd functions named +:class:`Recompute` and :class:`Checkpoint`. Each function represents +recomputation and recursive backpropagation, respectively. We can manipulate +the control flow in aspect of both the autograd engine and CUDA with a pair of +the functions. + +Specifically, we place CUDA stream synchronization between :class:`Recompute` +and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is +copied entirely. + +""" +from collections import deque +from contextlib import contextmanager +import threading +from typing import ( + Any, + Deque, + Generator, + List, + Optional, + Protocol, + Union, + Sequence, + Tuple +) + +import torch +from torch import Tensor +import torch.autograd + +from .dependency import fork, join +from .microbatch import Batch +from .phony import get_phony + +__all__ = ["Function", "checkpoint", "Checkpointing", "ThreadLocal", "enable_checkpointing", + "enable_recomputing", "is_checkpointing", "is_recomputing", "Context", "save_rng_states", + "restore_rng_states", "Checkpoint", "Recompute"] + + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + +# Types for shared memory between Checkpoint and Recompute. +Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf) +RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state) + + +# Protocol with __call__ instead of Callable can be used as an attribute type. +# See: https://github.com/python/mypy/issues/708#issuecomment-561735949 +class Function(Protocol): + def __call__(self, input: TensorOrTensors) -> TensorOrTensors: + ... + + +def checkpoint(function: Function, input): + """Make a checkpoint with a simple interface like + :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug + :class:`Checkpoint` and :class:`Recompute` without boilerplate. + """ + batch = Batch(input) + + chk = Checkpointing(function, batch) + batch = chk.checkpoint() + chk.recompute(batch) + + return batch.values + + +class Checkpointing: + """Generates a pair of :class:`Checkpoint` and :class:`Recompute`.""" + + def __init__(self, function: Function, batch: Batch) -> None: + self.function = function + self.batch = batch + + # Shared memory between Checkpoint and Recompute. 1-length deque is + # used for mutability and length limitation. + self.recomputed: Deque[Recomputed] = deque(maxlen=1) + self.rng_states: Deque[RNGStates] = deque(maxlen=1) + + def checkpoint(self) -> Batch: + """Return a batch applied by :class:`Checkpoint`.""" + input_atomic = self.batch.atomic + inputs = tuple(self.batch) + + # Use a phony which requires grad to ensure that Checkpoint can be + # tracked by the autograd engine even when none of the input tensors + # require grad. + phony = get_phony(self.batch.get_device(), requires_grad=True) + + output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) + + # Gradients are only supported for float Tensors. + if isinstance(output, tuple): + output = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in output]) + + return Batch(output) + + def recompute(self, batch: Batch) -> None: + """Apply :class:`Recompute` to the batch in place.""" + input_atomic = self.batch.atomic + inputs = tuple(self.batch) + + # Use a tensor in the batch to tie together fork-join + tensor_idx = batch.find_tensor_idx() + # batch[tensor_idx] is always requiring grad, because it has been passed + # checkpoint with a phony requiring grad. + batch[tensor_idx], phony = fork(batch[tensor_idx]) + phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) + batch[tensor_idx] = join(batch[tensor_idx], phony) + + +class ThreadLocal(threading.local): + def __init__(self) -> None: + self.is_checkpointing = False + self.is_recomputing = False + + +thread_local = ThreadLocal() + + +@contextmanager +def enable_checkpointing() -> Generator[None, None, None]: + """Make :func:`is_checkpointing` return :data:`True` within a context.""" + orig = thread_local.is_checkpointing + thread_local.is_checkpointing = True + try: + yield + finally: + thread_local.is_checkpointing = orig + + +@contextmanager +def enable_recomputing() -> Generator[None, None, None]: + """Makes :func:`is_recomputing` return :data:`True` within a context.""" + orig = thread_local.is_recomputing + thread_local.is_recomputing = True + try: + yield + finally: + thread_local.is_recomputing = orig + + +def is_checkpointing() -> bool: + """Whether the current forward propagation is under checkpointing. + + Returns: + bool: :data:`True` if it's under checkpointing. + + """ + return thread_local.is_checkpointing + + +def is_recomputing() -> bool: + """Whether the current forward propagation is under checkpoint recomputation. + + Use this to prevent duplicated side-effects at forward + propagation:: + + class Counter(nn.Module): + def __init__(self): + super().__init__() + self.counter = 0 + + def forward(self, input): + if not is_recomputing(): + self.counter += 1 + return input + + Returns: + bool: :data:`True` if it's under checkpoint recomputation. + + .. seealso:: :ref:`Detecting Recomputation` + + """ + return thread_local.is_recomputing + + +class Context: + """The common interface between the :class:`Checkpoint` and :class:`Recompute` context.""" + + recomputed: Deque[Recomputed] + rng_states: Deque[RNGStates] + function: Function + input_atomic: bool + inputs: Sequence[Any] + + saved_tensors: Tuple[Tensor, ...] + + def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover + pass + + +def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: + """: + Capture the current random number generator states. + + meth:`Checkpoint.forward` captures the current PyTorch's random number + generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. + + .. seealso:: :ref:`Referential Transparency` + + """ + cpu_rng_state = torch.get_rng_state() + + gpu_rng_state: Optional[Tensor] + if device.type == "cuda": + gpu_rng_state = torch.cuda.get_rng_state(device) + else: + gpu_rng_state = None + + rng_states.append((cpu_rng_state, gpu_rng_state)) + + +@contextmanager +def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: + """: + Restore the random number generator state. + + meth:`Recompute.backward` restores the random number generator states + captured by :func:`save_rng_states` within its context. + + .. seealso:: :ref:`Referential Transparency` + + """ + cpu_rng_state, gpu_rng_state = rng_states.pop() + + gpu_devices: List[torch.device] = [] + if device.type == "cuda": + gpu_devices.append(device) + + with torch.random.fork_rng(gpu_devices): + torch.set_rng_state(cpu_rng_state) + if gpu_rng_state is not None: + torch.cuda.set_rng_state(gpu_rng_state, device) + yield + + +class Checkpoint(torch.autograd.Function): + @staticmethod + # type: ignore[override] + def forward( + ctx: Context, + phony: Tensor, + recomputed: Deque[Recomputed], + rng_states: Deque[RNGStates], + function: Function, + input_atomic: bool, + *inputs, + ): + ctx.recomputed = recomputed + ctx.rng_states = rng_states + + save_rng_states(phony.device, ctx.rng_states) + + ctx.function = function + ctx.input_atomic = input_atomic + if input_atomic: + tensors = [inputs[0]] + else: + tensors = [] + for input in inputs: + if torch.is_tensor(input): + tensors.append(input) + + ctx.save_for_backward(*tensors) + + with torch.no_grad(), enable_checkpointing(): + if input_atomic: + assert len(inputs) == 1 + output = function(inputs[0]) + else: + output = function(*inputs) + return output + + @staticmethod + def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover + output, input_leaf = ctx.recomputed.pop() + + if isinstance(output, tuple): + outputs = output + else: + outputs = (output,) + if any(torch.is_tensor(y) and y.requires_grad for y in outputs): + tensors = tuple([x for x in outputs if torch.is_tensor(x) and x.requires_grad]) + torch.autograd.backward(tensors, grad_output) + + grad_input: List[Optional[Tensor]] = [None, None, None, None, None] + grad_input.extend(x.grad if torch.is_tensor(x) else None for x in input_leaf) + return tuple(grad_input) + + +class Recompute(torch.autograd.Function): + @staticmethod + # type: ignore[override] + def forward( + ctx: Context, + phony: Tensor, + recomputed: Deque[Recomputed], + rng_states: Deque[RNGStates], + function: Function, + input_atomic: bool, + *inputs, + ) -> Tensor: + ctx.recomputed = recomputed + ctx.rng_states = rng_states + + ctx.function = function + ctx.input_atomic = input_atomic + ctx.inputs = inputs + if input_atomic: + tensors = [inputs[0]] + else: + tensors = [] + for input in inputs: + if torch.is_tensor(input): + tensors.append(input) + ctx.save_for_backward(*tensors) + + return phony + + @staticmethod + def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover + inputs = ctx.inputs + inputs_leaf = tuple(x.detach().requires_grad_(x.requires_grad) if torch.is_tensor(x) else x for x in inputs) + + # Get the device for the inputs from a tensor + device = None + for input in inputs: + if torch.is_tensor(input): + device = input.device + break + + if device is None: + raise RuntimeError(f'No tensors found in {inputs}') + + with restore_rng_states(device, ctx.rng_states): + with torch.enable_grad(), enable_recomputing(): + if ctx.input_atomic: + assert len(inputs_leaf) == 1 + output = ctx.function(inputs_leaf[0]) + else: + output = ctx.function(*inputs_leaf) + + ctx.recomputed.append((output, inputs_leaf)) + + grad_input: List[None] = [None, None, None, None, None] + grad_input.extend(None for _ in ctx.inputs) + return tuple(grad_input) diff --git a/torch/distributed/pipeline/sync/copy.py b/torch/distributed/pipeline/sync/copy.py new file mode 100644 index 000000000000..b717f0c2932c --- /dev/null +++ b/torch/distributed/pipeline/sync/copy.py @@ -0,0 +1,108 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Autograd functions for stream-aware CUDA copy. + +It is used to overlap copy and computation on the same GPU. +""" +from collections import deque +from typing import Deque, List, Optional, Tuple, Sequence + +import torch +from torch import Tensor + +from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream + +__all__: List[str] = ["Context", "Copy", "Wait"] + + +Tensors = Sequence[Tensor] + + +# Common interface between :class:`Copy` and :class:`Wait`. +class Context: + prev_stream: AbstractStream + next_stream: AbstractStream + + +class Copy(torch.autograd.Function): + """Copies tensors on specific streams.""" + + @staticmethod + # type: ignore[override] + def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input,) -> Tensors: + ctx.prev_stream = prev_stream + ctx.next_stream = next_stream + + output = [] + output_stream = current_stream(get_device(next_stream)) + + with use_stream(prev_stream), use_stream(next_stream): + for x in input: + if torch.is_tensor(x): + y = x.to(get_device(next_stream), non_blocking=True) + output.append(y) + + # 'prev_stream' is not where 'x' has been allocated. + record_stream(x, prev_stream) + # 'y' has been allocated on 'next_stream'. + # It might be used on the current stream captured as 'output_stream'. + record_stream(y, output_stream) + else: + output.append(x) + + return tuple(output) + + @staticmethod + def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: + prev_stream = ctx.prev_stream + next_stream = ctx.next_stream + + grad_input: Deque[Tensor] = deque(maxlen=len(grad_output)) + input_stream = current_stream(get_device(prev_stream)) + + with use_stream(prev_stream), use_stream(next_stream): + for x in reversed(grad_output): + y = x.to(get_device(prev_stream), non_blocking=True) + grad_input.appendleft(y) + + # 'next_stream' is not where 'x' has been allocated. + record_stream(x, next_stream) + # 'y' has been allocated on 'prev_stream'. + # It might be used on the current stream captured as 'input_stream'. + record_stream(y, input_stream) + + grad_streams: Tuple[Optional[Tensor], ...] = (None, None) + return grad_streams + tuple(grad_input) + + +class Wait(torch.autograd.Function): + """Synchronizes a stream to another stream. + + Place it just before you want to start an operation on the next stream, + provided that all operations on the previous stream are done. + + """ + + @staticmethod + # type: ignore[override] + def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input) -> Tensors: + ctx.prev_stream = prev_stream + ctx.next_stream = next_stream + + wait_stream(next_stream, prev_stream) + + return tuple(x.detach() if torch.is_tensor(x) else x for x in input) + + @staticmethod + def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]: + prev_stream = ctx.prev_stream + next_stream = ctx.next_stream + + wait_stream(prev_stream, next_stream) + + grad_streams: Tuple[Optional[Tensor], ...] = (None, None) + return grad_streams + grad_input diff --git a/torch/distributed/pipeline/sync/dependency.py b/torch/distributed/pipeline/sync/dependency.py new file mode 100644 index 000000000000..ca5c69e388fe --- /dev/null +++ b/torch/distributed/pipeline/sync/dependency.py @@ -0,0 +1,54 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Arbitrary dependency between two autograd lanes.""" +from typing import List, Tuple + +import torch +from torch import Tensor + +from .phony import get_phony + +__all__: List[str] = ["fork", "Fork", "join", "Join"] + + +def fork(input: Tensor) -> Tuple[Tensor, Tensor]: + """Branches out from an autograd lane of the given tensor.""" + if torch.is_grad_enabled() and input.requires_grad: + input, phony = Fork.apply(input) + else: + phony = get_phony(input.device, requires_grad=False) + + return input, phony + + +class Fork(torch.autograd.Function): + @staticmethod + def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore[override] + phony = get_phony(input.device, requires_grad=False) + return input.detach(), phony.detach() + + @staticmethod + def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore[override] + return grad_input + + +def join(input: Tensor, phony: Tensor) -> Tensor: + """Merge two autograd lanes.""" + if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): + input = Join.apply(input, phony) + + return input + + +class Join(torch.autograd.Function): + @staticmethod + def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore[override] + return input.detach() + + @staticmethod + def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore[override] + return grad_input, None diff --git a/torch/distributed/pipeline/sync/microbatch.py b/torch/distributed/pipeline/sync/microbatch.py new file mode 100644 index 000000000000..5b8aca257548 --- /dev/null +++ b/torch/distributed/pipeline/sync/microbatch.py @@ -0,0 +1,234 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Manipulation of micro-batches.""" +import typing +from typing import Any, Callable, List, Union, cast, Sequence + +import torch +from torch import Tensor +import torch.cuda.comm + +__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"] + + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] +Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]] + + +class NoChunk: + """ + Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor + should not be chunked on the batch dimension and instead be replicated + as-is across all micro-batches. This is useful for tensors which might + not have any 'batch' semantics for the model. + """ + def __init__(self, inp: Tensor): + if not torch.is_tensor(inp): + raise TypeError(f'NoChunk only supported for tensors, found: {inp}') + self._tensor = inp + + @property + def tensor(self): + return self._tensor + + +class Batch: + """ + An abstraction representing a microbatch in the pipeline. + """ + + def __init__(self, values: Union[List[Any], Tensor]) -> None: + self._values = values + self.atomic = torch.is_tensor(values) + + # Verify at least on tensor + if not self.atomic: + if not any(torch.is_tensor(value) for value in self._values): + raise TypeError(f'No tensors found in batch: {self._values}') + + @property + def tensor(self) -> Tensor: + """Retrieves the underlying tensor.""" + if not self.atomic: + raise AttributeError("not atomic batch") + return cast(Tensor, self._values) + + @property + def values(self): + """Retrieves the underlying values for the batch""" + return self._values + + def find_tensor_idx(self): + """ + Retrieves the index of first tensor found. + """ + if self.atomic: + return 0 + for i, value in enumerate(self._values): + if torch.is_tensor(value): + return i + + raise TypeError("No tensor found!") + + def get_device(self): + """ + Retrieves the device for this microbatch. + """ + if self.atomic: + return self._values.device # type: ignore[union-attr] + + for value in self._values: + if torch.is_tensor(value): + return value.device + + def call(self, function: Function) -> "Batch": + """Calls a function on the microbatch. It also wraps + the output with :class:`Batch`. + """ + if self.atomic: + return Batch(function(self._values)) + else: + return Batch(function(*self._values)) + + def __repr__(self) -> str: + return f"Batch[atomic={self.atomic!r}]({self._values!r})" + + def __iter__(self): + if self.atomic: + yield self._values + else: + yield from self._values + + def __len__(self) -> int: + return 1 if self.atomic else len(self._values) + + def __getitem__(self, index: int): + if not self.atomic: + return self._values[index] + + if index != 0: + raise IndexError("atomic batch allows index 0 only") + + return self._values + + # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload". + @typing.overload + def __setitem__(self, index: int, value: Tensor) -> None: + ... + + @typing.overload + def __setitem__(self, index: slice, value: Tensors) -> None: + ... + + def __setitem__(self, index: Union[int, slice], value) -> None: + if isinstance(index, int): + self._setitem_by_index(index, value) + else: + self._setitem_by_slice(index, value) + + def _setitem_by_index(self, index: int, value) -> None: + if not self.atomic: + i = index + self._values = self._values[:i] + (value,) + self._values[i + 1 :] # type: ignore[operator] + return + + if index != 0: + raise IndexError("atomic batch allows index 0 only") + + self._values = value + + def _setitem_by_slice(self, index: slice, value) -> None: + if not (index.start is index.stop is index.step is None): # noqa: E714 + raise NotImplementedError("only slice [:] supported") + + if not self.atomic: + self._values = value + return + + if len(value) != 1: + raise IndexError("atomic batch cannot be replaced with multiple tensors") + + self._values = value[0] + + +def check(first_device, *inputs) -> None: + """ + Checks whether the input contains at least one tensor and each tensor is + on the same device as the first partition. + + Raises: + ValueError: input does not contain at least one tensor + + """ + + if not any(torch.is_tensor(input) for input in inputs): + raise TypeError(f'inputs do not have any tensors: {inputs}') + if any(torch.is_tensor(input) and input.device != first_device for input in inputs): + raise ValueError('All inputs should be on the same device as the first partition') + + +def scatter(*inputs, chunks: int) -> List[Batch]: + """Splits an input mini-batch into multiple micro-batches.""" + if len(inputs) == 1 and isinstance(inputs[0], Tensor): + return [Batch(x) for x in inputs[0].chunk(chunks)] + + batches: List[Any] = [[] for _ in range(chunks)] + # Actual number of chunks produced + num_chunks = -1 + for input in inputs: + if torch.is_tensor(input): + # Chunk only tensors. + tensors = input.chunk(chunks) + + # Validate number of chunks equal across all inputs. + if num_chunks != -1 and num_chunks != len(tensors): + raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}') + num_chunks = len(tensors) + + for i, tensor in enumerate(tensors): + batches[i].append(tensor) + else: + # Replicate non-tensors or tensors wrapped with 'NoChunk'. + for i in range(chunks): + if isinstance(input, NoChunk): + # Extract the tensor out. + batches[i].append(input.tensor) + else: + batches[i].append(input) + + # Truncate to actual number of chunks + batches = batches[:num_chunks] + + return [Batch(x) for x in batches] + + +def gather(outputs: List[Batch]): + """Concatenates output micro-batches into a mini-batch.""" + output: Any + + if outputs[0].atomic: + tensors = tuple(b.tensor for b in outputs) + output = torch.cat(tensors) + else: + output_buf: List[Any] = [] + for i in range(len(outputs[0])): + output_type = type(outputs[0][i]) + current_outputs = [] + for batch in outputs: + if output_type != type(batch[i]): + raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}') + current_outputs.append(batch[i]) + + if torch.is_tensor(outputs[0][i]): + output_buf.append(torch.cat(current_outputs)) + else: + output_buf.append(current_outputs) + + output = tuple(output_buf) + + return output diff --git a/torch/distributed/pipeline/sync/phony.py b/torch/distributed/pipeline/sync/phony.py new file mode 100644 index 000000000000..012926699cfb --- /dev/null +++ b/torch/distributed/pipeline/sync/phony.py @@ -0,0 +1,50 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Provides phony for arbitrary dependency in a autograd graph.""" +from typing import Dict, List, Tuple + +import torch +from torch import Tensor + +from .stream import default_stream, use_stream + +__all__: List[str] = ["get_phony"] + + +_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} + + +def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: + """Get a phony. Phony is tensor without space. + + It is useful to make arbitrary dependency in a autograd graph because it doesn't require any + gradient accumulation. + + .. note:: + + Phonies for each device are cached. If an autograd function gets a phony + internally, the phony must be detached to be returned. Otherwise, the + autograd engine will mutate the cached phony in-place:: + + class Phonify(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + phony = get_phony(input.device, requires_grad=False) + return phony.detach() # detach() is necessary. + + """ + key = (device, requires_grad) + + try: + phony = _phonies[key] + except KeyError: + with use_stream(default_stream(device)): + phony = torch.empty(0, device=device, requires_grad=requires_grad) + + _phonies[key] = phony + + return phony diff --git a/torch/distributed/pipeline/sync/pipe.py b/torch/distributed/pipeline/sync/pipe.py new file mode 100644 index 000000000000..5e61341d9ad9 --- /dev/null +++ b/torch/distributed/pipeline/sync/pipe.py @@ -0,0 +1,490 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""The Pipe interface.""" +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Union, Sequence, Tuple, cast + +import torch +from torch import Tensor, nn +from torch.distributed.rpc import RRef +import torch.autograd +import torch.cuda + +from . import microbatch +from .batchnorm import DeferredBatchNorm +from .pipeline import Pipeline +from .skip.layout import inspect_skip_layout +from .skip.skippable import verify_skippables +from .stream import AbstractStream, new_stream + +__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"] + + +Device = Union[torch.device, int, str] +Devices = Union[Iterable[Device], List[Device]] + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + +if TYPE_CHECKING: + # Typechecking: nn.Module is not a Generic + Module = nn.Module[TensorOrTensors] # type: ignore[type-arg] + NamedModules = OrderedDict[str, Module] +else: + Module = nn.Module + NamedModules = OrderedDict + + +def _recommend_auto_balance(message: str) -> str: + """Expands a message with recommendation to :mod:`torchpipe.balance`.""" + return f"""{message} + +If your model is still under development, its optimal balance would change +frequently. In this case, we highly recommend 'torch.distributed.pipeline.sync.balance' for +naive automatic balancing: + + from torch.distributed.pipeline.sync import Pipe + from torch.distributed.pipeline.sync.balance import balance_by_time + + partitions = torch.cuda.device_count() + sample = torch.empty(...) + balance = balance_by_time(partitions, model, sample) + + model = Pipe(model, balance, ...) +""" + + +def _verify_module(module: nn.Sequential) -> None: + if not isinstance(module, nn.Sequential): + raise TypeError("module must be nn.Sequential to be partitioned") + + named_children = list(module.named_children()) + if len(named_children) != len(module): + raise ValueError("module with duplicate children is not supported") + + +def _verify_splitting( + module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device] +) -> None: + num_parameters = len(list(module.parameters())) + num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) + if num_parameters == num_child_parameters: + return + + for i in range(len(partitions)): + for j in range(i + 1, len(partitions)): + parti = partitions[i] + partj = partitions[j] + if devices[i] == devices[j]: + continue + for p in parti.parameters(): + for q in partj.parameters(): + if p is q: + raise ValueError("module with duplicate parameters on distinct devices is not supported") + + +class BalanceError(ValueError): + pass + + +def _retrieve_device(module: nn.Module) -> torch.device: + """Validates all parameters in the Module have the same device and returns + the appropriate device. + + Args: + An ``nn.Module`` to process. + + Returns: + ``torch.Device`` for the entire module. + + Raises: + ValueError: + If devices for ``nn.Module`` parameters are not all same. + """ + + device = None + for parameter in module.parameters(): + if device is None: + device = parameter.device + elif device != parameter.device: + raise ValueError( + f'nn.Module: {module}, should have all parameters on a single device,' + ' please use .to() to place the module on a single device') + + return device if device is not None else torch.device("cpu") + + +class PipeSequential(nn.Sequential): + """ + Pipe variant of ``nn.Sequential`` which supports multiple inputs. + """ + + def forward(self, *inputs): + for module in self: + if isinstance(inputs, Tuple): # type: ignore[arg-type] + inputs = module(*inputs) + else: + # Don't expand single variables (ex: lists/Tensor) + inputs = module(inputs) + return inputs + + +class WithDevice(nn.Module): + """ + Wraps an ``nn.Module`` which is part of ``nn.Sequential`` passed into :class:`Pipe` + that overrides the device for that module. In cases where :class:`Pipe` + can't implicitly determine the device for the module and places it on CPU, + this wrapper can be used to override the implicit behavior and explicitly + specify which device a module should run on. + + The provided module is also moved to the given device via ``.to(device)`` + by :class:`Pipe` + + Args: + module(:class:`torch.nn.Module`): The module to be wrapped. + device(:class:`torch.device`): The device to run the module on. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> fc1 = nn.Linear(16, 8).cuda(0) + >>> fc2 = nn.Linear(8, 4).cuda(1) + >>> dropout = nn.Dropout() + >>> + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) + >>> # Dropout does not have any parameters/buffers, but we want to + >>> # run it on cuda:1 to avoid any GPU to CPU transfers. + >>> model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1')) + >>> # xdoctest: +SKIP("Needs RPC framework init") + >>> model = Pipe(model, chunks=8) + """ + def __init__(self, module: nn.Module, device: torch.device): + super().__init__() + self._module = module + self._device = torch.device(device) + + def forward(self, *args, **kwargs): + return self._module(*args, **kwargs) + + @property + def module(self): + return self._module + + @property + def device(self): + return self._device + + +def _assemble_partition(modules: List[nn.Module]): + modules_list: List[nn.Module] = [] + for module in modules: + if isinstance(module, nn.Sequential): + modules_list.extend(module.children()) + else: + modules_list.append(module) + return PipeSequential(*modules_list) + + +def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]: + partitions = [] + devices = [] + + current_partition = [] + current_device = None + for name, module in modules.named_children(): + if isinstance(module, WithDevice): + # Process device override and move module to appropriate device. + device = module.device + module = module.module + module.to(device) + else: + device = _retrieve_device(module) + if current_device is not None and (current_device != device or device.type == 'cpu'): + partitions.append(_assemble_partition(current_partition)) + devices.append(current_device) + current_partition = [] + current_device = device + current_partition.append(module) + + if current_device is not None: + partitions.append(_assemble_partition(current_partition)) + devices.append(current_device) + + partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) + + return partitions, devices + + +MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement") + + +class Pipe(Module): + """Wraps an arbitrary :class:`nn.Sequential ` module + to train on using synchronous pipeline parallelism. If the module requires + lots of memory and doesn't fit on a single GPU, pipeline parallelism is a + useful technique to employ for training. + + The implementation is based on the torchgpipe_ paper. + + .. _torchgpipe: https://arxiv.org/abs/2004.09910 + + Pipe combines pipeline parallelism with checkpointing to reduce peak + memory required to train while minimizing device under-utilization. + + You should place all the modules on the appropriate devices and wrap them + into an :class:`nn.Sequential ` module defining the + desired order of execution. If a module does not contain any + parameters/buffers, it is assumed this module should be executed on CPU + and appropriate input tensors to the module are moved to CPU before + execution. This behavior can be overridden by the :class:`WithDevice` + wrapper which can be used to explicitly specify which device a module + should run on. + + Args: + module (:class:`nn.Sequential `): + sequential module to be parallelized using pipelining. Each module + in the sequence has to have all of its parameters on a single + device. Each module in the sequence has to either be an nn.Module + or :class:`nn.Sequential ` (to combine multiple + sequential modules on a single device) + chunks (int): + number of micro-batches (default: ``1``) + checkpoint (str): + when to enable checkpointing, one of ``'always'``, + ``'except_last'``, or ``'never'`` (default: ``'except_last'``). + ``'never'`` disables checkpointing completely, ``'except_last'`` + enables checkpointing for all micro-batches except the last one + and ``'always'`` enables checkpointing for all micro-batches. + deferred_batch_norm (bool): + whether to use deferred ``BatchNorm`` moving statistics (default: + :data:`False`). If set to :data:`True`, we track statistics across + multiple micro-batches to update the running statistics per + mini-batch. + + Raises: + TypeError: + the module is not a :class:`nn.Sequential `. + ValueError: + invalid arguments + + Example:: + Pipeline of two FC layers across GPUs 0 and 1. + + >>> # Need to initialize RPC framework first. + >>> # xdoctest: +SKIP + >>> os.environ['MASTER_ADDR'] = 'localhost' + >>> os.environ['MASTER_PORT'] = '29500' + >>> torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1) + >>> + >>> # Build pipe. + >>> fc1 = nn.Linear(16, 8).cuda(0) + >>> fc2 = nn.Linear(8, 4).cuda(1) + >>> model = nn.Sequential(fc1, fc2) + >>> model = Pipe(model, chunks=8) + >>> input = torch.rand(16, 16).cuda(0) + >>> output_rref = model(input) + + .. note:: + You can wrap a :class:`Pipe` model with + :class:`torch.nn.parallel.DistributedDataParallel` only when the + checkpoint parameter of :class:`Pipe` is ``'never'``. + + .. note:: + :class:`Pipe` only supports intra-node pipelining currently, but + will be expanded to support inter-node pipelining in the future. + The forward function returns an :class:`~torch.distributed.rpc.RRef` + to allow for inter-node pipelining in the future, where the output + might be on a remote host. For intra-node pipelining you can use + :meth:`~torch.distributed.rpc.RRef.local_value` to retrieve the + output locally. + + .. warning:: + :class:`Pipe` is experimental and subject to change. + """ + + def __init__( + self, + module: nn.Sequential, + chunks: int = 1, + checkpoint: str = "except_last", + deferred_batch_norm: bool = False, + ) -> None: + super().__init__() + + # Check if RPC framework is initialized. + if not torch.distributed.rpc._is_current_rpc_agent_set(): + raise RuntimeError( + 'Please initialize RPC framework for Pipe using ' + 'torch.distributed.rpc.init_rpc') + + chunks = int(chunks) + checkpoint = str(checkpoint) + + if chunks <= 0: + raise ValueError("number of chunks must be positive integer") + if checkpoint not in ["always", "except_last", "never"]: + raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") + + _verify_module(module) + + # Verify if the underlying skippable modules satisfy integrity. The + # integrity can be verified before forward() because it is static. + verify_skippables(module) + + self.chunks = chunks + self.checkpoint = checkpoint + + if deferred_batch_norm: + module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) + + self.partitions, self.devices = _split_module(module) + _verify_splitting(module, self.partitions, self.devices) + + self._copy_streams: List[List[AbstractStream]] = [] + self._skip_layout = inspect_skip_layout(self.partitions) + + # Separate CUDA streams for copy. + copy_streams = self._ensure_copy_streams() + + # The micro-batch index where the checkpointing stops. + checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] + + self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) + + def __len__(self) -> int: + """Counts the length of the underlying sequential module.""" + return sum(len(p) for p in self.partitions) + + def __getitem__(self, index: int) -> nn.Module: + """Gets a layer in the underlying sequential module.""" + partitions = self.partitions + if index < 0: + partitions = partitions[::-1] + + for partition in partitions: + try: + return partition[index] + except IndexError: + pass + + shift = len(partition) + + if index < 0: + index += shift + else: + index -= shift + + raise IndexError + + def __iter__(self) -> Iterator[nn.Module]: + """Iterates over children of the underlying sequential module.""" + for partition in self.partitions: + yield from partition + + # Pipe should manage the device of each partition. + # Deny cuda(), cpu(), and to() with device, by TypeError. + def cuda(self, device: Optional[Device] = None) -> "Pipe": + raise MOVING_DENIED + + def cpu(self) -> "Pipe": + raise MOVING_DENIED + + def to(self, *args: Any, **kwargs: Any) -> "Pipe": + # Deny these usages: + # + # - to(device[, dtype, non_blocking]) + # - to(tensor[, non_blocking]) + # + # But allow this: + # + # - to(dtype[, non_blocking]) + # + if "device" in kwargs or "tensor" in kwargs: + raise MOVING_DENIED + + if args: + if isinstance(args[0], (torch.device, int, str)): + raise MOVING_DENIED + if torch.is_tensor(args[0]): + raise MOVING_DENIED + + return super().to(*args, **kwargs) + + def _ensure_copy_streams(self) -> List[List[AbstractStream]]: + """Ensures that :class:`Pipe` caches CUDA streams for copy. + + It's worth to cache CUDA streams although PyTorch already manages a + pool of pre-allocated CUDA streams, because it may reduce GPU memory + fragmentation when the number of micro-batches is small. + + """ + if not self._copy_streams: + for device in self.devices: + self._copy_streams.append([new_stream(device) for _ in range(self.chunks)]) + + return self._copy_streams + + def forward(self, *inputs) -> RRef: + """ + Processes a single input mini-batch through the pipe and returns an + :class:`~torch.distributed.rpc.RRef` pointing to the output. + :class:`Pipe` is a fairly transparent module wrapper. It doesn't + modify the input and output signature of the underlying module. But + there's type restriction. Input and output have to contain at least one + tensor. This restriction is applied at partition boundaries too. + + The sequence of inputs are fed into the first stage of the pipeline as + ``*inputs``. As a result the positional args for this function should + match the positional args for the first stage of the pipeline. The same + condition applies for output of one stage of the pipeline which is the + input for the next stage. + + The input tensor is split into multiple micro-batches based on the + ``chunks`` parameter used to initialize :class:`Pipe`. The batch size + is assumed to be the first dimension of the tensor and if the batch + size is less than ``chunks``, the number of micro-batches is equal to + the batch size. + + Only tensors are split into multiple micro-batches, non-Tensor inputs + are just replicated as-is in each micro-batch. For non-Tensor outputs + in the last stage of the pipeline, they are aggregated as a ``List`` + and returned the user. For example, if you have 2 micro-batches + returning the integer 5, the user would receive the consolidated + output of `[5, 5]` + + All the input tensors need to be on the same device as the first + partition of the pipeline. + + If a tensor is wrapped with the :class:`NoChunk` wrapper, the tensor + is not split across micro-batches and is replicated as-is similar to + non-tensors. + + Args: + inputs: input mini-batch + + Returns: + :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch + + Raises: + TypeError: input doesn't contain at least one tensor + + """ + first_partition_device = self.devices[0] if len(self.devices) != 0 else torch.device("cpu") + microbatch.check(first_partition_device, *inputs) + + if not self.devices: + # Empty sequential module is not illegal. + return RRef(*inputs) + + # Divide a mini-batch into micro-batches. + batches = microbatch.scatter(*inputs, chunks=self.chunks) + + # Run pipeline parallelism. + self.pipeline.run(batches) + + # Merge the micro-batches into one mini-batch. + output = microbatch.gather(batches) + return RRef(output) diff --git a/torch/distributed/pipeline/sync/pipeline.py b/torch/distributed/pipeline/sync/pipeline.py new file mode 100644 index 000000000000..7cd5e5831169 --- /dev/null +++ b/torch/distributed/pipeline/sync/pipeline.py @@ -0,0 +1,255 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""The pipeline parallelism of Pipe.""" +from queue import Queue +from types import TracebackType +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence + +import torch +from torch import Tensor, nn +from torch.autograd.profiler import record_function + +from .checkpoint import Checkpointing +from .copy import Copy, Wait +from .dependency import fork, join +from .microbatch import Batch +from .skip.layout import SkipLayout +from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker +from .stream import AbstractStream, current_stream, use_device +from .worker import Task, create_workers + +__all__: List[str] = ["Pipeline"] + + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + +ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] + +# Queue is generic only in stubs. +# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime +if TYPE_CHECKING: + InQueue = Queue[Optional["Task"]] + OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] +else: + InQueue = Queue + OutQueue = Queue + + +def _depend(fork_from: Batch, join_to: Batch) -> None: + fork_from_idx = fork_from.find_tensor_idx() + join_to_idx = join_to.find_tensor_idx() + + fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx]) + join_to[join_to_idx] = join(join_to[join_to_idx], phony) + + +def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: + batch[:] = Copy.apply(prev_stream, next_stream, *batch) + # Gradients are only supported for float Tensors. + batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) + + +def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: + batch[:] = Wait.apply(prev_stream, next_stream, *batch) + # Gradients are only supported for float Tensors. + batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) + + +def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]: + """Generate schedules for each clock cycle.""" + # m: number of micro-batches + # n: number of partitions + # i: index of micro-batch + # j: index of partition + # k: clock number + # + # k (i,j) (i,j) (i,j) + # - ----- ----- ----- + # 0 (0,0) + # 1 (1,0) (0,1) + # 2 (2,0) (1,1) (0,2) + # 3 (2,1) (1,2) + # 4 (2,2) + for k in range(m + n - 1): + yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] + + +class Pipeline: + """The pipeline parallelism for Pipe.""" + + def __init__( + self, + partitions: List[nn.Sequential], + devices: List[torch.device], + copy_streams: List[List[AbstractStream]], + skip_layout: SkipLayout, + checkpoint_stop: int, + ) -> None: + self.partitions = partitions + self.devices = devices + self.copy_streams = copy_streams + self.skip_layout = skip_layout + self.checkpoint_stop = checkpoint_stop + (self.in_queues, self.out_queues) = create_workers(devices) + + def run(self, batches: List[Batch]) -> None: + """Runs pipeline parallelism. + + It modifies the given batches in place. + + """ + partitions = self.partitions + devices = self.devices + skip_layout = self.skip_layout + + m = len(batches) + n = len(partitions) + + skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] + + for schedule in _clock_cycles(m, n): + self.fence(batches, schedule, skip_trackers) + self.compute(batches, schedule, skip_trackers) + + def fence( + self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], + ) -> None: + """Copy micro-batches after computation for the previous micro-batches.""" + copy_streams = self.copy_streams + skip_layout = self.skip_layout + + for i, j in schedule: + # Ensure that batches[i-1] is executed after batches[i] in + # backpropagation by an explicit dependency. + if i != 0 and j != 0: + _depend(batches[i - 1], batches[i]) + + next_stream = copy_streams[j][i] + + for prev_j, ns, name in skip_layout.copy_policy(j): + prev_stream = copy_streams[prev_j][i] + skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name) + + if j != 0: + prev_stream = copy_streams[j - 1][i] + _copy(batches[i], prev_stream, next_stream) + + def compute( + self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], + ) -> None: + """Run tasks with synchronization to copy streams.""" + partitions = self.partitions + devices = self.devices + copy_streams = self.copy_streams + checkpoint_stop = self.checkpoint_stop + + # Disable checkpointing if in eval mode. + if not self.partitions[0].training: + checkpoint_stop = 0 + + n = len(partitions) + streams = [current_stream(d) for d in devices] + exc_info: Optional[ExcInfo] = None + + # With checkpointing, the autograd graph looks like this diagram: + # +-----+------+ + # | Copy | + # +-----+------+ (fence) + # - - - + - - - - - - - - - + # | (compute) + # +-----+------+ + # | Wait | [1] Synchronize the current stream with the copy stream. + # +-----+------+ + # +-----+------+ + # | Checkpoint | [2] Compute a partition within checkpointing. + # +-----+------+ + # +-----+------+ + # | Wait | [3] Synchronize the copy stream with the current stream. + # +-----+------+ + # + - - - + + # | +-----+-----+ + # | | Recompute | [4] Schedule the recomputation at backpropagation. + # | +-----+-----+ + # + - - - + + # | + # - - - + - - - - - - - - - + # +-----+------+ (fence) + # | Copy | + # +-----+------+ + for i, j in schedule: + batch = batches[i] + partition = partitions[j] + + # Synchronize with the copied input. ([1] in the diagram) + if j != 0: + _wait(batch, copy_streams[j][i], streams[j]) + + # Determine whether checkpointing or not. + checkpoint = i < checkpoint_stop + if checkpoint: + + def function( + *inputs, + partition: nn.Module = partition, + skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], + chunk_id: int = i, + part_id: int = j, + ) -> TensorOrTensors: + with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): + return partition(*inputs) + + chk = Checkpointing(function, batch) # type: ignore[arg-type] + task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) + del function, chk + + else: + + def compute( + batch: Batch = batch, + partition: nn.Module = partition, + skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], + chunk_id: int = i, + part_id: int = j, + ) -> Batch: + with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): + return batch.call(partition) + + task = Task(streams[j], compute=compute, finalize=None) + del compute + + # Compute tasks in parallel. ([2] in the diagram) + self.in_queues[j].put(task) + + for i, j in schedule: + ok, payload = self.out_queues[j].get() + + # Hold the first exception. + if exc_info is not None: + continue + elif not ok: + exc_info = cast(ExcInfo, payload) + continue + + task, batch = cast(Tuple[Task, Batch], payload) + + # The copy stream synchronizes to copy the output. ([3] in the + # diagram) + if j != n - 1: + _wait(batch, streams[j], copy_streams[j][i]) + + # Finalize tasks. If checkpointing is enabled, here the + # recomputation is scheduled at backpropagation. ([4] in the + # diagram) + with use_device(devices[j]): + task.finalize(batch) + + batches[i] = batch + + # Fail at the first exception. + if exc_info is not None: + raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) diff --git a/torch/distributed/pipeline/sync/py.typed b/torch/distributed/pipeline/sync/py.typed new file mode 100644 index 000000000000..ab03724cafbf --- /dev/null +++ b/torch/distributed/pipeline/sync/py.typed @@ -0,0 +1,6 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/skip/__init__.py b/torch/distributed/pipeline/sync/skip/__init__.py new file mode 100644 index 000000000000..bdcb913867a7 --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Supports efficiency with skip connections.""" +from .namespace import Namespace +from .skippable import pop, skippable, stash, verify_skippables + +__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"] diff --git a/torch/distributed/pipeline/sync/skip/layout.py b/torch/distributed/pipeline/sync/skip/layout.py new file mode 100644 index 000000000000..04d76d34ea16 --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/layout.py @@ -0,0 +1,92 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Static skip connection layout of ``@skippable`` modules.""" +from typing import Dict, Iterable, List, Tuple + +from torch import nn + +from .namespace import Namespace + +__all__: List[str] = [] + + +class SkipLayout: + """Represents a skip connection layout across partitions.""" + + # Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...} + by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]] + + # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] + by_partition: List[List[Tuple[int, Namespace, str]]] + + def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: + # The skip routes are already indexed by 'ns, name'. + self.by_ns_name = skip_routes + + # Index skip routes by partition number 'j'. + self.by_partition = [[] for _ in range(num_partitions)] + + for (ns, name), (prev_j, next_j) in skip_routes.items(): + self.by_partition[next_j].append((prev_j, ns, name)) + + for p in self.by_partition: + p.sort() + + def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]: + """Generates skip routes for the given destination partition number. + The skip routes are sorted by source partition number in ascending + order. + + Yields: + Each tuple of (source partition number, namespace, name). + + """ + for prev_j, ns, name in self.by_partition[next_j]: + if prev_j == next_j: + # This skip tensor will be popped at the same partition where + # it is stashed. In this case, copy is not required. + continue + + yield (prev_j, ns, name) + + def requires_copy(self, ns: Namespace, name: str) -> bool: + """Whether the given namespace and name requires partition-to-partition + copy or not. + """ + prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1)) + return prev_j != next_j + + +def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout: + """Inspects the skip connection layout in the given partitions.""" + # NOTE(sublee): Hide circular import inside this subroutine. Circular + # import is not ideal but placing this logic near to SkipLayout may + # increase cohesion of code. + from .skippable import Skippable + + skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {} + stashed_at: Dict[Tuple[Namespace, str], int] = {} + + for j, partition in enumerate(partitions): + def inspect_layer(layer): + if not isinstance(layer, Skippable): + return + + for ns, name in layer.stashable(): + stashed_at[(ns, name)] = j + + for ns, name in layer.poppable(): + prev_j = stashed_at.pop((ns, name)) + skip_routes[(ns, name)] = (prev_j, j) + + if isinstance(partition, nn.Sequential): + for layer in partition: + inspect_layer(layer) + else: + inspect_layer(partition) + + return SkipLayout(len(partitions), skip_routes) diff --git a/torch/distributed/pipeline/sync/skip/namespace.py b/torch/distributed/pipeline/sync/skip/namespace.py new file mode 100644 index 000000000000..7d9c0d9b7d84 --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/namespace.py @@ -0,0 +1,50 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Provides isolated namespace of skip tensors.""" +import abc +from functools import total_ordering +from typing import Any +import uuid + +__all__ = ["Namespace"] + + +@total_ordering +class Namespace(metaclass=abc.ABCMeta): # noqa: B024 + """Namespace for isolating skip tensors used by :meth:`isolate() + `. + """ + + __slots__ = ("id",) + + def __init__(self) -> None: + self.id = uuid.uuid4() + + def __repr__(self) -> str: + return f"" + + def __hash__(self) -> int: + return hash(self.id) + + # Namespaces should support ordering, since SkipLayout will sort tuples + # including a namespace. But actual order between namespaces is not + # important. That's why they are ordered by version 4 UUID which generates + # random numbers. + def __lt__(self, other: Any) -> bool: + if isinstance(other, Namespace): + return self.id < other.id + return False + + def __eq__(self, other: object) -> bool: + if isinstance(other, Namespace): + return self.id == other.id + return False + + +# 'None' is the default namespace, +# which means that 'isinstance(None, Namespace)' is 'True'. +Namespace.register(type(None)) diff --git a/torch/distributed/pipeline/sync/skip/portal.py b/torch/distributed/pipeline/sync/skip/portal.py new file mode 100644 index 000000000000..335793f4cc13 --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/portal.py @@ -0,0 +1,231 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the +autograd engine. The shared context of three functions (:class:`PortalBlue`, +:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is +one of the most important feature of :mod:`torchpipe.skip`. + +The metaphor is inspired by Portal(tm) from Valve. + +""" +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + +from ..copy import Context as CopyContext +from ..copy import Copy +from ..phony import get_phony +from ..stream import AbstractStream, get_device + +__all__: List[str] = [] + + +class Portal: + """A portal for a tensor.""" + + def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None: + self.put_tensor(tensor, tensor_life) + self.grad: Optional[Tensor] = None + + def blue(self) -> Tensor: + """Creates a :class:`PortalBlue` which hides the underlying tensor from + the autograd engine. + + Join the returning phony to the main lane of the autograd graph to + assure the correct backpropagation:: + + PortalBlue --+ + | + ---------- Join -- + + """ + tensor = self.use_tensor() + + if tensor is None: + return get_phony(torch.device("cpu"), requires_grad=False) + + return PortalBlue.apply(self, tensor) + + def orange(self, phony: Tensor) -> Optional[Tensor]: + """Creates a :class:`PortalOrange` which retrieves the hidden tensor + without losing ability of backpropagation. + + Give a phony forked from the main lane of an autograd graph:: + + +-- PortalOrange --+ + | | + -- Fork --------- f(a, b) -- + + """ + self.check_tensor_life() + + if self.tensor is None: + return self.use_tensor() + + return PortalOrange.apply(self, phony) + + def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor: + """Copies the hidden tensor by a :class:`PortalCopy`. + + Give a phony and use the returning phony to keep backpropagation:: + + +-- PortalCopy --+ + | | + -- Fork ---------- Join -- + + """ + if self.tensor is None: + return get_phony(torch.device("cpu"), requires_grad=False) + + return PortalCopy.apply(self, prev_stream, next_stream, phony) + + def check_tensor_life(self) -> None: + if self.tensor_life <= 0: + raise RuntimeError("tensor in portal has been removed") + + def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None: + """Stores a tensor into this portal.""" + # [Life of Tensor through Portal] + # + # The tensor can be retrieved by use_tensor() up to 'tensor_life' + # times. When the life becomes 0, the tensor will be deleted for + # deallocation in CUDA memory. + # + # The below events participate in a tensor through a portal. + # Note that [x] denotes the events which call use_tensor(): + # + # 1. [x] blue() + # 2. [ ] PortalBlue.forward + # 3. [ ] copy() + # 4. [ ] PortalCopy.forward + # 5. [ ] orange() + # 6. [x] PortalOrange.forward + # - - - - - - - - - - - - - - - - - - - - - - - - - - - + # 7. [ ] orange() (recomputed) + # 8. [x] PortalOrange.forward (recomputed) + # 9. [ ] PortalOrange.backward + # 10. [ ] PortalCopy.backward + # 11. [x] blue() (recomputed) + # 12. [ ] PortalBlue.forward (recomputed) + # 13. [ ] PortalBlue.backward + # + self.tensor_life = tensor_life + + if tensor_life > 0: + self.tensor = tensor + else: + self.tensor = None + + def use_tensor(self) -> Optional[Tensor]: + """Retrieves the underlying tensor and decreases the tensor life. When + the life becomes 0, it the tensor will be removed. + """ + self.check_tensor_life() + + tensor = self.tensor + + self.tensor_life -= 1 + + if self.tensor_life <= 0: + self.tensor = None + + return tensor + + def put_grad(self, grad: Tensor) -> None: + """Stores a gradient into this portal.""" + self.grad = grad + + def use_grad(self) -> Tensor: + """Retrieves and removes the underlying gradient. The gradient is + always ephemeral. + """ + if self.grad is None: + raise RuntimeError("grad in portal has been removed or never set") + + grad = self.grad + self.grad = None + return grad + + +# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and +# :class:`PortalCopy`. +class Context(CopyContext): + portal: Portal + + +class PortalBlue(torch.autograd.Function): + """Hides a tensor from the autograd engine by a :class:`Portal`.""" + + @staticmethod + # type: ignore[override] + def forward( + ctx: Context, + portal: Portal, + # This tensor must be retrieved by portal.use_tensor(). + tensor: Tensor, + ) -> Tensor: + ctx.portal = portal + + phony = get_phony(tensor.device, requires_grad=False) + return phony.detach() + + @staticmethod + # type: ignore[override] + def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]: + # The paired PortalOrange should keep the gradient. + grad = ctx.portal.use_grad() + return None, grad + + +class PortalOrange(torch.autograd.Function): + """Retrieves the hidden tensor from a :class:`Portal`.""" + + @staticmethod + # type: ignore[override] + def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor: + ctx.portal = portal + + tensor = portal.use_tensor() + assert tensor is not None + + return tensor.detach() + + @staticmethod + def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore[override] + # The paired PortalBlue will use the gradient. + ctx.portal.put_grad(grad) + return None, None + + +class PortalCopy(torch.autograd.Function): + """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden + tensor with copied one. + """ + + @staticmethod + # type: ignore[override] + def forward( + ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor, + ) -> Tensor: + ctx.portal = portal + + assert portal.tensor is not None + (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor) + + phony = get_phony(get_device(next_stream), requires_grad=False) + return phony.detach() + + @staticmethod + # type: ignore[override] + def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]: + portal = ctx.portal + + assert portal.grad is not None + _, _, portal.grad = Copy.backward(ctx, portal.grad) + + return None, None, None, None diff --git a/torch/distributed/pipeline/sync/skip/skippable.py b/torch/distributed/pipeline/sync/skip/skippable.py new file mode 100644 index 000000000000..9d4db76c6b67 --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/skippable.py @@ -0,0 +1,431 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""The user interface to define skip connections.""" +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + FrozenSet, + Generator, + Iterable, + List, + Optional, + Set, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from torch import Tensor, nn + +from ..microbatch import Batch +from .namespace import Namespace +from .tracker import current_skip_tracker + +__all__ = ["skippable", "stash", "pop", "verify_skippables"] + + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + +StashPop = Union["stash", "pop"] +StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors] +if TYPE_CHECKING: + # Typechecking: nn.Module is not a Generic + SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg] +else: + SkippableModule = nn.Module + +T = TypeVar("T", bound="Skippable") + + +class Skippable(nn.Module): + """The base class for skippable modules. + + Do not use this class directly. Define a subclass by :func:`skippable` + instead. + + """ + + module_cls: ClassVar[Type[SkippableModule]] + stashable_names: ClassVar[FrozenSet[str]] + poppable_names: ClassVar[FrozenSet[str]] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + self.module = self.module_cls(*args, **kwargs) # type: ignore[call-arg] + self.namespaces: Dict[str, Namespace] = {} + + def __repr__(self) -> str: + return f"@skippable({self.module})" + + def namespaced(self, name: str) -> Tuple[Namespace, str]: + """Prepend namespace for the given skip name.""" + ns = self.namespaces.get(name) + ns = cast(Namespace, ns) + return (ns, name) + + def stashable(self) -> Iterable[Tuple[Namespace, str]]: + """Iterate over namespaced skip names to be stashed.""" + for name in self.stashable_names: + yield self.namespaced(name) + + def poppable(self) -> Iterable[Tuple[Namespace, str]]: + """Iterate over namespaced skip names to be popped.""" + for name in self.poppable_names: + yield self.namespaced(name) + + def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T: + r"""Isolate a specified subset or the whole set of skip tensors. + + In a single sequential module, skip tensors with the same + name are not allowed unless they are isolated by different namespaces. + + Here's an example using the same name for skip tensors twice. Each pair + of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1`` + and ``ns2``. There is no conflict anymore:: + + ns1 = Namespace() + ns2 = Namespace() + + model = nn.Sequential( + Layer1().isolate(ns1), + Layer1().isolate(ns2), + Layer2(), + Layer3().isolate(ns2), + Layer3().isolate(ns1), + ) + + When `only` parameter is omitted, all skip tensors are isolated. You + can isolate a subset of skip tensors by passing `only` parameter:: + + ns_alice = Namespace() + ns_bob = Namespace() + + model = nn.Sequential( + ... + StashStashPop().isolate(ns_alice, only=['alice']) \ + .isolate(ns_bob, only=['bob']), + ... + ) + + Args: + ns (Namespace): + namespace for isolation + + Keyword Args: + only (iterable of strs): + names of specific skip tensors to be isolated (omit this option + to isolate all skip tensors declared in this module) + + Returns: + this module itself + + """ + names: Iterable[str] + + if only is None: + names = self.stashable_names | self.poppable_names + else: + names = set(only) + + for name in names: + self.namespaces[name] = ns + + return self + + def dispatch( + self, + input, + handle_stash: Callable[[str, Optional[Tensor]], None], + handle_pop: Callable[[str], Optional[Tensor]], + ): + """Dispatch :class:`stash` or :class:`pop` commands. + + The commands are generated by the module's ``forward()``. + """ + generator = self.module(input) + + if not isinstance(generator, Generator): + # The underlying module returned output without any yield. + output = generator + return output + + try: + op = next(generator) + + while True: + if isinstance(op, stash): + handle_stash(op.name, op.tensor) + op = next(generator) + continue + + if isinstance(op, pop): + tensor = handle_pop(op.name) + op = generator.send(tensor) + continue + + raise TypeError(f"{op!r} is not a command from @skippable") + + except StopIteration as stop: + output = stop.args[0] + return output + + def forward(self, input: Union[List[Any], Tensor]) -> TensorOrTensors: + """Perform the forward propagation. + + :class:`stash` or :class:`pop` commands will be handled by portals + silently. The portals won't be exposed to users. + + Raises: + RuntimeError: + illegal 'stash' or 'pop' is found. + + """ + skip_tracker = current_skip_tracker() + stashed_tensors: Dict[str, Optional[Tensor]] = {} + + # Load skip tensors that might be popped. + poppable_tensors = {} + batch = Batch(input) + for ns, name in self.poppable(): + try: + poppable_tensors[name] = skip_tracker.load(batch, ns, name) + except KeyError as e: + raise RuntimeError(f"'{name}' has not been stashed") from e + input = batch.values + + # Handle skip commands. + def handle_stash(name: str, tensor: Optional[Tensor]) -> None: + if name not in self.stashable_names: + raise RuntimeError(f"'{name}' has not been declared as stashable") + stashed_tensors[name] = tensor + + def handle_pop(name: str) -> Optional[Tensor]: + if name not in self.poppable_names: + raise RuntimeError(f"'{name}' has not been declared as poppable") + return poppable_tensors.pop(name) + + output = self.dispatch(input, handle_stash, handle_pop) + + # All declared skips must be stashed or popped. + not_stashed = self.stashable_names - stashed_tensors.keys() + if not_stashed: + comma_names = ", ".join(f"'{n}'" for n in not_stashed) + raise RuntimeError(f"{comma_names} must be stashed but have not") + + not_popped = poppable_tensors.keys() + if not_popped: + comma_names = ", ".join(f"'{n}'" for n in not_popped) + raise RuntimeError(f"{comma_names} must be popped but have not") + + # Save stashed skip tensors. + batch = Batch(output) + for ns, name in self.stashable(): + tensor = stashed_tensors[name] + skip_tracker.save(batch, ns, name, tensor) + output = batch.values + + return output + + +# TODO(sublee): Move to above of Skippable class for better read flow. +def skippable( + stash: Iterable[str] = (), pop: Iterable[str] = (), +) -> Callable[[Type[SkippableModule]], Type[Skippable]]: + """Define a decorator to create :class:`nn.Module ` with skip connections. + + These decorated modules are called "skippable". This functionality works perfectly + fine even when the module is not wrapped by :class:`~torch.distributed.pipeline.sync.Pipe`. + + Each skip tensor is managed by its name. Before manipulating skip tensors, + a skippable module must statically declare the names for skip tensors by + `stash` and/or `pop` parameters. Skip tensors with pre-declared name can be + stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield + pop(name)``. + + Here is an example with three layers. A skip tensor named "1to3" is stashed + and popped at the first and last layer, respectively:: + + @skippable(stash=['1to3']) + class Layer1(nn.Module): + def forward(self, input): + yield stash('1to3', input) + return f1(input) + + class Layer2(nn.Module): + def forward(self, input): + return f2(input) + + @skippable(pop=['1to3']) + class Layer3(nn.Module): + def forward(self, input): + skip_1to3 = yield pop('1to3') + return f3(input) + skip_1to3 + + model = nn.Sequential(Layer1(), Layer2(), Layer3()) + + One skippable module can stash or pop multiple skip tensors:: + + @skippable(stash=['alice', 'bob'], pop=['carol']) + class StashStashPop(nn.Module): + def forward(self, input): + yield stash('alice', f_alice(input)) + yield stash('bob', f_bob(input)) + carol = yield pop('carol') + return input + carol + + Every skip tensor must be associated with exactly one pair of `stash` and + `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this + restriction automatically when wrapping a module. You can also check the + restriction by :func:`verify_skippables` + without :class:`~torch.distributed.pipeline.sync.Pipe`. + + """ + stashable_names = frozenset(stash) + poppable_names = frozenset(pop) + + def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]: + name = module_cls.__name__ + bases = (Skippable,) + attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names} + return type(name, bases, attrs) + + return extend_skippable + + +class stash: + """The command to stash a skip tensor. + + :: + + def forward(self, input): + yield stash('name', input) + return f(input) + + Args: + name (str): name of skip tensor + input (torch.Tensor or None): tensor to pass to the skip connection + + """ + + __slots__ = ("name", "tensor") + + def __init__(self, name: str, tensor: Optional[Tensor]) -> None: + self.name = name + self.tensor = tensor + + +class pop: + """The command to pop a skip tensor. + + :: + + def forward(self, input): + skip = yield pop('name') + return f(input) + skip + + Args: + name (str): name of skip tensor + + Returns: + the skip tensor previously stashed by another layer under the same name + + """ + + __slots__ = ("name",) + + def __init__(self, name: str) -> None: + self.name = name + + +def verify_skippables(module: nn.Sequential) -> None: + """Verify if the underlying skippable modules satisfy integrity. + + Every skip tensor must have only one pair of `stash` and `pop`. If there + are one or more unmatched pairs, it will raise :exc:`TypeError` with the + detailed messages. + + Here are a few failure cases. :func:`verify_skippables` will report failure + for these cases:: + + # Layer1 stashes "1to3". + # Layer3 pops "1to3". + + nn.Sequential(Layer1(), Layer2()) + # +---- ? + + nn.Sequential(Layer2(), Layer3()) + # ? ----+ + + nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3()) + # +-------------------+ ^^^^^^ + + nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3()) + # ^^^^^^ +-------------------+ + + To use the same name for multiple skip tensors, they must be isolated by + different namespaces. See :meth:`isolate() + `. + + Raises: + TypeError: + one or more pairs of `stash` and `pop` are not matched. + + """ + stashed: Set[Tuple[Namespace, str]] = set() + popped: Set[Tuple[Namespace, str]] = set() + msgs: List[str] = [] + + for layer_name, layer in module.named_children(): + if not isinstance(layer, Skippable): + continue + + for name in layer.stashable_names & layer.poppable_names: + msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable" + msgs.append(msg) + + for ns, name in layer.stashable(): + if name in layer.poppable_names: + continue + + if (ns, name) in stashed: + msg = f"'{layer_name}' redeclared '{name}' as stashable but not isolated by namespace" + msgs.append(msg) + continue + + stashed.add((ns, name)) + + for ns, name in layer.poppable(): + if name in layer.stashable_names: + continue + + if (ns, name) in popped: + msg = f"'{layer_name}' redeclared '{name}' as poppable but not isolated by namespace" + msgs.append(msg) + continue + + if (ns, name) not in stashed: + msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed" + msgs.append(msg) + continue + + popped.add((ns, name)) + + for (_, name) in stashed - popped: + msg = f"no module declared '{name}' as poppable but stashed" + msgs.append(msg) + + if msgs: + raise TypeError( + "one or more pairs of stash and pop do not match:\n\n{}" "".format("\n".join(f"* {x}" for x in msgs)) + ) diff --git a/torch/distributed/pipeline/sync/skip/tracker.py b/torch/distributed/pipeline/sync/skip/tracker.py new file mode 100644 index 000000000000..8ac82bc05dc9 --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/tracker.py @@ -0,0 +1,180 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Tracks skip tensors on a thread.""" +from contextlib import contextmanager +import threading +from typing import Dict, Generator, List, Optional, Tuple + +from torch import Tensor + +from ..checkpoint import is_checkpointing +from ..dependency import fork, join +from ..microbatch import Batch +from ..stream import AbstractStream +from .layout import SkipLayout +from .namespace import Namespace +from .portal import Portal + +__all__: List[str] = [] + + +class SkipTracker: + """Tracks saved skip tensors. + + It will update the given micro-batch in place. This is because when it + manipulates the underlying skip tensors, the current micro-batch also has + to be connected with the skip tensors. + + One thread has one skip tracker. Call :func:`current_skip_tracker` to get + the skip tracker on the current thread. + + """ + + def __init__(self) -> None: + self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {} + + def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: + self.tensors[(ns, name)] = tensor + + def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: + return self.tensors.pop((ns, name)) + + def copy( + self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, + ) -> None: + raise TypeError("copy is not supported for non-portal skip tensors") + + +class SkipTrackerThroughPotals(SkipTracker): + """Tracks saved skip tensors through portals. The skip tensors will be + hidden in portals so that the autograd engine does not need to track them. + + This tracker is only used when the training or evaluating module is wrapped + with :class:`torchpipe.Pipe`. + + """ + + def __init__(self, skip_layout: SkipLayout) -> None: + super().__init__() + self.skip_layout = skip_layout + self.portals: Dict[Tuple[Namespace, str], Portal] = {} + + def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: + """Saves the stashed skip tensor in a portal. The portal is then + connected to the given micro-batch with :class:`Join`. + """ + if not self.skip_layout.requires_copy(ns, name): + super().save(batch, ns, name, tensor) + return + + # See [Tensor Life of Portal] at Portal.put_tensor() to understand the + # below tensor_life values. Here are the selected events which retrieve + # the tensor in portal: + # + # 1. [x] blue() + # ... + # 6. [x] PortalOrange.forward + # ... + # 8. [x] PortalOrange.forward (recomputed) + # ... + # 11. [x] blue() (recomputed) + # + if (ns, name) not in self.portals: + if is_checkpointing(): + # Under checkpointing, the tensor used by the first + # PortalOrange should be alive in the portal. This tensor will + # be used again by the second PortalOrange during the + # recomputation. + tensor_life = 3 # Delete at [8. PortalOrange.forward (recomputed)] + else: + tensor_life = 2 # Delete at [6. PortalOrange.forward] + + portal = Portal(tensor, tensor_life) + self.portals[(ns, name)] = portal + + else: + # Under recomputation, the portal already exists. + portal = self.portals[(ns, name)] + + # The existing tensor life already became 0. It should be reset as + # 1 to delete the tensor after the second PortalBlue immediately. + tensor_life = 1 # Delete at [11. blue() (recomputed)] + + portal.put_tensor(tensor, tensor_life) + + phony = portal.blue() + tensor_idx = batch.find_tensor_idx() + batch[tensor_idx] = join(batch[tensor_idx], phony) + + def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: + """Loads a skip tensor from the corresponding portal to pop. The given + micro-batch is connected to the portal with :class:`Fork`. + """ + if not self.skip_layout.requires_copy(ns, name): + tensor = super().load(batch, ns, name) + return tensor + + portal = self.portals[(ns, name)] + tensor_idx = batch.find_tensor_idx() + batch[tensor_idx], phony = fork(batch[tensor_idx]) + tensor = portal.orange(phony) + return tensor + + def copy( + self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, + ) -> None: + """Copies the skip tensor in the corresponding portal. The given + micro-batch and the portal will be tied with :class:`Fork` and + :class:`Join`. + """ + assert self.skip_layout.requires_copy(ns, name) + + tensor_idx = batch.find_tensor_idx() + batch[tensor_idx], phony = fork(batch[tensor_idx]) + + portal = self.portals[(ns, name)] + phony = portal.copy(prev_stream, next_stream, phony) + + batch[tensor_idx] = join(batch[tensor_idx], phony) + + +class ThreadLocal(threading.local): + def __init__(self) -> None: + self.skip_tracker: Optional[SkipTracker] = None + + +thread_local = ThreadLocal() + + +@contextmanager +def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]: + """Registers the given skip tracker on the current thread within a + context:: + + with use_skip_tracker(my_skip_tracker): + ... + + """ + orig = thread_local.skip_tracker + + thread_local.skip_tracker = skip_tracker + + try: + yield + finally: + thread_local.skip_tracker = orig + + +def current_skip_tracker() -> SkipTracker: + """Gets the skip tracker on the current thread.""" + skip_tracker = thread_local.skip_tracker + + if skip_tracker is None: + skip_tracker = SkipTracker() + thread_local.skip_tracker = skip_tracker + + return skip_tracker diff --git a/torch/distributed/pipeline/sync/stream.py b/torch/distributed/pipeline/sync/stream.py new file mode 100644 index 000000000000..59fedf865a42 --- /dev/null +++ b/torch/distributed/pipeline/sync/stream.py @@ -0,0 +1,120 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Utilities for eliminating boilerplate code to handle abstract streams with +CPU device. +""" +from contextlib import contextmanager +from typing import Generator, List, Union, cast + +import torch + +__all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream", + "use_device", "use_stream", "get_device", "wait_stream", "record_stream", + "is_cuda", "as_cuda"] + + +class CPUStreamType: + pass + + +# The placeholder on place of streams for the CPU device instead of CUDA. +CPUStream = CPUStreamType() + +# It represents both CUDA streams and the CPU stream. +AbstractStream = Union[torch.cuda.Stream, CPUStreamType] + + +def new_stream(device: torch.device) -> AbstractStream: + """Creates a new stream for either CPU or CUDA device.""" + if device.type != "cuda": + return CPUStream + return torch.cuda.Stream(device) + + +def current_stream(device: torch.device) -> AbstractStream: + """:func:`torch.cuda.current_stream` for either CPU or CUDA device.""" + if device.type != "cuda": + return CPUStream + return torch.cuda.current_stream(device) + + +def default_stream(device: torch.device) -> AbstractStream: + """:func:`torch.cuda.default_stream` for either CPU or CUDA device.""" + if device.type != "cuda": + return CPUStream + return torch.cuda.default_stream(device) + + +@contextmanager +def use_device(device: torch.device) -> Generator[None, None, None]: + """:func:`torch.cuda.device` for either CPU or CUDA device.""" + if device.type != "cuda": + yield + return + + with torch.cuda.device(device): + yield + + +@contextmanager +def use_stream(stream: AbstractStream) -> Generator[None, None, None]: + """:func:`torch.cuda.stream` for either CPU or CUDA stream.""" + if not is_cuda(stream): + yield + return + + with torch.cuda.stream(as_cuda(stream)): + yield + + +def get_device(stream: AbstractStream) -> torch.device: + """Gets the device from CPU or CUDA stream.""" + if is_cuda(stream): + return as_cuda(stream).device + return torch.device("cpu") + + +def wait_stream(source: AbstractStream, target: AbstractStream) -> None: + """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It + makes the source stream wait until the target stream completes work queued. + """ + if is_cuda(target): + if is_cuda(source): + # A CUDA stream waits another CUDA stream. + as_cuda(source).wait_stream(as_cuda(target)) + else: + # CPU waits a CUDA stream. + as_cuda(target).synchronize() + + # If the target is CPU, synchronization is not required. + + +def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: + """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream.""" + if is_cuda(stream): + # NOTE(sublee): record_stream() on a shifted view tensor throws + # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely + # protect the tensor against unexpected reallocation, here we use a + # temporal tensor associated with the same storage without shifting as + # a workaround. + # + # Issue: https://github.com/pytorch/pytorch/issues/27366 + # + tensor = tensor.new_empty([0]).set_(tensor._typed_storage()) + + # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream + tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type] + + +def is_cuda(stream: AbstractStream) -> bool: + """Returns ``True`` if the given stream is a valid CUDA stream.""" + return stream is not CPUStream + + +def as_cuda(stream: AbstractStream) -> torch.cuda.Stream: + """Casts the given stream as :class:`torch.cuda.Stream`.""" + return cast(torch.cuda.Stream, stream) diff --git a/torch/distributed/pipeline/sync/utils.py b/torch/distributed/pipeline/sync/utils.py new file mode 100644 index 000000000000..210c475317e2 --- /dev/null +++ b/torch/distributed/pipeline/sync/utils.py @@ -0,0 +1,38 @@ +from torch import nn +from typing import List, Optional + +__all__ = ["partition_model"] + +def partition_model( + module: nn.Sequential, + balance: List[int], + devices: Optional[List[int]] = None): + """ + Partions the model accross multiple GPU devices. + + Given an :class:`nn.Sequential ` module, partitions + the model across multiple GPU devices according the provided ``balance`` + and ``devices``. + + Args: + module (:class:`nn.Sequential `): + Sequential model representing the pipe. + balance (List[int]): + List indicating the number of layers in each partition. + devices (List[int], optional): + List indicating the device to use for each partition. Defaults to + ``range(len(balance))`` + """ + device_idx = 0 + pipe_idx = 0 + balanced_pipe = [] + for num_layers in balance: + layers = [] + for i in range(num_layers): + layers.append(module[pipe_idx]) + pipe_idx += 1 + device = device_idx if devices is None else devices[device_idx] + balanced_pipe.append(nn.Sequential(*layers).to(device)) + device_idx += 1 + + return nn.Sequential(*balanced_pipe) diff --git a/torch/distributed/pipeline/sync/worker.py b/torch/distributed/pipeline/sync/worker.py new file mode 100644 index 000000000000..87b20c4a5551 --- /dev/null +++ b/torch/distributed/pipeline/sync/worker.py @@ -0,0 +1,132 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Multithreading in pipeline parallelism.""" +from contextlib import contextmanager +from queue import Queue +import sys +from threading import Thread +from types import TracebackType +from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast + +import torch + +from .microbatch import Batch +from .stream import AbstractStream, use_device, use_stream + +__all__: List[str] = ["Task", "worker", "create_workers", "spawn_workers"] + + +ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] + +# Queue is generic only in stubs. +# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime +if TYPE_CHECKING: + InQueue = Queue[Optional["Task"]] + OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] +else: + InQueue = Queue + OutQueue = Queue + + +class Task: + """A task represents how to compute a micro-batch on a partition. + + It consists of two parts: :meth:`compute` and :meth:`finalize`. + :meth:`compute` should be executed in worker threads concurrently. + :meth:`finalize` should be executed after when worker threads complete to + execute :meth:`compute`. + + :meth:`compute` might be boosted by worker threads. Because it produces + several CUDA API calls by user code. In PyTorch, parallel CUDA API calls + are not serialized through GIL. So more than one CUDA API call can be + produced at the same time. + + """ + + def __init__( + self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]], + ) -> None: + self.stream = stream + self._compute = compute + self._finalize = finalize + self._grad_enabled = torch.is_grad_enabled() + + def compute(self) -> Batch: + with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): + return self._compute() + + def finalize(self, batch: Batch) -> None: + if self._finalize is None: + return + with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): + self._finalize(batch) + + +def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None: + """Main loop of a worker thread.""" + with use_device(device): + while True: + task = in_queue.get() + + if task is None: + break + + try: + batch = task.compute() + except Exception: + exc_info = cast(ExcInfo, sys.exc_info()) + out_queue.put((False, exc_info)) + continue + + out_queue.put((True, (task, batch))) + + done = (False, None) + out_queue.put(done) + + +def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]: + """Spawns worker threads. A worker thread is bound to a device.""" + in_queues: List[InQueue] = [] + out_queues: List[OutQueue] = [] + + # Spawn workers. + workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {} + + def normalize_device(device: torch.device) -> torch.device: + if device.type == "cuda" and device.index is None: + return torch.device("cuda", index=torch.cuda.current_device()) + + if device.type == "cpu" and device.index is not None: + return torch.device("cpu") + + return device + + for device in devices: + device = normalize_device(device) + + try: + in_queue, out_queue = workers[device] + except KeyError: + in_queue = Queue() + out_queue = Queue() + workers[device] = (in_queue, out_queue) + + t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,) + t.start() + + in_queues.append(in_queue) + out_queues.append(out_queue) + + return (in_queues, out_queues) + +@contextmanager +def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]: + try: + (in_queues, out_queues) = create_workers(devices) + yield (in_queues, out_queues) + finally: + pass diff --git a/torch/testing/_internal/distributed/pipe_with_ddp_test.py b/torch/testing/_internal/distributed/pipe_with_ddp_test.py new file mode 100644 index 000000000000..1ed9f3cc96df --- /dev/null +++ b/torch/testing/_internal/distributed/pipe_with_ddp_test.py @@ -0,0 +1,149 @@ +# mypy: ignore-errors + +import torch +import torch.distributed as dist + +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) +from torch.testing._internal.common_distributed import ( + requires_gloo, + requires_nccl, + skip_if_lt_x_gpu, + skip_if_rocm, +) +from torch.distributed.pipeline.sync import Pipe + +class PipeWithDDPTest(RpcAgentTestFixture): + @property + def world_size(self) -> int: + return 2 + + @skip_if_lt_x_gpu(4) + @requires_nccl() + @dist_init + @skip_if_rocm + def test_basic_nccl_ckpt_never(self): + self._run_basic_test("nccl", "never") + + @skip_if_lt_x_gpu(4) + @requires_nccl() + @dist_init + @skip_if_rocm + def test_basic_nccl_ckpt_never_find_unused(self): + self._run_basic_test("nccl", "never", find_unused_parameters=True) + + @skip_if_lt_x_gpu(4) + @requires_nccl() + @dist_init + @skip_if_rocm + def test_basic_nccl_ckpt_always(self): + self._run_basic_test("nccl", "always", static_graph=True) + + @skip_if_lt_x_gpu(4) + @requires_nccl() + @dist_init + @skip_if_rocm + def test_basic_nccl_ckpt_except_last(self): + self._run_basic_test("nccl", "except_last", static_graph=True) + + @skip_if_lt_x_gpu(4) + @requires_gloo() + @dist_init + @skip_if_rocm + def test_basic_gloo_ckpt_never(self): + self._run_basic_test("gloo", "never") + + @skip_if_lt_x_gpu(4) + @requires_gloo() + @dist_init + @skip_if_rocm + def test_basic_gloo_ckpt_never_find_unused(self): + self._run_basic_test("gloo", "never", find_unused_parameters=True) + + @skip_if_lt_x_gpu(4) + @requires_gloo() + @dist_init + @skip_if_rocm + def test_basic_gloo_ckpt_always(self): + self._run_basic_test("gloo", "always", static_graph=True) + + @skip_if_lt_x_gpu(4) + @requires_gloo() + @dist_init + @skip_if_rocm + def test_basic_gloo_ckpt_except_last(self): + self._run_basic_test("gloo", "except_last", static_graph=True) + + def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False, static_graph=False): + dist.init_process_group( + backend=backend, + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), + world_size=self.world_size, + rank=self.rank, + ) + + # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another + # pipe between GPU 2 and 3. Both replicas are replicated via DDP. + fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank) + + class MyModule(nn.Module): + def __init__(self, device): + super().__init__() + self.fc2 = nn.Linear(8, 4, bias=False).cuda(device) + self.fc3 = nn.Linear(4, 2, bias=False).cuda(device) + + def forward(self, inp): + if find_unused_parameters: + return self.fc2(inp) + else: + return self.fc3(self.fc2(inp)) + + layer2 = MyModule(2 * self.rank + 1) + model = nn.Sequential( + fc1, + layer2 + ) + model = Pipe(model, chunks=2, checkpoint=checkpoint) + model = DistributedDataParallel( + model, + find_unused_parameters=find_unused_parameters, + static_graph=static_graph, + ) + + # Ensure inputs are different across ranks to verify that gradient + # sync indeed occurs. + model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) + out = model(model_input).local_value() + out.sum().backward() + + # Run forward again for find_unused_parameters to trigger any potential errors. + if find_unused_parameters: + # Ensure inputs are different across ranks to verify that gradient + # sync indeed occurs. + unused_param_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) + model(unused_param_input).local_value().sum().backward() + + # Run a few more iterations of fwd + bwd to ensure gradient synchronization + # occurs properly across iterations via delay_all_reduce/bucketized allreduce. + for _ in range(3): + model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) + out = model(model_input).local_value() + out.sum().backward() + + # Check grads + output = [torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad)] + dist.all_gather(output, fc1.weight.grad) + self.assertEqual(output[0], output[1]) + + output = [torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad)] + dist.all_gather(output, layer2.fc2.weight.grad) + self.assertEqual(output[0], output[1]) + + if not find_unused_parameters: + output = [torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad)] + dist.all_gather(output, layer2.fc3.weight.grad) + self.assertEqual(output[0], output[1]) diff --git a/torch/testing/_internal/distributed/pipeline/__init__.py b/torch/testing/_internal/distributed/pipeline/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/testing/_internal/distributed/rpc_utils.py b/torch/testing/_internal/distributed/rpc_utils.py index 5b6e2c90770f..cdbbdcfd0681 100644 --- a/torch/testing/_internal/distributed/rpc_utils.py +++ b/torch/testing/_internal/distributed/rpc_utils.py @@ -16,6 +16,9 @@ DdpComparisonTest, DdpUnderDistAutogradTest, ) +from torch.testing._internal.distributed.pipe_with_ddp_test import ( + PipeWithDDPTest, +) from torch.testing._internal.distributed.nn.api.remote_module_test import ( CudaRemoteModuleTest, RemoteModuleTest, @@ -118,6 +121,7 @@ def tearDown(self): CudaDistAutogradTest, CudaRemoteModuleTest, CudaDdpComparisonTest, + PipeWithDDPTest, ] From 3bcc3cddb580bf0f0f1958cfe27001f236eac2c1 Mon Sep 17 00:00:00 2001 From: Shan19900305 Date: Tue, 4 Jun 2024 18:19:28 +0000 Subject: [PATCH 329/706] Using scalarType instead string in function _group_tensors_by_device_and_dtype. (#127869) Now torch.dtype can pass through pybind11, so modify function _group_tensors_by_device_and_dtype to using scalar type. And without convert torch.dtype and string in python and c++ side. @ezyang @bdhirsh Pull Request resolved: https://github.com/pytorch/pytorch/pull/127869 Approved by: https://github.com/ezyang --- torch/_C/__init__.pyi.in | 2 +- torch/csrc/Module.cpp | 41 ++--------------------------------- torch/utils/_foreach_utils.py | 7 +----- 3 files changed, 4 insertions(+), 46 deletions(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 5d8a45f86523..99985151c19f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1195,7 +1195,7 @@ def _conv_determine_backend_memory_format( def _has_storage(x: Tensor) -> _bool: ... def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ... def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... -def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, str], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ... +def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ... def _check_tp_alloc_is_default(cls: Type) -> _bool: ... # NB: There is no Capsule type in typing, see diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 5fad0f0a9541..063cdbbde0b7 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2154,50 +2154,13 @@ Call this whenever a new thread is created in order to propagate values from return torch::should_allow_numbers_as_tensors(name); }); - // FIXME(crcrpar): Better to have `at::ScalarType` get mapped to `torch.dtype` - // Currently I see the second item of the key is displayed as - // e.g. `torch._C._te.ScalarType at 0x7fcf318adab0` - // I thought adding an appropriate type_caster of `at::ScalarType` to - // torch/csrc/pybind.h` would solve this but it caused segmentation fault in - // my environment. - using _DeviceDtypeKey = std::pair; - // Custom hasher is necessary to make unordered_map compilable for Windows - // debug targets. As `at::native::ParamsHash` only works on structs with - // standard layout, but std::string isn't one in Visual C++ debug builds, - // which one can easily verify by running something like: - // #define _DEBUG - // #include - // #include - // static_assert(std::is_standard_layout_v, "Oh noes"); - // If above condition is not met, VC++ raises a very cryptic compilation - // error. See - // https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for - // more detail - struct _DeviceDtypeHasher { - std::size_t operator()(const _DeviceDtypeKey& k) const noexcept { - static at::native::ParamsHash device_hasher; - static std::hash string_hasher; - return device_hasher(k.first) ^ string_hasher(k.second); - } - }; - using _FlatMap = std::unordered_map< - _DeviceDtypeKey, - at::native::TensorsAndIndicesT, - _DeviceDtypeHasher>; py_module.def( "_group_tensors_by_device_and_dtype", [](const std::vector>>& nested_tensorlist, const bool with_indices) { - _FlatMap map; - for (const auto& iter : - at::native::_group_tensors_by_first_tensors_device_and_dtype( - nested_tensorlist, with_indices)) { - const auto scalar_type_name = - torch::utils::getDtypeNames(iter.first.second).first; - map.insert({{iter.first.first, scalar_type_name}, iter.second}); - } - return map; + return at::native::_group_tensors_by_first_tensors_device_and_dtype( + nested_tensorlist, with_indices); }); py_module.def( diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index 6f8a9b5b7e23..bcc274579ad0 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -34,12 +34,7 @@ def _group_tensors_by_device_and_dtype( tensorlistlist: TensorListList, with_indices: bool = False, ) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]: - return { - (device, getattr(torch, str_dtype)): value - for (device, str_dtype), value in - torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items() - } - + return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices) def _device_has_foreach_support(device: torch.device) -> bool: return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting() From c7e936a56a052da2169606b1fa7e03856e409006 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 3 Jun 2024 11:47:09 -0700 Subject: [PATCH 330/706] [dynamo] Tensorvariable - track grad with _grad field (#127785) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127785 Approved by: https://github.com/jansel --- test/dynamo/test_repros.py | 10 ++++++++++ torch/_dynamo/variables/builtin.py | 3 +++ 2 files changed, 13 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 90e1d34e8acc..647507be076f 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5041,6 +5041,16 @@ def func(x, m): self.assertEqual(func(x, m), opt_func(x, m)) self.assertEqual(func(x, 0), opt_func(x, 0)) + def test_grad(self): + def fn(x, y): + x._grad = y + return x.grad.data + + x = torch.randn(4, requires_grad=True) + y = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager") + self.assertEqual(fn(x, y), opt_fn(x, y)) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index ce1d4bf9a0dd..2586c8deab94 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1702,6 +1702,9 @@ def _lower_version_count_by_1(x): return out tx.output.side_effects.store_attr(obj, name, val) + if name == "_grad": + tx.output.side_effects.store_attr(obj, "grad", val) + return val elif isinstance(obj, variables.UserDefinedObjectVariable): unimplemented( From 569c5e72e7b191f80ae14a3f3eed8431ec2f8646 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 3 Jun 2024 13:20:42 -0700 Subject: [PATCH 331/706] [dynamo] Unspec nn module when global backward hooks are present (#127802) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127802 Approved by: https://github.com/jansel ghstack dependencies: #127785 --- torch/_dynamo/mutation_guard.py | 5 ++++- torch/_dynamo/utils.py | 8 ++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py index c4a588888e11..1fa24cfa25bb 100644 --- a/torch/_dynamo/mutation_guard.py +++ b/torch/_dynamo/mutation_guard.py @@ -7,7 +7,7 @@ from torch.nn import Module from . import config -from .utils import ExactWeakKeyDictionary, is_lazy_module +from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks class MutationTracker: @@ -109,6 +109,9 @@ def is_dynamic_nn_module(obj, is_export): and not is_export ): return True + + if isinstance(obj, torch.nn.Module) and nn_module_has_global_hooks(): + return True dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check( obj ) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 2b42c8dec63d..04f01c757674 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2106,6 +2106,14 @@ def format_bytecode(prefix, name, filename, line_no, code): all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names +def nn_module_has_global_hooks(): + # This is limited to backward hooks for now because NNModuleVariable + # supports fwd hooks underneath. + return len(torch.nn.modules.module._global_backward_hooks) or len( + torch.nn.modules.module._global_backward_pre_hooks + ) + + def nn_module_get_all_hooks( mod, check_forward_hooks=False, From f27c4dd862bf79f37019ef277957cd577d57b66f Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 3 Jun 2024 18:02:31 -0700 Subject: [PATCH 332/706] [dynamo] Bugfix for nn parameter construction (#127806) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127806 Approved by: https://github.com/jansel ghstack dependencies: #127785, #127802 --- torch/_dynamo/variables/torch.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index fbfabb5fdf06..f8444a279aeb 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -896,6 +896,13 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): ) assert isinstance(result, variables.TensorVariable) result.class_type = torch.nn.Parameter + + # TODO(jansel/bdhirsh) - There is some issue with + # tracable_create_paramter. It does not seem to use the right + # grad_enabled. Since this is parameter, we can just override the + # has_grad_fn field to False to workaround the issue. + result.has_grad_fn = False + # In reconstruct() should use the original parameter. The one returned by the graph will be an alias. result.source = placeholder.source From 9a25ff77af932b59899f337a7a8dffbaab166ecf Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 4 Jun 2024 18:26:11 +0000 Subject: [PATCH 333/706] Revert "[inductor] Enable subprocess-based parallel compile as the default (#126817)" This reverts commit cf77e7dd9770caf65e898ac2ee82045aa0408e30. Reverted https://github.com/pytorch/pytorch/pull/126817 on behalf of https://github.com/huydhn due to There are lots of flaky inductor failure showing up in trunk after this commit https://hud.pytorch.org/pytorch/pytorch/commit/cf77e7dd9770caf65e898ac2ee82045aa0408e30, so I am trying to revert this to see if this helps ([comment](https://github.com/pytorch/pytorch/pull/126817#issuecomment-2148143502)) --- torch/_inductor/config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index dbaa528cd3e5..b8ff5ae5a6cd 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -387,9 +387,7 @@ def is_fbcode(): # The multiprocessing start method to use for inductor workers in the codecache. # "subprocess", "fork", or "spawn" def decide_worker_start_method(): - start_method = os.environ.get( - "TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess" - ) + start_method = os.environ.get("TORCHINDUCTOR_WORKER_START", "fork") assert start_method in [ "subprocess", "fork", From 01e6d1cae46ff4af8d55e04237a05e430cfb3136 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:18:46 -0700 Subject: [PATCH 334/706] [dtensor][debug] added c10d reduce_scatter_ and reduce_scatter_tensor_coalesced tracing_ to CommDebugMode (#127358) **Summary** Added c10d reduce_scatter_ and reduce_scatter_tensor_coalesced tracing to CommDebugMode and edited test case in test_comm_mode to include added features. **Test Plan** pytest test/distributed/_tensor/debug/test_comm_mode.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/127358 Approved by: https://github.com/wz337, https://github.com/XilunWu, https://github.com/yifuwang --- test/distributed/_tensor/debug/test_comm_mode.py | 12 ++++++++++++ torch/distributed/_tensor/debug/comm_mode.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py index 6cb94c860024..dc088f38988c 100644 --- a/test/distributed/_tensor/debug/test_comm_mode.py +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -193,6 +193,18 @@ def test_comm_mode_with_c10d(self): self.checksAssert(comm_mode, c10d_ops.allreduce_coalesced_, 1, 1) + # tests c10d reduce_scatter_ + with comm_mode: + dist.reduce_scatter(all_gather_out, [inp]) + + self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1) + + # tests c10d reduce_scatter_tensor_coalesced + with comm_mode as A, dist._coalescing_manager() as B: + dist.reduce_scatter_tensor(all_gather_out, inp) + + self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index d566da546d21..82f7e98c07cf 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -36,6 +36,8 @@ c10d_ops.gather_, c10d_ops.scatter_, c10d_ops.reduce_, + c10d_ops.reduce_scatter_, + c10d_ops.reduce_scatter_tensor_coalesced_, } From e76b28c7658561de1bcaff9b5af7156e1d3f3880 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:18:47 -0700 Subject: [PATCH 335/706] [dtensor][debug] added c10d alltoall_ and alltoall_base_ to CommDebugMode (#127360) **Summary** Added c10d alltoall_ and alltoall_base tracing to CommDebugMode and edited test case in test_comm_mode to include added features. **Test Plan** pytest test/distributed/_tensor/debug/test_comm_mode.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/127360 Approved by: https://github.com/wz337, https://github.com/XilunWu, https://github.com/yifuwang ghstack dependencies: #127358 --- test/distributed/_tensor/debug/test_comm_mode.py | 12 ++++++++++++ torch/distributed/_tensor/debug/comm_mode.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py index dc088f38988c..5483b3171f30 100644 --- a/test/distributed/_tensor/debug/test_comm_mode.py +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -205,6 +205,18 @@ def test_comm_mode_with_c10d(self): self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1) + # tests c10d alltoall_ + with comm_mode: + dist.all_to_all([inp], [inp]) + + self.checksAssert(comm_mode, c10d_ops.alltoall_, 1, 1) + + # tests c10d alltoall_base_ + with comm_mode: + dist.all_to_all_single(inp, inp) + + self.checksAssert(comm_mode, c10d_ops.alltoall_base_, 1, 1) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index 82f7e98c07cf..1ff97e4e78e1 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -32,6 +32,8 @@ c10d_ops.allgather_into_tensor_coalesced_, c10d_ops.allreduce_, c10d_ops.allreduce_coalesced_, + c10d_ops.alltoall_, + c10d_ops.alltoall_base_, c10d_ops.broadcast_, c10d_ops.gather_, c10d_ops.scatter_, From 597922ba21b23852b447623f62dae96f84f3fa59 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 4 Jun 2024 19:44:30 +0000 Subject: [PATCH 336/706] Reapply "distributed debug handlers (#126601)" (#127805) This reverts commit 7646825c3eb687030c4f873b01312be0eed80174. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127805 Approved by: https://github.com/PaliC --- BUILD.bazel | 1 + WORKSPACE | 6 + build_variables.bzl | 2 + caffe2/CMakeLists.txt | 3 + cmake/Dependencies.cmake | 4 + docs/source/distributed.elastic.rst | 1 + docs/source/elastic/control_plane.rst | 10 + .../distributed/elastic/test_control_plane.py | 86 +++++++++ third_party/cpp-httplib.BUILD | 10 + torch/CMakeLists.txt | 2 + torch/_C/_distributed_c10d.pyi | 4 + .../distributed/c10d/ProcessGroupNCCL.cpp | 8 + .../c10d/control_plane/Handlers.cpp | 75 ++++++++ .../c10d/control_plane/Handlers.hpp | 67 +++++++ .../c10d/control_plane/WorkerServer.cpp | 178 ++++++++++++++++++ .../c10d/control_plane/WorkerServer.hpp | 28 +++ torch/csrc/distributed/c10d/init.cpp | 12 ++ torch/distributed/elastic/control_plane.py | 51 +++++ 18 files changed, 548 insertions(+) create mode 100644 docs/source/elastic/control_plane.rst create mode 100644 test/distributed/elastic/test_control_plane.py create mode 100644 third_party/cpp-httplib.BUILD create mode 100644 torch/csrc/distributed/c10d/control_plane/Handlers.cpp create mode 100644 torch/csrc/distributed/c10d/control_plane/Handlers.hpp create mode 100644 torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp create mode 100644 torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp create mode 100644 torch/distributed/elastic/control_plane.py diff --git a/BUILD.bazel b/BUILD.bazel index 71ebc296598c..b58fb57199f3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -769,6 +769,7 @@ cc_library( ":caffe2", ":torch_headers", "@kineto", + "@cpp-httplib", ] + if_cuda([ "@cuda//:nvToolsExt", "@cutlass", diff --git a/WORKSPACE b/WORKSPACE index 5b4f2f2e3375..4169e0dbce1d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -168,6 +168,12 @@ new_local_repository( path = "third_party/opentelemetry-cpp", ) +new_local_repository( + name = "cpp-httplib", + build_file = "//third_party:cpp-httplib.BUILD", + path = "third_party/cpp-httplib", +) + new_local_repository( name = "tensorpipe", build_file = "//third_party:tensorpipe.BUILD", diff --git a/build_variables.bzl b/build_variables.bzl index 8b5ac4f46d7c..20822ba95cf2 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -515,6 +515,8 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/sequence_num.cpp", "torch/csrc/distributed/c10d/socket.cpp", "torch/csrc/distributed/c10d/Work.cpp", + "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", + "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", ] # These files are only supported on Linux (and others) but not on Windows. diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index e9b2b20ce6ad..fe24571b66a5 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1179,6 +1179,9 @@ if(USE_KINETO) ${TORCH_ROOT}/third_party/kineto/libkineto/src) endif() +target_include_directories(torch_cpu PRIVATE + ${TORCH_ROOT}/third_party/cpp-httplib) + install(DIRECTORY "${TORCH_SRC_DIR}/csrc" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp") diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 8c7751f4c07b..9693ac6e9fe6 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1681,3 +1681,7 @@ endif() # Include google/FlatBuffers include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake) + +# Include cpp-httplib +add_library(httplib INTERFACE IMPORTED) +target_include_directories(httplib SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/cpp-httplib) diff --git a/docs/source/distributed.elastic.rst b/docs/source/distributed.elastic.rst index 24d33d1982df..0aabb560c9c8 100644 --- a/docs/source/distributed.elastic.rst +++ b/docs/source/distributed.elastic.rst @@ -29,6 +29,7 @@ Documentation elastic/metrics elastic/events elastic/subprocess_handler + elastic/control_plane .. toctree:: :maxdepth: 1 diff --git a/docs/source/elastic/control_plane.rst b/docs/source/elastic/control_plane.rst new file mode 100644 index 000000000000..c37454cf1b0a --- /dev/null +++ b/docs/source/elastic/control_plane.rst @@ -0,0 +1,10 @@ +Control Plane +============= + +.. automodule:: torch.distributed.elastic.control_plane +.. currentmodule:: torch.distributed.elastic.control_plane + +This module contains optional helpers that add extra debug and control handlers +into your application. + +.. autofunction:: torch.distributed.elastic.control_plane.worker_main diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py new file mode 100644 index 000000000000..c9ae512f2718 --- /dev/null +++ b/test/distributed/elastic/test_control_plane.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Owner(s): ["oncall: distributed"] + +import json +import os +import pickle +import socket +import tempfile +from contextlib import contextmanager + +from urllib3.connection import HTTPConnection +from urllib3.connectionpool import HTTPConnectionPool + +from torch.distributed.elastic.control_plane import ( + TORCH_WORKER_SERVER_SOCKET, + worker_main, +) +from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase + + +class UnixHTTPConnection(HTTPConnection): + def __init__(self, socket_path: str) -> None: + super().__init__("localhost") + + self.socket_path = socket_path + + def connect(self) -> None: + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.connect(self.socket_path) + + +class UnixHTTPConnectionPool(HTTPConnectionPool): + def __init__(self, socket_path: str) -> None: + super().__init__("localhost") + + self.socket_path = socket_path + + def _new_conn(self): + return UnixHTTPConnection(self.socket_path) + + +@contextmanager +def local_worker_server() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "socket.sock") + os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path + + with worker_main(): + pool = UnixHTTPConnectionPool(socket_path) + yield pool + + +class WorkerServerTest(TestCase): + def test_worker_server(self) -> None: + with local_worker_server() as pool: + resp = pool.request("GET", "/") + self.assertEqual(resp.status, 200) + self.assertEqual( + resp.data, + b"""

torch.distributed.WorkerServer

+Handler names +""", + ) + + resp = pool.request("POST", "/handler/ping") + self.assertEqual(resp.status, 200) + self.assertEqual(resp.data, b"pong") + + resp = pool.request("GET", "/handler/") + self.assertEqual(resp.status, 200) + self.assertIn("ping", json.loads(resp.data)) + + resp = pool.request("POST", "/handler/nonexistant") + self.assertEqual(resp.status, 404) + self.assertIn(b"Handler nonexistant not found:", resp.data) + + @requires_cuda + def test_dump_nccl_trace_pickle(self) -> None: + with local_worker_server() as pool: + resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") + self.assertEqual(resp.status, 200) + out = pickle.loads(resp.data) + + +if __name__ == "__main__": + run_tests() diff --git a/third_party/cpp-httplib.BUILD b/third_party/cpp-httplib.BUILD new file mode 100644 index 000000000000..3cd0c3dbe94b --- /dev/null +++ b/third_party/cpp-httplib.BUILD @@ -0,0 +1,10 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "cpp-httplib", + hdrs = ["httplib.h"], + includes = [ + "/", + ], + visibility = ["//visibility:public"], +) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index fa62688c7e86..10a44af747be 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -68,6 +68,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES ${TORCH_ROOT}/third_party/onnx ${TORCH_ROOT}/third_party/flatbuffers/include ${TORCH_ROOT}/third_party/kineto/libkineto/include + ${TORCH_ROOT}/third_party/cpp-httplib ${TORCH_SRC_DIR}/csrc ${TORCH_SRC_DIR}/csrc/api/include @@ -80,6 +81,7 @@ set(TORCH_PYTHON_LINK_LIBRARIES Python::Module pybind::pybind11 opentelemetry::api + httplib shm fmt::fmt-header-only ATEN_CPU_FILES_GEN_LIB) diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 1a3e4ea63342..dab215d396ce 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -94,6 +94,10 @@ class Logger: def _set_uneven_input_join(self) -> None: ... def _set_static_graph(self) -> None: ... +class _WorkerServer: + def __init__(self, socket_path: str) -> None: ... + def shutdown(self) -> None: ... + def get_debug_level(): ... def set_debug_level(): ... def set_debug_level_from_env(): ... diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 1c0bdc43be35..2e55bfdb6f34 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -369,6 +370,13 @@ std::string dump_nccl_trace() { } #endif +// TODO(c-p-i-o): add a JSON endpoint. +control_plane::RegisterHandler dumpHandler{ + "dump_nccl_trace_pickle", + [](const control_plane::Request&, control_plane::Response& res) { + res.setContent(dump_nccl_trace(), "application/octet-stream"); + }}; + std::optional)>>& get_cpp_trace_dumper() { static std::optional< diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp new file mode 100644 index 000000000000..e29f1e3a2ac3 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -0,0 +1,75 @@ +#include + +#include +#include +#include +#include + +namespace c10d { +namespace control_plane { + +namespace { + +class HandlerRegistry { + public: + void registerHandler(const std::string& name, HandlerFunc f) { + std::unique_lock lock(handlersMutex_); + + if (handlers_.find(name) != handlers_.end()) { + throw std::runtime_error( + fmt::format("Handler {} already registered", name)); + } + + handlers_[name] = f; + } + + HandlerFunc getHandler(const std::string& name) { + std::shared_lock lock(handlersMutex_); + + auto it = handlers_.find(name); + if (it == handlers_.end()) { + throw std::runtime_error(fmt::format("Failed to find handler {}", name)); + } + return handlers_[name]; + } + + std::vector getHandlerNames() { + std::shared_lock lock(handlersMutex_); + + std::vector names; + for (const auto& [name, _] : handlers_) { + names.push_back(name); + } + return names; + } + + private: + std::shared_mutex handlersMutex_{}; + std::unordered_map handlers_{}; +}; + +HandlerRegistry& getHandlerRegistry() { + static HandlerRegistry registry; + return registry; +} + +RegisterHandler pingHandler{"ping", [](const Request&, Response& res) { + res.setContent("pong", "text/plain"); + }}; + +} // namespace + +void registerHandler(const std::string& name, HandlerFunc f) { + return getHandlerRegistry().registerHandler(name, f); +} + +HandlerFunc getHandler(const std::string& name) { + return getHandlerRegistry().getHandler(name); +} + +std::vector getHandlerNames() { + return getHandlerRegistry().getHandlerNames(); +} + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp new file mode 100644 index 000000000000..0c1063054931 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include + +namespace c10d { +namespace control_plane { + +// Request represents a request to the handler. This conceptually maps to an +// HTTP request but could be called via other transports. +class TORCH_API Request { + public: + virtual ~Request() = default; + + virtual const std::string& body() = 0; +}; + +// Response represents a response to the handler. This conceptually maps to an +// HTTP response but could be called via other transports. +class TORCH_API Response { + public: + virtual ~Response() = default; + + // Set the response body to the provided string. + // TODO: add support for chunked responses + virtual void setContent( + std::string&& content, + const std::string& content_type) = 0; + + // Set the response status code. + // These should match standard HTTP status codes. + virtual void setStatus(int status) = 0; +}; + +using HandlerFunc = std::function; + +// Registers a handler. The name needs to be unique and can be called by using +// getHandler directly or via WorkerServer for remote requests. +// These handlers are called from a background C++ thread concurrently with the +// main thread. These handlers need to be thread safe and not cause issues +// during Python training. +TORCH_API void registerHandler(const std::string& name, HandlerFunc f); + +// Fetches a handler by name. +TORCH_API HandlerFunc getHandler(const std::string& name); + +TORCH_API std::vector getHandlerNames(); + +// Registers a handler statically. +// See registerHandler for more details. +class TORCH_API RegisterHandler { + public: + RegisterHandler(const std::string& name, HandlerFunc f) { + registerHandler(name, f); + } + + // disable move, copy + RegisterHandler(const RegisterHandler&) = delete; + RegisterHandler(RegisterHandler&&) = delete; + RegisterHandler& operator=(const RegisterHandler&) = delete; + RegisterHandler& operator=(RegisterHandler&&) = delete; +}; + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp new file mode 100644 index 000000000000..14d287e9607f --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -0,0 +1,178 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10d { +namespace control_plane { + +namespace { +class RequestImpl : public Request { + public: + RequestImpl(const httplib::Request& req) : req_(req) {} + + const std::string& body() override { + return req_.body; + } + + private: + const httplib::Request& req_; +}; + +class ResponseImpl : public Response { + public: + ResponseImpl(httplib::Response& res) : res_(res) {} + + void setStatus(int status) override { + res_.status = status; + } + + void setContent(std::string&& content, const std::string& content_type) + override { + res_.set_content(std::move(content), content_type); + } + + private: + httplib::Response& res_; +}; + +std::string jsonStrEscape(const std::string& str) { + std::ostringstream ostream; + for (char ch : str) { + if (ch == '"') { + ostream << "\\\""; + } else if (ch == '\\') { + ostream << "\\\\"; + } else if (ch == '\b') { + ostream << "\\b"; + } else if (ch == '\f') { + ostream << "\\f"; + } else if (ch == '\n') { + ostream << "\\n"; + } else if (ch == '\r') { + ostream << "\\r"; + } else if (ch == '\t') { + ostream << "\\t"; + } else if ('\x00' <= ch && ch <= '\x1f') { + ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0') + << static_cast(ch); + } else { + ostream << ch; + } + } + return ostream.str(); +} +} // namespace + +WorkerServer::WorkerServer(const std::string& socketFile) { + // using unix sockets + server_.set_address_family(AF_UNIX); + + // adjust keep alives as it stops the server from shutting down quickly + server_.set_keep_alive_timeout(1); // second, default is 5 + server_.set_keep_alive_max_count( + 30); // wait max 30 seconds before closing socket + + server_.Get("/", [](const httplib::Request& req, httplib::Response& res) { + res.set_content( + R"BODY(

torch.distributed.WorkerServer

+Handler names +)BODY", + "text/html"); + }); + server_.Get( + "/handler/", [](const httplib::Request& req, httplib::Response& res) { + std::ostringstream body; + body << "["; + bool first = true; + for (const auto& name : getHandlerNames()) { + if (!first) { + body << ","; + } + first = false; + + body << "\"" << jsonStrEscape(name) << "\""; + } + body << "]"; + + res.set_content(body.str(), "application/json"); + }); + server_.Post( + "/handler/:handler", + [](const httplib::Request& req, httplib::Response& res) { + auto handler_name = req.path_params.at("handler"); + HandlerFunc handler; + try { + handler = getHandler(handler_name); + } catch (const std::exception& e) { + res.status = 404; + res.set_content( + fmt::format("Handler {} not found: {}", handler_name, e.what()), + "text/plain"); + return; + } + RequestImpl torchReq{req}; + ResponseImpl torchRes{res}; + + try { + handler(torchReq, torchRes); + } catch (const std::exception& e) { + res.status = 500; + res.set_content( + fmt::format("Handler {} failed: {}", handler_name, e.what()), + "text/plain"); + return; + } catch (...) { + res.status = 500; + res.set_content( + fmt::format( + "Handler {} failed with unknown exception", handler_name), + "text/plain"); + return; + } + }); + + if (std::filesystem::exists(socketFile)) { + throw std::runtime_error(fmt::format("{} already exists", socketFile)); + } + + C10D_WARNING("Server listening to {}", socketFile); + if (!server_.bind_to_port(socketFile, 80)) { + throw std::runtime_error(fmt::format("Error binding to {}", socketFile)); + } + + serverThread_ = std::thread([this]() { + try { + if (!server_.listen_after_bind()) { + throw std::runtime_error("failed to listen"); + } + } catch (std::exception& e) { + C10D_ERROR("Error while running server: {}", e.what()); + throw; + } + C10D_WARNING("Server exited"); + }); +} + +void WorkerServer::shutdown() { + C10D_WARNING("Server shutting down"); + server_.stop(); + serverThread_.join(); +} + +WorkerServer::~WorkerServer() { + if (serverThread_.joinable()) { + C10D_WARNING("WorkerServer destructor called without shutdown"); + shutdown(); + } +} + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp new file mode 100644 index 000000000000..7d64038f0b01 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace c10d { +namespace control_plane { + +class TORCH_API WorkerServer : public c10::intrusive_ptr_target { + public: + WorkerServer(const std::string& socketFile); + ~WorkerServer(); + + void shutdown(); + + private: + httplib::Server server_; + std::thread serverThread_; +}; + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f6dae326065..c4b9a9823c84 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #ifndef _WIN32 #include @@ -3164,6 +3165,17 @@ such as `dist.all_reduce(tensor, async_op=True)`. return py::bytes(::c10d::dump_nccl_trace()); }); #endif + + intrusive_ptr_class_<::c10d::control_plane::WorkerServer>( + module, "_WorkerServer", R"( +)") + .def( + py::init([](const std::string& socketPath) { + return c10::make_intrusive<::c10d::control_plane::WorkerServer>( + socketPath); + }), + py::arg("socket_path")) + .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown); Py_RETURN_TRUE; } diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py new file mode 100644 index 000000000000..160383637865 --- /dev/null +++ b/torch/distributed/elastic/control_plane.py @@ -0,0 +1,51 @@ +import os +from contextlib import contextmanager, ExitStack +from typing import Generator + +from torch.distributed.elastic.multiprocessing.errors import record + +__all__ = [ + "worker_main", +] + +TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" + + +@contextmanager +def _worker_server(socket_path: str) -> Generator[None, None, None]: + from torch._C._distributed_c10d import _WorkerServer + + server = _WorkerServer(socket_path) + try: + yield + finally: + server.shutdown() + + +@contextmanager +@record +def worker_main() -> Generator[None, None, None]: + """ + This is a context manager that wraps your main entry function. This combines + the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that + exposes handlers via a unix socket specified by + ``Torch_WORKER_SERVER_SOCKET``. + + Example + + :: + + @worker_main() + def main(): + pass + + if __name__=="__main__": + main() + + """ + with ExitStack() as stack: + socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) + if socket_path is not None: + stack.enter_context(_worker_server(socket_path)) + + yield From 6dc0a291b9bf27aa7258866591f20ed246acb81c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 4 Jun 2024 20:51:41 +0000 Subject: [PATCH 337/706] Revert "[dynamo] Bugfix for nn parameter construction (#127806)" This reverts commit f27c4dd862bf79f37019ef277957cd577d57b66f. Reverted https://github.com/pytorch/pytorch/pull/127806 on behalf of https://github.com/PaliC due to causing nn tests to fail ([comment](https://github.com/pytorch/pytorch/pull/127806#issuecomment-2148393903)) --- torch/_dynamo/variables/torch.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index f8444a279aeb..fbfabb5fdf06 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -896,13 +896,6 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): ) assert isinstance(result, variables.TensorVariable) result.class_type = torch.nn.Parameter - - # TODO(jansel/bdhirsh) - There is some issue with - # tracable_create_paramter. It does not seem to use the right - # grad_enabled. Since this is parameter, we can just override the - # has_grad_fn field to False to workaround the issue. - result.has_grad_fn = False - # In reconstruct() should use the original parameter. The one returned by the graph will be an alias. result.source = placeholder.source From 1b704a160f2e055b8e3d9634433fddd8bf34ff18 Mon Sep 17 00:00:00 2001 From: Ting Lu Date: Tue, 4 Jun 2024 20:51:41 +0000 Subject: [PATCH 338/706] Add linker script optimization flag to CMAKE rule for CUDA ARM wheel (#127514) Original PR - https://github.com/pytorch/pytorch/pull/127220 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127514 Approved by: https://github.com/Aidyn-A, https://github.com/atalman --- CMakeLists.txt | 2 ++ cmake/Summary.cmake | 3 +++ 2 files changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 998073bc72b3..b36f52c93fe0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -540,6 +540,8 @@ option(BUILD_EXECUTORCH "Master flag to build Executorch" ON) if(LINUX) set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--no-as-needed") + set(CMAKE_SHARED_LINKER_FLAGS + "${CMAKE_SHARED_LINKER_FLAGS} $ENV{LDFLAGS}") endif() if(MSVC) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index bc15f70ad1f5..289419c38603 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -14,6 +14,9 @@ function(caffe2_print_configuration_summary) message(STATUS " Found ccache : ${CCACHE_PROGRAM}") endif() message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}") + message(STATUS " Shared LD flags : ${CMAKE_SHARED_LINKER_FLAGS}") + message(STATUS " Static LD flags : ${CMAKE_STATIC_LINKER_FLAGS}") + message(STATUS " Module LD flags : ${CMAKE_MODULE_LINKER_FLAGS}") message(STATUS " Build type : ${CMAKE_BUILD_TYPE}") get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS) message(STATUS " Compile definitions : ${tmp}") From a7b1dd82ff3063894fc665ab0c424815231c10e6 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 29 May 2024 17:14:06 -0700 Subject: [PATCH 339/706] Default XLA to use swap_tensors path in nn.Module._apply (#126814) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126814 Approved by: https://github.com/JackCaoG, https://github.com/albanD ghstack dependencies: #127313 --- test/test_nn.py | 4 ++-- torch/nn/modules/module.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 6dfac4f7ca1b..6bcb4017e4b5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8184,9 +8184,9 @@ def test_batchnorm_large_batch(self, device, dtype): @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128) def test_conv_empty_input(self, device, dtype): def help(input, conv, memory_format): - ref_out = conv(input) + ref_out = conv(input).detach() conv_cl = conv.to(memory_format=memory_format) - out_cl = conv_cl(input) + out_cl = conv_cl(input).detach() self.assertEqual(ref_out, out_cl) input_cl = input.to(memory_format=memory_format) out_cl2 = conv(input_cl) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index ffd429cc06f2..3d683cb82181 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -794,6 +794,13 @@ def compute_should_use_set_data(tensor, tensor_applied): should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() + def compute_should_use_swap_tensors(tensor, tensor_applied): + return (should_use_swap_tensors + # subclasses may have multiple child tensors so we need to use swap_tensors + or is_traceable_wrapper_subclass(tensor_applied) + or tensor.device.type == 'xla' + or tensor_applied.device.type == 'xla') + for key, param in self._parameters.items(): if param is None: continue @@ -804,8 +811,7 @@ def compute_should_use_set_data(tensor, tensor_applied): param_applied = fn(param) p_should_use_set_data = compute_should_use_set_data(param, param_applied) - # subclasses may have multiple child tensors so we need to use swap_tensors - p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) + p_should_use_swap_tensors = compute_should_use_swap_tensors(param, param_applied) param_grad = param.grad if p_should_use_swap_tensors: From 20f966a8e0d145b410fac5ab0e050175e5ef4786 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Tue, 4 Jun 2024 22:11:09 +0000 Subject: [PATCH 340/706] Ignore undocumented PipelineSchedule.step (#127955) Ignore undocumented PipelineSchedule.step to fix doc build: https://github.com/pytorch/pytorch/actions/runs/9372492435/job/25805861083?pr=127938#step:11:1284 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/127955 Approved by: https://github.com/kit1980 --- docs/source/conf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index fe548737b313..ef492f17c506 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -645,6 +645,8 @@ "create_workers", "spawn_workers", "worker", + # torch.distributed.pipelining.PipelineSchedule + "step", # torch.distributed.rendezvous "register_rendezvous_handler", "rendezvous", From 0eb9ec958a949baf1733248cdb2ac36d22fe8c1f Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 4 Jun 2024 08:56:47 -0700 Subject: [PATCH 341/706] Revert "Inductor respects strides for custom ops by default (#126986)" (#127923) This reverts commit dd64ca2a02434944ecbc8f3e186d44ba81e3cb26. There's a silent incorrectness bug with needs_fixed_stride_order=True and mutable custom ops, so it's better to flip the default back to avoid silent incorrectness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127923 Approved by: https://github.com/williamwen42 --- aten/src/ATen/native/tags.yaml | 10 ----- test/inductor/test_torchinductor.py | 59 +---------------------------- torch/_inductor/graph.py | 17 +-------- torch/_library/custom_ops.py | 2 +- 4 files changed, 4 insertions(+), 84 deletions(-) diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index 727534c0d347..c31721729036 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -46,16 +46,6 @@ desc: | This tag indicates that the operator should be passed Tensors following the same stride permutation as observed in eager when compiled in inductor. - The default for custom ops (i.e. not torch._library.utils.is_builtin) - is that they do need a fixed stride order; add `does_not_need_fixed_stride_order` - to change the behavior. - The default for builtin ops is that they do not need a fixed stride order; - add `needs_fixed_stride_order` to change the behavior. -- tag: does_not_need_fixed_stride_order - desc: | - This tag indicates that the operator doesn't need to be passed Tensors following - the same stride permutation as observed in eager when compiled in inductor. - See `needs_fixed_stride_order` for more details. # NOTE [Core ATen Ops] - tag: core diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 111d0e1ef959..c6928ee37a8e 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -9810,6 +9810,7 @@ def bar_meta(x): bar_cuda, bar_xpu, bar_meta, + tags=[torch._C.Tag.needs_fixed_stride_order], ) def fn(x): @@ -9872,69 +9873,13 @@ def baz_meta(x): baz_cuda, baz_xpu, baz_meta, + tags=[torch._C.Tag.needs_fixed_stride_order], ) with torch.no_grad(): net = torch.compile(model) out = net(input_t) - @requires_gpu() - @config.patch(implicit_fallbacks=True) - def test_needs_fixed_stride_order(self): - with torch.library._scoped_library("prims", "FRAGMENT") as prims_lib: - with torch.library._scoped_library("custom", "FRAGMENT") as custom_lib: - strides = [] - - def foo_impl(x): - strides.append(x.stride()) - return x.clone() - - def foo_meta(x): - return x.clone() - - all_ops = [] - for ( - needs_fixed_stride_order, - does_not_need_fixed_stride_order, - ) in itertools.product([True, False], [True, False]): - tags = [] - if needs_fixed_stride_order: - tags.append(torch.Tag.needs_fixed_stride_order) - if does_not_need_fixed_stride_order: - tags.append(torch.Tag.does_not_need_fixed_stride_order) - name = f"foo_{int(needs_fixed_stride_order)}{int(does_not_need_fixed_stride_order)}" - for ns, lib in {"custom": custom_lib, "prims": prims_lib}.items(): - all_ops.append(ns + "::" + name) - lib.define(f"{name}(Tensor x) -> Tensor", tags=tags) - lib.impl(name, foo_impl, "CompositeExplicitAutograd") - lib.impl(name, foo_meta, "Meta") - - assert len(all_ops) == 8 - expect_contig_strides = { - "custom::foo_01", - "prims::foo_00", - "prims::foo_01", - } - print(all_ops) - - for qualname in all_ops: - ns, name = qualname.split("::") - op = getattr(getattr(torch.ops, ns), name) - - @torch.compile(fullgraph=True) - def f(x): - y = x.t().contiguous().t() - y = y.sin() - return op(y) - - x = torch.randn(24, 24, device=self.device) - f(x) - stride = strides[-1] - if qualname in expect_contig_strides: - self.assertEqual(stride, (24, 1)) - else: - self.assertEqual(stride, (1, 24)) - def test_buffer_use_after_remove(self): # https://github.com/pytorch/pytorch/issues/102857 diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index a22e31baf752..abe93686ac83 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -929,22 +929,7 @@ def get_custom_op_layout_constraints(target, args, kwargs): # which run through implicit fallback must constrain their # arguments' fx strides layout_constraint = None - - def needs_fixed_stride_order(target): - if ( - torch._C.Tag.needs_fixed_stride_order in target.tags - and torch._C.Tag.does_not_need_fixed_stride_order in target.tags - ): - # If both tags were specified, pessimistically assume that we do need it. - return True - if torch._library.utils.is_builtin(target): - return torch._C.Tag.needs_fixed_stride_order in target.tags - else: - return ( - torch._C.Tag.does_not_need_fixed_stride_order not in target.tags - ) - - if needs_fixed_stride_order(target): + if torch._C.Tag.needs_fixed_stride_order in target.tags: # We have to set the current args because call_function will immediately # evaluate this lowering after creating the fallback, without evaluating # the layout constraint diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 3272ffc1a18f..20758d24e37a 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -453,7 +453,7 @@ def _register_to_dispatcher(self) -> None: lib.define( schema_str, - tags=[_C.Tag.pt2_compliant_tag], + tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order], ) self._opoverload = _library.utils.lookup_op(self._qualname) From f4b05ce683662f0b50185e9818f8feafa0781aaf Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Tue, 4 Jun 2024 22:53:00 +0000 Subject: [PATCH 342/706] Add registry for TorchScript to ExportedProgram conversion (#127464) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127464 Approved by: https://github.com/ydwu4, https://github.com/angelayi --- torch/_export/converter.py | 85 ++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 2c021377cb4f..08decec27fea 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -41,7 +41,15 @@ def normalize_name(name: str) -> str: return name.replace(".", "_") -# Given a node: torch._C.Node, map from node.kind() to a standard operator +def ir_name_to_func_name(name: str) -> str: + """prim::If -> convert_prim_If""" + name_list = name.split("::") + return "convert_" + "_".join(name_list) + + +# Those operators will be automatically populated to a instance method +# of TS2FXGraphConverter with name convert__(). +# Please check __init__ for method population implementations. kind_to_standard_operators = { "prim::TupleIndex": operator.getitem, "aten::__is__": operator.is_, @@ -89,6 +97,17 @@ def __init__( self.subgraphs: Dict[str, torch.fx.GraphModule] = {} + # Populate methods for the standard operators. + for k in kind_to_standard_operators.keys(): + handler_func_name = ir_name_to_func_name(k) + # Create an indirect function call: + # convert__ --> lambda node: _convert_standard_operator(node) + setattr( + self, + handler_func_name, + lambda node: self._convert_standard_operators(node), + ) + def add_subgraph(self, subgraph) -> str: name = f"subgraph_{len(self.subgraphs)}" self.subgraphs[name] = subgraph @@ -263,7 +282,13 @@ def convert_aten_op(self, node: torch._C.Node): output_name = node.output().debugName() self.name_to_node[output_name] = fx_node + def convert_prim_TupleConstruct(self, node: torch._C.Node): + self._convert_prim_iterator(node) + def convert_prim_ListConstruct(self, node: torch._C.Node): + self._convert_prim_iterator(node) + + def _convert_prim_iterator(self, node: torch._C.Node): output_list = [] for inp in node.inputs(): output_list.append(self.get_fx_value(inp)) @@ -376,7 +401,7 @@ def convert_aten___getitem__(self, node: torch._C.Node): output_name = node.output().debugName() self.name_to_node[output_name] = fx_node - def convert_prim_if(self, node: torch._C.Node): + def convert_prim_If(self, node: torch._C.Node): inputs = list(node.inputs()) assert len(inputs) == 1 predicate = self.get_fx_value(inputs[0]) @@ -429,7 +454,10 @@ def convert_prim_if(self, node: torch._C.Node): output_name = node.output().debugName() self.name_to_node[output_name] = cond_node - def convert_as_noop(self, node: torch._C.Node): + def convert_aten_Bool(self, node: torch._C.Node): + self._convert_as_noop(node) + + def _convert_as_noop(self, node: torch._C.Node): # Converts the node as a no-op by mapping its output node as arg[0] target = get_op_overload(node) @@ -455,7 +483,7 @@ def convert_profiler__record_function_exit(self, node: torch._C.Node): args = tuple(self.get_fx_value(input) for input in node.inputs()) self.fx_graph.call_function(target, args) - def convert_standard_operators(self, node: torch._C.Node): + def _convert_standard_operators(self, node: torch._C.Node): target = kind_to_standard_operators[node.kind()] args = tuple(self.get_fx_value(input) for input in node.inputs()) fx_node = self.fx_graph.call_function(target, args) @@ -464,43 +492,18 @@ def convert_standard_operators(self, node: torch._C.Node): def convert_node(self, node: torch._C.Node): node_kind = node.kind() - if node_kind == "prim::CreateObject": - self.convert_prim_CreateObject(node) - elif node_kind == "prim::Constant": - self.convert_prim_Constant(node) - elif node_kind == "prim::GetAttr": - self.convert_prim_GetAttr(node) - elif node_kind == "prim::NumToTensor": - self.convert_prim_NumToTensor(node) - elif node_kind in {"prim::ListConstruct", "prim::TupleConstruct"}: - # Tuple is just a non-mutable List, so we can handle them together. - self.convert_prim_ListConstruct(node) - elif node_kind == "prim::device": - self.convert_prim_device(node) - elif node_kind == "prim::dtype": - self.convert_prim_dtype(node) - elif node_kind == "prim::DictConstruct": - self.convert_prim_DictConstruct(node) - # elif node_kind == "aten::Int": - # convert_aten_Int(node) - elif node_kind == "aten::_convolution": - self.convert_aten__convolution(node) - elif node_kind == "aten::__getitem__": - self.convert_aten___getitem__(node) - elif node_kind == "aten::div": - self.convert_aten_div(node) - elif node_kind == "prim::If": - self.convert_prim_if(node) - elif node_kind == "aten::Bool": - self.convert_as_noop(node) - elif node_kind == "profiler::_record_function_enter_new": - self.convert_profiler__record_function_enter_new(node) - elif node_kind == "profiler::_record_function_exit": - self.convert_profiler__record_function_exit(node) - elif node_kind in kind_to_standard_operators: - self.convert_standard_operators(node) - elif node_kind.startswith("aten::"): - # order matters! this should be handled after kind_to_standard_operators + node_kind_split = node_kind.split("::") + + # Get handler based on namespace and operator name. + # Provide a default node handler as well in case we don't find + # matching converter for that. + handler_func_name = ir_name_to_func_name(node_kind) + handler_func = getattr(self, handler_func_name, self.convert_default_node) + handler_func(node) + + def convert_default_node(self, node: torch._C.Node): + node_kind = node.kind() + if node_kind.startswith("aten::"): self.convert_aten_op(node) else: raise ValueError(f"Unsupported node kind: {node_kind}") From 907cb28f676a6d3f44d6f3a2503c56888ebecc93 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 4 Jun 2024 23:06:43 +0000 Subject: [PATCH 343/706] Revert "Inductor: Allow small sizes of m for mixed mm autotuning (#127663)" This reverts commit d8d0bf264a736c7fb3cd17799a1c1aba4addf8d9. Reverted https://github.com/pytorch/pytorch/pull/127663 on behalf of https://github.com/soulitzer due to breaks torch ao CI, see: https://github.com/pytorch/pytorch/issues/127924 ([comment](https://github.com/pytorch/pytorch/pull/127663#issuecomment-2148554128)) --- torch/_inductor/kernel/mm.py | 15 +++----------- torch/_inductor/kernel/mm_common.py | 32 +++++------------------------ 2 files changed, 8 insertions(+), 39 deletions(-) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index eba1c65702e8..a90fdbfa33d9 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -26,7 +26,6 @@ from .mm_common import ( addmm_epilogue, int8_mm_configs, - mixed_mm_configs, mm_args, mm_configs, mm_grid, @@ -408,23 +407,15 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype): # can't use triton kernel unless one of these is true or if running on v100 (numerical issues) skip_triton = ( - mat1.layout.dtype != torch.float32 - and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed()) + mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous() ) or _is_sm7x_or_older_gpu(layout.device.index) if inductor_config.force_mixed_mm: choices = [] if not skip_triton: b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "") - for config in mixed_mm_configs(m, n, k): - # skipping this config because triton crashes on it - # See: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424 - if ( - config.kwargs["BLOCK_M"] == 16 - and config.kwargs["BLOCK_K"] == 16 - and config.kwargs["BLOCK_N"] == 64 - ): - continue + has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2) + for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor): mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 7fa403fe78f1..97741cc0f8eb 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -27,10 +27,14 @@ def filtered_configs( n: int, k: int, configs: List[Tuple[int, int, int, int, int]], + has_int8_tensor=False, ): """Heuristic to shrink configs when they are bigger than the input size""" - min_block_size = 16 + # According to https://github.com/openai/triton/issues/2156#issuecomment-1695897424 + # it's safer to use at least [32, 32] block size for int8/uint8 + # tensors + min_block_size = 32 if has_int8_tensor else 16 m = max( next_power_of_2( V.graph.sizevars.size_hint( @@ -162,18 +166,6 @@ def filtered_configs( {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None}, ] -# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192). -mixed_mm_kernel_configs_small_m = [ - {"config": (16, 128, 256, 3, 4), "cond": True}, - {"config": (16, 128, 256, 5, 8), "cond": True}, -] - -mixed_mm_kernel_configs = ( - mm_kernel_configs + mixed_mm_kernel_configs_small_m - if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" - else mm_kernel_configs -) - # Create filtered list of configs based on cond evaluation @@ -187,11 +179,6 @@ def filtered_configs( for config in int8_mm_kernel_configs if config["cond"] ) -mixed_mm_platform_configs = tuple( - cast(Tuple[int, int, int, int, int], config["config"]) - for config in mixed_mm_kernel_configs - if config["cond"] -) # On ROCm convert num_stages to 0 to enable software pipelining if torch.version.hip: @@ -203,10 +190,6 @@ def filtered_configs( (config[0], config[1], config[2], 0, config[4]) for config in mm_platform_configs ) - mixed_mm_platform_configs = tuple( - (config[0], config[1], config[2], 0, config[4]) - for config in mixed_mm_platform_configs - ) mm_configs = functools.partial( filtered_configs, @@ -218,11 +201,6 @@ def filtered_configs( configs=int8_platform_configs, ) -mixed_mm_configs = functools.partial( - filtered_configs, - configs=mixed_mm_platform_configs, -) - def mm_grid(m, n, meta): """ From 1f67cfd4372715778ece303762f3a09c8460c130 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Tue, 4 Jun 2024 16:17:12 -0700 Subject: [PATCH 344/706] [inductor] raise tolerance for cspdarknet (#127949) cspdarknet previously is flaky but after https://github.com/pytorch/pytorch/pull/127367 it fails quite stably. It's probably due to small numerical change from the mentioned PR. That PR will let inductor generated different code due to different loop orders. Raise tolerance to pass CI. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127949 Approved by: https://github.com/atalman, https://github.com/nWEIdia, https://github.com/eqy --- .../cu124/dynamic_inductor_timm_training.csv | 2 +- benchmarks/dynamo/timm_models.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv index 9443ae8c83a8..ae860db793c9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv @@ -38,7 +38,7 @@ crossvit_9_240,pass,7 -cspdarknet53,fail_accuracy,7 +cspdarknet53,pass,7 diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index d5cdc533da43..60a7cc81c06f 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -77,6 +77,7 @@ def pip_install(package): "mobilenetv3_large_100", "sebotnet33ts_256", "selecsls42b", + "cspdarknet53", } REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = { From 7fdfb88f03215161f1c305d9b45951b0bd9e30d7 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 3 Jun 2024 14:56:45 -0700 Subject: [PATCH 345/706] [pipelining] rewrite interleaved 1f1b (#127332) ## Context Interleaved 1F1B has multiple points in the schedule where communication is both criss-crossed across ranks leading to hangs due to 1. looped nature of schedules, 2. batched nature of forward + backward in 1f1b phase. image In the current implementation, it is difficult to fix these hangs since it requires `dist.recv` from a prior point in time, but each rank operates on its own step schedule and does not have knowledge of other ranks operations to perform the `recv` prior to their own `send`. ## New implementation The new implementation is split into 2 parts: 1. Creating the pipeline order. Each rank will create the timestep normalized ordering of all schedule actions across all ranks. This is created once during the initialization of the schedule class. The timestep between each rank is normalized as each rank can only have 1 computation action (forward or backward) during that timestep. image 3. Executing the pipeline order. Once the pipeline order is determined, execution is simple because as each rank will perform its send to its peer (based on whether they did forward and backward). Now that each rank has a global understanding of the schedule, they can check their previous and next neighbor ranks to see if they need to recv any activations/gradients from them. Therefore, during execution, each rank is aligned and executing the same time step. ## Benefits - Implementation is faster since 1f1b computation can now be split up in two time steps, 1 for forward and 1 for backward. - Debugging is easier since we can now determine which timestep each rank is hung on - Testing is easier since we can just validate the pipeline order, without running the schedule. This allows us to test on large amount of ranks without actually needing the GPUs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127332 Approved by: https://github.com/wconstab ghstack dependencies: #127084 --- .../pipelining/PipelineSchedule.py | 374 ++++++++++-------- 1 file changed, 219 insertions(+), 155 deletions(-) diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index f7d6d7c1b372..f63f4ed061b2 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -3,7 +3,8 @@ import logging from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from enum import Enum +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union import torch import torch.distributed as dist @@ -26,6 +27,17 @@ logger = logging.getLogger(__name__) +class _ComputationType(Enum): + FORWARD = 1 + BACKWARD = 2 + + +class _Action(NamedTuple): + computation_type: _ComputationType + microbatch_index: int + stage_index: int + + class _PipelineSchedule(ABC): def __init__( self, @@ -649,6 +661,7 @@ def _step_microbatches( class ScheduleInterleaved1F1B(PipelineScheduleMulti): """ The Interleaved 1F1B schedule. + See https://arxiv.org/pdf/2104.04473 for details. Will perform one forward and one backward on the microbatches in steady state and supports multiple stages per rank. When microbatches are ready for multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch @@ -666,9 +679,8 @@ def __init__( # TODO: is this limitation a must? if n_microbatches % self.pp_group_size != 0: raise ValueError( - "Interleaved 1F1B requires the number of microbatches to be a " - f"multiple of the number of pipeline ranks ({self.pp_group_size}), " - f"but got {n_microbatches}." + f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \ + to be a multiple of the number of pipeline ranks ({self.pp_group_size})." ) super().__init__( @@ -680,6 +692,130 @@ def __init__( self.n_local_stages = len(stages) self.rank = stages[0].group_rank + self.group = stages[0].group + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + # ======================================================================== + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = (self.n_local_stages - 1) * self.pp_group_size + # Increment warmup operations by 2 for each hop away from the last stage + warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank) + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.pp_group_size) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.pp_group_size) % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + # Dictionary for tracking {stage index : current microbatch index} + # All stages start with handling microbatch 0 + fwd_stage_mb_index: Dict[int, int] = defaultdict(int) + bwd_stage_mb_index: Dict[int, int] = defaultdict(int) + + # Store the list of operations used for that rank + rank_ops: List[Optional[_Action]] = [] + # Pre-padding, rank starts with no-ops based on the warmup. + for _ in range(rank): + rank_ops.append(None) + + # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup + # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. + # Formula: + # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward + # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) + # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] + # warmup_ops = calculated above + post_warmup_ops = ( + self.n_local_stages * self.pp_group_size + + 2 * (self.pp_group_size - 1 - rank) + ) - (warmup_ops + rank) + + for op in range(total_ops): + # Warmup phase + if op < warmup_ops: + fwd_stage_index = forward_stage_index(op) + # This will assign the current microbatch index and update it as well + fwd_stage_mb_index[fwd_stage_index] = ( + mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(_ComputationType.FORWARD, mb_index, fwd_stage_index) + ) + if op == warmup_ops - 1: + # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up + rank_ops.extend([None] * post_warmup_ops) + # 1F1B Phase (forward and backward) + elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: + fwd_stage_index = forward_stage_index(op) + fwd_stage_mb_index[fwd_stage_index] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(_ComputationType.FORWARD, fwd_mb_index, fwd_stage_index) + ) + + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index) + ) + # Cooldown phase + else: + # During cooldown phase, we need steps to align with 1f1b happening in other ranks + # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None + rank_ops.append(None) + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index) + ) + + # Post padding + for _ in range(self.pp_group_size - rank - 1): + rank_ops.append(None) + return rank_ops def _step_microbatches( self, @@ -689,161 +825,89 @@ def _step_microbatches( losses: Optional[List] = None, ): """ - Operate on the microbatches for interleaved 1f1b schedule (https://arxiv.org/pdf/2104.04473.pdf). + Operate on the microbatches using the interleaved 1f1b schedule. - Highest rank has a warmup (fwd only) count of [len(stages) - 1] * number of PP ranks - and each rank away from highest rank adds 2 warmup steps due to: - - one happened before highest rank's warmup started, - - one waiting for backward result to trickle down from highest rank - - TODO: Interleaved 1F1B does not support using _sorted_batch_p2p() - because it requires recvs and sends from different peers - to execute in the same coalesced operation. As a result, this schedule does + TODO: Interleaved 1F1B does not use sorted_batch_isend_irecv(). As a result, this schedule does not support models with skip connections. """ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - # increment warmup_steps by 2 for each hop away - warmup_steps = (self.n_local_stages - 1) * self.pp_group_size - warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank) - warmup_steps = min(warmup_steps, self._n_microbatches * self.n_local_stages) - fwd_bwd_steps = (self.n_local_stages * self._n_microbatches) - warmup_steps - cooldown_steps = (self.n_local_stages * self._n_microbatches) - fwd_bwd_steps - - assert ( - warmup_steps + fwd_bwd_steps * 2 + cooldown_steps - == self.n_local_stages * self._n_microbatches * 2 - ) - total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps - - logger.debug( - f"rank {self.rank}, warmup_steps {warmup_steps}, " # noqa: G004 - f"1f1b {fwd_bwd_steps}, cooldown_steps {cooldown_steps}" - ) - - def forward_stage_local_index(step): - return (step // self.pp_group_size) % self.n_local_stages - - def backward_stage_local_index(step): - return ( - self.n_local_stages - - 1 - - ((step - warmup_steps) // self.pp_group_size) % self.n_local_stages - ) - - fwd_stage_mb_index: Dict[_PipelineStageBase, int] = defaultdict(int) - bwd_stage_mb_index: Dict[_PipelineStageBase, int] = defaultdict(int) - - # Delay send waits - sends_to_wait: List[dist.Work] = [] - - # Store ops (potentially across steps) - ops: List[dist.P2POp] = [] - - # Warmup Phase (forward only) - for step in range(warmup_steps): - fwd_stage = self._stages[forward_stage_local_index(step)] - - # This will assign the current microbatch index and update it for future steps - fwd_stage_mb_index[fwd_stage] = ( - mb_index := fwd_stage_mb_index[fwd_stage] - ) + 1 - - logger.debug( - f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}" # noqa: G004 - ) - - with record_function(f"Forward {step}"): - ops.extend(fwd_stage.get_fwd_recv_ops()) - if ops: - work = _batch_p2p(ops, desc="warmup_pre_fwd") - work.wait() - ops.clear() - - output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index] - - ops.extend(fwd_stage.get_fwd_send_ops()) - # If we are right before the fwd-bwd step, then we need to delay the send to the next step, - # This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang. - # In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed - if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0): - work = _batch_p2p(ops, desc="warmup_post_fwd") - sends_to_wait.append(work) - ops.clear() - - self._maybe_compute_loss(fwd_stage, output, target_mbs, mb_index) - - # 1F1B Phase (forward and backward) - for step in range(warmup_steps, warmup_steps + fwd_bwd_steps): - fwd_stage = self._stages[forward_stage_local_index(step)] - bwd_stage = self._stages[backward_stage_local_index(step)] - - fwd_stage_mb_index[fwd_stage] = ( - fwd_mb_index := fwd_stage_mb_index[fwd_stage] - ) + 1 - bwd_stage_mb_index[bwd_stage] = ( - bwd_mb_index := bwd_stage_mb_index[bwd_stage] - ) + 1 - - bwd_stage._configure_data_parallel_mode( - bwd_mb_index == self._n_microbatches - 1 - ) - logger.debug( - f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, " # noqa: G004 - f"{bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}" - ) - desc = f"1F1B {step}" - with record_function(desc): - ops.extend(fwd_stage.get_fwd_recv_ops()) - ops.extend(bwd_stage.get_bwd_recv_ops()) - if ops: - work = _batch_p2p(ops, desc=desc) - work.wait() - ops.clear() - - # Forward - output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] - ops.extend(fwd_stage.get_fwd_send_ops()) - self._maybe_compute_loss(fwd_stage, output, target_mbs, fwd_mb_index) - - # Backward - loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) - bwd_stage.backward_one_chunk(loss=loss) - ops.extend(bwd_stage.get_bwd_send_ops()) - - # Cooldown Phase (backward only) - for step in range(warmup_steps + fwd_bwd_steps, total_steps): - bwd_stage = self._stages[backward_stage_local_index(step)] - bwd_stage_mb_index[bwd_stage] = ( - bwd_mb_index := bwd_stage_mb_index[bwd_stage] - ) + 1 - bwd_stage._configure_data_parallel_mode( - bwd_mb_index == self._n_microbatches - 1 - ) - - logger.debug( - f"Rank {self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}" # noqa: G004 - ) - desc = f"Cooldown {step}" - with record_function(desc): - ops.extend(bwd_stage.get_bwd_recv_ops()) - if ops: - work = _batch_p2p(ops, desc=desc + " pre_bwd") - work.wait() - ops.clear() - - loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) - bwd_stage.backward_one_chunk(loss=loss) - - ops.extend(bwd_stage.get_bwd_send_ops()) - if ops: - work = _batch_p2p(ops, desc=desc + " post_bwd") - sends_to_wait.append(work) - ops.clear() - - # Make sure all sends are finished - for work in sends_to_wait: - work.wait() - + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: Dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + prev_rank: int = (self.rank - 1) % self.pp_group_size + next_rank: int = (self.rank + 1) % self.pp_group_size + + for time_step, action in enumerate(self.pipeline_order[self.rank]): + prev_rank_ops = self.pipeline_order[prev_rank] + next_rank_ops = self.pipeline_order[next_rank] + ops: List[dist.P2POp] = [] + if action is not None: + computation_type, mb_index, stage_index = action + if computation_type == _ComputationType.FORWARD: + # perform forward computation + stage = stage_index_to_stage[stage_index] + output = stage.forward_one_chunk( + arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + ops.extend(stage.get_fwd_send_ops()) + elif computation_type == _ComputationType.BACKWARD: + # perform backward computation + stage = stage_index_to_stage[stage_index] + stage._configure_data_parallel_mode( + mb_index == self._n_microbatches - 1 + ) + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk(loss=loss) + ops.extend(stage.get_bwd_send_ops()) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # Look at the neighboring ranks for this current timestep and determine whether + # this current rank needs to do any recv communication + prev_rank_action = None + if time_step < len(prev_rank_ops): + prev_rank_action = prev_rank_ops[time_step] + if prev_rank_action is not None: + computation_type, mb_index, stage_index = prev_rank_action + # Only handle sends for the forward from a previous rank + if computation_type == _ComputationType.FORWARD: + # If not the last stage, then receive fwd activations + if stage_index != self._num_stages - 1: + # TODO: We are assuming that stage will always receive from stage-1 + # however that is not necessarily true of get_fwd_recv_ops + stage = stage_index_to_stage[stage_index + 1] + ops.extend(stage.get_fwd_recv_ops()) + elif computation_type == _ComputationType.BACKWARD: + # Previous rank doing backward has no influence for the current rank forward recv + pass + else: + raise ValueError(f"Unknown computation type {computation_type}") + + next_rank_action = None + if time_step < len(next_rank_ops): + next_rank_action = next_rank_ops[time_step] + if next_rank_action is not None: + computation_type, mb_index, stage_index = next_rank_action + # Only handle receives for the backwards from a next rank + if computation_type == _ComputationType.FORWARD: + # Next rank doing forward has no influence for the current rank backward recv + pass + elif computation_type == _ComputationType.BACKWARD: + # If not the first stage, then receive bwd gradients + if stage_index != 0: + # TODO: We are assuming that stage will always receive from stage+1 + # however that is not necessarily true of get_bwd_recv_ops + stage = stage_index_to_stage[stage_index - 1] + ops.extend(stage.get_bwd_recv_ops()) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # do the communication + if ops: + _batch_p2p(ops).wait() # Return losses if there is a container passed in self._update_losses(self._stages, losses) From 8830b812081150be7e27641fb14be31efbf7dc1e Mon Sep 17 00:00:00 2001 From: Cory Modlin Date: Wed, 5 Jun 2024 00:19:52 +0000 Subject: [PATCH 346/706] [c10d] Add commCreateFromRanks to c10d (#127421) (#127982) This is a duplicate of: https://github.com/pytorch/pytorch/pull/127421 which we can't merge. its landed internally already Summary: `ncclCommCreateFromRanks` - described in this [document](https://docs.google.com/document/d/1QIRkAO4SAQ6eFBpxE51JmRKRAH2bwAHn8OIj69XuFqQ/edit#heading=h.5g71oqe3soez), replaces `ncclCommSplit` in NCCLX versions 2.21.5+. The difference is that `ncclCommCreateFromRanks` is given a list of active ranks and is collective only over those ranks as opposed to `ncclCommSplit` for which you give it a color for every rank including NO_COLOR for inactive ranks and the collective is over the entire world. This diff connects `ncclCommCreateFromRanks` to `c10d` `ncclCommSplit` will still be available at the NCCL API but, in this diff, is not used starting at version 2.21.5 Split the python test and implementation of `split()` for internal FB and external OSS builds. The diff defines `"USE_C10D_NCCL_FBCODE"` as a compiler option. When defined, we use the version of split in the newly created `NCCLUtils.cpp` in the `fb` directory. The `fb` directory is not *shipit*-ed to *github*. The same API is used for `split()` in both the `ncclx` and `nccl` versions adding `ranks` to the API. This argument is not used in the `nccl` version nor in the 2.18 `ncclx` version where `ncclCommSplit()` is used instead of `ncclCommCreateFromRanks()` in `ncclx` This diff was squashed with D57343946 - see D57343946 for additional review comments. Test Plan: for 2.18.3-1 and 2.21.5-1 versions: ``` buck2 run fbcode//mode/opt -c param.use_nccl=True -c fbcode.nvcc_arch=a100 -c hpc_comms.use_ncclx="$VERSION" -c fbcode.enable_gpu_sections=true fbcode//caffe2/test/distributed/fb:test_comm_split_subgroup_x ``` ``` BUILD SUCCEEDED ... ok ---------------------------------------------------------------------- Ran 1 test in 10.210s OK ~/scripts ``` OSS build: `[cmodlin@devgpu003.vll5 ~/fbsource/third-party/ncclx/v2.21.5-1 (e56338cfa)]$ ./maint/oss_build.sh` OSS build output: ``` ... ncclCommHash 197dce9b413e2775 nccl commDesc example_pg Dump from comm 0x4708aa0 rings: [[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]] Dump from comm 0x4708aa0 commDesc: example_pg Dump from comm 0x4708aa0 nRanks: 1 Dump from comm 0x4708aa0 nNodes: 1 Dump from comm 0x4708aa0 node: 0 Dump from comm 0x4708aa0 localRanks: 1 Dump from comm 0x4708aa0 localRank: 0 Dump from comm 0x4708aa0 rank: 0 Dump from comm 0x4708aa0 commHash: "197dce9b413e2775" 2024-05-24T09:02:54.385543 devgpu003:3040664:3040744 [0][AsyncJob]ctran/backends/ib/CtranIb.cc:143 NCCL WARN CTRAN-IB : No active device found. 2024-05-24T09:02:54.385607 devgpu003:3040664:3040744 [0][AsyncJob]ctran/mapper/CtranMapper.cc:187 NCCL WARN CTRAN: IB backend not enabled Created NCCL_SPLIT_TYPE_NODE type splitComm 0x11c76d0, rank 0 ~/fbsource/third-party/ncclx/v2.21.5-1 ``` Reviewed By: wconstab, wesbland Differential Revision: D56907877 Fixes #ISSUE_NUMBER Co-authored-by: Cory Modlin Pull Request resolved: https://github.com/pytorch/pytorch/pull/127982 Approved by: https://github.com/izaitsevfb --- test/distributed/test_c10d_nccl.py | 13 ++++++++++-- torch/csrc/distributed/c10d/NCCLUtils.cpp | 20 +++++++++++++++++++ torch/csrc/distributed/c10d/NCCLUtils.hpp | 14 ++----------- .../distributed/c10d/ProcessGroupNCCL.cpp | 16 ++++++++++++--- .../distributed/c10d/ProcessGroupNCCL.hpp | 2 +- 5 files changed, 47 insertions(+), 18 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 4fea855a85b9..baf2adb1fb2d 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -599,6 +599,9 @@ def test_comm_split_optimization(self): @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @skip_but_pass_in_sandcastle_if( + torch.cuda.nccl.version()[-1] == "x", "NCCL test not for NCCLX" + ) def test_comm_split_subgroup(self): # Test `ncclCommSplit` for smaller subgroups of the world when # we've passed a specific device_id to init_process_group. @@ -614,12 +617,18 @@ def test_comm_split_subgroup(self): # rank 0 hasn't split yet, but rank 1 did for the # nocolor... so split count matches rank count coincidentally # in each of the proceses this test spawned! - self.assertEqual(backend.comm_split_count(), self.rank) + # when using ncclCommCreateFromRanks() in version 2.21+, + # unused ranks are not included in split + version = torch.cuda.nccl.version() + is_nccl_2_21 = version >= (2, 21) + exp_count = 0 if (is_nccl_2_21 or self.rank == 0) else 1 + self.assertEqual(backend.comm_split_count(), exp_count) if self.rank == 0: dist.broadcast(tensor, 0, group=ng) # now everyone has split because rank 0 has performed a comm - self.assertEqual(backend.comm_split_count(), 1) + exp_count = 1 if not is_nccl_2_21 else (1 if self.rank == 0 else 0) + self.assertEqual(backend.comm_split_count(), exp_count) self.assertEqual(tensor, original_tensor) @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index e26ab22f1a9f..e2771641af69 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -63,6 +63,26 @@ void NCCLComm::waitUntilInitialized(int timeoutSecs) { } } +#if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2) +// last argument to split() API is not used to support +// multiple implementations +std::shared_ptr NCCLComm::split( + NCCLComm* source, + int color_id, + int rank, + ncclConfig_t& config, + std::vector& ranks_ull) { + auto comm = std::make_shared(); + C10D_NCCL_CHECK( + ncclCommSplit( + source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config), + c10::nullopt); + ++source->ncclCommSplitCounter_; + ncclCommUserRank(comm->ncclComm_, &comm->rank_); + return comm; +} +#endif + std::string getNcclVersion() { static c10::once_flag ncclGetVersionFlag; static std::string versionString; diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 165c514bbd27..7617f929feb3 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -286,22 +286,12 @@ class NCCLComm { } #endif -#ifdef NCCL_HAS_COMM_SPLIT static std::shared_ptr split( NCCLComm* source, int color_id, int rank, - ncclConfig_t& config) { - auto comm = std::make_shared(); - C10D_NCCL_CHECK( - ncclCommSplit( - source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config), - c10::nullopt); - ++source->ncclCommSplitCounter_; - comm->rank_ = rank; - return comm; - } -#endif + ncclConfig_t& config, + std::vector& ranks_ull); #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) std::unordered_map ncclCommDump() { diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 2e55bfdb6f34..26381207ca7d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -933,7 +933,12 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) { LOG(INFO) << logPrefix() << "Performing nocolor split on backend device " << device << ", key " << key << ", i am " << this; auto comm = getNCCLComm(key, device, OpType::ALLREDUCE); - NCCLComm::split(comm.get(), NCCL_SPLIT_NOCOLOR, rank_, options_->config); + NCCLComm::split( + comm.get(), + NCCL_SPLIT_NOCOLOR, + rank_, + options_->config, + options_->global_ranks_in_group); #endif } @@ -2082,6 +2087,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( numRanks = 2; rank = p2pRank; } + // Get the device index auto deviceIndex = device.index(); gpuGuard.set_index(deviceIndex); @@ -2098,13 +2104,17 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( auto& parentComm = dit->second; if (parentComm != nullptr && !parentComm->isAborted()) { ncclComm = NCCLComm::split( - parentComm.get(), options_->split_color, rank, options_->config); + parentComm.get(), + options_->split_color, + rank, + options_->config, + options_->global_ranks_in_group); } } } #endif - // To simplify conditioonal nesting, just create the ncclComms[i] + // To simplify conditional nesting, just create the ncclComms[i] // entry if it hasn't been yet rather than untangling the // conditions that might have resulted in a split above. if (!ncclComm) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 8a3b7b1b5c21..f36ebdeb16e9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -643,7 +643,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { uint64_t getSequenceNumberForGroup() override; // Return the total number of splits the communicators held by this process - // group have performed. + // group have performed. Counts ncclCommCreateFromRanks() for ncclx v2.21.5+ uint64_t getCommSplitCounter() const; void registerOnCompletionHook( From 6c07e2c9300df5988e591e2128f4efed107af5e4 Mon Sep 17 00:00:00 2001 From: Dmovic <944388576@qq.com> Date: Wed, 5 Jun 2024 02:03:02 +0000 Subject: [PATCH 347/706] fix redundant tensor (#127850) As title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127850 Approved by: https://github.com/mikaylagawarecki --- torch/nn/modules/pooling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 38acd9fb430a..3f02bb63a849 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -380,7 +380,7 @@ class MaxUnpool2d(_MaxUnpoolNd): [ 0., 0., 0., 0.], [ 0., 14., 0., 16.]]]]) >>> # Now using output_size to resolve an ambiguous size for the inverse - >>> input = torch.torch.tensor([[[[ 1., 2., 3., 4., 5.], + >>> input = torch.tensor([[[[ 1., 2., 3., 4., 5.], [ 6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.]]]]) From 8e496046e5c633868e86200807a98b03f8402583 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Wed, 5 Jun 2024 02:13:46 +0000 Subject: [PATCH 348/706] Update torch-xpu-ops pin (ATen XPU implementation) (#127879) Support AMP GradScaler. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127879 Approved by: https://github.com/EikanWang --- third_party/xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 7131a86c765c..07950b62467f 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -bd76ae2a5a233ae57911c1de81322dcea19493c1 +97d692eb8c4b3afab17700a2fd918adcea0cba45 From a135776307500acce6c9bd955452ad76337d826d Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 4 Jun 2024 13:55:14 -0700 Subject: [PATCH 349/706] Remove tensor subclass detection logic from weights_only unpickler (#127808) Remove logic to auto-detect and allow subclasses that did not override certain methods from the weights_only unpickler from https://github.com/pytorch/pytorch/pull/124331 for 2.4 release Subclasses should be loadable using `torch.serialization.add_safe_globals` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127808 Approved by: https://github.com/malfet --- test/test_serialization.py | 211 +++---------------------------- torch/_C/__init__.pyi.in | 1 - torch/_weights_only_unpickler.py | 150 +--------------------- torch/csrc/Module.cpp | 17 --- torch/serialization.py | 2 +- 5 files changed, 25 insertions(+), 356 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index e83cafd3f3d8..f22331831c39 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -18,7 +18,6 @@ from collections import OrderedDict from copy import deepcopy from itertools import product -from types import ModuleType from torch._utils_internal import get_file_path_2 from torch._utils import _rebuild_tensor @@ -4111,23 +4110,6 @@ def __setstate__(self, state): class TestEmptySubclass(torch.Tensor): ... -# ONLY use SubclassSpoof subclasses for the subclass spoof tests since we modify them -# Cannot define locally in test or pickle will fail. -class TestEmptySubclassSpoof(TestEmptySubclass): - ... - -class TestWrapperSubclassSpoof(TestWrapperSubclass): - ... - -class RebuildFromTypeV2Spoof(torch.Tensor): - def __new__(cls, elem, naughty, **kwargs): - if naughty: - raise RuntimeError("naughty") - return super().__new__(cls, elem) - - def __reduce_ex__(self, protocol): - return (torch._tensor._rebuild_from_type_v2, (RebuildFromTypeV2Spoof, torch.Tensor, (True,), {})) - class TestSubclassSerialization(TestCase): def test_tensor_subclass_wrapper_serialization(self): @@ -4207,201 +4189,42 @@ def test_empty_class_serialization(self): f.seek(0) tensor2 = torch.load(f) - def _create_bad_func(self, name): - def bad_func(self, *args, **kwargs): - raise RuntimeError(f"running {name}") - return bad_func - - @parametrize("wrapper", (True, False)) - def test_tensor_subclass_method_spoofing(self, wrapper): - ''' - This tests seeks to do the following: - - determine which methods of a tensor subclass might be called during unpickling (weights_only=False) - we consider these methods "risky" for weights_only - - ensure that we ban overriding this group of methods on a tensor subclass by default (weights_only=True) - - ensure that tensor subclass that doesn't override any of these can be unpickled (weights_only=True) - - We achieve this by overriding all methods of a tensor subclass to raise a RuntimeError - when called. We then try to unpickle a tensor subclass with weights_only=False and ensure that - only the RuntimeErrors that we expect are thrown. - - We then load with weights_only and ensure that weights_only will fail unless all the risky methods - are not overriden by resetting the risky methods to the non-overriden version in a loop and calling load. - The final weights_only load call when all the risky methods are no longer overriden. - ''' - subclass = TestWrapperSubclassSpoof if wrapper else TestEmptySubclassSpoof - t = subclass(torch.randn(2, 3)) - # To trigger setattr for the non-wrapper case - if not wrapper: - t.foo = 'bar' - inp = {'weight': t} - - with TemporaryFileName() as f: - torch.save(inp, f) - loaded = torch.load(f, weights_only=True) - self.assertEqual(loaded['weight'], inp['weight']) - - restore_methods = dict() - methods = [func for func in dir(subclass) if callable(getattr(subclass, func))] - for method in methods: - if method != "__class__": - restore_methods[method] = getattr(subclass, method) - setattr(subclass, method, self._create_bad_func(method)) - # These additional methods might be called during getattr or setattr - # but are not in methods above (not defined on tensor base class) - subclass.__get__ = self._create_bad_func("__get__") - subclass.__set__ = self._create_bad_func("__set__") - subclass.__getattr__ = self._create_bad_func("__getattr__") - restore_methods["__get__"] = None - restore_methods["__getattr__"] = None - restore_methods["__set__"] = None - - try: - # Check that weights_only=False load raises the RuntimeErrors we expect - with self.assertRaisesRegex(RuntimeError, "running __getattribute__"): - torch.load(f, weights_only=False) - subclass.__getattribute__ = restore_methods['__getattribute__'] - with self.assertRaisesRegex(RuntimeError, "running __setstate__"): - torch.load(f, weights_only=False) - subclass.__setstate__ = restore_methods['__setstate__'] - with self.assertRaisesRegex(RuntimeError, "running __setattr__"): - torch.load(f, weights_only=False) - subclass.__setattr__ = restore_methods['__setattr__'] - # should finally work - torch.load(f, weights_only=False) - - # Check that weights_only=True catches that risky methods are overriden - subclass.__setstate__ = self._create_bad_func("__setstate__") - subclass.__getattribute__ = self._create_bad_func("__getattribute__") - subclass.__setattr__ = self._create_bad_func("__setattr__") - with self.assertRaisesRegex(pickle.UnpicklingError, - "methods: __getattribute__=True __getattr__=True __get__=True " - "__setattr__=True __set__=True __setstate__=True"): - torch.load(f, weights_only=True) - risky_methods = ['__get__', '__set__', '__getattr__', '__setattr__', '__getattribute__', '__setstate__'] - for i, meth in enumerate(risky_methods): - setattr(subclass, meth, restore_methods[meth]) - if i != len(risky_methods) - 1: - # When the given methods are not all back to default, load should still throw - # but reflect which methods are no longer overriden - with self.assertRaisesRegex(pickle.UnpicklingError, f"{meth}=False"): - torch.load(f, weights_only=True) - else: - # When the given methods are all back to default, weights_only load should finally work - loaded = torch.load(f, weights_only=True) - finally: - for method, func in restore_methods.items(): - setattr(subclass, method, func) - a = subclass(torch.randn(2, 3)) - @skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined") def test_safe_globals_for_weights_only(self): ''' Tests import semantic for tensor subclass and the {add/get/clear}_safe_globals APIs ''' - # Needed to prevent UnboundLocalError: local variable 'TwoTensor' referenced before assignment - global TwoTensor t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3)) p = torch.nn.Parameter(t) sd = OrderedDict([('t', t), ('p', p)]) with tempfile.NamedTemporaryFile() as f: torch.save(sd, f) - # unimport TwoTensor - try: - del sys.modules['torch.testing._internal.two_tensor'] - - # Loading tensor subclass with weights_only=True should fail - # if tensor subclass has not been imported - with self.assertRaisesRegex(pickle.UnpicklingError, - "expect `torch.testing._internal.two_tensor` to be present in `sys.modules`"): - f.seek(0) - sd = torch.load(f, weights_only=True) - # Loading tensor subclass with weights_only=True should work - # if target methods are not overriden and user has imported the subclass - from torch.testing._internal.two_tensor import TwoTensor + # Loading tensor subclass with weights_only=True should fail + # since tensor subclass is not in safe_globals + with self.assertRaisesRegex(pickle.UnpicklingError, + "Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor"): f.seek(0) sd = torch.load(f, weights_only=True) + + # Loading tensor subclass should work if the class is marked safe + f.seek(0) + try: + torch.serialization.add_safe_globals([TwoTensor]) + self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor]) + sd = torch.load(f, weights_only=True) self.assertEqual(sd['t'], t) self.assertEqual(sd['p'], p) - # Loading tensor subclass with weights_only=True should fail - # if __setstate__ is overriden + # Should fail again when safe globals are cleared + torch.serialization.clear_safe_globals() f.seek(0) - restore_setstate = TwoTensor.__setstate__ - try: - TwoTensor.__setstate__ = lambda self, state: self.__dict__.update(state) - with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"): - torch.load(f, weights_only=True) - - # Loading tensor subclass with overriden __setstate__ with weights_only=True should work - # if the class is marked safe - f.seek(0) - torch.serialization.add_safe_globals([TwoTensor]) - self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor]) - sd = torch.load(f, weights_only=True) - self.assertEqual(sd['t'], t) - self.assertEqual(sd['p'], p) - - # Should fail again when safe globals are cleared - torch.serialization.clear_safe_globals() - f.seek(0) - with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"): - torch.load(f, weights_only=True) - finally: - TwoTensor.__setstate__ = restore_setstate + with self.assertRaisesRegex(pickle.UnpicklingError, + "Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor"): + torch.load(f, weights_only=True) finally: - from torch.testing._internal.two_tensor import TwoTensor - - - def test_tensor_subclass_parent_module_method_spoofing(self): - ''' - Tests that weights_only load does not call any methods of the parent module - that contains the tensor subclass. - - We achieve this by overriding all methods of a module we add to sys.modules to raise a RuntimeError - when called. We then try to unpickle a tensor subclass with weights_only=True and ensure that - no RuntimeErrors are thrown. - ''' - # Simulates user doing `import spoof_mod` where `spoof_mod` contains `TestEmptySubclass` - class SpoofModule(ModuleType): - pass - - spoof_mod = SpoofModule('bla') - spoof_mod.TestEmptySubclass = TestEmptySubclass - inp = {'weight': TestEmptySubclass(torch.randn(2, 3))} - TestEmptySubclass.__module__ = 'spoof_mod' - sys.modules['spoof_mod'] = spoof_mod - - try: - with TemporaryFileName() as f: - torch.save(inp, f) - torch.load(f, weights_only=True) - restore_methods = dict() - methods = [func for func in dir(SpoofModule) if callable(getattr(SpoofModule, func))] - for method in methods: - if method != "__class__": - restore_methods[method] = getattr(SpoofModule, method) - setattr(SpoofModule, method, self._create_bad_func(method)) - SpoofModule.__get__ = self._create_bad_func("__get__") - SpoofModule.__getattr__ = self._create_bad_func("__getattr__") - loaded = torch.load(f, weights_only=True) - self.assertEqual(loaded['weight'], inp['weight']) - finally: - TestEmptySubclass.__module__ = __name__ - del sys.modules['spoof_mod'] - - def test_rebuild_from_type_v2_spoof(self): - t = RebuildFromTypeV2Spoof(torch.randn(2, 3), False) - inp = {'weight': t} - - with TemporaryFileName() as f: - torch.save(inp, f) - # subclass will be pushed onto unpickler's stack as a string - # and only gets converted to the type if it is argument 1 to _rebuild_from_type_v2 - with self.assertRaisesRegex(TypeError, "'str' object is not callable"): - loaded = torch.load(f, weights_only=True) + torch.serialization.clear_safe_globals() @unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda") def test_tensor_subclass_map_location(self): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 99985151c19f..bcc26350a896 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1196,7 +1196,6 @@ def _has_storage(x: Tensor) -> _bool: ... def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ... def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ... -def _check_tp_alloc_is_default(cls: Type) -> _bool: ... # NB: There is no Capsule type in typing, see # https://code.activestate.com/lists/python-dev/139675/ diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 6c9f3b61ae8b..9cc74c05e45f 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -23,7 +23,6 @@ import functools as _functools from collections import Counter, OrderedDict -from inspect import getattr_static from pickle import ( APPEND, APPENDS, @@ -64,8 +63,8 @@ UnpicklingError, ) from struct import unpack -from sys import maxsize, modules -from typing import Any, Dict, List, Type +from sys import maxsize +from typing import Any, Dict, List import torch @@ -170,11 +169,6 @@ def __init__(self, file, *, encoding: str = "bytes"): self.readline = file.readline self.read = file.read self.memo: Dict[int, Any] = {} - # tensor subclass types found from GLOBAL instructions that have passed the criteria - # to be allowed as the second argument to `torch._tensor._rebuild_from_type_v2` - # This enables rebuilding of tensor subclasses defined outside the `torch` package. - # See [Note: Criteria for allowing out-of-core tensor subclasses] for details on the criteria. - self.tensor_subclasses_found: Dict[str, Type] = {} def load(self): """Read a pickled object representation from the open file. @@ -201,121 +195,11 @@ def load(self): elif full_path in _get_user_allowed_globals(): self.append(_get_user_allowed_globals()[full_path]) else: - # The logic in this branch handles user-defined tensor subclasses. - # We can automatically allow and raise and error for anything that is not provably safe. - # [Note: Criteria for allowing out-of-core tensor subclasses] - # GLOBAL '.' instructions will get the class and - # push the string (not the actual type) while adding the type to the dictionary keyed - # by the string onto the unpickler's stack if they satisfy the following conditions: - # (1) The that defines them is in `sys.modules` - # (we will use getattr_static to access it to ensure no code execution) - # (2) They inherit from `torch.Tensor` - # (2) The class is not overriding any of the `torch.Tensor` methods listed here: - # `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, `__set__`, - # and `tp_alloc` - # The methods that we ban overriding were selected in a test-driven manner - # by overriding every callable method on a tensor subclass and determinining - # which might get called during unpickling. - # When executing REDUCE, the string will be appropriately converted back to the type only - # for `torch._tensor._rebuild_from_type_v2` as other use of the class could use methods - # we didn't audit. - if module == "__builtin__": - raise RuntimeError( - f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " - "Please use `torch.serialization.add_safe_globals` to allowlist this global " - "if you trust this class/function." - ) - elif module not in modules: - # TODO: add a link here to a doc that explains to users what we mean by trust - raise RuntimeError( - f"Found GLOBAL `{full_path}` instruction in the pickle file but `{full_path}` was " - f"not in the pre-defined list of allowed globals that are considered safe by the " - "weights_only unpickler for rebuilding state_dicts. This is the expected behavior if " - f"`{full_path}` is a class or function that is not in the list of allowed globals " - f"If `{full_path}` is NOT a tensor subclass, you might consider" - "`torch.serialization.add_safe_globals` if it is appropriate. However, if it is a " - "user-defined tensor subclass not defined in the `torch` package, this error might arise " - f"as we expect `{module}` to be present in `sys.modules` (i.e. it " - "must be imported in the current environment), but this was not the case. " - f"If you intend to unpickle a tensor subclass `{full_path}` please import `{name}` from " - f"`{module}`. Note that having this imported will *only* allow the type `{full_path}` to " - "be passed as the second argument to `torch._tensor._rebuild_from_type_v2`, which should " - "enable the tensor subclass to be unpickled without any arbitrary code execution as long " - # If the user imports and these are overridden the next error will prompt them to use - # torch.serialization.add_safe_globals. - "a sa pre-defined list of methods called when unpickling are not overridden. In " - "particular, the methods are `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, " - "`__set__`, as well as the implementation of `tp_alloc`." - ) - else: - try: - class_type = getattr_static(modules[module], name) - except AttributeError as e: - raise AttributeError( - "For safety during weights_only loading, we use inspect.getattr_state to " - f"get {name} from {module}, if {module} implements the descriptor protocol, " - "__getattr__ or __getattribute__ these will not be called." - ) from e - # None of the objects here contain any data from the pickle so this is safe - if isinstance(class_type, type) and issubclass( - class_type, torch.Tensor - ): - # getattr is called by the getattr call in `_rebuild_from_type_v2` - custom_get_attribute = ( - class_type.__getattribute__ - is not torch.Tensor.__getattribute__ - ) - custom_get = ( - getattr_static(class_type, "__get__", None) is not None - ) - custom_get_attr = ( - getattr_static(class_type, "__getattr__", None) - is not None - ) - # Tensor.__setstate__ might be called in `_rebuild_from_type_v2` - custom_set_state = ( - class_type.__setstate__ is not torch.Tensor.__setstate__ - ) - # setattr is called in `torch._utils._set_obj_state` - custom_set_attr = ( - class_type.__setattr__ is not object.__setattr__ - ) - custom_set = ( - getattr_static(class_type, "__set__", None) is not None - ) - # tp_alloc is called by `Tensor._rebuild_wrapper_subclass` and `Tensor.as_subclass` - has_custom_tp_alloc = ( - not torch._C._check_tp_alloc_is_default(class_type) - ) - custom_methods = { - "__getattribute__": custom_get_attribute, - "__getattr__": custom_get_attr, - "__get__": custom_get, - "__setattr__": custom_set_attr, - "__set__": custom_set, - "__setstate__": custom_set_state, - "tp_alloc": has_custom_tp_alloc, - } - if any(custom_methods.values()): - error = "" - for k, v in custom_methods.items(): - error += f" {k}={v}" - raise RuntimeError( - f"Trying to unpickle tensor subclass `{full_path}` that has defined a custom " - f"version for one of these methods:{error}. Please check whether you trust these " - "methods and allowlist the subclass with `torch.serialization.add_safe_globals` if so." - ) - # push the string full_path onto the stack (in REBUILD, there is special logic to - # access this from tensor_subclasses_found for rebuild_from_type_v2) - self.tensor_subclasses_found[full_path] = class_type - self.append(full_path) - else: - raise RuntimeError( - f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " - "Please use `torch.serialization.add_safe_globals` to allowlist this global " - "if you trust this class/function." - ) - + raise RuntimeError( + f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " + "Please use `torch.serialization.add_safe_globals` to allowlist this global " + "if you trust this class/function." + ) elif key[0] == NEWOBJ[0]: args = self.stack.pop() cls = self.stack.pop() @@ -332,26 +216,6 @@ def load(self): raise RuntimeError( f"Trying to call reduce for unrecognized function {func}" ) - # Special handling for tensor subclass type found in GLOBAL that is pushed - # onto stack as str to prevent it from being used anywhere except the - # second arg of _rebuild_from_type_v2 and within argument tuple for _rebuild_wrapper_subclass - # _rebuild_from_type_v2 is called with args (func, type, func_args, state) - # where both type and, when func is rebuild_wrapper_subclass, func_args[0] could be the subclass type - # Since we pushed these subclass types onto the stack as strings, convert them to the actual - # type here. - if func is torch._tensor._rebuild_from_type_v2 and type(args[1]) is str: - args_after = args[2:] - if ( - args[0] is torch._utils._rebuild_wrapper_subclass - and type(args[2][0]) is str - ): - new_arg_tuple = ( - self.tensor_subclasses_found[args[2][0]], - ) + args[2][1:] - args_after = (new_arg_tuple,) + args[3:] - args = ( - args[:1] + (self.tensor_subclasses_found[args[1]],) + args_after - ) self.stack[-1] = func(*args) elif key[0] == BUILD[0]: state = self.stack.pop() diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 063cdbbde0b7..dbd58657b951 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -414,19 +414,6 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { END_HANDLE_TH_ERRORS } -PyObject* THPModule_check_tp_alloc_is_default( - PyObject* _unused, - PyObject* cls) { - HANDLE_TH_ERRORS - TORCH_CHECK_TYPE( - PyType_Check(cls), - "cls must be a type (got ", - Py_TYPE(cls)->tp_name, - ")"); - return PyBool_FromLong(Py_TYPE(cls)->tp_alloc == PyType_GenericAlloc); - END_HANDLE_TH_ERRORS -} - PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) { // adds a __doc__ string to a function, similar to numpy's arr_add_docstring static std::vector all_docs; @@ -1273,10 +1260,6 @@ static PyMethodDef TorchMethods[] = { // NOLINT {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr}, {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr}, {"_swap_tensor_impl", THPModule_swap_tensor_impl, METH_VARARGS, nullptr}, - {"_check_tp_alloc_is_default", - THPModule_check_tp_alloc_is_default, - METH_O, - nullptr}, {"_init_names", THPModule_initNames, METH_O, nullptr}, {"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr}, {"_set_default_tensor_type", diff --git a/torch/serialization.py b/torch/serialization.py index 9401c775a510..a13363d037ac 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -921,7 +921,7 @@ def load( pickle_module: module used for unpickling metadata and objects (has to match the :attr:`pickle_module` used to serialize file) weights_only: Indicates whether unpickler should be restricted to - loading only tensors, tensor subclasses, primitive types, dictionaries + loading only tensors, primitive types, dictionaries and any types added via :func:`torch.serialization.add_safe_globals`. mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory. Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they From ce4436944c856cf080f1266f31713a9a08745ab9 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 5 Jun 2024 02:14:41 +0000 Subject: [PATCH 350/706] Fix IOS builds (#127985) IOS builds fail these days, fix it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127985 Approved by: https://github.com/ezyang --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b36f52c93fe0..cd11ffdf7333 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -898,7 +898,7 @@ if(USE_SLEEF_FOR_ARM_VEC256) endif() # Enable sleef on macOS with Apple silicon by default -if((${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") AND (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "arm64")) +if((${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") AND ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "arm64")) message(STATUS "Running on macOS with Apple silicon") string(APPEND CMAKE_CXX_FLAGS " -DAT_BUILD_ARM_VEC256_WITH_SLEEF") add_definitions(-DAT_BUILD_ARM_VEC256_WITH_SLEEF) From 71e684bfae9d6ba5d779d2c6e68e6f23e310bc99 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 5 Jun 2024 02:16:48 +0000 Subject: [PATCH 351/706] [BE][Mac] Add missing prototypes (#127988) Really confused how CI did not catch this one, but this triggers missing prototype erros if compiled from scratch on MacOS Sonoma using clang-15 Fixes https://github.com/pytorch/pytorch/issues/127942 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127988 Approved by: https://github.com/Skylion007, https://github.com/huydhn --- aten/src/ATen/native/BlasKernel.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index 66e39c218a06..97f04c9968c8 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -110,6 +110,23 @@ float fp16_dot_with_fp32_arith( const float16_t* vec1, const float16_t* vec2, int64_t len); + +void bf16_gemv_trans( + const int m, + const int n, + const at::BFloat16 alpha, + const at::BFloat16* a, + const int lda, + const at::BFloat16* x, + const int incx, + const at::BFloat16 beta, + at::BFloat16* y, + const int incy); + +float bf16_dot_with_fp32_arith( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + int64_t len); #endif template From 55a4ef80c4b8c822411c7206dd615bb4d82143de Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 3 Jun 2024 22:18:43 -0700 Subject: [PATCH 352/706] [pipelining] test pipeline_order in schedule (#127559) Add a unittest to test validate the pipeline order for different `num_stages`, `num_microbatches`, `num_world_size` combinations. This doesn't actually run the schedule but just validates the ordering of microbatches processed is valid, therefore doesn't require GPUs / multiple processes. Will add more combinations and negative tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127559 Approved by: https://github.com/wconstab ghstack dependencies: #127084, #127332 --- test/distributed/pipelining/test_schedule.py | 220 ++++++++++++++++++ .../pipelining/PipelineSchedule.py | 9 + 2 files changed, 229 insertions(+) diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 3cefc6da2322..462ba83da07e 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -1,9 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import copy +import logging import os import sys import tempfile +import unittest +from typing import Dict, List, Optional, Tuple from model_registry import ModelWithKwargs, MultiMLP @@ -18,6 +21,8 @@ ScheduleInterleaved1F1B, ScheduleLoopedBFS, ) +from torch.distributed.pipelining.PipelineSchedule import _Action, _ComputationType +from torch.distributed.pipelining.PipelineStage import _PipelineStageBase from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, @@ -29,6 +34,7 @@ skip_but_pass_in_sandcastle_if, ) +logger = logging.getLogger(__name__) d_hid = 512 batch_size = 256 @@ -36,6 +42,18 @@ torch.manual_seed(0) +class MockPipelineStage(_PipelineStageBase): + def __init__(self, *args, **kwargs): + # Mock the necessary attributes + self.num_stages = kwargs.get("num_stages", 1) + self.group_size = kwargs.get("group_size", 1) + self.group_rank = kwargs.get("group_rank", 0) + self.group = kwargs.get("group", None) + + def _create_grad_recv_info(self, *args, **kwargs): + return None + + class ScheduleTest(MultiProcContinousTest): @classmethod def backend_str(cls) -> str: @@ -384,7 +402,209 @@ def test_grad_with_manual_interleaved(self, ScheduleClass): instantiate_parametrized_tests(ScheduleTest) + +def format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]): + import itertools + + # Calculate the maximum number of steps across all ranks + num_steps = max(len(actions) for actions in pipeline_order.values()) + step_labels = [ + "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps) + ] + # Sorting the dictionary by keys and retrieving values in that order + rank_actions = [ + pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order) + ] + # Transpose the list of lists (rows to columns) + transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue="")) + # Generate column labels for ranks + num_ranks = len(pipeline_order) + rank_labels = ["Rank " + str(i) for i in range(num_ranks)] + # Calculate the maximum length of each column, considering labels + max_lengths = [ + max(len(str(item)) if item is not None else 0 for item in col) + for col in zip(step_labels, *transposed_actions) + ] + # Format the header row with rank labels + header_row = " " * (len(step_labels[0]) + 2) + " ".join( + f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels) + ) + # Format each row with its corresponding label + formatted_rows = [ + f"{label}: " + + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row)) + for label, row in zip(step_labels, transposed_actions) + ] + # Join the rows into a single string + formatted_table = ( + "=========== ALL_RANK_ACTIONS ===========\n" + + header_row + + "\n" + + "\n".join(formatted_rows) + + "\n" + ) + return formatted_table + + +class TestSchedulePlan(unittest.TestCase): + def _validate_pipeline_order( + self, + pipeline_order: Dict[int, List[Optional[_Action]]], + num_microbatches: int, + num_stages: int, + ): + """ + pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...] + + Validating that the pipeline order follows the rules: + 1. Forward action for a microbatch must be before the Backward action for that microbatch + 2. Recv for a microbatch must be before the send for that microbatch + 3. Microbatch index is handled in sequential order for each stage + 4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it + 5. Same microbatch cannot be handled in the same time step across ranks + """ + # microbatch_index: (current computation type, current stage) + error_msg = [] + microbatch_process_info: Dict[int, Tuple(_ComputationType, int)] = {} + max_timestep = max(len(rank_list) for rank_list in pipeline_order.values()) + for timestep in range(max_timestep): + error_msg = [] + current_timestep_actions = [] + for rank in range(len(pipeline_order)): + action = ( + pipeline_order[rank][timestep] + if timestep < len(pipeline_order[rank]) + else None + ) + if action is not None: + current_timestep_actions.append(action) + + # TODO: enable this + # if len(current_timestep_actions) == 0: + # error_msg.append( + # "All actions were None, there is an unnecessary gap in the schedule" + # ) + + # Ensure that no microbatch is operated on twice in current_timestep_actions + unique_microbatch_indices = { + action[1] for action in current_timestep_actions + } + if len(unique_microbatch_indices) != len(current_timestep_actions): + error_msg.append( + "Duplicate microbatch index found in current_timestep_actions" + ) + + # Add additional checks for other rules here... + for action in current_timestep_actions: + computation_type, mb_index, stage_index = action + + if mb_index >= num_microbatches: + error_msg.append(f"Microbatch index {mb_index} out of range") + + # first microbatch + if mb_index not in microbatch_process_info: + if computation_type != _ComputationType.FORWARD or stage_index != 0: + error_msg.append(f"Incorrect start for microbatch {mb_index}") + microbatch_process_info[mb_index] = (computation_type, stage_index) + else: + # if the microbatch is included, check that the current stage is right after prev + prev_computation, prev_stage = microbatch_process_info[mb_index] + if prev_computation == _ComputationType.FORWARD: + if prev_stage == num_stages - 1: + expected_stage = num_stages - 1 + expected_computation = _ComputationType.BACKWARD + else: + expected_stage = prev_stage + 1 + expected_computation = _ComputationType.FORWARD + elif prev_computation == _ComputationType.BACKWARD: + if prev_stage == 0: + error_msg.append( + f"[{mb_index=}] already finished backward computation" + ) + expected_stage = None + expected_computation = None + else: + expected_stage = prev_stage - 1 + expected_computation = _ComputationType.BACKWARD + else: + raise ValueError( + f"Computation type {prev_computation} not supported" + ) + + if expected_computation is not None: + if expected_computation != computation_type: + error_msg.append( + f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}" + ) + + if expected_stage != stage_index: + error_msg.append( + f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}" + ) + + microbatch_process_info[mb_index] = ( + expected_computation, + expected_stage, + ) + + if len(error_msg) != 0: + self.fail(f"Error at timestep {timestep}: " + ",".join(error_msg)) + + def test_pipeline_order(self): + # Define a list of test cases with varying num_local_stages, num_microbatches, and group_size + # These should succeed since num_microbatches % group_size == 0 + test_cases = [ + # small number of stages + (2, 2, 2), + (2, 4, 4), + (2, 8, 4), + (2, 8, 8), + (4, 4, 4), + (4, 8, 4), + (4, 8, 8), + # large microbatches + (4, 16, 4), + (4, 32, 4), + (4, 64, 4), + # large groups + (4, 16, 16), + (4, 32, 32), + (4, 128, 64), + # odd num pipeline stages + (3, 2, 2), + (3, 8, 2), + (3, 12, 4), + # odd group_sizes + (4, 6, 3), + (4, 10, 5), + ] + for num_local_stages, num_microbatches, group_size in test_cases: + with self.subTest( + num_local_stages=num_local_stages, + num_microbatches=num_microbatches, + group_size=group_size, + ): + print(f"{num_local_stages=} {num_microbatches=} {group_size=}") + num_stages = num_local_stages * group_size + stages = [ + MockPipelineStage(group_size=group_size, num_stages=num_stages) + for i in range(num_local_stages) + ] + + schedule = ScheduleInterleaved1F1B(stages, num_microbatches) + # print(format_pipeline_order(schedule.pipeline_order)) + self._validate_pipeline_order( + schedule.pipeline_order, num_microbatches, num_stages + ) + + if __name__ == "__main__": + # Run only the TestSchedulePlan tests (single process) + loader = unittest.TestLoader() + suite = loader.loadTestsFromTestCase(TestSchedulePlan) + runner = unittest.TextTestRunner() + runner.run(suite) + # Check if GPU and NCCL are available if not ( dist.is_available() diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index f63f4ed061b2..5c04ac824e69 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -31,12 +31,21 @@ class _ComputationType(Enum): FORWARD = 1 BACKWARD = 2 + def __str__(self): + if self == _ComputationType.FORWARD: + return "F" + else: + return "B" + class _Action(NamedTuple): computation_type: _ComputationType microbatch_index: int stage_index: int + def __repr__(self): + return f"{self.computation_type}{self.microbatch_index}_s{self.stage_index}" + class _PipelineSchedule(ABC): def __init__( From d5cb5d623aefcb0928d80f226d1a4962706ada38 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Jun 2024 03:57:58 +0000 Subject: [PATCH 353/706] Revert "Complete revamp of float/promotion sympy handling (#126905)" This reverts commit fb696ef3aa34e20c0fef1c0210a397abd3ea5885. Reverted https://github.com/pytorch/pytorch/pull/126905 on behalf of https://github.com/ezyang due to internal user reported ceiling equality simplification problem, I have a plan ([comment](https://github.com/pytorch/pytorch/pull/126905#issuecomment-2148805840)) --- c10/core/SymNodeImpl.h | 18 - test/dynamo/test_dynamic_shapes.py | 7 + test/dynamo/test_export.py | 3 +- test/dynamo/test_misc.py | 17 +- test/inductor/test_indexing.py | 72 +++- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 8 +- test/test_dynamic_shapes.py | 208 +++++++--- test/test_proxy_tensor.py | 3 +- test/test_sympy_utils.py | 122 +++--- torch/__init__.py | 162 +------- torch/_export/serde/serialize.py | 9 +- torch/_inductor/bounds.py | 5 - torch/_inductor/codegen/common.py | 168 ++------ torch/_inductor/codegen/cpp.py | 4 +- torch/_inductor/codegen/cpp_utils.py | 45 +- torch/_inductor/codegen/triton.py | 58 +-- torch/_inductor/graph.py | 5 +- torch/_inductor/ir.py | 16 +- torch/_inductor/kernel/flex_attention.py | 5 +- torch/_inductor/lowering.py | 6 +- torch/_inductor/ops_handler.py | 60 +-- torch/_inductor/select_algorithm.py | 4 +- torch/_inductor/sizevars.py | 20 +- torch/_inductor/utils.py | 2 +- torch/_subclasses/fake_tensor.py | 2 +- torch/csrc/jit/python/init.cpp | 5 - torch/csrc/utils/python_symnode.h | 20 - torch/export/dynamic_shapes.py | 9 +- torch/fx/experimental/recording.py | 8 +- torch/fx/experimental/sym_node.py | 204 ++------- torch/fx/experimental/symbolic_shapes.py | 80 ++-- torch/fx/experimental/validator.py | 32 +- torch/utils/_sympy/functions.py | 389 ++++-------------- torch/utils/_sympy/interp.py | 67 +-- torch/utils/_sympy/reference.py | 151 +++---- torch/utils/_sympy/solve.py | 1 - torch/utils/_sympy/value_ranges.py | 275 ++++--------- 37 files changed, 667 insertions(+), 1603 deletions(-) diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index bb92b09775b7..9ffab5065109 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -49,33 +49,15 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode mul(const SymNode& other) { TORCH_CHECK(false, "NYI"); } - // NB: legacy, prefer float_truediv or int_truediv virtual SymNode truediv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } - virtual SymNode float_truediv(const SymNode& other) { - return truediv(other); - } - virtual SymNode int_truediv(const SymNode& other) { - return truediv(other); - } - // NB: legacy, prefer float_pow or pow_by_natural virtual SymNode pow(const SymNode& other) { TORCH_CHECK(false, "NYI"); } - virtual SymNode float_pow(const SymNode& other) { - return pow(other); - } - virtual SymNode pow_by_natural(const SymNode& other) { - return pow(other); - } - // NB: legacy, prefer int_floordiv virtual SymNode floordiv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } - virtual SymNode int_floordiv(const SymNode& other) { - return floordiv(other); - } virtual SymNode mod(const SymNode& other) { TORCH_CHECK(false, "NYI"); } diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 57671e620e56..175ed573391b 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -78,6 +78,13 @@ def make_dynamic_cls(cls): del test if TEST_Z3: + # this only fails when z3 is available + unittest.expectedFailure( + # SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'. + # Ref: https://github.com/sympy/sympy/issues/25146 + DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821 + ) + if not config.inline_inbuilt_nn_modules: # TODO model is somehow not being freed when z3 is available unittest.expectedFailure( diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 7ae0f839f6ff..9f1417e23247 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2385,7 +2385,8 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, "Constraints violated .*!(.*\n)*.*" - "Not all values of dim0 .* satisfy the generated guard 4 <= .* and .* <= 10(.*\n)*.*", + "by dim0 = 2\\*dim1(.*\n)*.*" + "Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*", ): torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index dc2b9530f0dd..bcb0fd18818e 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9309,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9343,7 +9343,7 @@ def test_shape_env_equal_unbacked(self): > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False, is_int=True, is_float=False), u1: ValueRanges(lower=0, upper=1, is_bool=False, is_int=True, is_float=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False, is_int=False, is_float=True)} + > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), u1: ValueRanges(lower=0, upper=1, is_bool=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False)} > Right: {} """, ) @@ -9420,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self): > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} """, ) self._replay_and_check(main) @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self): > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} """, ) self._replay_and_check(main) @@ -9484,7 +9484,10 @@ def test_shape_env_equal_runtime_assert(self): ShapeEnv not equal: field values don't match: ==> deferred_runtime_asserts: values don't match. - > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} + > Left: {u0: [Eq(Mod(u0, 3), 0)]} + > Right: {} +==> divisible: values don't match. + > Left: {Mod(u0, 3)} > Right: {} ==> name_to_node: values don't match. > Left: {_assert, eq, mod, u0} diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index da527cfbb1d8..299a619f9cd6 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -11,12 +11,7 @@ instantiate_parametrized_tests, parametrize, ) -from torch.utils._sympy.functions import ( - FloorDiv, - ModularIndexing, - RoundDecimal, - RoundToInt, -) +from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Round, RoundDecimal class TestIndexingSimplification(InductorTestCase): @@ -173,11 +168,21 @@ def test_print_pow(self): common_cases = [ # expr, result + # Test exprs. + ( + s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1), + lambda c, L: f"((-1{L})*({c}/((-1{L}) + (2{L}*foo)))) + (foo*({c}/((-1{L}) + (2{L}*foo))))", + ), + (s1 / (s2 - s3), lambda c, L: f"foo*({c}/(bar + ((-1{L})*baz)))"), # Test Pow directly. ( sympy.Pow(s1 + s2, 0), lambda _, L: f"1{L}", ), # note: simplified before _print_Pow + ( + sympy.Pow(s1 + s2, -3), + lambda c, _: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))", + ), ] gpu_cases = common_cases + [ @@ -226,10 +231,12 @@ def test_print_ceil(self): self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""") def test_print_round(self): - expr = RoundToInt(sympy.Symbol("x", integer=True) / 2) + expr = Round(sympy.Symbol("x", integer=True) / 2) self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""") self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""") - self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""") + self.assertExpectedInline( + texpr(expr), """libdevice.llrint((1/2)*x).to(tl.int64)""" + ) @parametrize("ndigits", [-1, 0, 1]) def test_print_round_decimal(self, ndigits): @@ -244,18 +251,45 @@ def test_print_round_decimal(self, ndigits): f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}", ) + expr = RoundDecimal(sympy.Symbol("x", integer=True), ndigits) + if ndigits >= 0: + for do_print in [pexpr, cexpr, texpr]: + self.assertEqual(do_print(expr), "x") + else: + self.assertEqual(pexpr(expr), f"round(x, {ndigits})") + for do_print in [cexpr, texpr]: + with self.assertRaisesRegex( + ValueError, "only non-negative ndigits are currently supported" + ): + do_print(expr) + def test_print_floor_div(self): - s1 = sympy.Symbol("s1", integer=True) - s2 = sympy.Symbol("s2", integer=True) - expr = FloorDiv(s1, s2) - self.assertEqual(pexpr(expr), "(s1 // s2)") - self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") - - s1 = sympy.Symbol("s1", integer=True) - s2 = sympy.S(-1) - expr = FloorDiv(s1, s2) - self.assertEqual(pexpr(expr), "(-1)*s1") - self.assertEqual(cexpr(expr), "(-1L)*s1") + for integer in [True, False]: + s1 = sympy.Symbol("s1", integer=integer) + s2 = sympy.Symbol("s2", integer=integer) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(s1 // s2)") + if integer: + self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") + else: + self.assertEqual( + cexpr(expr), + "c10::div_floor_floating(static_cast(s1), static_cast(s2))", + ) + + for integer in [True, False]: + s1 = sympy.Symbol("s1", integer=integer) + s2 = sympy.S(-1) + expr = FloorDiv(s1, s2) + if integer: + self.assertEqual(pexpr(expr), "(-1)*s1") + self.assertEqual(cexpr(expr), "(-1L)*s1") + else: + self.assertEqual(pexpr(expr), "(s1 // (-1))") + self.assertEqual( + cexpr(expr), + "c10::div_floor_floating(static_cast(s1), static_cast((-1L)))", + ) def test_print_Min_Max(self): cases = ( diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 0f0e01bc0dc2..b70bfbf9c4a7 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -158,12 +158,8 @@ def forward(self, x, y): torch.tensor([operator.sub(x.item(), y.item())]), torch.tensor([operator.mul(x.item(), y.item())]), torch.tensor([operator.truediv(x.item(), y.item())]), - # This requires torch.sym_float, probably easy to lower to - # ONNX but I don't know where to put it - # torch.tensor([operator.floordiv(x.item(), y.item())]), - # NB: abs so that the base and exponent are provably - # non-negative, so we don't generate runtime asserts - torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]), + torch.tensor([operator.floordiv(x.item(), y.item())]), + torch.tensor([operator.pow(x.item(), y.item())]), torch.tensor([operator.abs(x.item())]), torch.tensor([operator.neg(x.item())]), torch.tensor([math.ceil(x.item())]), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 82503b5866b5..d548e9df0707 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -205,15 +205,15 @@ def create_symtype(cls, pytype, shape_env, val, duck=True): # TODO: default duck to False -def create_symint(shape_env, i: int, duck=True) -> SymInt: +def create_symint(shape_env, i: int, duck=True): return create_symtype(SymInt, int, shape_env, i, duck=duck) -def create_symbool(shape_env, b: bool) -> SymBool: +def create_symbool(shape_env, b: bool): return create_symtype(SymBool, bool, shape_env, b) -def create_symfloat(shape_env, f: float) -> SymFloat: +def create_symfloat(shape_env, f: float): return create_symtype(SymFloat, float, shape_env, f) @@ -457,16 +457,14 @@ def test_sym_int(self): r = sym_int(a1 / 2) self.assertEqual(guard_int(r), 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline( - str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" - ) + self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""") a3 = create_symint(shape_env, 3) r = sym_int(2.0 * torch.sym_float(a3)) self.assertEqual(guard_int(r), 6) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" + str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)""" ) def test_sym_sqrt(self): @@ -476,7 +474,7 @@ def test_sym_sqrt(self): self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)""" + str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)""" ) def test_sym_floor(self): @@ -485,17 +483,11 @@ def test_sym_floor(self): r = math.floor(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline( - str(shape_env.guards[0][0]), - """Eq(floor(IntTrueDiv(s0, 2)), 2)""", - ) + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""") r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline( - str(shape_env.guards[1][0]), - """Eq(floor(3.0*ToFloat(s0)), 15)""", - ) + self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") def test_sym_trunc(self): shape_env = ShapeEnv() @@ -503,14 +495,12 @@ def test_sym_trunc(self): r = math.trunc(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" - ) + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""") r = torch.sym_int(torch.sym_sqrt(a0)) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)""" + str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)""" ) def test_sym_ceil(self): @@ -520,17 +510,12 @@ def test_sym_ceil(self): self.assertEqual(r, 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), - """Eq(ceiling(IntTrueDiv(s0, 2)), 3)""", + str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""" ) - r1 = 3.0 * a0 - r = math.floor(r1) + r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline( - str(shape_env.guards[1][0]), - """Eq(floor(3.0*ToFloat(s0)), 15)""", - ) + self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") def test_sym_ite(self): shape_env = ShapeEnv() @@ -977,14 +962,8 @@ def test_ephemeral_source_unified_with_non_ephemeral_source(self): ) class TestSymNumberMagicMethods(TestCase): def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): - with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn): - return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn) - - def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn): # Helper function # NB: don't use one as that will get specialized - # TODO: We don't have to circuitously create the float, can just - # create a symfloat directly seed_node = (create_symint(shape_env, 2) / 2.0).node bool_seed_node = (create_symint(shape_env, 2) == 2).node @@ -997,42 +976,27 @@ def get_sym_inp(inp): else: return torch.SymFloat(to_node(seed_node, inp)) - if fn == "float_pow": - if inp1 < 0: - return - - if fn == "pow_by_natural": - if isinstance(inp1, float) or isinstance(inp2, float): - return - if inp2 < 0: - return - def maybe_xfail(inp1, inp2): if fn == "sym_sqrt" and inp1 < 0: # ValueError: math domain error return self.assertRaises((ValueError,)) - elif ( - fn in ("float_truediv", "int_truediv", "int_floordiv", "mod") - and inp2 == 0 - ): + elif fn in ("truediv", "floordiv", "mod") and inp2 == 0: # ZeroDivisionError: division by zero return self.assertRaises((ZeroDivisionError,)) - elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0: + elif fn == "pow" and inp1 == 0 and inp2 < 0: # ZeroDivisionError: 0.0 cannot be raised to a negative power return self.assertRaises((ZeroDivisionError,)) elif ( - # TODO: dear catastrophe waitress, - # this doesn't work - fn in ["float_pow", "pow_by_natural"] + fn == "pow" and inp1 < 0 + and inp2 in (2.5, -2.5) and ( - type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat) + type(inp1) in (SymFloat, SymInt) or type(inp2) in (SymFloat, SymInt) ) - and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float)) ): # Complex result, which we do not support: # TypeError: Cannot convert complex to float - return self.assertRaises((RuntimeError,)) + return self.assertRaises((TypeError,)) elif fn in ("lshift", "rshift") and not ( isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int)) ): @@ -1116,9 +1080,6 @@ def test_method(self, fn, first_type, second_type): ) and fn in sym_node.only_float_magic_methods: self.skipTest(f"{fn} is not an int method") - if second_type == "float" and fn in ["mod"]: - self.skipTest(f"{fn} only handles int") - is_unary_fn = fn in sym_node.unary_methods or fn == "round" # Second argument is ignored for unary function. So only run for one type if is_unary_fn and second_type == "float": @@ -1290,15 +1251,112 @@ def yield_test_cases(values, negate=True): yield (-x, -y) def test_floordiv_float_int(self): - values = ((7, 2),) + values = ( + (2.5, 2.1), + (2.1, 2.5), + (2.0, 2.1), + (7, 2.5), + (2.1, 7), + (7, 2), + ) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) ) + def test_floordiv_bool(self): + values = ( + (False, True), + (True, 2.5), + (2.5, True), + (False, 7), + (7, True), + ) + + for x, y in TestFloorDiv.yield_test_cases(values, negate=False): + # Compares to int since our FloorDiv has no bool support + self.assertEqual( + TestFloorDiv.python_floordiv(x, y), + TestFloorDiv.torch_floordiv(int(x), int(y)), + ) + # Tests that our impl throws + self.assertRaisesRegex( + TypeError, + ( + rf"unsupported operand type\(s\) for //: " + rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" + rf", expected integer or real" + ), + lambda: TestFloorDiv.torch_floordiv(x, y), + ) + + def test_floordiv_complex(self): + values = ( + (1.5 + 2.5j, 1.3 + 3.5j), + (1.5 + 2.5j, 2.5), + (2.5, 1.5 + 2.5j), + (1.5 + 2.5j, 7), + (7, 1.5 + 2.5j), + ) + + for x, y in TestFloorDiv.yield_test_cases(values): + # We don't test error messages to avoid depending on Python + # interpreter version + self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y)) + self.assertRaisesRegex( + TypeError, + ( + rf"unsupported operand type\(s\) for //: " + rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" + rf", expected integer or real" + ), + lambda: TestFloorDiv.torch_floordiv(x, y), + ) + + def test_floordiv_div_by_zero(self): + values = ( + (2.5, 0), + (2.1, 0.0), + (2.3, sympy.Symbol("s", zero=True)), + ) + + for x, y in TestFloorDiv.yield_test_cases(values, negate=False): + # We don't test error messages to avoid depending on Python + # interpreter version + if type(y) is not sympy.Symbol: + self.assertRaises( + ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y) + ) + self.assertRaisesRegex( + ZeroDivisionError, + "division by zero", + lambda: TestFloorDiv.torch_floordiv(x, y), + ) + + def test_floordiv_zero_base(self): + values = ( + (0, 2.5), + (0.0, 2.1), + (sympy.Symbol("s", zero=True), 2.3), + ) + + for x, y in TestFloorDiv.yield_test_cases(values, negate=False): + if type(x) is not sympy.Symbol: + self.assertEqual( + TestFloorDiv.python_floordiv(x, y), + TestFloorDiv.torch_floordiv(x, y), + ) + else: + self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y)) + def test_floordiv_div_by_one(self): - values = ((2, 1),) + values = ( + (2.5, 1), + (2.1, 1.0), + (2, 1.0), + (2, 1), + ) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( @@ -1309,7 +1367,12 @@ def test_floordiv_simplify(self): # Tests how we simplify or evaluate FloorDiv without free variables shape_env = ShapeEnv() result = 21 - exprs = (7 * FloorDiv(6, 2),) + exprs = ( + 7 * FloorDiv(6, 2), + 7 * FloorDiv(6.28, 2), + 7 * FloorDiv(6.28, 2.0), + 7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))), + ) for expr in exprs: self.assertEqual(expr, result) @@ -1319,10 +1382,33 @@ def test_floordiv_simplify(self): self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result) + def test_floordiv_simplify_rational(self): + result = 21 + + a = sympy.Symbol("a", integer=True) + b = sympy.Symbol("b") + + cases = [ + (FloorDiv(a, sympy.Rational(1, 8)), 8 * a), + (FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)), + ] + + for expr, expected in cases: + self.assertEqual(expr, expected) + def test_floordiv_assumptions(self): + # We define two Symbols (with different names) for each type to make + # sure the behavior is consistent regardless of whether both arguments + # are the same object or not. cases = ( sympy.Symbol("i1", integer=True), sympy.Symbol("i2", integer=True), + sympy.Symbol("r1", real=True), + sympy.Symbol("r2", real=True), + sympy.Symbol("c1", complex=True, real=False, integer=False), + sympy.Symbol("c2", complex=True, real=False, integer=False), + sympy.Symbol("s1"), + sympy.Symbol("s2"), ) for base, divisor in itertools.product(cases, repeat=2): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 04483ffba0fc..c7b2e51ced20 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1618,8 +1618,7 @@ def f(a): self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) - sym_float = torch.sym_float(sym_size_int); sym_size_int = None - pow_1 = sym_float ** 0.5; sym_float = None + pow_1 = sym_size_int ** 0.5; sym_size_int = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None return div""") diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 8b16b2c620fd..c5da8f7fc0da 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -36,12 +36,7 @@ "floor", "ceil", ] -BINARY_OPS = [ - "truediv", "floordiv", - # "truncdiv", # TODO - # NB: pow is float_pow - "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" -] +BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"] UNARY_BOOL_OPS = ["not_"] BINARY_BOOL_OPS = ["or_", "and_"] @@ -86,24 +81,16 @@ def valid_unary(fn, v): def valid_binary(fn, a, b): if fn == "pow" and ( - # sympy will expand to x*x*... for integral b; don't do it if it's big - b > 4 - # no imaginary numbers - or a <= 0 - # 0**0 is undefined - or (a == b == 0) - ): - return False - elif fn == "pow_by_natural" and ( - # sympy will expand to x*x*... for integral b; don't do it if it's big b > 4 - or b < 0 - or (a == b == 0) + or ( # sympy will expand to x*x*... for integral b; don't do it if it's big + a <= 0 and b == -1 + ) + or (a == b == 0) # no imaginary numbers # 0**0 is undefined ): return False - elif fn == "mod" and (a < 0 or b <= 0): + elif fn == "mod" and b == 0: return False - elif (fn in ["div", "truediv", "floordiv"]) and b == 0: + elif (fn == "div" or fn == "truediv") and b == 0: return False return True @@ -143,26 +130,27 @@ def test_pow_half(self): ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) @parametrize("fn", BINARY_OPS) - @parametrize("dtype", ("int", "float")) - def test_binary_ref(self, fn, dtype): + @parametrize("dtype_a", ("int", "float")) + @parametrize("dtype_b", ("int", "float")) + def test_binary_ref(self, fn, dtype_a, dtype_b): to_dtype = {"int": sympy.Integer, "float": sympy.Float} - # Don't test float on int only methods - if dtype == "float" and fn in ["pow_by_natural", "mod"]: - return - dtype = to_dtype[dtype] + dtype_a = to_dtype[dtype_a] + dtype_b = to_dtype[dtype_b] for a, b in itertools.product(CONSTANTS, repeat=2): if not valid_binary(fn, a, b): continue - a = dtype(a) - b = dtype(b) + a = dtype_a(a) + b = dtype_b(b) with self.subTest(a=a, b=b): r = getattr(ValueRangeAnalysis, fn)(a, b) if r == ValueRanges.unknown(): continue ref_r = getattr(ReferenceAnalysis, fn)(a, b) - self.assertEqual(r.lower.is_integer, r.upper.is_integer) - self.assertEqual(ref_r.is_integer, r.upper.is_integer) + # sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf + if fn != "floordiv": + self.assertEqual(r.lower.is_integer, r.upper.is_integer) + self.assertEqual(ref_r.is_integer, r.upper.is_integer) self.assertEqual(r.lower, r.upper) self.assertEqual(ref_r, r.lower) @@ -212,8 +200,7 @@ def test_binary_bool_ref_range(self, fn): @parametrize("fn", UNARY_OPS) def test_unary_ref_range(self, fn): - # TODO: bring back sympy.oo testing for float unary fns - vals = CONSTANTS + vals = [-sympy.oo, *CONSTANTS, sympy.oo] for a in generate_range(vals): with self.subTest(a=a): ref_r = getattr(ValueRangeAnalysis, fn)(a) @@ -229,26 +216,40 @@ def test_unary_ref_range(self, fn): # This takes about 4s for all the variants @parametrize("fn", BINARY_OPS + COMPARE_OPS) def test_binary_ref_range(self, fn): - # TODO: bring back sympy.oo testing for float unary fns - vals = LESS_CONSTANTS + vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo] for a, b in itertools.product(generate_range(vals), repeat=2): # don't attempt pow on exponents that are too large (but oo is OK) if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: continue with self.subTest(a=a, b=b): + ref_r = getattr(ValueRangeAnalysis, fn)(a, b) for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): if a0 not in a or b0 not in b: continue if not valid_binary(fn, a0, b0): continue with self.subTest(a0=a0, b0=b0): - ref_r = getattr(ValueRangeAnalysis, fn)(a, b) r = getattr(ReferenceAnalysis, fn)( sympy.Integer(a0), sympy.Integer(b0) ) if r.is_finite: self.assertIn(r, ref_r) + def test_rational_bounds(self): + # Repro from https://github.com/pytorch/pytorch/issues/105097 + from sympy import floor, Eq + shape_0 = sympy.Symbol('shape_0', positive=True, integer=True) + new_expr = ( + Eq(30 * floor(4 * ((shape_0 + 1) // 96) * + ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 + + 2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647), + 2880 * floor(((shape_0 + 1) // 96) * + ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 + + 323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764))) + new_range_env = {shape_0: ValueRanges(lower=1, upper=190)} + self.assertTrue(new_expr.subs({shape_0: 95})) + self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)) + class TestSympyInterp(TestCase): @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) @@ -257,13 +258,7 @@ def test_interp(self, fn): if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): return - is_integer = None - if fn == "pow_by_natural": - is_integer = True - - x = sympy.Dummy('x', integer=is_integer) - y = sympy.Dummy('y', integer=is_integer) - + from sympy.abc import x, y vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] @@ -305,17 +300,29 @@ def test_python_interp_fx(self, fn): if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: arity = 2 - is_integer = None - if fn == "pow_by_natural": - is_integer = True - - x = sympy.Dummy('x', integer=is_integer) - y = sympy.Dummy('y', integer=is_integer) + from sympy.abc import x, y symbols = [x] if arity == 2: symbols = [x, y] + # Workaround mpf from symbol error + if fn == "minimum": + sympy_expr = sympy.Min(x, y) + elif fn == "maximum": + sympy_expr = sympy.Max(x, y) + else: + sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) + + if arity == 1: + def trace_f(px): + return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) + else: + def trace_f(px, py): + return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) + + gm = fx.symbolic_trace(trace_f) + for args in itertools.product(vals, repeat=arity): if arity == 1 and not valid_unary(fn, *args): continue @@ -323,28 +330,11 @@ def test_python_interp_fx(self, fn): continue if fn == "truncdiv" and args[1] == 0: continue - elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0): + elif fn == "pow" and (args[0] == 0 and args[1] <= 0): continue elif fn == "floordiv" and args[1] == 0: continue with self.subTest(args=args): - # Workaround mpf from symbol error - if fn == "minimum": - sympy_expr = sympy.Min(x, y) - elif fn == "maximum": - sympy_expr = sympy.Max(x, y) - else: - sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) - - if arity == 1: - def trace_f(px): - return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) - else: - def trace_f(px, py): - return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) - - gm = fx.symbolic_trace(trace_f) - self.assertEqual( sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), gm(*args) diff --git a/torch/__init__.py b/torch/__init__.py index dfb1da76739d..18f1752019ec 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -316,75 +316,6 @@ def __index__(self): # Magic methods installed by torch.fx.experimental.sym_node - def __round__(self, ndigits=None): - return self - - def __truediv__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return sym_float(self).__float_truediv__(other) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - return self.__int_truediv__(other) - - def __rtruediv__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return sym_float(self).__rfloat_truediv__(other) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - return self.__rint_truediv__(other) - - def __floordiv__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return torch.sym_float(math.floor(sym_float(self) / other)) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - return self.__int_floordiv__(other) - - def __rfloordiv__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return torch.sym_float(math.floor(other / sym_float(self))) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - return self.__rint_floordiv__(other) - - # nb: complex is impossible to handle correctly lol, with - # negative base and integral float need to diverge semantics and - # just always return complex. Neener neener pretend this problem - # doesn't exist - def __pow__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return sym_float(self).__pow__(other) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - # Guards! This guard is necessary because we need to know it to - # determine the output type of this operation - if other >= 0: - return self.__pow_by_natural__(other) - else: - # Mercifully, when the exponent is negative, Python just promotes - # to doubles and does a float pow: - # - # if (Py_SIZE(b) < 0 && c == NULL) { - # /* if exponent is negative and there's no modulus: - # return a float. This works because we know - # that this calls float_pow() which converts its - # arguments to double. */ - # Py_DECREF(a); - # Py_DECREF(b); - # return PyFloat_Type.tp_as_number->nb_power(v, w, x); - # } - return sym_float(self).__pow__(sym_float(other)) - - def __rpow__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return sym_float(self).__rpow__(other) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - if self >= 0: # self is exponent - return self.__rpow_by_natural__(other) - else: - return sym_float(self).__rpow__(sym_float(other)) - def __eq__(self, other: object) -> builtins.bool: raise AssertionError("type stub not overridden") @@ -406,24 +337,6 @@ def __add__(self, other) -> "SymInt": def __mul__(self, other) -> "SymInt": raise AssertionError("type stub not overridden") - def __pow_by_natural__(self, other) -> "SymInt": - raise AssertionError("type stub not overridden") - - def __rpow_by_natural__(self, other) -> "SymInt": - raise AssertionError("type stub not overridden") - - def __int_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __rint_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __int_floordiv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __rint_floordiv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - def __sym_max__(self, other): raise AssertionError("type stub not overridden") @@ -458,43 +371,9 @@ def __init__(self, node): # class has a field named node that stores SymNode self.node = node - def __truediv__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - return self.__float_truediv__(sym_float(other)) - - def __rtruediv__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - return self.__rfloat_truediv__(sym_float(other)) - - def __floordiv__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - return torch.sym_float(math.floor(self / sym_float(other))) - - def __rfloordiv__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - return torch.sym_float(math.floor(sym_float(other) / self)) - def __bool__(self): return self.node.bool_() - # Symbolic power does NOT work with negative base, this is to avoid - # potential complex outputs - def __pow__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - torch._check(self >= 0) - return self.__float_pow__(other) - - def __rpow__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - torch._check(other >= 0) - return self.__rfloat_pow__(other) - # Magic methods installed by torch.fx.experimental.sym_node def __eq__(self, other: object) -> builtins.bool: @@ -512,18 +391,6 @@ def __le__(self, other) -> builtins.bool: def __ge__(self, other) -> builtins.bool: raise AssertionError("type stub not overridden") - def __float_pow__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __rfloat_pow__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __float_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __rfloat_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - def __trunc__(self): raise AssertionError("type stub not overridden") @@ -657,12 +524,7 @@ def sym_int(a): return py_int(a) # type: ignore[operator] def sym_max(a, b): - """ - SymInt-aware utility for max which avoids branching on a < b. - Unlike builtins.max(), this only works for int/float, and it always - promotes to float if any argument is float (unlike builtins.max, which - will faithfully preserve the type of the input argument). - """ + """ SymInt-aware utility for max().""" from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -670,19 +532,14 @@ def sym_max(a, b): if isinstance(a, (SymInt, SymFloat)): return a.__sym_max__(b) elif isinstance(b, (SymInt, SymFloat)): - # Due to promotion semantics, this is operator is commutative: - # max(1, 1.0) === max(1.0, 1) === 1.0 + # NB: If you actually care about preserving output type exactly + # if you do something like max(0, 0.0), it is NOT sound to treat + # min/max as commutative return b.__sym_max__(a) - # TODO: Probably can make bool work too, just lazy - assert isinstance(a, (builtins.int, builtins.float)), type(a) - assert isinstance(b, (builtins.int, builtins.float)), type(b) - if isinstance(a, builtins.float) or isinstance(b, builtins.float): - return builtins.float(builtins.max(a, b)) - else: - return builtins.max(a, b) + return builtins.max(a, b) # type: ignore[operator] def sym_min(a, b): - """ SymInt-aware utility for min().""" + """ SymInt-aware utility for max().""" from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -691,12 +548,7 @@ def sym_min(a, b): return a.__sym_min__(b) elif isinstance(b, (SymInt, SymFloat)): return b.__sym_min__(a) - assert isinstance(a, (builtins.int, builtins.float)), type(a) - assert isinstance(b, (builtins.int, builtins.float)), type(b) - if isinstance(a, builtins.float) or isinstance(b, builtins.float): - return builtins.float(builtins.min(a, b)) - else: - return builtins.min(a, b) + return builtins.min(a, b) # type: ignore[operator] # Drop in replacement for math.sqrt, math.sin, math.cos etc current_module = sys.modules[__name__] diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 9a92c238f950..8d6dc939fb5c 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1474,15 +1474,10 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: # Here we force symbols corresponding to SymInts to be at least integers. # Otherwise some expressions that the shape env would otherwise evaluate to False, # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. - # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024) sym = sym.subs( {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} ) - # We need to check if the symbol has already been allocated, - # self.symbol_name_to_symbol is not enough because the - # integer-ification of symbols can induce simplification; - # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral - if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val: + if isinstance(sym, sympy.Symbol): self.symbol_name_to_symbol[val.expr_str] = sym if hint is not None: self.shape_env.add_var_to_val(sym, hint) @@ -1501,7 +1496,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: free_symbols = sym.free_symbols for s in free_symbols: if s.name not in self.symbol_name_to_symbol: - self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment] + self.symbol_name_to_symbol[s.name] = s if vr := self.symbol_name_to_range.get(s.name): self.shape_env.constrain_symbol_range( s, diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 212b79e35bf9..4640ec4dce6b 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,4 +1,3 @@ -import logging import operator from functools import partial from typing import Any, Callable, Dict @@ -12,9 +11,6 @@ from .virtualized import V -log = logging.getLogger(__name__) - - class BoundVars: """ Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() @@ -59,7 +55,6 @@ def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: with V.set_ops_handler(ValueRangeAnalysis()): interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) - log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) interpreter.run(V.get_ops_handler(), initial_env=self._bounds) return self._bounds diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 9f4783a8fc59..f7b3e7a45d6e 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -340,8 +340,6 @@ def propagate_scheduler_node(cls, node): DataTypePropagation.propagate_loopbody(node._body) -# This printer contains rules that are supposed to be generic for both C/C++ and -# Python class ExprPrinter(Printer): @staticmethod def paren(string): @@ -371,6 +369,12 @@ def all_in_parens(string): return string return f"({string})" + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + def _print_Relational(self, expr): return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) @@ -380,14 +384,11 @@ def _print_Mul(self, expr): def _print_Add(self, expr): return " + ".join(map(self.paren, map(self._print, expr.args))) - # NB: this is OK to put here, because Mod is only defined for positive - # numbers, and so across C/Python its behavior is consistent def _print_Mod(self, expr): return " % ".join(map(self.paren, map(self._print, expr.args))) - def _print_FloatTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) @@ -398,84 +399,10 @@ def _print_GreaterThan(self, expr): # Go figure... return " >= ".join(map(self.paren, map(self._print, expr.args))) - # NB: The C implementation is injected into codegen at - # torch/_inductor/codegen/wrapper.py def _print_align(self, expr): assert len(expr.args) == 1 return f"align({self._print(expr.args[0])})" - # This must be implemented because sympy will collect x * x into Pow(x, 2), without - # any explicit intervention. We print it just like x * x, notably, we - # never generate sympy.Pow with floats. - # - # NB: this pow by natural, you should never have used builtin sympy.pow - # for FloatPow, and a symbolic exponent should be PowByNatural. These - # means exp is guaranteed to be integer. - def _print_Pow(self, expr): - base, exp = expr.args - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - assert exp >= 0 - if exp > 0: - return "*".join([self.paren(base)] * exp) - else: # exp == 0 - return "1" - - # Explicit NotImplemented functions are to prevent default sympy printing - # behavior, which will just barf out ToFloat(...) to your IR. The error - # message is better here because it tells you which printer class it needs - # to go in. - - def _print_ToFloat(self, expr): - raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") - - def _print_Infinity(self, expr): - raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") - - def _print_NegativeInfinity(self, expr): - raise NotImplementedError( - f"_print_NegativeInfinity not implemented for {type(self)}" - ) - - def _print_FloorDiv(self, expr): - raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") - - def _print_PythonMod(self, expr): - raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") - - def _print_IntTrueDiv(self, expr): - raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") - - def _print_PowByNatural(self, expr): - raise NotImplementedError( - f"_print_PowByNatural not implemented for {type(self)}" - ) - - def _print_FloatPow(self, expr): - raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") - - def _print_TruncToInt(self, expr): - raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") - - def _print_RoundToInt(self, expr): - raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") - - def _print_RoundDecimal(self, expr): - raise NotImplementedError( - f"_print_RoundDecimal not implemented for {type(self)}" - ) - - # NB: Some float operations are INTENTIONALLY not implemented for - # printers. You can implement them as a quick unblock, but it is better - # to ask yourself why we haven't done this computation in the Tensor - # universe instead - - def _print_TruncToFloat(self, expr): - raise NotImplementedError( - f"_print_TruncToFloat not implemented for {type(self)}" - ) - def doprint(self, expr, *, simplify: bool = True): # TODO: why are people passing strings to the printer here :think: if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): @@ -484,10 +411,6 @@ def doprint(self, expr, *, simplify: bool = True): class PythonPrinter(ExprPrinter): - def _print_ToFloat(self, expr): - assert len(expr.args) == 1 - return f"float({self._print(expr.args[0])})" - def _print_ModularIndexing(self, expr): x, div, mod = expr.args x = self.paren(self.doprint(x)) @@ -497,51 +420,46 @@ def _print_ModularIndexing(self, expr): x = f"({x} // {div})" return f"{x} % {mod}" - def _print_Infinity(self, expr): - return "math.inf" - - def _print_NegativeInfinity(self, expr): - return "-math.inf" - - # WARNING: this is dangerous for Triton, which has C-style modulus - def _print_PythonMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - # WARNING: this is dangerous for Triton, which has C-style modulus def _print_FloorDiv(self, expr): x, div = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) return f"({x} // {div})" - # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python - # does a special algorithm - def _print_IntTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" - def _helper_sqrt(self, expr): return f"math.sqrt({self._print(expr)})" def _print_OpaqueUnaryFn_sqrt(self, expr): return self._helper_sqrt(expr.args[0]) - def _print_FloatPow(self, expr): - base, exp = expr.args - return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" - - # TODO: Not sure this works with Triton, even when base/exp are integral - def _print_PowByNatural(self, expr): + def _print_Pow(self, expr): + # Pow() confuses triton base, exp = expr.args - return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + # NB: Remember this is sizevar computation! You don't typically + # expect to have to do floating point computation including exponents + # in sizevar compute. Instead of adding support for floating + # point pow, you should make upstream retranslate the Sympy expression + # into Tensor expressions earlier and do that instead. + if exp == 0.5: + return self._helper_sqrt(base) + elif exp == -0.5: + return "1/" + self._helper_sqrt(base) + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + if exp > 0: + return "*".join([self.paren(base)] * exp) + elif exp < 0: + return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + return "1" def _print_floor(self, expr): assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" - def _print_TruncToInt(self, expr): + def _print_Trunc(self, expr): assert len(expr.args) == 1 - # This also could have been int(), they'll do the same thing for float return f"math.trunc({self._print(expr.args[0])})" def _print_ceiling(self, expr): @@ -552,9 +470,6 @@ def _print_Abs(self, expr): assert len(expr.args) == 1 return f"abs({self._print(expr.args[0])})" - # NB: It's expected that we've made explicit any promotion in the sympy - # expression, so it doesn't matter that Python max/min doesn't perform - # promotion def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" @@ -599,7 +514,7 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" - def _print_RoundToInt(self, expr): + def _print_Round(self, expr): assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" @@ -738,29 +653,6 @@ def remainder(a, b): ) return ops.where(cond, ops.add(r, b), r) - @staticmethod - def trunc_to_int(a, dtype): - return ops.to_dtype(ops.trunc(a), dtype) - - @staticmethod - def floor_to_int(a, dtype): - return ops.to_dtype(ops.floor(a), dtype) - - @staticmethod - def ceil_to_int(a, dtype): - return ops.to_dtype(ops.ceil(a), dtype) - - @staticmethod - def round_to_int(a, dtype): - return ops.to_dtype(ops.round(a), dtype) - - @staticmethod - def int_truediv(a, b): - # TODO: this is wrong - # TODO: an easy bandaid is to generate runtime asserts that it's - # <= 2**53, which is when this equation is correct - return ops.truediv(a, b) - @staticmethod def load_seed(name, offset): return ops.load(name, sympy.Integer(offset)) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 311781102c3f..eabb5bbef470 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -275,11 +275,11 @@ def visit_modular_indexing(divisor, modulus): original_index = index - div = sympy.Wild("divisor", integer=True) + div = sympy.Wild("divisor") if index.has(FloorDiv): index = index.replace(FloorDiv(var, div), visit_indexing_div) - mod = sympy.Wild("modulus", integer=True) + mod = sympy.Wild("modulus") if index.has(ModularIndexing): index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 79884364420a..4ab33a5e26dc 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -100,48 +100,10 @@ def _print_floor(self, expr): r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - def _print_TruncToInt(self, expr): + def _print_Trunc(self, expr): assert len(expr.args) == 1 r = f"std::trunc({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" - - def _print_TruncToFloat(self, expr): - assert len(expr.args) == 1 - return f"std::trunc({self._print(expr.args[0])})" - - def _print_ToFloat(self, expr): - assert len(expr.args) == 1 - return f"static_cast({self._print(expr.args[0])})" - - # TODO: This is wrong if one of the inputs is negative. This is hard to - # tickle though, as the inputs are typically positive (and if we can prove - # they are positive, we will have used Mod instead, for which this codegen - # is right). - def _print_PythonMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - def _print_CMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - def _print_IntTrueDiv(self, expr): - lhs, rhs = expr.args - # TODO: This is only accurate up to 2**53 - return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" - - # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT - # use std::pow, that operates on floats - def _print_PowByNatural(self, expr): - raise NotImplementedError( - f"_print_PowByNatural not implemented for {type(self)}" - ) - - def _print_FloatTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" - - def _print_FloatPow(self, expr): - base, exp = expr.args - return f"std::pow({self._print(base)}, {self._print(exp)})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r def _print_Pow(self, expr): # Uses float constants to perform FP div @@ -238,9 +200,8 @@ def _print_OpaqueUnaryFn_atan(self, expr): def _print_OpaqueUnaryFn_sqrt(self, expr): return f"std::sqrt({self._print(expr.args[0])})" - def _print_RoundToInt(self, expr): + def _print_Round(self, expr): assert len(expr.args) == 1 - # TODO: dispatch to llrint depending on index type return f"std::lrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 066b6545a0a2..4b0ea92f3bf4 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -272,52 +272,17 @@ def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]): return f"{value}[{', '.join(expand)}]" -# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a -# number of operators which Triton "implements", but in a way that is -# inconsistent with Python semantics (and consistent with C semantics). We -# must override all of these, or it is potential silent correctness problem class TritonPrinter(PythonPrinter): - def _print_TruncToInt(self, expr): + def _print_floor(self, expr): assert len(expr.args) == 1 return ( - f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) - def _print_ToFloat(self, expr): - assert len(expr.args) == 1 - return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" - - # TODO: This is wrong if one of the inputs is negative. This is hard to - # tickle though, as the inputs are typically positive (and if we can prove - # they are positive, we will have used Mod instead, for which this codegen - # is right). If you are trying to hit this, maybe try something like - # torch.arange(n, device="cuda") - 1 and then do a modulus on it - def _print_PythonMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - # TODO: This is wrong, see - # https://github.com/triton-lang/triton/issues/955 - # But for Sympy expressions, things will /mostly/ work out because we - # don't usually deal with negative numbers in the division - def _print_FloorDiv(self, expr): - assert expr.is_integer - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"({x} // {div})" - - # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher - # precision algorithm, which we would need to replicate here - def _print_IntTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" - - # NB: sympy.floor/ceiling produce integers, so we have to do the - # conversion to index dtype - def _print_floor(self, expr): + def _print_Trunc(self, expr): assert len(expr.args) == 1 return ( - f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) def _print_ceiling(self, expr): @@ -394,9 +359,20 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" - def _print_RoundToInt(self, expr): + def _print_FloorDiv(self, expr): + if expr.is_integer: + return super()._print_FloorDiv(expr) + + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"libdevice.floor({x} / {div}).to({V.kernel.index_dtype})" + + def _print_Round(self, expr): assert len(expr.args) == 1 - return f"libdevice.llrint({self._print(expr.args[0])})" + return ( + f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) def _print_RoundDecimal(self, expr): assert len(expr.args) == 2 diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index abe93686ac83..337a7375afa8 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1196,11 +1196,8 @@ def debug(msg): elif is_magic_method(n.target): # TODO: this is sus, it probably should be handled in the # lowerings themselves similarly to sym_size/sym-stride - # https://github.com/pytorch/pytorch/issues/127789 debug("is_magic_method") - if isinstance( - n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) - ): + if isinstance(n.meta["val"], torch.SymInt): result = n.meta["val"].node.expr else: result = super().run_node(n) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index e9adfcd19a2d..c46cad5e41e2 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -44,6 +44,7 @@ is_boolean_dtype, is_float_dtype, make_channels_last_strides_for, + make_contiguous_strides_for, StrideType, ) from torch._subclasses.fake_tensor import get_schema_info @@ -235,7 +236,7 @@ def ir_node_to_tensor(x, guard_shape=True): if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] else: - stride = FlexibleLayout.contiguous_strides(size) # type: ignore[arg-type] + stride = make_contiguous_strides_for(size) # type: ignore[arg-type] dtype = x.get_dtype() device = x.get_device() size = convert_shape_to_symint(size) @@ -2765,7 +2766,6 @@ class FlexibleLayout(Layout): allow_indexing = False - # WARNING! This doesn't handle zero size tensors correctly @staticmethod def contiguous_strides(sizes): if len(sizes) == 0: @@ -5915,7 +5915,7 @@ def _original_deconv_weight_size( # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) if dynamic_shapes and is_contiguous_storage_and_layout(x): - output_stride = FlexibleLayout.contiguous_strides(output_size) + output_stride = make_contiguous_strides_for(output_size) else: output_stride = make_channels_last_strides_for(output_size) @@ -5967,7 +5967,7 @@ def _prepare_linear_fusion_create( assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" inputs = [x, weight] - output_stride = FlexibleLayout.contiguous_strides(output_size) + output_stride = make_contiguous_strides_for(output_size) kernel_layout = FixedLayout( x.get_device(), x.get_dtype(), @@ -6283,7 +6283,7 @@ def create(cls, x, packed_w, orig_w, B, batch_size): *m, _ = x.get_size() oc, _ = orig_w.get_size() output_size = list(m) + [oc] - output_stride = FlexibleLayout.contiguous_strides(output_size) + output_stride = make_contiguous_strides_for(output_size) inputs = [x, packed_w, orig_w] constant_args = [batch_size] if B is not None: @@ -6601,13 +6601,13 @@ def create( def get_strides_of_lstm_output(output_shape, batch_first): assert len(output_shape) == 3, "Expect output_shape to be 3D" - return FlexibleLayout.contiguous_strides(output_shape) + return make_contiguous_strides_for(output_shape) output_sizes = [output_shape, hy_shape, cy_shape] output_strides = [ get_strides_of_lstm_output(output_shape, batch_first), - FlexibleLayout.contiguous_strides(hy_shape), - FlexibleLayout.contiguous_strides(cy_shape), + make_contiguous_strides_for(hy_shape), + make_contiguous_strides_for(cy_shape), ] output_ir = [ MultiOutput( diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index f3492949a84d..42fabf65591d 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -5,6 +5,7 @@ from typing import Any, List, Tuple import torch +from torch._prims_common import make_contiguous_strides_for from .. import config from ..ir import ( ComputedBuffer, @@ -388,7 +389,7 @@ def flex_attention(*args, **kwargs): query.get_device(), query.get_dtype(), query.get_size(), - FlexibleLayout.contiguous_strides(query.get_size()), + make_contiguous_strides_for(query.get_size()), ) # see NOTE:[TritonTemplates with multiple outputs] logsumexp_shape = query.get_size()[:-1] # [B, H, M] @@ -744,7 +745,7 @@ def flex_attention_backward(*args, **kwargs): key.get_device(), key.get_dtype(), key.get_size(), - FlexibleLayout.contiguous_strides(key.get_size()), + make_contiguous_strides_for(key.get_size()), ) # Create delta which will is needed for the bwd's kernel diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 300cf71c2934..20b0082eb1d9 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -34,7 +34,7 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 @@ -4262,7 +4262,7 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): out_sz = out_sz[dim] in_sz = in_sz[dim] kernel_sz = kernel_sz[dim] - alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) + alpha = (in_sz - kernel_sz) / (out_sz - 1) samples_loader = samples.make_loader() def load(prefix, i): @@ -4372,7 +4372,7 @@ def upsample_nearest2d_backward( w_kernel_max = ceildiv(inp_w, out_w) def start_index(index, out_dim, inp_dim): - return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) + return CeilDiv(index * inp_dim, out_dim) def end_index(index, out_dim, inp_dim): return start_index((index + 1), out_dim, inp_dim) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index f88cd948ca4d..5630061b4426 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -138,38 +138,6 @@ def to_dtype( """ ... - def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: - """ - Convert x to dtype with truncation semantics (similar to how the int - constructor works in Python). In Inductor codegen, this just decays - to trunc and then to_dtype, but this composite operation helps - roundtrips for Sympy evaluation. - - dtype is taken as an explicit parameter because the desired output - dtype is typically the index dtype, which may vary between int32 and - int64 depending on if we've shown that all the indexing operations can - be done in int32. - """ - ... - - def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: - """ - Convert x to dtype with ceiling semantics. See also trunc_to_int. - """ - ... - - def floor_to_int(self, x: T, dtype: torch.dtype) -> T: - """ - Convert x to dtype with ceiling semantics. See also trunc_to_int. - """ - ... - - def round_to_int(self, x: T, dtype: torch.dtype) -> T: - """ - Convert x to dtype with round-to-even semantics. See also trunc_to_int. - """ - ... - def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: """ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) @@ -430,23 +398,21 @@ def isinf(self, x0: T) -> T: def isnan(self, x0: T) -> T: ... - # NB: this returns a float, like the torch operation - # This rounds half to even to break ties def round(self, x0: T) -> T: ... - # NB: this returns a float, like the torch operation def floor(self, x0: T) -> T: ... def sign(self, x0: T) -> T: ... - # NB: this returns a float, like the torch operation + def to_int(self, x0: T) -> T: + ... + def trunc(self, x0: T) -> T: ... - # NB: this returns a float, like the torch operation def ceil(self, x0: T) -> T: ... @@ -483,7 +449,6 @@ def sub(self, x0: T, x1: T) -> T: def mul(self, x0: T, x1: T) -> T: ... - # NB: this returns a float, like the torch operation def pow(self, x0: T, x1: T) -> T: ... @@ -652,21 +617,14 @@ def truncdiv(self, x0: T, x1: T) -> T: def floordiv(self, x0: T, x1: T) -> T: """Python-style floor division between integers only. Computes the - true division of two numbers and floors the result. If you want - floor division for floats, do regular truediv and floor the result. + true division of two numbers and floors the result. """ ... def truediv(self, x0: T, x1: T) -> T: - """True division between floats. Integer inputs are NOT valid. To - do Python-style (int, int) -> float division, use int_truediv""" - ... - - def int_truediv(self, x0: T, x1: T) -> T: - """True division between integers. This is NOT the same as promoting - to float and doing integer division, there is a bespoke algorithm for - doing the division in higher precision than the above. - """ + """True division between floats. Integer inputs are NOT valid: to do + Python style (int, int) -> float division, promote the inputs to float + first.""" ... def div(self, x0: T, x1: T) -> T: @@ -682,10 +640,6 @@ def remainder(self, x0: T, x1: T) -> T: """Python-style modulus, take sign from RHS (x1).""" ... - def round_decimal(self, x0: T, x1: T) -> T: - """Python-style round with decimal argument""" - ... - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # In CUDA, optimized implementations of other mathematical operations are # offered separately via libdevice for double precision computation (in diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 85d0d0f1954c..bc89441e3bd8 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -385,7 +385,7 @@ def store_output( assert isinstance(mask, (str, type(None))) assert self.template_mask is None indices = list(map(TritonPrinter.paren, indices)) - index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + index_symbols = [sympy.Symbol(x) for x in indices] lengths = [ V.graph.sizevars.simplify(s) for s in self.output_node.get_size() ] @@ -409,7 +409,7 @@ def store_output( output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) if output_index == contiguous_index: - output_index = sympy.Symbol("xindex", integer=True) + output_index = sympy.Symbol("xindex") epilogue_args = [val] for input_node in itertools.chain( diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index fba9a66f9237..bc8803a5e715 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -161,9 +161,9 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(ModularIndexing): expr = expr.replace( ModularIndexing( - sympy.Wild("base", integer=True), - sympy.Wild("divisor", integer=True), - sympy.Wild("modulus", integer=True), + sympy.Wild("base"), + sympy.Wild("divisor"), + sympy.Wild("modulus"), ), visit_modular_indexing, ) @@ -171,8 +171,8 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(FloorDiv): expr = expr.replace( FloorDiv( - sympy.Wild("base", integer=True), - sympy.Wild("divisor", integer=True), + sympy.Wild("base"), + sympy.Wild("divisor"), ), visit_indexing_div, ) @@ -604,11 +604,11 @@ def _join_dimensions_cached(expr: Expr) -> Expr: """ assert isinstance(expr, sympy.Add) - scale = sympy.Wild("scale", exclude=[0], integer=True) - base = sympy.Wild("base", integer=True) - divisor = sympy.Wild("divisor", integer=True) - mod1 = sympy.Wild("modulus", integer=True) - mod2 = sympy.Wild("modulus2", integer=True) + scale = sympy.Wild("scale", exclude=[0]) + base = sympy.Wild("base") + divisor = sympy.Wild("divisor") + mod1 = sympy.Wild("modulus") + mod2 = sympy.Wild("modulus2") for term1 in expr.args: m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) if m1: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index a635c2f509c1..0915a8330c34 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -192,7 +192,7 @@ def ceildiv( numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] ) -> Union[int, sympy.Expr]: if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): - return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) + return CeilDiv(numer, denom) # TODO: There is a bug in a call to this function, to repro: # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy # --amp --only YituTechConvBert --dynamic-shapes diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 9343490de3e8..47d4abcf77b9 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1727,7 +1727,7 @@ def go(t, real_t): for run_impl_check, op_impl in op_implementations_checks: if run_impl_check(func): op_impl_out = op_impl(self, func, *args, **kwargs) - if op_impl_out is not NotImplemented: + if op_impl_out != NotImplemented: return maybe_propagate_real_tensors(op_impl_out) def maybe_run_unsafe_fallback(error=None): diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 2a3cb62c56d7..a7ce337f9ac8 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1200,13 +1200,8 @@ void initJITBindings(PyObject* module) { SYMNODE_BINARY(sub) SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) - SYMNODE_BINARY(int_truediv) - SYMNODE_BINARY(float_truediv) SYMNODE_BINARY(pow) - SYMNODE_BINARY(float_pow) - SYMNODE_BINARY(pow_by_natural) SYMNODE_BINARY(floordiv) - SYMNODE_BINARY(int_floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(eq) SYMNODE_BINARY(ne) diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index 15738b1a67e1..f8c710cf6579 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -198,34 +198,14 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__func__, other); } - c10::SymNode float_truediv(const c10::SymNode& other) override { - return dispatch_common_(__func__, other); - } - - c10::SymNode int_truediv(const c10::SymNode& other) override { - return dispatch_common_(__func__, other); - } - c10::SymNode pow(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } - c10::SymNode float_pow(const c10::SymNode& other) override { - return dispatch_common_(__func__, other); - } - - c10::SymNode pow_by_natural(const c10::SymNode& other) override { - return dispatch_common_(__func__, other); - } - c10::SymNode floordiv(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } - c10::SymNode int_floordiv(const c10::SymNode& other) override { - return dispatch_common_(__func__, other); - } - c10::SymNode mod(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index ac2bdd60a550..a4ed16e975b8 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,6 +1,7 @@ import builtins import dataclasses import inspect +import math import sys import weakref from collections import defaultdict @@ -253,14 +254,11 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): shared: Optional[_ConstraintTarget] = None debug_name: Optional[str] = None - def _clone_with_range(self, lower=0, upper=None): + def _clone_with_range(self, lower=0, upper=math.inf): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges - if upper is None: - upper = sys.maxsize - 1 - constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), warn_only=False, @@ -488,6 +486,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): ) # Import sympy locally + import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges @@ -497,7 +496,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): id(t), index, StrictMinMaxConstraint( - vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False + vr=ValueRanges(lower=0, upper=sympy.oo), warn_only=False ), debug_name=debug_name, ) diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 28df3fddab0e..4bf9ebab17b3 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -277,13 +277,7 @@ def wrapper(*args, **kwargs): raise except Exception: - log.error( # noqa: G201 - "failed while running %s(*%s, **%s)", - name, - args[1:], - kwargs, - exc_info=log.isEnabledFor(logging.INFO), - ) + log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs) raise return wrapper diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 4a88d24ce3d5..98cba67a73a1 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -267,11 +267,8 @@ def mul(self, other) -> "SymNode": def mod(self, other) -> "SymNode": return self._mod(other) # type: ignore[attr-defined] - def float_pow(self, other) -> "SymNode": - return self._float_pow(other) # type: ignore[attr-defined] - - def pow_by_natural(self, other) -> "SymNode": - return self._pow_by_natural(other) # type: ignore[attr-defined] + def pow(self, other) -> "SymNode": + return self._pow(other) # type: ignore[attr-defined] def and_(self, other) -> "SymNode": return self._and_(other) # type: ignore[attr-defined] @@ -279,14 +276,11 @@ def and_(self, other) -> "SymNode": def or_(self, other) -> "SymNode": return self._or_(other) # type: ignore[attr-defined] - def float_truediv(self, other) -> "SymNode": - return self._float_truediv(other) # type: ignore[attr-defined] - - def int_truediv(self, other) -> "SymNode": - return self._int_truediv(other) # type: ignore[attr-defined] + def truediv(self, other) -> "SymNode": + return self._truediv(other) # type: ignore[attr-defined] - def int_floordiv(self, other) -> "SymNode": - return self._int_floordiv(other) # type: ignore[attr-defined] + def floordiv(self, other) -> "SymNode": + return self._floordiv(other) # type: ignore[attr-defined] def lshift(self, other) -> "SymNode": return self._lshift(other) # type: ignore[attr-defined] @@ -367,17 +361,6 @@ def sym_or(self, other): def sym_and(self, other): return self.and_(other) - # There is no int_truediv available from C++ - def truediv(self, other): - return self.float_truediv(other) - - def floordiv(self, other) -> "SymNode": - return self.int_floordiv(other) - - # We didn't bind integer pow in C++ - def pow(self, other): - return self.float_pow(other) - def is_non_overlapping_and_dense(self, sizes, strides): return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] @@ -494,7 +477,7 @@ def is_constant(self): "eq": operator.eq, "floor": math.floor, "trunc": math.trunc, - "int_floordiv": operator.floordiv, + "floordiv": operator.floordiv, "ge": operator.ge, "gt": operator.gt, "is_integer": lambda x: x.is_integer(), @@ -506,8 +489,7 @@ def is_constant(self): "ne": operator.ne, "neg": operator.neg, "or": operator.or_, - "float_pow": operator.pow, - "pow_by_natural": operator.pow, + "pow": operator.pow, "round": builtins.round, "rshift": operator.rshift, "sub": operator.sub, @@ -516,14 +498,12 @@ def is_constant(self): "sym_max": sym_max, "sym_min": sym_min, "sym_not": sym_not, - "float_truediv": operator.truediv, - "int_truediv": operator.truediv, + "truediv": operator.truediv, } unary_magic_methods = { "abs", "sym_float", - "sym_int", "ceil", "floor", "neg", @@ -579,20 +559,20 @@ def fn(self): bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods # Methods that are only for float -only_float_magic_methods = {"is_integer", "round", "sym_int"} +only_float_magic_methods = {"is_integer"} magic_methods_on_operator_with_trailing_underscore = {"and", "or"} -always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} +always_float_magic_methods = {"truediv", "sym_float", "pow"} for name in math_op_names: sym_name = f"sym_{name}" always_float_magic_methods.add(sym_name) -always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} +always_int_magic_methods = {"ceil", "floor", "trunc"} always_bool_magic_methods = { "eq", "ne", @@ -610,16 +590,10 @@ def fn(self): # Methods that have a `__foo__` as well as `__rfoo__` -def _sympy_float_truediv(a, b): - from torch.utils._sympy.functions import FloatTrueDiv - - return FloatTrueDiv(a, b) - +def _sympy_truediv(a, b): + from torch.utils._sympy.functions import TrueDiv -def _sympy_int_truediv(a, b): - from torch.utils._sympy.functions import IntTrueDiv - - return IntTrueDiv(a, b) + return TrueDiv(a, b) def _sympy_floordiv(a, b): @@ -629,24 +603,15 @@ def _sympy_floordiv(a, b): def _sympy_mod(a, b): - from torch.utils._sympy.functions import Mod, PythonMod - - if a.is_nonnegative and b.is_nonnegative: - return Mod(a, b) - else: - return PythonMod(a, b) + from torch.utils._sympy.functions import Mod + return Mod(a, b) -def _sympy_pow_by_natural(a, b): - from torch.utils._sympy.functions import PowByNatural - return PowByNatural(a, b) +def _sympy_pow(a, b): + from torch.utils._sympy.functions import Pow - -def _sympy_float_pow(a, b): - from torch.utils._sympy.functions import FloatPow - - return FloatPow(a, b) + return Pow(a, b) def _sympy_and(a, b): @@ -678,13 +643,11 @@ def _sympy_rshift(a, b): "sub": operator.sub, "mul": operator.mul, "mod": _sympy_mod, - "pow_by_natural": _sympy_pow_by_natural, - "float_pow": _sympy_float_pow, + "pow": _sympy_pow, "and": _sympy_and, "or": _sympy_or, - "float_truediv": _sympy_float_truediv, - "int_truediv": _sympy_int_truediv, - "int_floordiv": _sympy_floordiv, + "truediv": _sympy_truediv, + "floordiv": _sympy_floordiv, "lshift": _sympy_lshift, "rshift": _sympy_rshift, } @@ -708,22 +671,18 @@ def _floor_ceil_helper(a, fn): return fn(a) -# NB: this is Python semantics so it returns an int def _sympy_floor(a): import sympy return _floor_ceil_helper(a, sympy.floor) -# NB: this is Python trunc semantics which returns an int. Do NOT use this to -# represent torch.trunc (which is float to float) def _sympy_trunc(a): - from torch.utils._sympy.functions import TruncToInt + from torch.utils._sympy.functions import Trunc - return TruncToInt(a) + return Trunc(a) -# NB: this is Python semantics so it returns an int def _sympy_ceil(a): import sympy @@ -812,28 +771,26 @@ def _sympy_abs(a): def _sympy_round(number, ndigits=None): - from torch.utils._sympy.functions import RoundDecimal, RoundToInt + from torch.utils._sympy.functions import Round, RoundDecimal if ndigits is None: - return RoundToInt(number) + return Round(number) else: return RoundDecimal(number, ndigits) def _sympy_sym_float(a): - from torch.utils._sympy.functions import ToFloat - - # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly - # reports that it is an integer - return ToFloat(a) + # Cannot use sympy.Float(a) here, coz it expects python literals + # Multiply by 1.0 to cast to float. This is needed when the input + # is a SymInt which has the assumption that it is integer and + # SymPy will otherwise assume that return value cannot be a float. + return a * 1.0 def _sympy_is_integer(a): import sympy - from torch.utils._sympy.functions import ToFloat - - return sympy.Eq(ToFloat(sympy.floor(a)), a) + return sympy.Eq(sympy.floor(a), a) magic_methods = { @@ -1032,26 +989,9 @@ def binary_magic_impl(self, other): self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) ) assert isinstance(other, SymNode) + # TODO: consider constant prop here try: - if method == "mod": - from torch.utils._sympy.functions import Mod, PythonMod - - # Special handling for mod that requires access to the value - # ranges - shape_env = self.shape_env - if ( - self.expr.is_nonnegative - or shape_env.bound_sympy(self.expr).lower >= 0 - ) and ( - other.expr.is_nonnegative - or shape_env.bound_sympy(other.expr).lower >= 0 - ): - out = Mod(self.expr, other.expr) - else: - out = PythonMod(self.expr, other.expr) - else: - # TODO: consider constant prop here - out = func(self.expr, other.expr) + out = func(self.expr, other.expr) except Exception: log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) raise @@ -1182,13 +1122,9 @@ def round_impl(self, ndigits=None): except Exception: log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) raise - out = safe_expand(out) - if ndigits is None: - pytype = int - else: - pytype = self.pytype + pytype = int if ndigits is None else self.pytype out_hint = None if self.hint is not None: @@ -1200,7 +1136,6 @@ def round_impl(self, ndigits=None): # hack down below works, because all round function down the line all take ndigits=None as default in their # signature. # TODO: Remove the args construction below if a different sentinel is used by FX. - # ezyang(May 2024): LOL args = [self.fx_node] if ndigits is not None: args.append(ndigits) @@ -1324,32 +1259,6 @@ def is_constant(x): return x.node.is_constant() return False - # Promotion rules for binary operations. NB: we preserve PYTHON semantics - # - if args are same type, do nothing - # - if one arg is float, promote other arg to float - # - nb: this applies to floordiv, even though output is integral - # (it's still float) - # - pow is funny business - # - if both ints - # - trigger a guard on exponent >= 0 - # - if non-negative, output is int - # - otherwise, output is float - # - otherwise, promote other arg to float - # - nb: complex is impossible to handle correctly lol, with - # negative base and integral float need to diverge semantics and - # just always return complex. Neener neener pretend this problem - # doesn't exist - # - equality is pain: Python does the fancy thing where it unpacks the - # mantissa from the float and then compares that against the int. - # Which means it is able to tell that - # 9007199254740993 != 9007199254740992. (rather than if the LHS was - # promoted to float, in which case it would have truncated to the RHS - # and subsequently been equal). We'll model this exactly by having - # special mixed type equality operations. Unfortunately, we need to - # do this for all comparison operations (maybe I'll only implement - # compare) - # - sym_ite mumble mumble really shouldn't allow mixed but whatever - if method in bool_becomes_int_magic_methods: def promote(x): @@ -1363,41 +1272,6 @@ def promote(x): def promote(x): return x - def promote2(self, other): - # TODO: Remove eq and other relations from this list. - # CPython has fancy implementations for these to get as much precision - # as possible instead of just promoting to float64 and praying, so we - # need to handle them specially too. - # Also, note that int_truediv doesn't go through this path: both - # arguments are "int" so there isn't any promotion - if method not in [ - "add", - "sub", - "mul", - "mod", - "float_pow", - "float_truediv", - "int_floordiv", - "sym_min", - "sym_max", - # TODO: remove these - "eq", - "ne", - "gt", - "lt", - "le", - "ge", - ]: - return self, other - f_self = isinstance(self, (float, torch.SymFloat)) - f_other = isinstance(other, (float, torch.SymFloat)) - if f_self or f_other: - if not f_self: - self = torch.sym_float(self) - if not f_other: - other = torch.sym_float(other) - return self, other - # Before and after performing the operation, check if any operands are constant. # If so, extract out the constant values first. If `self` itself is a # constant, then "redispatch" by calling back into the operator. Sometimes @@ -1412,12 +1286,9 @@ def unary_magic_impl(self): return wrap_node(getattr(self.node, method_attr)()) def binary_magic_impl(self, other): - if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): - return NotImplemented sym_node_log.debug("MAGIC %s %s %s", method, self, other) self = promote(self) other = promote(other) - self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): @@ -1429,11 +1300,8 @@ def binary_magic_impl(self, other): return get_constant(ret) if is_constant(ret) else ret def rbinary_magic_impl(self, other): - if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): - return NotImplemented self = promote(self) other = promote(other) - self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index d7321f071865..a2abde3a861e 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -61,7 +61,7 @@ from torch import SymBool, SymFloat, SymInt from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from torch.utils._sympy.functions import FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv +from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.singleton_int import SingletonInt @@ -869,9 +869,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): for N=1. """ if min is None: - min = -sys.maxsize - 1 + min = -sympy.oo if max is None: - max = sys.maxsize - 1 + max = sympy.oo if max < min: raise ValueError( @@ -979,6 +979,16 @@ def eval_guards(gm, *args, ignore_static=True): def bind_symbols(gm, *args): return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) +def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges): + """ + We assert that the bounds are either Boolean, or not finite, or can be computed + in exact prevision via rational arithmetic. + The only exception to this is the rare case when the user calls `sqrt(s0)` + sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still) + """ + assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr) + assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr) + class DimDynamic(Enum): """ Controls how to perform symbol allocation for a dimension. It is always @@ -1377,17 +1387,14 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 'Min': min, 'Max': max, 'Mod': operator.mod, - 'PythonMod': operator.mod, 'FloorDiv': operator.floordiv, 'TrueDiv': operator.truediv, 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 'floor': math.floor, 'ceiling': math.ceil, 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, - 'RoundToInt': builtins.round, + 'Round': builtins.round, 'RoundDecimal': builtins.round, - 'TruncToInt': math.trunc, - 'IntTrueDiv': operator.truediv, } @@ -1635,17 +1642,10 @@ def floor_div_handler(*args): congruence = (base - mod_reduced) % divisor if congruence != 0: self._congruences[s].add(congruence) - # NB: Must not be CleanDiv, it needs to be regular sympy division - # so inequality solver works. This is sort of problematic for - # is_integer tests though haha return (base - mod_reduced) / divisor if expr.has(Mod): expr = expr.replace(Mod, mod_handler) - # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative - # arguments should be OK. - if expr.has(PythonMod): - expr = expr.replace(PythonMod, mod_handler) if expr.has(FloorDiv): expr = expr.replace(FloorDiv, floor_div_handler) return expr @@ -3330,7 +3330,6 @@ def create_unbacked_symfloat(self): self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() - assert vr.is_float # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, float) @@ -3349,7 +3348,6 @@ def create_unbacked_symint(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = self._default_unspecified_value_range() - assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, int) @@ -3373,7 +3371,6 @@ def create_unbacked_symbool(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges(0, 1) - assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) @@ -3519,7 +3516,6 @@ def create_symbol( self.var_to_range[sympy_expr] &= constraint_dim.vr vr = self.var_to_range[sympy_expr] - assert vr.is_int if val not in vr: raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") @@ -3528,7 +3524,6 @@ def create_symbol( elif isinstance(val, float): self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) range_str = f"[{vr.lower}, {vr.upper}]" - assert vr.is_float else: # Skip var_range logic for SingletonInt # Only used for jagged layout nested tensors @@ -3578,7 +3573,6 @@ def create_symbol( def add_var_to_val(self, expr: sympy.Symbol, val: int): """ Adds a new symbol to the symbolic environment. """ - log.debug("add_var_to_val %s %s", expr, val, stack_info=True) assert expr not in self.var_to_val, f"{expr} already exists" self.var_to_val[expr] = sympy.Integer(val) @@ -4307,8 +4301,7 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa # Clamp values of size-like variables for x in self.size_like & var_to_range.keys(): if var_to_range[x] is not None: - var_to_range[x] = ValueRanges(2, sys.maxsize - 1) - assert var_to_range[x].is_int + var_to_range[x] = ValueRanges(2, sympy.oo) return bound_sympy(expr, var_to_range) @_lru_cache @@ -4425,11 +4418,6 @@ def _maybe_evaluate_static( vr = self._default_unspecified_value_range() if size_oblivious and k in self.size_like: lower = max(2, vr.lower) - # This is a bit dodgy: what this means is that there was a - # size-like unbacked symbol whose upper bound < 2. This - # causes... problems. - if lower <= vr.upper: - vr = ValueRanges(lower, vr.upper) else: lower = vr.lower # Don't do anything if we don't have a nontrivial lower bound @@ -4437,17 +4425,10 @@ def _maybe_evaluate_static( # SymInt if ( lower < (-sys.maxsize - 1) // 2 or - (unbacked_only and k in self.var_to_val) or - not vr.is_int + (unbacked_only and k in self.var_to_val) ): new_range_env[k] = vr continue - # The goal is to take our symbols which have various lower bounds - # and reallocate them into new symbols which are exactly positive; - # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in - # [1, inf], where s0 = ess0 + 1. This gives the most information - # to sympy for subsequent simplifications. - # # Positive means >= 1 # Positive - 1 means >= 0 # Positive + lower - 1 means >= lower @@ -4479,14 +4460,6 @@ def replace(expr, repl): self.counter["sympy_recursion_error"] += 1 return None - new_expr = safe_expand(new_expr) - if new_expr.is_number: - return new_expr - - # This is bad to do, the replacement with division leaves us with - # rationals when atom.args[0] is addition, e.g., sympy will happily - # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication! - """ floor_div_replace = {} for atom in new_expr.atoms(FloorDiv): floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) @@ -4495,12 +4468,13 @@ def replace(expr, repl): # are still free symbols if new_expr.is_number: return new_expr - """ # Check if the range can solve it statically out = bound_sympy(new_expr, new_range_env) - if out.is_singleton(): - return out.lower + if expect_rational: + _assert_bound_is_rational(new_expr, out) + if out.is_singleton(): + return out.lower return new_expr if unbacked_only else None @@ -4552,7 +4526,7 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": for fd in expr.atoms(FloorDiv): base, divisor = fd.args if self.replace(Mod(base, divisor)) in self.divisible: - div_replacements[fd] = CleanDiv(base, divisor) + div_replacements[fd] = base / divisor new_expr = expr.xreplace(div_replacements) new_expr = safe_expand(new_expr) new_pows = new_expr.atoms(sympy.Pow) @@ -4696,10 +4670,7 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1) def issubset(x, y): - if x.is_int and y.is_int: - return (x & int_range).issubset(y & int_range) - else: - return x.issubset(y) + return (x & int_range).issubset(y & int_range) # First, refine the value range of a based on the computed value range # of tgt. This is always OK to do, even if we decide not to do the @@ -4717,7 +4688,7 @@ def issubset(x, y): b = next(iter(tgt.free_symbols)) # Try to invert the equality r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) - if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])): + if r is not None: b_bound = self.bound_sympy(r[1]) self.var_to_range[b] = b_bound & self.var_to_range[b] tgt_bound = self.bound_sympy(tgt) @@ -4928,12 +4899,12 @@ def trivial_solve(lhs, rhs): ): # We have Mod(i0, q / c) == 0, which means we can # rewrite i0 as (q / gcd(q, c)) * i1 - d = q / sympy.gcd(q, c) # TODO: CleanDiv? + d = q / sympy.gcd(q, c) i1 = self.create_unbacked_symint().node.expr # Propagate the value ranges. It doesn't really # matter if we use truediv or floordiv, because we # have established divisibility. - self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv( + self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv( self.var_to_range[i0], ValueRanges.wrap(d) )) # Propagate size-like-ness @@ -5370,6 +5341,7 @@ def _refine_ranges(self, expr: sympy.Expr) -> None: lower, upper = vr.lower, vr.upper rhs_vr = bound_sympy(rhs, self.var_to_range) + _assert_bound_is_rational(rhs, rhs_vr) # Let's suppose that we have a preexisting range for x [0, 100]. # Now, we issue a guard x > y, where the range for y is [50, 150]. diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index d06b38d60c80..6dcb59db7979 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -216,7 +216,10 @@ def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: def abs(self, number: z3.ArithRef) -> z3.ArithRef: return z3.Abs(number) - def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: + def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: + if ndigits is not None: + raise ValueError("round(..., ndigits=) is currently not supported by shape validations.") + # Pythons builtin 'round' implements the 'round half to even' strategy # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to @@ -281,7 +284,7 @@ def wrapper(*args): operator.truediv: lift(ops.div), operator.mod: lift(ops.mod), operator.abs: lift(ops.abs), - builtins.round: lift(ops.round_to_int), + builtins.round: lift(ops.round), # Math module. math.ceil: lift(ops.ceil), @@ -347,7 +350,6 @@ def __init__( self._ops = _Z3Ops(self._validator) def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: - # TODO: Probably OK to relax this and allow lower precision if dtype is torch.int64: return z3.IntVal(int(value)) if dtype is torch.double: @@ -356,20 +358,6 @@ def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: return z3.BoolVal(bool(value)) raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") - def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: - if dtype == torch.float64: - return z3.ToReal(x) - raise NotImplementedError(f"to_dtype {dtype} NYI") - - def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: - return z3.ToInt(x) - - def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: - return self._ops.round_to_int(x) - - def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: - return self._ops.div(numerator, denominator) - def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: return self._ops.div(numerator, denominator) @@ -382,17 +370,11 @@ def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: return self._ops.pow(base, exp) - def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: - return self._ops.pow(base, exp) - def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: return self._ops.mod(p, q) - def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: - return self._ops.ceil(x) - - def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: - return self._ops.floor(x) + def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: + return self._ops.round(number, ndigits) def __getattr__(self, name: str) -> Any: REPLACEMENT = { diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 9b1599288949..1384261b4512 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,78 +1,43 @@ -import functools import math -import sys import sympy from sympy import S +from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or __all__ = [ "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", - "IntTrueDiv", - "FloatTrueDiv", + "Pow", + "TrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", - "RoundToInt", + "Round", "RoundDecimal", - "ToFloat", - "FloatPow", - "PowByNatural", ] -def _keep_float(f): - @functools.wraps(f) - def inner(*args): - r = f(*args) - if any(isinstance(a, sympy.Float) for a in args) and not isinstance( - r, sympy.Float - ): - r = sympy.Float(float(r)) - return r - - return inner - - def fuzzy_eq(x, y): if None in (x, y): return None return x == y -# It would be nice to have assertions on whether or not inputs is_integer -# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy -# sometimes inconsistently reports floats an integers. -# -# What we can assume from sympy is that if something is an int, it -# definitely is is_integer, but if it is a float it may or may not -# be is_integer. So we are unable to do strong asserts that things -# are NOT integers. - - -# TODO: In Triton, // rounds to zero, but in Python, it is floor division. -# When we can prove both arguments are non-negative, we should just have a -# GenericFloorDiv (name pending) which can codegen efficiently in Python/C, -# and then PythonFloorDiv and CIntDiv which have the appropriate rounding -# semantics. -# -# Right now, FloorDiv de facto changes behavior if arguments are negative or -# not, this can potentially cause correctness issues. class FloorDiv(sympy.Function): """ We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) - - NB: This is Python-style floor division, round to -Inf """ nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 - is_integer = True + # Default return type for SymPy assumptions. + # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers + is_real = True @property def base(self): @@ -87,14 +52,29 @@ def _sympystr(self, printer): divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" + # SymPy assumptions based on argument types. + def _eval_is_real(self): + return fuzzy_or([self.base.is_real, self.divisor.is_real]) + + def _eval_is_integer(self): + return fuzzy_and([self.base.is_integer, self.divisor.is_integer]) + # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod def eval(cls, base, divisor): - # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full - # Assert triggered by inequality solver - # assert base.is_integer, base - # assert divisor.is_integer, divisor + def check_supported_type(x): + if ( + x.is_integer is False and x.is_real is False and x.is_complex + ) or x.is_Boolean: + raise TypeError( + f"unsupported operand type(s) for //: " + f"'{type(base).__name__}' and '{type(divisor).__name__}'" + f", expected integer or real" + ) + + check_supported_type(base) + check_supported_type(divisor) # We don't provide the same error message as in Python because SymPy # makes it difficult to check the types. @@ -105,22 +85,26 @@ def eval(cls, base, divisor): return sympy.S.Zero if base.is_integer and divisor == 1: return base + if base.is_real and divisor == 1: + return sympy.floor(base) if base.is_integer and divisor == -1: return sympy.Mul(base, -1) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): - return sympy.Integer(int(base) // int(divisor)) + return base // divisor + if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance( + divisor, (sympy.Integer, sympy.Float) + ): + return sympy.floor(base / divisor) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) + if isinstance(divisor, sympy.Rational) and divisor.p == 1: + return sympy.floor(base * divisor.q) - # gcd in sympy is over polynomials, so you'll end up with rationals if - # you do this. Don't. - """ if isinstance(base, sympy.Add): for a in base.args: gcd = sympy.gcd(a, divisor) if gcd == divisor: return FloorDiv(base - a, divisor) + a / gcd - """ try: gcd = sympy.gcd(base, divisor) @@ -142,10 +126,6 @@ class ModularIndexing(sympy.Function): @classmethod def eval(cls, base, divisor, modulus): - assert isinstance(base, int) or base.is_integer, base - assert isinstance(divisor, int) or divisor.is_integer, divisor - assert isinstance(modulus, int) or modulus.is_integer, modulus - if base == 0 or modulus == 1: return sympy.Integer(0) @@ -209,19 +189,6 @@ class Where(sympy.Function): nargs = (3,) - def _eval_is_integer(self): - return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] - - def _eval_is_nonnegative(self): - return ( - True - if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] - else None - ) - - def _eval_is_positive(self): - return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] - @classmethod def eval(cls, c, p, q): if c == sympy.true: @@ -230,27 +197,28 @@ def eval(cls, c, p, q): return q -# Python-style modulus: take sign from RHS -class PythonMod(sympy.Function): - nargs = (2,) +class Mod(sympy.Function): + """ + We maintain this so that we avoid SymPy correctness issues, such as: + https://github.com/sympy/sympy/issues/25146 + """ - is_integer = True + nargs = (2,) @classmethod def eval(cls, p, q): - # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint - # Triggered by sympy.solvers.inequalities.reduce_inequalities - # assert p.is_integer, p - # assert q.is_integer, q + # This was adapted from: sympy/core/mod.py if q.is_zero: raise ZeroDivisionError("Modulo by zero") - + # If either of them is NaN or infinite. + if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False: + return S.NaN # Three cases: # 1. p == 0 # 2. p is either q or -q # 3. p is integer and q == 1 - if p is S.Zero or p in (q, -q) or q == 1: + if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1): return S.Zero # Evaluate if they are both literals. @@ -279,7 +247,10 @@ def eval(cls, p, q): if sympy.Mod(p, q) == 0: return S.Zero - # NB: args[1] for PythonMod + def _eval_is_integer(self): + p, q = self.args + return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined] + def _eval_is_nonnegative(self): return True if self.args[1].is_positive else None # type: ignore[attr-defined] @@ -287,58 +258,6 @@ def _eval_is_nonpositive(self): return True if self.args[1].is_negative else None # type: ignore[attr-defined] -# Generic modulus: only defined on non-negative arguments -class Mod(sympy.Function): - nargs = (2,) - - is_integer = True - is_nonnegative = True - - @classmethod - def eval(cls, p, q): - # This was adapted from: sympy/core/mod.py - - # Triggered by - # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full - # assert p.is_integer, p - # assert q.is_integer, q - - if q.is_zero: - raise ZeroDivisionError("Modulo by zero") - - # Three cases: - # 1. p == 0 - # 2. p is either q or -q - # 3. p is integer and q == 1 - if p is S.Zero or p in (q, -q) or q == 1: - return S.Zero - - # Evaluate if they are both literals. - if q.is_Number and p.is_Number: - assert p >= 0, p - assert q >= 1, q - return p % q - - # If q == 2, it's a matter of whether p is odd or even. - if q.is_Number and q == 2: - if p.is_even: - return S.Zero - if p.is_odd: - return S.One - - # If p is a multiple of q. - r = p / q - if r.is_integer: - return S.Zero - - # If p < q and its ratio is positive, then: - # - floor(p / q) = 0 - # - p % q = p - floor(p / q) * q = p - less = p < q - if less.is_Boolean and bool(less) and r.is_positive: - return p - - class CleanDiv(FloorDiv): """ Div where we can assume no rounding. @@ -356,10 +275,6 @@ class CeilDiv(sympy.Function): is_integer = True def __new__(cls, base, divisor): - base = sympy.sympify(base) - divisor = sympy.sympify(divisor) - assert base.is_integer, base - assert divisor.is_integer, divisor if sympy.gcd(base, divisor) == divisor: return CleanDiv(base, divisor) else: @@ -367,139 +282,43 @@ def __new__(cls, base, divisor): class LShift(sympy.Function): - is_integer = True - @classmethod def eval(cls, base, shift): - assert base.is_integer, base - assert shift.is_integer, shift - if shift < 0: raise ValueError("negative shift count") return base * 2**shift class RShift(sympy.Function): - is_integer = True - @classmethod def eval(cls, base, shift): - assert base.is_integer, base - assert shift.is_integer, shift - if shift < 0: raise ValueError("negative shift count") return base // 2**shift -def safe_pow(base, exp): - sign = 1 - if base < 0: - base = -base - sign = 1 if exp % 2 == 0 else -1 - return sign * _safe_pow(base, exp) - - -def _safe_pow(base, exponent): - if exponent < 0: - raise ValueError("Exponent must be non-negative.") - - if exponent == 0: - return 1 - - half_exp = safe_pow(base, exponent // 2) - if half_exp > sys.maxsize - 1: - return sys.maxsize - 1 - - result = half_exp * half_exp - if result > sys.maxsize - 1: - return sys.maxsize - 1 - - if exponent % 2 == 1: - result *= base - if result > sys.maxsize - 1: - return sys.maxsize - 1 - - return result - - -class PowByNatural(sympy.Function): - is_integer = True - - @classmethod - def eval(cls, base, exp): - # exp can be assumed to be is_integer and is_nonnegative, but we may - # have concluded this externally from Sympy assumptions, so we can't - # assert the nonnegative - assert exp.is_integer, exp - if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): - return sympy.Integer(safe_pow(base, exp)) - if isinstance(exp, sympy.Integer): - # Translate power into iterated multiplication - r = sympy.Integer(1) - for _ in range(int(exp)): - r *= base - return r - # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp - # is a natural number if we do - - -# base is assumed to be nonnegative, thereby prevent complex numbers from -# occuring -class FloatPow(sympy.Function): - is_integer = False - is_real = True - +# Overloaded to be compatible with regular Python. +# https://github.com/pytorch/pytorch/issues/90900 +class Pow(sympy.Function): @classmethod def eval(cls, base, exp): - if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): - return sympy.Float(float(base) ** float(exp)) - # NB: do not do any nontrivial reasoning + if exp.is_zero: + return sympy.Integer(1) + elif base.is_zero and exp < 0: + raise ZeroDivisionError(f"{base} cannot be raised to a negative power") + else: + return base**exp # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 -# -# In particular, sympy division is willing to simplify x/x == 1 -# where 1 is an integer, but this must be a float if x was float. -class FloatTrueDiv(sympy.Function): - is_integer = False - is_real = True - - @classmethod - def eval(cls, base, divisor): - # assert base.is_integer is not True, base - # assert divisor.is_integer is not True, divisor - - if divisor.is_zero: - raise ZeroDivisionError("division by zero") - - if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): - return sympy.Float(float(base) / float(divisor)) - - -# Overloaded to be compatible with regular Python. We distinguish this from -# FloatTrueDiv, because the code generation has to be different for this case: -# Python has a fancy algorithm for integer true division that isn't just -# "promote both arguments to float and use float division", so you need to -# codegen it differently. While technically you can work it out from the -# types of the input, this is often inconvenient to do in Inductor codegen, -# so just have a different operator -# NB: Right now, Inductor codegen doesn't implement this correctly lol -class IntTrueDiv(sympy.Function): - is_integer = False - is_real = True - +class TrueDiv(sympy.Function): @classmethod def eval(cls, base, divisor): - assert base.is_integer, base - assert divisor.is_integer, divisor - if divisor.is_zero: raise ZeroDivisionError("division by zero") - - if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): - return sympy.Float(int(base) / int(divisor)) + else: + return base / divisor # TODO: As an indicator, this != 0 implies == 1 (and vice versa). @@ -534,87 +353,45 @@ def eval(cls, *args): return None -# NB: this is inconsistent with math.trunc in Python -class TruncToFloat(sympy.Function): - is_integer = False - is_real = True - - @classmethod - def eval(cls, number): - # assert number.is_integer is not True, number - if isinstance(number, sympy.Number): - # NB: It is safe to use truncation to integer, which is what - # math.trunc does, as Python integers are arbitrary precision and - # so we are guaranteed not to lose precision when we do this - return sympy.Float(math.trunc(float(number))) - - -class TruncToInt(sympy.Function): +class Trunc(sympy.Function): is_integer = True @classmethod def eval(cls, number): - # assert number.is_integer is not True, number - if number == sympy.oo: - return sympy.Integer(sys.maxsize - 1) - if number == -sympy.oo: - return sympy.Integer(-sys.maxsize - 1) - if isinstance(number, sympy.Number): + if number.is_integer: + return number + elif isinstance(number, sympy.Number): return sympy.Integer(math.trunc(float(number))) -# This is float -> int -class RoundToInt(sympy.Function): +class Round(sympy.Function): is_integer = True @classmethod def eval(cls, number): - # assert number.is_integer is not True, number - - if isinstance(number, sympy.Float): - return sympy.Integer(round(float(number), 0)) - + if number.is_integer: + return number + elif isinstance(number, sympy.Number): + return sympy.Integer(round(float(number))) -# To get float -> int, Python style round semantics. -# -# x = PyFloat_AsDouble(self); -# if (o_ndigits == Py_None) { -# /* single-argument round or with None ndigits: -# * round to nearest integer */ -# rounded = round(x); -# if (fabs(x-rounded) == 0.5) -# /* halfway case: round to even */ -# rounded = 2.0*round(x/2.0); -# return PyLong_FromDouble(rounded); -# } + def __int__(self): + # This will only ever be called when computing size hints. At that point, self.args[0] should be a number and + # no longer an expression. If it were, the float call would fail and the caller would handle this further. + return round(float(self.args[0])) # type: ignore[arg-type] -# NB: Like Round, this only ever returns floats. ndigits cannot be None class RoundDecimal(sympy.Function): - is_integer = False - is_real = True - @classmethod def eval(cls, number, ndigits): - # assert number.is_integer is not True, number - - if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): - return sympy.Float(round(float(number), int(ndigits))) - - -class ToFloat(sympy.Function): - is_integer = False - is_real = True - - @classmethod - def eval(cls, number): - if number in [sympy.oo, -sympy.oo]: + if number.is_integer and ndigits >= 0: return number - - assert number.is_integer, number - - if isinstance(number, sympy.Integer): - return sympy.Float(int(number)) + elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): + value_type, output_type = ( + (int, sympy.Integer) + if isinstance(number, sympy.Integer) + else (float, sympy.Float) + ) + return output_type(round(value_type(number), int(ndigits))) def make_opaque_unary_fn(name): diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index c2d9ae464125..806e91cfe281 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -16,20 +16,15 @@ import torch from .functions import ( CleanDiv, - FloatPow, - FloatTrueDiv, FloorDiv, - IntTrueDiv, IsNonOverlappingAndDenseIndicator, Mod, ModularIndexing, - PowByNatural, - PythonMod, + Pow, + Round, RoundDecimal, - RoundToInt, - ToFloat, - TruncToFloat, - TruncToInt, + TrueDiv, + Trunc, Where, ) @@ -54,39 +49,30 @@ def handlers(): sympy.Le: "le", sympy.Ge: "ge", sympy.Not: "not_", - IntTrueDiv: "int_truediv", - FloatTrueDiv: "truediv", + TrueDiv: "truediv", FloorDiv: "floordiv", - CleanDiv: "floordiv", # TODO: hmm? - TruncToFloat: "trunc", + CleanDiv: "div", + Trunc: "trunc", Where: "where", sympy.Add: "add", sympy.Mul: "mul", - FloatPow: "pow", - PowByNatural: "pow_by_natural", - # sympy simplifies x * x into Pow(x, 2), so we need to handle this. - # Do NOT use builtin Pow for floats - # TODO: There is a hazard here, if we have float * float it will - # also get turned into Pow(float, 2) but we don't want this because - # pow_by_natural is assumed to only be integers. Probably the fix is - # to add a FloatMul to impede this optimization - sympy.Pow: "pow_by_natural", + Pow: "pow", + sympy.Pow: "pow", Mod: "mod", - PythonMod: "mod", # TODO: this is wrong - # TODO: Inductor can generate these, but it's ill-specified which - # semantics were intended here. Needs to be cleaned up along with - # FloorDiv in a bigger cleanup sympy.Mod: "mod", sympy.Abs: "abs", sympy.log: "log", sympy.exp: "exp", + sympy.floor: "floor", + sympy.ceiling: "ceil", sympy.Min: "minimum", sympy.Max: "maximum", ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", - RoundDecimal: "round_decimal", + Round: "round", + RoundDecimal: "round", } for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: HANDLERS[getattr(sympy, name)] = name @@ -98,11 +84,7 @@ def handlers(): def sympy_interp( - analysis, - env: Dict[sympy.Symbol, Any], - expr: Union[sympy.Expr, SympyBoolean], - *, - index_dtype=torch.int64, + analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean] ): # Handle base cases dtype = None @@ -123,30 +105,9 @@ def sympy_interp( expr.args[1], sympy.core.numbers.Half ): return analysis.sqrt(sympy_interp(analysis, env, expr.args[0])) - if isinstance(expr, ToFloat): - return analysis.to_dtype( - sympy_interp(analysis, env, expr.args[0]), torch.float64 - ) # Recursive case args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type] - - # These handlers are special because they take an extra dtype argument - # specifying what they should convert to, and we need to appropriately set - # this up when we convert from Sympy. A reasonable default when you - # are translating is to conservatively do int64, and then narrow these - # arguments later when you discover you can narrow the index range. But - # if you already know that 32-bit indexing is OK, you can directly do the - # sympy translation with index_dtype=torch.int32 - INDEX_DTYPE_HANDLERS = { - TruncToInt: "trunc_to_int", - sympy.floor: "floor_to_int", - sympy.ceiling: "ceil_to_int", - RoundToInt: "round_to_int", - } - if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: - return getattr(analysis, handler_name)(*args, index_dtype) - if hasattr(expr.func, "_torch_handler_name"): handler_name = expr.func._torch_handler_name else: diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index b54a0d0503a1..881b9d616eb5 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,25 +1,12 @@ import math -import operator - import sympy import torch from torch.utils._sympy.functions import ( - _keep_float, - FloatPow, - FloatTrueDiv, - FloorDiv, - IntTrueDiv, - Mod, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, OpaqueUnaryFn_sqrt, - PowByNatural, - RoundDecimal, - RoundToInt, - ToFloat, - TruncToInt, ) @@ -75,41 +62,20 @@ def not_(a): @staticmethod def reciprocal(x): - return FloatTrueDiv(1.0, x) + return 1 / x @staticmethod def square(x): - return PowByNatural(x, 2) - - @staticmethod - def trunc_to_int(x, dtype): - return TruncToInt(x) - - @staticmethod - def ceil_to_int(x, dtype): - return sympy.ceiling(x) - - @staticmethod - def floor_to_int(x, dtype): - return sympy.floor(x) - - @staticmethod - def floor(x): - return _keep_float(sympy.floor)(x) - - @staticmethod - def ceil(x): - return _keep_float(sympy.ceiling)(x) - - @staticmethod - def to_dtype(x, dtype): - if dtype == torch.float64: - return ToFloat(x) - raise NotImplementedError(f"to_dtype {dtype} NYI") + return x * x @staticmethod def mod(x, y): - return Mod(x, y) + ret = abs(x) % abs(y) + # without check: + # tracing will fail trying to go through control-flow if x is Proxy() + if isinstance(x, (int, sympy.Number)) and x < 0: + ret *= -1 + return ret @staticmethod def abs(x): @@ -121,31 +87,37 @@ def neg(x): @staticmethod def truediv(a, b): - return FloatTrueDiv(a, b) + return a / b @staticmethod - def int_truediv(a, b): - return IntTrueDiv(a, b) + def div(a, b): + return ReferenceAnalysis.truediv(a, b) @staticmethod def floordiv(a, b): - return FloorDiv(a, b) + if b == 0: + return sympy.nan if a == 0 else sympy.zoo + return a // b @staticmethod def truncdiv(a, b): - raise NotImplementedError("TODO: truncdiv") + result = a / b + if result.is_finite: + result = sympy.Integer(result) + + return result @staticmethod def add(a, b): - return _keep_float(operator.add)(a, b) + return a + b @staticmethod def mul(a, b): - return _keep_float(operator.mul)(a, b) + return a * b @staticmethod def sub(a, b): - return _keep_float(operator.sub)(a, b) + return a - b @staticmethod def exp(x): @@ -161,27 +133,39 @@ def sqrt(x): @staticmethod def pow(a, b): - return _keep_float(FloatPow)(a, b) - - @staticmethod - def pow_by_natural(a, b): - return PowByNatural(a, b) + return a**b @staticmethod def minimum(a, b): - return sympy.Min(a, b) + # Poorman's version of upcasting in Sympy + # This won't do for sympy.Expr as the casting does nothing for those + if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: + result_type = sympy.Float + else: + assert a.is_Integer + assert b.is_Integer + result_type = sympy.Integer + return sympy.Min(result_type(a), result_type(b)) @staticmethod def maximum(a, b): - return sympy.Max(a, b) + # Poorman's version of upcasting in Sympy + # This won't do for sympy.Expr as the casting does nothing for those + if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: + result_type = sympy.Float + else: + assert a.is_Integer + assert b.is_Integer + result_type = sympy.Integer + return sympy.Max(result_type(a), result_type(b)) @staticmethod - def round_to_int(a, dtype): - return RoundToInt(a) + def floor(x): + return sympy.floor(x) @staticmethod - def round_decimal(a, b): - return RoundDecimal(a, b) + def ceil(x): + return sympy.ceiling(x) # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain @@ -207,20 +191,10 @@ def not_(a): def floordiv(a, b): return a // b - @staticmethod - def mod(x, y): - return x % y - @staticmethod def truncdiv(a, b): return a / b - @staticmethod - def to_dtype(x, dtype): - if dtype == torch.float64: - return float(x) - raise NotImplementedError(f"to_dtype {dtype} NYI") - @staticmethod def exp(x): raise AssertionError("exp is not valid shape sympy expr") @@ -241,41 +215,10 @@ def minimum(a, b): def maximum(a, b): return torch.sym_max(a, b) - @staticmethod - def floor_to_int(x, dtype): - return math.floor(x) - - @staticmethod - def ceil_to_int(x, dtype): - return math.ceil(x) - @staticmethod def floor(x): - return float(math.floor(x)) + return math.floor(x) @staticmethod def ceil(x): - return float(math.ceil(x)) - - @staticmethod - def truediv(a, b): - return a / b - - @staticmethod - def pow(a, b): - return a**b - - @staticmethod - def pow_by_natural(a, b): - # Pray that safe_pow is not needed here lol. In particular, this - # never participates in VR low/high ranges, so overflow should be - # unlikely - return a**b - - @staticmethod - def round_to_int(a, dtype): - return round(a) - - @staticmethod - def round_decimal(a, b): - return round(a, ndigits=b) + return math.ceil(x) diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index 02ddf7c34219..6276c696293c 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -88,7 +88,6 @@ def try_solve( # Return if we were able to isolate 'thing' on the left-hand side. if isinstance(e, sympy.Rel) and e.lhs == thing: - log.debug("solved: %s ---> %s", expr, e) return e, e.rhs return None diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 4d364d4981b5..c7cc96beb980 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -5,7 +5,6 @@ import logging import math import operator -import sys from typing import ( Callable, Dict, @@ -26,20 +25,17 @@ from torch._prims_common import dtype_to_type from .functions import ( - _keep_float, - FloatTrueDiv, - FloorDiv, - IntTrueDiv, + OpaqueUnaryFn_acos, + OpaqueUnaryFn_asinh, + OpaqueUnaryFn_atan, + OpaqueUnaryFn_cosh, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, + OpaqueUnaryFn_sinh, OpaqueUnaryFn_sqrt, - PowByNatural, + OpaqueUnaryFn_tanh, + Round, RoundDecimal, - RoundToInt, - safe_pow, - ToFloat, - TruncToFloat, - TruncToInt, ) from .interp import sympy_interp @@ -124,8 +120,6 @@ class ValueRanges(Generic[_T]): lower: _T upper: _T is_bool: bool - is_int: bool - is_float: bool @overload def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: @@ -148,39 +142,8 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: # Because this is a frozen class object.__setattr__(self, "lower", lower) object.__setattr__(self, "upper", upper) - # Unlike bool/int in Python, we don't report bools are ints object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean)) - if self.is_bool: - assert isinstance(upper, SympyBoolean), (lower, upper) - - # Warning: is_int/is_float is best effort. We do pretty well in - # Dynamo, but in Inductor these attributes are often wrong because we - # are not very rigorous in dtype analysis. This is also why we need - # the flexible analysis for is_int: sometimes a sympy.oo pops in for - # an integer bound. I would /like/ for us not to do this, but it's - # too hard to push the invariant through right now. - - object.__setattr__( - self, - "is_int", - not self.is_bool - and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)), - ) - """ - # This assert is just impossible right now, too many sympy bugs - if self.is_int: - # NB: sympy will sometimes randomly lose the float-ness of zero, - # so we also need to account for that in the assertion here. - # See also https://github.com/sympy/sympy/issues/26620 - assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], ( - lower, - upper, - ) - assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper) - """ - # NB: [-oo, oo] always advertises as float! - object.__setattr__(self, "is_float", not self.is_bool and not self.is_int) - assert self.is_bool or self.is_int or self.is_float, (lower, upper) + assert isinstance(upper, SympyBoolean) == self.is_bool def boolify(self) -> ValueRanges[SympyBoolean]: if vr_is_bool(self): @@ -221,8 +184,6 @@ def __and__(self: AllVR, other: AllVR) -> AllVR: if self == ValueRanges.unknown(): return other assert self.is_bool == other.is_bool, (self, other) - assert self.is_int == other.is_int, (self, other) - assert self.is_float == other.is_float, (self, other) if self.is_bool: return ValueRanges( sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper) @@ -392,12 +353,7 @@ def constant(value, dtype): # using nan makes subsequent computation throw, and for the purposes of optimization # returning -math.inf - math.inf is equivalent to giving up if isinstance(value, SupportsFloat) and math.isnan(value): - if dtype == torch.bool: - return ValueRanges.unknown_bool() - elif dtype.is_floating_point: - return ValueRanges.unknown() - else: - return ValueRanges(-sys.maxsize - 1, sys.maxsize) + return ValueRanges.unknown() if is_python: type_ = dtype_to_type(dtype) @@ -413,18 +369,7 @@ def constant(value, dtype): # dtype is intXX assert value.is_integer - r = ValueRanges.wrap(value) - return r - - @staticmethod - def to_dtype(a, dtype, src_dtype=None): - if dtype == torch.float64: - return ValueRanges.increasing_map(a, ToFloat) - return ValueRanges.unknown() - - @staticmethod - def trunc_to_int(a, dtype): - return ValueRanges.increasing_map(a, TruncToInt) + return ValueRanges.wrap(value) @staticmethod def not_(a): @@ -483,9 +428,7 @@ def ge(cls, a, b): @staticmethod def add(a, b): - return ValueRanges.coordinatewise_increasing_map( - a, b, _keep_float(operator.add) - ) + return ValueRanges.coordinatewise_increasing_map(a, b, operator.add) @classmethod def mul(cls, a, b): @@ -505,20 +448,11 @@ def safe_mul(a, b): else: return a * b - return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul)) + return ValueRanges.coordinatewise_monotone_map(a, b, safe_mul) - @staticmethod - def int_truediv(a, b): - a = ValueRanges.wrap(a) - b = ValueRanges.wrap(b) - if 0 in b or ( - (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) - ): - return ValueRanges.unknown() - else: - return ValueRanges.coordinatewise_monotone_map( - a, b, _keep_float(IntTrueDiv) - ) + @classmethod + def div(cls, a, b): + return cls.truediv(a, b) @staticmethod def truediv(a, b): @@ -529,22 +463,18 @@ def truediv(a, b): ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map( - a, b, _keep_float(FloatTrueDiv) - ) + return ValueRanges.coordinatewise_monotone_map(a, b, operator.truediv) @staticmethod def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) if 0 in b or ( - # TODO: make this more precise - (-sympy.oo in a or sympy.oo in a) - or (-sympy.oo in b or sympy.oo in b) + (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv) + return ValueRanges.coordinatewise_monotone_map(a, b, operator.floordiv) @classmethod def mod(cls, x, y): @@ -593,51 +523,17 @@ def modular_indexing(cls, a, b, c): @classmethod def is_non_overlapping_and_dense_indicator(cls, *args): - return ValueRanges.unknown() # TODO: type here is wrong - - @classmethod - def pow_by_natural(cls, a, b): - a = ValueRanges.wrap(a) - b = ValueRanges.wrap(b) - if a.is_singleton() and b.is_singleton(): - return ValueRanges.wrap(safe_pow(a.lower, b.lower)) - # NB: Exclude zero, because zero is special - elif a.lower >= 1: - # We should know that b >= 0 but we may have forgotten this fact due - # to replacements, so don't assert it, but DO clamp it to prevent - # degenerate problems - return ValueRanges.coordinatewise_increasing_map( - a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural - ) - elif b.is_singleton(): - if b.lower % 2 == 0: - # x^n where n is even - return ValueRanges.convex_min_zero_map( - a, lambda x: safe_pow(x, b.lower) - ) - else: - # x^n where n is odd - return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower)) - else: - # a is potentially negative, and we don't know if the exponent is - # even or odd. So just conservatively set the upper and lower - # bound based on what the maximum absolute value could be, in both - # directions - max_base = max(a.upper, -a.lower) - return ValueRanges( - -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper) - ) + return ValueRanges.unknown() @classmethod def pow(cls, a, b): - return ValueRanges.unknown() + def is_integer(val): + return isinstance(val, int) or ( + hasattr(val, "is_integer") and val.is_integer + ) - # We could implement all this, but for floating point pow, is there - # really a point? - """ a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) - # Not implemented yet. It's a bit tricky # If you want to implement it, compute the partial derivatives of a ** b # and check the ranges where the function is increasing / decreasing @@ -657,7 +553,8 @@ def pow(cls, a, b): if b == 0: if not a.lower.is_finite: return ValueRanges.unknown() - return ValueRanges.wrap(1.0) + type_ = sympy.Float if a.lower.is_real else sympy.Integer + return ValueRanges.wrap(type_(1)) if b < 0: a = cls.reciprocal(a) @@ -666,12 +563,21 @@ def pow(cls, a, b): if a == ValueRanges.unknown(): return ValueRanges.unknown() - # If the base is positive, then we're good, otherwise nothing's defined - if a.lower >= 0: - return ValueRanges.increasing_map(a, lambda x: x**b) + # Here b > 0 + if not is_integer(b): + # If the base is positive, then we're good, otherwise nothing's defined + if a.lower >= 0: + return ValueRanges.increasing_map(a, lambda x: x**b) + else: + return ValueRanges.unknown() else: - return ValueRanges.unknown() - """ + # b > 0 integer + if b % 2 == 0: + # x^n where n is even + return ValueRanges.convex_min_zero_map(a, lambda x: x**b) + else: + # x^n where n is odd + return ValueRanges.increasing_map(a, lambda x: x**b) @staticmethod def reciprocal(x): @@ -680,7 +586,7 @@ def reciprocal(x): if 0 in x: return ValueRanges.unknown() else: - return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y)) + return ValueRanges.decreasing_map(x, lambda y: 1 / y) @staticmethod def abs(x): @@ -709,64 +615,45 @@ def maximum(cls, a, b): def min_or_max(a, b, fn): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) - return ValueRanges.coordinatewise_increasing_map(a, b, fn) - - @classmethod - def floor_to_int(cls, x, dtype): - return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) - @classmethod - def ceil_to_int(cls, x, dtype): - return ValueRanges.increasing_map( - x, sympy.functions.elementary.integers.ceiling - ) + # Performs upcasting first + def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr: + # Poorman's version of upcasting in Sympy + # Inf is not a float... + if x.is_Integer and y.is_Integer: + result_type = sympy.Integer + elif x.is_rational and y.is_rational: + result_type = sympy.Rational + else: + assert x.is_real or not x.is_finite or y.is_real or not y.is_finite + result_type = sympy.Float + return fn(result_type(x), result_type(y)) - # I think these implementations are sound. The hazard here is that sympy - # will carry out the floor/ceil at too high precision and then something - # bad will happen when we convert it to float. - # - # For truncation, the implementation is clearly sound, because the desired - # target float is always exactly representable, since you're just chopping - # off bits the mantissa. But what about ceil/floor? - # - # The important constraint here is that we're not defining floor on - # arbitrary real numbers, only representable float numbers. So we can - # take advantage of the fact that before we reach the first - # unrepresentable integer in floating point space, we have the range of - # numbers corresponding to exponent zero: all integers, with no fractional - # amounts. floor/ceil is an identity operation in this case. In the - # range below here, representable floating point numbers are spaced - # exactly 1/2 apart, and notably, both the floor/ceil are defined floating - # point numbers. There is no "gap" as you step up to the next exponent. + return ValueRanges.coordinatewise_increasing_map(a, b, fn_) @classmethod def floor(cls, x): - return ValueRanges.increasing_map( - x, _keep_float(sympy.functions.elementary.integers.floor) - ) + return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) @classmethod def ceil(cls, x): return ValueRanges.increasing_map( - x, _keep_float(sympy.functions.elementary.integers.ceiling) + x, sympy.functions.elementary.integers.ceiling ) @classmethod - def round_decimal(cls, number, ndigits): - if not ndigits.is_singleton(): - return ValueRanges.unknown() - - ndigits = ndigits.lower - # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind - # the second parameter. - fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 + def round(cls, number, ndigits=None): + if ndigits is None: + fn = Round + else: + assert ndigits.is_singleton() + ndigits = ndigits.lower + # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind + # the second parameter. + fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 return ValueRanges.increasing_map(number, fn) - @classmethod - def round_to_int(cls, number, dtype): - return ValueRanges.increasing_map(number, RoundToInt) - # It's used in some models on symints @staticmethod def sqrt(x): @@ -821,15 +708,12 @@ def cos(x): @staticmethod def cosh(x): - return ValueRanges(0.0, sympy.oo) - """ x = ValueRanges.wrap(x) if x.lower > 0: return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh) elif x.upper < 0: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh) return ValueRanges(0.0, sympy.oo) - """ @staticmethod def sin(x): @@ -839,8 +723,7 @@ def sin(x): @staticmethod def sinh(x): - # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) - return ValueRanges(-sympy.oo, sympy.oo) + return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) @staticmethod def tan(x): @@ -848,37 +731,32 @@ def tan(x): @staticmethod def tanh(x): - # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) - return ValueRanges(-sympy.oo, sympy.oo) + return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) @staticmethod def asin(x): - return ValueRanges(-sympy.oo, sympy.oo) - """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh) return ValueRanges.unknown() - """ @staticmethod def acos(x): - return ValueRanges(-sympy.oo, sympy.oo) - """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos) return ValueRanges.unknown() - """ @staticmethod def atan(x): - return ValueRanges(-sympy.oo, sympy.oo) - # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) + return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) @staticmethod def trunc(x): - return ValueRanges.increasing_map(x, TruncToFloat) + def trunc(x): + return sympy.Integer(x) if x.is_finite else x + + return ValueRanges.increasing_map(x, trunc) class ValueRangeAnalysis(SymPyValueRangeAnalysis): @@ -913,10 +791,9 @@ def store(self, name, index, value, mode=None): def reduction(self, name, dtype, src_dtype, reduction_type, index, value): return ValueRanges.unknown() - @classmethod - def index_expr(cls, index, dtype): + def index_expr(self, index, dtype): assert isinstance(index, ValueRanges) - return cls.to_dtype(index, dtype) + return index @staticmethod def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): @@ -953,15 +830,12 @@ def cast(x, dtype): @staticmethod def square(x): - return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) + return ValueRanges.convex_min_zero_map(x, lambda y: y * y) @staticmethod def neg(x): return ValueRanges.decreasing_map(x, operator.neg) - # TODO: this is slightly inaccurate because truncdiv operates at integer - # precision, but we're going through float truediv which means we can - # potentially lose precision on the bounds @classmethod def truncdiv(cls, a, b): x = cls.truediv(a, b) @@ -982,7 +856,6 @@ def __getattr__(self, name): def bound_sympy( expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: - log.debug("bound_sympy(%s, %s)", expr, ranges) if isinstance(expr, sympy.Number): return ValueRanges.wrap(expr) From e505132797219b4a31bbf07da8347062b6c8e441 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 5 Jun 2024 04:16:54 +0000 Subject: [PATCH 354/706] [export] track TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS for export runtime asserts (#127554) Track TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 in export so it doesn't omit runtime asserts. Differential Revision: D57978699 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127554 Approved by: https://github.com/tugsbayasgalan --- test/export/test_export.py | 24 ++++++++++++++++++++++++ torch/export/_trace.py | 32 +++++++++++++++----------------- torch/export/exported_program.py | 32 +++++++++++++++++--------------- 3 files changed, 56 insertions(+), 32 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 859a46e80b93..228db50b9dc4 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -5318,6 +5318,30 @@ def forward(self, x, y): ): ep.module()(torch.randn(10), torch.randn(9)) # fail + # this should be set with command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1, + # but dynamo checks that at torch import time, so setting os.environ makes no difference + # instead, manually patch dynamo config and test. + # test that setting this flag removes runtime asserts + from torch._dynamo import config as _dynamo_config + + with _dynamo_config.patch( + do_not_emit_runtime_asserts=True, + ): + ep = torch.export._trace._export( + Foo(), + inputs, + dynamic_shapes=dynamic_shapes, + _allow_complex_guards_as_runtime_asserts=True, + ).run_decompositions() + + self.assertEqual( + [ + node.target == torch.ops.aten._assert_scalar.default + for node in ep.graph.nodes + ].count(True), + 0, + ) + def test_constant_aliasing(self): class M1(torch.nn.Module): def __init__(self, m2, foo): diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 0ed664f43fc9..4fcc85f3236b 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -2,7 +2,6 @@ import functools import inspect import logging -import os import re import time import warnings @@ -553,10 +552,6 @@ def _export_to_aten_ir( pre_dispatch=False, _is_torch_jit_trace=False, ): - # set this to False if env variable is specified - if os.environ.get("TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS", "0") == "1": - should_insert_runtime_assertion = False - # [NOTE] If the user is exporting under training mode, we want to detect if there is any # state change in the autograd global state and error. If the user is exporting under inference # mode, we don't care. At predispatch level, we don't care about the state change. @@ -676,19 +671,22 @@ def make_argument_spec(i, node) -> ArgumentSpec: fake_mode = detect_fake_mode(flat_args) - stack_trace = ( - 'File "torch/fx/passes/runtime_assert.py", line 24, ' - "in insert_deferred_runtime_asserts" - ) - with _set_node_metadata_hook( - gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) - ): - insert_deferred_runtime_asserts( - gm, - fake_mode.shape_env, - f"exported program: {first_call_function_nn_module_stack(gm.graph)}", - export=True, + from torch._dynamo import config as _dynamo_config + + if not _dynamo_config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" ) + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + insert_deferred_runtime_asserts( + gm, + fake_mode.shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) if pre_dispatch: from torch._export.passes.replace_set_grad_with_hop_pass import ( diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index bfdeb5db8e0e..048ffe2e85c9 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -663,26 +663,28 @@ def update_arg(old_arg, new_ph): _replace_sym_size_ops_pass(gm) + from torch._dynamo import config as _dynamo_config from torch._export.passes._node_metadata_hook import ( _node_metadata_hook, _set_node_metadata_hook, ) - stack_trace = ( - 'File "torch/fx/passes/runtime_assert.py", line 24, ' - "in insert_deferred_runtime_asserts" - ) - shape_env = _get_shape_env(gm) - if shape_env is not None: - with _set_node_metadata_hook( - gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) - ): - insert_deferred_runtime_asserts( - gm, - shape_env, - f"exported program: {first_call_function_nn_module_stack(gm.graph)}", - export=True, - ) + if not _dynamo_config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + shape_env = _get_shape_env(gm) + if shape_env is not None: + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) exported_program = ExportedProgram( root=gm, From 30788739f4cec672cdabc21c155bbe30ac5d66e6 Mon Sep 17 00:00:00 2001 From: Shuqiang Zhang Date: Tue, 4 Jun 2024 15:59:05 -0700 Subject: [PATCH 355/706] [c10d] add a simple test to demonstrate the user usage of collectives (#127665) Summary: Just play around the UT and think it would be good to give an simple example of user function which can be used for different subclasses of _ControlCollectives, and test the user function can be executed correctly Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/127665 Approved by: https://github.com/d4l3k --- test/distributed/test_control_collectives.py | 25 ++++++++++++++++++++ torch/_C/_distributed_c10d.pyi | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py index fb0067f2dd2e..594c028ae9d4 100644 --- a/test/distributed/test_control_collectives.py +++ b/test/distributed/test_control_collectives.py @@ -8,6 +8,17 @@ from torch.testing._internal.common_utils import run_tests, TestCase +# simple example of user code that takes the base class ControlCollectives +# and executes multiple different collectives +def simple_user_func(collectives: dist._ControlCollectives, rank: int) -> int: + timeout = timedelta(seconds=10) + # first a barrier + collectives.barrier("1", timeout, True) + # then an all_sum + out = collectives.all_sum("2", rank, timeout) + return out + + class TestCollectives(TestCase): def test_barrier(self) -> None: store = dist.HashStore() @@ -180,6 +191,20 @@ def test_unique(self) -> None: with self.assertRaisesRegex(Exception, "Key foo has already been used"): collectives.all_sum("foo", 2) + def test_simple_user_func(self) -> None: + store = dist.HashStore() + world_size = 4 + + def f(rank: int) -> None: + # user need to create child collectives + # but simple_user_func do not need to be changed for different child collectives + store_collectives = dist._StoreCollectives(store, rank, world_size) + out = simple_user_func(store_collectives, rank) + self.assertEqual(out, sum(range(world_size))) + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + if __name__ == "__main__": assert ( diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index dab215d396ce..d6f7ae259a88 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -223,7 +223,7 @@ class _ControlCollectives: def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ... def scatter_recv(self, key: str, timeout: timedelta) -> str: ... def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ... - def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ... + def all_sum(self, key: str, data: int, timeout: timedelta) -> int: ... class _StoreCollectives(_ControlCollectives): def __init__(self, store: Store, rank: int, world_size: int) -> None: ... From b054470db22a6c8ecba31c44ce54b9ca48159cdd Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 5 Jun 2024 05:21:24 +0000 Subject: [PATCH 356/706] Remove unused functions (#127881) Some unused functions detected by g++ warnings can be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127881 Approved by: https://github.com/zou3519 --- .../impl/kernel_function_legacy_test.cpp | 4 ---- aten/src/ATen/test/vec_test_all_types.cpp | 20 ------------------- test/cpp/lazy/test_lazy_ops_util.cpp | 5 ----- 3 files changed, 29 deletions(-) diff --git a/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp b/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp index 0b0df2af1ca1..a5cb61874173 100644 --- a/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp +++ b/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp @@ -40,10 +40,6 @@ int64_t incrementKernel(const Tensor& tensor, int64_t input) { return input + 1; } -int64_t decrementKernel(const Tensor& tensor, int64_t input) { - return input - 1; -} - void expectCallsIncrement(DispatchKey dispatch_key) { at::AutoDispatchBelowAutograd mode; diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index 4c7e3e5b2b02..f9a0557f8bdf 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -978,26 +978,6 @@ namespace { b[i] = b[i - 1] + (T)(1.0); } } - template<> - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - void blend_init, 4>(Complex(&a)[4], Complex(&b)[4]) { - auto add = Complex(1., 100.); - a[0] = Complex(1., 100.); - b[0] = Complex(5., 1000.); - for (const auto i : c10::irange(1, 4)) { - a[i] = a[i - 1] + add; - b[i] = b[i - 1] + add; - } - } - template<> - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - void blend_init, 2>(Complex(&a)[2], Complex(&b)[2]) { - auto add = Complex(1.0, 100.0); - a[0] = Complex(1.0, 100.0); - b[0] = Complex(3.0, 1000.0); - a[1] = a[0] + add; - b[1] = b[0] + add; - } TYPED_TEST(BitwiseFloatsAdditional, Blendv) { using vec = TypeParam; using VT = ValueType; diff --git a/test/cpp/lazy/test_lazy_ops_util.cpp b/test/cpp/lazy/test_lazy_ops_util.cpp index cc5287cd9b3d..c024780187c7 100644 --- a/test/cpp/lazy/test_lazy_ops_util.cpp +++ b/test/cpp/lazy/test_lazy_ops_util.cpp @@ -12,11 +12,6 @@ namespace torch { namespace lazy { namespace { -bool IsLtcTensor(const at::Tensor& tensor) { - return dynamic_cast( - tensor.unsafeGetTensorImpl()); -} - std::unordered_set* CreateIgnoredCounters() { std::unordered_set* icounters = new std::unordered_set(); From 4a384d813b0b824adbf423558419cfa298d89868 Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Wed, 5 Jun 2024 07:33:27 +0000 Subject: [PATCH 357/706] [SDPA/memeff] Backport changes from xFormers to PT (#127090) Backporting a few fixes from xFormers: * Bug fixes for local attention (which is not exposed in PT at the moment) * Massively reduced memory usage on the BW pass (see also https://github.com/facebookresearch/xformers/pull/1028) Essentially this will also make xFormers build process much easier, as we will be able to use mem-eff from PyTorch (if the user has a recent enough version) rather than building it at xFormers install time The goal is to have the source of truth for these files in PT moving forward, and remove them from xFormers eventually once our users have a recent-enough version of PT. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127090 Approved by: https://github.com/drisspg --- aten/src/ATen/native/native_functions.yaml | 4 +- .../native/transformers/cuda/attention.cu | 14 +- .../transformers/cuda/attention_backward.cu | 53 +++++- .../cuda/mem_eff_attention/kernel_backward.h | 164 +++++++++++------- .../cuda/mem_eff_attention/kernel_forward.h | 61 +++---- .../check_forward_backward_compatibility.py | 3 +- tools/autograd/derivatives.yaml | 2 +- torch/_meta_registrations.py | 25 ++- torch/_subclasses/fake_impls.py | 2 +- .../aoti_torch/generated/c_shim_cuda.h | 4 +- torch/nn/attention/bias.py | 1 - .../_internal/common_methods_invocations.py | 4 - 12 files changed, 209 insertions(+), 128 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 54b12a9a0b0c..aa3d9b3fb2f5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14751,13 +14751,13 @@ CUDA: _flash_attention_backward # Returns output, logsumexp if compute_logsumexp -- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) +- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) variants: function dispatch: CUDA: _efficient_attention_forward tags: nondeterministic_seeded -- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None) -> (Tensor, Tensor, Tensor, Tensor) +- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor) device_check: NoCheck variants: function dispatch: diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index a1cdb47c12b4..b3f07206ccbe 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -959,7 +959,6 @@ std::tuple _efficient_ int64_t custom_mask_type, bool compute_logsumexp, std::optional scale, - const std::optional& causal_diagonal, const std::optional& seqlen_k, const std::optional window_size) { #if defined(USE_MEM_EFF_ATTENTION) @@ -1147,12 +1146,6 @@ std::tuple _efficient_ p.num_keys = max_seqlen_k; p.num_batches = seqstart_q.has_value() ? seqstart_q->size(0) - 1 : B; p.custom_mask_type = custom_mask_type; - p.causal_diagonal_ptr = nullptr; - if (causal_diagonal.has_value()) { - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(causal_diagonal.value()); - TORCH_CHECK(causal_diagonal->scalar_type() == at::ScalarType::Int); - p.causal_diagonal_ptr = (const int32_t*)causal_diagonal->const_data_ptr(); - } p.seqlen_k_ptr = nullptr; if (seqlen_k.has_value()) { @@ -1222,8 +1215,13 @@ std::tuple _efficient_ " kb)"); AT_CUDA_CHECK(err); } + auto blocks = p.getBlocksGrid(); + if (blocks.x * blocks.y * blocks.z == 0 || key.size(1) == 0) { + res.zero_(); + return; + } Kernel::check_supported(p); - kernel_fn<<>>(p); + kernel_fn<<>>(p); }; // Dispatch to the right kernel diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 690f433aa5f2..5d9f0ce98474 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -230,7 +230,8 @@ _efficient_attention_backward( const bool bias_requires_grad, const std::optional scale, std::optional num_splits_key, - const std::optional window_size) { + const std::optional window_size, + const bool shared_storage_dqdkdv) { #if defined(USE_MEM_EFF_ATTENTION) if (!grad_out_.defined()) { return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); @@ -310,9 +311,33 @@ _efficient_attention_backward( int64_t Kv = value.size(3); at::Tensor grad_q, grad_k, grad_v, grad_bias; - grad_q = at::empty(query.sizes(), query.options()); - grad_k = at::empty(key.sizes(), key.options()); - grad_v = at::empty(value.sizes(), value.options()); + if (shared_storage_dqdkdv) { + // Create one big contiguous chunk + // This is because q, k and v usually come from a single + // output of a linear layer that is chunked. + // Creating the gradients with the right layout saves us + // a `torch.cat` call in the backward pass + TORCH_CHECK( + query.size(1) == key.size(1), + "`shared_storage_dqdkdv` is only supported when Q/K/V " + "have the same sequence length: got ", query.size(1), + " query tokens and ", key.size(1), " key/value tokens" + ); + TORCH_CHECK( + query.size(3) == key.size(3), + "`shared_storage_dqdkdv` is only supported when Q/K/V " + "have the same embed dim: got ", query.size(3), + " for Q, and ", key.size(3), " for K" + ); + at::Tensor chunk = at::empty({B, M, 3, nH, K}, query.options()); + grad_q = chunk.select(2, 0); + grad_k = chunk.select(2, 1); + grad_v = chunk.select(2, 2); + } else { + grad_q = at::empty(query.sizes(), query.options()); + grad_k = at::empty(key.sizes(), key.options()); + grad_v = at::empty(value.sizes(), value.options()); + } if (bias_requires_grad) { // force alignment for the last dim @@ -439,8 +464,7 @@ _efficient_attention_backward( ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); - // We removed the chunk/cat optimization and the multiplier is always 1 - p.gQKV_strideM_multiplier = 1; + p.gQKV_strideM_multiplier = shared_storage_dqdkdv ? 3 : 1; TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); @@ -503,8 +527,12 @@ _efficient_attention_backward( auto parallelism_without_split_key = p.getBlocksGrid().x * p.getBlocksGrid().y * p.getBlocksGrid().z; p.num_splits_key = cutlass::ceil_div(p.num_keys, Kernel::kBlockSizeJ); - if (num_splits_key.has_value()) { // Skip heuristic, if user provided an explicit value - p.num_splits_key = std::max(p.num_splits_key, num_splits_key.value()); + if (num_splits_key.has_value()) { + p.num_splits_key = + std::min(p.num_splits_key, num_splits_key.value()); + } else { + // Keys splitting heuristic + // If we already have enough parallelism, split-keys can help // better use L2 cache. // This is negligible when the seqlen is too small tho @@ -545,6 +573,15 @@ _efficient_attention_backward( workspace.zero_(); } } + + // Handle the edge-cases where some tensors are empty + if (p.num_queries == 0 || p.num_keys == 0 || p.num_batches == 0 || + p.num_heads == 0) { + grad_k.zero_(); + grad_v.zero_(); + grad_q.zero_(); + return; + } Kernel::check_supported(p); if (smem_bytes > 0xc000) { diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index 564e3f2f3522..05fa314a2bf6 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -607,7 +607,10 @@ struct AttentionBackwardKernel { using AccumTileGmem = GmemTile; }; - static constexpr bool kEnableSplitKeys = true; + // NOTE: nvcc 12.4 has correctness errors with this on M60 (sm52) + // when there is an attention bias. Let's just disable it for now. + static constexpr auto kMinSm = ArchTag::kMinComputeCapability; + static constexpr bool kEnableSplitKeys = kMinSm >= 70; static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys || !cutlass::platform::is_same::value; @@ -720,11 +723,19 @@ struct AttentionBackwardKernel { int64_t gV_strideH = 0; int64_t gB_strideH = 0; - CUTLASS_DEVICE int16_t num_splits_key_device() const { + CUTLASS_HOST_DEVICE int16_t num_splits_key_device() const { +#ifdef __CUDA_ARCH__ return kEnableSplitKeys ? gridDim.x : 1; +#else + return num_splits_key; // for host-side tests +#endif } - CUTLASS_DEVICE int16_t split_key_device() const { + CUTLASS_HOST_DEVICE int16_t split_key_device() const { +#ifdef __CUDA_ARCH__ return kEnableSplitKeys ? blockIdx.x : 0; +#else + return 0; // for host-side tests +#endif } CUTLASS_DEVICE bool advance_to_block() { @@ -846,14 +857,14 @@ struct AttentionBackwardKernel { if (!kNeedsAccumGradK) { return 0; } - return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + return num_splits_key * kBlockSizeJ * align_up(head_dim, (int32_t)kBlockSizeI); } CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { if (!kNeedsAccumGradV) { return 0; } - return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + return num_splits_key * kBlockSizeJ * align_up(head_dim_value, (int32_t)kBlockSizeI); } CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { @@ -877,7 +888,7 @@ struct AttentionBackwardKernel { return num_batches * num_heads * workspace_strideBH() * sizeof(float); } CUTLASS_HOST_DEVICE bool should_zero_workspace() const { - return num_splits_key > 1; + return num_splits_key > 1 || window_size > 0; } }; @@ -1174,8 +1185,12 @@ struct AttentionBackwardKernel { CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment); - TORCH_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned"); - TORCH_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned"); + TORCH_CHECK( + p.num_heads <= 1 || p.lse_strideH % 8 == 0, + "LSE is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_batches <= 1 || p.lse_strideB % 8 == 0, + "LSE is not correctly aligned (strideB)"); TORCH_CHECK( p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned (strideH)"); @@ -1187,7 +1202,7 @@ struct AttentionBackwardKernel { "value is not correctly aligned (strideH)"); TORCH_CHECK( p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, - "query is not correctly aligned (strideB)."); + "query is not correctly aligned (strideB)"); TORCH_CHECK( p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, "key is not correctly aligned (strideB)"); @@ -1268,15 +1283,18 @@ struct AttentionBackwardKernel { } TORCH_CHECK( kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled"); - TORCH_CHECK(p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); + TORCH_CHECK( + p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); TORCH_CHECK( p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ), "Invalid `num_splits_key` (", p.num_splits_key, ") - too large for `num_keys` = ", p.num_keys); - if (p.window_size > 0) { - TORCH_CHECK(p.custom_mask_type == CausalFromTopLeft); + if (p.window_size != 0) { + TORCH_CHECK( + p.custom_mask_type != NoCustomMask, + "LocalAttention only supported in causal mode"); } return true; } @@ -1338,15 +1356,15 @@ struct AttentionBackwardKernel { std::get<1>(seeds) + p.dropout_batch_head_rng_offset, &rng_state_init); } + CUTLASS_PRAGMA_UNROLL for (; key_start < p.num_keys; key_start += p.num_splits_key_device() * kBlockSizeJ) { output_frags.clear(); - CUTLASS_PRAGMA_UNROLL - for (int32_t query_start_shifted = getQueryStart(p, key_start); - query_start_shifted < getQueryStartShift(p) + getQueryEnd(p); - query_start_shifted += kBlockSizeI) { + int32_t next_key = key_start; + int32_t query_start = getQueryStart(p, key_start); + while (next_key == key_start && query_start < p.num_queries) { // This line here // vvvvvvvvvvvvvv warp_id = warp_uniform(warp_id); @@ -1357,11 +1375,6 @@ struct AttentionBackwardKernel { // from the previous iteration, which prevents MASSIVE // register spilling. - int32_t query_start = query_start_shifted; - if (query_start >= p.num_queries) { - query_start = query_start % getQueryEnd(p); - } - processBlockIJ( shared_storage, output_frags, @@ -1371,6 +1384,10 @@ struct AttentionBackwardKernel { rng_state_init, warp_id, lane_id); + + int32_t next_query; + incrIteration(p, query_start, key_start, next_query, next_key); + query_start = next_query; } if (kOutputInRF) { writeFragsToGmem( @@ -1466,13 +1483,7 @@ struct AttentionBackwardKernel { ? MatmulQK::Mma::Shape::kM : warp_uniform(cutlass::fast_min( (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start)); - if (p.window_size > 0) { - if (p.custom_mask_type == CausalFromTopLeft && - key_start + num_keys_in_block <= - int32_t(query_start) - p.window_size) { - return; - } - } + auto prologueGradV = [&](int col) { typename MatmulGradV::Mma::IteratorB iterator_dO( {int32_t(p.gO_strideM)}, @@ -2119,14 +2130,20 @@ struct AttentionBackwardKernel { p.grad_query_ptr + query_start * p.gQ_strideM() + col, {problem_size.m(), problem_size.n()}, thread_id); - bool storage_contains_zeros = kNeedsAccumGradQ || key_start == 0 || + // if `direct_store` is True, we store to gmem (`*gmem = accum`) + // otherwise, we accumulate in gmem (`*gmem = *gmem + accum`) + // If we know ahead of time when we will write for the first time + // we can: + // (1) Avoid an additional memory read + // (2) Avoid the cost of initializing memory to 0 + bool direct_store = kNeedsAccumGradQ || key_start == 0 || (p.num_splits_key_device() > 1); accumulateInGmem( isLastColumn ? shared_storage.gradQ_epilogue_lastIter() : shared_storage.gradQ_epilogue(), accum, output_it, - storage_contains_zeros, + direct_store, warp_id, lane_id); } @@ -2237,12 +2254,13 @@ struct AttentionBackwardKernel { isFirstQuery || kNeedsAccumGradK, warp_id, lane_id); + __syncthreads(); } } } } - static CUTLASS_DEVICE int32_t getQueryStartShift(Params const& p) { + static CUTLASS_HOST_DEVICE int32_t getQueryStartShift(Params const& p) { if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) { return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p); } @@ -2250,55 +2268,70 @@ struct AttentionBackwardKernel { } // Iteration order logic - static CUTLASS_DEVICE int32_t + static CUTLASS_HOST_DEVICE int32_t getQueryStart(Params const& p, int32_t key_start) { return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p); }; - static CUTLASS_DEVICE int32_t getQueryEnd(Params const& p) { + static CUTLASS_HOST_DEVICE int32_t getQueryEnd(Params const& p) { return align_up(p.num_queries, kBlockSizeI); }; - static CUTLASS_DEVICE int32_t + static CUTLASS_HOST_DEVICE int32_t getSmallestQueryForKey(Params const& p, int32_t key_start) { - if (p.custom_mask_type == CausalFromTopLeft) { - return (key_start / kBlockSizeI) * kBlockSizeI; - } else if (p.custom_mask_type == CausalFromBottomRight) { - int first_query = - cutlass::fast_max(0, key_start - p.num_keys + p.num_queries); - return (first_query / kBlockSizeI) * kBlockSizeI; + if (p.custom_mask_type == NoCustomMask) { + return 0; } - return 0; + int32_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + int32_t window_size = + p.window_size == 0 ? p.num_queries + p.num_keys : p.window_size; + + auto last_key_for_block = + cutlass::fast_min(key_start + kBlockSizeJ, p.num_keys) - 1; + int first_query = key_start - shift; + int last_query = last_key_for_block - shift + window_size - 1; + if (last_query < 0 || first_query >= p.num_queries) { + return getQueryEnd(p); // nothing to compute in this column + } + first_query = cutlass::fast_max(0, first_query); + return (first_query / kBlockSizeI) * kBlockSizeI; }; // Returns how many kernel blocks will write to a given block in `grad_query` // This is usually equal to the number of key splits, but can be different // for instance in the causal case, or varying seqlen - static CUTLASS_DEVICE int32_t + static CUTLASS_HOST_DEVICE int32_t getNumParallelBlocksForQuery(Params const& p, int32_t query_start) { int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ); - if (p.custom_mask_type == CausalFromTopLeft) { - int32_t last_key_for_block = query_start + kBlockSizeI - 1; - last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); + if (p.custom_mask_type != NoCustomMask) { + int32_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + int32_t last_query_for_block = + cutlass::fast_min(query_start + kBlockSizeI, p.num_queries) - 1; + int32_t last_key_for_block = + cutlass::fast_min(last_query_for_block + shift, p.num_keys - 1); + int32_t first_key_for_block = p.window_size == 0 + ? 0 + : cutlass::fast_max(query_start - p.window_size + 1 + shift, 0); + if (p.window_size == 0) { - num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); + num_key_blocks = last_key_for_block / kBlockSizeJ + 1; } else { - int32_t first_key_for_block = - cutlass::fast_max(query_start - p.window_size + 1, 0); - int32_t first_key_block = first_key_for_block / kBlockSizeJ; - int32_t last_key_block = last_key_for_block / kBlockSizeJ; - num_key_blocks = last_key_block - first_key_block + 1; + num_key_blocks = (last_key_for_block / kBlockSizeJ) - + (first_key_for_block / kBlockSizeJ) + 1; + } + + if (last_key_for_block < 0 || first_key_for_block >= p.num_keys) { + num_key_blocks = 0; } - } else if (p.custom_mask_type == CausalFromBottomRight) { - int32_t last_key_for_block = - query_start + (kBlockSizeI - 1) + (1 + p.num_keys - p.num_queries); - last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); - num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); } return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks); }; // Returns the next block to process - static CUTLASS_DEVICE void incrIteration( + static CUTLASS_HOST_DEVICE void incrIteration( Params const& p, int32_t query_start, int32_t key_start, @@ -2318,14 +2351,19 @@ struct AttentionBackwardKernel { return; } } else { - if (p.window_size == 0 && next_query < p.num_queries) { - return; - } else if (p.window_size > 0) { - if (next_query < - cutlass::fast_min( - key_start + kBlockSizeJ + p.window_size, p.num_queries)) { + if (p.window_size > 0) { + int32_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + // last key that is not masked out + int last_key_for_block = + cutlass::fast_min(key_start + kBlockSizeJ, p.num_keys) - 1; + int last_query = last_key_for_block - shift + p.window_size - 1; + if (next_query <= last_query && next_query < p.num_queries) { return; } + } else if (next_query < p.num_queries) { + return; } // jump to next key } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h index 74330ecd242a..642145f5a0da 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -138,7 +138,6 @@ struct AttentionKernel { const int32_t* seqstart_q_ptr = nullptr; const int32_t* seqstart_k_ptr = nullptr; - const int32_t* causal_diagonal_ptr = nullptr; const int32_t* seqlen_k_ptr = nullptr; uint32_t causal_diagonal_offset = 0; @@ -153,44 +152,44 @@ struct AttentionKernel { int32_t window_size = 0; // Scale - accum_t scale; + accum_t scale = 0.0; // Dimensions/strides - int32_t head_dim; - int32_t head_dim_value; - int32_t num_queries; - int32_t num_keys; - int32_t num_keys_absolute; + int32_t head_dim = 0; + int32_t head_dim_value = 0; + int32_t num_queries = 0; + int32_t num_keys = 0; + int32_t num_keys_absolute = 0; uint8_t custom_mask_type = NoCustomMask; - int32_t q_strideM; - int32_t k_strideM; - int32_t v_strideM; + int32_t q_strideM = 0; + int32_t k_strideM = 0; + int32_t v_strideM = 0; int32_t bias_strideM = 0; int32_t o_strideM = 0; // Everything below is only used in `advance_to_block` // and shouldn't use registers - int32_t q_strideH; - int32_t k_strideH; - int32_t v_strideH; + int32_t q_strideH = 0; + int32_t k_strideH = 0; + int32_t v_strideH = 0; int64_t bias_strideH = 0; - int64_t q_strideB; - int64_t k_strideB; - int64_t v_strideB; + int64_t q_strideB = 0; + int64_t k_strideB = 0; + int64_t v_strideB = 0; int64_t bias_strideB = 0; - int32_t num_batches; - int32_t num_heads; + int32_t num_batches = 0; + int32_t num_heads = 0; // dropout - bool use_dropout; - unsigned long long dropout_batch_head_rng_offset; - float dropout_prob; - at::PhiloxCudaState rng_engine_inputs; + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; + at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0); int64_t* extragraph_offset; int64_t* seed; @@ -209,7 +208,7 @@ struct AttentionKernel { head_id * num_queries * num_keys; } - int64_t q_start, k_start; + int64_t q_start = 0, k_start = 0; // Advance to current batch - in case of different sequence lengths if (seqstart_q_ptr != nullptr) { assert(seqstart_k_ptr != nullptr); @@ -274,11 +273,8 @@ struct AttentionKernel { } // Custom masking - if (causal_diagonal_ptr) { - causal_diagonal_offset = causal_diagonal_ptr[batch_id]; - } if (custom_mask_type == CausalFromBottomRight) { - causal_diagonal_offset += num_keys - num_queries; + causal_diagonal_offset = num_keys - num_queries; } // We use num_keys_absolute to index into the rng_state // We need this index to match between forward and backwards @@ -302,7 +298,7 @@ struct AttentionKernel { // - we only launch kernels for head_id % kQueriesPerBlock == 0 // - we iterate over heads instead of queries (strideM = strideH) if (num_queries == 1 && k_strideH == 0 && v_strideH == 0 && - logsumexp_ptr == nullptr) { + logsumexp_ptr == nullptr && window_size == 0) { if (head_id % kQueriesPerBlock != 0) { return false; } @@ -318,6 +314,7 @@ struct AttentionKernel { // Make sure the compiler knows these variables are the same on all // the threads of the warp. + // Only worth doing if they could have been modified above. query_ptr = warp_uniform(query_ptr); key_ptr = warp_uniform(key_ptr); value_ptr = warp_uniform(value_ptr); @@ -330,8 +327,6 @@ struct AttentionKernel { num_queries = warp_uniform(num_queries); num_keys = warp_uniform(num_keys); num_heads = warp_uniform(num_heads); - head_dim = warp_uniform(head_dim); - head_dim_value = warp_uniform(head_dim_value); o_strideM = warp_uniform(o_strideM); custom_mask_type = warp_uniform(custom_mask_type); return true; @@ -614,16 +609,14 @@ struct AttentionKernel { TORCH_CHECK( p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, "value is not correctly aligned (strideH)"); - TORCH_CHECK( - p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask, - "`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal"); TORCH_CHECK( p.custom_mask_type < NumCustomMaskTypes, "invalid value for `custom_mask_type`"); if (p.window_size > 0) { TORCH_CHECK( p.custom_mask_type == CausalFromTopLeft || - p.custom_mask_type == CausalFromBottomRight); + p.custom_mask_type == CausalFromBottomRight, + "custom_mask_type not supported"); } return true; } diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 81b85a4fe42f..88927e8bf7ce 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -135,7 +135,8 @@ ("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)), ("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)), ("aten::sym_constrain_range", datetime.date(2023, 12, 31)), - ("aten::_efficient_attention_forward", datetime.date(2024, 1, 15)), + ("aten::_efficient_attention_forward", datetime.date(2024, 7, 1)), + ("aten::_efficient_attention_backward", datetime.date(2024, 7, 1)), ("onednn::qconv1d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)), diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 4922513f295d..1e9b9091a20e 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2820,7 +2820,7 @@ output_differentiability: [True, False, False, False, False] query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale, window_size_left, window_size_right) -- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) +- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) output_differentiability: [True, False, False, False, False, False] query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 7442ca9157e3..759870b4427d 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5236,10 +5236,29 @@ def meta__efficient_attention_backward( bias_requires_grad: bool, scale: Optional[float] = None, num_splits_key: Optional[int] = None, + shared_storage_dqdkdv: bool = False, ): - grad_query = torch.empty_like(query) - grad_key = torch.empty_like(key) - grad_value = torch.empty_like(value) + if shared_storage_dqdkdv: + torch._check( + query.shape[1] == key.shape[1], + lambda: "seqlen must match for `shared_storage_dqdkdv", + ) + torch._check( + query.shape[3] == key.shape[3], + lambda: "embedding dim must match for `shared_storage_dqdkdv", + ) + chunk = torch.empty( + (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]), + dtype=query.dtype, + device=query.device, + ) + grad_query = chunk.select(-3, 0) + grad_key = chunk.select(-3, 1) + grad_value = chunk.select(-3, 2) + else: + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) if bias is not None: lastDim = bias.size(-1) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 2b1cf13cc935..90b2c878ab2a 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -903,7 +903,7 @@ def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs): max_seqlen_q = kwargs["max_seqlen_q"] max_seqlen_k = kwargs["max_seqlen_k"] compute_log_sumexp = kwargs["compute_log_sumexp"] - # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k + # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, seqlen_k def convert_tensor(t, device): return FakeTensor(fake_mode, t, device) diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index a96cdaee5eb3..c973f69cb69d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -19,8 +19,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__addmm_activation(AtenTensorHan AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cudnn_rnn(AtenTensorHandle input, const AtenTensorHandle* weight, int64_t weight_len_, int64_t weight_stride0, AtenTensorHandle* weight_buf, AtenTensorHandle hx, AtenTensorHandle* cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, int32_t batch_first, double dropout, int32_t train, int32_t bidirectional, const int64_t* batch_sizes, int64_t batch_sizes_len_, AtenTensorHandle* dropout_state, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle out, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, AtenTensorHandle logsumexp, double dropout_p, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, int64_t custom_mask_type, int32_t bias_requires_grad, double* scale, int64_t* num_splits_key, int64_t* window_size, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t* max_seqlen_q, int64_t* max_seqlen_k, double dropout_p, int64_t custom_mask_type, int32_t compute_log_sumexp, double* scale, AtenTensorHandle* causal_diagonal, AtenTensorHandle* seqlen_k, int64_t* window_size, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle out, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, AtenTensorHandle logsumexp, double dropout_p, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, int64_t custom_mask_type, int32_t bias_requires_grad, double* scale, int64_t* num_splits_key, int64_t* window_size, int32_t shared_storage_dqdkdv, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t* max_seqlen_q, int64_t* max_seqlen_k, double dropout_p, int64_t custom_mask_type, int32_t compute_log_sumexp, double* scale, AtenTensorHandle* seqlen_k, int64_t* window_size, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0); diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index d54ed8915789..c7f6b41d660c 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -249,7 +249,6 @@ def _dispatch( custom_mask_type=int(attn_mask.variant), compute_log_sumexp=compute_log_sumexp, scale=scale, - causal_diagonal=None, seqlen_k=None, )[0].transpose(1, 2) else: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 50cfac763be5..832a75e6639a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8688,7 +8688,6 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g custom_mask_type=mask_type, compute_log_sumexp=requires_grad, scale=scale, - causal_diagonal=None, seqlen_k=None )) @@ -8706,7 +8705,6 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g custom_mask_type=0, # No Mask compute_log_sumexp=requires_grad, scale=None, - causal_diagonal=None, seqlen_k=None ) @@ -8725,7 +8723,6 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g custom_mask_type=0, # No Mask compute_log_sumexp=requires_grad, scale=None, - causal_diagonal=None, seqlen_k=None ) ) @@ -8748,7 +8745,6 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g custom_mask_type=0, # No Mask compute_log_sumexp=requires_grad, scale=None, - causal_diagonal=None, seqlen_k=None, ) ) From bb2de3b10120f91afce8da6233094076713f673d Mon Sep 17 00:00:00 2001 From: ibartol Date: Wed, 5 Jun 2024 07:37:29 +0000 Subject: [PATCH 358/706] Fixed broken link and removed unfinished sentence from issue #126367 (#127938) Fixes #126367. ## Description Fixed a broken link in the pytorch/docs/source/torch.compiler_faq.rst doc and deleted a few words that were extra according to the issue tagged above. ## Checklist - [X] The issue that is being fixed is referred in the description - [X] Only one issue is addressed in this pull request - [X] Labels from the issue that this PR is fixing are added to this pull request - [X] No unnecesary issues are included into this pull request Pull Request resolved: https://github.com/pytorch/pytorch/pull/127938 Approved by: https://github.com/msaroufim --- docs/source/torch.compiler_faq.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/torch.compiler_faq.rst b/docs/source/torch.compiler_faq.rst index aeaf308ac090..a5883ce015be 100644 --- a/docs/source/torch.compiler_faq.rst +++ b/docs/source/torch.compiler_faq.rst @@ -37,7 +37,7 @@ backwards ops, due to how AOTAutograd compiled functions interact with dispatcher hooks. The basic strategy for optimizing DDP with Dynamo is outlined in -`distributed.py `__ +`distributed.py `__ where the main idea will be to graph break on `DDP bucket boundaries `__. @@ -186,7 +186,7 @@ The above are general principles for accelerating PyTorch code but different backends will each make different tradeoffs on what to optimize. For example Inductor first takes care of fusing whatever it can and only then generates `Triton `__ -kernels. It can also +kernels. Triton in addition offers speedups because of automatic memory coalescing, memory management and scheduling within each Streaming From 9a8ab778d34bd24c5caceb340837483decc4c311 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Jun 2024 08:59:53 +0000 Subject: [PATCH 359/706] Revert "[BE]: Update cudnn to 9.1.0.70 (#123475)" This reverts commit c490046693e77e254664e19d940e9b05a1da18ef. Reverted https://github.com/pytorch/pytorch/pull/123475 on behalf of https://github.com/huydhn due to CUDA trunk jobs are pretty red after this change, and the forward fix https://github.com/pytorch/pytorch/pull/127984 does not look working ([comment](https://github.com/pytorch/pytorch/pull/123475#issuecomment-2149258430)) --- .ci/docker/build.sh | 50 +++++++++---------- .ci/docker/common/install_base.sh | 2 +- .ci/docker/common/install_cudnn.sh | 17 ++++--- .ci/docker/ubuntu-cuda/Dockerfile | 2 +- .../scripts/generate_binary_build_matrix.py | 8 +-- .github/workflows/docker-builds.yml | 18 +++---- ...linux-aarch64-binary-manywheel-nightly.yml | 10 ++-- .../generated-linux-binary-manywheel-main.yml | 6 +-- ...nerated-linux-binary-manywheel-nightly.yml | 30 +++++------ ...d-linux-s390x-binary-manywheel-nightly.yml | 10 ++-- ...rated-macos-arm64-binary-wheel-nightly.yml | 10 ++-- ...generated-windows-binary-wheel-nightly.yml | 40 +++++++-------- .../workflows/inductor-micro-benchmark.yml | 2 +- .github/workflows/inductor-perf-compare.yml | 2 +- .../workflows/inductor-perf-test-nightly.yml | 2 +- .github/workflows/inductor-periodic.yml | 2 +- .github/workflows/inductor.yml | 12 ++--- .github/workflows/lint.yml | 4 +- .github/workflows/periodic.yml | 4 +- .github/workflows/pull.yml | 24 ++++----- .github/workflows/slow.yml | 4 +- .../target-determination-indexer.yml | 2 +- .github/workflows/torchbench.yml | 2 +- .github/workflows/trunk.yml | 12 ++--- .../aot_eager_timm_training.csv | 2 +- .../dynamic_inductor_torchbench_training.csv | 2 +- .../cu124/inductor_torchbench_training.csv | 2 +- .../dynamic_aot_eager_timm_training.csv | 2 +- docker.Makefile | 2 +- 29 files changed, 145 insertions(+), 140 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 537b0b9d2ba7..fa4dbf2b0165 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -91,9 +91,9 @@ _UCC_COMMIT=20eae37090a4ce1b32bcce6144ccad0b49943e0b # configuration, so we hardcode everything here rather than do it # from scratch case "$image" in - pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=9 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -105,9 +105,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9) + pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9) CUDA_VERSION=12.1.1 - CUDNN_VERSION=9 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -119,9 +119,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.4.0 - CUDNN_VERSION=9 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -134,9 +134,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.1.1 - CUDNN_VERSION=9 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -149,9 +149,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks) CUDA_VERSION=12.1.1 - CUDNN_VERSION=9 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=9 PROTOBUF=yes @@ -164,9 +164,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks) CUDA_VERSION=12.4.0 - CUDNN_VERSION=9 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=9 PROTOBUF=yes @@ -179,9 +179,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9) + pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9) CUDA_VERSION=11.8.0 - CUDNN_VERSION=9 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -193,9 +193,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=9 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -207,9 +207,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9) + pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9) CUDA_VERSION=12.1.1 - CUDNN_VERSION=9 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -221,9 +221,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=9 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -330,10 +330,10 @@ case "$image" in DOCS=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12) + pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12) ANACONDA_PYTHON_VERSION=3.8 CUDA_VERSION=11.8 - CUDNN_VERSION=9 + CUDNN_VERSION=8 CLANG_VERSION=12 PROTOBUF=yes DB=yes @@ -380,7 +380,7 @@ case "$image" in ANACONDA_PYTHON_VERSION=3.9 CONDA_CMAKE=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter) + pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter) ANACONDA_PYTHON_VERSION=3.9 CUDA_VERSION=11.8 CONDA_CMAKE=yes @@ -447,7 +447,7 @@ tmp_tag=$(basename "$(mktemp -u)" | tr '[:upper:]' '[:lower:]') #when using cudnn version 8 install it separately from cuda if [[ "$image" == *cuda* && ${OS} == "ubuntu" ]]; then IMAGE_NAME="nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu${UBUNTU_VERSION}" - if [[ ${CUDNN_VERSION} == 9 ]]; then + if [[ ${CUDNN_VERSION} == 8 ]]; then IMAGE_NAME="nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}" fi fi @@ -499,7 +499,7 @@ docker build \ "$@" \ . -# NVIDIA dockers for RC releases use tag names like `11.0-cudnn9-devel-ubuntu18.04-rc`, +# NVIDIA dockers for RC releases use tag names like `11.0-cudnn8-devel-ubuntu18.04-rc`, # for this case we will set UBUNTU_VERSION to `18.04-rc` so that the Dockerfile could # find the correct image. As a result, here we have to replace the # "$UBUNTU_VERSION" == "18.04-rc" diff --git a/.ci/docker/common/install_base.sh b/.ci/docker/common/install_base.sh index fd58ad8a60b8..ebaa17878ade 100755 --- a/.ci/docker/common/install_base.sh +++ b/.ci/docker/common/install_base.sh @@ -3,7 +3,7 @@ set -ex install_ubuntu() { - # NVIDIA dockers for RC releases use tag names like `11.0-cudnn9-devel-ubuntu18.04-rc`, + # NVIDIA dockers for RC releases use tag names like `11.0-cudnn8-devel-ubuntu18.04-rc`, # for this case we will set UBUNTU_VERSION to `18.04-rc` so that the Dockerfile could # find the correct image. As a result, here we have to check for # "$UBUNTU_VERSION" == "18.04"* diff --git a/.ci/docker/common/install_cudnn.sh b/.ci/docker/common/install_cudnn.sh index 60f4561d420c..3afd2f28841f 100644 --- a/.ci/docker/common/install_cudnn.sh +++ b/.ci/docker/common/install_cudnn.sh @@ -1,18 +1,23 @@ #!/bin/bash -if [[ -n "${CUDNN_VERSION}" ]]; then +if [[ ${CUDNN_VERSION} == 8 ]]; then # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement mkdir tmp_cudnn pushd tmp_cudnn - if [[ ${CUDA_VERSION:0:2} == "12" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda12-archive" - elif [[ ${CUDA_VERSION:0:2} == "11" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda11-archive" + if [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-8.9.7.29_cuda12-archive" + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz + elif [[ ${CUDA_VERSION:0:4} == "12.1" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-8.9.2.26_cuda12-archive" + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz + elif [[ ${CUDA_VERSION:0:4} == "11.8" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-8.7.0.84_cuda11-archive" + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/redist/cudnn/v8.7.0/local_installers/11.8/${CUDNN_NAME}.tar.xz else print "Unsupported CUDA version ${CUDA_VERSION}" exit 1 fi - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz + tar xf ${CUDNN_NAME}.tar.xz cp -a ${CUDNN_NAME}/include/* /usr/local/cuda/include/ cp -a ${CUDNN_NAME}/lib/* /usr/local/cuda/lib64/ diff --git a/.ci/docker/ubuntu-cuda/Dockerfile b/.ci/docker/ubuntu-cuda/Dockerfile index 3b2bbea0097a..cb3ea502d231 100644 --- a/.ci/docker/ubuntu-cuda/Dockerfile +++ b/.ci/docker/ubuntu-cuda/Dockerfile @@ -139,7 +139,7 @@ COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm ARG CUDNN_VERSION ARG CUDA_VERSION COPY ./common/install_cudnn.sh install_cudnn.sh -RUN if [ -n "${CUDNN_VERSION}" ]; then bash install_cudnn.sh; fi +RUN if [ "${CUDNN_VERSION}" -eq 8 ]; then bash install_cudnn.sh; fi RUN rm install_cudnn.sh # Install CUSPARSELT diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 3e50cb7930fa..b192475f72b1 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -19,7 +19,7 @@ CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.1": "12.1.1", "12.4": "12.4.0"} -CUDA_ARCHES_CUDNN_VERSION = {"11.8": "9", "12.1": "9", "12.4": "9"} +CUDA_ARCHES_CUDNN_VERSION = {"11.8": "8", "12.1": "8", "12.4": "8"} ROCM_ARCHES = ["6.0", "6.1"] @@ -42,7 +42,7 @@ "nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950 "nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " @@ -55,7 +55,7 @@ "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950 "nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | " @@ -68,7 +68,7 @@ "nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | " diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index f732dab42050..0eec1556bb96 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -38,19 +38,19 @@ jobs: matrix: runner: [linux.12xlarge] docker-image-name: [ - pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9, - pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9, - pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9, + pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9, + pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9, + pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9, pytorch-linux-focal-py3.8-clang10, pytorch-linux-focal-py3.11-clang10, pytorch-linux-focal-py3.12-clang10, pytorch-linux-focal-rocm-n-1-py3, pytorch-linux-focal-rocm-n-py3, - pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12, + pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12, pytorch-linux-focal-py3-clang9-android-ndk-r21e, pytorch-linux-jammy-py3.8-gcc11, pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks, @@ -58,7 +58,7 @@ jobs: pytorch-linux-jammy-py3-clang15-asan, pytorch-linux-focal-py3-clang10-onnx, pytorch-linux-focal-linter, - pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter, + pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter, pytorch-linux-jammy-py3-clang12-executorch ] include: diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index a1a7e6fd9537..726dbf40f985 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -54,7 +54,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_8-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cpu-aarch64-test: # Testing @@ -162,7 +162,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-aarch64-test: # Testing @@ -270,7 +270,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-aarch64-test: # Testing @@ -378,7 +378,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-aarch64-test: # Testing @@ -486,7 +486,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-aarch64-test: # Testing diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 053877b1c90e..6e7edae7b613 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -48,7 +48,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda11_8-test: # Testing @@ -88,7 +88,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_1-test: # Testing @@ -128,7 +128,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_4-test: # Testing diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 03e3e3f4db20..8ad43b4c3660 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -174,7 +174,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda11_8-test: # Testing @@ -237,7 +237,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_1-test: # Testing @@ -300,7 +300,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_4-test: # Testing @@ -690,7 +690,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda11_8-test: # Testing @@ -753,7 +753,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda12_1-test: # Testing @@ -816,7 +816,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda12_4-test: # Testing @@ -1206,7 +1206,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda11_8-test: # Testing @@ -1269,7 +1269,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_1-test: # Testing @@ -1332,7 +1332,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_4-test: # Testing @@ -1722,7 +1722,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda11_8-test: # Testing @@ -1785,7 +1785,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_1-test: # Testing @@ -1848,7 +1848,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_4-test: # Testing @@ -2238,7 +2238,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda11_8-test: # Testing @@ -2301,7 +2301,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_1-test: # Testing @@ -2364,7 +2364,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_4-test: # Testing diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index db0748463da5..4f0569c253f2 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -54,7 +54,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_8-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cpu-s390x-test: # Testing @@ -117,7 +117,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_9-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-s390x-test: # Testing @@ -180,7 +180,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-s390x-test: # Testing @@ -243,7 +243,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-s390x-test: # Testing @@ -306,7 +306,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-s390x-test: # Testing diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index b4910d46ed5e..94a8fd9cd3de 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -46,7 +46,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -165,7 +165,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -284,7 +284,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -403,7 +403,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -522,7 +522,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index d06f99bd9a5a..d64c221e7895 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -46,7 +46,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -290,7 +290,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -536,7 +536,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -782,7 +782,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1027,7 +1027,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1271,7 +1271,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1517,7 +1517,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1763,7 +1763,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2008,7 +2008,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2252,7 +2252,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2498,7 +2498,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2744,7 +2744,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2989,7 +2989,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3233,7 +3233,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3479,7 +3479,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3725,7 +3725,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3970,7 +3970,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4214,7 +4214,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4460,7 +4460,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4706,7 +4706,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index 431545ea6d0d..4fe0ddf50ef2 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -21,7 +21,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index a5e4ad1781aa..e485a8bfce1b 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -18,7 +18,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 2f129c52fe13..e77c915749f3 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -71,7 +71,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 2fe649cebb5e..6f8c06ed030b 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -23,7 +23,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 7ce641761f2e..0f9c81104f9f 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -44,7 +44,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -86,7 +86,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -112,7 +112,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -133,7 +133,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -175,7 +175,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -189,7 +189,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e0e4d3c20cd8..f1b6611d00e0 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,7 +20,7 @@ jobs: with: timeout: 120 runner: linux.2xlarge - docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter + docker-image: pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 @@ -36,7 +36,7 @@ jobs: with: timeout: 120 runner: linux.2xlarge - docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter + docker-image: pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index b2f404da6d65..925bca54c074 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -67,7 +67,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda11.8-py3.9-gcc9 - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -89,7 +89,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda11.8-py3.10-gcc9-debug - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 build-with-debug: true test-matrix: | { include: [ diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 71f1e11094e2..2b81e998bde5 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -237,7 +237,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda11.8-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 test-matrix: | { include: [ { config: "distributed", shard: 1, num_shards: 3, runner: "linux.8xlarge.nvidia.gpu" }, @@ -262,7 +262,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, @@ -290,7 +290,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, @@ -325,12 +325,12 @@ jobs: { config: "default", shard: 1, num_shards: 1 }, ]} - linux-jammy-cuda-11_8-cudnn9-py3_8-clang12-build: - name: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 + linux-jammy-cuda-11_8-cudnn8-py3_8-clang12-build: + name: linux-jammy-cuda11.8-cudnn8-py3.8-clang12 uses: ./.github/workflows/_linux-build-label.yml with: - build-environment: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 - docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12 + build-environment: linux-jammy-cuda11.8-cudnn8-py3.8-clang12 + docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -389,7 +389,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 cuda-version: cpu test-matrix: | { include: [ @@ -401,7 +401,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 cuda-version: "12.1" test-matrix: | { include: [ @@ -413,7 +413,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 cuda-version: "12.4" test-matrix: | { include: [ @@ -475,7 +475,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -502,7 +502,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 50f74b01f08c..31db7af8fc55 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -41,7 +41,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3-gcc9-slow-gradcheck - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -70,7 +70,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index e8bf91c8d9ee..0ce1bae6a413 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -26,7 +26,7 @@ jobs: id: calculate-docker-image uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 working-directory: pytorch - name: Use following to pull public copy of the image diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index ac5814966899..73befe34c078 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -16,7 +16,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index e727445d3ada..a91238fa2c9b 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -39,7 +39,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 test-matrix: | { include: [ { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, @@ -63,7 +63,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: libtorch-linux-focal-cuda12.1-py3.7-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 build-generates-artifacts: false runner: linux.4xlarge test-matrix: | @@ -77,7 +77,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-no-ops - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -88,7 +88,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 test-matrix: | { include: [ { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, @@ -112,7 +112,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: libtorch-linux-focal-cuda12.4-py3.7-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 build-generates-artifacts: false runner: linux.4xlarge test-matrix: | @@ -126,7 +126,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-no-ops - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv index fe7efa082cea..1def1d99bd53 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv @@ -218,7 +218,7 @@ tf_mixnet_l,pass,6 -tinynet_a,fail_accuracy,6 +tinynet_a,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv index ee58808c0bb0..a3c9c3915fc5 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv @@ -182,7 +182,7 @@ phlippe_densenet,pass,6 -phlippe_resnet,pass,6 +phlippe_resnet,fail_accuracy,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv index cfc524426644..02411bef6cc5 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv @@ -182,7 +182,7 @@ phlippe_densenet,pass,6 -phlippe_resnet,pass,6 +phlippe_resnet,fail_accuracy,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv index fe7efa082cea..1def1d99bd53 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv @@ -218,7 +218,7 @@ tf_mixnet_l,pass,6 -tinynet_a,fail_accuracy,6 +tinynet_a,pass,6 diff --git a/docker.Makefile b/docker.Makefile index 7f131707e7ab..a33c411907bc 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -10,7 +10,7 @@ endif CUDA_VERSION_SHORT ?= 12.1 CUDA_VERSION ?= 12.1.1 -CUDNN_VERSION ?= 9 +CUDNN_VERSION ?= 8 BASE_RUNTIME = ubuntu:22.04 BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-devel-ubuntu22.04 CMAKE_VARS ?= From c3949b20a14875bf22934246404eb3cb497fb9cc Mon Sep 17 00:00:00 2001 From: weiyusheng Date: Wed, 5 Jun 2024 13:01:16 +0000 Subject: [PATCH 360/706] Opt model save and load (#126374) ## save&load support for OptimizedModule [Issue Description](https://github.com/pytorch/pytorch/pull/101651) English is not my native language; please excuse typing errors. This pr is based on commit b9588101c4d3411b107fdc860acfa8a72c642f91\ I'll do something with the merge conflicts later ### test result for test/dynamo Conclusion:\ It performs the same as before as far as I can see. ENV(CPU only):\ platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\ configfile: pytest.ini\ plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0 #### before this pr: [before](https://github.com/pytorch/pytorch/files/15329370/before.md) #### after this pr: [after](https://github.com/pytorch/pytorch/files/15329376/after.md) ### some changes 1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'" 2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling 3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py ) 4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition Pull Request resolved: https://github.com/pytorch/pytorch/pull/126374 Approved by: https://github.com/msaroufim, https://github.com/jansel --- test/dynamo/test_modules.py | 46 +++++++++++ torch/_dynamo/backends/common.py | 30 ++++--- torch/_dynamo/convert_frame.py | 122 +++++++++++++++++++--------- torch/_dynamo/eval_frame.py | 30 ++++++- torch/_dynamo/repro/after_dynamo.py | 52 ++++++------ 5 files changed, 204 insertions(+), 76 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index b22f02ee2fcc..c38dc7c7b892 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3,6 +3,8 @@ import collections import copy import itertools +import os +import tempfile import traceback import types import unittest @@ -16,6 +18,7 @@ import torch._dynamo.test_case import torch._dynamo.testing import torch.nn.functional as F +from torch._dynamo.debug_utils import same_two_models from torch._dynamo.eval_frame import unsupported from torch._dynamo.mutation_guard import GenerationTracker from torch._dynamo.testing import expectedFailureDynamic, same @@ -2739,6 +2742,49 @@ def fn(x): self.assertEqual(test_functions._variable, 1) self.assertEqual(res, 3 * torch.ones(10)) + @unittest.skipIf( + "inductor" not in torch._dynamo.list_backends(), + "inductor backend is not available", + ) + def test_save_and_load_inductor(self): + mod = MockModule() + opt_mod = torch.compile(mod, backend="inductor") + inp = torch.randn(10, 10) + opt_mod(inp) + + with tempfile.TemporaryDirectory() as tmpdirname: + torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) + loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) + loaded_model(inp) + self.assertTrue(same_two_models(loaded_model, mod, [inp])) + self.assertTrue(same_two_models(loaded_model, opt_mod, [inp])) + + torch._dynamo.reset() # force recompiles + torch._inductor.metrics.generated_kernel_count = 0 + loaded_model(inp) + self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0) + + def test_save_and_load_all_backends(self): + mod = MockModule() + inp = torch.randn(10, 10) + for backend in torch._dynamo.list_backends(): + try: + opt_mod = torch.compile(mod, backend=backend) + with tempfile.TemporaryDirectory() as tmpdirname: + torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) + loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) + torch._dynamo.reset() # force recompiles + torch._inductor.metrics.generated_kernel_count = 0 + opt_mod(inp) + opt_success = torch._inductor.metrics.generated_kernel_count == 0 + torch._dynamo.reset() # force recompiles + torch._inductor.metrics.generated_kernel_count = 0 + loaded_model(inp) + loaded_success = torch._inductor.metrics.generated_kernel_count == 0 + self.assertEqual(opt_success, loaded_success) + except torch._dynamo.exc.BackendCompilerFailed: + pass + def test_monkeypatching_forward(self): class FakeModule(torch.nn.Module): def forward(self, x): diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index cf1204de1a5f..69e70198c7f5 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -14,18 +14,22 @@ log = logging.getLogger(__name__) -def aot_autograd(**kwargs): - def compiler_fn(gm: torch.fx.GraphModule, example_inputs): +class AotAutograd: + def __init__(self, **kwargs): + self.__name__ = "compiler_fn" + self.kwargs = kwargs + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): if any(isinstance(x, (list, tuple, dict)) for x in example_inputs): return flatten_graph_inputs( gm, example_inputs, - compiler_fn, + self, ) # Hack to get around circular import problems with aot_eager_decomp_partition - if callable(kwargs.get("decompositions")): - kwargs["decompositions"] = kwargs["decompositions"]() + if callable(self.kwargs.get("decompositions")): + self.kwargs["decompositions"] = self.kwargs["decompositions"]() # NB: dont delete counter increment counters["aot_autograd"]["total"] += 1 @@ -42,10 +46,10 @@ def _wrapped_bw_compiler(*args, **kwargs): # stop TorchDynamo from trying to compile our generated backwards pass return disable(disable(bw_compiler)(*args, **kwargs)) - bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] - kwargs["bw_compiler"] = _wrapped_bw_compiler - kwargs["inference_compiler"] = ( - kwargs.get("inference_compiler") or kwargs["fw_compiler"] + bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"] + self.kwargs["bw_compiler"] = _wrapped_bw_compiler + self.kwargs["inference_compiler"] = ( + self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"] ) from functorch.compile import nop @@ -54,7 +58,7 @@ def _wrapped_bw_compiler(*args, **kwargs): # debug asserts slow down compile time noticeably, # So only default them on when the aot_eager backend is used. - if kwargs.get("fw_compiler", None) == nop: + if self.kwargs.get("fw_compiler", None) == nop: patch_config = patch("functorch.compile.config.debug_assert", True) else: patch_config = contextlib.nullcontext() @@ -62,14 +66,16 @@ def _wrapped_bw_compiler(*args, **kwargs): try: # NB: NOT cloned! with enable_aot_logging(), patch_config: - cg = aot_module_simplified(gm, example_inputs, **kwargs) + cg = aot_module_simplified(gm, example_inputs, **self.kwargs) counters["aot_autograd"]["ok"] += 1 return disable(cg) except Exception: counters["aot_autograd"]["not_ok"] += 1 raise - return compiler_fn + +def aot_autograd(**kwargs): + return AotAutograd(**kwargs) def mem_efficient_fusion_kwargs(use_decomps): diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index e779ccef9e38..37ff5a8a299b 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -361,17 +361,34 @@ def profile_wrapper(*args, **kwargs): return profile_wrapper -def convert_frame_assert( - compiler_fn: CompilerFn, - one_graph: bool = True, - export: bool = False, - export_constraints=None, -): - """Fully convert a frame into an FX graph""" - reset_graph_break_dup_checker() +class ConvertFrameAssert: + def __init__( + self, + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints=None, + ): + reset_graph_break_dup_checker() + self._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] + self._one_graph = one_graph + self._export = export + self._export_constraints = export_constraints + + @property + def _clone_with_backend(self): + return lambda backend: convert_frame_assert( + backend, self._one_graph, self._export, self._export_constraints + ) - def _convert_frame_assert( - frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0 + def __call__( + self, + frame: types.FrameType, + cache_entry, + hooks: Hooks, + frame_state, + *, + skip: int = 0, ): increment_frame() @@ -458,10 +475,10 @@ def _convert_frame_assert( frame.f_globals, frame.f_locals, frame.f_builtins, - compiler_fn, - one_graph, - export, - export_constraints, + self._torchdynamo_orig_callable, + self._one_graph, + self._export, + self._export_constraints, hooks, cache_entry, cache_size, @@ -471,13 +488,15 @@ def _convert_frame_assert( skip=skip + 1, ) - _convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] - def _clone_with_backend(backend): - return convert_frame_assert(backend, one_graph, export, export_constraints) - - _convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined] - return _convert_frame_assert +def convert_frame_assert( + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints=None, +): + """Fully convert a frame into an FX graph""" + return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints) from collections import OrderedDict @@ -907,16 +926,27 @@ def format_guard_failures(): torch._dynamo.callback_handler.run_end_callbacks() -def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): - """Try to convert a frame into an FX graph, if error leave frame unmodified""" - inner_convert = convert_frame_assert(compiler_fn, one_graph=False) +class ConvertFrame: + def __init__(self, compiler_fn: CompilerFn, hooks: Hooks): + self._torchdynamo_orig_callable = compiler_fn + self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False) + self._hooks = hooks + + @property + def _clone_with_backend(self): + return lambda backend: convert_frame(backend, self._hooks) - def _convert_frame( - frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0 + def __call__( + self, + frame: types.FrameType, + cache_entry, + hooks: Hooks, + frame_state, + skip: int = 0, ): counters["frames"]["total"] += 1 try: - result = inner_convert( + result = self._inner_convert( frame, cache_entry, hooks, frame_state, skip=skip + 1 ) counters["frames"]["ok"] += 1 @@ -980,9 +1010,10 @@ def _convert_frame( log.warning(error_msg, exc_info=True) return None - _convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] - _convert_frame._clone_with_backend = lambda backend: convert_frame(backend, hooks) # type: ignore[attr-defined] - return _convert_frame + +def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): + """Try to convert a frame into an FX graph, if error leave frame unmodified""" + return ConvertFrame(compiler_fn, hooks) # TODO mlazos: add support for same args, or record them @@ -1023,9 +1054,13 @@ def first_real_inst_idx(code): raise RuntimeError("RESUME instruction not found in code") -def catch_errors_wrapper(callback, hooks: Hooks): - @functools.wraps(callback) - def catch_errors(frame, cache_entry, frame_state): +class CatchErrorsWrapper: + def __init__(self, callback, hooks): + functools.wraps(callback)(self) + self._torchdynamo_orig_callable = callback + self.hooks = hooks + + def __call__(self, frame, cache_entry, frame_state): assert frame_state is not None is_skipfile = trace_rules.check(frame.f_code) @@ -1063,19 +1098,26 @@ def catch_errors(frame, cache_entry, frame_state): ddp_optimizer = DDPOptimizer( bucket_bytes_cap=ddp_module.bucket_bytes_cap, - backend_compile_fn=callback._torchdynamo_orig_callable, + backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable, ) assert hasattr( - callback, "_clone_with_backend" + self._torchdynamo_orig_callable, "_clone_with_backend" ), "DDPOptimizer only supports callback fns that know how to clone themselves." - hijacked_callback = callback._clone_with_backend( - ddp_optimizer.compile_fn, + hijacked_callback = ( + self._torchdynamo_orig_callable._clone_with_backend( + ddp_optimizer.compile_fn, + ) + ) + return hijacked_callback( + frame, cache_entry, self.hooks, frame_state ) - return hijacked_callback(frame, cache_entry, hooks, frame_state) with compile_lock, _disable_current_modes(): # skip=1: skip this frame - return callback(frame, cache_entry, hooks, frame_state, skip=1) + return self._torchdynamo_orig_callable( + frame, cache_entry, self.hooks, frame_state, skip=1 + ) - catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined] - return catch_errors + +def catch_errors_wrapper(callback, hooks: Hooks): + return CatchErrorsWrapper(callback, hooks) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 318fdd265085..94cad71f7ef5 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -168,6 +168,9 @@ def _initialize(self): self._forward = self.forward self.forward = self._call_lazy_check + def __reduce__(self): + return (self.__class__, (self._orig_mod, self.dynamo_ctx)) + def __getstate__(self): state = dict(self.__dict__) state.pop("forward", None) @@ -273,9 +276,11 @@ def __init__( super().__init__() assert callable(callback) or callback is False or callback is None self.callback: DynamoCallback = callback + self._backend_ctx_ctor = backend_ctx_ctor self.prior: Union[Unset, DynamoCallback] = unset self.first_ctx = first_ctx self.export = export + self._dynamic = dynamic self.compiler_config = compiler_config self.cleanup_fns: List[Callable[[], Any]] = [] self.enter_exit_hooks = [] @@ -379,7 +384,13 @@ def get_compiler_config(): # call to a builtin without a frame for us to capture fn = external_utils.wrap_inline(fn) - callback = self.callback + def do_nothing(*arg, **kwargs): + pass + + if hasattr(self, "callback"): + callback = self.callback + else: + callback = do_nothing is_jit_tracing = torch._C._is_tracing is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing @@ -522,6 +533,17 @@ def call_compiled_autograd(): self.enter_exit_hooks.append(call_compiled_autograd) + def __reduce__(self): + return ( + self.__class__, + (self.callback, self._backend_ctx_ctor, self.first_ctx), + { + "export": self.export, + "dynamic": self._dynamic, + "compiler_config": self.compiler_config, + }, + ) + class RunOnlyContext(_TorchDynamoContext): def __init__(self): @@ -531,6 +553,9 @@ def on_enter(): super().__init__(callback=False, on_enter=on_enter) + def __reduce__(self): + return (self.__class__, ()) + class DisableContext(_TorchDynamoContext): def __init__(self): @@ -583,6 +608,9 @@ def _fn(*args, **kwargs): return _fn + def __reduce__(self): + return (self.__class__, ()) + def _optimize_catch_errors( compile_fn, diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index 76b9128e6995..43f761f84d3d 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -56,19 +56,20 @@ def _accuracy_fails(gm, example_inputs, compiler_fn): ) -def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): - """ - A minifier decorator that wraps the TorchDynamo produced Fx graph modules. - As opposed to wrap_compiler_debug, this wrapper intercepts at the - TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some - level, e.g., it is useful for minifying issues related to Aot Autograd - tracing. If an error is found, we minify and save the minified repro in - repro.tar.gz. - """ - - @functools.wraps(unconfigured_compiler_fn) - def debug_wrapper(gm, example_inputs, **kwargs): - compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) +class WrapBackendDebug: + def __init__(self, unconfigured_compiler_fn, compiler_name: str): + functools.wraps(unconfigured_compiler_fn)(self) + self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] + self._compiler_name = compiler_name + if hasattr(unconfigured_compiler_fn, "__name__"): + self.__name__ = unconfigured_compiler_fn.__name__ + if hasattr(unconfigured_compiler_fn, "compiler_name"): + self.__name__ = unconfigured_compiler_fn.compiler_name + if hasattr(unconfigured_compiler_fn, "get_compiler_config"): + self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] + + def __call__(self, gm, example_inputs, **kwargs): + compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs) assert config.repro_after in ("dynamo", "aot", None) if config.repro_after == "dynamo": @@ -82,7 +83,7 @@ def add_paths(exc): ) if config.repro_level == 3: - dump_to_minify_after_dynamo(gm, example_inputs, compiler_name) + dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name) # Check for either accuracy (level 4) or other type of failures. if config.repro_level == 4: @@ -95,7 +96,7 @@ def add_paths(exc): dump_to_minify_after_dynamo( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs, - compiler_name, + self._compiler_name, ) exc = AccuracyError("Bad accuracy detected.") add_paths(exc) @@ -110,7 +111,7 @@ def add_paths(exc): ) if config.repro_level == 1: dump_state_fn = functools.partial( - dump_backend_state, compiler_name=compiler_name + dump_backend_state, compiler_name=self._compiler_name ) dump_state_fn( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs @@ -119,7 +120,7 @@ def add_paths(exc): dump_to_minify_after_dynamo( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs, - compiler_name, + self._compiler_name, ) add_paths(exc) raise @@ -128,12 +129,17 @@ def add_paths(exc): return compiled_gm - debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] - if hasattr(unconfigured_compiler_fn, "compiler_name"): - debug_wrapper.__name__ = unconfigured_compiler_fn.compiler_name - if hasattr(unconfigured_compiler_fn, "get_compiler_config"): - debug_wrapper.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] - return debug_wrapper + +def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): + """ + A minifier decorator that wraps the TorchDynamo produced Fx graph modules. + As opposed to wrap_compiler_debug, this wrapper intercepts at the + TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some + level, e.g., it is useful for minifying issues related to Aot Autograd + tracing. If an error is found, we minify and save the minified repro in + repro.tar.gz. + """ + return WrapBackendDebug(unconfigured_compiler_fn, compiler_name) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # From faabda4fc9e3c033aca7b7aa921fba21ebae06d5 Mon Sep 17 00:00:00 2001 From: Weizhuo Zhang Date: Wed, 5 Jun 2024 14:23:09 +0000 Subject: [PATCH 361/706] [Inductor] Skip model_fail_to_load and eager_fail_to_run models in inductor benchmarks test (#127210) Aligned with test-infra repo, we skipped `model_fail_to_load` and `eager_fail_to_run` models Refer code logic: https://github.com/pytorch/test-infra/blob/d3b79778f8e67b66b5ab5fff3bc9a60db82faca5/torchci/rockset/inductor/__sql/compilers_benchmark_performance.sql#L57-L58 ```SQL WHERE filename LIKE '%_accuracy' AND filename LIKE CONCAT( '%_', : dtypes, '_', : mode, '_', : device, '_%' ) AND _event_time >= PARSE_DATETIME_ISO8601(:startTime) AND _event_time < PARSE_DATETIME_ISO8601(:stopTime) AND (workflow_id = :workflowId OR :workflowId = 0) AND accuracy != 'model_fail_to_load' AND accuracy != 'eager_fail_to_run' ), ``` Comp Item | Compiler | suite | Before | After fix -- | -- | -- | -- | -- Pass Rate | Inductor | torchbench | 96%, 80/83 | 100%, 80/80 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127210 Approved by: https://github.com/jansel --- benchmarks/dynamo/runner.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 54cff5658257..bc42b6566706 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -776,12 +776,18 @@ def extract_df(self, metric, testing): if not perf_row.empty: if acc_row.empty: perf_row[compiler] = 0.0 + elif acc_row[compiler].iloc[0] in ( + "model_fail_to_load", + "eager_fail_to_run", + ): + perf_row = pd.DataFrame() elif acc_row[compiler].iloc[0] not in ( "pass", "pass_due_to_skip", ): perf_row[compiler] = 0.0 - perf_rows.append(perf_row) + if not perf_row.empty: + perf_rows.append(perf_row) df = pd.concat(perf_rows) df = df.sort_values(by=list(reversed(self.compilers)), ascending=False) From 4ce5322a1f16c8852523f7417db5ed494450b119 Mon Sep 17 00:00:00 2001 From: hippocookie Date: Wed, 5 Jun 2024 14:31:26 +0000 Subject: [PATCH 362/706] Enable UFMT on test_shape_ops.py test_show_pickle.py test_sort_and_select.py (#127165) Fixes some files in #123062 Run lintrunner on files: test_shape_ops.py test_show_pickle.py test_sort_and_select.py ```bash $ lintrunner --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127165 Approved by: https://github.com/ezyang --- .lintrunner.toml | 3 - test/test_shape_ops.py | 244 ++++++++----- test/test_show_pickle.py | 13 +- test/test_sort_and_select.py | 645 ++++++++++++++++++++++------------- 4 files changed, 588 insertions(+), 317 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index e4f2507de8cc..1d7c00a2c772 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1114,9 +1114,6 @@ exclude_patterns = [ 'test/test_segment_reductions.py', 'test/test_serialization.py', 'test/test_set_default_mobile_cpu_allocator.py', - 'test/test_shape_ops.py', - 'test/test_show_pickle.py', - 'test/test_sort_and_select.py', 'test/test_sparse.py', 'test/test_sparse_csr.py', 'test/test_sparse_semi_structured.py', diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py index 47acfff9c6d4..2353d6841bbb 100644 --- a/test/test_shape_ops.py +++ b/test/test_shape_ops.py @@ -1,22 +1,41 @@ # Owner(s): ["module: tests"] -import torch -import numpy as np - -from itertools import product, combinations, permutations, chain -from functools import partial import random -import warnings import unittest +import warnings +from functools import partial + +from itertools import chain, combinations, permutations, product + +import numpy as np + +import torch from torch import nan from torch.testing import make_tensor -from torch.testing._internal.common_utils import ( - TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict, IS_JETSON, TEST_PRIVATEUSE1_DEVICE_TYPE) from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyNativeDeviceTypes, - dtypesIfCUDA, largeTensorTest) -from torch.testing._internal.common_dtype import all_types_and_complex_and, all_types, all_types_and + dtypes, + dtypesIfCUDA, + instantiate_device_type_tests, + largeTensorTest, + onlyCPU, + onlyCUDA, + onlyNativeDeviceTypes, +) +from torch.testing._internal.common_dtype import ( + all_types, + all_types_and, + all_types_and_complex_and, +) +from torch.testing._internal.common_utils import ( + IS_JETSON, + run_tests, + skipIfTorchDynamo, + TEST_PRIVATEUSE1_DEVICE_TYPE, + TestCase, + torch_to_numpy_dtype_dict, +) + # TODO: replace with make_tensor def _generate_input(shape, dtype, device, with_extremal): @@ -29,17 +48,19 @@ def _generate_input(shape, dtype, device, with_extremal): x = torch.randn(*shape, device=device) * random.randint(30, 100) x = x.to(torch.bfloat16) else: - x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint( + 30, 100 + ) x[torch.randn(*shape) > 0.5] = 0 if with_extremal and dtype.is_floating_point: # Use extremal values - x[torch.randn(*shape) > 0.5] = float('nan') - x[torch.randn(*shape) > 0.5] = float('inf') - x[torch.randn(*shape) > 0.5] = float('-inf') + x[torch.randn(*shape) > 0.5] = float("nan") + x[torch.randn(*shape) > 0.5] = float("inf") + x[torch.randn(*shape) > 0.5] = float("-inf") elif with_extremal and dtype.is_complex: - x[torch.randn(*shape) > 0.5] = complex('nan') - x[torch.randn(*shape) > 0.5] = complex('inf') - x[torch.randn(*shape) > 0.5] = complex('-inf') + x[torch.randn(*shape) > 0.5] = complex("nan") + x[torch.randn(*shape) > 0.5] = complex("inf") + x[torch.randn(*shape) > 0.5] = complex("-inf") elif dtype == torch.bool: x = torch.zeros(shape, dtype=dtype, device=device) x[torch.randn(*shape) > 0.5] = True @@ -48,8 +69,8 @@ def _generate_input(shape, dtype, device, with_extremal): return x -class TestShapeOps(TestCase): +class TestShapeOps(TestCase): # TODO: update to work on CUDA, too @onlyCPU def test_unbind(self, device): @@ -71,7 +92,7 @@ def test_tolist(self, device): tensor0D = torch.tensor(list0D) self.assertEqual(tensor0D.tolist(), list0D) - table1D = [1., 2., 3.] + table1D = [1.0, 2.0, 3.0] tensor1D = torch.tensor(table1D) storage = torch.Storage(table1D) self.assertEqual(tensor1D.tolist(), table1D) @@ -102,19 +123,29 @@ def test_movedim_invalid(self, device, dtype): fn(x, 0, 5) # Mismatch in size of `source` and `destination` - with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"): - fn(x, (1, 0), (0, )) - - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: Invalid source or destination dims:" + ): + fn(x, (1, 0), (0,)) + + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `source`" + ): fn(x, (0, 0), (0, 1)) - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `source`" + ): fn(x, (0, 1, 0), (0, 1, 2)) - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `destination`" + ): fn(x, (0, 1), (1, 1)) - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `destination`" + ): fn(x, (0, 1, 2), (1, 0, 1)) @dtypes(torch.int64, torch.float, torch.complex128) @@ -137,8 +168,12 @@ def test_movedim(self, device, dtype): # Integer `source` and `destination` torch_fn = partial(fn, source=src_dim, destination=dst_dim) - np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + np_fn = partial( + np.moveaxis, source=src_dim, destination=dst_dim + ) + self.compare_with_numpy( + torch_fn, np_fn, x, device=None, dtype=None + ) if nd == 0: continue @@ -148,9 +183,13 @@ def make_index_negative(sequence, idx): sequence[random_idx] = sequence[random_idx] - nd return tuple(src_sequence) - for src_sequence in permutations(range(nd), r=random.randint(1, nd)): + for src_sequence in permutations( + range(nd), r=random.randint(1, nd) + ): # Sequence `source` and `destination` - dst_sequence = tuple(random.sample(range(nd), len(src_sequence))) + dst_sequence = tuple( + random.sample(range(nd), len(src_sequence)) + ) # Randomly change a dim to a negative dim representation of itself. random_prob = random.random() @@ -166,9 +205,15 @@ def make_index_negative(sequence, idx): random_idx = random.randint(0, len(src_sequence) - 1) src_sequence = make_index_negative(src_sequence, random_idx) - torch_fn = partial(fn, source=src_sequence, destination=dst_sequence) - np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + torch_fn = partial( + fn, source=src_sequence, destination=dst_sequence + ) + np_fn = partial( + np.moveaxis, source=src_sequence, destination=dst_sequence + ) + self.compare_with_numpy( + torch_fn, np_fn, x, device=None, dtype=None + ) # Move dim to same position x = torch.randn(2, 3, 5, 7, 11) @@ -213,10 +258,7 @@ def test_diagonal(self, device): def test_diagonal_multidim(self, device, dtype): x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device) xn = x.numpy() - for args in [(2, 2, 3), - (2,), - (-2, 1, 2), - (0, -2, -1)]: + for args in [(2, 2, 3), (2,), (-2, 1, 2), (0, -2, -1)]: result = torch.diagonal(x, *args) expected = xn.diagonal(*args) self.assertEqual(expected.shape, result.shape) @@ -270,14 +312,22 @@ def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nan max_vals = max_vals.cpu().numpy() # Use NumPy implementation as reference - X_clamped = torch.tensor(np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device) + X_clamped = torch.tensor( + np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device + ) return X, X_clamped # Tests clamp and its alias, clip @dtypes(torch.int64, torch.float32) def test_clamp(self, device, dtype): - op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, - torch.clip, torch.Tensor.clip, torch.Tensor.clip_) + op_list = ( + torch.clamp, + torch.Tensor.clamp, + torch.Tensor.clamp_, + torch.clip, + torch.Tensor.clip, + torch.Tensor.clip_, + ) # min/max argument product args = product((-10, None), (10, None)) @@ -287,10 +337,9 @@ def test_clamp(self, device, dtype): if min_val is None and max_val is None: continue - X, Y_expected = self.generate_clamp_baseline(device, dtype, - min_vals=min_val, - max_vals=max_val, - with_nans=False) + X, Y_expected = self.generate_clamp_baseline( + device, dtype, min_vals=min_val, max_vals=max_val, with_nans=False + ) # Test op X1 = X.clone() # So that the in-place ops do not change X @@ -304,8 +353,14 @@ def test_clamp(self, device, dtype): self.assertEqual(Y_expected, Y_out) def test_clamp_propagates_nans(self, device): - op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, - torch.clip, torch.Tensor.clip, torch.Tensor.clip_) + op_list = ( + torch.clamp, + torch.Tensor.clamp, + torch.Tensor.clamp_, + torch.clip, + torch.Tensor.clip, + torch.Tensor.clip_, + ) # min/max argument product args = product((-10, None), (10, None)) @@ -315,10 +370,13 @@ def test_clamp_propagates_nans(self, device): if min_val is None and max_val is None: continue - X, Y_expected = self.generate_clamp_baseline(device, torch.float, - min_vals=min_val, - max_vals=max_val, - with_nans=True) + X, Y_expected = self.generate_clamp_baseline( + device, + torch.float, + min_vals=min_val, + max_vals=max_val, + with_nans=True, + ) Y_expected = torch.isnan(Y_expected) # Test op @@ -334,7 +392,7 @@ def test_clamp_propagates_nans(self, device): def test_clamp_raises_arg_errors(self, device): X = torch.randn(100, dtype=torch.float, device=device) - error_msg = 'At least one of \'min\' or \'max\' must not be None' + error_msg = "At least one of 'min' or 'max' must not be None" with self.assertRaisesRegex(RuntimeError, error_msg): X.clamp() with self.assertRaisesRegex(RuntimeError, error_msg): @@ -369,18 +427,22 @@ def all_t(): self.assertEqual(in_t.flip(p_dims), out_t) if len(p_dims) > 0: # Wrap 1st dim - self.assertEqual(in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t) + self.assertEqual( + in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t + ) def gen_data(): # Basic tests data = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2) nonctg = make_from_size((2, 2, 2), noncontiguous=True).copy_(data) - dims_result = ((0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)), - (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)), - (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)), - ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)), - ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2))) + dims_result = ( + (0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)), + (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)), + (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)), + ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)), + ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2)), + ) for in_tensor, (dims, out_tensor) in product((data, nonctg), dims_result): yield in_tensor, dims, out_tensor @@ -393,7 +455,9 @@ def gen_data(): yield in_t, 1, in_t # Transposed - in_t = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1) + in_t = ( + make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1) + ) dims = (0, 1, 2) out_t = make_from_data([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2) yield in_t, dims, out_t @@ -411,7 +475,9 @@ def gen_data(): if device == "cpu" and dtype != torch.bfloat16: for mf in [torch.contiguous_format, torch.channels_last]: for c in [2, 3, 8, 16]: - in_t = make_from_size((2, c, 32, 32)).contiguous(memory_format=mf) + in_t = make_from_size((2, c, 32, 32)).contiguous( + memory_format=mf + ) np_in_t = in_t.numpy() np_out_t = np_in_t[:, :, :, ::-1].copy() @@ -464,7 +530,9 @@ def gen_data(): size = [2, 3, 4] data = make_from_size(size) possible_dims = range(len(size)) - test_dims = chain(combinations(possible_dims, 1), combinations(possible_dims, 2)) + test_dims = chain( + combinations(possible_dims, 1), combinations(possible_dims, 2) + ) for dims in test_dims: self.assertEqual(size, list(data.flip(dims).size())) @@ -483,7 +551,6 @@ def test_flip_errors(self, device, dtype): self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3)) self.assertRaises(IndexError, lambda: data.flip(3)) - def _rand_shape(self, dim, min_size, max_size): return tuple(torch.randint(min_size, max_size + 1, (dim,))) @@ -504,8 +571,10 @@ def test_flip_numpy(self, device, dtype): self.compare_with_numpy(torch_fn, np_fn, data) @onlyCUDA # CPU is too slow - @largeTensorTest('17GB') # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB - @largeTensorTest("81GB", "cpu") # even for CUDA test, sufficient system memory is required + @largeTensorTest("17GB") # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB + @largeTensorTest( + "81GB", "cpu" + ) # even for CUDA test, sufficient system memory is required @unittest.skipIf(IS_JETSON, "Too large for Jetson") def test_flip_large_tensor(self, device): t_in = torch.empty(2**32 + 1, dtype=torch.uint8).random_() @@ -569,7 +638,9 @@ def test_rot90(self, device): # test tensor with more than 2D data = torch.arange(1, 9, device=device).view(2, 2, 2) - self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) + self.assertEqual( + torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2]) + ) self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) # test for errors @@ -601,7 +672,6 @@ def test_nonzero_no_warning(self, device): @dtypes(*all_types_and(torch.half, torch.bool, torch.bfloat16)) def test_nonzero(self, device, dtype): - shapes = [ torch.Size((12,)), torch.Size((12, 1)), @@ -616,7 +686,9 @@ def gen_nontrivial_input(shape, dtype, device): return torch.randint(2, shape, device=device, dtype=dtype) else: # windows does not work for bfloat16 randing - return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype) + return torch.randint(2, shape, device=device, dtype=torch.float).to( + dtype + ) for shape in shapes: tensor = gen_nontrivial_input(shape, dtype, device) @@ -624,20 +696,31 @@ def gen_nontrivial_input(shape, dtype, device): dst2 = tensor.nonzero(as_tuple=False) dst3 = torch.empty([], dtype=torch.long, device=device) torch.nonzero(tensor, out=dst3) - if self.device_type != 'xla': + if self.device_type != "xla": # xla does not raise runtime error self.assertRaisesRegex( RuntimeError, "scalar type Long", - lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.float, device=device)) + lambda: torch.nonzero( + tensor, out=torch.empty([], dtype=torch.float, device=device) + ), ) - if self.device_type == 'cuda' or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE: + if ( + self.device_type == "cuda" + or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE + ): self.assertRaisesRegex( RuntimeError, "on the same device", - lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.long)) + lambda: torch.nonzero( + tensor, out=torch.empty([], dtype=torch.long) + ), ) - np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy() + np_array = ( + tensor.cpu().numpy() + if dtype != torch.bfloat16 + else tensor.float().cpu().numpy() + ) np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) @@ -656,7 +739,9 @@ def test_nonzero_astuple_out(self, device): with self.assertRaises(RuntimeError): torch.nonzero(t, as_tuple=True, out=out) - self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out)) + self.assertEqual( + torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out) + ) # Verifies that JIT script cannot handle the as_tuple kwarg # See Issue https://github.com/pytorch/pytorch/issues/45499. @@ -684,7 +769,9 @@ def _foo(t): def test_nonzero_discontiguous(self, device): shape = (4, 4) tensor = torch.randint(2, shape, device=device) - tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor) + tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_( + tensor + ) dst1 = tensor.nonzero(as_tuple=False) dst2 = tensor_nc.nonzero(as_tuple=False) self.assertEqual(dst1, dst2, atol=0, rtol=0) @@ -695,7 +782,9 @@ def test_nonzero_discontiguous(self, device): self.assertEqual(data_ptr, dst3.data_ptr()) self.assertEqual(dst1, dst3, atol=0, rtol=0) # discontiguous out - dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2] + dst4 = torch.empty( + dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device + )[:, ::2] data_ptr = dst4.data_ptr() strides = dst4.stride() torch.nonzero(tensor, out=dst4) @@ -710,7 +799,7 @@ def test_nonzero_non_diff(self, device): @dtypes(torch.int64, torch.float, torch.complex128) def test_sparse_dense_dim(self, device, dtype): - for shape in [(), (2, ), (2, 3)]: + for shape in [(), (2,), (2, 3)]: if dtype.is_complex or dtype.is_floating_point: x = torch.rand(shape, device=device, dtype=dtype) else: @@ -718,7 +807,8 @@ def test_sparse_dense_dim(self, device, dtype): self.assertEqual(x.sparse_dim(), 0) self.assertEqual(x.dense_dim(), len(shape)) + instantiate_device_type_tests(TestShapeOps, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_show_pickle.py b/test/test_show_pickle.py index 929584943007..48b459e12eac 100644 --- a/test/test_show_pickle.py +++ b/test/test_show_pickle.py @@ -1,15 +1,16 @@ # Owner(s): ["oncall: mobile"] -import unittest import io import tempfile +import unittest + import torch import torch.utils.show_pickle -from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase -class TestShowPickle(TestCase): +class TestShowPickle(TestCase): @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows") def test_scripted_model(self): class MyCoolModule(torch.nn.Module): @@ -26,11 +27,13 @@ def forward(self, x): torch.jit.save(m, tmp) tmp.flush() buf = io.StringIO() - torch.utils.show_pickle.main(["", tmp.name + "@*/data.pkl"], output_stream=buf) + torch.utils.show_pickle.main( + ["", tmp.name + "@*/data.pkl"], output_stream=buf + ) output = buf.getvalue() self.assertRegex(output, "MyCoolModule") self.assertRegex(output, "weight") -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 7709131e6102..211c7b998608 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -1,44 +1,69 @@ # Owner(s): ["module: tests"] -import torch +import random +from itertools import permutations, product + import numpy as np -import random +import torch from torch import nan -from itertools import permutations, product from torch.testing import make_tensor -from torch.testing._internal.common_dtype import all_types, all_types_and, floating_types_and, integral_types -from torch.testing._internal.common_utils import \ - (TestCase, run_tests, slowTest, skipIfTorchDynamo) -from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, onlyNativeDeviceTypes, - onlyCUDA, dtypesIfCUDA, dtypesIfCPU, onlyCPU, largeTensorTest) +from torch.testing._internal.common_device_type import ( + dtypes, + dtypesIfCPU, + dtypesIfCUDA, + instantiate_device_type_tests, + largeTensorTest, + onlyCPU, + onlyCUDA, + onlyNativeDeviceTypes, +) +from torch.testing._internal.common_dtype import ( + all_types, + all_types_and, + floating_types_and, + integral_types, +) +from torch.testing._internal.common_utils import ( + run_tests, + skipIfTorchDynamo, + slowTest, + TestCase, +) # TODO: remove this SIZE = 100 -class TestSortAndSelect(TestCase): +class TestSortAndSelect(TestCase): def assertIsOrdered(self, order, x, mxx, ixx, task): SIZE = x.size(1) - if order == 'descending': + if order == "descending": + def check_order(a, b): # `a != a` because we put NaNs # at the end of ascending sorted lists, # and the beginning of descending ones. return ((a != a) | (a >= b)).all().item() - elif order == 'ascending': + + elif order == "ascending": + def check_order(a, b): # see above return ((b != b) | (a <= b)).all().item() + else: - error(f'unknown order "{order}", must be "ascending" or "descending"') # noqa: F821 + error( # noqa: F821 + f'unknown order "{order}", must be "ascending" or "descending"' + ) are_ordered = True for k in range(1, SIZE): - self.assertTrue(check_order(mxx[:, k - 1], mxx[:, k]), - f'torch.sort ({order}) values unordered for {task}') + self.assertTrue( + check_order(mxx[:, k - 1], mxx[:, k]), + f"torch.sort ({order}) values unordered for {task}", + ) seen = set() indicesCorrect = True @@ -50,8 +75,11 @@ def check_order(a, b): for k in range(size0): seen.clear() for j in range(size): - self.assertEqual(x[k][ixx[k][j]], mxx[k][j], - msg=f'torch.sort ({order}) indices wrong for {task}') + self.assertEqual( + x[k][ixx[k][j]], + mxx[k][j], + msg=f"torch.sort ({order}) indices wrong for {task}", + ) seen.add(ixx[k][j]) self.assertEqual(len(seen), size) @@ -79,19 +107,22 @@ def test_sort(self, device): self.assertEqual(x.argsort(), res1ind) # Test sorting of random numbers - self.assertIsOrdered('ascending', x, res2val, res2ind, 'random') + self.assertIsOrdered("ascending", x, res2val, res2ind, "random") # Test simple sort self.assertEqual( torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0], torch.tensor((10, 20, 30, 40, 50), device=device), - atol=0, rtol=0 + atol=0, + rtol=0, ) # Test that we still have proper sorting with duplicate keys x = torch.floor(torch.rand(4, SIZE, device=device) * 10) torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered('ascending', x, res2val, res2ind, 'random with duplicate keys') + self.assertIsOrdered( + "ascending", x, res2val, res2ind, "random with duplicate keys" + ) # DESCENDING SORT x = torch.rand(4, SIZE, device=device) @@ -107,35 +138,41 @@ def test_sort(self, device): self.assertEqual(x.argsort(x.dim() - 1, True), res1ind) # Test sorting of random numbers - self.assertIsOrdered('descending', x, res2val, res2ind, 'random') + self.assertIsOrdered("descending", x, res2val, res2ind, "random") # Test simple sort task self.assertEqual( - torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[0], + torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[ + 0 + ], torch.tensor((50, 40, 30, 20, 10), device=device), - atol=0, rtol=0 + atol=0, + rtol=0, ) # Test that we still have proper sorting with duplicate keys - self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys') + self.assertIsOrdered( + "descending", x, res2val, res2ind, "random with duplicate keys" + ) # Test argument sorting with and without stable x = torch.tensor([1, 10, 2, 2, 3, 7, 7, 8, 9, 9] * 3) - self.assertEqual(torch.argsort(x, stable=True), torch.sort(x, stable=True).indices) - self.assertEqual(torch.argsort(x, stable=False), torch.sort(x, stable=False).indices) + self.assertEqual( + torch.argsort(x, stable=True), torch.sort(x, stable=True).indices + ) + self.assertEqual( + torch.argsort(x, stable=False), torch.sort(x, stable=False).indices + ) self.assertEqual(torch.argsort(x), torch.sort(x).indices) - # Test sorting with NaNs x = torch.rand(4, SIZE, device=device) - x[1][2] = float('NaN') - x[3][0] = float('NaN') + x[1][2] = float("NaN") + x[3][0] = float("NaN") torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered('ascending', x, res2val, res2ind, - 'random with NaNs') + self.assertIsOrdered("ascending", x, res2val, res2ind, "random with NaNs") torch.sort(x, out=(res2val, res2ind), descending=True) - self.assertIsOrdered('descending', x, res2val, res2ind, - 'random with NaNs') + self.assertIsOrdered("descending", x, res2val, res2ind, "random with NaNs") def test_sort_stable_none(self): # Called sort with stable=None used to trigger an assertion @@ -169,19 +206,19 @@ def test_stable_sort(self, device, dtype): _, idx = x.sort(stable=True) self.assertEqual( idx[:ncopies], - torch.arange(start=0, end=2 * ncopies, step=2, device=device) + torch.arange(start=0, end=2 * ncopies, step=2, device=device), ) self.assertEqual( idx[ncopies:], - torch.arange(start=1, end=2 * ncopies, step=2, device=device) + torch.arange(start=1, end=2 * ncopies, step=2, device=device), ) @onlyCUDA @dtypes(torch.uint8) - @largeTensorTest('200GB') # Unfortunately 80GB A100 is not large enough + @largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough def test_sort_large(self, device, dtype): t0 = torch.randperm(8192, device=device).to(dtype) - t = t0.view(1, 8192).expand(2 ** 18 + 1, -1).contiguous() + t = t0.view(1, 8192).expand(2**18 + 1, -1).contiguous() v, i = t.sort() del t iv, im = i.var_mean(dim=0) @@ -193,7 +230,6 @@ def test_sort_large(self, device, dtype): self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device)) self.assertEqual(im, t0.sort().indices) - @dtypes(torch.float32) def test_sort_restride(self, device, dtype): # Input: non-contiguous (stride: 5) 3-element array @@ -223,14 +259,24 @@ def _test_sort_discontiguous(self, device, dtype): n = t.size(dim) # assert ordered - self.assertTrue((r1.values.narrow(dim, 1, n - 1) >= r1.values.narrow(dim, 0, n - 1)).all()) + self.assertTrue( + ( + r1.values.narrow(dim, 1, n - 1) + >= r1.values.narrow(dim, 0, n - 1) + ).all() + ) # assert that different segments does not mix, which can easily happen # if the stride is not handled correctly - self.assertTrue((t.unsqueeze(-1).transpose(dim, -1) == r1.values.unsqueeze(-1)).any(dim=dim).any(dim=-1).all()) + self.assertTrue( + (t.unsqueeze(-1).transpose(dim, -1) == r1.values.unsqueeze(-1)) + .any(dim=dim) + .any(dim=-1) + .all() + ) # assert stride is preserved - if self.device_type == 'cuda': + if self.device_type == "cuda": # FIXME: this behavior should be true for all cases, not # just the one specified in if condition self.assertEqual(r1.values.stride(), t.stride()) @@ -262,7 +308,9 @@ def test_sort_1d_output_discontiguous(self, device, dtype): @dtypes(*integral_types()) def test_sort_1d_parallel(self, device, dtype): low = 0 if dtype == torch.uint8 else -128 - tensor = torch.randint(low=low, high=127, size=(100000, ), device=device, dtype=dtype) + tensor = torch.randint( + low=low, high=127, size=(100000,), device=device, dtype=dtype + ) vals, _ = torch.sort(tensor, stable=True) self.assertEqual(True, torch.all(vals[:-1] <= vals[1:])) @@ -283,9 +331,9 @@ def test_topk_1d_output_discontiguous(self, device, dtype): @dtypes(*all_types_and(torch.half, torch.bfloat16)) def test_stable_sort_against_numpy(self, device, dtype): if dtype in floating_types_and(torch.float16, torch.bfloat16): - inf = float('inf') - neg_inf = -float('inf') - nan = float('nan') + inf = float("inf") + neg_inf = -float("inf") + nan = float("nan") else: if dtype != torch.bool: # no torch.iinfo support for torch.bool @@ -305,7 +353,7 @@ def generate_samples(): # binary strings yield (torch.tensor([0, 1] * size, dtype=dtype, device=device), 0) - if self.device_type == 'cuda': + if self.device_type == "cuda": return yield (torch.tensor([0, 1] * 100, dtype=dtype, device=device), 0) @@ -326,13 +374,21 @@ def repeated_index_fill(t, dim, idxs, vals): # for each dimension. n_fill_vals = 3 # cardinality of (inf, neg_inf, nan) for dim in range(len(sizes)): - idxs = (torch.randint(high=size, size=(size // 10,)) for i in range(n_fill_vals)) + idxs = ( + torch.randint(high=size, size=(size // 10,)) + for i in range(n_fill_vals) + ) vals = (inf, neg_inf, nan) - subsets = chain.from_iterable(combinations(list(zip(idxs, vals)), r) - for r in range(1, n_fill_vals + 1)) + subsets = chain.from_iterable( + combinations(list(zip(idxs, vals)), r) + for r in range(1, n_fill_vals + 1) + ) for subset in subsets: idxs_subset, vals_subset = zip(*subset) - yield (repeated_index_fill(x, dim, idxs_subset, vals_subset), dim) + yield ( + repeated_index_fill(x, dim, idxs_subset, vals_subset), + dim, + ) for sample, dim in generate_samples(): _, idx_torch = sample.sort(dim=dim, stable=True) @@ -340,7 +396,7 @@ def repeated_index_fill(t, dim, idxs, vals): sample_numpy = sample.float().cpu().numpy() else: sample_numpy = sample.cpu().numpy() - idx_numpy = np.argsort(sample_numpy, axis=dim, kind='stable') + idx_numpy = np.argsort(sample_numpy, axis=dim, kind="stable") self.assertEqual(idx_torch, idx_numpy) @dtypes(*all_types_and(torch.half, torch.bfloat16)) @@ -349,7 +405,9 @@ def test(shape): tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9) if tensor.size() != torch.Size([]): if dtype is torch.bfloat16: - expected = torch.from_numpy(np.msort(tensor.float().cpu().numpy())).bfloat16() + expected = torch.from_numpy( + np.msort(tensor.float().cpu().numpy()) + ).bfloat16() else: expected = torch.from_numpy(np.msort(tensor.cpu().numpy())) else: @@ -364,11 +422,15 @@ def test(shape): shapes = ( [], - [0, ], - [20, ], + [ + 0, + ], + [ + 20, + ], [1, 20], [30, 30], - [10, 20, 30] + [10, 20, 30], ) for shape in shapes: test(shape) @@ -414,9 +476,12 @@ def compare(t, k, dim, dir): sortKVal, sortKInd = topKViaSort(t, k, dim, dir) compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim) - t = torch.rand(random.randint(1, SIZE), - random.randint(1, SIZE), - random.randint(1, SIZE), device=device) + t = torch.rand( + random.randint(1, SIZE), + random.randint(1, SIZE), + random.randint(1, SIZE), + device=device, + ) for _kTries in range(3): for _dimTries in range(3): @@ -457,91 +522,94 @@ def test_topk_arguments(self, device): self.assertRaises(TypeError, lambda: q.topk(4, True)) def test_unique_dim(self, device): - self.assertFalse(hasattr(torch, 'unique_dim')) + self.assertFalse(hasattr(torch, "unique_dim")) def run_test(device, dtype): - x = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]], - [[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) + x = torch.tensor( + [ + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + ], + dtype=dtype, + device=device, + ) x_empty = torch.empty(5, 0, dtype=dtype, device=device) x_ill_formed_empty = torch.empty(5, 0, 0, dtype=dtype, device=device) - x_ill_formed_empty_another = torch.empty(5, 0, 5, dtype=dtype, device=device) + x_ill_formed_empty_another = torch.empty( + 5, 0, 5, dtype=dtype, device=device + ) if dtype in floating_types_and(torch.float16, torch.bfloat16): - x_nan = torch.tensor([float("nan"), 0, 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) - expected_unique_dim0 = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) + x_nan = torch.tensor( + [float("nan"), 0, 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) + expected_unique_dim0 = torch.tensor( + [[[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]]], + dtype=dtype, + device=device, + ) expected_inverse_dim0 = torch.tensor([0, 0]) expected_counts_dim0 = torch.tensor([2]) - expected_unique_dim1 = torch.tensor([[[0., 1.], - [1., 1.], - [2., 1.]], - [[0., 1.], - [1., 1.], - [2., 1.]]], - dtype=dtype, - device=device) - expected_unique_dim1_bool = torch.tensor([[[False, True], [True, True]], - [[False, True], [True, True]]], - dtype=torch.bool, - device=device) + expected_unique_dim1 = torch.tensor( + [ + [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]], + [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]], + ], + dtype=dtype, + device=device, + ) + expected_unique_dim1_bool = torch.tensor( + [[[False, True], [True, True]], [[False, True], [True, True]]], + dtype=torch.bool, + device=device, + ) expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) expected_inverse_dim1_bool = torch.tensor([1, 0, 1, 0]) expected_counts_dim1 = torch.tensor([2, 1, 1]) expected_counts_dim1_bool = torch.tensor([2, 2]) - expected_unique_dim2 = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]], - [[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) + expected_unique_dim2 = torch.tensor( + [ + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + ], + dtype=dtype, + device=device, + ) expected_inverse_dim2 = torch.tensor([0, 1]) expected_counts_dim2 = torch.tensor([1, 1]) expected_unique_empty = torch.empty(5, 0, dtype=dtype, device=device) expected_inverse_empty = torch.tensor([], dtype=torch.long, device=device) expected_counts_empty = torch.tensor([], dtype=torch.long, device=device) if dtype in floating_types_and(torch.float16, torch.bfloat16): - expected_unique_nan = torch.tensor([float("nan"), 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) - expected_inverse_nan = torch.tensor([0, 1, 1, 2, 3, 4], dtype=torch.long, device=device) - expected_counts_nan = torch.tensor([1, 2, 1, 1, 1], dtype=torch.long, device=device) + expected_unique_nan = torch.tensor( + [float("nan"), 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) + expected_inverse_nan = torch.tensor( + [0, 1, 1, 2, 3, 4], dtype=torch.long, device=device + ) + expected_counts_nan = torch.tensor( + [1, 2, 1, 1, 1], dtype=torch.long, device=device + ) # dim0 x_unique = torch.unique(x, dim=0) self.assertEqual(expected_unique_dim0, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=0) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=0) + x, return_inverse=False, return_counts=True, dim=0 + ) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_counts_dim0, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=0) + x, return_inverse=True, return_counts=True, dim=0 + ) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) self.assertEqual(expected_counts_dim0, x_counts) @@ -553,10 +621,7 @@ def run_test(device, dtype): else: self.assertEqual(expected_unique_dim1, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=1) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_inverse_dim1_bool, x_inverse) @@ -565,10 +630,8 @@ def run_test(device, dtype): self.assertEqual(expected_inverse_dim1, x_inverse) x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=1) + x, return_inverse=False, return_counts=True, dim=1 + ) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_counts_dim1_bool, x_counts) @@ -577,10 +640,8 @@ def run_test(device, dtype): self.assertEqual(expected_counts_dim1, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=1) + x, return_inverse=True, return_counts=True, dim=1 + ) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_inverse_dim1_bool, x_inverse) @@ -594,36 +655,27 @@ def run_test(device, dtype): x_unique = torch.unique(x, dim=2) self.assertEqual(expected_unique_dim2, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=2) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=2) + x, return_inverse=False, return_counts=True, dim=2 + ) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_counts_dim2, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=2) + x, return_inverse=True, return_counts=True, dim=2 + ) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) self.assertEqual(expected_counts_dim2, x_counts) # test empty tensor x_unique, x_inverse, x_counts = torch.unique( - x_empty, - return_inverse=True, - return_counts=True, - dim=1) + x_empty, return_inverse=True, return_counts=True, dim=1 + ) self.assertEqual(expected_unique_empty, x_unique) self.assertEqual(expected_inverse_empty, x_inverse) self.assertEqual(expected_counts_empty, x_counts) @@ -631,10 +683,8 @@ def run_test(device, dtype): # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): x_unique, x_inverse, x_counts = torch.unique( - x_nan, - return_inverse=True, - return_counts=True, - dim=0) + x_nan, return_inverse=True, return_counts=True, dim=0 + ) self.assertEqual(expected_unique_nan, x_unique) self.assertEqual(expected_inverse_nan, x_inverse) self.assertEqual(expected_counts_nan, x_counts) @@ -643,10 +693,8 @@ def run_test(device, dtype): # Checking for runtime error, as this is the expected behaviour with self.assertRaises(RuntimeError): torch.unique( - x_ill_formed_empty, - return_inverse=True, - return_counts=True, - dim=1) + x_ill_formed_empty, return_inverse=True, return_counts=True, dim=1 + ) # test along dim2 with self.assertRaises(RuntimeError): @@ -654,46 +702,66 @@ def run_test(device, dtype): x_ill_formed_empty_another, return_inverse=True, return_counts=True, - dim=2) + dim=2, + ) # test consecutive version y = torch.tensor( - [[0, 1], - [0, 1], - [0, 1], - [1, 2], - [1, 2], - [3, 4], - [0, 1], - [0, 1], - [3, 4], - [1, 2]], + [ + [0, 1], + [0, 1], + [0, 1], + [1, 2], + [1, 2], + [3, 4], + [0, 1], + [0, 1], + [3, 4], + [1, 2], + ], dtype=dtype, - device=device + device=device, ) # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): - y_nan = torch.tensor([float("nan"), 0, 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) + y_nan = torch.tensor( + [float("nan"), 0, 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) expected_y_unique = torch.tensor( - [[0, 1], - [1, 2], - [3, 4], - [0, 1], - [3, 4], - [1, 2]], + [[0, 1], [1, 2], [3, 4], [0, 1], [3, 4], [1, 2]], dtype=dtype, - device=device + device=device, + ) + expected_y_inverse = torch.tensor( + [0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device + ) + expected_y_counts = torch.tensor( + [3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device + ) + expected_y_inverse_bool = torch.tensor( + [0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device + ) + expected_y_counts_bool = torch.tensor( + [3, 3, 2, 2], dtype=torch.int64, device=device ) - expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device) - expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device) - expected_y_inverse_bool = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device) - expected_y_counts_bool = torch.tensor([3, 3, 2, 2], dtype=torch.int64, device=device) if dtype in floating_types_and(torch.float16, torch.bfloat16): - expected_y_unique_nan = torch.tensor([float("nan"), 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) - expected_y_inverse_nan = torch.tensor([0, 1, 1, 2, 3, 4], dtype=torch.long, device=device) - expected_y_counts_nan = torch.tensor([1, 2, 1, 1, 1], dtype=torch.long, device=device) - - y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0) + expected_y_unique_nan = torch.tensor( + [float("nan"), 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) + expected_y_inverse_nan = torch.tensor( + [0, 1, 1, 2, 3, 4], dtype=torch.long, device=device + ) + expected_y_counts_nan = torch.tensor( + [1, 2, 1, 1, 1], dtype=torch.long, device=device + ) + + y_unique, y_inverse, y_counts = torch.unique_consecutive( + y, return_inverse=True, return_counts=True, dim=0 + ) if x.dtype == torch.bool: self.assertEqual(expected_y_inverse_bool, y_inverse) self.assertEqual(expected_y_counts_bool, y_counts) @@ -704,23 +772,27 @@ def run_test(device, dtype): # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): y_unique, y_inverse, y_counts = torch.unique_consecutive( - y_nan, - return_inverse=True, - return_counts=True, - dim=0) + y_nan, return_inverse=True, return_counts=True, dim=0 + ) self.assertEqual(expected_y_unique_nan, y_unique) self.assertEqual(expected_y_inverse_nan, y_inverse) self.assertEqual(expected_y_counts_nan, y_counts) # Test dim is sorted same as NumPy with dims >= 3 - x = torch.tensor([[[[1, 0, 1, 0, 1, 1], - [0, 1, 1, 0, 1, 1]], - [[0, 1, 1, 0, 0, 1], - [0, 0, 0, 1, 0, 0]]], - [[[0, 1, 0, 1, 1, 1], - [0, 1, 1, 0, 1, 1]], - [[0, 0, 1, 1, 0, 1], - [1, 1, 0, 0, 0, 0]]]], dtype=dtype, device=device) + x = torch.tensor( + [ + [ + [[1, 0, 1, 0, 1, 1], [0, 1, 1, 0, 1, 1]], + [[0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0]], + ], + [ + [[0, 1, 0, 1, 1, 1], [0, 1, 1, 0, 1, 1]], + [[0, 0, 1, 1, 0, 1], [1, 1, 0, 0, 0, 0]], + ], + ], + dtype=dtype, + device=device, + ) xn = x.cpu().numpy() for d in range(x.dim()): t = torch.unique(x, dim=d) @@ -750,15 +822,20 @@ def test_topk_noncontiguous_gpu(self, device): def _test_topk_dtype(self, device, dtype, integral, size): if integral: - a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, - size=(size,), dtype=dtype, device=device) + a = torch.randint( + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + size=(size,), + dtype=dtype, + device=device, + ) else: a = torch.randn(size=(size,), dtype=dtype, device=device) - sort_topk = a.sort()[0][-(size // 2):].flip(0) + sort_topk = a.sort()[0][-(size // 2) :].flip(0) topk = a.topk(size // 2) - self.assertEqual(sort_topk, topk[0]) # check values - self.assertEqual(sort_topk, a[topk[1]]) # check indices + self.assertEqual(sort_topk, topk[0]) # check values + self.assertEqual(sort_topk, a[topk[1]]) # check indices @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64) def test_topk_integral(self, device, dtype): @@ -770,7 +847,6 @@ def test_topk_integral(self, device, dtype): @dtypes(torch.bfloat16, torch.half) def test_topk_lower_precision(self, device, dtype): - small = 10 large = 4096 verylarge = 8192 # multi_block topk on cuda @@ -780,14 +856,20 @@ def test_topk_lower_precision(self, device, dtype): @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) @dtypes(torch.float, torch.double, torch.bfloat16, torch.half) def test_topk_nonfinite(self, device, dtype): - x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype) + x = torch.tensor( + [float("nan"), float("inf"), 1e4, 0, -1e4, -float("inf")], + device=device, + dtype=dtype, + ) val, idx = x.topk(4) - expect = torch.tensor([float('nan'), float('inf'), 1e4, 0], device=device, dtype=dtype) + expect = torch.tensor( + [float("nan"), float("inf"), 1e4, 0], device=device, dtype=dtype + ) self.assertEqual(val, expect) self.assertEqual(idx, [0, 1, 2, 3]) val, idx = x.topk(4, largest=False) - expect = torch.tensor([-float('inf'), -1e4, 0, 1e4], device=device, dtype=dtype) + expect = torch.tensor([-float("inf"), -1e4, 0, 1e4], device=device, dtype=dtype) self.assertEqual(val, expect) self.assertEqual(idx, [5, 4, 3, 2]) @@ -796,13 +878,13 @@ def test_topk_4d(self, device): large = 8192 for size in (small, large): x = torch.ones(2, size, 2, 2, device=device) - x[:, 1, :, :] *= 2. + x[:, 1, :, :] *= 2.0 x[:, 10, :, :] *= 1.5 val, ind = torch.topk(x, k=2, dim=1) expected_ind = torch.ones(2, 2, 2, 2, dtype=torch.long, device=device) expected_ind[:, 1, :, :] = 10 expected_val = torch.ones(2, 2, 2, 2, device=device) - expected_val[:, 0, :, :] *= 2. + expected_val[:, 0, :, :] *= 2.0 expected_val[:, 1, :, :] *= 1.5 self.assertEqual(val, expected_val, atol=0, rtol=0) self.assertEqual(ind, expected_ind, atol=0, rtol=0) @@ -838,7 +920,17 @@ def _test_unique_scalar_empty(self, dtype, device, f): self.assertEqual(inverse, expected_inverse) self.assertEqual(counts, expected_counts) - def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape): + def _test_unique_with_expects( + self, + device, + dtype, + f, + x, + expected_unique, + expected_inverse, + expected_counts, + additional_shape, + ): def ensure_tuple(x): if isinstance(x, torch.Tensor): return (x,) @@ -847,7 +939,9 @@ def ensure_tuple(x): for return_inverse in [True, False]: for return_counts in [True, False]: # test with expected - ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) + ret = ensure_tuple( + f(x, return_inverse=return_inverse, return_counts=return_counts) + ) self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) self.assertEqual(expected_unique, ret[0]) if return_inverse: @@ -858,7 +952,9 @@ def ensure_tuple(x): # tests per-element unique on a higher rank tensor. y = x.view(additional_shape) - y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True) + y_unique, y_inverse, y_counts = f( + y, return_inverse=True, return_counts=True + ) self.assertEqual(expected_unique, y_unique) self.assertEqual(expected_inverse.view(additional_shape), y_inverse) self.assertEqual(expected_counts, y_counts) @@ -872,9 +968,17 @@ def ensure_tuple(x): return x if dtype is torch.bool: - x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device) - expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device) - expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device) + x = torch.tensor( + [True, False, False, False, True, False, True, False], + dtype=torch.bool, + device=device, + ) + expected_unique = torch.tensor( + [False, True], dtype=torch.bool, device=device + ) + expected_inverse = torch.tensor( + [1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device + ) expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device) else: x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device) @@ -890,18 +994,29 @@ def ensure_tuple(x): x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x) xs = (x, x_sliced) for f, x in product(fs, xs): - self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2)) + self._test_unique_with_expects( + device, + dtype, + f, + x, + expected_unique, + expected_inverse, + expected_counts, + (2, 2, 2), + ) self._test_unique_scalar_empty(dtype, device, f) # test unsorted unique fs = ( lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs), - lambda x, **kwargs: x.unique(sorted=False, **kwargs) + lambda x, **kwargs: x.unique(sorted=False, **kwargs), ) for f, x in product(fs, xs): self._test_unique_scalar_empty(dtype, device, f) for return_inverse, return_counts in product((True, False), repeat=2): - ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) + ret = ensure_tuple( + f(x, return_inverse=return_inverse, return_counts=return_counts) + ) self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) x_list = x.tolist() x_unique_list = ret[0].tolist() @@ -924,18 +1039,40 @@ def ensure_tuple(x): @dtypes(*all_types_and(torch.half, torch.bool)) def test_unique_consecutive(self, device, dtype): if dtype is torch.bool: - x = torch.tensor([True, False, False, False, True, True, False, False, False], dtype=torch.bool, device=device) - expected_unique = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) - expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device) - expected_counts = torch.tensor([1, 3, 2, 3], dtype=torch.long, device=device) + x = torch.tensor( + [True, False, False, False, True, True, False, False, False], + dtype=torch.bool, + device=device, + ) + expected_unique = torch.tensor( + [True, False, True, False], dtype=torch.bool, device=device + ) + expected_inverse = torch.tensor( + [0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device + ) + expected_counts = torch.tensor( + [1, 3, 2, 3], dtype=torch.long, device=device + ) else: x = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], dtype=dtype, device=device) expected_unique = torch.tensor([1, 2, 5, 2, 3], dtype=dtype, device=device) expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device) expected_counts = torch.tensor([1, 3, 2, 2, 1], device=device) - for f in [torch.unique_consecutive, lambda x, **kwargs: x.unique_consecutive(**kwargs)]: - self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (3, 3)) + for f in [ + torch.unique_consecutive, + lambda x, **kwargs: x.unique_consecutive(**kwargs), + ]: + self._test_unique_with_expects( + device, + dtype, + f, + x, + expected_unique, + expected_inverse, + expected_counts, + (3, 3), + ) self._test_unique_scalar_empty(dtype, device, f) @dtypes(torch.double) @@ -991,7 +1128,7 @@ def test_kthvalue(self, device, dtype): self.assertEqual(x, x0, atol=0, rtol=0) # simple test case (with repetitions) - y = torch.tensor((3., 5, 4, 1, 1, 5), dtype=dtype, device=device) + y = torch.tensor((3.0, 5, 4, 1, 1, 5), dtype=dtype, device=device) self.assertEqual(torch.kthvalue(y, 3)[0], 3, atol=0, rtol=0) self.assertEqual(torch.kthvalue(y, 2)[0], 1, atol=0, rtol=0) @@ -1007,7 +1144,7 @@ def test_kthvalue(self, device, dtype): self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) @dtypes(torch.float) - @onlyNativeDeviceTypes # Fails on XLA + @onlyNativeDeviceTypes # Fails on XLA def test_kthvalue_scalar(self, device, dtype): # Test scalar input (test case from https://github.com/pytorch/pytorch/issues/30818) # Tests that passing a scalar tensor or 1D tensor with 1 element work either way @@ -1029,7 +1166,9 @@ def assert_isin_equal(a, b): # multi-dim tensor, multi-dim tensor a = torch.arange(24, device=device, dtype=dtype).reshape([2, 3, 4]) - b = torch.tensor([[10, 20, 30], [0, 1, 3], [11, 22, 33]], device=device, dtype=dtype) + b = torch.tensor( + [[10, 20, 30], [0, 1, 3], [11, 22, 33]], device=device, dtype=dtype + ) assert_isin_equal(a, b) # zero-dim tensor @@ -1073,16 +1212,56 @@ def define_expected(lst, invert=False): c = torch.isin(a, b, assume_unique=True, invert=invert) self.assertEqual(c, ec) - a = torch.tensor([5, 4, 5, 3, 4, 4, 3, 4, 3, 5, 2, 1, 5, 5], device=device, dtype=dtype) + a = torch.tensor( + [5, 4, 5, 3, 4, 4, 3, 4, 3, 5, 2, 1, 5, 5], + device=device, + dtype=dtype, + ) b = torch.tensor([2, 3, 4] * mult, device=device, dtype=dtype) - ec = define_expected([False, True, False, True, True, True, True, True, True, - False, True, False, False, False], invert=invert) + ec = define_expected( + [ + False, + True, + False, + True, + True, + True, + True, + True, + True, + False, + True, + False, + False, + False, + ], + invert=invert, + ) c = torch.isin(a, b, invert=invert) self.assertEqual(c, ec) - b = torch.tensor([2, 3, 4] * mult + [5, 5, 4] * mult, device=device, dtype=dtype) - ec = define_expected([True, True, True, True, True, True, True, True, True, True, - True, False, True, True], invert=invert) + b = torch.tensor( + [2, 3, 4] * mult + [5, 5, 4] * mult, device=device, dtype=dtype + ) + ec = define_expected( + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + ], + invert=invert, + ) c = torch.isin(a, b, invert=invert) self.assertEqual(c, ec) @@ -1108,12 +1287,14 @@ def define_expected(lst, invert=False): for assume_unique in [False, True]: a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3]) b = torch.arange(3, 30, device=device, dtype=dtype) - ec = define_expected([[False, False, False], [True, True, True]], invert=invert) + ec = define_expected( + [[False, False, False], [True, True, True]], invert=invert + ) c = torch.isin(a, b, invert=invert, assume_unique=assume_unique) self.assertEqual(c, ec) def test_isin_different_dtypes(self, device): - supported_types = all_types() if device == 'cpu' else all_types_and(torch.half) + supported_types = all_types() if device == "cpu" else all_types_and(torch.half) for mult in [1, 10]: for assume_unique in [False, True]: for dtype1, dtype2 in product(supported_types, supported_types): @@ -1127,18 +1308,18 @@ def test_isin_different_dtypes(self, device): @dtypes(*all_types()) def test_isin_different_devices(self, device, dtype): a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3]) - b = torch.arange(3, 30, device='cpu', dtype=dtype) + b = torch.arange(3, 30, device="cpu", dtype=dtype) with self.assertRaises(RuntimeError): torch.isin(a, b) - c = torch.arange(6, device='cpu', dtype=dtype).reshape([2, 3]) + c = torch.arange(6, device="cpu", dtype=dtype).reshape([2, 3]) d = torch.arange(3, 30, device=device, dtype=dtype) with self.assertRaises(RuntimeError): torch.isin(c, d) @dtypes(*integral_types()) def test_sort_overflow(self, device, dtype): - " Regression test for https://github.com/pytorch/pytorch/issues/111189 " + "Regression test for https://github.com/pytorch/pytorch/issues/111189" prev_num_threads = torch.get_num_threads() try: low = 0 if dtype == torch.uint8 else -1 @@ -1153,5 +1334,5 @@ def test_sort_overflow(self, device, dtype): instantiate_device_type_tests(TestSortAndSelect, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() From 879d01afcb3a48879c0bbc6cbf0f33d4a1e4b00f Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Tue, 4 Jun 2024 17:56:39 +0000 Subject: [PATCH 363/706] [dynamo][numpy] Add unsigned integer dtypes (#125717) We should support these to whatever extent we can. They corresponding `torch.uint` types are defined, so I don't see an issue with generating the various casting rules and allowing them to trace. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125717 Approved by: https://github.com/lezcano --- ...Histogram.test_unsigned_monotonicity_check | 0 ...omplex.test_sort_real_type_in_H_type_out_F | 0 test/test_numpy_interop.py | 7 +- .../numpy_tests/core/test_getlimits.py | 24 +- .../numpy_tests/core/test_scalarmath.py | 5 +- test/torch_np/test_dtype.py | 2 +- torch/_numpy/_casting_dicts.py | 488 +++++++++++++++++- torch/_numpy/_dtypes.py | 21 +- 8 files changed, 536 insertions(+), 11 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestHistogram.test_unsigned_monotonicity_check delete mode 100644 test/dynamo_expected_failures/TestSortComplex.test_sort_real_type_in_H_type_out_F diff --git a/test/dynamo_expected_failures/TestHistogram.test_unsigned_monotonicity_check b/test/dynamo_expected_failures/TestHistogram.test_unsigned_monotonicity_check deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestSortComplex.test_sort_real_type_in_H_type_out_F b/test/dynamo_expected_failures/TestSortComplex.test_sort_real_type_in_H_type_out_F deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index 940938c79dde..bf81701f37dc 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -476,13 +476,18 @@ def test_multiplication_numpy_scalar(self, device) -> None: self.assertTrue(r2.requires_grad) @onlyCPU - def test_parse_numpy_int(self, device): + @skipIfTorchDynamo() + def test_parse_numpy_int_overflow(self, device): + # assertRaises uses a try-except which dynamo has issues with # Only concrete class can be given where "Type[number[_64Bit]]" is expected self.assertRaisesRegex( RuntimeError, "(Overflow|an integer is required)", lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)), ) # type: ignore[call-overload] + + @onlyCPU + def test_parse_numpy_int(self, device): # https://github.com/pytorch/pytorch/issues/29252 for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]: scalar = 3 diff --git a/test/torch_np/numpy_tests/core/test_getlimits.py b/test/torch_np/numpy_tests/core/test_getlimits.py index ab5b08319db6..3be8bc2619ab 100644 --- a/test/torch_np/numpy_tests/core/test_getlimits.py +++ b/test/torch_np/numpy_tests/core/test_getlimits.py @@ -8,14 +8,17 @@ # from numpy.core.getlimits import _discovered_machar, _float_ma -from unittest import skipIf +from unittest import expectedFailure as xfail, skipIf import numpy from pytest import raises as assert_raises from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, run_tests, + subtest, TEST_WITH_TORCHDYNAMO, TestCase, xpassIfTorchDynamo, @@ -109,6 +112,7 @@ def test_basic_missing(self): getattr(finfo(dt), attr) +@instantiate_parametrized_tests class TestIinfo(TestCase): def test_basic(self): dts = list( @@ -129,11 +133,19 @@ def test_basic(self): with assert_raises((TypeError, ValueError)): iinfo("f4") - def test_unsigned_max(self): - types = np.sctypes["uint"] - for T in types: - max_calculated = T(0) - T(1) - assert_equal(iinfo(T).max, max_calculated) + @parametrize( + "T", + [ + np.uint8, + # xfail: unsupported add (uint[16,32,64]) + subtest(np.uint16, decorators=[xfail]), + subtest(np.uint32, decorators=[xfail]), + subtest(np.uint64, decorators=[xfail]), + ], + ) + def test_unsigned_max(self, T): + max_calculated = T(0) - T(1) + assert_equal(iinfo(T).max, max_calculated) class TestRepr(TestCase): diff --git a/test/torch_np/numpy_tests/core/test_scalarmath.py b/test/torch_np/numpy_tests/core/test_scalarmath.py index 8099ca8c4c32..d86595d9d3cc 100644 --- a/test/torch_np/numpy_tests/core/test_scalarmath.py +++ b/test/torch_np/numpy_tests/core/test_scalarmath.py @@ -732,13 +732,16 @@ def test_numpy_abs(self, dtype): @instantiate_parametrized_tests class TestBitShifts(TestCase): - @parametrize("type_code", np.typecodes["Integer"] + "B") + @parametrize("type_code", np.typecodes["AllInteger"]) @parametrize("op", [operator.rshift, operator.lshift]) def test_shift_all_bits(self, type_code, op): """Shifts where the shift amount is the width of the type or wider""" # gh-2449 dt = np.dtype(type_code) nbits = dt.itemsize * 8 + if dt in (np.dtype(np.uint64), np.dtype(np.uint32), np.dtype(np.uint16)): + raise SkipTest("NYI: bitshift uint64") + for val in [5, -5]: for shift in [nbits, nbits + 4]: val_scl = np.array(val).astype(dt)[()] diff --git a/test/torch_np/test_dtype.py b/test/torch_np/test_dtype.py index 42866adbe5c2..e288e54286e7 100644 --- a/test/torch_np/test_dtype.py +++ b/test/torch_np/test_dtype.py @@ -18,7 +18,7 @@ dtype_names = [ "bool_", *[f"int{w}" for w in [8, 16, 32, 64]], - "uint8", + *[f"uint{w}" for w in [8, 16, 32, 64]], *[f"float{w}" for w in [16, 32, 64]], *[f"complex{w}" for w in [64, 128]], ] diff --git a/torch/_numpy/_casting_dicts.py b/torch/_numpy/_casting_dicts.py index 513e73ef2efe..b30ce7c55604 100644 --- a/torch/_numpy/_casting_dicts.py +++ b/torch/_numpy/_casting_dicts.py @@ -3,7 +3,7 @@ import torch # These two dicts are autogenerated with autogen/gen_dtypes.py, -# using numpy version 1.23.5. +# using numpy version 1.24.3. _can_cast_dict = { "no": { @@ -14,6 +14,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -27,6 +30,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -40,6 +46,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -53,6 +62,9 @@ torch.complex64: True, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -66,6 +78,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -79,6 +94,57 @@ torch.complex64: False, torch.complex128: False, torch.uint8: True, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint16: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: True, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint32: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: True, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: True, torch.int8: False, torch.int16: False, torch.int32: False, @@ -92,6 +158,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: False, torch.int32: False, @@ -105,6 +174,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: True, torch.int32: False, @@ -118,6 +190,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: True, @@ -131,6 +206,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -144,6 +222,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -159,6 +240,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -172,6 +256,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -185,6 +272,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -198,6 +288,9 @@ torch.complex64: True, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -211,6 +304,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -224,6 +320,57 @@ torch.complex64: False, torch.complex128: False, torch.uint8: True, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint16: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: True, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint32: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: True, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: True, torch.int8: False, torch.int16: False, torch.int32: False, @@ -237,6 +384,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: False, torch.int32: False, @@ -250,6 +400,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: True, torch.int32: False, @@ -263,6 +416,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: True, @@ -276,6 +432,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -289,6 +448,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -304,6 +466,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -317,6 +482,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -330,6 +498,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -343,6 +514,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -356,6 +530,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -369,12 +546,63 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: False, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: False, }, + torch.uint16: { + torch.float16: False, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint32: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: True, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: True, + torch.bool: False, + }, + torch.uint64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, torch.int8: { torch.float16: True, torch.float32: True, @@ -382,6 +610,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: True, torch.int32: True, @@ -395,6 +626,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: True, torch.int32: True, @@ -408,6 +642,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: True, @@ -421,6 +658,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -434,6 +674,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -449,6 +692,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -462,6 +708,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -475,6 +724,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -488,6 +740,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -501,6 +756,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -514,6 +772,57 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -527,6 +836,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: True, torch.int32: True, @@ -540,6 +852,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: True, torch.int32: True, @@ -553,6 +868,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: True, torch.int32: True, @@ -566,6 +884,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: True, torch.int32: True, @@ -579,6 +900,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -594,6 +918,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -607,6 +934,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -620,6 +950,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -633,6 +966,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -646,6 +982,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -659,6 +998,57 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -672,6 +1062,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -685,6 +1078,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -698,6 +1094,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -711,6 +1110,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -724,6 +1126,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -742,6 +1147,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.float16, + torch.uint16: torch.float32, + torch.uint32: torch.float64, + torch.uint64: torch.float64, torch.int8: torch.float16, torch.int16: torch.float32, torch.int32: torch.float64, @@ -755,6 +1163,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.float32, + torch.uint16: torch.float32, + torch.uint32: torch.float64, + torch.uint64: torch.float64, torch.int8: torch.float32, torch.int16: torch.float32, torch.int32: torch.float64, @@ -768,6 +1179,9 @@ torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.float64, + torch.uint16: torch.float64, + torch.uint32: torch.float64, + torch.uint64: torch.float64, torch.int8: torch.float64, torch.int16: torch.float64, torch.int32: torch.float64, @@ -781,6 +1195,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.complex64, + torch.uint16: torch.complex64, + torch.uint32: torch.complex128, + torch.uint64: torch.complex128, torch.int8: torch.complex64, torch.int16: torch.complex64, torch.int32: torch.complex128, @@ -794,6 +1211,9 @@ torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.complex128, + torch.uint16: torch.complex128, + torch.uint32: torch.complex128, + torch.uint64: torch.complex128, torch.int8: torch.complex128, torch.int16: torch.complex128, torch.int32: torch.complex128, @@ -807,12 +1227,63 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.uint8, + torch.uint16: torch.uint16, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, torch.int8: torch.int16, torch.int16: torch.int16, torch.int32: torch.int32, torch.int64: torch.int64, torch.bool: torch.uint8, }, + torch.uint16: { + torch.float16: torch.float32, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.uint16, + torch.uint16: torch.uint16, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, + torch.int8: torch.int32, + torch.int16: torch.int32, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.uint16, + }, + torch.uint32: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.uint32, + torch.uint16: torch.uint32, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, + torch.int8: torch.int64, + torch.int16: torch.int64, + torch.int32: torch.int64, + torch.int64: torch.int64, + torch.bool: torch.uint32, + }, + torch.uint64: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.uint64, + torch.uint16: torch.uint64, + torch.uint32: torch.uint64, + torch.uint64: torch.uint64, + torch.int8: torch.float64, + torch.int16: torch.float64, + torch.int32: torch.float64, + torch.int64: torch.float64, + torch.bool: torch.uint64, + }, torch.int8: { torch.float16: torch.float16, torch.float32: torch.float32, @@ -820,6 +1291,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.int16, + torch.uint16: torch.int32, + torch.uint32: torch.int64, + torch.uint64: torch.float64, torch.int8: torch.int8, torch.int16: torch.int16, torch.int32: torch.int32, @@ -833,6 +1307,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.int16, + torch.uint16: torch.int32, + torch.uint32: torch.int64, + torch.uint64: torch.float64, torch.int8: torch.int16, torch.int16: torch.int16, torch.int32: torch.int32, @@ -846,6 +1323,9 @@ torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.int32, + torch.uint16: torch.int32, + torch.uint32: torch.int64, + torch.uint64: torch.float64, torch.int8: torch.int32, torch.int16: torch.int32, torch.int32: torch.int32, @@ -859,6 +1339,9 @@ torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.int64, + torch.uint16: torch.int64, + torch.uint32: torch.int64, + torch.uint64: torch.float64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64, @@ -872,6 +1355,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.uint8, + torch.uint16: torch.uint16, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, torch.int8: torch.int8, torch.int16: torch.int16, torch.int32: torch.int32, diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py index f8b8f4f722be..27799adaf563 100644 --- a/torch/_numpy/_dtypes.py +++ b/torch/_numpy/_dtypes.py @@ -113,6 +113,24 @@ class uint8(unsignedinteger): torch_dtype = torch.uint8 +class uint16(unsignedinteger): + name = "uint16" + typecode = "H" + torch_dtype = torch.uint16 + + +class uint32(signedinteger): + name = "uint32" + typecode = "I" + torch_dtype = torch.uint32 + + +class uint64(signedinteger): + name = "uint64" + typecode = "L" + torch_dtype = torch.uint64 + + # floating point @@ -160,6 +178,7 @@ class bool_(generic): "byte": int8, "short": int16, "longlong": int64, # XXX: is this correct? + "ulonglong": uint64, "ubyte": uint8, "half": float16, "single": float32, @@ -180,7 +199,7 @@ class bool_(generic): # cf tests/core/test_scalar_methods.py sctypes = { "int": [int8, int16, int32, int64], - "uint": [uint8], + "uint": [uint8, uint16, uint32, uint64], "float": [float16, float32, float64], "complex": [complex64, complex128], "others": [bool_], From 8992141dbadff396dbef1aa9a9b0e3550957e841 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 5 Jun 2024 14:44:00 +0000 Subject: [PATCH 364/706] Restore MPS testing on MacOS 13 and m2 metal (#127853) The runners are ready now https://github.com/organizations/pytorch/settings/actions/runners?qr=label%3Amacos-m1-13, we want to keep some MacOS 13 runner for mps coverage until MacOS 15 is out. This also fixes the `macos-m2-14` mistake from https://github.com/pytorch/pytorch/pull/127582. The current `macos-m2-14` runner is on 14.2 while our `macos-m1-14` has 14.4. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127853 Approved by: https://github.com/malfet --- .github/workflows/mac-mps.yml | 5 ++++- .github/workflows/trunk.yml | 1 + test/dynamo/test_dynamic_shapes.py | 3 +++ test/test_mps.py | 16 ++++++++++++++++ .../_internal/common_methods_invocations.py | 2 -- 5 files changed, 24 insertions(+), 3 deletions(-) diff --git a/.github/workflows/mac-mps.yml b/.github/workflows/mac-mps.yml index 53504b6133f6..06521f20c49e 100644 --- a/.github/workflows/mac-mps.yml +++ b/.github/workflows/mac-mps.yml @@ -23,9 +23,12 @@ jobs: build-generates-artifacts: true # To match the one pre-installed in the m1 runners python-version: 3.9.12 + # The runner macos-m2-14 is not a typo, it's a custom runner that is different + # than our AWS macos-m1-14 runners test-matrix: | { include: [ - { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" }, + { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" }, + { config: "mps", shard: 1, num_shards: 1, runner: "macos-m2-14" }, ]} macos-py3-arm64-mps-test: diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index a91238fa2c9b..77f54f937ad0 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -172,6 +172,7 @@ jobs: python-version: 3.9.12 test-matrix: | { include: [ + { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" }, { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" }, ]} diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 175ed573391b..0bead6e47e48 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -104,6 +104,9 @@ def make_dynamic_cls(cls): DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes = slowTest( # noqa: F821 DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes # noqa: F821 ) +DynamicShapesExportTests.test_retracibility_nested_list_out_dynamic_shapes = slowTest( # noqa: F821 + DynamicShapesExportTests.test_retracibility_nested_list_out_dynamic_shapes # noqa: F821 +) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/test_mps.py b/test/test_mps.py index 8c3bbf4b7bcf..32e7a2a08861 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -193,6 +193,9 @@ def mps_ops_grad_modifier(ops): # Failures due to lack of implementation of downstream functions on MPS backend # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented 'linalg.matrix_rank': None, + + # Exception: Caused by sample input at index 3 on MPS + 'nn.functional.conv3d': [torch.float32], } def addDecorator(op, d) -> None: @@ -667,6 +670,11 @@ def mps_ops_modifier(ops): 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8], } + MACOS_BEFORE_14_4_XFAILLIST = { + # These ops work fine in 14.4 but fail in 14.2 or 13.x + 'fft.hfft2': [torch.complex64], + } + # Those ops are not expected to work UNIMPLEMENTED_XFAILLIST = { # Failures due to lack of op implementation on MPS backend @@ -1020,6 +1028,9 @@ def mps_ops_modifier(ops): # Unsupported # input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible 'nn.functional.avg_pool2d': [torch.float16], + + # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16 + 'nn.functional.conv3d': None, } def addDecorator(op, d) -> None: @@ -1040,6 +1051,11 @@ def addDecorator(op, d) -> None: unittest.expectedFailure, dtypes=xfaillist[key])) + if key in MACOS_BEFORE_14_4_XFAILLIST and (product_version < 14.4): + addDecorator(op, DecorateInfo( + unittest.expectedFailure, + dtypes=MACOS_BEFORE_14_4_XFAILLIST[key])) + if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3): addDecorator(op, DecorateInfo( unittest.expectedFailure, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 832a75e6639a..151210cf9f53 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14987,8 +14987,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # RuntimeError: UNSUPPORTED DTYPE: complex DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), - # RuntimeError: Conv3D is not supported on MPS - DecorateInfo(unittest.expectedFailure, 'TestConsistency'), # AssertionError: Tensor-likes are not close! # break slow tests DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), From d48c25c7d1dfc351494bc00c1ab8a9e99493d95f Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 5 Jun 2024 14:58:17 +0000 Subject: [PATCH 365/706] [BE] Fix missing-prototypes errors in Metal backend (#127994) By declaring a bunch of functions static. Removed `USE_PYTORCH_METAL` from list of flags that suppress `-Werror=missing-prototypes`. This will prevent regressions like the ones reported in https://github.com/pytorch/pytorch/issues/127942 to sneak past CI, that builds PyTorch with Metal support. Use nested namespaces Remove spurious semicolon after TORCH_LIBRARY declaration. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127994 Approved by: https://github.com/Skylion007, https://github.com/ZainRizvi --- aten/src/ATen/native/metal/MetalAten.mm | 16 +++++----- aten/src/ATen/native/metal/MetalConvParams.h | 8 ++--- aten/src/ATen/native/metal/MetalDevice.h | 8 ++--- aten/src/ATen/native/metal/MetalNeuronType.h | 8 ++--- .../ATen/native/metal/MetalPrepackOpContext.h | 8 ++--- .../native/metal/MetalPrepackOpRegister.cpp | 16 ++++------ .../native/metal/MetalTensorImplStorage.h | 8 ++--- aten/src/ATen/native/metal/MetalTensorUtils.h | 8 ++--- .../ATen/native/metal/mpscnn/MPSCNNUtils.h | 10 ++----- .../ATen/native/metal/mpscnn/MPSCNNUtils.mm | 16 ++++------ aten/src/ATen/native/metal/ops/MetalAddmm.mm | 14 ++++----- .../metal/ops/MetalBinaryElementwise.mm | 30 ++++++++----------- aten/src/ATen/native/metal/ops/MetalChunk.mm | 12 +++----- aten/src/ATen/native/metal/ops/MetalClamp.mm | 16 ++++------ aten/src/ATen/native/metal/ops/MetalConcat.mm | 14 ++++----- .../ATen/native/metal/ops/MetalConvolution.h | 8 ++--- .../ATen/native/metal/ops/MetalConvolution.mm | 10 ++----- aten/src/ATen/native/metal/ops/MetalCopy.h | 8 ++--- aten/src/ATen/native/metal/ops/MetalCopy.mm | 12 +++----- .../ATen/native/metal/ops/MetalHardshrink.mm | 14 ++++----- .../ATen/native/metal/ops/MetalHardswish.mm | 14 ++++----- .../ATen/native/metal/ops/MetalLeakyReLU.mm | 14 ++++----- .../src/ATen/native/metal/ops/MetalNeurons.mm | 24 +++++++-------- .../src/ATen/native/metal/ops/MetalPadding.mm | 12 +++----- .../src/ATen/native/metal/ops/MetalPooling.mm | 12 +++----- aten/src/ATen/native/metal/ops/MetalReduce.mm | 10 ++----- .../src/ATen/native/metal/ops/MetalReshape.mm | 18 +++++------ .../src/ATen/native/metal/ops/MetalSoftmax.mm | 12 +++----- .../ATen/native/metal/ops/MetalTranspose.mm | 12 +++----- .../metal/ops/MetalUpsamplingNearest.mm | 10 ++----- caffe2/CMakeLists.txt | 2 +- 31 files changed, 131 insertions(+), 253 deletions(-) diff --git a/aten/src/ATen/native/metal/MetalAten.mm b/aten/src/ATen/native/metal/MetalAten.mm index a1ee8e6f8ded..ec6156573e06 100644 --- a/aten/src/ATen/native/metal/MetalAten.mm +++ b/aten/src/ATen/native/metal/MetalAten.mm @@ -6,10 +6,9 @@ #include namespace at { -namespace native { -namespace metal { +namespace native::metal { -at::Tensor& copy_from_metal_(at::Tensor& dst, const at::Tensor& src) { +static Tensor& copy_from_metal_(Tensor& dst, const Tensor& src) { TORCH_INTERNAL_ASSERT( src.device().type() == DeviceType::Metal, "copy_from_metal input tensor's device is not metal"); @@ -34,7 +33,7 @@ return dst; } -at::Tensor& copy_to_metal_(at::Tensor& dst, const at::Tensor& src) { +static Tensor& copy_to_metal_(Tensor& dst, const Tensor& src) { TORCH_INTERNAL_ASSERT( dst.device().type() == DeviceType::Metal, "copy_to_metal_ output tensor's device is not metal"); @@ -54,7 +53,7 @@ return dst; } -at::Tensor& metal_copy_impl_(at::Tensor& dst, const at::Tensor& src) { +static Tensor& metal_copy_impl_(Tensor& dst, const Tensor& src) { if (src.device().type() == at::kMetal && dst.device().type() == at::kCPU) { return copy_from_metal_(dst, src); } @@ -69,7 +68,7 @@ #pragma mark - ATen Ops -Tensor empty( +static Tensor empty( c10::SymIntArrayRef sym_size, optional dtype, optional layout, @@ -88,7 +87,7 @@ Tensor empty( std::move(mt), at::device(at::kMetal).dtype(dtype)); }; -at::Tensor empty_strided( +static Tensor empty_strided( IntArrayRef size, IntArrayRef stride, optional dtype, @@ -109,8 +108,7 @@ Tensor empty( m.impl(TORCH_SELECTIVE_NAME("aten::empty_strided"), TORCH_FN(empty_strided)); } -} // namespace metal -} // namespace native +} // namespace native::metal struct MetalImpl : public at::metal::MetalInterface { bool is_metal_available() const override { diff --git a/aten/src/ATen/native/metal/MetalConvParams.h b/aten/src/ATen/native/metal/MetalConvParams.h index 7b0bfc9670a1..55a8ea657e72 100644 --- a/aten/src/ATen/native/metal/MetalConvParams.h +++ b/aten/src/ATen/native/metal/MetalConvParams.h @@ -3,9 +3,7 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { struct Conv2DParams final { Conv2DParams() {} @@ -46,8 +44,6 @@ struct Conv2DParams final { int64_t OH; // output height }; -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal #endif /* MetalConvParams_h */ diff --git a/aten/src/ATen/native/metal/MetalDevice.h b/aten/src/ATen/native/metal/MetalDevice.h index 29d34246cc1b..42c3ae43cd02 100644 --- a/aten/src/ATen/native/metal/MetalDevice.h +++ b/aten/src/ATen/native/metal/MetalDevice.h @@ -5,9 +5,7 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { struct MetalDeviceInfo { std::string name; @@ -42,8 +40,6 @@ static inline MetalDeviceInfo createDeviceInfo(id device) { return device_info; } -} -} -} +} // namespace at::native::metal #endif diff --git a/aten/src/ATen/native/metal/MetalNeuronType.h b/aten/src/ATen/native/metal/MetalNeuronType.h index c5cb0b99502c..e1cada24a7fd 100644 --- a/aten/src/ATen/native/metal/MetalNeuronType.h +++ b/aten/src/ATen/native/metal/MetalNeuronType.h @@ -6,9 +6,7 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { enum class NeuronType { None, @@ -66,8 +64,6 @@ static inline MPSNNNeuronDescriptor* neuronDescriptor(NeuronType type) { } } -} -} -} +} // namespace at::native::metal #endif /* MetalNeuronType_h */ diff --git a/aten/src/ATen/native/metal/MetalPrepackOpContext.h b/aten/src/ATen/native/metal/MetalPrepackOpContext.h index 4481c879eec2..a484812d6874 100644 --- a/aten/src/ATen/native/metal/MetalPrepackOpContext.h +++ b/aten/src/ATen/native/metal/MetalPrepackOpContext.h @@ -3,9 +3,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using SerializationTypeConv2dPrePack = std::tuple< Tensor, @@ -197,6 +195,4 @@ class LinearOpContext : public torch::jit::CustomClassHolder { std::function releaseCallback_ = nullptr; }; -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp b/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp index ebf9b9daf626..d4a7e463d777 100644 --- a/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp +++ b/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp @@ -3,11 +3,9 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { -c10::intrusive_ptr unpack( +static c10::intrusive_ptr unpack( Tensor&& weight, std::optional&& bias, std::vector&& stride, @@ -28,7 +26,7 @@ c10::intrusive_ptr unpack( output_max); } -c10::intrusive_ptr unpack( +static c10::intrusive_ptr unpack( Tensor&& weight, std::optional&& bias, const std::optional& output_min, @@ -94,7 +92,7 @@ TORCH_LIBRARY(metal_prepack, m) { TORCH_SELECTIVE_SCHEMA("metal_prepack::linear_run(Tensor X, __torch__.torch.classes.metal.LinearOpContext W_prepack) -> Tensor Y")); } -c10::intrusive_ptr conv2d_prepack( +static c10::intrusive_ptr conv2d_prepack( Tensor&& weight, std::optional&& bias, std::vector&& stride, @@ -115,7 +113,7 @@ c10::intrusive_ptr conv2d_prepack( output_max); } -c10::intrusive_ptr linear_prepack( +static c10::intrusive_ptr linear_prepack( Tensor&& weight, std::optional&& bias, const std::optional& output_min, @@ -129,6 +127,4 @@ TORCH_LIBRARY_IMPL(metal_prepack, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("metal_prepack::linear_prepack"), TORCH_FN(linear_prepack)); } -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/MetalTensorImplStorage.h b/aten/src/ATen/native/metal/MetalTensorImplStorage.h index 1ac7d126de95..975827aee15a 100644 --- a/aten/src/ATen/native/metal/MetalTensorImplStorage.h +++ b/aten/src/ATen/native/metal/MetalTensorImplStorage.h @@ -1,9 +1,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { class MPSImageWrapper; class MetalTensorImplStorage final { @@ -42,6 +40,4 @@ class MetalTensorImplStorage final { std::shared_ptr _impl; }; -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/MetalTensorUtils.h b/aten/src/ATen/native/metal/MetalTensorUtils.h index 318da09d86b2..9663e59fb74d 100644 --- a/aten/src/ATen/native/metal/MetalTensorUtils.h +++ b/aten/src/ATen/native/metal/MetalTensorUtils.h @@ -10,9 +10,7 @@ typedef float16_t fp16_t; typedef uint16_t fp16_t; #endif -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { uint32_t batchSize(const Tensor& tensor); uint32_t channelsSize(const Tensor& tensor); @@ -70,6 +68,4 @@ static inline MetalCommandBuffer* getCommandBuffer( return cmdBuffer; } -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.h index 13264d097e92..346d58ace539 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.h +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.h @@ -20,10 +20,7 @@ } \ } while (false) -namespace at { -namespace native { -namespace metal { -namespace mpscnn { +namespace at::native::metal::mpscnn { struct LaunchParams { MTLSize threadsPerThreadgroup; @@ -71,7 +68,4 @@ static inline int computeMPSAlignOffset(int kernel, int pad) { return mps_offset - pt_offset; } -} -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal::mpscnn diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm index ff8ad447dd0f..90f4ed030000 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm @@ -1,11 +1,8 @@ #import -namespace at { -namespace native { -namespace metal { -namespace mpscnn { +namespace at::native::metal::mpscnn { -auto divRoundUp(uint x, uint y) -> uint { +static auto divRoundUp(uint x, uint y) -> uint { return (x + y - 1) / y; } @@ -14,7 +11,7 @@ LaunchParams spatialPointwiseKernelLaunchParams( MPSImage* im) { return spatialPointwiseKernelLaunchParams( pipeline, im.numberOfImages, im.featureChannels, im.height, im.width); -}; +} LaunchParams spatialPointwiseKernelLaunchParams( id pipeline, @@ -33,9 +30,6 @@ LaunchParams spatialPointwiseKernelLaunchParams( const auto threadsPerGrid = MTLSizeMake( width, height, numberOfImages * divRoundUp(featureChannels, 4)); return {threadsPerThreadgroup, threadgroupsPerGrid, threadsPerGrid}; -}; - -} -} -} } + +} // namespace at::native::metal::mpscnn diff --git a/aten/src/ATen/native/metal/ops/MetalAddmm.mm b/aten/src/ATen/native/metal/ops/MetalAddmm.mm index e0c196ac68b3..b10b2a4b81f3 100644 --- a/aten/src/ATen/native/metal/ops/MetalAddmm.mm +++ b/aten/src/ATen/native/metal/ops/MetalAddmm.mm @@ -12,12 +12,10 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor addmm( +static Tensor addmm( const Tensor& bias, const Tensor& input, const Tensor& weight, @@ -63,7 +61,7 @@ Tensor addmm( namespace prepack { -Tensor linear(const Tensor& input, LinearOpContext& context) { +static Tensor linear(const Tensor& input, LinearOpContext& context) { TORCH_CHECK(input.is_metal()); TORCH_CHECK(context.get_weight().device() == kCPU); TORCH_CHECK(context.get_weight().dim() == 4); @@ -126,7 +124,7 @@ Tensor linear(const Tensor& input, LinearOpContext& context) { return output; } -Tensor linear_run( +static Tensor linear_run( const Tensor& input, const c10::intrusive_ptr& op_context) { return linear(input, *op_context); @@ -142,6 +140,4 @@ Tensor linear_run( m.impl(TORCH_SELECTIVE_NAME("metal_prepack::linear_run"), TORCH_FN(prepack::linear_run)); } -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm b/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm index 0b5312632e1d..8505a89b9681 100644 --- a/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm +++ b/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm @@ -10,9 +10,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; @@ -58,7 +56,7 @@ static inline void checkInputs(const Tensor& input1, const Tensor& input2) { } } -Tensor binaryElementwiseShaderKernel( +static Tensor binaryElementwiseShaderKernel( const Tensor& input1, const Tensor& input2, const std::string& arrayKernel, @@ -98,7 +96,7 @@ Tensor binaryElementwiseShaderKernel( return output; } -Tensor& binaryElementwiseShaderKernel_( +static Tensor& binaryElementwiseShaderKernel_( Tensor& input1, const Tensor& input2, const std::string& arrayKernel, @@ -208,7 +206,7 @@ Tensor binaryElementwiseMPSCNNKernel( return input1; } -Tensor add_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) { +static Tensor add_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -219,7 +217,7 @@ Tensor add_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alph } } -Tensor& add__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) { +static Tensor& add__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -230,7 +228,7 @@ Tensor add_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alph } } -Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) { +static Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -241,7 +239,7 @@ Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alph } } -Tensor& sub__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) { +static Tensor& sub__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -252,7 +250,7 @@ Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alph } } -Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) { +static Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -263,7 +261,7 @@ Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) { } } -Tensor& mul__Tensor(Tensor& input1, const Tensor& input2) { +static Tensor& mul__Tensor(Tensor& input1, const Tensor& input2) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -274,7 +272,7 @@ Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) { } } -Tensor div_Tensor(const Tensor& input1, const Tensor& input2) { +static Tensor div_Tensor(const Tensor& input1, const Tensor& input2) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -285,7 +283,7 @@ Tensor div_Tensor(const Tensor& input1, const Tensor& input2) { } } -Tensor& div__Tensor(Tensor& input1, const Tensor& input2) { +static Tensor& div__Tensor(Tensor& input1, const Tensor& input2) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -305,8 +303,6 @@ Tensor div_Tensor(const Tensor& input1, const Tensor& input2) { m.impl(TORCH_SELECTIVE_NAME("aten::sub_.Tensor"), TORCH_FN(sub__Tensor)); m.impl(TORCH_SELECTIVE_NAME("aten::div.Tensor"), TORCH_FN(div_Tensor)); m.impl(TORCH_SELECTIVE_NAME("aten::div_.Tensor"), TORCH_FN(div__Tensor)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalChunk.mm b/aten/src/ATen/native/metal/ops/MetalChunk.mm index ee02b269a580..0011b065bf81 100644 --- a/aten/src/ATen/native/metal/ops/MetalChunk.mm +++ b/aten/src/ATen/native/metal/ops/MetalChunk.mm @@ -9,13 +9,11 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { // Split the input tensor into two on channel dimension // TODO: [T87567124] Fully implement chunk in Metal shader -std::vector chunk(const Tensor& input, int64_t chunks, int64_t dim) { +static std::vector chunk(const Tensor& input, int64_t chunks, int64_t dim) { TORCH_CHECK(chunks == 2 && dim == 1); TORCH_CHECK(input.dim() == 4); TORCH_CHECK(input.size(0) == 1); @@ -61,8 +59,6 @@ TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::chunk"), TORCH_FN(chunk)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalClamp.mm b/aten/src/ATen/native/metal/ops/MetalClamp.mm index b0eac2460ac3..4eedf3775028 100644 --- a/aten/src/ATen/native/metal/ops/MetalClamp.mm +++ b/aten/src/ATen/native/metal/ops/MetalClamp.mm @@ -8,11 +8,9 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { -Tensor& hardtanh_(Tensor& input, const Scalar& min_val, const Scalar& max_val) { +static Tensor& hardtanh_(Tensor& input, const Scalar& min_val, const Scalar& max_val) { TORCH_CHECK(input.is_metal()); MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); @@ -29,7 +27,7 @@ return input; } -Tensor hardtanh( +static Tensor hardtanh( const Tensor& input, const Scalar& min_val, const Scalar& max_val) { @@ -52,7 +50,7 @@ Tensor hardtanh( return output; } -at::Tensor clamp( +static at::Tensor clamp( const at::Tensor& input, const c10::optional& min, const c10::optional& max) { @@ -64,8 +62,6 @@ Tensor hardtanh( m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh_"), TORCH_FN(hardtanh_)); m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh"), TORCH_FN(hardtanh)); m.impl(TORCH_SELECTIVE_NAME("aten::clamp"), TORCH_FN(clamp)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalConcat.mm b/aten/src/ATen/native/metal/ops/MetalConcat.mm index be9d87d8fe5a..5de99046f2d0 100644 --- a/aten/src/ATen/native/metal/ops/MetalConcat.mm +++ b/aten/src/ATen/native/metal/ops/MetalConcat.mm @@ -12,11 +12,9 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { -Tensor cat_batch(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) { +static Tensor cat_batch(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) { MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor); MPSImage* Y = mt.texture()->image(); ushort cat_dim4_pointer = 0; @@ -53,7 +51,7 @@ Tensor cat_batch(const Tensor& tensor, const ITensorListRef& tensors, MetalTenso return output; } -Tensor cat_feature(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) { +static Tensor cat_feature(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) { MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor); MPSImage* Y = mt.texture()->image(); ushort channel_offset = 0; @@ -162,7 +160,7 @@ Tensor cat_feature(const Tensor& tensor, const ITensorListRef& tensors, MetalTen return output; } -Tensor cat(const ITensorListRef& tensors, int64_t dim) { +static Tensor cat(const ITensorListRef& tensors, int64_t dim) { TORCH_CHECK( dim == 0 || dim == 1, "Metal cat is implemented only for batch dimension"); @@ -203,6 +201,4 @@ Tensor cat(const ITensorListRef& tensors, int64_t dim) { m.impl(TORCH_SELECTIVE_NAME("aten::cat"), TORCH_FN(cat)); } -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalConvolution.h b/aten/src/ATen/native/metal/ops/MetalConvolution.h index 77053448cbcb..dc8192812d8c 100644 --- a/aten/src/ATen/native/metal/ops/MetalConvolution.h +++ b/aten/src/ATen/native/metal/ops/MetalConvolution.h @@ -2,9 +2,7 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { Tensor conv2d( const Tensor& input, @@ -19,6 +17,4 @@ namespace prepack { Tensor conv2d(const Tensor& input, Conv2dOpContext& context); } -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalConvolution.mm b/aten/src/ATen/native/metal/ops/MetalConvolution.mm index 46295abefae9..eb5d1f16fabb 100644 --- a/aten/src/ATen/native/metal/ops/MetalConvolution.mm +++ b/aten/src/ATen/native/metal/ops/MetalConvolution.mm @@ -9,9 +9,7 @@ #import -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; Tensor conv2d( @@ -97,7 +95,7 @@ Tensor conv2d(const Tensor& input, Conv2dOpContext& context) { return output; } -Tensor conv2d_prepack_run( +static Tensor conv2d_prepack_run( const Tensor& input, const c10::intrusive_ptr& op_context) { return conv2d(input, *op_context); @@ -115,6 +113,4 @@ Tensor conv2d_prepack_run( m.impl(TORCH_SELECTIVE_NAME("metal_prepack::conv2d_run"), prepack::conv2d_prepack_run); } -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalCopy.h b/aten/src/ATen/native/metal/ops/MetalCopy.h index fdee7acad4f4..2023d3c508e2 100644 --- a/aten/src/ATen/native/metal/ops/MetalCopy.h +++ b/aten/src/ATen/native/metal/ops/MetalCopy.h @@ -3,14 +3,10 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { Tensor copy_to_host(const Tensor& input); -} -} // namespace native -} // namespace at +} // namespace at::native::metal #endif diff --git a/aten/src/ATen/native/metal/ops/MetalCopy.mm b/aten/src/ATen/native/metal/ops/MetalCopy.mm index b1df48b5c89c..c4ce058f78ed 100644 --- a/aten/src/ATen/native/metal/ops/MetalCopy.mm +++ b/aten/src/ATen/native/metal/ops/MetalCopy.mm @@ -9,11 +9,9 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { -Tensor copy_to_host(const Tensor& input) { +static Tensor copy_to_host(const Tensor& input) { TORCH_CHECK(input.is_metal()); MPSImage* X = imageFromTensor(input); if (X && !X.isTemporaryImage) { @@ -52,8 +50,6 @@ Tensor copy_to_host(const Tensor& input) { TORCH_LIBRARY_IMPL(metal, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("metal::copy_to_host"), TORCH_FN(copy_to_host)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalHardshrink.mm b/aten/src/ATen/native/metal/ops/MetalHardshrink.mm index 4de506cb6526..05b6b585e7f0 100644 --- a/aten/src/ATen/native/metal/ops/MetalHardshrink.mm +++ b/aten/src/ATen/native/metal/ops/MetalHardshrink.mm @@ -9,15 +9,13 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; // NB: this is currently unused, but I've left it because in principle // it's useful -Tensor& hardshrink_(Tensor& input, const at::Scalar& lambda=0.5) { +static Tensor& hardshrink_(Tensor& input, const at::Scalar& lambda=0.5) { float l = lambda.toFloat(); MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); @@ -51,7 +49,7 @@ return input; } -Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) { +static Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) { float l = lambda.toFloat(); MPSImage* X = imageFromTensor(input); IntArrayRef outputSize = input.sizes(); @@ -87,8 +85,6 @@ Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) { TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalHardswish.mm b/aten/src/ATen/native/metal/ops/MetalHardswish.mm index 07706483c1ae..22d84d6c1bf0 100644 --- a/aten/src/ATen/native/metal/ops/MetalHardswish.mm +++ b/aten/src/ATen/native/metal/ops/MetalHardswish.mm @@ -9,13 +9,11 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; -Tensor& hardswish_(Tensor& input) { +static Tensor& hardswish_(Tensor& input) { MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); IntArrayRef outputSize = input.sizes(); @@ -47,7 +45,7 @@ return input; } -Tensor hardswish(const at::Tensor& input) { +static Tensor hardswish(const at::Tensor& input) { MPSImage* X = imageFromTensor(input); IntArrayRef outputSize = input.sizes(); MetalTensorImplStorage mt{outputSize.vec()}; @@ -82,8 +80,6 @@ Tensor hardswish(const at::Tensor& input) { TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::hardswish_"), TORCH_FN(hardswish_)); m.impl(TORCH_SELECTIVE_NAME("aten::hardswish"), TORCH_FN(hardswish)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm b/aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm index 2034a64d82d5..0bd476ffa4f5 100644 --- a/aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm +++ b/aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm @@ -9,13 +9,11 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; -Tensor& leaky_relu_(Tensor& input, const Scalar& negative_slope_val) { +static Tensor& leaky_relu_(Tensor& input, const Scalar& negative_slope_val) { MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); IntArrayRef outputSize = input.sizes(); @@ -49,7 +47,7 @@ return input; } -Tensor leaky_relu(const at::Tensor& input, const Scalar& negative_slope_val) { +static Tensor leaky_relu(const at::Tensor& input, const Scalar& negative_slope_val) { MPSImage* X = imageFromTensor(input); IntArrayRef outputSize = input.sizes(); MetalTensorImplStorage mt{outputSize.vec()}; @@ -86,8 +84,6 @@ Tensor leaky_relu(const at::Tensor& input, const Scalar& negative_slope_val) { TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu_"), TORCH_FN(leaky_relu_)); m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu"), TORCH_FN(leaky_relu)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalNeurons.mm b/aten/src/ATen/native/metal/ops/MetalNeurons.mm index ca925d9b841b..09944092f6a1 100644 --- a/aten/src/ATen/native/metal/ops/MetalNeurons.mm +++ b/aten/src/ATen/native/metal/ops/MetalNeurons.mm @@ -9,13 +9,11 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; -Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) { +static Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) { MPSImage* X = imageFromTensor(input); IntArrayRef outputSize = input.sizes(); if(input.numel() == 0){ @@ -33,7 +31,7 @@ Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) { return output; } -Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) { +static Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) { MPSImage* X = imageFromTensor(input); IntArrayRef outputSize = input.sizes(); if(input.numel() == 0){ @@ -52,30 +50,30 @@ Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) { } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor relu(const Tensor& input) { +static Tensor relu(const Tensor& input) { TORCH_CHECK(input.is_metal()); return neuronKernel(input, [MPSCNNNeuronOp relu]); } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor& relu_(Tensor& input) { +static Tensor& relu_(Tensor& input) { TORCH_CHECK(input.is_metal()); return neuronKernel_(input, [MPSCNNNeuronOp relu]); } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor sigmoid(const Tensor& input) { +static Tensor sigmoid(const Tensor& input) { return neuronKernel(input, [MPSCNNNeuronOp sigmoid]); } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor& hardsigmoid_(Tensor& input) { +static Tensor& hardsigmoid_(Tensor& input) { TORCH_CHECK(input.is_metal()); return neuronKernel_(input, [MPSCNNNeuronOp hardSigmoid]); } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor tanh(const Tensor& input) { +static Tensor tanh(const Tensor& input) { TORCH_CHECK(input.is_metal()); return neuronKernel(input, [MPSCNNNeuronOp tanh]); } @@ -86,8 +84,6 @@ Tensor tanh(const Tensor& input) { m.impl(TORCH_SELECTIVE_NAME("aten::relu_"), TORCH_FN(relu_)); m.impl(TORCH_SELECTIVE_NAME("aten::sigmoid"), TORCH_FN(sigmoid)); m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid_"), TORCH_FN(hardsigmoid_)); -}; - -} -} } + +} // namepsace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalPadding.mm b/aten/src/ATen/native/metal/ops/MetalPadding.mm index 748fa8f4b653..c924c40cc62b 100644 --- a/aten/src/ATen/native/metal/ops/MetalPadding.mm +++ b/aten/src/ATen/native/metal/ops/MetalPadding.mm @@ -9,12 +9,10 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor reflection_pad2d(const Tensor& input, IntArrayRef padding) { +static Tensor reflection_pad2d(const Tensor& input, IntArrayRef padding) { TORCH_CHECK(input.is_metal()); const int pad_dim = padding.size(); @@ -87,8 +85,6 @@ Tensor reflection_pad2d(const Tensor& input, IntArrayRef padding) { TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::reflection_pad2d"), TORCH_FN(reflection_pad2d)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalPooling.mm b/aten/src/ATen/native/metal/ops/MetalPooling.mm index 5e3b9110756e..a4d5c07f39fd 100644 --- a/aten/src/ATen/native/metal/ops/MetalPooling.mm +++ b/aten/src/ATen/native/metal/ops/MetalPooling.mm @@ -11,12 +11,10 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor max_pool2d( +static Tensor max_pool2d( const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, @@ -71,7 +69,7 @@ Tensor max_pool2d( } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor adaptive_avg_pool2d(const Tensor& input, IntArrayRef output_size) { +static Tensor adaptive_avg_pool2d(const Tensor& input, IntArrayRef output_size) { // averages across the width and height, and outputs a 1x1xC image. TORCH_CHECK(output_size[0] == 1 && output_size[1] == 1); TORCH_CHECK(input.is_metal()); @@ -108,6 +106,4 @@ Tensor adaptive_avg_pool2d(const Tensor& input, IntArrayRef output_size) { m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_avg_pool2d"), TORCH_FN(adaptive_avg_pool2d)); } -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalReduce.mm b/aten/src/ATen/native/metal/ops/MetalReduce.mm index b0da375809b8..3de3104f6f93 100644 --- a/aten/src/ATen/native/metal/ops/MetalReduce.mm +++ b/aten/src/ATen/native/metal/ops/MetalReduce.mm @@ -11,9 +11,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { API_AVAILABLE(ios(11.3), macos(10.13)) static inline MPSNNReduceUnary* kernelForReducedDim(int dim) { @@ -28,7 +26,7 @@ return nil; } -Tensor wrapper_mean_dim( +static Tensor wrapper_mean_dim( const Tensor& input, OptionalIntArrayRef opt_dims, bool keepdim, @@ -82,6 +80,4 @@ Tensor wrapper_mean_dim( m.impl(TORCH_SELECTIVE_NAME("aten::mean.dim"), TORCH_FN(wrapper_mean_dim)); }; -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalReshape.mm b/aten/src/ATen/native/metal/ops/MetalReshape.mm index a4336d1b92d4..de224018eb7c 100644 --- a/aten/src/ATen/native/metal/ops/MetalReshape.mm +++ b/aten/src/ATen/native/metal/ops/MetalReshape.mm @@ -11,12 +11,10 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) { +static Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) { auto size = C10_AS_INTARRAYREF_SLOW(sym_size); TORCH_CHECK(input.is_metal()); auto inferred_size = at::infer_size(size, input.numel()); @@ -62,12 +60,12 @@ Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) { return output; } -Tensor reshape(const Tensor& input, IntArrayRef shape) { +static Tensor reshape(const Tensor& input, IntArrayRef shape) { TORCH_CHECK(input.is_metal()); return view(input, c10::fromIntArrayRefSlow(shape)); } -Tensor flatten_using_ints( +static Tensor flatten_using_ints( const Tensor& input, int64_t start_dim, int64_t end_dim) { @@ -97,7 +95,7 @@ Tensor flatten_using_ints( return input.reshape(shape); } -Tensor detach(const Tensor& input) { +static Tensor detach(const Tensor& input) { TORCH_CHECK(input.is_metal()); return input; } @@ -107,8 +105,6 @@ Tensor detach(const Tensor& input) { m.impl(TORCH_SELECTIVE_NAME("aten::view"), TORCH_FN(view)); m.impl(TORCH_SELECTIVE_NAME("aten::reshape"), TORCH_FN(reshape)); m.impl(TORCH_SELECTIVE_NAME("aten::flatten.using_ints"), TORCH_FN(flatten_using_ints)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalSoftmax.mm b/aten/src/ATen/native/metal/ops/MetalSoftmax.mm index 11ebe255953f..6ec8f60f3ae7 100644 --- a/aten/src/ATen/native/metal/ops/MetalSoftmax.mm +++ b/aten/src/ATen/native/metal/ops/MetalSoftmax.mm @@ -10,9 +10,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { template Tensor mpscnn_softmax( @@ -50,14 +48,14 @@ Tensor mpscnn_softmax( return output; } -Tensor log_softmax_int( +static Tensor log_softmax_int( const Tensor& input, int64_t dim, c10::optional dtype) { return mpscnn_softmax(input, dim, dtype); } -Tensor softmax_int( +static Tensor softmax_int( const Tensor& input, int64_t dim, c10::optional dtype) { @@ -69,6 +67,4 @@ Tensor softmax_int( m.impl(TORCH_SELECTIVE_NAME("aten::softmax.int"), TORCH_FN(metal::softmax_int)); }; -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalTranspose.mm b/aten/src/ATen/native/metal/ops/MetalTranspose.mm index e1b57a2a4019..d0df9f7596e6 100644 --- a/aten/src/ATen/native/metal/ops/MetalTranspose.mm +++ b/aten/src/ATen/native/metal/ops/MetalTranspose.mm @@ -10,9 +10,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { // TODO: Move this function to MetalContext template @@ -24,7 +22,7 @@ return buffer; } -Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) { +static Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) { TORCH_CHECK(input.is_metal()); auto ndims = input.dim(); // Support maximum eight channels on mobile @@ -87,7 +85,7 @@ Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) { } } -Tensor t(const Tensor& input) { +static Tensor t(const Tensor& input) { TORCH_CHECK(input.is_metal()); TORCH_CHECK(input.dim() == 2); return metal::transpose(input, 0, input.dim() < 2 ? 0 : 1); @@ -98,6 +96,4 @@ Tensor t(const Tensor& input) { m.impl(TORCH_SELECTIVE_NAME("aten::transpose.int"), TORCH_FN(transpose)); }; -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm b/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm index 39524569bae5..165e139c886d 100644 --- a/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm +++ b/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm @@ -11,11 +11,9 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { -Tensor upsample_nearest2d_vec( +static Tensor upsample_nearest2d_vec( const Tensor& input, at::OptionalIntArrayRef output_size, c10::optional> scale_factors) { @@ -70,6 +68,4 @@ Tensor upsample_nearest2d_vec( m.impl(TORCH_SELECTIVE_NAME("aten::upsample_nearest2d.vec"), TORCH_FN(upsample_nearest2d_vec)); }; -} -} -} +} // namespace at::native::metal diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index fe24571b66a5..69446aa34a57 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -772,7 +772,7 @@ if(NOT MSVC) set_source_files_properties(${PROJECT_SOURCE_DIR}/torch/csrc/distributed/c10d/socket.cpp PROPERTIES COMPILE_OPTIONS "-Wno-error=deprecated") endif() -if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_VULKAN AND NOT USE_IOS AND NOT USE_PYTORCH_METAL AND NOT USE_COREML_DELEGATE) +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_VULKAN AND NOT USE_IOS AND NOT USE_COREML_DELEGATE) target_compile_options_if_supported(torch_cpu "-Wmissing-prototypes") target_compile_options_if_supported(torch_cpu "-Werror=missing-prototypes") get_target_property(TORCH_CPU_SOURCES torch_cpu SOURCES) From df75a9dc801d8022190c347171829bb8a26e0a2c Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 5 Jun 2024 15:10:12 +0000 Subject: [PATCH 366/706] Remove Caffe2/onnx (#127991) Remove Caffe2/onnx since it is not used. Other tiny fixes are also applied. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127991 Approved by: https://github.com/ezyang --- caffe2/BUILD_MODE.bzl | 23 ---- caffe2/CMakeLists.txt | 7 -- caffe2/__init__.py | 5 - caffe2/onnx/torch_ops/CMakeLists.txt | 5 - caffe2/onnx/torch_ops/constants.h | 7 -- caffe2/onnx/torch_ops/defs.cc | 168 -------------------------- caffe2/onnx/torch_ops/operator_sets.h | 46 ------- caffe2/onnx/torch_ops/schema.cc | 17 --- caffe2/onnx/torch_ops/schema.h | 8 -- cmake/Dependencies.cmake | 2 - 10 files changed, 288 deletions(-) delete mode 100644 caffe2/BUILD_MODE.bzl delete mode 100644 caffe2/__init__.py delete mode 100644 caffe2/onnx/torch_ops/CMakeLists.txt delete mode 100644 caffe2/onnx/torch_ops/constants.h delete mode 100644 caffe2/onnx/torch_ops/defs.cc delete mode 100644 caffe2/onnx/torch_ops/operator_sets.h delete mode 100644 caffe2/onnx/torch_ops/schema.cc delete mode 100644 caffe2/onnx/torch_ops/schema.h diff --git a/caffe2/BUILD_MODE.bzl b/caffe2/BUILD_MODE.bzl deleted file mode 100644 index 1fbd3e6f7a47..000000000000 --- a/caffe2/BUILD_MODE.bzl +++ /dev/null @@ -1,23 +0,0 @@ -""" build mode definitions for caffe2/caffe2 """ - -load("@fbcode//:BUILD_MODE.bzl", get_parent_modes = "all_modes_keep_gpu_sections_all_modes_use_lld") -load("@fbcode_macros//build_defs:create_build_mode.bzl", "extend_build_mode") - -def update_mode_struct(name, mode_struct): - if name == "dev": - return extend_build_mode( - mode_struct, - # TODO(ipbrady): Modules introduce floating point inaccuracies (T43879333) - cxx_modules = False, - ) - else: - return mode_struct - -_modes = { - mode_name: update_mode_struct(mode_name, mode_struct) - for mode_name, mode_struct in get_parent_modes().items() -} - -def get_modes(): - """ Return modes for this file """ - return _modes diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 69446aa34a57..458fa26f1b3e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1930,7 +1930,6 @@ if(BUILD_PYTHON) if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") # Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80947 in EmbeddingBag.cpp set_source_files_properties(../aten/src/ATen/native/EmbeddingBag.cpp PROPERTIES COMPILE_FLAGS -Wno-attributes) - set_source_files_properties(${TORCH_SRC_DIR}/../caffe2/operators/box_with_nms_limit_op.cc PROPERTIES COMPILE_FLAGS -Wno-attributes) endif() set(build_files) @@ -1950,10 +1949,4 @@ if(BUILD_PYTHON) # Pick up static python files install(DIRECTORY ${CMAKE_BINARY_DIR}/caffe2 DESTINATION ${PYTHON_LIB_REL_PATH} FILES_MATCHING PATTERN "*.py") - # Caffe proto files - install(DIRECTORY ${CMAKE_BINARY_DIR}/caffe DESTINATION ${PYTHON_LIB_REL_PATH} - FILES_MATCHING PATTERN "*.py") - # Caffe2 proto files - install(DIRECTORY ${CMAKE_BINARY_DIR}/caffe2 DESTINATION ${PYTHON_LIB_REL_PATH} - FILES_MATCHING PATTERN "*.py") endif() diff --git a/caffe2/__init__.py b/caffe2/__init__.py deleted file mode 100644 index f319e8e2dc15..000000000000 --- a/caffe2/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -import warnings -from torch.onnx import _CAFFE2_ATEN_FALLBACK - -if not _CAFFE2_ATEN_FALLBACK: - warnings.warn("Caffe2 support is no longer present in PyTorch.") diff --git a/caffe2/onnx/torch_ops/CMakeLists.txt b/caffe2/onnx/torch_ops/CMakeLists.txt deleted file mode 100644 index 99443af4cc9b..000000000000 --- a/caffe2/onnx/torch_ops/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -# ---[ Extra onnx files. -file(GLOB ONNX_SRCS *.cc) - -# ---[ Send the lists to the parent scope. -set(ONNX_SRCS ${ONNX_SRCS} PARENT_SCOPE) diff --git a/caffe2/onnx/torch_ops/constants.h b/caffe2/onnx/torch_ops/constants.h deleted file mode 100644 index ebd2a2464d9b..000000000000 --- a/caffe2/onnx/torch_ops/constants.h +++ /dev/null @@ -1,7 +0,0 @@ -namespace ONNX_NAMESPACE { - -const int AI_ONNX_PYTORCH_DOMAIN_MIN_OPSET = 1; -const int AI_ONNX_PYTORCH_DOMAIN_MAX_OPSET = 1; -constexpr const char* AI_ONNX_PYTORCH_DOMAIN = "ai.onnx.pytorch"; - -} // namespace ONNX_NAMESPACE diff --git a/caffe2/onnx/torch_ops/defs.cc b/caffe2/onnx/torch_ops/defs.cc deleted file mode 100644 index a324cce6f284..000000000000 --- a/caffe2/onnx/torch_ops/defs.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) Facebook Inc. and Microsoft Corporation. -// Licensed under the MIT license. - -#include "./schema.h" - -namespace ONNX_NAMESPACE { - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - SparseLengthsSumFused8BitRowwise, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 SparseLengthsSumFused8BitRowwise operator") - .Input(0, "DATA", "data tensor", "T1") - .Input(1, "INDICES", "indices tensor", "T2") - .Input(2, "LENGTHS", "lengths tensor", "T2") - .Output(0, "output", "Output tensor", "T2") - .TypeConstraint( - "T1", - {"tensor(uint8)"}, - "Constrain input data to uint8 tensors.") - .TypeConstraint( - "T2", - {"tensor(int8)", - "tensor(int16)", - "tensor(int32)", - "tensor(int64)", - "tensor(uint8)", - "tensor(uint16)", - "tensor(uint32)", - "tensor(uint64)"}, - "Constrain index and length to integral tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - SparseLengthsSum, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 SparseLengthsSum operator") - .Input(0, "DATA", "data tensor", "T1") - .Input(1, "INDICES", "indices tensor", "T2") - .Input(2, "LENGTHS", "lengths tensor", "T2") - .Output(0, "output", "Output tensor", "T1") - .TypeConstraint( - "T1", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.") - .TypeConstraint( - "T2", - {"tensor(int8)", - "tensor(int16)", - "tensor(int32)", - "tensor(int64)", - "tensor(uint8)", - "tensor(uint16)", - "tensor(uint32)", - "tensor(uint64)"}, - "Constrain index and length to integral tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - SparseLengthsWeightedSum, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 SparseLengthsWeightedSum operator") - .Input(0, "DATA", "data tensor", "T1") - .Input(1, "WEIGHTS", "data tensor", "T1") - .Input(2, "INDICES", "indices tensor", "T2") - .Input(3, "LENGTHS", "lengths tensor", "T2") - .Output(0, "output", "Output tensor", "T1") - .TypeConstraint( - "T1", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.") - .TypeConstraint( - "T2", - {"tensor(int8)", - "tensor(int16)", - "tensor(int32)", - "tensor(int64)", - "tensor(uint8)", - "tensor(uint16)", - "tensor(uint32)", - "tensor(uint64)"}, - "Constrain index and length to integral tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - BatchGather, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 BatchGather operator") - .Input(0, "DATA", "data tensor", "T1") - .Input(1, "INDICES", "indices tensor", "T2") - .Output(0, "output", "Output tensor", "T1") - .TypeConstraint( - "T1", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.") - .TypeConstraint( - "T2", - {"tensor(int8)", - "tensor(int16)", - "tensor(int32)", - "tensor(int64)", - "tensor(uint8)", - "tensor(uint16)", - "tensor(uint32)", - "tensor(uint64)"}, - "Constrain index and length to integral tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - DotProduct, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 DotProduct operator") - .Input(0, "X", "Input 1 tensor", "T") - .Input(1, "Y", "Input 2 tensor", "T") - .Output(0, "Z", "Output tensor", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - FCTransposed, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 FCTransposed operator") - .Input(0, "X", "Input tensor", "T") - .Input(1, "W", "Weight tensor", "T") - .Input(2, "B", "Bias tensor", "T") - .Output(0, "Z", "Output tensor", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - BatchMatMul, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 BatchMatMul operator") - .Input(0, "X", "tensor of shape (dim0, dim1 ... M, K)", "T") - .Input(1, "Y", "tensor of shape (dim0, dim2 ... K, N)", "T") - .Output(0, "Z", "tensor of shape (dim0, dim1 ... M, N)", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - ExpandDims, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 ExpandDims operator") - .Input(0, "X", "Input tensor", "T") - .Output(0, "Y", "Output tensor", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.")); - -} // namespace ONNX_NAMESPACE diff --git a/caffe2/onnx/torch_ops/operator_sets.h b/caffe2/onnx/torch_ops/operator_sets.h deleted file mode 100644 index f7380af3910f..000000000000 --- a/caffe2/onnx/torch_ops/operator_sets.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include "onnx/defs/schema.h" - -namespace ONNX_NAMESPACE { - -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( - PyTorch, - 1, - SparseLengthsSumFused8BitRowwise); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsSum); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsWeightedSum); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchGather); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, DotProduct); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, FCTransposed); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchMatMul); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, ExpandDims); - -// Iterate over schema from ai.onnx.pytorch domain opset 1 -class OpSet_PyTorch_ver1 { - public: - static void ForEachSchema(std::function fn) { - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - } -}; - -inline void RegisterPyTorchOperatorSetSchema() { - RegisterOpSetSchema(); -} - -} // namespace ONNX_NAMESPACE diff --git a/caffe2/onnx/torch_ops/schema.cc b/caffe2/onnx/torch_ops/schema.cc deleted file mode 100644 index de933c2c23ab..000000000000 --- a/caffe2/onnx/torch_ops/schema.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "./schema.h" -#include "./operator_sets.h" - -namespace { -using namespace ONNX_NAMESPACE; -class PyTorchSchemasRegisterer { - public: - PyTorchSchemasRegisterer() { - OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion( - AI_ONNX_PYTORCH_DOMAIN, - AI_ONNX_PYTORCH_DOMAIN_MIN_OPSET, - AI_ONNX_PYTORCH_DOMAIN_MAX_OPSET); - RegisterPyTorchOperatorSetSchema(); - } -}; -static PyTorchSchemasRegisterer registerer{}; -} // namespace diff --git a/caffe2/onnx/torch_ops/schema.h b/caffe2/onnx/torch_ops/schema.h deleted file mode 100644 index 3454e366a1ee..000000000000 --- a/caffe2/onnx/torch_ops/schema.h +++ /dev/null @@ -1,8 +0,0 @@ -#pragma once - -#include "./constants.h" -#include "onnx/defs/schema.h" - -#define ONNX_PYTORCH_OPERATOR_SET_SCHEMA(name, ver, impl) \ - ONNX_OPERATOR_SET_SCHEMA_EX( \ - name, PyTorch, AI_ONNX_PYTORCH_DOMAIN, ver, false, impl) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 9693ac6e9fe6..3a57dd64c6af 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1282,8 +1282,6 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX) add_definitions(-DONNX_ML=1) endif() add_definitions(-DONNXIFI_ENABLE_EXT=1) - # Add op schemas in "ai.onnx.pytorch" domain - add_subdirectory("${CMAKE_CURRENT_LIST_DIR}/../caffe2/onnx/torch_ops") if(NOT USE_SYSTEM_ONNX) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/onnx EXCLUDE_FROM_ALL) if(NOT MSVC) From 3d617333e70071f81f141c6f3aaf8935f5e2b210 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 5 Jun 2024 15:17:31 +0000 Subject: [PATCH 367/706] Simplify CMake code (#127683) Due to the recent adoption of find(python), it is possible to further simplify some CMake code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127683 Approved by: https://github.com/ezyang --- aten/src/ATen/CMakeLists.txt | 1 + caffe2/CMakeLists.txt | 48 +++--------------------------------- cmake/MiscCheck.cmake | 8 ------ cmake/Modules/FindAVX.cmake | 1 + 4 files changed, 6 insertions(+), 52 deletions(-) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 9fa7a1f2305b..0087dd95d96e 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -386,6 +386,7 @@ if(UNIX AND NOT APPLE) endif(UNIX AND NOT APPLE) if(UNIX) + include(CheckFunctionExists) set(CMAKE_EXTRA_INCLUDE_FILES "sys/mman.h") CHECK_FUNCTION_EXISTS(mmap HAVE_MMAP) if(HAVE_MMAP) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 458fa26f1b3e..2a58b15e8d5e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -346,7 +346,7 @@ add_custom_command( OUTPUT ${TORCH_GENERATED_CODE} COMMAND - "${Python_EXECUTABLE}" tools/setup_helpers/generate_code.py + Python::Interpreter tools/setup_helpers/generate_code.py --native-functions-path "aten/src/ATen/native/native_functions.yaml" --tags-path "aten/src/ATen/native/tags.yaml" $<$:--disable-autograd> @@ -1094,7 +1094,7 @@ if(BUILD_LITE_INTERPRETER AND SELECTED_OP_LIST) add_custom_command( OUTPUT ${CMAKE_BINARY_DIR}/aten/src/ATen/selected_mobile_ops.h COMMAND - "${Python_EXECUTABLE}" + Python::Interpreter -m tools.code_analyzer.gen_oplist --model_file_list_path "${SELECTED_OP_LIST}" --output_dir "${CMAKE_BINARY_DIR}/aten/src/ATen" @@ -1109,7 +1109,7 @@ if(BUILD_LITE_INTERPRETER AND SELECTED_OP_LIST) add_custom_command( OUTPUT ${CMAKE_BINARY_DIR}/aten/src/ATen/selected_mobile_ops.h COMMAND - "${Python_EXECUTABLE}" + Python::Interpreter -m tools.lite_interpreter.gen_selected_mobile_ops_header --yaml_file_path "${SELECTED_OP_LIST}" --output_file_path "${CMAKE_BINARY_DIR}/aten/src/ATen" @@ -1887,50 +1887,10 @@ endif() # only rerun when needed. if(BUILD_PYTHON) - # Python site-packages - # Get canonical directory for python site packages (relative to install - # location). It varies from system to system. - # We should pin the path separator to the forward slash on Windows. - # More details can be seen at - # https://github.com/pytorch/pytorch/tree/main/tools/build_pytorch_libs.bat#note-backslash-munging-on-windows - pycmd(PYTHON_SITE_PACKAGES " - import os - import sysconfig - relative_site_packages = sysconfig.get_path('purelib').replace(sysconfig.get_path('data'), '').lstrip(os.path.sep) - print(relative_site_packages) - ") - file(TO_CMAKE_PATH ${PYTHON_SITE_PACKAGES} PYTHON_SITE_PACKAGES) - set(PYTHON_SITE_PACKAGES ${PYTHON_SITE_PACKAGES} PARENT_SCOPE) # for Summary # ---[ Options. - set(PYTHON_LIB_REL_PATH "${PYTHON_SITE_PACKAGES}" CACHE STRING "Python installation path (relative to CMake installation prefix)") + set(PYTHON_LIB_REL_PATH "${Python_SITELIB}" CACHE STRING "Python installation path (relative to CMake installation prefix)") message(STATUS "Using ${PYTHON_LIB_REL_PATH} as python relative installation path") - # Python extension suffix - # Try to get from python through sysconfig.get_env_var('EXT_SUFFIX') first, - # fallback to ".pyd" if windows and ".so" for all others. - pycmd(PY_EXT_SUFFIX " - def get_ext_suffix(): - import sys - import sysconfig - return sysconfig.get_config_var('EXT_SUFFIX') - - suffix = get_ext_suffix() - if suffix is not None: - print(suffix) - else: - print() - ") - if("${PY_EXT_SUFFIX}" STREQUAL "") - if(MSVC) - set(PY_EXT_SUFFIX ".pyd") - else() - set(PY_EXT_SUFFIX ".so") - endif() - endif() - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - # Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80947 in EmbeddingBag.cpp - set_source_files_properties(../aten/src/ATen/native/EmbeddingBag.cpp PROPERTIES COMPILE_FLAGS -Wno-attributes) - endif() set(build_files) foreach(python_src ${PYTHON_SRCS}) diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 71d73866b2af..433d96ebfd23 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -1,11 +1,3 @@ -if(UNIX) - # prevent Unknown CMake command "check_function_exists". - include(CheckFunctionExists) -endif() -include(CheckIncludeFile) -include(CheckCSourceCompiles) -include(CheckCSourceRuns) -include(CheckCCompilerFlag) include(CheckCXXSourceCompiles) include(CheckCXXCompilerFlag) include(CMakePushCheckState) diff --git a/cmake/Modules/FindAVX.cmake b/cmake/Modules/FindAVX.cmake index 9604723e2cd3..1497f951402f 100644 --- a/cmake/Modules/FindAVX.cmake +++ b/cmake/Modules/FindAVX.cmake @@ -1,4 +1,5 @@ INCLUDE(CheckCSourceRuns) +INCLUDE(CheckCSourceCompiles) INCLUDE(CheckCXXSourceRuns) SET(AVX_CODE " From 9f2c4b9342bfedee6315e6cf9bbbee395707e287 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 5 Jun 2024 15:22:45 +0000 Subject: [PATCH 368/706] Replace with standard type traits in torch/csrc (#127852) In preparation to clean up more type traits. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127852 Approved by: https://github.com/ezyang --- torch/csrc/api/include/torch/data/dataloader.h | 10 +++++----- torch/csrc/api/include/torch/data/datasets/map.h | 4 ++-- .../api/include/torch/nn/modules/container/any.h | 2 +- .../torch/nn/modules/container/any_module_holder.h | 4 ++-- .../include/torch/nn/modules/container/any_value.h | 3 ++- .../torch/nn/modules/container/functional.h | 2 +- .../torch/nn/modules/container/sequential.h | 10 ++++------ torch/csrc/api/include/torch/nn/pimpl-inl.h | 10 ++++++---- torch/csrc/api/include/torch/python.h | 7 +++---- torch/csrc/jit/frontend/tracer.h | 10 +++++----- torch/csrc/jit/ir/named_value.h | 14 +++++++------- 11 files changed, 38 insertions(+), 38 deletions(-) diff --git a/torch/csrc/api/include/torch/data/dataloader.h b/torch/csrc/api/include/torch/data/dataloader.h index 06ea83d8a232..a7bbdcb27d84 100644 --- a/torch/csrc/api/include/torch/data/dataloader.h +++ b/torch/csrc/api/include/torch/data/dataloader.h @@ -18,8 +18,8 @@ namespace data { /// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and /// some `options`. template -torch::disable_if_t< - Dataset::is_stateful, +std::enable_if_t< + !Dataset::is_stateful, std::unique_ptr>> make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) { return std::make_unique>( @@ -30,8 +30,8 @@ make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) { /// `options`. A sampler (by default a `RandomSampler`) will be constructed from /// the size of the dataset. template -torch::disable_if_t< - Dataset::is_stateful || !std::is_constructible::value, +std::enable_if_t< + !Dataset::is_stateful && std::is_constructible_v, std::unique_ptr>> make_data_loader( Dataset dataset, @@ -46,7 +46,7 @@ make_data_loader( } /// Creates a `DataLoader` for a stateful `dataset` and some `options`. -template > +template > std::unique_ptr> make_data_loader( Dataset dataset, DataLoaderOptions options = DataLoaderOptions()) { diff --git a/torch/csrc/api/include/torch/data/datasets/map.h b/torch/csrc/api/include/torch/data/datasets/map.h index 7b8b8febd222..facd4fe28705 100644 --- a/torch/csrc/api/include/torch/data/datasets/map.h +++ b/torch/csrc/api/include/torch/data/datasets/map.h @@ -71,7 +71,7 @@ class MapDataset : public BatchDataset< /// applies the transform to the output of `get_batch()` from the dataset. template < typename D = SourceDataset, - typename = torch::disable_if_t> + typename = std::enable_if_t> OutputBatchType get_batch_impl(BatchRequestType indices) { return transform_.apply_batch(dataset_.get_batch(std::move(indices))); } @@ -82,7 +82,7 @@ class MapDataset : public BatchDataset< /// contains a value, and returns a new optional (of a different type) if the /// original optional returned by `get_batch()` was empty. template - torch::enable_if_t get_batch_impl( + std::enable_if_t get_batch_impl( BatchRequestType indices) { if (auto batch = dataset_.get_batch(std::move(indices))) { return transform_.apply_batch(std::move(*batch)); diff --git a/torch/csrc/api/include/torch/nn/modules/container/any.h b/torch/csrc/api/include/torch/nn/modules/container/any.h index 05983b1ea106..35d9c91b8ca3 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any.h @@ -340,7 +340,7 @@ std::unique_ptr AnyModule::make_holder( "AnyModule cannot store modules that return void " "(you can return a dummy value)."); return std::make_unique< - AnyModuleHolder, ArgumentTypes...>>( + AnyModuleHolder, ArgumentTypes...>>( std::move(module)); } diff --git a/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h b/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h index cd1dca9ff7a0..4d1e69650035 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h @@ -40,10 +40,10 @@ struct AnyModuleHolder : public AnyModulePlaceholder { /// \internal struct CheckedGetter { template - decay_t&& operator()(size_t index) { + std::decay_t&& operator()(size_t index) { AT_ASSERT(index < arguments_.size()); auto& value = arguments_[index]; - if (auto* maybe_value = value.template try_get>()) { + if (auto* maybe_value = value.template try_get>()) { return std::move(*maybe_value); } AT_ERROR( diff --git a/torch/csrc/api/include/torch/nn/modules/container/any_value.h b/torch/csrc/api/include/torch/nn/modules/container/any_value.h index 3e6c23ef977c..d154130618f2 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any_value.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any_value.h @@ -40,7 +40,8 @@ class AnyValue { template // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) explicit AnyValue(T&& value) - : content_(std::make_unique>>(std::forward(value))) { + : content_( + std::make_unique>>(std::forward(value))) { } /// Returns a pointer to the value contained in the `AnyValue` if the type diff --git a/torch/csrc/api/include/torch/nn/modules/container/functional.h b/torch/csrc/api/include/torch/nn/modules/container/functional.h index dbd2b0aaebdc..3f381a63944f 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/functional.h +++ b/torch/csrc/api/include/torch/nn/modules/container/functional.h @@ -65,7 +65,7 @@ class TORCH_API FunctionalImpl : public torch::nn::Cloneable { template < typename SomeFunction, typename... Args, - typename = torch::enable_if_t<(sizeof...(Args) > 0)>> + typename = std::enable_if_t<(sizeof...(Args) > 0)>> explicit FunctionalImpl(SomeFunction original_function, Args&&... args) // NOLINTNEXTLINE(modernize-avoid-bind) : function_(std::bind( diff --git a/torch/csrc/api/include/torch/nn/modules/container/sequential.h b/torch/csrc/api/include/torch/nn/modules/container/sequential.h index 9494926eef3c..4007e2cfd801 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/container/sequential.h @@ -219,7 +219,7 @@ class SequentialImpl : public Cloneable { /// and letting the container deal with the boxing. template > void push_back(std::string name, M&& module) { - using Type = typename std::remove_reference::type; + using Type = typename std::remove_reference_t; push_back(std::move(name), std::make_shared(std::forward(module))); } @@ -348,12 +348,10 @@ class SequentialImpl : public Cloneable { typename First, typename Second, typename... Rest, - typename = torch::disable_if_t< - std::is_same::value || + typename = std::enable_if_t< + !std::is_same_v && // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - std::is_same< - typename std::decay::type, - std::decay::type>::value>> + !std::is_same_v, std::decay_t>>> void push_back(First&& first, Second&& second, Rest&&... rest) { push_back(std::forward(first)); // Recursively calls this method, until the parameter pack only thas this diff --git a/torch/csrc/api/include/torch/nn/pimpl-inl.h b/torch/csrc/api/include/torch/nn/pimpl-inl.h index b38e6cf2c0ff..cea53b6562bd 100644 --- a/torch/csrc/api/include/torch/nn/pimpl-inl.h +++ b/torch/csrc/api/include/torch/nn/pimpl-inl.h @@ -6,10 +6,12 @@ struct ModuleHolderIndicator {}; // A type trait that is true for types that are `ModuleHolder`s. template -using is_module_holder = std::is_base_of>; +using is_module_holder = + std::is_base_of>; template -using disable_if_module_holder_t = disable_if_t::value>; +using disable_if_module_holder_t = + std::enable_if_t::value>; // A collection of templates that answer the question whether a type `T` is a // `ModuleHolder`, and if so whether its contained type is of type `C`. This is @@ -43,8 +45,8 @@ struct is_module_holder_of_impl template struct is_module_holder_of : is_module_holder_of_impl< is_module_holder::value, - decay_t, - decay_t> {}; + std::decay_t, + std::decay_t> {}; // A collection of templates that allow deducing the return type of the // `forward()` method, but only if a module actually has a `forward()` method, diff --git a/torch/csrc/api/include/torch/python.h b/torch/csrc/api/include/torch/python.h index 15902a026cf5..cc9d6a51a6de 100644 --- a/torch/csrc/api/include/torch/python.h +++ b/torch/csrc/api/include/torch/python.h @@ -212,8 +212,8 @@ py::class_ add_module_bindings( /// } /// \endrst template -torch::disable_if_t< - torch::detail::has_forward::value && !force_enable, +std::enable_if_t< + !torch::detail::has_forward::value || force_enable, detail::PyModuleClass> bind_module(py::module module, const char* name) { py::module cpp = module.def_submodule("cpp"); @@ -249,8 +249,7 @@ bind_module(py::module module, const char* name) { /// \endrst template < typename ModuleType, - typename = - torch::enable_if_t::value>> + typename = std::enable_if_t::value>> detail::PyModuleClass bind_module( py::module module, const char* name) { diff --git a/torch/csrc/jit/frontend/tracer.h b/torch/csrc/jit/frontend/tracer.h index a1cc856a22e1..fef018dc7388 100644 --- a/torch/csrc/jit/frontend/tracer.h +++ b/torch/csrc/jit/frontend/tracer.h @@ -381,12 +381,12 @@ TORCH_API void ensureUniqueIfOutOfPlaced( template < typename T, - typename = torch::enable_if_t< - (!std::is_convertible_v, at::TensorList> && - !std::is_convertible_v, c10::List> && - !std::is_convertible_v, at::Tensor> && + typename = std::enable_if_t< + (!std::is_convertible_v, at::TensorList> && + !std::is_convertible_v, c10::List> && + !std::is_convertible_v, at::Tensor> && !std::is_convertible_v< - torch::decay_t, + std::decay_t, c10::intrusive_ptr>)>> void addOutput(Node* node, T&&) { AT_ERROR( diff --git a/torch/csrc/jit/ir/named_value.h b/torch/csrc/jit/ir/named_value.h index 277e7f269969..a594b4d045e9 100644 --- a/torch/csrc/jit/ir/named_value.h +++ b/torch/csrc/jit/ir/named_value.h @@ -30,18 +30,18 @@ struct NamedValue { template < typename T, - typename = enable_if_t< - (!std::is_same, NamedValue>::value && - !std::is_same, Value*>::value && - !std::is_same, IValue>::value)>> + typename = std::enable_if_t< + (!std::is_same_v, NamedValue> && + !std::is_same_v, Value*> && + !std::is_same_v, IValue>)>> // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) NamedValue(T&& t) : NamedValue(IValue(std::forward(t))) {} template < typename T, - typename = enable_if_t< - (!std::is_same, Value*>::value && - !std::is_same, IValue>::value)>> + typename = std::enable_if_t< + (!std::is_same_v, Value*> && + !std::is_same_v, IValue>)>> NamedValue(const std::string& name, T&& t) : NamedValue(name, IValue(std::forward(t))) {} From 4f9fcd71562abcc46a2b1b5eacb3ea1bccadc8b2 Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Wed, 5 Jun 2024 15:27:13 +0000 Subject: [PATCH 369/706] Handle unpacking during TorchScript to ExportedProgram conversion (#127419) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127419 Approved by: https://github.com/angelayi --- test/export/test_converter.py | 17 +++++++++++++++++ torch/_export/converter.py | 14 ++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index d59bb0ebf8f7..8d0b2d6f93c6 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -247,6 +247,23 @@ def forward( inp = (torch.randn(10, 10), torch.rand(10, 10)) self._check_equal_ts_ep_converter(Module(), inp) + def test_ts2ep_converter_unpack(self): + class MUnpackList(torch.nn.Module): + def forward(self, x): + x, y = torch.split(x, 2) + return x + y + + class MUnpackTuple(torch.nn.Module): + def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]): + x, y = x_tuple + x = x.cos() + return x + y + + inp = torch.ones(1, 4) + self._check_equal_ts_ep_converter(MUnpackList(), inp) + inp = ((torch.zeros(1, 4), torch.ones(1, 4)),) + self._check_equal_ts_ep_converter(MUnpackTuple(), inp) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 08decec27fea..4cbcc3faea83 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -319,6 +319,20 @@ def convert_prim_DictConstruct(self, node: torch._C.Node): output_name = node.output().debugName() self.name_to_node[output_name] = output_dict + def convert_prim_ListUnpack(self, node: torch._C.Node): + self._convert_prim_unpack_iterator(node) + + def convert_prim_TupleUnpack(self, node: torch._C.Node): + self._convert_prim_unpack_iterator(node) + + def _convert_prim_unpack_iterator(self, node: torch._C.Node): + # Single input and multiple outputs for unpacking. + for i, outp in enumerate(node.outputs()): + outp_name = outp.debugName() + inp = self.get_fx_value(node.input()) + fx_node = self.fx_graph.call_function(operator.getitem, (inp, i)) + self.name_to_node[outp_name] = fx_node + def convert_aten_Int(self, node: torch._C.Node): # converts aten::Int as aten._to_copy + aten::_local_scalar_dense target = torch.ops.aten._to_copy.default From 5dc912822913b3d90f4938891c7eca722a057cf1 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 5 Jun 2024 15:46:40 +0000 Subject: [PATCH 370/706] FP8 rowwise scaling (#125204) # Summary This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met: - `x`'s scale should be a 1-dimensional tensor of length `M`. - `y`'s scale should be a 1-dimensional tensor of length `N`. It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row". The following two PRs were required to enable local builds: - [PR #126185](https://github.com/pytorch/pytorch/pull/126185) - [PR #125523](https://github.com/pytorch/pytorch/pull/125523) ### Todo We still do not build our Python wheels with this architecture. @ptrblck @malfet, should we replace `sm_90` with `sm_90a`? The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit: https://github.com/pytorch/pytorch/pull/125204/files#r1586986954 #### ifdef I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \ defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this Kernel Credit: @jwfromm Pull Request resolved: https://github.com/pytorch/pytorch/pull/125204 Approved by: https://github.com/lw, https://github.com/malfet --- aten/src/ATen/CMakeLists.txt | 1 + aten/src/ATen/cuda/detail/LazyNVRTC.cpp | 37 ++ aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h | 15 +- aten/src/ATen/native/cuda/Blas.cpp | 113 +++- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 536 +++++++++++++++++++ aten/src/ATen/native/cuda/RowwiseScaledMM.h | 15 + test/test_matmul_cuda.py | 149 +++++- third_party/cutlass.BUILD | 14 +- 8 files changed, 855 insertions(+), 25 deletions(-) create mode 100644 aten/src/ATen/native/cuda/RowwiseScaledMM.cu create mode 100644 aten/src/ATen/native/cuda/RowwiseScaledMM.h diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 0087dd95d96e..5cd6aacf2463 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -473,6 +473,7 @@ endif() if(USE_CUDA AND NOT USE_ROCM) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) + list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) if($ENV{ATEN_STATIC_CUDA}) list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${CUDA_LIBRARIES} diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp index 1b85e7776e22..75c503d48d51 100644 --- a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp +++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -170,6 +170,43 @@ CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *); CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int); CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction); +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 +CUresult CUDAAPI +cuTensorMapEncodeTiled( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const cuuint32_t* boxDim, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) { + auto fn = reinterpret_cast( + getCUDALibrary().sym(__func__)); + if (!fn) + throw std::runtime_error("Can't get cuTensorMapEncodeTiled"); + lazyNVRTC.cuTensorMapEncodeTiled = fn; + return fn( + tensorMap, + tensorDataType, + tensorRank, + globalAddress, + globalDim, + globalStrides, + boxDim, + elementStrides, + interleave, + swizzle, + l2Promotion, + oobFill); +} + +#endif + // Irregularly shaped functions CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index 574b2c41c264..cb34d10db254 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -59,16 +59,25 @@ namespace at { namespace cuda { _(cuLinkAddData) \ _(cuLinkComplete) \ _(cuFuncSetAttribute) \ - _(cuFuncGetAttribute) + _(cuFuncGetAttribute) \ + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 +#define AT_FORALL_NVRTC_EXTENDED(_) \ + AT_FORALL_NVRTC_BASE(_) \ + _(cuTensorMapEncodeTiled) +#else +#define AT_FORALL_NVRTC_EXTENDED(_) \ + AT_FORALL_NVRTC_BASE(_) +#endif #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 #define AT_FORALL_NVRTC(_) \ - AT_FORALL_NVRTC_BASE(_) \ + AT_FORALL_NVRTC_EXTENDED(_) \ _(nvrtcGetCUBINSize) \ _(nvrtcGetCUBIN) #else #define AT_FORALL_NVRTC(_) \ - AT_FORALL_NVRTC_BASE(_) + AT_FORALL_NVRTC_EXTENDED(_) #endif #else diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 84c59a4fd0d7..ed59b47349cc 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1,3 +1,7 @@ +#include +#include +#include +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -10,6 +14,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -819,24 +824,97 @@ static bool _scaled_mm_allowed_device() { #endif } +namespace{ + +enum class ScalingType { + TensorWise, + RowWise, + Error +}; + +// Validates the scale tensors to scaled_mm +// And returns the type of scaling/which kernel to use +ScalingType get_scaling_type( + const c10::optional& scale_a, + const c10::optional& scale_b, + int64_t dim_m, + int64_t dim_n) { + TORCH_CHECK( + scale_a.has_value() == scale_b.has_value(), + "Both scale_a and scale_b must be present or absent."); + + if (scale_a.has_value()) { + // Both Per-Tensor and Row-wise scaling expect fp32 tensors + TORCH_CHECK( + scale_a->scalar_type() == kFloat && scale_b->scalar_type() == kFloat, + "Both scale_a and scale_b must be float (fp32) tensors."); + + // Check the singluar scale case for per-tensor scaling + if (scale_a->numel() == 1 && scale_b->numel() == 1) { + return ScalingType::TensorWise; + } else if (scale_a->dim() == 1 && scale_a->size(0) == dim_m) { +// Check the per-row scaling case +#if !defined(USE_ROCM) && !defined(_MSC_VER) || \ + (defined(USE_ROCM) && ROCM_VERSION >= 60000) + TORCH_CHECK( + scale_a->dim() == 1 && scale_b->dim() == 1, + "Both scale_a and scale_b must be 1-dimensional tensors"); + TORCH_CHECK( + scale_b->size(0) == dim_n, + "For row-wise scaling, scale_b must have size ", + dim_n, + " but got ", + scale_b->size(0), + "."); + TORCH_CHECK( + scale_a->is_contiguous() && scale_b->is_contiguous(), + "Both scale_a and scale_b must be contiguous."); + return ScalingType::RowWise; +#else + TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); + return ScalingType::Error; +#endif // !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && + // ROCM_VERSION >= 60000) + } else { + TORCH_CHECK( + false, + "For row-wise scaling, scale_a must be size ", + dim_m, + " but got ", + scale_a->numel(), + " and scale_b must be size ", + dim_n, + " but got ", + scale_b->numel(), + "."); + // Unreachable + return ScalingType::RowWise; + } + } + return ScalingType::Error; +} + +} // namespace + // Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax // Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default. // If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed. // Known limitations: // - Only works if mat1 is row-major and mat2 is column-major // - Only works if matrices sizes are divisible by 32 -// +// - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0) +// and scale_b should have size = to mat2.size(1) // Arguments: // - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16` // - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type -// - `scale_a`: a scalar tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type -// - `scale_b`: a scalar tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type -// - `scale_result`: a scalar tensor with the scale of the output, only set if the output is a float8 type +// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type +// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type +// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type // - `use_fast_accum`: if true, enables fast float8 accumulation // - `out`: a reference to the output tensor -// - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace +// - `amax`: a reference to the amax tensor of the output, only mutated if the output is a float8 type and will be updated inplace std::tuple _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, @@ -855,10 +933,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK( mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - TORCH_CHECK(!scale_a || (scale_a->numel() == 1 && scale_a->scalar_type() == kFloat), - "scale_a must be float scalar"); - TORCH_CHECK(!scale_b || (scale_b->numel() == 1 && scale_b->scalar_type() == kFloat), - "scale_b must be a float scalar"); + + // Check what type of scaling we are doing based on inputs + ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1)); + TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); + TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], @@ -901,12 +980,26 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, {scale_result_, "scale_result", 7}}; checkAllSameGPU(__func__, targs); } - + // Validation checks have passed lets resize the output to actual size IntArrayRef mat1_sizes = mat1.sizes(); IntArrayRef mat2_sizes = mat2.sizes(); at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); at::native::resize_output(amax, {}); + // We are doing row-wise scaling + if (scaling_choice == ScalingType::RowWise) { + TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling."); + at::cuda::detail::f8f8bf16_rowwise( + mat1, + mat2, + scale_a.value(), + scale_b.value(), + bias, + use_fast_accum, + out); + return {out, amax}; + } + cublasCommonArgs args(mat1, mat2, out); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu new file mode 100644 index 000000000000..84655d281afc --- /dev/null +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -0,0 +1,536 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +// Determine if the architecture supports rowwise scaled mm +// Currenlty failing on windows with: https://github.com/NVIDIA/cutlass/issues/1571 +#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + +#define BUILD_ROWWISE_FP8_KERNEL +#endif + +#if defined(BUILD_ROWWISE_FP8_KERNEL) + +// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader +static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const cuuint32_t* boxDim, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) { + return at::globalContext().getNVRTC().cuTensorMapEncodeTiled( + tensorMap, + tensorDataType, + tensorRank, + globalAddress, + globalDim, + globalStrides, + boxDim, + elementStrides, + interleave, + swizzle, + l2Promotion, + oobFill); +} + + +#include +#include +#include +#include +#include +#include +#include + +// Rename the global function symbol +#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled +#include +#undef cuTensorMapEncodeTiled +// Set everything back to normal + +#include +#include +#include + +#include +#include +#include +#include + + +namespace { +// Cutlass rowwise kernel +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + bool FAST_ACCUM, + bool USE_BIAS, + typename INPUT_DTYPE, + typename BIAS_DTYPE> +void f8f8bf16_rowwise_impl( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + c10::optional bias, + at::Tensor out) { + int M = XQ.size(0); + int N = WQ.size(1); + int K = XQ.size(1); + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK( + WQ.is_cuda() && WQ.ndimension() == 2 && WQ.stride(1) == WQ.size(0) && + WQ.stride(0) == 1); + + // auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + + using ElementInputA = INPUT_DTYPE; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); + + using ElementInputB = cutlass::float_e4m3_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); + + using ElementBias = BIAS_DTYPE; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using KernelSchedule = cutlass::gemm::collective:: + KernelScheduleAuto; // Kernel to launch based on the default setting in + // the Collective Builder + + // Implement rowwise scaling epilogue. + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementBias, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementComputeEpilogue, // First stage output type. + ElementComputeEpilogue, // First stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + cute::conditional_t< // Second stage output type. + USE_BIAS, + ElementBias, + ElementOutput>, + ElementComputeEpilogue, // Second stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementOutput, // Final (optional) stage output type. + ElementBias, // Final stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeBias = + cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using SlowAccum = cute::conditional_t; + using FastAccum = + cute::conditional_t; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, 1)); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, 1)); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(M, N, 1)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + {reinterpret_cast(XQ.data_ptr()), + stride_a, + reinterpret_cast(WQ.data_ptr()), + stride_b}, + {{}, // Epilogue thread we populate below. + (ElementOutput*)out.data_ptr(), + stride_output, + (ElementOutput*)out.data_ptr(), + stride_output}}; + + if constexpr (USE_BIAS) { + arguments.epilogue.thread = { + {reinterpret_cast(bias.value().data_ptr())}, // bias + // compute_1 + { + {reinterpret_cast( + x_scale.data_ptr())}, // x_scale + // compute_0 + { + {reinterpret_cast( + w_scale.data_ptr())}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }, + {}, // Plus + }; + } else { + arguments.epilogue.thread = { + {reinterpret_cast( + x_scale.data_ptr())}, // x_scale + // compute_0 + { + {reinterpret_cast( + w_scale.data_ptr())}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +// FP8 Rowwise Cutlass kernel dispatch. +enum class KernelMode { Small, Large, Default }; + +KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { + auto M = XQ.size(0); + auto K = XQ.size(1); + auto N = WQ.size(0); + // Use a large kernel if at least two shapes are large.... + bool use_large_kernel = + ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || + (K >= 2048 && N >= 2048)); + if (M <= 128 || N <= 128) { + return KernelMode::Small; + } else if (use_large_kernel) { + return KernelMode::Large; + } else { + return KernelMode::Default; + } +} + +template +void dispatch_fp8_rowwise_kernel( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + c10::optional bias, + at::Tensor out) { + KernelMode kernel = get_kernel_mode(XQ, WQ); + if (kernel == KernelMode::Small) { + return f8f8bf16_rowwise_impl< + 64, + 128, + 128, + 2, + 1, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); + } else if (kernel == KernelMode::Large) { + return f8f8bf16_rowwise_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return f8f8bf16_rowwise_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); + } +} + +} // namespace + +#endif // !defined(USE_ROCM) + +namespace at::cuda::detail { +void f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + c10::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out) { +#if defined(BUILD_ROWWISE_FP8_KERNEL) + // Check datatypes. + TORCH_CHECK( + x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, + "Scale tensors must be float32."); + if (bias.has_value()) { + TORCH_CHECK( + bias.value().dtype() == at::kFloat || + bias.value().dtype() == at::kBFloat16, + "Bias type must be bfloat16 or float32 if provided."); + } + // Extract problem size. + int M = XQ.size(0); + int N = WQ.size(1); + int K = XQ.size(1); + + bool use_bias = bias.has_value(); + bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; + + // Templatize based on input dtype. + bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2; + TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn, "For row-wise scaling the second input is required to be a float8_e4m3fn dtype."); + + if (use_bias) { + if (bf16_bias) { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + } +#else // BUILD_ROWWISE_FP8_KERNEL + TORCH_CHECK(false, "Rowwise scaling is not currenlty supported on your device"); +#endif +} + +} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.h b/aten/src/ATen/native/cuda/RowwiseScaledMM.h new file mode 100644 index 000000000000..4d9054108c85 --- /dev/null +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include + + +namespace at::cuda::detail { +TORCH_API void f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + c10::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out); +} // at::cuda::detail diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index a5c583580848..74381567a552 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -204,7 +204,6 @@ def _expand_to_batch(t: torch.Tensor): self.assertEqual(out1_gpu, out2_gpu[0]) - f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" if torch.version.hip: @@ -256,8 +255,12 @@ def amax_to_scale( scale.copy_(res) return scale -def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype): - amax = torch.max(torch.abs(x)) +def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): + if dim is None: + amax = torch.max(torch.abs(x)) + else: + amax = torch.max(torch.abs(x), dim=dim).values + return amax_to_scale(amax, float8_dtype, x.dtype) def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype): @@ -316,7 +319,6 @@ def mm_float8( def to_fp8_saturated( x: torch.Tensor, - x_scale: torch.tensor, fp8_dtype: torch.dtype ): """ @@ -339,8 +341,6 @@ def to_fp8_saturated( of a tensor has a maximum value of `amax1`, and the current amax value is `amax2`, where `amax1 < amax2`. """ - x_scaled = x * x_scale - if fp8_dtype == e4m3_type: x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) elif fp8_dtype == e5m2_type: @@ -353,8 +353,6 @@ def to_fp8_saturated( @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") class TestFP8MatmulCuda(TestCase): - - @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) def _test_tautological_mm(self, device: str = "cuda", x_dtype: torch.dtype = e4m3_type, @@ -418,8 +416,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype): x_scale = tensor_to_scale(x, input_dtype).float() y_scale = tensor_to_scale(y, input_dtype).float() - x_fp8 = to_fp8_saturated(x, x_scale, e4m3_type) - y_fp8 = to_fp8_saturated(y, y_scale, e4m3_type) + x_fp8 = to_fp8_saturated(x * x_scale, e4m3_type) + y_fp8 = to_fp8_saturated(y * y_scale, e4m3_type) # Calculate actual F8 mm out_scaled_mm, output_amax_scaled = mm_float8( @@ -526,6 +524,137 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) + @unittest.skipIf(not scaled_mm_supported_device() or IS_WINDOWS, f8_msg) + @skipIfRocm() + @parametrize("use_fast_accum", [True, False]) + def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: + M, K, N = (1024, 512, 2048) + fill_value = 0.5 + x = torch.full((M, K), fill_value, device=device) + y = torch.full((N, K), fill_value, device=device) + + x_scales = torch.ones(x.shape[0], device=device, dtype=torch.float32) + y_scales = torch.ones(y.shape[0], device=device, dtype=torch.float32) + + x_fp8 = x.to(torch.float8_e4m3fn) + y_fp8 = y.to(torch.float8_e4m3fn).t() + + out_fp8, _ = torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=x_scales, + scale_b=y_scales, + out_dtype=torch.bfloat16, + use_fast_accum=use_fast_accum, + ) + self.assertEqual( + out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) + ) + + @unittest.skipIf(not scaled_mm_supported_device() or IS_WINDOWS, f8_msg) + @skipIfRocm() + def test_float8_error_messages(self, device) -> None: + M, K, N = (1024, 512, 2048) + fill_value = 0.5 + x = torch.full((M, K), fill_value, device=device) + y = torch.full((N, K), fill_value, device=device) + + x_fp8 = x.to(torch.float8_e4m3fn) + y_fp8 = y.to(torch.float8_e4m3fn).t() + + with self.assertRaisesRegex( + RuntimeError, + "For row-wise scaling, scale_a must be size 1024 but got 1 and scale_b must be size 2048 but got 2", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((), device="cuda"), + scale_b=torch.ones((2), device="cuda"), + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + "For row-wise scaling, scale_b must have size 2048 but got 2049.", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N + 1), device="cuda"), + out_dtype=torch.bfloat16, + ) + with self.assertRaisesRegex( + RuntimeError, + "Both scale_a and scale_b must be 1-dimensional tensors", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N, N), device="cuda"), + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + "Both scale_a and scale_b must be contiguous.", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N * 2), device="cuda")[::2], + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + "For row-wise scaling the second input is required to be a float8_e4m3fn dtype.", + ): + torch._scaled_mm( + x_fp8, + y_fp8.to(torch.float8_e5m2), + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N), device="cuda"), + out_dtype=torch.bfloat16, + ) + + @unittest.skipIf(not scaled_mm_supported_device() or IS_WINDOWS, f8_msg) + @skipIfRocm() + @parametrize("base_dtype", [torch.bfloat16]) + def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): + torch.manual_seed(42) + input_dtype = e4m3_type + output_dtype = base_dtype + + x = torch.randn(16, 16, device="cuda", dtype=base_dtype) + y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + + x_scales = tensor_to_scale(x, input_dtype, dim=1).float() + y_scales = tensor_to_scale(y, input_dtype, dim=0).float() + + x_fp8 = to_fp8_saturated(x * x_scales[:, None], e4m3_type) + y_fp8 = to_fp8_saturated(y * y_scales[None, :], e4m3_type) + + # Calculate actual F8 mm + out_scaled_mm, _ = mm_float8( + x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype + ) + + # Calculate emulated F8 mm + out_emulated, _ = mm_float8_emulated( + x_fp8, x_scales[:, None], y_fp8, y_scales[None, :], output_dtype + ) + + if base_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 7e-2, 7e-2 + else: + atol, rtol = 2e-3, 2e-3 + + torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD index e712d59597cc..e3e7b7b288e7 100644 --- a/third_party/cutlass.BUILD +++ b/third_party/cutlass.BUILD @@ -5,7 +5,17 @@ load("@rules_cc//cc:defs.bzl", "cc_library") cc_library( name = "cutlass", - hdrs = glob(["include/**/*.h", "include/**/*.hpp"]), - includes = ["include/"], + hdrs = glob([ + "include/**/*.h", + "include/**/*.hpp", + "include/**/*.inl", + "tools/util/include/**/*.h", + "tools/util/include/**/*.hpp", + "tools/util/include/**/*.inl", + ]), + includes = [ + "include/", + "tools/util/include/", + ], visibility = ["//visibility:public"], ) From 22964d1007d1d87964df1db609da53b35b00316f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 3 Jun 2024 12:31:48 -0700 Subject: [PATCH 371/706] [DSD] Deprecate submodules feature for DSD (#127793) Summary: Getting a partial of the state_dict and set the state_dict with the type of Dict[nn.Module, Dict[str, Any]] is too complicated and can confuse users. The features can be achieved by simple pre-processing and post-processing by users. So this PR adds the deprecation warning to the feature. The previous PR, https://github.com/pytorch/pytorch/pull/127070, assumes no one is using the feature and remove it without the grace period. This seems to be too aggresive and causes some concerns. This PR adds the deprecation warning and tests. We will remove the support in 2.5. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127793 Approved by: https://github.com/LucasLLC --- .../distributed/checkpoint/test_state_dict.py | 55 +++++++++++++++ torch/distributed/checkpoint/state_dict.py | 67 +++++++++++++++++-- 2 files changed, 116 insertions(+), 6 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 329f8015dc7c..b0d62c32c6e5 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -738,6 +738,61 @@ def test_flattened_osd(self) -> None: ) self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) + @with_comms + @skip_if_lt_x_gpu(1) + def test_deprecate_partial(self) -> None: + model = CompositeParamModel(device=torch.device("cuda")) + + model_state_dict1 = get_model_state_dict(model) + model_state_dict1 = copy.deepcopy(model_state_dict1) + with self.assertWarnsRegex( + FutureWarning, + "Getting submodules only model/optim state_dict is deprecated", + ): + model_state_dict2 = get_model_state_dict(model, submodules={model.l}) + model_state_dict2 = copy.deepcopy(model_state_dict2) + with self.assertWarnsRegex( + FutureWarning, + "Getting submodules only model/optim state_dict is deprecated", + ): + model_state_dict3 = get_model_state_dict( + model, + submodules={model.l}, + options=StateDictOptions(keep_submodule_prefixes=False), + ) + model_state_dict3 = copy.deepcopy(model_state_dict3) + self.assertEqual(len(model_state_dict2), 2) + self.assertEqual(len(model_state_dict3), 2) + for key in model_state_dict3.keys(): + full_fqn = f"l.{key}" + value1 = model_state_dict1[full_fqn] + value2 = model_state_dict2[full_fqn] + value3 = model_state_dict3[key] + self.assertEqual(value1, value2) + self.assertEqual(value2, value3) + + zeros_state_dict = { + k: torch.zeros_like(v) for k, v in model_state_dict1.items() + } + model.load_state_dict(zeros_state_dict) + set_model_state_dict( + model, + model_state_dict=model_state_dict2, + options=StateDictOptions(strict=False), + ) + self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) + self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) + + model.load_state_dict(zeros_state_dict) + with self.assertWarnsRegex(FutureWarning, "Passing model_state_dict as a "): + set_model_state_dict( + model, + model_state_dict={model.l: model_state_dict3}, + options=StateDictOptions(strict=False), + ) + self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) + self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) + class TestNoComm(MultiProcessTestCase): def setUp(self) -> None: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 46701c3493d5..144dafd0a561 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -1,6 +1,7 @@ import contextlib import functools import gc +import warnings from dataclasses import asdict, dataclass, field from itertools import chain from typing import ( @@ -123,7 +124,7 @@ class StateDictOptions: won't contain any frozen parameters -- the ``requires_grad`` is False. The default value is False. - - ``keep_submodule_prefixes``: when ``submodules`` is not None, this option + - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option indicates whether to keep the submodule prefixes from the state_dict keys. or example, if the submodule is ``module.pretrain`` and the full FQN of the parameter is ``pretrain.layer1.weight`` of the param. When this option @@ -275,6 +276,13 @@ def _verify_options( """ Verify the model and options passed by the user and generates _StateDictInfo. """ + if submodules: + warnings.warn( + "Getting submodules only model/optim state_dict is deprecated and " + "will be removed in 2.5. This feature can be achieved by manually " + "filtering out the state_dict returned from get_state_dict.", + FutureWarning, + ) if optim_only and not optims: raise RuntimeError( "Optimizers are not passed in but optim_only is set to True." @@ -910,7 +918,7 @@ def get_model_state_dict( Args: model (nn.Module): the nn.Module to the model. - submodules: Optional[Set[nn.Module]]: only return the model parameters + submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be returned. See @@ -950,7 +958,7 @@ def get_optimizer_state_dict( model (nn.Module): the nn.Module to the model. optimizers (Union[None, Optimizer, Iterable[Optimizer]]): The optimizers that are used to optimize ``model``. - submodules: Optional[Set[nn.Module]]: only return the model parameters + submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be returned. See @@ -1037,7 +1045,7 @@ def get_state_dict( model (nn.Module): the nn.Module to the model. optimizers (Union[None, Optimizer, Iterable[Optimizer]]): The optimizers that are used to optimize ``model``. - submodules: Optional[Set[nn.Module]]: only return the model parameters + submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be returned. See @@ -1068,6 +1076,39 @@ def get_state_dict( return model_state_dict, optim_state_dict +def _unflatten_model_state_dict( + model: nn.Module, + state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]], +) -> Dict[str, ValueType]: + if not state_dict: + return {} + + if isinstance(next(iter(state_dict.keys())), nn.Module): + warnings.warn( + "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``" + "is deprecated and will be removed in 2.5. If you need this " + "feature, please preprocessing the model_state_dict to achieve the " + "same functionality.", + FutureWarning, + ) + cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict) + new_state_dict: Dict[str, ValueType] = {} + for submodule, sub_state_dict in cast_state_dict.items(): + for name, m in model.named_modules(): + if m != submodule: + continue + + fqns = _get_fqns(model, name) + assert len(fqns) == 1, "FQNs for a submodule should only have 1 element" + prefix = f"{next(iter(fqns))}." + new_state_dict.update( + {prefix + subfqn: value for subfqn, value in sub_state_dict.items()} + ) + return new_state_dict + else: + return cast(Dict[str, ValueType], state_dict) + + def set_model_state_dict( model: nn.Module, model_state_dict: Dict[str, ValueType], @@ -1081,7 +1122,11 @@ def set_model_state_dict( Args: model (nn.Module): the nn.Module to the model. - model_state_dict: (Dict[str, ValueType]): the model state_dict to load. + model_state_dict: (Dict[str, ValueType]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be loaded. See `StateDictOptions` for the details. @@ -1093,6 +1138,9 @@ def set_model_state_dict( :type model_state_dict: typing.Dict[str, ValueType] """ + model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( + model, model_state_dict + ) with gc_context(): info = _verify_options(model, tuple(), optim_only=False, options=options) @@ -1161,7 +1209,11 @@ def set_state_dict( model (nn.Module): the nn.Module to the model. optimizers (Union[Optimizer, Iterable[Optimizer]]): The optimizers that are used to optimize ``model``. - model_state_dict: (Dict[str, ValueType]]): the model state_dict to load. + model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. optim_state_dict: OptimizerStateType: the optimizer state_dict to load. options (StateDictOptions): the options to control how @@ -1177,6 +1229,9 @@ def set_state_dict( :type optim_state_dict: typing.OptimizerStateType """ + model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( + model, model_state_dict + ) with gc_context(): optimizers = ( (optimizers,) From 9acc19f8da260d8fd75227bd4f7f267f4e1ba209 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Mon, 3 Jun 2024 18:49:20 +0100 Subject: [PATCH 372/706] [inductor] Take absolute value of strides when picking loop order (#127425) Fixes #126860 The stride hint is found by comparing the value of the indexing expression evaluated at `idx` set to all zeros and at `idx[dim] = 1`. This causes a problem for padded inputs where 0 and 1 are still in the padded region. In particular, for reflection padding this causes the stride to be negative. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127425 Approved by: https://github.com/lezcano --- test/inductor/test_cuda_repro.py | 34 ++++++++++++++++++++++++++++++++ torch/_inductor/scheduler.py | 6 ++++-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index c1ce2769a658..8365d216f82c 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1204,6 +1204,40 @@ def forward(self, x): out2 = m(input_tensor) self.assertEqual(out, out2, atol=1e-3, rtol=1e-3) + def test_reflection_pad_loop_order(self): + def fn(x, y): + a = torch.nn.functional.pad(x, (5, 5, 5, 5), mode="reflect") + b = torch.nn.functional.pad(y, (5, 5, 5, 5), mode="reflect") + return a + b + + cfn = torch.compile(fn) + a = torch.rand((10, 10, 10), device="cuda") + b = torch.rand((10, 10, 10), device="cuda") + expect = fn(a, b) + actual, code = run_and_get_code(cfn, a, b) + self.assertEqual(expect, actual) + + # Expect the code iterates in contiguous order, and is not tiled + kernel_code = "\n".join(code[0].split("\n")[50:64]) + self.assertExpectedInline( + kernel_code, + """\ +@triton.jit +def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 4000 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex % 20 + x1 = (xindex // 20) % 20 + x2 = (xindex // 400) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last') + tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last') + tmp2 = tmp0 + tmp1 + tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950 + ) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index f17fb1f12daa..88ff1714a3f6 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1279,8 +1279,10 @@ def index_cmp(a: int, b: int) -> int: # 1-sizes don't matter, just move them to the end return cmp(sizes[a] == 1, sizes[b] == 1) - stride_len_a = [sl[a] for sl in stride_lengths] - stride_len_b = [sl[b] for sl in stride_lengths] + # Take abs, otherwise flipped dimensions are treated as smaller + # strides than contiguous dims + stride_len_a = [abs(sl[a]) for sl in stride_lengths] + stride_len_b = [abs(sl[b]) for sl in stride_lengths] # equivalent to # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all() From a9cc147fa1f3d9d26752ed39b3da55a92d75cdcd Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 3 Jun 2024 12:31:54 -0700 Subject: [PATCH 373/706] [DSD][FSDP1] Deprecate FSDP.state_dict_type and redirect users to DSD (#127794) Summary: As title Pull Request resolved: https://github.com/pytorch/pytorch/pull/127794 Approved by: https://github.com/awgu ghstack dependencies: #127793 --- .../distributed/checkpoint/test_state_dict.py | 22 ++++++++++++++++++- torch/distributed/checkpoint/state_dict.py | 18 ++++++++++++++- .../fsdp/fully_sharded_data_parallel.py | 9 ++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index b0d62c32c6e5..58d1f20ad911 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -30,7 +30,7 @@ set_optimizer_state_dict, StateDictOptions, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.optim import _apply_optimizer_in_backward from torch.nn.parallel import DistributedDataParallel as DDP @@ -793,6 +793,26 @@ def test_deprecate_partial(self) -> None: self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) + @with_comms + @skip_if_lt_x_gpu(1) + def test_deprecate_fsdp_api(self) -> None: + device_mesh = init_device_mesh("cuda", (self.world_size,)) + model = CompositeParamModel(device=torch.device("cuda")) + fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh) + with self.assertWarnsRegex( + FutureWarning, + r"FSDP.state_dict_type\(\) and FSDP.set_state_dict_type\(\) are being deprecated", + ): + with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): + fsdp_model.state_dict() + + with self.assertRaisesRegex(AssertionError, "FutureWarning not triggered"): + with self.assertWarnsRegex( + FutureWarning, + r"FSDP.state_dict_type\(\) and FSDP.set_state_dict_type\(\) are being deprecated", + ): + get_model_state_dict(model) + class TestNoComm(MultiProcessTestCase): def setUp(self) -> None: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 144dafd0a561..0d1a3a625a25 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -341,8 +341,24 @@ def _verify_options( ) state_dict_type = StateDictType.SHARDED_STATE_DICT + @contextlib.contextmanager + def fsdp_state_dict_type_without_warning( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ): + with warnings.catch_warnings(): + with FSDP.state_dict_type( + module=module, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ): + yield + fsdp_context = functools.partial( - FSDP.state_dict_type, + fsdp_state_dict_type_without_warning, module=model, state_dict_type=state_dict_type, state_dict_config=state_dict_config, diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index c798ed1818d7..9edd057a8f37 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -686,6 +686,15 @@ def set_state_dict_type( A StateDictSettings that include the previous state_dict type and configuration for the module. """ + warnings.warn( + "FSDP.state_dict_type() and FSDP.set_state_dict_type() are being " + "deprecated. Please use APIs, get_state_dict() and set_state_dict(), " + "which can support different parallelisms, FSDP1, FSDP2, DDP. " + "API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html" + "#torch.distributed.checkpoint.state_dict.get_state_dict ." + "Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .", + FutureWarning, + ) _state_dict_type_to_config = { StateDictType.FULL_STATE_DICT: FullStateDictConfig, StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, From 4adee71155bec4e419bac32be2cbc1763bc6c98f Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Tue, 4 Jun 2024 17:56:39 +0000 Subject: [PATCH 374/706] [dynamo] Support ndarray.dtype attribute access (#124490) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124490 Approved by: https://github.com/lezcano ghstack dependencies: #125717 --- test/dynamo/test_functions.py | 4 ++ test/test_binary_ufuncs.py | 4 +- test/test_unary_ufuncs.py | 2 +- .../numpy_tests/core/test_multiarray.py | 15 +++++-- torch/_dynamo/variables/misc.py | 5 +++ torch/_dynamo/variables/tensor.py | 3 ++ torch/testing/_internal/common_utils.py | 42 +++++++++---------- 7 files changed, 46 insertions(+), 29 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index e2baebf60321..4d0285871ced 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1619,6 +1619,10 @@ def test_numpy_dtype_call_in_function(x): dt = np.dtype("float") return np.full_like(x, 2.4, dtype=dt) + @make_test + def test_numpy_dtype_attr(x): + return np.ones_like(x).dtype == x.dtype + @make_test def test_numpy_linalg(x): return np.linalg.norm(x.numpy(), axis=0) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index ffa3e5388979..f1423d0ac8cb 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -121,9 +121,7 @@ def _test_reference_numerics(self, dtype, op, gen, equal_nan=True): def _helper_reference_numerics( expected, actual, msg, exact_dtype, equal_nan=True ): - if not torch.can_cast( - numpy_to_torch_dtype_dict[expected.dtype.type], dtype - ): + if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype], dtype): exact_dtype = False if dtype is torch.bfloat16 and expected.dtype == np.float32: diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index f47e7d36222f..b232d47260f4 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -184,7 +184,7 @@ def _helper_reference_numerics( expected, actual, msg, exact_dtype, equal_nan=True ): if not torch.can_cast( - numpy_to_torch_dtype_dict[expected.dtype.type], dtype + numpy_to_torch_dtype_dict[expected.dtype], dtype ): exact_dtype = False diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index 76af79f62084..a957c8dd86c4 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -1833,7 +1833,7 @@ def test_argsort_axis(self): a = np.array(["aaaaaaaaa" for i in range(100)], dtype=np.unicode_) assert_equal(a.argsort(kind="m"), r) - @xpassIfTorchDynamo # (reason="TODO: searchsorted with nans differs in pytorch") + @xfail # (reason="TODO: searchsorted with nans differs in pytorch") @parametrize( "a", [ @@ -1905,7 +1905,7 @@ def test_searchsorted_n_elements(self): b = a.searchsorted([0, 1, 2], "right") assert_equal(b, [0, 2, 2]) - @xpassIfTorchDynamo # ( + @xfail # ( # reason="RuntimeError: self.storage_offset() must be divisible by 8" # ) def test_searchsorted_unaligned_array(self): @@ -1984,7 +1984,7 @@ def test_searchsorted_with_invalid_sorter(self): # assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1, 0, 1, 2, 3]) # assert_raises(ValueError, np.searchsorted, a, 0, sorter=[4, 0, -1, 2, 3]) - @xpassIfTorchDynamo # (reason="self.storage_offset() must be divisible by 8") + @xfail # (reason="self.storage_offset() must be divisible by 8") def test_searchsorted_with_sorter(self): a = np.random.rand(300) s = a.argsort() @@ -3713,7 +3713,14 @@ def test_out_overlap(self): y = np.take(x, [1, 2, 3], out=x[2:5], mode="wrap") assert_equal(y, np.array([1, 2, 3])) - @parametrize("shape", [(1, 2), (1,), ()]) + @parametrize( + "shape", + [ + subtest((1, 2)), + subtest((1,)), + subtest((), decorators=[skip("Sensitive to np version")]), + ], + ) def test_ret_is_out(self, shape): # 0d arrays should not be an exception to this rule x = np.arange(5) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index cc0fb7096701..06372f4d53b5 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1189,6 +1189,11 @@ class NumpyTypeInfoVariable(ConstantLikeVariable): class NumpyDTypeVariable(ConstantLikeVariable): _error_prefix = "np.dtype[...]" + def __init__(self, value, **kwargs): + if isinstance(value, tnp.DType): + value = ConstantLikeVariable.np_dtype(value.name) + super().__init__(value, **kwargs) + def as_proxy(self): """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable: diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 0552a8e62122..30cbd556d0b2 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1089,6 +1089,7 @@ def var_getattr(self, tx, name): from ..utils import numpy_attr_wrapper from .builder import wrap_fx_proxy + from .misc import NumpyDTypeVariable result = None @@ -1135,6 +1136,8 @@ def insert_into_graph(): if not has_free_symbols(r := example_ndarray.size): return ConstantVariable.create(int(r)) return insert_into_graph() + if name == "dtype": + return NumpyDTypeVariable(example_ndarray.dtype) elif name in ["base", "flags", "dtype"]: unimplemented(f"TODO: add support for ndarray.{name}") elif name in ["__version__"]: diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index e748ff0388fb..88cba64052c2 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1500,31 +1500,31 @@ def wrapper(*args, **kwargs): # Dict of NumPy dtype -> torch dtype (when the correspondence exists) numpy_to_torch_dtype_dict = { - np.bool_ : torch.bool, - np.uint8 : torch.uint8, - np.uint16 : torch.uint16, - np.uint32 : torch.uint32, - np.uint64 : torch.uint64, - np.int8 : torch.int8, - np.int16 : torch.int16, - np.int32 : torch.int32, - np.int64 : torch.int64, - np.float16 : torch.float16, - np.float32 : torch.float32, - np.float64 : torch.float64, - np.complex64 : torch.complex64, - np.complex128 : torch.complex128 + np.dtype(np.bool_) : torch.bool, + np.dtype(np.uint8) : torch.uint8, + np.dtype(np.uint16) : torch.uint16, + np.dtype(np.uint32) : torch.uint32, + np.dtype(np.uint64) : torch.uint64, + np.dtype(np.int8) : torch.int8, + np.dtype(np.int16) : torch.int16, + np.dtype(np.int32) : torch.int32, + np.dtype(np.int64) : torch.int64, + np.dtype(np.float16) : torch.float16, + np.dtype(np.float32) : torch.float32, + np.dtype(np.float64) : torch.float64, + np.dtype(np.complex64) : torch.complex64, + np.dtype(np.complex128): torch.complex128 } -# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like -# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type. -# Especially when checking against a reference we can't be sure which variant we get, so we simply try both. +# numpy dtypes like np.float64 are not instances, but rather classes. This leads +# to rather absurd cases like np.float64 != np.dtype("float64") but +# np.dtype(np.float64) == np.dtype("float64") and +# np.dtype(np.dtype("float64")) == np.dtype("float64"). Especially when +# checking against a reference we can't be sure which variant we get, so we +# simply apply the conversion. def numpy_to_torch_dtype(np_dtype): - try: - return numpy_to_torch_dtype_dict[np_dtype] - except KeyError: - return numpy_to_torch_dtype_dict[np_dtype.type] + return numpy_to_torch_dtype_dict[np.dtype(np_dtype)] def has_corresponding_torch_dtype(np_dtype): From 6454e95824f7ec7374435f02ac29062edfe746f6 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 5 Jun 2024 01:07:28 -0700 Subject: [PATCH 375/706] [FSDP2] enable CI for torch.compile(root Transformer) (#127832) This CI showcases FSDP2 works with `torch.compile` root model, since FSDP1 can do the same compiling root Transformer without AC: `pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_multi_group` compiling root Transformer with AC: `pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_with_activation_checkpointing` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127832 Approved by: https://github.com/awgu --- .../fsdp/test_fully_shard_training.py | 73 ++++++++----------- torch/testing/_internal/common_distributed.py | 2 + torch/testing/_internal/common_fsdp.py | 68 ++++++++++------- 3 files changed, 72 insertions(+), 71 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 6634e142312b..a34a59f3cde6 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -243,7 +243,6 @@ def world_size(self) -> int: return min(8, torch.cuda.device_count()) @skip_if_lt_x_gpu(2) - @test_compiled_fsdp() def test_train_parity_single_group(self): """Tests train parity with DDP for a single FSDP group.""" self.run_subtests( @@ -275,7 +274,8 @@ def _test_train_parity_single_group(self, lin_shapes: List[Tuple[int, int]]): self.assertEqual(losses[0], losses[1]) @skip_if_lt_x_gpu(2) - def test_train_parity_multi_group_eager(self): + @test_compiled_fsdp(compile_compute_on_module=Transformer) + def test_train_parity_multi_group(self): """ Tests train parity against DDP when using multiple parameter groups for communication (for communication and computation overlap plus memory @@ -294,21 +294,6 @@ def test_train_parity_multi_group_eager(self): self._test_train_parity_multi_group, ) - @skip_if_lt_x_gpu(2) - def test_train_parity_multi_group_compile(self): - self.run_subtests( - { - "reshard_after_forward": [True, False], - "device_type": ["cuda"], - "offload_policy": [OffloadPolicy()], - "delay_after_forward": [False, True], - "delay_before_all_gather": [False], - "delay_before_reduce_scatter": [False], - "delay_before_optim": [False, True], - }, - self._test_train_parity_multi_group, - ) - @skip_if_lt_x_gpu(2) def test_train_parity_multi_group_cpu_offload_eager(self): """ @@ -353,7 +338,15 @@ def _test_train_parity_multi_group( assert device_type in ("cuda", "cpu"), f"{device_type}" torch.manual_seed(42) lin_dim = 32 - model = nn.Sequential(*[MLP(lin_dim, torch.device("cpu")) for _ in range(3)]) + vocab_size = 1024 + model_args = ModelArgs( + n_layers=3, + n_heads=4, + vocab_size=vocab_size, + max_seq_len=64, + dropout_p=0, + ) + model = Transformer(model_args) ref_model = copy.deepcopy(model) if device_type == "cuda": replicate(ref_model.cuda(), device_ids=[self.rank]) @@ -368,8 +361,9 @@ def _test_train_parity_multi_group( reshard_after_forward=reshard_after_forward, offload_policy=offload_policy, ) - for mlp in model: - fully_shard_fn(mlp) + for module in model.modules(): + if isinstance(module, TransformerBlock): + fully_shard_fn(module) fully_shard_fn(model) optim = torch.optim.Adam(model.parameters(), lr=1e-2) @@ -398,7 +392,7 @@ def delayed_reduce_scatter(*args, **kwargs): ) with patch_all_gather_ctx, patch_reduce_scatter_ctx: for iter_idx in range(10): - inp = torch.randn((8, lin_dim), device=torch.device(device_type)) + inp = torch.randint(0, vocab_size, (3, 64), device=device_type) losses: List[torch.Tensor] = [] for _model, _optim in ((ref_model, ref_optim), (model, optim)): _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) @@ -412,7 +406,6 @@ def delayed_reduce_scatter(*args, **kwargs): self.assertEqual(losses[0], losses[1]) @skip_if_lt_x_gpu(2) - @test_compiled_fsdp() def test_non_root_forward_backward(self): """ Tests running forward/backward through the root and then through a @@ -459,7 +452,6 @@ def test_non_root_forward_backward(self): self.assertEqual(ref_model(inp).sum(), model(inp).sum()) @skip_if_lt_x_gpu(2) - @test_compiled_fsdp() def test_multi_forward_module(self): """ Tests parity with DDP when running a module that participates multiple @@ -511,6 +503,7 @@ def world_size(self) -> int: return min(torch.cuda.device_count(), 2) @skip_if_lt_x_gpu(2) + @test_compiled_fsdp(compile_compute_on_module=Transformer) def test_train_parity_with_activation_checkpointing(self): """ Tests train parity against DDP when composing with activation @@ -528,6 +521,9 @@ def _test_train_parity_with_activation_checkpointing( self, reshard_after_forward: Union[bool, int], checkpoint_impl: str ): assert checkpoint_impl in ("composable", "utils", "wrapper") + testing_compile = fully_shard != torch.distributed._composable.fsdp.fully_shard + if testing_compile and checkpoint_impl == "composable": + return torch.manual_seed(42) vocab_size = 1024 with torch.device(torch.device("cuda")): @@ -536,7 +532,7 @@ def _test_train_parity_with_activation_checkpointing( n_heads=4, vocab_size=vocab_size, max_seq_len=64, - dropout_p=0.1, + dropout_p=0, checkpoint_activations=(checkpoint_impl == "utils"), ) model = Transformer(model_args) @@ -579,16 +575,18 @@ def _test_train_parity_with_activation_checkpointing( torch.manual_seed(iter_idx + 1) # for dropout determinism losses.append(_model(inp).sum()) losses[-1].backward() - check_sharded_parity( - self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore - ) + if not testing_compile: + check_sharded_parity( + self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore + ) self.assertEqual(losses[0], losses[1]) for _optim in (ref_optim, optim): _optim.step() _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) - check_sharded_parity( - self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore - ) + if not testing_compile: + check_sharded_parity( + self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore + ) class TestFullyShardSharedParams(FSDPTest): @@ -597,22 +595,11 @@ def world_size(self) -> int: return min(4, torch.cuda.device_count()) @skip_if_lt_x_gpu(2) - @test_compiled_fsdp(compile_compute_on_module=TransformerBlock) - def test_train_parity_with_shared_params_no_ac(self): + def test_train_parity_with_shared_params(self): self.run_subtests( { "reshard_after_forward": [False, True], - "use_activation_checkpointing": [False], - }, - self._test_train_shared_params, - ) - - @skip_if_lt_x_gpu(2) - def test_train_parity_with_shared_params_ac(self): - self.run_subtests( - { - "reshard_after_forward": [False, True], - "use_activation_checkpointing": [True], + "use_activation_checkpointing": [False, True], }, self._test_train_shared_params, ) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index b325a9601e25..60813409104c 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -867,7 +867,9 @@ def run_subtests( # Map keyword to chosen value subtest_kwargs = dict(zip(subtest_config_keys, values)) with cls_inst.subTest(**subtest_kwargs): + torch._dynamo.reset() test_fn(*test_args, **test_kwargs, **subtest_kwargs) + torch._dynamo.reset() c10d.barrier() diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 4e266117c13b..da9dc2ef4e3c 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -9,7 +9,7 @@ from contextlib import nullcontext from copy import deepcopy from enum import auto, Enum -from functools import partial, wraps +from functools import wraps from typing import ( Any, Callable, @@ -1086,6 +1086,12 @@ def setUp(self): def run_subtests(self, *args, **kwargs): return run_subtests(self, *args, **kwargs) + def perThreadSetUp(self): + torch._dynamo.reset() + + def perThreadTearDown(self): + torch._dynamo.reset() + class FSDPTest(MultiProcessTestCase): def setUp(self): @@ -1156,7 +1162,9 @@ def _run(cls, rank, test_name, file_name, pipe): # immediately exiting due to a skip doesn't cause flakiness. dist.barrier(device_ids=device_ids) + torch._dynamo.reset() self.run_test(test_name, pipe) + torch._dynamo.reset() dist.barrier(device_ids=device_ids) @@ -1416,45 +1424,49 @@ def _test_fsdp_parity( def test_compiled_fsdp(compile_compute_on_module: Optional[type] = None): def fully_shard_with_compiled_compute(*args, **kwargs): - # compile ``module._call_impl`` - # to showcase how to include user-registered hooks + torch.distributed._composable.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator] if compile_compute_on_module is None or isinstance( args[0], compile_compute_on_module ): args[0].compile() - return torch.distributed._composable.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator] - class FullyShardPatch(Enum): - # apply ``partial`` in order to use ``Enum.value`` - EAGER = partial(torch.distributed._composable.fsdp.fully_shard) # type: ignore[var-annotated, arg-type] - COMPILED_COMPUTE = partial(fully_shard_with_compiled_compute) # type: ignore[arg-type] - # add FULL for tracing FSDP + class FullyShardMode(Enum): + EAGER = auto() + COMPILED_COMPUTE = auto() def decorator(func): @wraps(func) def wrapper(*args, **kwargs): original_fully_shard = torch.distributed._composable.fsdp.fully_shard - for fully_shard_patch in FullyShardPatch: - if fully_shard_patch != FullyShardPatch.EAGER and not has_triton(): + for mode in FullyShardMode: + if mode != FullyShardMode.EAGER and not has_triton(): warnings.warn("Inductor on GPU needs Triton and recent GPU arch") continue - imported_fully_shard = ( - f"{func.__module__}.{original_fully_shard.__name__}" - ) - with mock.patch( - imported_fully_shard, - fully_shard_patch.value, - ): - func(*args, **kwargs) - torch.distributed.barrier() - # mock.patch.__exit__ does not work with multi-thread - # thread 1 set {func.__module__}.fully_shard - # thread 2 read {func.__module__}.fully_shard and thought it is original - # hence we manually reset them after __exit__ - import_path, _ = mock._get_target(imported_fully_shard) # type: ignore[attr-defined] - setattr( - import_path(), original_fully_shard.__name__, original_fully_shard - ) + # barrier to ensure thread reading the same value + original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks + original_compile_threads = torch._inductor.config.compile_threads + torch.distributed.barrier() + + if mode == FullyShardMode.EAGER: + fully_shard_patch = original_fully_shard + elif mode == FullyShardMode.COMPILED_COMPUTE: + torch._dynamo.config.skip_fsdp_hooks = True + torch._inductor.config.compile_threads = 1 + fully_shard_patch = fully_shard_with_compiled_compute # type: ignore[assignment] + else: + raise NotImplementedError( + f"Need to implement FullyShardMode={mode}" + ) + + # fully_shard is imported as a global + # through `from ... import fully_shard` + func.__globals__[original_fully_shard.__name__] = fully_shard_patch + func(*args, **kwargs) + # other threads use patched func before this thread restores + torch.distributed.barrier() + func.__globals__[original_fully_shard.__name__] = original_fully_shard + torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks + torch._inductor.config.compile_threads = original_compile_threads return wrapper From 3acbfd602ee9d2e63c1fbcd40bc11982f8df5105 Mon Sep 17 00:00:00 2001 From: Arun Pa Date: Wed, 5 Jun 2024 17:44:47 +0000 Subject: [PATCH 376/706] Document torch.utils.collect_env.get_env_info function (#128021) Fixes #127911 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128021 Approved by: https://github.com/malfet --- torch/utils/collect_env.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 6cbf598156b0..b2254312109d 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -434,6 +434,21 @@ def is_xnnpack_available(): return "N/A" def get_env_info(): + """ + Collects environment information to aid in debugging. + + The returned environment information contains details on torch version, is debug build + or not, cuda compiled version, gcc version, clang version, cmake version, operating + system, libc version, python version, python platform, CUDA availability, CUDA + runtime version, CUDA module loading config, GPU model and configuration, Nvidia + driver version, cuDNN version, pip version and versions of relevant pip and + conda packages, HIP runtime version, MIOpen runtime version, + Caching allocator config, XNNPACK availability and CPU information. + + Returns: + SystemEnv (namedtuple): A tuple containining various environment details + and system information. + """ run_lambda = run pip_version, pip_list_output = get_pip_packages(run_lambda) From bb68b54be0072a8911653a84088044460bf8fdfd Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Jun 2024 23:34:07 -0700 Subject: [PATCH 377/706] [BE][ptd_fb_test][1/N] Enable testslide (#127512) This change allows to enable Testslide, which gives us more readable output, import time, etc. The PR is previously stamped https://github.com/pytorch/pytorch/pull/126460 but the old PR has some ghexport issue. Differential Revision: [D57919583](https://our.internmc.facebook.com/intern/diff/D57919583/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127512 Approved by: https://github.com/wz337, https://github.com/Skylion007 --- torch/testing/_internal/common_distributed.py | 6 +++++- torch/testing/_internal/common_utils.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 60813409104c..80dc47210471 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -544,7 +544,11 @@ def wrapper(self): # Constructor patches current instance test method to # assume the role of the main process and join its subprocesses, # or run the underlying test function. - def __init__(self, method_name: str = "runTest") -> None: + def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None: + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName super().__init__(method_name) fn = getattr(self, method_name) setattr(self, method_name, self.join_or_run(fn)) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 88cba64052c2..a85c44fe1e05 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2596,7 +2596,11 @@ def rel_tol(self, prec: float) -> None: # the test, skip it instead. _ignore_not_implemented_error = False - def __init__(self, method_name='runTest'): + def __init__(self, method_name='runTest', methodName='runTest'): + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName super().__init__(method_name) test_method = getattr(self, method_name, None) From 6412c6060cf86839f2a8478b524ad4ad5be36623 Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 4 Jun 2024 22:50:44 -0700 Subject: [PATCH 378/706] [reland] Refresh OpOverloadPacket if a new OpOverload gets added (#128000) If a user accesses an OpOverloadPacket, then creates a new OpOverload, then uses the OpOverloadPacket, the new OpOverload never gets hit. This is because OpOverloadPacket caches OpOverloads when it is constructed. This PR fixes the problem by "refreshing" the OpOverloadPacket if a new OpOverload gets constructed and the OpOverloadPacket exists. Test Plan: - new tests This is the third land attempt. The first one was reverted for breaking internal tests, the second was reverted for being erroneously suspected of causing a perf regression. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128000 Approved by: https://github.com/albanD --- test/jit/test_list_dict.py | 31 ++++++++++++++++++++++++++++++- test/test_custom_ops.py | 24 ++++++++++++++++++++++++ torch/_ops.py | 26 +++++++++++++++++++++----- torch/library.py | 16 +++++++++++++++- 4 files changed, 90 insertions(+), 7 deletions(-) diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index f3d314dbac77..90fa24e43506 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -5,7 +5,7 @@ import sys import types import unittest -from collections import OrderedDict +from collections import defaultdict, OrderedDict from textwrap import dedent from typing import Any, Dict, List, NamedTuple, Optional, Tuple @@ -2966,3 +2966,32 @@ def test_reference_semantics(self): self.assertEqual(len(l), 3) self.assertTrue(3 in l) self.assertEqual(l[2], 3) + + def test_defaultdict(self): + def get_dict(): + test_dict = defaultdict(list) + return test_dict + + class Test(torch.nn.Module): + segments_groupby_col: Dict[str, List[str]] + + def __init__(self): + super().__init__() + self.segments_groupby_col = get_dict() + self.col1 = "a" + self.col2 = "b" + + def forward(self): + if self.col1 in self.segments_groupby_col.keys(): + return 1 + else: + return 2 + + test = Test() + test_script = torch.jit.script(test) + test_script.segments_groupby_col + + # Smoketest for flakiness. Takes around 2s. + for i in range(300): + test = Test() + test_script = torch.jit.script(test) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 1239ff8e0ebd..e2af3efaa98a 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -2850,6 +2850,30 @@ def f(x: Tensor) -> Tensor: y = f(x) self.assertEqual(y, x.sin()) + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_overloading(self): + called_f = 0 + called_f1 = 0 + + @torch.library.custom_op("_torch_testing::f", mutates_args=()) + def f(x: Tensor) -> Tensor: + nonlocal called_f + called_f += 1 + return x.clone() + + x = torch.randn(2, 3) + torch.ops._torch_testing.f(x) + self.assertEqual(called_f, 1) + + @torch.library.custom_op("_torch_testing::f.overload", mutates_args=()) + def f1(x: Tensor, y: Tensor) -> Tensor: + nonlocal called_f1 + called_f1 += 1 + return x.clone() + + torch.ops._torch_testing.f(x, x) + self.assertEqual(called_f1, 1) + def test_disallows_output_aliasing(self): @torch.library.custom_op("_torch_testing::f", mutates_args=()) def f(x: Tensor) -> Tensor: diff --git a/torch/_ops.py b/torch/_ops.py index 0b19c75a51aa..83a7b6b849df 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1161,8 +1161,10 @@ def __getattr__(self, op_name): # for overloads and raise an exception if there are more than one. namespace_name = self.name qualified_op_name = f"{namespace_name}::{op_name}" + module_name = self.__module__ + "." + namespace_name + try: - op, overload_names = torch._C._jit_get_operation(qualified_op_name) + op, overload_names = _get_packet(qualified_op_name, module_name) if op is None: raise AttributeError( f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" @@ -1174,10 +1176,7 @@ def __getattr__(self, op_name): f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" ) from e - # let the script frontend know that op is identical to the builtin op - # with qualified_op_name - torch.jit._builtins._register_builtin(op, qualified_op_name) - op.__module__ = self.__module__ + "." + namespace_name + op.__module__ = module_name opoverloadpacket = OpOverloadPacket( qualified_op_name, op_name, op, overload_names ) @@ -1189,6 +1188,23 @@ def __getattr__(self, op_name): return opoverloadpacket +def _get_packet(qualname, op_module): + op, overload_names = torch._C._jit_get_operation(qualname) + if op is not None: + # let the script frontend know that op is identical to the builtin op + # with qualified_op_name + torch.jit._builtins._register_builtin(op, qualname) + op.__module__ = op_module + return op, overload_names + + +def _refresh_packet(packet): + op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__) + assert op is not None + packet._op = op + packet._overload_names = overload_names + + class _PyOpNamespace(_OpNamespace): def __init__(self, name, ops): super().__init__(name) diff --git a/torch/library.py b/torch/library.py index 3bd0a1b6bc8a..da8c5a1264a2 100644 --- a/torch/library.py +++ b/torch/library.py @@ -109,8 +109,22 @@ def define(self, schema, alias_analysis="", *, tags=()): assert self.m is not None if isinstance(tags, torch.Tag): tags = (tags,) + + name = schema.split("(")[0] + packet_name = name.split(".")[0] if "." in name else name + has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(getattr(torch.ops, self.ns), packet_name) + result = self.m.define(schema, alias_analysis, tuple(tags)) - qualname = self.ns + "::" + schema.split("(")[0] + name = schema.split("(")[0] + qualname = self.ns + "::" + name + + # If the OpOverloadPacket exists already, then this means we're adding a + # new OpOverload for it. Refresh the packet to include the new OpOverload. + if has_preexisting_packet: + ns = getattr(torch.ops, self.ns) + packet = getattr(ns, packet_name) + torch._ops._refresh_packet(packet) + self._op_defs.add(qualname) _defs.add(qualname) return result From 6e545392cd27b9ee14a86cf6189c71211e79205d Mon Sep 17 00:00:00 2001 From: atalman Date: Wed, 5 Jun 2024 18:31:26 +0000 Subject: [PATCH 379/706] Move nongpu workflows from trunk to periodic (#128049) We don't need to run them on every PR. These are used to test for graceful degradation of GPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128049 Approved by: https://github.com/clee2000 --- .github/workflows/periodic.yml | 48 ++++++++++++++++++++++++++++++++++ .github/workflows/trunk.yml | 48 ---------------------------------- 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 925bca54c074..042a1442928f 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -38,6 +38,54 @@ jobs: id-token: write contents: read + linux-focal-cuda12_1-py3_10-gcc9-build: + name: linux-focal-cuda12.1-py3.10-gcc9 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.1-py3.10-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + test-matrix: | + { include: [ + { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda12_1-py3_10-gcc9-test: + name: linux-focal-cuda12.1-py3.10-gcc9 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-focal-cuda12_1-py3_10-gcc9-build + - target-determination + with: + build-environment: linux-focal-cuda12.1-py3.10-gcc9 + docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.test-matrix }} + + linux-focal-cuda12_4-py3_10-gcc9-build: + name: linux-focal-cuda12.4-py3.10-gcc9 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + test-matrix: | + { include: [ + { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda12_4-py3_10-gcc9-test: + name: linux-focal-cuda12.4-py3.10-gcc9 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-focal-cuda12_4-py3_10-gcc9-build + - target-determination + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.test-matrix }} + parallelnative-linux-jammy-py3_8-gcc11-build: name: parallelnative-linux-jammy-py3.8-gcc11 uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 77f54f937ad0..5efedab0cfb3 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -34,30 +34,6 @@ jobs: id-token: write contents: read - linux-focal-cuda12_1-py3_10-gcc9-build: - name: linux-focal-cuda12.1-py3.10-gcc9 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 - test-matrix: | - { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda12_1-py3_10-gcc9-test: - name: linux-focal-cuda12.1-py3.10-gcc9 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-focal-cuda12_1-py3_10-gcc9-build - - target-determination - with: - build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.test-matrix }} - libtorch-linux-focal-cuda12_1-py3_7-gcc9-debug-build: name: libtorch-linux-focal-cuda12.1-py3.7-gcc9-debug uses: ./.github/workflows/_linux-build.yml @@ -83,30 +59,6 @@ jobs: { config: "default", shard: 1, num_shards: 1 }, ]} - linux-focal-cuda12_4-py3_10-gcc9-build: - name: linux-focal-cuda12.4-py3.10-gcc9 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 - test-matrix: | - { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda12_4-py3_10-gcc9-test: - name: linux-focal-cuda12.4-py3.10-gcc9 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-focal-cuda12_4-py3_10-gcc9-build - - target-determination - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.test-matrix }} - libtorch-linux-focal-cuda12_4-py3_7-gcc9-debug-build: name: libtorch-linux-focal-cuda12.4-py3.7-gcc9-debug uses: ./.github/workflows/_linux-build.yml From 72e863df27a7534021efa8b2375f45293b58c04a Mon Sep 17 00:00:00 2001 From: rk7697 <91646263+rk7697@users.noreply.github.com> Date: Wed, 5 Jun 2024 20:02:33 +0000 Subject: [PATCH 380/706] Update _learnable_fake_quantize.py (#127993) Remove sentence "For literature references, please see the class _LearnableFakeQuantizePerTensorOp." and add "s" to "support" (Possibly) Fixes #99107 (But not sure, sorry) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127993 Approved by: https://github.com/jerryzh168 --- torch/ao/quantization/_learnable_fake_quantize.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch/ao/quantization/_learnable_fake_quantize.py b/torch/ao/quantization/_learnable_fake_quantize.py index 6827ae35533c..cdf44c5ea7b2 100644 --- a/torch/ao/quantization/_learnable_fake_quantize.py +++ b/torch/ao/quantization/_learnable_fake_quantize.py @@ -8,9 +8,8 @@ class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase): r"""Generalized extension of the FakeQuantize module in fake_quantize.py. This is an extension of the FakeQuantize module in fake_quantize.py, which - supports more generalized lower-bit quantization and support learning of the scale - and zero point parameters through backpropagation. For literature references, - please see the class _LearnableFakeQuantizePerTensorOp. + supports more generalized lower-bit quantization and supports learning of the scale + and zero point parameters through backpropagation. In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize module also includes the following attributes to support quantization parameter learning. From 626dc934d1356180153983172ab13e827bceb6e3 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 4 Jun 2024 22:24:31 -0700 Subject: [PATCH 381/706] [dynamo][pippy] Hotfix for nn_module_stack for pippy usecase (#127972) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127972 Approved by: https://github.com/ydwu4 --- torch/_dynamo/variables/nn_module.py | 31 +++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index e1848de97935..0a6bad4730dd 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -67,9 +67,38 @@ def convert_to_fake(x): mod._infer_parameters(mod, fake_args, fake_kwargs) +def cleanup_source_for_nn_module_stack(source): + # TODO(anijain2305, export-team) This is a bad hack to fix the nn module + # fully_qualified_name to work with export/unflatten. It converts + # mod._modules['net1'] to mod.net1. + + # This type of source occurs when we use UnspecializedNNModule variable + # because unspecialized nn module variable inlines module __getattr__ calls. + # For export, we rely heavily on NNModuleVariable and do not support + # UnspecializedNNModule. But there is one case where this gets exposed - + # Pippy. Pippy uses export/unflatten (an export feature) and also + # monkepatches the `forward` method of a mod that forces Dynamo to use + # UnspecializedNNModule. Therefore, we will need proper work to retain the + # nn module stack when we let export rely on UnspecializedNNModule variable. + + # This does not work if we have recursively UnspecializedNNModule variables + # e.g. mod._modules['net1']._modules['net2']. This is unlikely to happen in + # Pippy so the hotfix is enough for Pippy. + + if ( + isinstance(source, GetItemSource) + and isinstance(source.base, AttrSource) + and isinstance(source.base.base, NNModuleSource) + and source.base.member == "_modules" + ): + return AttrSource(source.base.base, source.index) + return source + + @contextmanager def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module): - fully_qualified_name = source.name() + source_for_nn_module_stack = cleanup_source_for_nn_module_stack(source) + fully_qualified_name = source_for_nn_module_stack.name() try: tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__) yield From 8184cd85fcfe663019edb3c1e502e03dcbaba4f0 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 5 Jun 2024 07:53:38 -0700 Subject: [PATCH 382/706] [fake tensor] Set _is_param for base fake tensors for views (#127823) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127823 Approved by: https://github.com/eellison, https://github.com/ezyang ghstack dependencies: #127972 --- test/test_fake_tensor.py | 9 +++++++++ torch/_subclasses/meta_utils.py | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 7456feb45d82..e5b36c47048b 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -619,6 +619,15 @@ def test_data_dependent_operator(self): self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x)) + def test_parameter_view(self): + x = torch.nn.Parameter(torch.randn(4)) + x_view = x.view(4) + mode = FakeTensorMode() + fake_x_view = mode.from_tensor(x_view) + fake_x = mode.from_tensor(x) + self.assertFalse(isinstance(fake_x_view, torch.nn.Parameter)) + self.assertTrue(isinstance(fake_x, torch.nn.Parameter)) + def test_tolist(self): shape_env = ShapeEnv() with FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env): diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 3a3a6eda012f..5aeccce2e1ee 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -323,6 +323,7 @@ def describe_tensor( is_view=is_view, is_conj=t.is_conj(), is_neg=t.is_neg(), + is_parameter=isinstance(t, torch.nn.Parameter), is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v, is_nested=is_nested, is_functional=is_functional, @@ -453,6 +454,7 @@ class MetaTensorDesc: is_functional: bool = False is_conj: bool = False is_neg: bool = False + is_parameter: bool = False stride: Optional[Tuple[int, ...]] = None storage_offset: int = 0 # NB: We have a choice whether or not to store the id or a direct pointer @@ -1535,6 +1537,10 @@ def is_c_of_r(complex_dtype, real_dtype): # Need to reflect this in the generated FakeTensor. if t.storage is not None and t.storage.size == 0: r.untyped_storage().resize_(0) + + if t.is_parameter: + r._is_param = True + self.set_tensor_memo(t, r) return self.get_tensor_memo(t) From 01694eaa56adb343f5d3d15b53d2962615dafe17 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Wed, 5 Jun 2024 21:01:36 +0000 Subject: [PATCH 383/706] Move cuda 12.4 jobs to periodic for both pull and inductor (#127825) Moves 12.4 sm86/a10g jobs in pull to trunk Moves 12.4 cuda non sm86 jobs to periodic Moves 12.4 jobs in inductor to inductor-periodic, except inductor_timm which seems to give important signal There has been a lot of queueing for cuda runners due to the addition of jobs for cuda 12.4, so move those jobs to other workflows that are run less often Co-authored-by: Andrey Talman Pull Request resolved: https://github.com/pytorch/pytorch/pull/127825 Approved by: https://github.com/ZainRizvi, https://github.com/nWEIdia, https://github.com/atalman, https://github.com/malfet --- .github/workflows/inductor-periodic.yml | 90 +++++++++++++++++++++++++ .github/workflows/inductor.yml | 66 +----------------- .github/workflows/periodic.yml | 11 ++- .github/workflows/pull.yml | 55 --------------- .github/workflows/trunk.yml | 27 ++++++++ 5 files changed, 128 insertions(+), 121 deletions(-) diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 6f8c06ed030b..34b3dc8101f2 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -56,3 +56,93 @@ jobs: test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_10-gcc9-inductor-build: + # Should be synced with the one in inductor.yml, but this doesn't run inductor_timm + name: cuda12.4-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + with: + sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, + { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_10-gcc9-inductor-test: + name: cuda12.4-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build + with: + sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-test + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }} + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp: + name: cuda12.4-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks + cuda-arch-list: '8.0' + test-matrix: | + { include: [ + { config: "inductor_torchbench_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.gcp.a100" }, + ]} + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_10-gcc9-inductor-test-gcp: + name: cuda12.4-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.test-matrix }} + use-gha: anything-non-empty-to-use-gha + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_12-gcc9-inductor-build: + name: cuda12.4-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda12_4-py3_12-gcc9-inductor-test: + name: cuda12.4-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_4-py3_12-gcc9-inductor-build + with: + build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 0f9c81104f9f..08d3b9fcfb24 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -129,32 +129,18 @@ jobs: test-matrix: ${{ needs.linux-focal-cuda12_1-py3_12-gcc9-inductor-build.outputs.test-matrix }} linux-focal-cuda12_4-py3_10-gcc9-inductor-build: + # Should be synced with the one in inductor-periodic.yml but this only runs inductor_timm name: cuda12.4-py3.10-gcc9-sm86 uses: ./.github/workflows/_linux-build.yml with: + sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, - { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} @@ -164,59 +150,13 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build with: + sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-test build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp: - name: cuda12.4-py3.10-gcc9-sm80 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks - cuda-arch-list: '8.0' - test-matrix: | - { include: [ - { config: "inductor_torchbench_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.gcp.a100" }, - ]} - secrets: - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - linux-focal-cuda12_4-py3_12-gcc9-inductor-build: - name: cuda12.4-py3.12-gcc9-sm86 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks - cuda-arch-list: '8.6' - test-matrix: | - { include: [ - { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda12_4-py3_10-gcc9-inductor-test-gcp: - name: cuda12.4-py3.10-gcc9-sm80 - uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.test-matrix }} - use-gha: anything-non-empty-to-use-gha - secrets: - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - linux-focal-cuda12_4-py3_12-gcc9-inductor-test: - name: cuda12.4-py3.12-gcc9-sm86 - uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-cuda12_4-py3_12-gcc9-inductor-build - with: - build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.test-matrix }} - linux-jammy-cpu-py3_8-gcc11-inductor-build: name: linux-jammy-cpu-py3.8-gcc11-inductor uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 042a1442928f..bcbd7ad6a5b5 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -37,7 +37,6 @@ jobs: permissions: id-token: write contents: read - linux-focal-cuda12_1-py3_10-gcc9-build: name: linux-focal-cuda12.1-py3.10-gcc9 uses: ./.github/workflows/_linux-build.yml @@ -50,7 +49,6 @@ jobs: { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, ]} - linux-focal-cuda12_1-py3_10-gcc9-test: name: linux-focal-cuda12.1-py3.10-gcc9 uses: ./.github/workflows/_linux-test.yml @@ -64,12 +62,18 @@ jobs: linux-focal-cuda12_4-py3_10-gcc9-build: name: linux-focal-cuda12.4-py3.10-gcc9 - uses: ./.github/workflows/_linux-build.yml + uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 test-matrix: | { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "deploy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, @@ -82,6 +86,7 @@ jobs: - linux-focal-cuda12_4-py3_10-gcc9-build - target-determination with: + timeout-minutes: 360 build-environment: linux-focal-cuda12.4-py3.10-gcc9 docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.test-matrix }} diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 2b81e998bde5..808e8a3795e3 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -285,34 +285,6 @@ jobs: docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.test-matrix }} - linux-focal-cuda12_4-py3_10-gcc9-build: - name: linux-focal-cuda12.4-py3.10-gcc9 - uses: ./.github/workflows/_linux-build-label.yml - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "deploy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda12_4-py3_10-gcc9-test: - name: linux-focal-cuda12.4-py3.10-gcc9 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-focal-cuda12_4-py3_10-gcc9-build - - target-determination - with: - timeout-minutes: 360 - build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.test-matrix }} - linux-jammy-py3-clang12-mobile-build: name: linux-jammy-py3-clang12-mobile-build uses: ./.github/workflows/_linux-build-label.yml @@ -497,33 +469,6 @@ jobs: docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-sm86-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-sm86-build.outputs.test-matrix }} - linux-focal-cuda12_4-py3_10-gcc9-sm86-build: - name: linux-focal-cuda12.4-py3.10-gcc9-sm86 - uses: ./.github/workflows/_linux-build-label.yml - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 - cuda-arch-list: 8.6 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda12_4-py3_10-gcc9-sm86-test: - name: linux-focal-cuda12.4-py3.10-gcc9-sm86 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-focal-cuda12_4-py3_10-gcc9-sm86-build - - target-determination - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.test-matrix }} - linux-jammy-py3-clang12-executorch-build: name: linux-jammy-py3-clang12-executorch uses: ./.github/workflows/_linux-build-label.yml diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 5efedab0cfb3..4d4cb7672653 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -34,6 +34,33 @@ jobs: id-token: write contents: read + linux-focal-cuda12_4-py3_10-gcc9-sm86-build: + name: linux-focal-cuda12.4-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-build-label.yml + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + cuda-arch-list: 8.6 + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda12_4-py3_10-gcc9-sm86-test: + name: linux-focal-cuda12.4-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-focal-cuda12_4-py3_10-gcc9-sm86-build + - target-determination + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.test-matrix }} + libtorch-linux-focal-cuda12_1-py3_7-gcc9-debug-build: name: libtorch-linux-focal-cuda12.1-py3.7-gcc9-debug uses: ./.github/workflows/_linux-build.yml From 4123323effaf2cdf2a9aad2cc034d1d4307494f7 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Wed, 5 Jun 2024 21:27:43 +0000 Subject: [PATCH 384/706] [ONNX] Single function for torch.onnx.export and torch.onnx.dynamo_export (#127974) Add `dynamo: bool = True` as a switch in `torch.onnx.export` to provide users an option to try `torch.onnx.dynamo_export`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127974 Approved by: https://github.com/justinchuby --- test/onnx/dynamo/test_exporter_api.py | 103 ++++++++++++++++++++++++++ torch/onnx/utils.py | 48 +++++++++++- 2 files changed, 147 insertions(+), 4 deletions(-) diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py index cc60b975a5eb..30bfd27483b9 100644 --- a/test/onnx/dynamo/test_exporter_api.py +++ b/test/onnx/dynamo/test_exporter_api.py @@ -26,6 +26,13 @@ def forward(self, x): return (y, z) +class SampleModelTwoInputs(torch.nn.Module): + def forward(self, x, b): + y = x + b + z = y.relu() + return (y, z) + + class _LargeModel(torch.nn.Module): def __init__(self): super().__init__() @@ -221,5 +228,101 @@ def test_serialize_succeeds_when_model_greater_than_2gb(self): serializer.serialize(onnx_program, io.BytesIO()) +class TestONNXExportWithDynamo(common_utils.TestCase): + def test_args_normalization_with_no_kwargs(self): + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + SampleModelTwoInputs(), torch.randn(1, 1, 2), torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), torch.randn(1, 1, 2)), + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_args_normalization_with_kwargs(self): + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + SampleModelTwoInputs(), torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}), + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_args_normalization_with_empty_dict_at_the_tail(self): + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + SampleModelTwoInputs(), torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}, {}), + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_dynamic_axes_enable_dynamic_shape(self): + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + SampleModelTwoInputs(), + torch.randn(1, 1, 2), + b=torch.randn(1, 1, 2), + export_options=ExportOptions(dynamic_shapes=True), + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}, {}), + dynamic_axes={"b": [0, 1, 2]}, + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_raises_unrelated_parameters_warning(self): + message = ( + "f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, " + "do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and " + "autograd_inlining are not supported for dynamo export at the moment." + ) + + with self.assertWarnsOnceRegex(UserWarning, message): + _ = torch.onnx.export( + SampleModel(), + (torch.randn(1, 1, 2),), + dynamo=True, + ) + + def test_raises_unsupported_specific_dynamic_axes_warning(self): + message = ( + "Specified dynamic axes is not supported for dynamo export at the moment." + ) + + with self.assertWarnsOnceRegex(UserWarning, message): + _ = torch.onnx.export( + SampleModel(), + (torch.randn(1, 1, 2),), + dynamic_axes={"input": [0, 1, 2]}, + dynamo=True, + ) + + def test_saved_f_exists_after_export(self): + with common_utils.TemporaryFileName(suffix=".onnx") as path: + _ = torch.onnx.export( + SampleModel(), torch.randn(1, 1, 2), path, dynamo=True + ) + self.assertTrue(os.path.exists(path)) + + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index f5206d425b4d..191df45ac9ef 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -186,11 +186,10 @@ def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx) -@_beartype.beartype def export( model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction], args: Union[Tuple[Any, ...], torch.Tensor], - f: Union[str, io.BytesIO], + f: Optional[Union[str, io.BytesIO]] = None, export_params: bool = True, verbose: bool = False, training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, @@ -206,7 +205,8 @@ def export( custom_opsets: Optional[Mapping[str, int]] = None, export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False, autograd_inlining: Optional[bool] = True, -) -> None: + dynamo: bool = False, +) -> Optional[torch.onnx.ONNXProgram]: r"""Exports a model into ONNX format. If ``model`` is not a :class:`torch.jit.ScriptModule` nor a @@ -500,6 +500,8 @@ def forward(self, x): autograd_inlining (bool, default True): Flag used to control whether to inline autograd functions. Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + dynamo (bool, default False): Whether to export the model with Dynamo instead of TorchScript. + Raises: :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it @@ -508,6 +510,43 @@ def forward(self, x): All errors are subclasses of :class:`errors.OnnxExporterError`. """ + if dynamo: + # Unsupported parameters for dynamo export + # TODO: These are not supported AT THE TIME + warnings.warn( + "f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, " + "do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and " + "autograd_inlining are not supported for dynamo export at the moment." + ) + # TODO: check args normalization + args = _decide_input_format(model, args) + kwargs = {} + if args is not None and isinstance(args[-1], dict): + kwargs = args[-1] + args = args[:-1] + # TODO: refactor this when we have migrated ExportedProgam and + # needs users to specify dynamic_axes + if dynamic_axes is None or not isinstance(dynamic_axes, dict): + dynamic_shapes = False + else: + dynamic_shapes = True + warnings.warn( + "Specified dynamic axes is not supported for dynamo export at the moment." + ) + # TODO: expose more ExportOptions? + export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_shapes) + onnx_program = torch.onnx.dynamo_export( + model, *args, **kwargs, export_options=export_options + ) + if f is not None: + onnx_program.save(f) + return onnx_program + + if f is None: + raise ValueError( + "Export destination must be specified for torchscript-onnx export." + ) + _export( model, args, @@ -527,6 +566,8 @@ def forward(self, x): autograd_inlining=autograd_inlining, ) + return None + @_beartype.beartype def _is_constant_tensor_list(node): @@ -870,7 +911,6 @@ def _decide_input_format(model, args): warnings.warn("No input args, skipping _decide_input_format") except Exception as e: warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}") - return args From a7c596870d92f15cf723f53139bf1a63ac6da4a2 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 5 Jun 2024 21:53:49 +0000 Subject: [PATCH 385/706] [BE][Eazy] remove `torch.torch.xxx` usages (#127800) NB: `torch` is exposed in `torch/__init__.py`. So there can be `torch.torch.torch.xxx`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127800 Approved by: https://github.com/peterbell10, https://github.com/kit1980, https://github.com/malfet --- test/dynamo/test_ctx_manager.py | 4 ++-- test/test_cuda.py | 4 ++-- torch/_dynamo/convert_frame.py | 4 +--- torch/_inductor/lowering.py | 2 +- torch/nn/modules/pooling.py | 6 +++--- 5 files changed, 9 insertions(+), 11 deletions(-) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 651c392f5dd2..47f8e8eeb863 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -497,7 +497,7 @@ def forward(self, x): a_float32 = torch.rand((8, 8), device="cuda") b_float32 = torch.rand((8, 8), device="cuda") - with torch.cuda.amp.autocast(dtype=torch.torch.float64): + with torch.cuda.amp.autocast(dtype=torch.float64): c_float64 = torch.mm(a_float32, b_float32) return c_float64 @@ -796,7 +796,7 @@ def forward(self, x): self.assertEqual(exported.dtype, real_dtype) self.assertEqual(exported.device.index, 0) - self.assertEqual(exported.dtype, torch.torch.float16) + self.assertEqual(exported.dtype, torch.float16) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_autocast_arguments_binding(self): diff --git a/test/test_cuda.py b/test/test_cuda.py index 785f0499df05..7ec86bd6f47b 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -380,10 +380,10 @@ def test_cublas_workspace_explicit_allocation(self): def check_workspace_size(inp): torch._C._cuda_clearCublasWorkspaces() - start = torch.torch.cuda.memory_stats()["active_bytes.all.allocated"] + start = torch.cuda.memory_stats()["active_bytes.all.allocated"] with torch.no_grad(): torch.matmul(inp, inp) - finish = torch.torch.cuda.memory_stats()["active_bytes.all.allocated"] + finish = torch.cuda.memory_stats()["active_bytes.all.allocated"] return finish - start # check default diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 37ff5a8a299b..88fb2a85bca2 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -178,9 +178,7 @@ def _fn(*args, **kwargs): finally: cleanup.close() torch._C._set_grad_enabled(prior_grad_mode) - torch.torch.autograd.grad_mode._enter_inference_mode( - prior_inference_mode - ) + torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) torch.use_deterministic_algorithms( prior_deterministic, warn_only=prior_warn_only ) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 20b0082eb1d9..0a1909890e69 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2421,7 +2421,7 @@ def inner_fn(idx): ops.index_expr( ModularIndexing(idx[dim] - start, 1, step), torch.int64 ), - ops.constant(0, torch.torch.int64), + ops.constant(0, torch.int64), ) ) assert mask diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 3f02bb63a849..61ce56390981 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -381,9 +381,9 @@ class MaxUnpool2d(_MaxUnpoolNd): [ 0., 14., 0., 16.]]]]) >>> # Now using output_size to resolve an ambiguous size for the inverse >>> input = torch.tensor([[[[ 1., 2., 3., 4., 5.], - [ 6., 7., 8., 9., 10.], - [11., 12., 13., 14., 15.], - [16., 17., 18., 19., 20.]]]]) + [ 6., 7., 8., 9., 10.], + [11., 12., 13., 14., 15.], + [16., 17., 18., 19., 20.]]]]) >>> output, indices = pool(input) >>> # This call will not work without specifying output_size >>> unpool(output, indices, output_size=input.size()) From ffaea656b5d8ff6518669494cc8f664b94f8e8b1 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 5 Jun 2024 22:56:29 +0000 Subject: [PATCH 386/706] WorkerServer: add support for binding to TCP (#127986) This adds support for the WorkerServer binding to TCP as well as the existing unix socket support. ```py server = _WorkerServer("", 1234) ``` Test plan: Added unit test ``` python test/distributed/elastic/test_control_plane.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127986 Approved by: https://github.com/c-p-i-o --- .../distributed/elastic/test_control_plane.py | 11 ++++++ .../c10d/control_plane/WorkerServer.cpp | 38 +++++++++++-------- .../c10d/control_plane/WorkerServer.hpp | 2 +- torch/csrc/distributed/c10d/init.cpp | 7 ++-- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index c9ae512f2718..775b062451b1 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -81,6 +81,17 @@ def test_dump_nccl_trace_pickle(self) -> None: self.assertEqual(resp.status, 200) out = pickle.loads(resp.data) + def test_tcp(self) -> None: + import requests + + from torch._C._distributed_c10d import _WorkerServer + + server = _WorkerServer("", 1234) + out = requests.get("http://localhost:1234/handler/") + self.assertEqual(out.status_code, 200) + + server.shutdown() + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index 14d287e9607f..e4b649d888dd 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -71,15 +71,7 @@ std::string jsonStrEscape(const std::string& str) { } } // namespace -WorkerServer::WorkerServer(const std::string& socketFile) { - // using unix sockets - server_.set_address_family(AF_UNIX); - - // adjust keep alives as it stops the server from shutting down quickly - server_.set_keep_alive_timeout(1); // second, default is 5 - server_.set_keep_alive_max_count( - 30); // wait max 30 seconds before closing socket - +WorkerServer::WorkerServer(const std::string& hostOrFile, int port) { server_.Get("/", [](const httplib::Request& req, httplib::Response& res) { res.set_content( R"BODY(

torch.distributed.WorkerServer

@@ -139,13 +131,29 @@ WorkerServer::WorkerServer(const std::string& socketFile) { } }); - if (std::filesystem::exists(socketFile)) { - throw std::runtime_error(fmt::format("{} already exists", socketFile)); - } + // adjust keep alives as it stops the server from shutting down quickly + server_.set_keep_alive_timeout(1); // second, default is 5 + server_.set_keep_alive_max_count( + 30); // wait max 30 seconds before closing socket + + if (port == -1) { + // using unix sockets + server_.set_address_family(AF_UNIX); - C10D_WARNING("Server listening to {}", socketFile); - if (!server_.bind_to_port(socketFile, 80)) { - throw std::runtime_error(fmt::format("Error binding to {}", socketFile)); + if (std::filesystem::exists(hostOrFile)) { + throw std::runtime_error(fmt::format("{} already exists", hostOrFile)); + } + + C10D_WARNING("Server listening to UNIX {}", hostOrFile); + if (!server_.bind_to_port(hostOrFile, 80)) { + throw std::runtime_error(fmt::format("Error binding to {}", hostOrFile)); + } + } else { + C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port); + if (!server_.bind_to_port(hostOrFile, port)) { + throw std::runtime_error( + fmt::format("Error binding to {}:{}", hostOrFile, port)); + } } serverThread_ = std::thread([this]() { diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp index 7d64038f0b01..a0b16ac192ba 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp @@ -14,7 +14,7 @@ namespace control_plane { class TORCH_API WorkerServer : public c10::intrusive_ptr_target { public: - WorkerServer(const std::string& socketFile); + WorkerServer(const std::string& hostOrFile, int port = -1); ~WorkerServer(); void shutdown(); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index c4b9a9823c84..ea8e6db9290b 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -3170,11 +3170,12 @@ such as `dist.all_reduce(tensor, async_op=True)`. module, "_WorkerServer", R"( )") .def( - py::init([](const std::string& socketPath) { + py::init([](const std::string& hostOrFile, int port) { return c10::make_intrusive<::c10d::control_plane::WorkerServer>( - socketPath); + hostOrFile, port); }), - py::arg("socket_path")) + py::arg("host_or_file"), + py::arg("port") = -1) .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown); Py_RETURN_TRUE; } From e98662bed99df57b7d79f9fc1cbe670afc303235 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Wed, 5 Jun 2024 23:44:51 +0000 Subject: [PATCH 387/706] [DDP] Bucket handling: make first bucket size equal to bucket_cap_mb if it was set (#121640) The fist DDP bucket is always being created of the size of `dist._DEFAULT_FIRST_BUCKET_BYTES` (1 MiB) by default regardless of `bucket_cap_mb`. The proposal is to set `bucket_cap_mb` as the one main bucket size if it was supplied by the user. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121640 Approved by: https://github.com/wanchaol --- torch/nn/parallel/distributed.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index ef6034ade58e..e95a2d9ab030 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -548,7 +548,8 @@ class DistributedDataParallel(Module, Joinable): multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation. :attr:`bucket_cap_mb` controls the bucket size in - MegaBytes (MB). (default: 25) + MebiBytes (MiB). If ``None``, a default size of 25 MiB + will be used. (default: ``None``) find_unused_parameters (bool): Traverse the autograd graph from all tensors contained in the return value of the wrapped module's ``forward`` function. Parameters @@ -631,7 +632,7 @@ def __init__( dim=0, broadcast_buffers=True, process_group=None, - bucket_cap_mb=25, + bucket_cap_mb=None, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, @@ -788,7 +789,14 @@ def __init__( self.broadcast_bucket_size = int(250 * 1024 * 1024) # reduction bucket size - self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) + if bucket_cap_mb is None: + # default case (bucket cap is 25 MiB) + self.bucket_bytes_cap_default = True + self.bucket_bytes_cap = int(25 * 1024 * 1024) + else: + self.bucket_bytes_cap_default = False + self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) + # Whether to perform input tensor CPU to GPU copies on a side-stream self.use_side_stream_for_tensor_copies = ( os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1" @@ -1156,10 +1164,13 @@ def _ddp_init_helper( if static_graph is True or self.find_unused_parameters is False: bucket_size_limits = [sys.maxsize] else: - bucket_size_limits = [ - dist._DEFAULT_FIRST_BUCKET_BYTES, - self.bucket_bytes_cap, - ] + if self.bucket_bytes_cap_default: + bucket_size_limits = [ + dist._DEFAULT_FIRST_BUCKET_BYTES, + self.bucket_bytes_cap, + ] + else: + bucket_size_limits = [self.bucket_bytes_cap] ( bucket_indices, per_bucket_size_limits, @@ -1195,7 +1206,9 @@ def _ddp_init_helper( param_to_name_mapping, # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first # bucket. - dist._DEFAULT_FIRST_BUCKET_BYTES, + dist._DEFAULT_FIRST_BUCKET_BYTES + if self.bucket_bytes_cap_default + else self.bucket_bytes_cap, ) self.logger = dist.Logger(self.reducer) From 2fd75667b4f92dcf95486ba50d658fb53536cded Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 5 Jun 2024 23:46:29 +0000 Subject: [PATCH 388/706] [Caffe2]Remove Caffe2 scripts and benchmarks (#126747) Due to removal of Caffe2. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126747 Approved by: https://github.com/ezyang, https://github.com/malfet --- .../framework_overhead_benchmark/C2Module.py | 45 --- .../framework_overhead_benchmark.py | 50 +-- benchmarks/operator_benchmark/README.md | 95 +---- .../operator_benchmark/benchmark_caffe2.py | 202 ---------- .../operator_benchmark/benchmark_core.py | 16 +- .../operator_benchmark/benchmark_runner.py | 13 +- .../operator_benchmark/benchmark_utils.py | 8 - benchmarks/operator_benchmark/c2/__init__.py | 0 benchmarks/operator_benchmark/c2/add_test.py | 49 --- .../c2/batch_box_cox_test.py | 49 --- .../c2/batch_gather_test.py | 58 --- .../operator_benchmark/c2/clip_ranges_test.py | 54 --- .../operator_benchmark/c2/concat_test.py | 171 -------- .../operator_benchmark/c2/matmul_test.py | 50 --- .../operator_benchmark/c2/quantile_op_test.py | 48 --- .../operator_benchmark/c2/replace_nan_test.py | 44 --- .../tests/c2_cpu_gpu_forward_backward_test.py | 41 -- .../record_function_bench.py | 13 +- binaries/bench_gen/bench_gen.py | 118 ------ scripts/appveyor/install.bat | 10 - scripts/appveyor/install_cuda.bat | 22 -- scripts/model_zoo/update-caffe2-models.py | 175 -------- .../model_zoo/update-models-from-caffe2.py | 372 ------------------ 23 files changed, 29 insertions(+), 1674 deletions(-) delete mode 100644 benchmarks/framework_overhead_benchmark/C2Module.py delete mode 100644 benchmarks/operator_benchmark/benchmark_caffe2.py delete mode 100644 benchmarks/operator_benchmark/c2/__init__.py delete mode 100644 benchmarks/operator_benchmark/c2/add_test.py delete mode 100644 benchmarks/operator_benchmark/c2/batch_box_cox_test.py delete mode 100644 benchmarks/operator_benchmark/c2/batch_gather_test.py delete mode 100644 benchmarks/operator_benchmark/c2/clip_ranges_test.py delete mode 100644 benchmarks/operator_benchmark/c2/concat_test.py delete mode 100644 benchmarks/operator_benchmark/c2/matmul_test.py delete mode 100644 benchmarks/operator_benchmark/c2/quantile_op_test.py delete mode 100644 benchmarks/operator_benchmark/c2/replace_nan_test.py delete mode 100644 benchmarks/operator_benchmark/common/tests/c2_cpu_gpu_forward_backward_test.py delete mode 100755 binaries/bench_gen/bench_gen.py delete mode 100644 scripts/appveyor/install.bat delete mode 100644 scripts/appveyor/install_cuda.bat delete mode 100755 scripts/model_zoo/update-caffe2-models.py delete mode 100644 scripts/model_zoo/update-models-from-caffe2.py diff --git a/benchmarks/framework_overhead_benchmark/C2Module.py b/benchmarks/framework_overhead_benchmark/C2Module.py deleted file mode 100644 index 0b93836e5940..000000000000 --- a/benchmarks/framework_overhead_benchmark/C2Module.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np - -from utils import NUM_LOOP_ITERS - -from caffe2.python import core, workspace - -workspace.GlobalInit(["caffe2"]) - - -def add_blob(ws, blob_name, tensor_size): - blob_tensor = np.random.randn(*tensor_size).astype(np.float32) - ws.FeedBlob(blob_name, blob_tensor) - - -class C2SimpleNet: - """ - This module constructs a net with 'op_name' operator. The net consist - a series of such operator. - It initializes the workspace with input blob equal to the number of parameters - needed for the op. - Provides forward method to run the net niter times. - """ - - def __init__(self, op_name, num_inputs=1, debug=False): - self.input_names = [] - self.net = core.Net("framework_benchmark_net") - self.input_names = [f"in_{i}" for i in range(num_inputs)] - for i in range(num_inputs): - add_blob(workspace, self.input_names[i], [1]) - self.net.AddExternalInputs(self.input_names) - op_constructor = getattr(self.net, op_name) - op_constructor(self.input_names) - self.output_name = self.net._net.op[-1].output - print(f"Benchmarking op {op_name}:") - for _ in range(NUM_LOOP_ITERS): - output_name = self.net._net.op[-1].output - self.input_names[-1] = output_name[0] - assert len(self.input_names) == num_inputs - op_constructor(self.input_names) - workspace.CreateNet(self.net) - if debug: - print(self.net._net) - - def forward(self, niters): - workspace.RunNet(self.net, niters, False) diff --git a/benchmarks/framework_overhead_benchmark/framework_overhead_benchmark.py b/benchmarks/framework_overhead_benchmark/framework_overhead_benchmark.py index 8d1b52738522..826c4d283ee8 100644 --- a/benchmarks/framework_overhead_benchmark/framework_overhead_benchmark.py +++ b/benchmarks/framework_overhead_benchmark/framework_overhead_benchmark.py @@ -1,6 +1,5 @@ import argparse -from C2Module import C2SimpleNet from pt_wrapper_module import WrapperModule from SimpleAddModule import add_tensors_loop, SimpleAddModule @@ -19,9 +18,6 @@ --add-op --graph-mode --eager-mode (Runs both graph mode and eager mode) buck run @mode/opt :framework_overhead_benchmark -- --add-op --graph-mode (Runs only graph mode) -To run C2 benchmark: -buck run @mode/opt :framework_overhead_benchmark -- - --add-op --benchmark-c2-net """ SUPPORTED_OPS = {"add_op"} @@ -49,39 +45,22 @@ def benchmark_simple_fn(args, config, module_config, module_type, result): module_type: Type of the module to be wrapped. e.g. SimpleAddModule for add op. result: dictionary instance to be populated with the benchmark result (latency per iter). """ - benchmark_c2_net = args.benchmark_c2_net print(f"Benchmarking {module_type.__name__}") - if benchmark_c2_net: - op_name = module_config.c2_op - num_inputs = module_config.num_params - module = C2SimpleNet(op_name, num_inputs=num_inputs, debug=args.debug) - latency_per_iter_ms = benchmark_module(config, module) - result[op_name] = latency_per_iter_ms - else: - f_name = ( - module_config.pt_fn.__name__ - + ":Num Operands=" - + str(module_config.num_params) - ) - graph_mode_str = "Graph mode" + ":" + str(module_config.graph_mode) - result_key = ",".join((f_name, graph_mode_str)) - module = WrapperModule(module_type, module_config, args.debug, args.save) - latency_per_iter_ms = benchmark_module( - config, module, args.use_throughput_benchmark - ) - result[result_key] = latency_per_iter_ms + f_name = ( + module_config.pt_fn.__name__ + ":Num Operands=" + str(module_config.num_params) + ) + graph_mode_str = "Graph mode" + ":" + str(module_config.graph_mode) + result_key = ",".join((f_name, graph_mode_str)) + module = WrapperModule(module_type, module_config, args.debug, args.save) + latency_per_iter_ms = benchmark_module( + config, module, args.use_throughput_benchmark + ) + result[result_key] = latency_per_iter_ms def main(): parser = argparse.ArgumentParser() parser.add_argument("--op", default="add_op", dest="op", type=str) - parser.add_argument( - "--benchmark-c2-net", - "--benchmark_c2_net", - default=False, - dest="benchmark_c2_net", - action="store_true", - ) parser.add_argument( "--use-throughput-benchmark", "--use_throughput_benchmark", @@ -107,10 +86,6 @@ def main(): if args.op not in SUPPORTED_OPS: print(f"Op {args.op} is not supported: Supported ops are:{SUPPORTED_OPS}") return - assert not ( - args.benchmark_c2_net and args.use_throughput_benchmark - ), "Benchmarking of C2 net via throughput benchmarking is not yet supported" - num_warmup_iters = args.num_warmup_iters num_iters = args.num_iters config = BenchmarkConfig(num_warmup_iters, num_iters) @@ -120,10 +95,7 @@ def main(): result = {} if args.op == "add_op": num_params = 2 - if args.benchmark_c2_net: - module_config = ModuleConfig(None, "Sum", num_params, None) - else: - module_config = ModuleConfig(add_tensors_loop, None, num_params, graph_mode) + module_config = ModuleConfig(add_tensors_loop, None, num_params, graph_mode) benchmark_simple_fn(args, config, module_config, SimpleAddModule, result) print_results(result) diff --git a/benchmarks/operator_benchmark/README.md b/benchmarks/operator_benchmark/README.md index 549bb137a9d3..9bcfc5d03e19 100644 --- a/benchmarks/operator_benchmark/README.md +++ b/benchmarks/operator_benchmark/README.md @@ -1,6 +1,6 @@ -# PyTorch/Caffe2 Operator Micro-benchmarks +# PyTorch Operator Micro-benchmarks -This benchmark suite provides a systemic way to measure the performance of operators for a wide range of inputs. The generated benchmark data fully characterized the performance of an operator in terms of execution time and the efficiency of the PyTorch/Caffe2 frameworks used. +This benchmark suite provides a systemic way to measure the performance of operators for a wide range of inputs. The generated benchmark data fully characterized the performance of an operator in terms of execution time and the efficiency of the PyTorch frameworks used. ## Features @@ -8,7 +8,7 @@ Key Features: 1\. Language used: Python -2\. Supported Frameworks: PyTorch and Caffe2 +2\. Supported Frameworks: PyTorch 3\. Supported PyTorch mode: eager and JIT @@ -49,7 +49,7 @@ python -m benchmark_all_test ``` ## Code to support `torch.add` in the benchmark -The following example shows the code to support `torch.add` with 27 different tests. In the subpages of this wiki, we'll step through the complete flow of adding PyTorch and Caffe2 operators to the benchmark suite. Existing benchmarks for operators are in `pt` and `c2` directories and we highly recommend putting your new operators in those locations. +The following example shows the code to support `torch.add` with 27 different tests. In the subpages of this wiki, we'll step through the complete flow of adding PyTorch operators to the benchmark suite. Existing benchmarks for operators are in the `pt` directory and we highly recommend putting your new operators in those locations. ```python add_short_configs = op_bench.cross_product_configs( @@ -77,7 +77,7 @@ op_bench.generate_pt_test(add_short_configs, AddBenchmark) The output is intended to be a human readable format. Here is an example output for `torch.add`: ``` # ---------------------------------------- -# PyTorch/Caffe2 Operator Micro-benchmarks +# PyTorch Operator Micro-benchmarks # ---------------------------------------- # Tag : short @@ -146,7 +146,7 @@ python -m pt.add_test --tag-filter long ``` ## Adding New Operators to the Benchmark Suite -In the previous sections, we gave several examples to show how to run the already available operators in the benchmark suite. In the following sections, we'll step through the complete flow of adding PyTorch and Caffe2 operators to the benchmark suite. Existing benchmarks for operators are in `pt` and `c2` directories and we highly recommend putting your new operators in those directories as well. +In the previous sections, we gave several examples to show how to run the already available operators in the benchmark suite. In the following sections, we'll step through the complete flow of adding PyTorch operators to the benchmark suite. Existing benchmarks for operators are in the `pt` directory and we highly recommend putting your new operators in those directories as well. ### Add a New PyTorch Operator Let's say you want to measure the execution time of the following operator: @@ -260,55 +260,6 @@ if __name__ == "__main__": ``` That's it. You just added a new operator to the benchmark suite! - -### Add a New Caffe2 Operator -The steps to add a new Caffe2 operator is the same as that for a PyTorch operator. The code below shows how to add Caffe2 `Add` operator: -```python -import operator_benchmark as op_bench -from caffe2.python import core - -add_long_configs = op_bench.cross_product_configs( - M=[8, 64, 128], - N=range(2, 10, 3), - K=[2 ** x for x in range(0, 3)], - tags=["long"] -) - -add_short_configs = op_bench.config_list( - attrs=[ - [8, 16, 32], - [16, 16, 64], - [64, 64, 128], - ], - attr_names=["M", "N", "K"], - tags=["short"], -) - -class AddBenchmark(op_bench.Caffe2BenchmarkBase): - - def init(self, M, N, K): - self.input_one = self.tensor(M, N, K) - self.input_two = self.tensor(M, N, K) - self.output = self.tensor(M, N, K) - self.set_module_name("add") - - def forward(self): - op = core.CreateOperator( - "Add", [self.input_one, self.input_two], self.output, **self.args - ) - - return op - -op_bench.generate_c2_test(add_long_configs + add_short_configs, AddBenchmark) - -if __name__ == "__main__": - op_bench.benchmark_runner.main() -``` -There are two things worth mentioning in this code: -* `self.tensor` is a helper function which takes shapes and returns a Caffe2 blob. It is designed to make the tensor creation step easier compared to the standard Caffe2 way. -* `generate_c2_test` is used to register Caffe2 tests with the benchmark. - - ### Add a List of Operators In the previous sections, we introduced the steps required to add a single operator to the benchmark suite. There are scenarios where you want to extend the benchmark suite with a list of operators which can share the same inputs. For example, to benchmark `abs` and `acos` operators, you can use the same set of inputs for both. @@ -416,37 +367,3 @@ The example below shows the relevant code for that: self.input_one = torch.rand(M, N, K, requires_grad=True) generate_pt_gradient_test(long_configs + short_configs, TorchAddBenchmark) ``` -#### For Caffe2 Gradient Ops -To add Caffe2 gradient ops, we need to implement a new backward method in the benchmark class: -```python -class AddBenchmark(op_bench.Caffe2BenchmarkBase): - - def init(self, M, N, K): - self.input_one = self.tensor(M, N, K) - self.input_two = self.tensor(M, N, K) - self.input_one_grad = self.tensor(M, N, K) - self.input_two_grad = self.tensor(M, N, K) - self.output = self.tensor(M, N, K) - self.set_module_name("add") - - def forward(self): - op = core.CreateOperator( - "Add", [self.input_one, self.input_two], self.output, **self.args - ) - - return op - - def backward(self): - grad_op = core.CreateOperator( - "AddGradient", - [self.output, self.input_one, self.input_two], - [self.input_one_grad, self.input_two_grad], **self.args - ) - - return grad_op - -op_bench.generate_c2_gradient_test(long_configs + short_configs,AddBenchmark) -``` -After the class is implemented, we need to register the tests with `generate_c2_gradient_test` function. - -This concludes the overview of the operator benchmark suite. diff --git a/benchmarks/operator_benchmark/benchmark_caffe2.py b/benchmarks/operator_benchmark/benchmark_caffe2.py deleted file mode 100644 index 2d238e593fc9..000000000000 --- a/benchmarks/operator_benchmark/benchmark_caffe2.py +++ /dev/null @@ -1,202 +0,0 @@ -from collections import namedtuple - -import benchmark_utils -from benchmark_test_generator import _register_test - -from caffe2.proto import caffe2_pb2 -from caffe2.python import core, workspace - -from .benchmark_core import TestConfig - -"""Caffe2 performance microbenchmarks. - -This module contains Caffe2-specific functionalities for performance -microbenchmarks. -""" - - -class Caffe2BenchmarkBase: - """This is a base class used to create Caffe2 operator benchmark""" - - tensor_index = 0 - test_index = 0 - - def __init__(self): - self.args = {} - self.user_provided_name = None - self._num_inputs_require_grads = 0 - self._pass_count = 0 - - def _set_backward_test(self, is_backward): - pass - - def _device_option(self, device): - """This method is used to set device option.""" - if device not in ["cuda", "cpu"]: - raise ValueError("Missing attrs in configs") - - if "cuda" in device: - self.dev = core.DeviceOption(caffe2_pb2.CUDA, 0) - else: - self.dev = core.DeviceOption(caffe2_pb2.CPU) - return self.dev - - def tensor(self, shapes, dtype="float32", device="cpu"): - """A wapper function to create C2 tensor filled with random data. - The name/label of the tensor is returned and it is available - throughout the benchmark execution phase. - Args: - shapes: int or a sequence of ints to defining the shapes of the tensor - dtype: use the dtypes from numpy - (https://docs.scipy.org/doc/numpy/user/basics.types.html) - Return: - C2 tensor of dtype - """ - return self.feed_tensor(benchmark_utils.numpy_random(dtype, *shapes), device) - - def feed_tensor(self, tensor, device="cpu"): - """Similar to tensor, but can supply any data compatible with FeedBlob""" - blob_name = "blob_" + str(Caffe2BenchmarkBase.tensor_index) - dev = self._device_option(device) - with core.DeviceScope(dev): - workspace.FeedBlob(blob_name, tensor) - Caffe2BenchmarkBase.tensor_index += 1 - return blob_name - - def module_name(self): - """this is used to label the operator being benchmarked""" - if self.user_provided_name: - return self.user_provided_name - return self.__class__.__name__ - - def set_module_name(self, name): - self.user_provided_name = name - - def _value_to_str(self, value): - """if value is bool, we will convert it to 0 and 1""" - ret = value - if type(value) == bool: - ret = int(value) - return str(ret) - - def test_name(self, name_type="long", **kargs): - """this is a globally unique name which can be used to - label a specific test - """ - if name_type == "long": - test_name_str = [] - for key in kargs: - value = kargs[key] - test_name_str.append(key + self._value_to_str(value)) - name = (self.module_name() + "_" + "_".join(test_name_str)).replace(" ", "") - elif name_type == "short": - # this is used to generate test name based on unique index - name = "_".join( - [self.module_name(), "test", str(Caffe2BenchmarkBase.test_index)] - ) - Caffe2BenchmarkBase.test_index += 1 - return name - - def extract_inputs_tuple(self): - # add a dummy function here to match the interface of TorchBenchmarkBase - pass - - -class Caffe2OperatorTestCase: - """This class includes all the information needed to benchmark an operator. - op_bench: it's a user-defined class (child of Caffe2BenchmarkBase) - which includes input and operator, .etc - test_config: a namedtuple includes test_name, input_shape, tag, run_backward. - When run_backward is false, the run_forward method will be executed, otherwise - run_backward method will be executed. - """ - - def __init__(self, op_bench, test_config): - self.op_bench = op_bench - self.test_config = test_config - self.framework = "Caffe2" - - def run_forward(self, num_runs, print_per_iter=False, cuda_sync=False): - """Run the forward path of an operator in a loop""" - with core.DeviceScope(self.op_bench.dev): - op = self.op_bench.forward() - if not workspace.RunOperatorMultiple(op, num_runs): - raise ValueError(f"Unable to run operator test case: {self.test_name}") - - def run_backward(self, num_runs, print_per_iter=False): - """Run the backward path of an operator in a loop""" - with core.DeviceScope(self.op_bench.dev): - op = self.op_bench.backward() - if not workspace.RunOperatorMultiple(op, num_runs): - raise ValueError( - f"Unable to run operator gradient test case: {self.test_name}" - ) - - def _print_per_iter(self): - pass - - -def create_caffe2_op_test_case(op_bench, test_config): - test_case = Caffe2OperatorTestCase(op_bench, test_config) - test_config = test_case.test_config - op = test_case.op_bench - func_name = f"{op.module_name()}{test_case.framework}{str(test_config)}" - return (func_name, test_case) - - -OpMeta = namedtuple( - "OpMeta", - "op_type num_inputs input_dims input_types \ - output_dims num_outputs args device", -) - - -def generate_c2_test_from_ops(ops_metadata, bench_op, tags): - """ - This function is used to generate Caffe2 tests based on the metadata - of operators. The metadata includes seven fields which are 1) op_type: - the name of the operator. 2) num_inputs: the number of input blobs. - 3) input_dims: a dictionary which includes the shapes of the input blobs. - 4) input_types: a list which includes the types of input blobs. 5) - output_dims: a dictionary which includes the shapes of output blobs. - 6) num_oupts: the number of output blobs. 7) args: a dictionary which - includes the args for th operator. - Here is an example to show the metadata for the WeighedSum operator - op_type : WeightedSum - num_inputs: 4 - input_dims: {'0': [256], '1': [1], '2': [256], '3': [1]} - input_types: ['float', 'float', 'float', 'float'] - output_dims: {'0': [256]} - num_outputs: 4 - args: {} - TODO(mingzhe0908): introduce device and add it to the benchmark name - """ - for op_metadata in ops_metadata: - tmp_attrs = OpMeta( - op_metadata.op_type, - op_metadata.num_inputs, - op_metadata.input_dims, - op_metadata.input_types, - op_metadata.output_dims, - op_metadata.num_outputs, - op_metadata.args, - op_metadata.device, - ) - test_attrs = tmp_attrs._asdict() - op = bench_op() - op.init(**test_attrs) - test_name = op.test_name("short") - input_config = f"Shapes: {op_metadata.input_dims}, Type: {op_metadata.input_types}, Args: {str(op_metadata.args)}" - test_config = TestConfig(test_name, input_config, tags, run_backward=False) - if op is not None: - create_caffe2_op_test_case(op, test_config) - - -def generate_c2_test(configs, c2_bench_op): - """This function creates Caffe2 op test based on the given operator""" - return _register_test(configs, c2_bench_op, create_caffe2_op_test_case, False) - - -def generate_c2_gradient_test(configs, c2_bench_op): - """This function creates Caffe2 op test based on the given operator""" - return _register_test(configs, c2_bench_op, create_caffe2_op_test_case, True) diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index c315382d1538..239dddbf7231 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -13,6 +13,7 @@ # needs to be imported after torch import torch.utils.cpp_extension as cpp_extension # noqa: F401 + """Performance microbenchmarks. This module contains core functionalities for performance microbenchmark tests. @@ -50,7 +51,7 @@ def _create_test( """Create tests with the benchmark backend. Args: bench_op_obj: an object which instantiated from a subclass of - Caffe2BenchmarkBase/TorchBenchmarkBase which includes tensor + TorchBenchmarkBase which includes tensor creation and operator execution. orig_test_attrs: a dictionary includes test configs. tags: a attribute in test config to filter inputs @@ -75,7 +76,7 @@ def _build_test( """Generate PyTorch/Caffe2 tests of operators with different inputs. Args: configs: a dictionary that has the input shapes - bench_op: a subclass of Caffe2BenchmarkBase/TorchBenchmarkBase which includes tensor + bench_op: a subclass of TorchBenchmarkBase which includes tensor creation and operator execution OperatorTestCase: a named tuple to save the metadata of an test run_backward: a bool parameter indicating backward path @@ -233,9 +234,7 @@ def _print_perf_result(self, reported_run_time_us, test_case): ) ) else: - if test_case.framework == "PyTorch": - print(f"# Mode: {'JIT' if self.use_jit else 'Eager'}") - + print(f"# Mode: {'JIT' if self.use_jit else 'Eager'}") print( f"# Name: {test_case.test_config.test_name}\n# Input: {test_case.test_config.input_config}" ) @@ -283,8 +282,7 @@ def _launch_backward(self, test_case, iters, print_per_iter=False): and the execution time is reported """ test_case.run_forward(num_runs=1, print_per_iter=False, cuda_sync=False) - if test_case.framework == "PyTorch": - test_case._output_mean() + test_case._output_mean() backward_time = timeit.timeit( functools.partial(test_case.run_backward, iters, print_per_iter), number=1 ) @@ -357,9 +355,6 @@ def _keep_test(self, test_case): # Currently, this is a sub-string matching. op_test_config = test_case.test_config - if self.args.framework: - frameworks = benchmark_utils.process_arg_list(self.args.framework) - operators = ( benchmark_utils.process_arg_list(self.args.operators) if self.args.operators @@ -370,7 +365,6 @@ def _keep_test(self, test_case): if ( self._check_keep(op_test_config.test_name, self.args.test_name) and self._check_keep_list(test_case.op_bench.module_name(), operators) - and self._check_keep_list(test_case.framework, frameworks) and self._check_operator_first_char( test_case.op_bench.module_name(), self.operator_range ) diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py index 7bb18f7d7708..6abbc566820b 100644 --- a/benchmarks/operator_benchmark/benchmark_runner.py +++ b/benchmarks/operator_benchmark/benchmark_runner.py @@ -92,7 +92,7 @@ def parse_args(): parser.add_argument( "--omp-num-threads", "--omp_num_threads", - help="Number of OpenMP threads used in PyTorch/Caffe2 runtime", + help="Number of OpenMP threads used in PyTorch runtime", default=None, type=int, ) @@ -100,7 +100,7 @@ def parse_args(): parser.add_argument( "--mkl-num-threads", "--mkl_num_threads", - help="Number of MKL threads used in PyTorch/Caffe2 runtime", + help="Number of MKL threads used in PyTorch runtime", default=None, type=int, ) @@ -135,12 +135,6 @@ def parse_args(): help="Only run the forward path of operators", ) - parser.add_argument( - "--framework", - help="Comma-delimited list of frameworks to test (Caffe2, PyTorch)", - default="Caffe2,PyTorch", - ) - parser.add_argument( "--device", help="Run tests on the provided architecture (cpu, cuda)", @@ -160,8 +154,7 @@ def parse_args(): # "Modifications to the environment variables after the program has started, # even if modified by the program itself, are ignored by the OpenMP implementation" benchmark_utils.set_omp_threads(args.omp_num_threads) - if benchmark_utils.is_pytorch_enabled(args.framework): - torch.set_num_threads(args.omp_num_threads) + torch.set_num_threads(args.omp_num_threads) if args.mkl_num_threads: benchmark_utils.set_mkl_threads(args.mkl_num_threads) diff --git a/benchmarks/operator_benchmark/benchmark_utils.py b/benchmarks/operator_benchmark/benchmark_utils.py index d7e45b7c1685..be9c62cb3c28 100644 --- a/benchmarks/operator_benchmark/benchmark_utils.py +++ b/benchmarks/operator_benchmark/benchmark_utils.py @@ -319,14 +319,6 @@ def op_list(**configs): return generated_configs -def is_caffe2_enabled(framework_arg): - return "Caffe2" in framework_arg - - -def is_pytorch_enabled(framework_arg): - return "PyTorch" in framework_arg - - def get_operator_range(chars_range): """Generates the characters from chars_range inclusive.""" if chars_range == "None" or chars_range is None: diff --git a/benchmarks/operator_benchmark/c2/__init__.py b/benchmarks/operator_benchmark/c2/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/benchmarks/operator_benchmark/c2/add_test.py b/benchmarks/operator_benchmark/c2/add_test.py deleted file mode 100644 index c3b71f3e8514..000000000000 --- a/benchmarks/operator_benchmark/c2/add_test.py +++ /dev/null @@ -1,49 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for element-wise Add operator. Supports both Caffe2/PyTorch.""" - -# Configs for C2 add operator -add_long_configs = op_bench.cross_product_configs( - M=[8, 64, 128], - N=range(2, 10, 3), - K=[2**x for x in range(0, 3)], - dtype=["int", "float"], - tags=["long"], -) - - -add_short_configs = op_bench.config_list( - attrs=[ - [8, 16, 32, "int"], - [16, 16, 64, "float"], - [64, 64, 128, "int"], - ], - attr_names=["M", "N", "K", "dtype"], - tags=["short"], -) - - -class AddBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, K, dtype): - self.input_one = self.tensor([M, N, K], dtype) - self.input_two = self.tensor([M, N, K], dtype) - self.output = self.tensor([M, N, K], dtype) - self.set_module_name("add") - - def forward(self): - op = core.CreateOperator( - "Add", [self.input_one, self.input_two], self.output, **self.args - ) - return op - - -op_bench_c2.generate_c2_test(add_long_configs + add_short_configs, AddBenchmark) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/batch_box_cox_test.py b/benchmarks/operator_benchmark/c2/batch_box_cox_test.py deleted file mode 100644 index 7c40f513cd6e..000000000000 --- a/benchmarks/operator_benchmark/c2/batch_box_cox_test.py +++ /dev/null @@ -1,49 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for BatchBoxCox operator.""" - -# Configs for C2 BatchBoxCox operator -batch_box_cox_long_configs = op_bench.cross_product_configs( - M=[32, 64, 128], N=range(32, 128, 32), dtype=["float", "double"], tags=["long"] -) - - -batch_box_cox_short_configs = op_bench.config_list( - attrs=[ - [16, 16, "float"], - [16, 16, "double"], - [64, 64, "float"], - [64, 64, "double"], - ], - attr_names=["M", "N", "dtype"], - tags=["short"], -) - - -class BatchBoxCoxBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, dtype): - self.data = self.tensor([M, N], dtype) - self.lambda1 = self.tensor([N], dtype) - self.lambda2 = self.tensor([N], dtype) - self.output = self.tensor([1, 1], dtype) - self.set_module_name("batch_box_cox") - - def forward(self): - op = core.CreateOperator( - "BatchBoxCox", [self.data, self.lambda1, self.lambda2], self.output - ) - return op - - -op_bench_c2.generate_c2_test( - batch_box_cox_long_configs + batch_box_cox_short_configs, BatchBoxCoxBenchmark -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/batch_gather_test.py b/benchmarks/operator_benchmark/c2/batch_gather_test.py deleted file mode 100644 index c0ff2c06f061..000000000000 --- a/benchmarks/operator_benchmark/c2/batch_gather_test.py +++ /dev/null @@ -1,58 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -import numpy -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for element-wise BatchGather operator.""" - -# Configs for C2 BatherGather operator -batch_gather_configs_short = op_bench.config_list( - attr_names=["M", "N", "K"], - attrs=[ - [8, 8, 1], - [256, 512, 1], - [512, 512, 1], - [8, 8, 2], - [256, 512, 2], - [512, 512, 2], - ], - cross_product_configs={ - "device": ["cpu", "cuda"], - }, - tags=["short"], -) - -batch_gather_configs_long = op_bench.cross_product_configs( - M=[128, 1024], N=[128, 1024], K=[1, 2], device=["cpu", "cuda"], tags=["long"] -) - - -class BatchGatherBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, K, device): - self.input_one = self.tensor([M, N, K], device=device) - max_val = N - numpy.random.seed((1 << 32) - 1) - index_dim = numpy.random.randint(0, N) - self.index = self.feed_tensor( - numpy.random.randint(0, max_val, index_dim), device=device - ) - self.output = self.tensor([M, index_dim, K], device=device) - self.set_module_name("batch_gather") - - def forward(self): - op = core.CreateOperator( - "BatchGather", [self.input_one, self.index], self.output - ) - return op - - -op_bench_c2.generate_c2_test( - batch_gather_configs_long + batch_gather_configs_short, BatchGatherBenchmark -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/clip_ranges_test.py b/benchmarks/operator_benchmark/c2/clip_ranges_test.py deleted file mode 100644 index 57bcd9858a8f..000000000000 --- a/benchmarks/operator_benchmark/c2/clip_ranges_test.py +++ /dev/null @@ -1,54 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core, dyndep - -dyndep.InitOpsLibrary("@/caffe2/caffe2/fb/operators:clip_ranges_op") - -"""Microbenchmarks for ClipRanges operator.""" - -# Configs for C2 ClipRanges operator -clip_ranges_long_configs = op_bench.cross_product_configs( - LENGTH=range(1, 100), - M=[1], - N=[2], - MAX_LENGTH=range(1, 100), - dtype=["int32"], - tags=["long"], -) - - -clip_ranges_short_configs = op_bench.config_list( - attrs=[ - [6, 1, 2, 1, "int32"], - [7, 1, 2, 2, "int32"], - [8, 1, 2, 3, "int32"], - [9, 1, 2, 4, "int32"], - [10, 1, 2, 5, "int32"], - ], - attr_names=["LENGTH", "M", "N", "MAX_LENGTH", "dtype"], - tags=["short"], -) - - -class ClipRangesBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, LENGTH, M, N, MAX_LENGTH, dtype): - self.input = self.tensor([LENGTH, M, N], dtype) - self.max_length = MAX_LENGTH - self.set_module_name("clip_ranges") - - def forward(self): - op = core.CreateOperator( - "ClipRanges", self.input, self.input, max_length=self.max_length - ) - return op - - -op_bench_c2.generate_c2_test( - clip_ranges_long_configs + clip_ranges_short_configs, ClipRangesBenchmark -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/concat_test.py b/benchmarks/operator_benchmark/c2/concat_test.py deleted file mode 100644 index 4e91c30f2a75..000000000000 --- a/benchmarks/operator_benchmark/c2/concat_test.py +++ /dev/null @@ -1,171 +0,0 @@ -import random - -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for Concat operator. Supports both Caffe2/PyTorch.""" - -cross_product_configs = { - "device": ["cpu", "cuda"], - "dtype": ["float"], - "add_axis": [0], -} - -# Configs for C2 concat operator -cat_configs_short = op_bench.config_list( - attr_names=["sizes", "N", "axis"], - attrs=[ - [(1, 1, 1), 2, 0], # noqa: E241 - [(512, 512, 2), 2, 1], # noqa: E241 - [(128, 1024, 2), 2, 1], # noqa: E241 - ], - cross_product_configs=cross_product_configs, - tags=["short"], -) - -# Configs specific to static runtime feature - a fast runtime for pared down models -cat_configs_static_runtime = op_bench.config_list( - attr_names=["sizes", "N", "axis", "add_axis"], - attrs=[ - [(1, 40), 5, 1, 1], - [[(1, 160), (1, 14)], -1, 1, 0], - [[(1, 20, 40), (1, 4, 40), (1, 5, 40)], -1, 1, 0], - [[(1, 580), (1, 174)], -1, 1, 0], - [(20, 40), 5, 1, 1], - [[(20, 160), (20, 14)], -1, 1, 0], - [[(20, 20, 40), (20, 4, 40), (20, 5, 40)], -1, 1, 0], - [[(20, 580), (20, 174)], -1, 1, 0], - ], - cross_product_configs=cross_product_configs, - tags=["static_runtime"], -) - -cat_configs_long = op_bench.config_list( - attr_names=["sizes", "N", "axis"], - attrs=[ - [(2**10, 2**10, 2), 2, 0], # noqa: E241 - [(2**10 + 1, 2**10 - 1, 2), 2, 1], # noqa: E226,E241 - [(2**10, 2**10, 2), 2, 2], # noqa: E241 - [ - [ - lambda: random.randint(2**6, 2**7), - 2**7 - 17, - 2**6 + 1, - ], # noqa: E201,E226,E241 - 5, - 0, - ], - [ - [ - 2**6 + 2**5, - lambda: random.randint(2**6, 2**7), - 2**6, - ], # noqa: E201,E226,E241,E272 - 5, - 1, - ], - [ - [ - 2**7, - 2**6, - lambda: random.randint(2**6, 2**7), - ], # noqa: E201,E241,E272 - 5, - 2, - ], - [[lambda: random.randint(2**5, 2**6), 2**5, 2**6], 50, 0], # noqa: E241 - [ - [2**5, lambda: random.randint(2**5, 2**6), 2**6], # noqa: E241,E272 - 50, - 1, - ], - [ - [ - 2**5 + 1, - 2**6 + 1, - lambda: random.randint(2**5, 2**6), - ], # noqa: E226,E241,E272 - 50, - 2, - ], - ], - cross_product_configs=cross_product_configs, - tags=["long"], -) - -# There is a different codepath on CUDA for >4 dimensions -cat_configs_multidim = op_bench.config_list( - attr_names=["sizes", "N", "axis", "dtype"], - attrs=[ - [(2**6, 2**5, 2**2, 2**4, 2**5), 2, 2], # noqa: E241 - [(2**4, 2**5, 2**2, 2**4, 2**5), 8, 2], # noqa: E241 - [ - (2**3 + 1, 2**5 - 1, 2**2 + 1, 2**4 - 1, 2**5 + 1), - 17, - 4, - ], # noqa: E226,E241 - ], - cross_product_configs=cross_product_configs, - tags=["multidim"], -) - -cat_configs_manyinputs = op_bench.config_list( - attr_names=["sizes", "N", "axis"], - attrs=[ - [[lambda: random.randint(1, 10000)], 100, 0], - [[lambda: random.randint(1, 1000)], 1000, 0], - [[lambda: random.randint(1, 500)], 2000, 0], - [[lambda: random.randint(1, 300)], 3000, 0], - ], - cross_product_configs=cross_product_configs, - tags=["manyinputs"], -) - - -class ConcatBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, sizes, N, axis, add_axis, dtype, device): - random.seed(42) - self.inputs = [] - self.args = {"axis": axis, "add_axis": add_axis} - gen_sizes = [] - if type(sizes) == list and N == -1: - gen_sizes = sizes - else: - for i in range(N): - gen_sizes.append( - [ - old_size() if callable(old_size) else old_size - for old_size in sizes - ] - ) - - for s in gen_sizes: - self.inputs.append(self.tensor(s, dtype, device=device)) - - self.output = self.tensor(gen_sizes[0], dtype, device=device) - self.split_info = self.tensor(gen_sizes[0], "int") - self.set_module_name("concat") - - def forward(self): - op = core.CreateOperator( - "Concat", self.inputs, [self.output, self.split_info], **self.args - ) - return op - - -op_bench_c2.generate_c2_test( - cat_configs_short - + cat_configs_long - + cat_configs_multidim - + cat_configs_manyinputs - + cat_configs_static_runtime, - ConcatBenchmark, -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/matmul_test.py b/benchmarks/operator_benchmark/c2/matmul_test.py deleted file mode 100644 index 72bc4c78d710..000000000000 --- a/benchmarks/operator_benchmark/c2/matmul_test.py +++ /dev/null @@ -1,50 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - -"""Microbenchmarks for MatMul operator""" - -# Configs for C2 Matmul operator -mm_long_configs = op_bench.cross_product_configs( - M=[8, 64, 128], - N=range(2, 10, 3), - K=[2**x for x in range(0, 3)], - trans_a=[True, False], - trans_b=[True, False], - tags=["long"], -) - - -mm_short_configs = op_bench.config_list( - attrs=[ - [128, 128, 128, False, True], - [1024, 1024, 256, True, False], - [8192, 8192, 1024, True, False], - ], - attr_names=["M", "N", "K", "trans_a", "trans_b"], - tags=["short"], -) - - -class MatMulBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, K, trans_a, trans_b): - self.input_one = self.tensor([N, M]) if trans_a else self.tensor([M, N]) - self.input_two = self.tensor([K, N]) if trans_b else self.tensor([N, K]) - self.args = {"trans_a": trans_a, "trans_b": trans_b} - self.output = self.tensor([M, K]) - self.set_module_name("matmul") - - def forward(self): - op = core.CreateOperator( - "MatMul", [self.input_one, self.input_two], self.output, **self.args - ) - return op - - -op_bench_c2.generate_c2_test(mm_long_configs + mm_short_configs, MatMulBenchmark) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/quantile_op_test.py b/benchmarks/operator_benchmark/c2/quantile_op_test.py deleted file mode 100644 index 296b6bf189e3..000000000000 --- a/benchmarks/operator_benchmark/c2/quantile_op_test.py +++ /dev/null @@ -1,48 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for QuantileOp operator.""" - -# Configs for C2 QuantileOp operator -quantile_op_long_configs = op_bench.cross_product_configs( - M=[32, 64, 128], N=range(32, 128, 32), dtype=["float", "double"], tags=["long"] -) - - -quantile_op_short_configs = op_bench.config_list( - attrs=[ - [16, 16, "float"], - [16, 16, "double"], - [64, 64, "float"], - [64, 64, "double"], - ], - attr_names=["M", "N", "dtype"], - tags=["short"], -) - - -class QuantileOpBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, dtype): - self.data = [self.tensor([N], dtype) for _ in range(M)] - self.quantile = 0.3 - self.output = self.tensor([1], dtype) - self.set_module_name("quantile_op") - - def forward(self): - op = core.CreateOperator( - "Quantile", inputs=self.data, outputs=self.output, quantile=self.quantile - ) - return op - - -op_bench_c2.generate_c2_test( - quantile_op_long_configs + quantile_op_short_configs, QuantileOpBenchmark -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/replace_nan_test.py b/benchmarks/operator_benchmark/c2/replace_nan_test.py deleted file mode 100644 index c735a69b4ab4..000000000000 --- a/benchmarks/operator_benchmark/c2/replace_nan_test.py +++ /dev/null @@ -1,44 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for element-wise ReplaceNaN operator.""" - -# Configs for C2 ReplaceNaN operator -replace_nan_long_configs = op_bench.cross_product_configs( - M=[32, 64, 128], N=range(32, 128, 32), dtype=["float", "double"], tags=["long"] -) - - -replace_nan_short_configs = op_bench.config_list( - attrs=[ - [16, 16, "float"], - [16, 16, "double"], - [64, 64, "float"], - [64, 64, "double"], - ], - attr_names=["M", "N", "dtype"], - tags=["short"], -) - - -class ReplaceNaNBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, dtype): - self.input = self.tensor([M, N], dtype) - self.set_module_name("replace_nan") - - def forward(self): - op = core.CreateOperator("ReplaceNaN", self.input, self.input, value=1.0) - return op - - -op_bench_c2.generate_c2_test( - replace_nan_long_configs + replace_nan_short_configs, ReplaceNaNBenchmark -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/common/tests/c2_cpu_gpu_forward_backward_test.py b/benchmarks/operator_benchmark/common/tests/c2_cpu_gpu_forward_backward_test.py deleted file mode 100644 index ff34a58533f9..000000000000 --- a/benchmarks/operator_benchmark/common/tests/c2_cpu_gpu_forward_backward_test.py +++ /dev/null @@ -1,41 +0,0 @@ -import operator_benchmark as op_bench - -from caffe2.python import core - - -add_configs = op_bench.cross_product_configs( - M=[8], N=[8], K=[8], tags=["short"], device=["cuda", "cpu"] -) - - -class AddBenchmark(op_bench.Caffe2BenchmarkBase): - def init(self, M, N, K, device): - self.set_module_name("add") - self.input_one = self.tensor([M, N, K], device=device) - self.input_two = self.tensor([M, N, K], device=device) - self.input_one_grad = self.tensor([M, N, K], device=device) - self.input_two_grad = self.tensor([M, N, K], device=device) - self.output = self.tensor([M, N, K], device=device) - - def forward(self): - op = core.CreateOperator( - "Add", [self.input_one, self.input_two], self.output, **self.args - ) - return op - - def backward(self): - grad_op = core.CreateOperator( - "AddGradient", - [self.output, self.input_one, self.input_two], - [self.input_one_grad, self.input_two_grad], - **self.args, - ) - return grad_op - - -op_bench.generate_c2_test(add_configs, AddBenchmark) -op_bench.generate_c2_gradient_test(add_configs, AddBenchmark) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/record_function_benchmark/record_function_bench.py b/benchmarks/record_function_benchmark/record_function_bench.py index 348c1cae7650..f42f9b0d647f 100644 --- a/benchmarks/record_function_benchmark/record_function_bench.py +++ b/benchmarks/record_function_benchmark/record_function_bench.py @@ -1,18 +1,13 @@ import argparse import sys -import torch -import torch.utils.benchmark as benchmark_utils - - -try: - from benchmarks.fastrnns.factory import lstm_creator -except ImportError: - from caffe2.benchmarks.fastrnns.factory import lstm_creator - +from benchmarks.fastrnns.factory import lstm_creator from torchvision.models import resnet50 +import torch +import torch.utils.benchmark as benchmark_utils + def prepare_lstm_jit(bench_args): model_def = lstm_creator( diff --git a/binaries/bench_gen/bench_gen.py b/binaries/bench_gen/bench_gen.py deleted file mode 100755 index 7523e76f8b14..000000000000 --- a/binaries/bench_gen/bench_gen.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import ast - -from caffe2.python import brew, workspace - -from caffe2.python.model_helper import ModelHelper -from caffe2.python.predictor import mobile_exporter - - -def parse_kwarg(kwarg_str): - key, value = kwarg_str.split("=") - try: - value = ast.literal_eval(value) - except ValueError: - pass - return key, value - - -def main(args): - # User defined keyword arguments - kwargs = {"order": "NCHW", "use_cudnn": False} - kwargs.update(dict(args.kwargs)) - - model = ModelHelper(name=args.benchmark_name) - - op_type = args.operator # assumes a brew type op name - input_name = args.input_name - output_name = args.output_name - - iters = int(args.instances) - for i in range(iters): - input_blob_name = input_name + (str(i) if i > 0 and args.chain else "") - output_blob_name = output_name + str(i + 1) - add_op = getattr(brew, op_type) - add_op(model, input_blob_name, output_blob_name, **kwargs) - if args.chain: - input_name, output_name = output_name, input_name - - workspace.RunNetOnce(model.param_init_net) - - init_net, predict_net = mobile_exporter.Export(workspace, model.net, model.params) - - if args.debug: - print("init_net:") - for op in init_net.op: - print(" ", op.type, op.input, "-->", op.output) - print("predict_net:") - for op in predict_net.op: - print(" ", op.type, op.input, "-->", op.output) - - with open(args.predict_net, "wb") as f: - f.write(predict_net.SerializeToString()) - with open(args.init_net, "wb") as f: - f.write(init_net.SerializeToString()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Utility to generate Caffe2 benchmark models." - ) - parser.add_argument("operator", help="Caffe2 operator to benchmark.") - parser.add_argument( - "-b", - "--blob", - help="Instantiate a blob --blob name=dim1,dim2,dim3", - action="append", - ) - parser.add_argument("--context", help="Context to run on.", default="CPU") - parser.add_argument( - "--kwargs", - help="kwargs to pass to operator.", - nargs="*", - type=parse_kwarg, - default=[], - ) - parser.add_argument( - "--init-net", - "--init_net", - help="Output initialization net.", - default="init_net.pb", - ) - parser.add_argument( - "--predict-net", - "--predict_net", - help="Output prediction net.", - default="predict_net.pb", - ) - parser.add_argument( - "--benchmark-name", - "--benchmark_name", - help="Name of the benchmark network", - default="benchmark", - ) - parser.add_argument( - "--input-name", "--input_name", help="Name of the input blob.", default="data" - ) - parser.add_argument( - "--output-name", - "--output_name", - help="Name of the output blob.", - default="output", - ) - parser.add_argument( - "--instances", help="Number of instances to run the operator.", default="1" - ) - parser.add_argument( - "-d", "--debug", help="Print debug information.", action="store_true" - ) - parser.add_argument( - "-c", - "--chain", - help="Chain ops together (create data dependencies)", - action="store_true", - ) - args = parser.parse_args() - main(args) diff --git a/scripts/appveyor/install.bat b/scripts/appveyor/install.bat deleted file mode 100644 index cd87d6273160..000000000000 --- a/scripts/appveyor/install.bat +++ /dev/null @@ -1,10 +0,0 @@ -:: Installation scripts for appveyor. - -@echo on - -if "%USE_CUDA%" == "ON" call %~dp0%install_cuda.bat - -:: Miniconda path for appveyor -set PATH=C:\Miniconda-x64;C:\Miniconda-x64\Scripts;%PATH% -:: Install numpy -conda install -y numpy diff --git a/scripts/appveyor/install_cuda.bat b/scripts/appveyor/install_cuda.bat deleted file mode 100644 index c8c86b002e5b..000000000000 --- a/scripts/appveyor/install_cuda.bat +++ /dev/null @@ -1,22 +0,0 @@ -@echo on - -appveyor DownloadFile ^ - https://developer.nvidia.com/compute/cuda/8.0/prod/local_installers/cuda_8.0.44_windows-exe ^ - -FileName cuda_8.0.44_windows.exe -appveyor Downloadfile ^ - http://developer.download.nvidia.com/compute/redist/cudnn/v5.1/cudnn-8.0-windows10-x64-v5.1.zip ^ - -FileName cudnn-8.0-windows10-x64-v5.1.zip - -cuda_8.0.44_windows.exe -s compiler_8.0 cublas_8.0 cublas_dev_8.0 cudart_8.0 curand_8.0 curand_dev_8.0 nvrtc_8.0 nvrtc_dev_8.0 -set PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v8.0\libnvvp;%PATH% - -7z x cudnn-8.0-windows10-x64-v5.1.zip -copy cuda\include\cudnn.h ^ - "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\include\" -copy cuda\lib\x64\cudnn.lib ^ - "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\x64\" -copy cuda\bin\cudnn64_5.dll ^ - "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin\" - -:: Make sure that nvcc is working correctly. -nvcc -V || exit /b diff --git a/scripts/model_zoo/update-caffe2-models.py b/scripts/model_zoo/update-caffe2-models.py deleted file mode 100755 index 1053530d05c5..000000000000 --- a/scripts/model_zoo/update-caffe2-models.py +++ /dev/null @@ -1,175 +0,0 @@ -#! /usr/bin/env python3 - -import os -import subprocess -import sys -import tarfile -import tempfile - -from urllib.request import urlretrieve - -from caffe2.python.models.download import ( - deleteDirectory, - downloadFromURLToFile, - getURLFromName, -) - - -class SomeClass: - # largely copied from - # https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py - def _download(self, model): - model_dir = self._caffe2_model_dir(model) - assert not os.path.exists(model_dir) - os.makedirs(model_dir) - for f in ["predict_net.pb", "init_net.pb", "value_info.json"]: - url = getURLFromName(model, f) - dest = os.path.join(model_dir, f) - try: - try: - downloadFromURLToFile(url, dest, show_progress=False) - except TypeError: - # show_progress not supported prior to - # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1 - # (Sep 17, 2017) - downloadFromURLToFile(url, dest) - except Exception as e: - print(f"Abort: {e}") - print("Cleaning up...") - deleteDirectory(model_dir) - sys.exit(1) - - def _caffe2_model_dir(self, model): - caffe2_home = os.path.expanduser("~/.caffe2") - models_dir = os.path.join(caffe2_home, "models") - return os.path.join(models_dir, model) - - def _onnx_model_dir(self, model): - onnx_home = os.path.expanduser("~/.onnx") - models_dir = os.path.join(onnx_home, "models") - model_dir = os.path.join(models_dir, model) - return model_dir, os.path.dirname(model_dir) - - # largely copied from - # https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py - def _prepare_model_data(self, model): - model_dir, models_dir = self._onnx_model_dir(model) - if os.path.exists(model_dir): - return - os.makedirs(model_dir) - url = f"https://s3.amazonaws.com/download.onnx/models/{model}.tar.gz" - - # On Windows, NamedTemporaryFile cannot be opened for a - # second time - download_file = tempfile.NamedTemporaryFile(delete=False) - try: - download_file.close() - print(f"Start downloading model {model} from {url}") - urlretrieve(url, download_file.name) - print("Done") - with tarfile.open(download_file.name) as t: - t.extractall(models_dir) - except Exception as e: - print(f"Failed to prepare data for model {model}: {e}") - raise - finally: - os.remove(download_file.name) - - -models = [ - "bvlc_alexnet", - "densenet121", - "inception_v1", - "inception_v2", - "resnet50", - # TODO currently onnx can't translate squeezenet :( - # 'squeezenet', - "vgg16", - # TODO currently vgg19 doesn't work in the CI environment, - # possibly due to OOM - # 'vgg19' -] - - -def download_models(): - sc = SomeClass() - for model in models: - print("update-caffe2-models.py: downloading", model) - caffe2_model_dir = sc._caffe2_model_dir(model) - onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) - if not os.path.exists(caffe2_model_dir): - sc._download(model) - if not os.path.exists(onnx_model_dir): - sc._prepare_model_data(model) - - -def generate_models(): - sc = SomeClass() - for model in models: - print("update-caffe2-models.py: generating", model) - caffe2_model_dir = sc._caffe2_model_dir(model) - onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) - subprocess.check_call(["echo", model]) - with open(os.path.join(caffe2_model_dir, "value_info.json")) as f: - value_info = f.read() - subprocess.check_call( - [ - "convert-caffe2-to-onnx", - "--caffe2-net-name", - model, - "--caffe2-init-net", - os.path.join(caffe2_model_dir, "init_net.pb"), - "--value-info", - value_info, - "-o", - os.path.join(onnx_model_dir, "model.pb"), - os.path.join(caffe2_model_dir, "predict_net.pb"), - ] - ) - subprocess.check_call( - ["tar", "-czf", model + ".tar.gz", model], cwd=onnx_models_dir - ) - - -def upload_models(): - sc = SomeClass() - for model in models: - print("update-caffe2-models.py: uploading", model) - onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) - subprocess.check_call( - [ - "aws", - "s3", - "cp", - model + ".tar.gz", - f"s3://download.onnx/models/{model}.tar.gz", - "--acl", - "public-read", - ], - cwd=onnx_models_dir, - ) - - -def cleanup(): - sc = SomeClass() - for model in models: - onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) - os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + ".tar.gz")) - - -if __name__ == "__main__": - try: - subprocess.check_call(["aws", "sts", "get-caller-identity"]) - except: - print( - "update-caffe2-models.py: please run `aws configure` manually to set up credentials" - ) - sys.exit(1) - if sys.argv[1] == "download": - download_models() - if sys.argv[1] == "generate": - generate_models() - elif sys.argv[1] == "upload": - upload_models() - elif sys.argv[1] == "cleanup": - cleanup() diff --git a/scripts/model_zoo/update-models-from-caffe2.py b/scripts/model_zoo/update-models-from-caffe2.py deleted file mode 100644 index 3d4d4d5d1c0c..000000000000 --- a/scripts/model_zoo/update-models-from-caffe2.py +++ /dev/null @@ -1,372 +0,0 @@ -#! /usr/bin/env python3 - -import argparse -import glob -import json -import os -import shutil -import tarfile -import tempfile - -from urllib.request import urlretrieve - -import boto3 -import numpy as np -import onnx -import onnx.backend -from onnx import numpy_helper - -import caffe2.python.onnx.backend -import caffe2.python.onnx.frontend -import caffe2.python.workspace as c2_workspace -from caffe2.proto import caffe2_pb2 - -from caffe2.python.models.download import ( - deleteDirectory, - downloadFromURLToFile, - getURLFromName, -) - - -"""A script converting Caffe2 models to ONNX, and updating ONNX model zoos. - -Arguments: - -v, verbose - --local-dir, where we store the ONNX and Caffe2 models - --no-cache, ignore existing models in local-dir - --clean-test-data, delete all the existing test data when updating ONNX model zoo - --add-test-data, add add-test-data sets of test data for each ONNX model - --only-local, run locally (for testing purpose) - -Examples: - # store the data in /home/username/zoo-dir, delete existing test data, ignore local cache, - # and generate 3 sets of new test data - python update-caffe2-models.py --local-dir /home/username/zoo-dir --clean-test-data --no-cache --add-test-data 3 - -""" - -# TODO: Add GPU support - - -def upload_onnx_model(model_name, zoo_dir, backup=False, only_local=False): - if only_local: - print("No uploading in local only mode.") - return - model_dir = os.path.join(zoo_dir, model_name) - suffix = "-backup" if backup else "" - if backup: - print(f"Backing up the previous version of ONNX model {model_name}...") - rel_file_name = f"{model_name}{suffix}.tar.gz" - abs_file_name = os.path.join(zoo_dir, rel_file_name) - print(f"Compressing {model_name} model to {abs_file_name}") - with tarfile.open(abs_file_name, "w:gz") as f: - f.add(model_dir, arcname=model_name) - file_size = os.stat(abs_file_name).st_size - print( - f"Uploading {abs_file_name} ({float(file_size) / 1024 / 1024} MB) to s3 cloud..." - ) - client = boto3.client("s3", "us-east-1") - transfer = boto3.s3.transfer.S3Transfer(client) - transfer.upload_file( - abs_file_name, - "download.onnx", - f"models/latest/{rel_file_name}", - extra_args={"ACL": "public-read"}, - ) - - print(f"Successfully uploaded {rel_file_name} to s3!") - - -def download_onnx_model(model_name, zoo_dir, use_cache=True, only_local=False): - model_dir = os.path.join(zoo_dir, model_name) - if os.path.exists(model_dir): - if use_cache: - upload_onnx_model(model_name, zoo_dir, backup=True, only_local=only_local) - return - else: - shutil.rmtree(model_dir) - url = f"https://s3.amazonaws.com/download.onnx/models/latest/{model_name}.tar.gz" - - download_file = tempfile.NamedTemporaryFile(delete=False) - try: - download_file.close() - print( - f"Downloading ONNX model {model_name} from {url} and save in {download_file.name} ...\n" - ) - urlretrieve(url, download_file.name) - with tarfile.open(download_file.name) as t: - print(f"Extracting ONNX model {model_name} to {zoo_dir} ...\n") - t.extractall(zoo_dir) - except Exception as e: - print(f"Failed to download/backup data for ONNX model {model_name}: {e}") - if not os.path.exists(model_dir): - os.makedirs(model_dir) - finally: - os.remove(download_file.name) - - if not only_local: - upload_onnx_model(model_name, zoo_dir, backup=True, only_local=only_local) - - -def download_caffe2_model(model_name, zoo_dir, use_cache=True): - model_dir = os.path.join(zoo_dir, model_name) - if os.path.exists(model_dir): - if use_cache: - return - else: - shutil.rmtree(model_dir) - os.makedirs(model_dir) - - for f in ["predict_net.pb", "init_net.pb", "value_info.json"]: - url = getURLFromName(model_name, f) - dest = os.path.join(model_dir, f) - try: - try: - downloadFromURLToFile(url, dest, show_progress=False) - except TypeError: - # show_progress not supported prior to - # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1 - # (Sep 17, 2017) - downloadFromURLToFile(url, dest) - except Exception as e: - print(f"Abort: {e}") - print("Cleaning up...") - deleteDirectory(model_dir) - raise - - -def caffe2_to_onnx(caffe2_model_name, caffe2_model_dir): - caffe2_init_proto = caffe2_pb2.NetDef() - caffe2_predict_proto = caffe2_pb2.NetDef() - - with open(os.path.join(caffe2_model_dir, "init_net.pb"), "rb") as f: - caffe2_init_proto.ParseFromString(f.read()) - caffe2_init_proto.name = f"{caffe2_model_name}_init" - with open(os.path.join(caffe2_model_dir, "predict_net.pb"), "rb") as f: - caffe2_predict_proto.ParseFromString(f.read()) - caffe2_predict_proto.name = caffe2_model_name - with open(os.path.join(caffe2_model_dir, "value_info.json"), "rb") as f: - value_info = json.loads(f.read()) - - print( - f"Converting Caffe2 model {caffe2_model_name} in {caffe2_model_dir} to ONNX format" - ) - onnx_model = caffe2.python.onnx.frontend.caffe2_net_to_onnx_model( - init_net=caffe2_init_proto, - predict_net=caffe2_predict_proto, - value_info=value_info, - ) - - return onnx_model, caffe2_init_proto, caffe2_predict_proto - - -def tensortype_to_ndarray(tensor_type): - shape = [] - for dim in tensor_type.shape.dim: - shape.append(dim.dim_value) - if tensor_type.elem_type == onnx.TensorProto.FLOAT: - type = np.float32 - elif tensor_type.elem_type == onnx.TensorProto.INT: - type = np.int32 - else: - raise - array = np.random.rand(*shape).astype(type) - return array - - -def generate_test_input_data(onnx_model, scale): - real_inputs_names = list( - {input.name for input in onnx_model.graph.input} - - {init.name for init in onnx_model.graph.initializer} - ) - real_inputs = [] - for name in real_inputs_names: - for input in onnx_model.graph.input: - if name == input.name: - real_inputs.append(input) - - test_inputs = [] - for input in real_inputs: - ndarray = tensortype_to_ndarray(input.type.tensor_type) - test_inputs.append((input.name, ndarray * scale)) - - return test_inputs - - -def generate_test_output_data(caffe2_init_net, caffe2_predict_net, inputs): - p = c2_workspace.Predictor(caffe2_init_net, caffe2_predict_net) - inputs_map = {input[0]: input[1] for input in inputs} - - output = p.run(inputs_map) - c2_workspace.ResetWorkspace() - return output - - -def onnx_verify(onnx_model, inputs, ref_outputs): - prepared = caffe2.python.onnx.backend.prepare(onnx_model) - onnx_inputs = [] - for input in inputs: - if isinstance(input, tuple): - onnx_inputs.append(input[1]) - else: - onnx_inputs.append(input) - onnx_outputs = prepared.run(inputs=onnx_inputs) - np.testing.assert_almost_equal(onnx_outputs, ref_outputs, decimal=3) - - -model_mapping = { - "bvlc_alexnet": "bvlc_alexnet", - "bvlc_googlenet": "bvlc_googlenet", - "bvlc_reference_caffenet": "bvlc_reference_caffenet", - "bvlc_reference_rcnn_ilsvrc13": "bvlc_reference_rcnn_ilsvrc13", - "densenet121": "densenet121", - #'finetune_flickr_style': 'finetune_flickr_style', - "inception_v1": "inception_v1", - "inception_v2": "inception_v2", - "resnet50": "resnet50", - "shufflenet": "shufflenet", - "squeezenet": "squeezenet_old", - #'vgg16': 'vgg16', - "vgg19": "vgg19", - "zfnet512": "zfnet512", -} - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Update the ONNX models.") - parser.add_argument("-v", action="store_true", default=False, help="verbose") - parser.add_argument( - "--local-dir", - type=str, - default=os.path.expanduser("~"), - help="local dir to store Caffe2 and ONNX models", - ) - parser.add_argument( - "--no-cache", - action="store_true", - default=False, - help="whether use local ONNX models", - ) - parser.add_argument( - "--clean-test-data", - action="store_true", - default=False, - help="remove the old test data", - ) - parser.add_argument( - "--add-test-data", type=int, default=0, help="add new test data" - ) - parser.add_argument( - "--only-local", - action="store_true", - default=False, - help="no upload including backup", - ) - - args = parser.parse_args() - delete_test_data = args.clean_test_data - add_test_data = args.add_test_data - use_cache = not args.no_cache - only_local = args.only_local - - root_dir = args.local_dir - caffe2_zoo_dir = os.path.join(root_dir, ".caffe2", "models") - onnx_zoo_dir = os.path.join(root_dir, ".onnx", "models") - - for onnx_model_name in model_mapping: - c2_model_name = model_mapping[onnx_model_name] - - print( - f"####### Processing ONNX model {onnx_model_name} ({c2_model_name} in Caffe2) #######" - ) - download_caffe2_model(c2_model_name, caffe2_zoo_dir, use_cache=use_cache) - download_onnx_model( - onnx_model_name, onnx_zoo_dir, use_cache=use_cache, only_local=only_local - ) - - onnx_model_dir = os.path.join(onnx_zoo_dir, onnx_model_name) - - if delete_test_data: - print("Deleting all the existing test data...") - # NB: For now, we don't delete the npz files. - # for f in glob.glob(os.path.join(onnx_model_dir, '*.npz')): - # os.remove(f) - for f in glob.glob(os.path.join(onnx_model_dir, "test_data_set*")): - shutil.rmtree(f) - - onnx_model, c2_init_net, c2_predict_net = caffe2_to_onnx( - c2_model_name, os.path.join(caffe2_zoo_dir, c2_model_name) - ) - - print(f"Deleteing old ONNX {onnx_model_name} model...") - for f in glob.glob(os.path.join(onnx_model_dir, "model*".format())): - os.remove(f) - - print(f"Serializing generated ONNX {onnx_model_name} model ...") - with open(os.path.join(onnx_model_dir, "model.onnx"), "wb") as file: - file.write(onnx_model.SerializeToString()) - - print(f"Verifying model {onnx_model_name} with ONNX model checker...") - onnx.checker.check_model(onnx_model) - - total_existing_data_set = 0 - print(f"Verifying model {onnx_model_name} with existing test data...") - for f in glob.glob(os.path.join(onnx_model_dir, "*.npz")): - test_data = np.load(f, encoding="bytes") - inputs = list(test_data["inputs"]) - ref_outputs = list(test_data["outputs"]) - onnx_verify(onnx_model, inputs, ref_outputs) - total_existing_data_set += 1 - for f in glob.glob(os.path.join(onnx_model_dir, "test_data_set*")): - inputs = [] - inputs_num = len(glob.glob(os.path.join(f, "input_*.pb"))) - for i in range(inputs_num): - tensor = onnx.TensorProto() - with open(os.path.join(f, f"input_{i}.pb"), "rb") as pf: - tensor.ParseFromString(pf.read()) - inputs.append(numpy_helper.to_array(tensor)) - ref_outputs = [] - ref_outputs_num = len(glob.glob(os.path.join(f, "output_*.pb"))) - for i in range(ref_outputs_num): - tensor = onnx.TensorProto() - with open(os.path.join(f, f"output_{i}.pb"), "rb") as pf: - tensor.ParseFromString(pf.read()) - ref_outputs.append(numpy_helper.to_array(tensor)) - onnx_verify(onnx_model, inputs, ref_outputs) - total_existing_data_set += 1 - - starting_index = 0 - while os.path.exists( - os.path.join(onnx_model_dir, f"test_data_set_{starting_index}") - ): - starting_index += 1 - - if total_existing_data_set == 0 and add_test_data == 0: - add_test_data = 3 - total_existing_data_set = 3 - - print(f"Generating {add_test_data} sets of new test data...") - for i in range(starting_index, add_test_data + starting_index): - data_dir = os.path.join(onnx_model_dir, f"test_data_set_{i}") - os.makedirs(data_dir) - inputs = generate_test_input_data(onnx_model, 255) - ref_outputs = generate_test_output_data(c2_init_net, c2_predict_net, inputs) - onnx_verify(onnx_model, inputs, ref_outputs) - for index, input in enumerate(inputs): - tensor = numpy_helper.from_array(input[1]) - with open(os.path.join(data_dir, f"input_{index}.pb"), "wb") as file: - file.write(tensor.SerializeToString()) - for index, output in enumerate(ref_outputs): - tensor = numpy_helper.from_array(output) - with open(os.path.join(data_dir, f"output_{index}.pb"), "wb") as file: - file.write(tensor.SerializeToString()) - - del onnx_model - del c2_init_net - del c2_predict_net - - upload_onnx_model( - onnx_model_name, onnx_zoo_dir, backup=False, only_local=only_local - ) - - print("\n\n") From 8bcebc8daee3424587c6ce9071f5392eb6166b10 Mon Sep 17 00:00:00 2001 From: albanD Date: Wed, 5 Jun 2024 23:59:36 +0000 Subject: [PATCH 389/706] Add runtime dependency on setuptools for cpp_extensions (#127921) As per title since this was removed from the builtin python binary in 3.12 and we use it `torch.utils.cpp_extension.*`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127921 Approved by: https://github.com/Skylion007 --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index e2529335bcc6..07d80a7e1392 100644 --- a/setup.py +++ b/setup.py @@ -1137,6 +1137,9 @@ def main(): 'mkl>=2021.1.1,<=2021.4.0; platform_system == "Windows"', ] + if sys.version_info >= (3, 12, 0): + install_requires.append("setuptools") + if BUILD_PYTHON_ONLY: install_requires.append(LIBTORCH_PKG_NAME) From d3ad84c38f5b06bd0278bec6dc6a2d42e12fa97e Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 5 Jun 2024 13:06:01 -0400 Subject: [PATCH 390/706] Use pexpr, not texpr in Triton launch codegen (#128038) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/128038 Approved by: https://github.com/Skylion007 --- torch/_inductor/select_algorithm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bc89441e3bd8..5e5cbf35baf9 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -40,6 +40,7 @@ ) from .codegen.triton_utils import config_of, signature_to_meta +from .codegen.wrapper import pexpr from .exc import CUDACompileError from .ir import ChoiceCaller, PrimitiveInfoType from .runtime.hints import DeviceProperties @@ -537,7 +538,7 @@ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): meta = wrapper.add_meta_once(self.meta) grid_call = [ - texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes + pexpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes ] + [meta] grid_call = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})" wrapper.writeline( From 80d34217c6e10e13325fa29d16062dd2bba0ab1e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 6 Jun 2024 01:03:22 +0000 Subject: [PATCH 391/706] Typo fixes: et al. (#127811) "et al." is short for _et alia_ and should be abbreviated with a period on the second word. Noticed this typo when reading through the SGD docs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127811 Approved by: https://github.com/janeyx99 --- aten/src/ATen/native/LinearAlgebra.cpp | 2 +- aten/src/ATen/native/LossCTC.cpp | 4 ++-- aten/src/ATen/native/Math.h | 2 +- aten/src/ATen/native/cuda/LossCTC.cu | 4 ++-- aten/src/ATen/native/native_functions.yaml | 2 +- .../functional_autograd_benchmark/torchaudio_models.py | 2 +- torch/_lowrank.py | 6 +++--- torch/_numpy/_funcs_impl.py | 4 ++-- torch/distributions/one_hot_categorical.py | 2 +- torch/distributions/relaxed_bernoulli.py | 4 ++-- torch/distributions/relaxed_categorical.py | 4 ++-- torch/nn/modules/pixelshuffle.py | 4 ++-- torch/optim/sgd.py | 4 ++-- torch/testing/_internal/common_nn.py | 2 +- 14 files changed, 23 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 3389033ac985..6015a3b509b0 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -856,7 +856,7 @@ namespace { /** * @brief Computes the optimal matrix chain multiplication order * - * Follows the dynamic programming algorithm from Cormen et al, + * Follows the dynamic programming algorithm from Cormen et al., * "Introduction to Algorithms, Third Edition", Chapter 15.2, * p. 370-378. Note that the book uses 1-based indexing. * diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index b13ed7e2ce92..6848abe70ec7 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -2,9 +2,9 @@ // Licensed under the BSD-3-Clause license // This is the CPU implementation of the Connectionist Temporal Loss. // We mostly follow Graves. -// 1. Graves et al: http://www.cs.toronto.edu/~graves/icml_2006.pdf +// 1. Graves et al.: http://www.cs.toronto.edu/~graves/icml_2006.pdf // We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based. -// Graves et al call the probabilities y, we use log_probs (also calling them inputs) +// Graves et al. call the probabilities y, we use log_probs (also calling them inputs) #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 092ee00992e9..8296d6cf60a2 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -508,7 +508,7 @@ static inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { /* References * [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov - * [igam2] Maddock et. al., "Incomplete Gamma Functions", + * [igam2] Maddock et al., "Incomplete Gamma Functions", * https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html */ diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index b451592f1944..f559625e6b0a 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -2,9 +2,9 @@ // Licensed under the BSD-3-Clause license // This is the GPU implementation of the Connectionist Temporal Loss. // We mostly follow Graves. -// 1. Graves et al: http://www.cs.toronto.edu/~graves/icml_2006.pdf +// 1. Graves et al.: http://www.cs.toronto.edu/~graves/icml_2006.pdf // We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based. -// Graves et al call the probabilities y, we use log_probs (also calling them inputs) +// Graves et al. call the probabilities y, we use log_probs (also calling them inputs) // A few optimizations (similar to those here, but also some I didn't take) are described in // 2. Minmin Sun: http://on-demand.gputechconf.com/gtc/2016/presentation/s6383-minmin-sun-speech-recognition.pdf #define TORCH_ASSERT_ONLY_METHOD_OPERATORS diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index aa3d9b3fb2f5..b7314756cec5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5426,7 +5426,7 @@ autogen: slice_backward.out # NB: This op exists to back the implementation of reverse view_funcs for various views (chunk, -# slice.Tensor, split_with_sizes, et. al.). Currently, these are only used during fake-ification +# slice.Tensor, split_with_sizes, et al.). Currently, these are only used during fake-ification # of PT2 graph input subclass instances that are views. This means: # * This op shouldn't really show up in eager mode (so e.g. XLA shouldn't have to implement it) # * This op shouldn't show up in a PT2 graph (so a PT2 backend shouldn't have to implement it) diff --git a/benchmarks/functional_autograd_benchmark/torchaudio_models.py b/benchmarks/functional_autograd_benchmark/torchaudio_models.py index e63db1d7cc02..04dc8969e329 100644 --- a/benchmarks/functional_autograd_benchmark/torchaudio_models.py +++ b/benchmarks/functional_autograd_benchmark/torchaudio_models.py @@ -219,7 +219,7 @@ def forward(self, x, output_lengths): class Lookahead(nn.Module): - # Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks + # Wang et al., 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks # input shape - sequence, batch, feature - TxNxH # output shape - same as input def __init__(self, n_features, context): diff --git a/torch/_lowrank.py b/torch/_lowrank.py index c739cc37178e..4641c4c4717c 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -21,7 +21,7 @@ def get_approximate_basis( of the size of :math:`A` or :math:`M`. .. note:: The implementation is based on the Algorithm 4.4 from - Halko et al, 2009. + Halko et al., 2009. .. note:: For an adequate approximation of a k-rank matrix :math:`A`, where k is not known in advance but could be @@ -94,7 +94,7 @@ def svd_lowrank( SVD is computed for the matrix :math:`A - M`. .. note:: The implementation is based on the Algorithm 5.1 from - Halko et al, 2009. + Halko et al., 2009. .. note:: For an adequate approximation of a k-rank matrix :math:`A`, where k is not known in advance but could be @@ -152,7 +152,7 @@ def _svd_lowrank( niter: Optional[int] = 2, M: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor]: - # Algorithm 5.1 in Halko et al 2009 + # Algorithm 5.1 in Halko et al., 2009 q = 6 if q is None else q m, n = A.shape[-2:] diff --git a/torch/_numpy/_funcs_impl.py b/torch/_numpy/_funcs_impl.py index de165d5db768..93f8a8ab1198 100644 --- a/torch/_numpy/_funcs_impl.py +++ b/torch/_numpy/_funcs_impl.py @@ -941,7 +941,7 @@ def choose( return choices[idx_list].squeeze(0) -# ### unique et al ### +# ### unique et al. ### def unique( @@ -1021,7 +1021,7 @@ def resize(a: ArrayLike, new_shape=None): return reshape(a, new_shape) -# ### diag et al ### +# ### diag et al. ### def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1): diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 37e62e874f5e..2fdf5ff6c0ae 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -119,7 +119,7 @@ class OneHotCategoricalStraightThrough(OneHotCategorical): through gradient estimator from [1]. [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation - (Bengio et al, 2013) + (Bengio et al., 2013) """ has_rsample = True diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index 05e0995e4a33..a41e1be1f029 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -30,10 +30,10 @@ class LogitRelaxedBernoulli(Distribution): logits (Number, Tensor): the log-odds of sampling `1` [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random - Variables (Maddison et al, 2017) + Variables (Maddison et al., 2017) [2] Categorical Reparametrization with Gumbel-Softmax - (Jang et al, 2017) + (Jang et al., 2017) """ arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.real diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index 245ab87aa2a7..707a80d05415 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -26,10 +26,10 @@ class ExpRelaxedCategorical(Distribution): logits (Tensor): unnormalized log probability for each event [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables - (Maddison et al, 2017) + (Maddison et al., 2017) [2] Categorical Reparametrization with Gumbel-Softmax - (Jang et al, 2017) + (Jang et al., 2017) """ arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = ( diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index 6050b7eaea60..e6136350b3a4 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -16,7 +16,7 @@ class PixelShuffle(Module): See the paper: `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ - by Shi et. al (2016) for more details. + by Shi et al. (2016) for more details. Args: upscale_factor (int): factor to increase spatial resolution by @@ -69,7 +69,7 @@ class PixelUnshuffle(Module): See the paper: `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ - by Shi et. al (2016) for more details. + by Shi et al. (2016) for more details. Args: downscale_factor (int): factor to decrease spatial resolution by diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index c0efc2443078..a95574a65aba 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -208,7 +208,7 @@ def step(self, closure=None): .. note:: The implementation of SGD with Momentum/Nesterov subtly differs from - Sutskever et. al. and implementations in some other frameworks. + Sutskever et al. and implementations in some other frameworks. Considering the specific case of Momentum, the update can be written as @@ -221,7 +221,7 @@ def step(self, closure=None): where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the parameters, gradient, velocity, and momentum respectively. - This is in contrast to Sutskever et. al. and + This is in contrast to Sutskever et al. and other frameworks which employ an update of the form .. math:: diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index c11314721f27..0505c749a7f9 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -3016,7 +3016,7 @@ def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mea return output -# this directly follows Graves et al's paper, in contrast to the production implementation, it does not use log-space +# this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'): input_lengths = torch.as_tensor(input_lengths, dtype=torch.long) target_lengths = torch.as_tensor(target_lengths, dtype=torch.long) From bf2c05352ea398e53bdf2072b1ae3a2fd174a4cd Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 5 Jun 2024 15:24:08 -0400 Subject: [PATCH 392/706] Make length == stop size oblivious too (#128050) This doesn't do anything right now (need some other PRs to activate) but since it edits a header file it would be better to land this earlier. Context: https://github.com/pytorch/pytorch/pull/127693 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/128050 Approved by: https://github.com/Skylion007, https://github.com/lezcano --- aten/src/ATen/TensorIndexing.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index b2ef33ffc058..fc951207f009 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -218,8 +218,8 @@ static inline Tensor applySlice( ? (*self_sizes)[dim] : self.sym_size(dim); if (!disable_slice_optimization && - TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) && length == stop && - step == 1) { + TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) && + TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) { return self; } } From 6adcf21b2be9217b54a6f2960751059d3bcbac20 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 6 Jun 2024 01:13:07 +0000 Subject: [PATCH 393/706] Documenting the torch.cuda.nccl.version function (#128022) Fixes #127892 This PR adds docstring to the torch.cuda.nccl.version function Pull Request resolved: https://github.com/pytorch/pytorch/pull/128022 Approved by: https://github.com/malfet --- torch/cuda/nccl.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torch/cuda/nccl.py b/torch/cuda/nccl.py index f1332c968d69..67d528771215 100644 --- a/torch/cuda/nccl.py +++ b/torch/cuda/nccl.py @@ -32,6 +32,15 @@ def is_available(tensors): def version(): + """ + Returns the version of the NCCL. + + + This function returns a tuple containing the major, minor, and patch version numbers of the NCCL. + The suffix is also included in the tuple if a version suffix exists. + Returns: + tuple: The version information of the NCCL. + """ ver = torch._C._nccl_version() major = ver >> 32 minor = (ver >> 16) & 65535 From b4a01614499b71513b118f25293896d841aabbd9 Mon Sep 17 00:00:00 2001 From: sdp Date: Thu, 6 Jun 2024 01:41:06 +0000 Subject: [PATCH 394/706] Build SYCL kernels for ATen XPU ops on Native Windows (take 2) (#127390) Original PR https://github.com/pytorch/pytorch/pull/126725 is closed due to bad rebase. ------- As proposed in https://github.com/pytorch/pytorch/issues/126719, we are enabling PyTorch XPU on Native Windows on Intel GPU. This PR enables XPU build on Windows as the first step of #126719: - Enable `USE_XPU` build on Windows using MSVC as host compiler. The use of MSVC as host compiler seamlessly aligns with the existing PyTorch build on Windows. - Build oneDNN GPU library on Windows. Co-authored-by: Yu, Guangye Pull Request resolved: https://github.com/pytorch/pytorch/pull/127390 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/ezyang --- CMakeLists.txt | 3 +-- README.md | 2 +- .../native/mkldnn/xpu/detail/oneDNNContext.h | 6 +++--- aten/src/ATen/xpu/XPUGeneratorImpl.h | 2 +- aten/src/ATen/xpu/detail/XPUHooks.cpp | 12 ++++++++++++ c10/util/Float8_fnuz_cvt.h | 6 ++++++ c10/xpu/CMakeLists.txt | 10 ++++++++++ c10/xpu/XPUFunctions.cpp | 12 ++++++++++-- c10/xpu/XPUMacros.h | 14 ++++++++++++++ c10/xpu/impl/xpu_cmake_macros.h.in | 6 ++++++ caffe2/CMakeLists.txt | 11 +++++++++-- cmake/Dependencies.cmake | 4 ++-- cmake/Modules/FindMKLDNN.cmake | 19 ++++++++++++------- cmake/Modules/FindSYCLToolkit.cmake | 17 +++++++++++++++++ torch/csrc/xpu/Module.cpp | 6 ++++++ 15 files changed, 110 insertions(+), 20 deletions(-) create mode 100644 c10/xpu/impl/xpu_cmake_macros.h.in diff --git a/CMakeLists.txt b/CMakeLists.txt index cd11ffdf7333..1264540c6875 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,8 +242,7 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON) option(USE_ASAN "Use Address+Undefined Sanitizers" OFF) option(USE_TSAN "Use Thread Sanitizer" OFF) option(USE_CUDA "Use CUDA" ON) -cmake_dependent_option(USE_XPU "Use XPU. Only available on Linux." ON "LINUX" - OFF) +option(USE_XPU "Use XPU" ON) cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF) diff --git a/README.md b/README.md index 9a4ba683d769..9123dea20107 100644 --- a/README.md +++ b/README.md @@ -189,7 +189,7 @@ Other potentially useful environment variables may be found in `setup.py`. ##### Intel GPU Support If you want to compile with Intel GPU support, follow these - [PyTorch Prerequisites for Intel GPUs](https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html) instructions. -- Intel GPU is currently supported only for Linux systems. +- Intel GPU is supported for Linux and Windows. If you want to disable Intel GPU support, export the environment variable `USE_XPU=0`. Other potentially useful environment variables may be found in `setup.py`. diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h index c7e7a5e94b40..afef4552c153 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h @@ -12,7 +12,7 @@ namespace at::native::onednn { -TORCH_API dnnl::memory make_onednn_memory( +TORCH_XPU_API dnnl::memory make_onednn_memory( dnnl::memory::desc md, dnnl::engine& engine, void* ptr); @@ -21,7 +21,7 @@ TORCH_API dnnl::memory make_onednn_memory( bool set_onednn_verbose(int level); // GpuEngineManager singleton -struct TORCH_API GpuEngineManager { +struct TORCH_XPU_API GpuEngineManager { static GpuEngineManager& Instance(); // Singleton dnnl::engine& get_engine(const Device& device) { @@ -51,7 +51,7 @@ struct TORCH_API GpuEngineManager { }; // GpuStreamManager singleton -struct TORCH_API GpuStreamManager { +struct TORCH_XPU_API GpuStreamManager { static GpuStreamManager& Instance(); // Singleton dnnl::stream get_stream() { diff --git a/aten/src/ATen/xpu/XPUGeneratorImpl.h b/aten/src/ATen/xpu/XPUGeneratorImpl.h index ce77d2e444e6..a1f264382a36 100644 --- a/aten/src/ATen/xpu/XPUGeneratorImpl.h +++ b/aten/src/ATen/xpu/XPUGeneratorImpl.h @@ -4,7 +4,7 @@ namespace at { -struct TORCH_API XPUGeneratorImpl : public GeneratorImpl { +struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl { // Constructors XPUGeneratorImpl(DeviceIndex device_index = -1); ~XPUGeneratorImpl() override = default; diff --git a/aten/src/ATen/xpu/detail/XPUHooks.cpp b/aten/src/ATen/xpu/detail/XPUHooks.cpp index 22f4ff22b4bb..61bc19faa95e 100644 --- a/aten/src/ATen/xpu/detail/XPUHooks.cpp +++ b/aten/src/ATen/xpu/detail/XPUHooks.cpp @@ -25,7 +25,13 @@ std::string XPUHooks::showConfig() const { int32_t XPUHooks::getGlobalIdxFromDevice(const at::Device& device) const { TORCH_CHECK(device.is_xpu(), "Only the XPU device type is expected."); +#ifdef _WIN32 + TORCH_CHECK( + false, + "Default context is not supported on XPU on Windows. So we can NOT find its global index of the ATen device."); +#else return at::xpu::getGlobalIdxFromDevice(device.index()); +#endif } Generator XPUHooks::getXPUGenerator(DeviceIndex device_index) const { @@ -38,7 +44,13 @@ const Generator& XPUHooks::getDefaultXPUGenerator( } Device XPUHooks::getDeviceFromPtr(void* data) const { +#ifdef _WIN32 + TORCH_CHECK( + false, + "Default context is not supported on XPU on Windows. So we can NOT find the ATen device of a pointer."); +#else return at::xpu::getDeviceFromPtr(data); +#endif } c10::DeviceIndex XPUHooks::getNumGPUs() const { diff --git a/c10/util/Float8_fnuz_cvt.h b/c10/util/Float8_fnuz_cvt.h index 983063a0230f..327f90d11a71 100644 --- a/c10/util/Float8_fnuz_cvt.h +++ b/c10/util/Float8_fnuz_cvt.h @@ -4,6 +4,10 @@ #include +#if defined(SYCL_LANGUAGE_VERSION) +#include +#endif + namespace c10::detail { /* @@ -33,6 +37,8 @@ inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) { // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) uint32_t renorm_shift = __clz(mantissa); +#elif defined(__SYCL_DEVICE_ONLY__) + uint32_t renorm_shift = sycl::clz(mantissa); #elif defined(_MSC_VER) unsigned long nonsign_bsr; _BitScanReverse(&nonsign_bsr, (unsigned long)mantissa); diff --git a/c10/xpu/CMakeLists.txt b/c10/xpu/CMakeLists.txt index d06d0f0aa92a..b5c63d4f7cca 100644 --- a/c10/xpu/CMakeLists.txt +++ b/c10/xpu/CMakeLists.txt @@ -8,6 +8,12 @@ if(NOT BUILD_LIBTORCHLESS) find_library(C10_XPU_LIB c10_xpu PATHS $ENV{LIBTORCH_LIB_PATH} NO_DEFAULT_PATH) endif() +# ---[ Configure macro file. +set(C10_XPU_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # used in xpu_cmake_macros.h.in +configure_file( + ${CMAKE_CURRENT_LIST_DIR}/impl/xpu_cmake_macros.h.in + ${CMAKE_BINARY_DIR}/c10/xpu/impl/xpu_cmake_macros.h) + set(C10_XPU_SRCS XPUCachingAllocator.cpp XPUFunctions.cpp @@ -50,3 +56,7 @@ foreach(file ${C10_XPU_HEADERS}) get_filename_component(dir ${file} DIRECTORY) install(FILES ${file} DESTINATION include/c10/xpu/${dir}) endforeach() + +if(MSVC AND C10_XPU_BUILD_SHARED_LIBS) + install(FILES $ DESTINATION lib OPTIONAL) +endif() diff --git a/c10/xpu/XPUFunctions.cpp b/c10/xpu/XPUFunctions.cpp index 15e24d94f5dc..cc885776a916 100644 --- a/c10/xpu/XPUFunctions.cpp +++ b/c10/xpu/XPUFunctions.cpp @@ -2,8 +2,6 @@ #include #include -#include -#include #include namespace c10::xpu { @@ -53,10 +51,20 @@ inline void initGlobalDevicePoolState() { return; } +#ifdef _WIN32 + // default context feature is disabled by default on Windows. + std::vector deviceList; + for (auto it = gDevicePool.devices.begin(); it != gDevicePool.devices.end(); + ++it) { + deviceList.push_back(*(*it)); + } + gDevicePool.context = std::make_unique(deviceList); +#else // The default context is utilized for each Intel GPU device, allowing the // retrieval of the context from any GPU device. gDevicePool.context = std::make_unique( gDevicePool.devices[0]->get_platform().ext_oneapi_get_default_context()); +#endif } inline void initDevicePoolCallOnce() { diff --git a/c10/xpu/XPUMacros.h b/c10/xpu/XPUMacros.h index fc6aad92229c..d51eab989d25 100644 --- a/c10/xpu/XPUMacros.h +++ b/c10/xpu/XPUMacros.h @@ -1,15 +1,29 @@ #pragma once +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif + // See c10/macros/Export.h for a detailed explanation of what the function // of these macros are. We need one set of macros for every separate library // we build. +#ifdef _WIN32 +#if defined(C10_XPU_BUILD_SHARED_LIBS) +#define C10_XPU_EXPORT __declspec(dllexport) +#define C10_XPU_IMPORT __declspec(dllimport) +#else +#define C10_XPU_EXPORT +#define C10_XPU_IMPORT +#endif +#else // _WIN32 #if defined(__GNUC__) #define C10_XPU_EXPORT __attribute__((__visibility__("default"))) #else // defined(__GNUC__) #define C10_XPU_EXPORT #endif // defined(__GNUC__) #define C10_XPU_IMPORT C10_XPU_EXPORT +#endif // _WIN32 // This one is being used by libc10_xpu.so #ifdef C10_XPU_BUILD_MAIN_LIB diff --git a/c10/xpu/impl/xpu_cmake_macros.h.in b/c10/xpu/impl/xpu_cmake_macros.h.in new file mode 100644 index 000000000000..48ed78c07e1d --- /dev/null +++ b/c10/xpu/impl/xpu_cmake_macros.h.in @@ -0,0 +1,6 @@ +#pragma once + +// Automatically generated header file for the C10 XPU library. Do not +// include this file directly. Instead, include c10/xpu/XPUMacros.h + +#cmakedefine C10_XPU_BUILD_SHARED_LIBS diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 2a58b15e8d5e..0d64fe75be41 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1062,8 +1062,15 @@ if(USE_XPU) message(WARNING "Failed to include ATen XPU implementation target") else() target_link_libraries(torch_xpu PRIVATE torch_xpu_ops) - target_link_libraries(torch_xpu PRIVATE - "-Wl,--whole-archive,\"$\" -Wl,--no-whole-archive") + if(MSVC) + # Windows + target_link_libraries(torch_xpu PRIVATE + "-WHOLEARCHIVE:\"$\"") + else() + # Linux + target_link_libraries(torch_xpu PRIVATE + "-Wl,--whole-archive,\"$\" -Wl,--no-whole-archive") + endif() endif() endif() diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 3a57dd64c6af..f1f2eb7cec31 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -89,8 +89,8 @@ endif() if(USE_XPU) include(${CMAKE_CURRENT_LIST_DIR}/public/xpu.cmake) if(NOT PYTORCH_FOUND_XPU) - # message(WARNING "Not compiling with XPU. Could NOT find SYCL." - # "Suppress this warning with -DUSE_XPU=OFF.") + message(WARNING "Not compiling with XPU. Could NOT find SYCL." + "Suppress this warning with -DUSE_XPU=OFF.") caffe2_update_option(USE_XPU OFF) endif() endif() diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index b93f9229fc23..382e71b1049b 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -21,10 +21,16 @@ IF(NOT MKLDNN_FOUND) if(USE_XPU) # Build oneDNN GPU library if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set(DNNL_HOST_COMPILER "g++") + # Linux # g++ is soft linked to /usr/bin/cxx, oneDNN would not treat it as an absolute path + set(DNNL_HOST_COMPILER "g++") + set(SYCL_CXX_DRIVER "icpx") + set(DNNL_LIB_NAME "libdnnl.a") else() - message(FATAL_ERROR "oneDNN library currently only supports GUN g++ compiler for XPU backend") + # Windows + set(DNNL_HOST_COMPILER "DEFAULT") + set(SYCL_CXX_DRIVER "icx") + set(DNNL_LIB_NAME "dnnl.lib") endif() set(DNNL_MAKE_COMMAND "cmake" "--build" ".") @@ -41,8 +47,7 @@ IF(NOT MKLDNN_FOUND) PREFIX ${XPU_MKLDNN_DIR_PREFIX} BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx - -DCMAKE_CXX_COMPILER=icpx - -DCMAKE_CXX_COMPILER_ID=IntelLLVM + -DCMAKE_CXX_COMPILER=${SYCL_CXX_DRIVER} -DDNNL_GPU_RUNTIME=SYCL -DDNNL_CPU_RUNTIME=THREADPOOL -DDNNL_BUILD_TESTS=OFF @@ -52,20 +57,20 @@ IF(NOT MKLDNN_FOUND) -DDNNL_DPCPP_HOST_COMPILER=${DNNL_HOST_COMPILER} # Use global cxx compiler as host compiler -G ${CMAKE_GENERATOR} # Align Generator to Torch BUILD_COMMAND ${DNNL_MAKE_COMMAND} - BUILD_BYPRODUCTS "xpu_mkldnn_proj-prefix/src/xpu_mkldnn_proj-build/src/libdnnl.a" + BUILD_BYPRODUCTS "xpu_mkldnn_proj-prefix/src/xpu_mkldnn_proj-build/src/${DNNL_LIB_NAME}" INSTALL_COMMAND "" ) ExternalProject_Get_Property(xpu_mkldnn_proj BINARY_DIR) set(__XPU_MKLDNN_BUILD_DIR ${BINARY_DIR}) - set(XPU_MKLDNN_LIBRARIES ${__XPU_MKLDNN_BUILD_DIR}/src/libdnnl.a) + set(XPU_MKLDNN_LIBRARIES ${__XPU_MKLDNN_BUILD_DIR}/src/${DNNL_LIB_NAME}) set(XPU_MKLDNN_INCLUDE ${__XPU_MKLDNN_BUILD_DIR}/include) # This target would be further linked to libtorch_xpu.so. # The libtorch_xpu.so would contain Conv&GEMM operators that depend on # oneDNN primitive implementations inside libdnnl.a. add_library(xpu_mkldnn INTERFACE) add_dependencies(xpu_mkldnn xpu_mkldnn_proj) - target_link_libraries(xpu_mkldnn INTERFACE ${__XPU_MKLDNN_BUILD_DIR}/src/libdnnl.a) + target_link_libraries(xpu_mkldnn INTERFACE ${__XPU_MKLDNN_BUILD_DIR}/src/${DNNL_LIB_NAME}) target_include_directories(xpu_mkldnn INTERFACE ${XPU_MKLDNN_INCLUDE}) endif() diff --git a/cmake/Modules/FindSYCLToolkit.cmake b/cmake/Modules/FindSYCLToolkit.cmake index d9345bb2fe0d..4a4a6dfaa789 100644 --- a/cmake/Modules/FindSYCLToolkit.cmake +++ b/cmake/Modules/FindSYCLToolkit.cmake @@ -55,6 +55,23 @@ find_library( HINTS ${SYCL_LIBRARY_DIR} NO_DEFAULT_PATH ) +# On Windows, currently there's no sycl.lib. Only sycl7.lib with version suffix, +# where the current version of the SYCL runtime is 7. +# Until oneAPI adds support to sycl.lib without the version suffix, +# sycl_runtime_version needs to be hardcoded and uplifted when SYCL runtime version uplifts. +# TODO: remove this when sycl.lib is supported on Windows +if(WIN32) + set(sycl_runtime_version 7) + find_library( + SYCL_LIBRARY + NAMES "sycl${sycl_runtime_version}" + HINTS ${SYCL_LIBRARY_DIR} + NO_DEFAULT_PATH + ) + if(SYCL_LIBRARY STREQUAL "SYCL_LIBRARY-NOTFOUND") + message(FATAL_ERROR "Cannot find a SYCL library on Windows") + endif() +endif() find_library( OCL_LIBRARY diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 7bf8abdef204..cfe7b43d19a9 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -11,24 +11,30 @@ #include #include +#ifndef WIN32 #include +#endif using namespace torch; static bool in_bad_fork = false; // True for children forked after xpu init +#ifndef WIN32 // Called in the forked child if xpu has already been initialized static void forked_child() { in_bad_fork = true; torch::utils::set_requires_device_init(at::kXPU, true); } +#endif // Should be called before the first xpu call. It is mainly called in lazy_init. // Note: This is distinct from initExtension because a stub xpu implementation // has some working functions (e.g. device_count) but cannot fully initialize. static void poison_fork() { +#ifndef WIN32 static c10::once_flag flag; c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); }); +#endif } // XPU management methods From 9795c4224bcfb317a517b966f00ac78a2debee22 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 6 Jun 2024 01:50:18 +0000 Subject: [PATCH 395/706] Revert "[DDP] Bucket handling: make first bucket size equal to bucket_cap_mb if it was set (#121640)" This reverts commit e98662bed99df57b7d79f9fc1cbe670afc303235. Reverted https://github.com/pytorch/pytorch/pull/121640 on behalf of https://github.com/clee2000 due to Sorry but it looks like you're failing `distributed/_composable/test_replicate_with_compiler.py::ReplicateTest::test_bucketing_coalesced_op `. THe build failed so the tests didn't run, consider rebasing, there have been a couple of PRs lately related to cudnn so you probably are either based on a bad or too old of a commit https://hud.pytorch.org/pytorch/pytorch/commit/e98662bed99df57b7d79f9fc1cbe670afc303235 https://github.com/pytorch/pytorch/actions/runs/9392731942/job/25868060913 ([comment](https://github.com/pytorch/pytorch/pull/121640#issuecomment-2151258585)) --- torch/nn/parallel/distributed.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index e95a2d9ab030..ef6034ade58e 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -548,8 +548,7 @@ class DistributedDataParallel(Module, Joinable): multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation. :attr:`bucket_cap_mb` controls the bucket size in - MebiBytes (MiB). If ``None``, a default size of 25 MiB - will be used. (default: ``None``) + MegaBytes (MB). (default: 25) find_unused_parameters (bool): Traverse the autograd graph from all tensors contained in the return value of the wrapped module's ``forward`` function. Parameters @@ -632,7 +631,7 @@ def __init__( dim=0, broadcast_buffers=True, process_group=None, - bucket_cap_mb=None, + bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, @@ -789,14 +788,7 @@ def __init__( self.broadcast_bucket_size = int(250 * 1024 * 1024) # reduction bucket size - if bucket_cap_mb is None: - # default case (bucket cap is 25 MiB) - self.bucket_bytes_cap_default = True - self.bucket_bytes_cap = int(25 * 1024 * 1024) - else: - self.bucket_bytes_cap_default = False - self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) - + self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) # Whether to perform input tensor CPU to GPU copies on a side-stream self.use_side_stream_for_tensor_copies = ( os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1" @@ -1164,13 +1156,10 @@ def _ddp_init_helper( if static_graph is True or self.find_unused_parameters is False: bucket_size_limits = [sys.maxsize] else: - if self.bucket_bytes_cap_default: - bucket_size_limits = [ - dist._DEFAULT_FIRST_BUCKET_BYTES, - self.bucket_bytes_cap, - ] - else: - bucket_size_limits = [self.bucket_bytes_cap] + bucket_size_limits = [ + dist._DEFAULT_FIRST_BUCKET_BYTES, + self.bucket_bytes_cap, + ] ( bucket_indices, per_bucket_size_limits, @@ -1206,9 +1195,7 @@ def _ddp_init_helper( param_to_name_mapping, # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first # bucket. - dist._DEFAULT_FIRST_BUCKET_BYTES - if self.bucket_bytes_cap_default - else self.bucket_bytes_cap, + dist._DEFAULT_FIRST_BUCKET_BYTES, ) self.logger = dist.Logger(self.reducer) From c1a43a69e422780dcfe5bc0171d921dc3a5b6836 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Thu, 6 Jun 2024 01:56:12 +0000 Subject: [PATCH 396/706] [NestedTensor] Add error checks for unbind operator coverage when ragged_idx != 1 (#128058) Summary: Add the following error checks for the `unbind` operator on `NestedTensor`s when `ragged_idx != 1`: - The current implementation allows the creation of `NestedTensor` instances from the class definition with an `offsets` tensor that applies to a dimension other than the jagged dimension. This diff ensures that `unbind` fails when the `offsets` exceed the length of the jagged dimension. Test Plan: Added the following unit tests: `test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu` verifies that `unbind` fails when there is a mismatch between the offsets and the jagged dimension, for `NestedTensor`s with `lengths`. ``` test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok ``` Reviewed By: davidberard98 Differential Revision: D57989082 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128058 Approved by: https://github.com/davidberard98 --- test/test_nestedtensor.py | 18 ++++++++++++++++++ torch/nested/_internal/ops.py | 5 +++++ 2 files changed, 23 insertions(+) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 2382cb40b522..ca50c93dd260 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -3865,6 +3865,24 @@ def test_unbind_lengths_ragged_idx_1(self, device): for i, t in enumerate(out): self.assertEqual(t, tensor_list[i]) + def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 8, 12, 13, 16], device=device) + lengths = torch.tensor([6, 2, 1, 2], device=device) + ragged_idx = 2 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, + offsets=offsets, + lengths=lengths, + _ragged_idx=ragged_idx) # 4D nested tensor + + self.assertRaisesRegex( + RuntimeError, + r"unbind\(\): nested tensor offsets and lengths.*", + lambda: nt.unbind() + ) + + def test_unbind_lengths_ragged_idx_2(self, device): values = torch.randn(16, 8, 128, device=device) offsets = torch.tensor([0, 2, 4, 8], device=device) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index cfbb50b395fa..85f62170595c 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -625,6 +625,11 @@ def unbind_int(func, *args, **kwargs): raise RuntimeError( "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)" ) + for i in range(lengths.shape[0]): + if offsets[i] + lengths[i] > values.shape[ragged_idx - 1]: + raise RuntimeError( + "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension" + ) return [ torch.narrow(values, dim=(ragged_idx - 1), start=offsets[i], length=lengths[i]) for i in range(lengths.shape[0]) From 2f7cfecd86009a9d396fdbdcdfb4ba7a005db16b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 5 Jun 2024 17:22:11 -0700 Subject: [PATCH 397/706] Complete revamp of float/promotion sympy handling (#126905) At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano --- c10/core/SymNodeImpl.h | 18 + test/dynamo/test_dynamic_shapes.py | 7 - test/dynamo/test_export.py | 3 +- test/dynamo/test_misc.py | 17 +- test/inductor/test_indexing.py | 72 +--- .../test_torchinductor_dynamic_shapes.py | 28 ++ test/onnx/test_fx_to_onnx_with_onnxruntime.py | 8 +- test/test_dynamic_shapes.py | 208 +++------ test/test_proxy_tensor.py | 3 +- test/test_sympy_utils.py | 122 +++--- torch/__init__.py | 162 ++++++- torch/_export/serde/serialize.py | 9 +- torch/_inductor/bounds.py | 5 + torch/_inductor/codegen/common.py | 176 ++++++-- torch/_inductor/codegen/cpp.py | 4 +- torch/_inductor/codegen/cpp_utils.py | 55 ++- torch/_inductor/codegen/triton.py | 64 ++- torch/_inductor/graph.py | 5 +- torch/_inductor/ir.py | 16 +- torch/_inductor/kernel/flex_attention.py | 5 +- torch/_inductor/lowering.py | 6 +- torch/_inductor/ops_handler.py | 60 ++- torch/_inductor/select_algorithm.py | 4 +- torch/_inductor/sizevars.py | 20 +- torch/_inductor/utils.py | 2 +- torch/_subclasses/fake_tensor.py | 2 +- torch/csrc/jit/python/init.cpp | 5 + torch/csrc/utils/python_symnode.h | 20 + torch/export/dynamic_shapes.py | 9 +- torch/fx/experimental/recording.py | 8 +- torch/fx/experimental/sym_node.py | 210 +++++++-- torch/fx/experimental/symbolic_shapes.py | 82 ++-- torch/fx/experimental/validator.py | 32 +- torch/utils/_sympy/functions.py | 398 ++++++++++++++---- torch/utils/_sympy/interp.py | 71 +++- torch/utils/_sympy/reference.py | 151 ++++--- torch/utils/_sympy/solve.py | 1 + torch/utils/_sympy/value_ranges.py | 275 ++++++++---- 38 files changed, 1674 insertions(+), 669 deletions(-) diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 9ffab5065109..bb92b09775b7 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -49,15 +49,33 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode mul(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + // NB: legacy, prefer float_truediv or int_truediv virtual SymNode truediv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode float_truediv(const SymNode& other) { + return truediv(other); + } + virtual SymNode int_truediv(const SymNode& other) { + return truediv(other); + } + // NB: legacy, prefer float_pow or pow_by_natural virtual SymNode pow(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode float_pow(const SymNode& other) { + return pow(other); + } + virtual SymNode pow_by_natural(const SymNode& other) { + return pow(other); + } + // NB: legacy, prefer int_floordiv virtual SymNode floordiv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode int_floordiv(const SymNode& other) { + return floordiv(other); + } virtual SymNode mod(const SymNode& other) { TORCH_CHECK(false, "NYI"); } diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 0bead6e47e48..a3c63ef66152 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -78,13 +78,6 @@ def make_dynamic_cls(cls): del test if TEST_Z3: - # this only fails when z3 is available - unittest.expectedFailure( - # SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'. - # Ref: https://github.com/sympy/sympy/issues/25146 - DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821 - ) - if not config.inline_inbuilt_nn_modules: # TODO model is somehow not being freed when z3 is available unittest.expectedFailure( diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 9f1417e23247..7ae0f839f6ff 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2385,8 +2385,7 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, "Constraints violated .*!(.*\n)*.*" - "by dim0 = 2\\*dim1(.*\n)*.*" - "Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*", + "Not all values of dim0 .* satisfy the generated guard 4 <= .* and .* <= 10(.*\n)*.*", ): torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index bcb0fd18818e..dc2b9530f0dd 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9309,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9343,7 +9343,7 @@ def test_shape_env_equal_unbacked(self): > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), u1: ValueRanges(lower=0, upper=1, is_bool=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False)} + > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False, is_int=True, is_float=False), u1: ValueRanges(lower=0, upper=1, is_bool=False, is_int=True, is_float=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False, is_int=False, is_float=True)} > Right: {} """, ) @@ -9420,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self): > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} """, ) self._replay_and_check(main) @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self): > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} """, ) self._replay_and_check(main) @@ -9484,10 +9484,7 @@ def test_shape_env_equal_runtime_assert(self): ShapeEnv not equal: field values don't match: ==> deferred_runtime_asserts: values don't match. - > Left: {u0: [Eq(Mod(u0, 3), 0)]} - > Right: {} -==> divisible: values don't match. - > Left: {Mod(u0, 3)} + > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} > Right: {} ==> name_to_node: values don't match. > Left: {_assert, eq, mod, u0} diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 299a619f9cd6..da527cfbb1d8 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -11,7 +11,12 @@ instantiate_parametrized_tests, parametrize, ) -from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Round, RoundDecimal +from torch.utils._sympy.functions import ( + FloorDiv, + ModularIndexing, + RoundDecimal, + RoundToInt, +) class TestIndexingSimplification(InductorTestCase): @@ -168,21 +173,11 @@ def test_print_pow(self): common_cases = [ # expr, result - # Test exprs. - ( - s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1), - lambda c, L: f"((-1{L})*({c}/((-1{L}) + (2{L}*foo)))) + (foo*({c}/((-1{L}) + (2{L}*foo))))", - ), - (s1 / (s2 - s3), lambda c, L: f"foo*({c}/(bar + ((-1{L})*baz)))"), # Test Pow directly. ( sympy.Pow(s1 + s2, 0), lambda _, L: f"1{L}", ), # note: simplified before _print_Pow - ( - sympy.Pow(s1 + s2, -3), - lambda c, _: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))", - ), ] gpu_cases = common_cases + [ @@ -231,12 +226,10 @@ def test_print_ceil(self): self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""") def test_print_round(self): - expr = Round(sympy.Symbol("x", integer=True) / 2) + expr = RoundToInt(sympy.Symbol("x", integer=True) / 2) self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""") self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""") - self.assertExpectedInline( - texpr(expr), """libdevice.llrint((1/2)*x).to(tl.int64)""" - ) + self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""") @parametrize("ndigits", [-1, 0, 1]) def test_print_round_decimal(self, ndigits): @@ -251,45 +244,18 @@ def test_print_round_decimal(self, ndigits): f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}", ) - expr = RoundDecimal(sympy.Symbol("x", integer=True), ndigits) - if ndigits >= 0: - for do_print in [pexpr, cexpr, texpr]: - self.assertEqual(do_print(expr), "x") - else: - self.assertEqual(pexpr(expr), f"round(x, {ndigits})") - for do_print in [cexpr, texpr]: - with self.assertRaisesRegex( - ValueError, "only non-negative ndigits are currently supported" - ): - do_print(expr) - def test_print_floor_div(self): - for integer in [True, False]: - s1 = sympy.Symbol("s1", integer=integer) - s2 = sympy.Symbol("s2", integer=integer) - expr = FloorDiv(s1, s2) - self.assertEqual(pexpr(expr), "(s1 // s2)") - if integer: - self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") - else: - self.assertEqual( - cexpr(expr), - "c10::div_floor_floating(static_cast(s1), static_cast(s2))", - ) - - for integer in [True, False]: - s1 = sympy.Symbol("s1", integer=integer) - s2 = sympy.S(-1) - expr = FloorDiv(s1, s2) - if integer: - self.assertEqual(pexpr(expr), "(-1)*s1") - self.assertEqual(cexpr(expr), "(-1L)*s1") - else: - self.assertEqual(pexpr(expr), "(s1 // (-1))") - self.assertEqual( - cexpr(expr), - "c10::div_floor_floating(static_cast(s1), static_cast((-1L)))", - ) + s1 = sympy.Symbol("s1", integer=True) + s2 = sympy.Symbol("s2", integer=True) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(s1 // s2)") + self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") + + s1 = sympy.Symbol("s1", integer=True) + s2 = sympy.S(-1) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(-1)*s1") + self.assertEqual(cexpr(expr), "(-1L)*s1") def test_print_Min_Max(self): cases = ( diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 8513e928c412..2f9506a9d561 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -3,6 +3,7 @@ import importlib import math +import operator import os import sys import unittest @@ -649,6 +650,33 @@ def fn(a): actual = cfn(5) self.assertEqual(expect, actual) + def test_interpolate_ceil_eq(self, device): + ceiling = math.ceil + IntTrueDiv = operator.truediv + + def fn(t): + s0, s2, s3 = t.size() + x = torch.zeros( + ( + s0, + 2048, + ceiling(IntTrueDiv(2 * ((s2 - 1) // 8) + 2, 1)), + ceiling(IntTrueDiv(2 * ((s3 - 1) // 8) + 2, 1)), + ), + dtype=torch.bfloat16, + ) + return torch.nn.functional.interpolate( + x, + scale_factor=2, + mode="nearest", + ) + + cfn = self.compile_fn(fn) + arg = torch.randn(4, 16, 18) + expect = fn(arg) + actual = cfn(arg) + self.assertEqual(expect, actual) + def test_full_recompiles(self, device): def fn(x): _, L = x.shape diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index b70bfbf9c4a7..0f0e01bc0dc2 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -158,8 +158,12 @@ def forward(self, x, y): torch.tensor([operator.sub(x.item(), y.item())]), torch.tensor([operator.mul(x.item(), y.item())]), torch.tensor([operator.truediv(x.item(), y.item())]), - torch.tensor([operator.floordiv(x.item(), y.item())]), - torch.tensor([operator.pow(x.item(), y.item())]), + # This requires torch.sym_float, probably easy to lower to + # ONNX but I don't know where to put it + # torch.tensor([operator.floordiv(x.item(), y.item())]), + # NB: abs so that the base and exponent are provably + # non-negative, so we don't generate runtime asserts + torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]), torch.tensor([operator.abs(x.item())]), torch.tensor([operator.neg(x.item())]), torch.tensor([math.ceil(x.item())]), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d548e9df0707..3b47f12198d5 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -205,15 +205,15 @@ def create_symtype(cls, pytype, shape_env, val, duck=True): # TODO: default duck to False -def create_symint(shape_env, i: int, duck=True): +def create_symint(shape_env, i: int, duck=True) -> SymInt: return create_symtype(SymInt, int, shape_env, i, duck=duck) -def create_symbool(shape_env, b: bool): +def create_symbool(shape_env, b: bool) -> SymBool: return create_symtype(SymBool, bool, shape_env, b) -def create_symfloat(shape_env, f: float): +def create_symfloat(shape_env, f: float) -> SymFloat: return create_symtype(SymFloat, float, shape_env, f) @@ -457,14 +457,16 @@ def test_sym_int(self): r = sym_int(a1 / 2) self.assertEqual(guard_int(r), 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" + ) a3 = create_symint(shape_env, 3) r = sym_int(2.0 * torch.sym_float(a3)) self.assertEqual(guard_int(r), 6) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)""" + str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" ) def test_sym_sqrt(self): @@ -474,7 +476,7 @@ def test_sym_sqrt(self): self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)""" + str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)""" ) def test_sym_floor(self): @@ -483,11 +485,17 @@ def test_sym_floor(self): r = math.floor(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""") + self.assertExpectedInline( + str(shape_env.guards[0][0]), + """Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""", + ) r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), + """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", + ) def test_sym_trunc(self): shape_env = ShapeEnv() @@ -495,12 +503,14 @@ def test_sym_trunc(self): r = math.trunc(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""") + self.assertExpectedInline( + str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" + ) r = torch.sym_int(torch.sym_sqrt(a0)) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)""" + str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)""" ) def test_sym_ceil(self): @@ -510,12 +520,17 @@ def test_sym_ceil(self): self.assertEqual(r, 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""" + str(shape_env.guards[0][0]), + """Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""", ) - r = math.floor(3.0 * a0) + r1 = 3.0 * a0 + r = math.floor(r1) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), + """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", + ) def test_sym_ite(self): shape_env = ShapeEnv() @@ -962,8 +977,14 @@ def test_ephemeral_source_unified_with_non_ephemeral_source(self): ) class TestSymNumberMagicMethods(TestCase): def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): + with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn): + return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn) + + def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn): # Helper function # NB: don't use one as that will get specialized + # TODO: We don't have to circuitously create the float, can just + # create a symfloat directly seed_node = (create_symint(shape_env, 2) / 2.0).node bool_seed_node = (create_symint(shape_env, 2) == 2).node @@ -976,27 +997,42 @@ def get_sym_inp(inp): else: return torch.SymFloat(to_node(seed_node, inp)) + if fn == "float_pow": + if inp1 < 0: + return + + if fn == "pow_by_natural": + if isinstance(inp1, float) or isinstance(inp2, float): + return + if inp2 < 0: + return + def maybe_xfail(inp1, inp2): if fn == "sym_sqrt" and inp1 < 0: # ValueError: math domain error return self.assertRaises((ValueError,)) - elif fn in ("truediv", "floordiv", "mod") and inp2 == 0: + elif ( + fn in ("float_truediv", "int_truediv", "int_floordiv", "mod") + and inp2 == 0 + ): # ZeroDivisionError: division by zero return self.assertRaises((ZeroDivisionError,)) - elif fn == "pow" and inp1 == 0 and inp2 < 0: + elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0: # ZeroDivisionError: 0.0 cannot be raised to a negative power return self.assertRaises((ZeroDivisionError,)) elif ( - fn == "pow" + # TODO: dear catastrophe waitress, + # this doesn't work + fn in ["float_pow", "pow_by_natural"] and inp1 < 0 - and inp2 in (2.5, -2.5) and ( - type(inp1) in (SymFloat, SymInt) or type(inp2) in (SymFloat, SymInt) + type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat) ) + and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float)) ): # Complex result, which we do not support: # TypeError: Cannot convert complex to float - return self.assertRaises((TypeError,)) + return self.assertRaises((RuntimeError,)) elif fn in ("lshift", "rshift") and not ( isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int)) ): @@ -1080,6 +1116,9 @@ def test_method(self, fn, first_type, second_type): ) and fn in sym_node.only_float_magic_methods: self.skipTest(f"{fn} is not an int method") + if second_type == "float" and fn in ["mod"]: + self.skipTest(f"{fn} only handles int") + is_unary_fn = fn in sym_node.unary_methods or fn == "round" # Second argument is ignored for unary function. So only run for one type if is_unary_fn and second_type == "float": @@ -1251,112 +1290,15 @@ def yield_test_cases(values, negate=True): yield (-x, -y) def test_floordiv_float_int(self): - values = ( - (2.5, 2.1), - (2.1, 2.5), - (2.0, 2.1), - (7, 2.5), - (2.1, 7), - (7, 2), - ) + values = ((7, 2),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) ) - def test_floordiv_bool(self): - values = ( - (False, True), - (True, 2.5), - (2.5, True), - (False, 7), - (7, True), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - # Compares to int since our FloorDiv has no bool support - self.assertEqual( - TestFloorDiv.python_floordiv(x, y), - TestFloorDiv.torch_floordiv(int(x), int(y)), - ) - # Tests that our impl throws - self.assertRaisesRegex( - TypeError, - ( - rf"unsupported operand type\(s\) for //: " - rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" - rf", expected integer or real" - ), - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_complex(self): - values = ( - (1.5 + 2.5j, 1.3 + 3.5j), - (1.5 + 2.5j, 2.5), - (2.5, 1.5 + 2.5j), - (1.5 + 2.5j, 7), - (7, 1.5 + 2.5j), - ) - - for x, y in TestFloorDiv.yield_test_cases(values): - # We don't test error messages to avoid depending on Python - # interpreter version - self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y)) - self.assertRaisesRegex( - TypeError, - ( - rf"unsupported operand type\(s\) for //: " - rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" - rf", expected integer or real" - ), - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_div_by_zero(self): - values = ( - (2.5, 0), - (2.1, 0.0), - (2.3, sympy.Symbol("s", zero=True)), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - # We don't test error messages to avoid depending on Python - # interpreter version - if type(y) is not sympy.Symbol: - self.assertRaises( - ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y) - ) - self.assertRaisesRegex( - ZeroDivisionError, - "division by zero", - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_zero_base(self): - values = ( - (0, 2.5), - (0.0, 2.1), - (sympy.Symbol("s", zero=True), 2.3), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - if type(x) is not sympy.Symbol: - self.assertEqual( - TestFloorDiv.python_floordiv(x, y), - TestFloorDiv.torch_floordiv(x, y), - ) - else: - self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y)) - def test_floordiv_div_by_one(self): - values = ( - (2.5, 1), - (2.1, 1.0), - (2, 1.0), - (2, 1), - ) + values = ((2, 1),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( @@ -1367,12 +1309,7 @@ def test_floordiv_simplify(self): # Tests how we simplify or evaluate FloorDiv without free variables shape_env = ShapeEnv() result = 21 - exprs = ( - 7 * FloorDiv(6, 2), - 7 * FloorDiv(6.28, 2), - 7 * FloorDiv(6.28, 2.0), - 7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))), - ) + exprs = (7 * FloorDiv(6, 2),) for expr in exprs: self.assertEqual(expr, result) @@ -1382,33 +1319,10 @@ def test_floordiv_simplify(self): self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result) - def test_floordiv_simplify_rational(self): - result = 21 - - a = sympy.Symbol("a", integer=True) - b = sympy.Symbol("b") - - cases = [ - (FloorDiv(a, sympy.Rational(1, 8)), 8 * a), - (FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)), - ] - - for expr, expected in cases: - self.assertEqual(expr, expected) - def test_floordiv_assumptions(self): - # We define two Symbols (with different names) for each type to make - # sure the behavior is consistent regardless of whether both arguments - # are the same object or not. cases = ( sympy.Symbol("i1", integer=True), sympy.Symbol("i2", integer=True), - sympy.Symbol("r1", real=True), - sympy.Symbol("r2", real=True), - sympy.Symbol("c1", complex=True, real=False, integer=False), - sympy.Symbol("c2", complex=True, real=False, integer=False), - sympy.Symbol("s1"), - sympy.Symbol("s2"), ) for base, divisor in itertools.product(cases, repeat=2): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index c7b2e51ced20..04483ffba0fc 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1618,7 +1618,8 @@ def f(a): self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) - pow_1 = sym_size_int ** 0.5; sym_size_int = None + sym_float = torch.sym_float(sym_size_int); sym_size_int = None + pow_1 = sym_float ** 0.5; sym_float = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None return div""") diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index c5da8f7fc0da..8b16b2c620fd 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -36,7 +36,12 @@ "floor", "ceil", ] -BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"] +BINARY_OPS = [ + "truediv", "floordiv", + # "truncdiv", # TODO + # NB: pow is float_pow + "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" +] UNARY_BOOL_OPS = ["not_"] BINARY_BOOL_OPS = ["or_", "and_"] @@ -81,16 +86,24 @@ def valid_unary(fn, v): def valid_binary(fn, a, b): if fn == "pow" and ( + # sympy will expand to x*x*... for integral b; don't do it if it's big b > 4 - or ( # sympy will expand to x*x*... for integral b; don't do it if it's big - a <= 0 and b == -1 - ) - or (a == b == 0) # no imaginary numbers # 0**0 is undefined + # no imaginary numbers + or a <= 0 + # 0**0 is undefined + or (a == b == 0) ): return False - elif fn == "mod" and b == 0: + elif fn == "pow_by_natural" and ( + # sympy will expand to x*x*... for integral b; don't do it if it's big + b > 4 + or b < 0 + or (a == b == 0) + ): return False - elif (fn == "div" or fn == "truediv") and b == 0: + elif fn == "mod" and (a < 0 or b <= 0): + return False + elif (fn in ["div", "truediv", "floordiv"]) and b == 0: return False return True @@ -130,27 +143,26 @@ def test_pow_half(self): ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) @parametrize("fn", BINARY_OPS) - @parametrize("dtype_a", ("int", "float")) - @parametrize("dtype_b", ("int", "float")) - def test_binary_ref(self, fn, dtype_a, dtype_b): + @parametrize("dtype", ("int", "float")) + def test_binary_ref(self, fn, dtype): to_dtype = {"int": sympy.Integer, "float": sympy.Float} - dtype_a = to_dtype[dtype_a] - dtype_b = to_dtype[dtype_b] + # Don't test float on int only methods + if dtype == "float" and fn in ["pow_by_natural", "mod"]: + return + dtype = to_dtype[dtype] for a, b in itertools.product(CONSTANTS, repeat=2): if not valid_binary(fn, a, b): continue - a = dtype_a(a) - b = dtype_b(b) + a = dtype(a) + b = dtype(b) with self.subTest(a=a, b=b): r = getattr(ValueRangeAnalysis, fn)(a, b) if r == ValueRanges.unknown(): continue ref_r = getattr(ReferenceAnalysis, fn)(a, b) - # sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf - if fn != "floordiv": - self.assertEqual(r.lower.is_integer, r.upper.is_integer) - self.assertEqual(ref_r.is_integer, r.upper.is_integer) + self.assertEqual(r.lower.is_integer, r.upper.is_integer) + self.assertEqual(ref_r.is_integer, r.upper.is_integer) self.assertEqual(r.lower, r.upper) self.assertEqual(ref_r, r.lower) @@ -200,7 +212,8 @@ def test_binary_bool_ref_range(self, fn): @parametrize("fn", UNARY_OPS) def test_unary_ref_range(self, fn): - vals = [-sympy.oo, *CONSTANTS, sympy.oo] + # TODO: bring back sympy.oo testing for float unary fns + vals = CONSTANTS for a in generate_range(vals): with self.subTest(a=a): ref_r = getattr(ValueRangeAnalysis, fn)(a) @@ -216,40 +229,26 @@ def test_unary_ref_range(self, fn): # This takes about 4s for all the variants @parametrize("fn", BINARY_OPS + COMPARE_OPS) def test_binary_ref_range(self, fn): - vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo] + # TODO: bring back sympy.oo testing for float unary fns + vals = LESS_CONSTANTS for a, b in itertools.product(generate_range(vals), repeat=2): # don't attempt pow on exponents that are too large (but oo is OK) if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: continue with self.subTest(a=a, b=b): - ref_r = getattr(ValueRangeAnalysis, fn)(a, b) for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): if a0 not in a or b0 not in b: continue if not valid_binary(fn, a0, b0): continue with self.subTest(a0=a0, b0=b0): + ref_r = getattr(ValueRangeAnalysis, fn)(a, b) r = getattr(ReferenceAnalysis, fn)( sympy.Integer(a0), sympy.Integer(b0) ) if r.is_finite: self.assertIn(r, ref_r) - def test_rational_bounds(self): - # Repro from https://github.com/pytorch/pytorch/issues/105097 - from sympy import floor, Eq - shape_0 = sympy.Symbol('shape_0', positive=True, integer=True) - new_expr = ( - Eq(30 * floor(4 * ((shape_0 + 1) // 96) * - ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 + - 2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647), - 2880 * floor(((shape_0 + 1) // 96) * - ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 + - 323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764))) - new_range_env = {shape_0: ValueRanges(lower=1, upper=190)} - self.assertTrue(new_expr.subs({shape_0: 95})) - self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)) - class TestSympyInterp(TestCase): @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) @@ -258,7 +257,13 @@ def test_interp(self, fn): if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): return - from sympy.abc import x, y + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Dummy('x', integer=is_integer) + y = sympy.Dummy('y', integer=is_integer) + vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] @@ -300,29 +305,17 @@ def test_python_interp_fx(self, fn): if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: arity = 2 - from sympy.abc import x, y + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Dummy('x', integer=is_integer) + y = sympy.Dummy('y', integer=is_integer) symbols = [x] if arity == 2: symbols = [x, y] - # Workaround mpf from symbol error - if fn == "minimum": - sympy_expr = sympy.Min(x, y) - elif fn == "maximum": - sympy_expr = sympy.Max(x, y) - else: - sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) - - if arity == 1: - def trace_f(px): - return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) - else: - def trace_f(px, py): - return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) - - gm = fx.symbolic_trace(trace_f) - for args in itertools.product(vals, repeat=arity): if arity == 1 and not valid_unary(fn, *args): continue @@ -330,11 +323,28 @@ def trace_f(px, py): continue if fn == "truncdiv" and args[1] == 0: continue - elif fn == "pow" and (args[0] == 0 and args[1] <= 0): + elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0): continue elif fn == "floordiv" and args[1] == 0: continue with self.subTest(args=args): + # Workaround mpf from symbol error + if fn == "minimum": + sympy_expr = sympy.Min(x, y) + elif fn == "maximum": + sympy_expr = sympy.Max(x, y) + else: + sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) + + if arity == 1: + def trace_f(px): + return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) + else: + def trace_f(px, py): + return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) + + gm = fx.symbolic_trace(trace_f) + self.assertEqual( sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), gm(*args) diff --git a/torch/__init__.py b/torch/__init__.py index 18f1752019ec..dfb1da76739d 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -316,6 +316,75 @@ def __index__(self): # Magic methods installed by torch.fx.experimental.sym_node + def __round__(self, ndigits=None): + return self + + def __truediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__float_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_truediv__(other) + + def __rtruediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rfloat_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_truediv__(other) + + def __floordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return torch.sym_float(math.floor(sym_float(self) / other)) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_floordiv__(other) + + def __rfloordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return torch.sym_float(math.floor(other / sym_float(self))) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_floordiv__(other) + + # nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + def __pow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__pow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + # Guards! This guard is necessary because we need to know it to + # determine the output type of this operation + if other >= 0: + return self.__pow_by_natural__(other) + else: + # Mercifully, when the exponent is negative, Python just promotes + # to doubles and does a float pow: + # + # if (Py_SIZE(b) < 0 && c == NULL) { + # /* if exponent is negative and there's no modulus: + # return a float. This works because we know + # that this calls float_pow() which converts its + # arguments to double. */ + # Py_DECREF(a); + # Py_DECREF(b); + # return PyFloat_Type.tp_as_number->nb_power(v, w, x); + # } + return sym_float(self).__pow__(sym_float(other)) + + def __rpow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rpow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + if self >= 0: # self is exponent + return self.__rpow_by_natural__(other) + else: + return sym_float(self).__rpow__(sym_float(other)) + def __eq__(self, other: object) -> builtins.bool: raise AssertionError("type stub not overridden") @@ -337,6 +406,24 @@ def __add__(self, other) -> "SymInt": def __mul__(self, other) -> "SymInt": raise AssertionError("type stub not overridden") + def __pow_by_natural__(self, other) -> "SymInt": + raise AssertionError("type stub not overridden") + + def __rpow_by_natural__(self, other) -> "SymInt": + raise AssertionError("type stub not overridden") + + def __int_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rint_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __int_floordiv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rint_floordiv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + def __sym_max__(self, other): raise AssertionError("type stub not overridden") @@ -371,9 +458,43 @@ def __init__(self, node): # class has a field named node that stores SymNode self.node = node + def __truediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__float_truediv__(sym_float(other)) + + def __rtruediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__rfloat_truediv__(sym_float(other)) + + def __floordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return torch.sym_float(math.floor(self / sym_float(other))) + + def __rfloordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return torch.sym_float(math.floor(sym_float(other) / self)) + def __bool__(self): return self.node.bool_() + # Symbolic power does NOT work with negative base, this is to avoid + # potential complex outputs + def __pow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(self >= 0) + return self.__float_pow__(other) + + def __rpow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(other >= 0) + return self.__rfloat_pow__(other) + # Magic methods installed by torch.fx.experimental.sym_node def __eq__(self, other: object) -> builtins.bool: @@ -391,6 +512,18 @@ def __le__(self, other) -> builtins.bool: def __ge__(self, other) -> builtins.bool: raise AssertionError("type stub not overridden") + def __float_pow__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rfloat_pow__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __float_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rfloat_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + def __trunc__(self): raise AssertionError("type stub not overridden") @@ -524,7 +657,12 @@ def sym_int(a): return py_int(a) # type: ignore[operator] def sym_max(a, b): - """ SymInt-aware utility for max().""" + """ + SymInt-aware utility for max which avoids branching on a < b. + Unlike builtins.max(), this only works for int/float, and it always + promotes to float if any argument is float (unlike builtins.max, which + will faithfully preserve the type of the input argument). + """ from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -532,14 +670,19 @@ def sym_max(a, b): if isinstance(a, (SymInt, SymFloat)): return a.__sym_max__(b) elif isinstance(b, (SymInt, SymFloat)): - # NB: If you actually care about preserving output type exactly - # if you do something like max(0, 0.0), it is NOT sound to treat - # min/max as commutative + # Due to promotion semantics, this is operator is commutative: + # max(1, 1.0) === max(1.0, 1) === 1.0 return b.__sym_max__(a) - return builtins.max(a, b) # type: ignore[operator] + # TODO: Probably can make bool work too, just lazy + assert isinstance(a, (builtins.int, builtins.float)), type(a) + assert isinstance(b, (builtins.int, builtins.float)), type(b) + if isinstance(a, builtins.float) or isinstance(b, builtins.float): + return builtins.float(builtins.max(a, b)) + else: + return builtins.max(a, b) def sym_min(a, b): - """ SymInt-aware utility for max().""" + """ SymInt-aware utility for min().""" from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -548,7 +691,12 @@ def sym_min(a, b): return a.__sym_min__(b) elif isinstance(b, (SymInt, SymFloat)): return b.__sym_min__(a) - return builtins.min(a, b) # type: ignore[operator] + assert isinstance(a, (builtins.int, builtins.float)), type(a) + assert isinstance(b, (builtins.int, builtins.float)), type(b) + if isinstance(a, builtins.float) or isinstance(b, builtins.float): + return builtins.float(builtins.min(a, b)) + else: + return builtins.min(a, b) # Drop in replacement for math.sqrt, math.sin, math.cos etc current_module = sys.modules[__name__] diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 8d6dc939fb5c..9a92c238f950 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1474,10 +1474,15 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: # Here we force symbols corresponding to SymInts to be at least integers. # Otherwise some expressions that the shape env would otherwise evaluate to False, # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. + # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024) sym = sym.subs( {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} ) - if isinstance(sym, sympy.Symbol): + # We need to check if the symbol has already been allocated, + # self.symbol_name_to_symbol is not enough because the + # integer-ification of symbols can induce simplification; + # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral + if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val: self.symbol_name_to_symbol[val.expr_str] = sym if hint is not None: self.shape_env.add_var_to_val(sym, hint) @@ -1496,7 +1501,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: free_symbols = sym.free_symbols for s in free_symbols: if s.name not in self.symbol_name_to_symbol: - self.symbol_name_to_symbol[s.name] = s + self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment] if vr := self.symbol_name_to_range.get(s.name): self.shape_env.constrain_symbol_range( s, diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 4640ec4dce6b..212b79e35bf9 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,3 +1,4 @@ +import logging import operator from functools import partial from typing import Any, Callable, Dict @@ -11,6 +12,9 @@ from .virtualized import V +log = logging.getLogger(__name__) + + class BoundVars: """ Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() @@ -55,6 +59,7 @@ def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: with V.set_ops_handler(ValueRangeAnalysis()): interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) interpreter.run(V.get_ops_handler(), initial_env=self._bounds) return self._bounds diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index f7b3e7a45d6e..dae72186df00 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -340,6 +340,8 @@ def propagate_scheduler_node(cls, node): DataTypePropagation.propagate_loopbody(node._body) +# This printer contains rules that are supposed to be generic for both C/C++ and +# Python class ExprPrinter(Printer): @staticmethod def paren(string): @@ -369,12 +371,6 @@ def all_in_parens(string): return string return f"({string})" - def _print_Infinity(self, expr): - return "math.inf" - - def _print_NegativeInfinity(self, expr): - return "-math.inf" - def _print_Relational(self, expr): return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) @@ -384,11 +380,14 @@ def _print_Mul(self, expr): def _print_Add(self, expr): return " + ".join(map(self.paren, map(self._print, expr.args))) + # NB: this is OK to put here, because Mod is only defined for positive + # numbers, and so across C/Python its behavior is consistent def _print_Mod(self, expr): return " % ".join(map(self.paren, map(self._print, expr.args))) - def _print_FloorDiv(self, expr): - raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) @@ -399,10 +398,84 @@ def _print_GreaterThan(self, expr): # Go figure... return " >= ".join(map(self.paren, map(self._print, expr.args))) + # NB: The C implementation is injected into codegen at + # torch/_inductor/codegen/wrapper.py def _print_align(self, expr): assert len(expr.args) == 1 return f"align({self._print(expr.args[0])})" + # This must be implemented because sympy will collect x * x into Pow(x, 2), without + # any explicit intervention. We print it just like x * x, notably, we + # never generate sympy.Pow with floats. + # + # NB: this pow by natural, you should never have used builtin sympy.pow + # for FloatPow, and a symbolic exponent should be PowByNatural. These + # means exp is guaranteed to be integer. + def _print_Pow(self, expr): + base, exp = expr.args + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + assert exp >= 0 + if exp > 0: + return "*".join([self.paren(base)] * exp) + else: # exp == 0 + return "1" + + # Explicit NotImplemented functions are to prevent default sympy printing + # behavior, which will just barf out ToFloat(...) to your IR. The error + # message is better here because it tells you which printer class it needs + # to go in. + + def _print_ToFloat(self, expr): + raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") + + def _print_Infinity(self, expr): + raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") + + def _print_NegativeInfinity(self, expr): + raise NotImplementedError( + f"_print_NegativeInfinity not implemented for {type(self)}" + ) + + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_PythonMod(self, expr): + raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") + + def _print_IntTrueDiv(self, expr): + raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") + + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr): + raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") + + def _print_TruncToInt(self, expr): + raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") + + def _print_RoundToInt(self, expr): + raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") + + def _print_RoundDecimal(self, expr): + raise NotImplementedError( + f"_print_RoundDecimal not implemented for {type(self)}" + ) + + # NB: Some float operations are INTENTIONALLY not implemented for + # printers. You can implement them as a quick unblock, but it is better + # to ask yourself why we haven't done this computation in the Tensor + # universe instead + + def _print_TruncToFloat(self, expr): + raise NotImplementedError( + f"_print_TruncToFloat not implemented for {type(self)}" + ) + def doprint(self, expr, *, simplify: bool = True): # TODO: why are people passing strings to the printer here :think: if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): @@ -411,6 +484,10 @@ def doprint(self, expr, *, simplify: bool = True): class PythonPrinter(ExprPrinter): + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"float({self._print(expr.args[0])})" + def _print_ModularIndexing(self, expr): x, div, mod = expr.args x = self.paren(self.doprint(x)) @@ -420,56 +497,72 @@ def _print_ModularIndexing(self, expr): x = f"({x} // {div})" return f"{x} % {mod}" + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # WARNING: this is dangerous for Triton, which has C-style modulus def _print_FloorDiv(self, expr): x, div = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) return f"({x} // {div})" + # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python + # does a special algorithm + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + def _helper_sqrt(self, expr): return f"math.sqrt({self._print(expr)})" def _print_OpaqueUnaryFn_sqrt(self, expr): return self._helper_sqrt(expr.args[0]) - def _print_Pow(self, expr): - # Pow() confuses triton + def _print_FloatPow(self, expr): base, exp = expr.args - # NB: Remember this is sizevar computation! You don't typically - # expect to have to do floating point computation including exponents - # in sizevar compute. Instead of adding support for floating - # point pow, you should make upstream retranslate the Sympy expression - # into Tensor expressions earlier and do that instead. - if exp == 0.5: - return self._helper_sqrt(base) - elif exp == -0.5: - return "1/" + self._helper_sqrt(base) - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - if exp > 0: - return "*".join([self.paren(base)] * exp) - elif exp < 0: - return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - return "1" + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + + # TODO: Not sure this works with Triton, even when base/exp are integral + def _print_PowByNatural(self, expr): + base, exp = expr.args + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" def _print_floor(self, expr): assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_TruncToInt(self, expr): assert len(expr.args) == 1 + # This also could have been int(), they'll do the same thing for float return f"math.trunc({self._print(expr.args[0])})" def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"math.ceil({self._print(expr.args[0])})" + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + def _print_Abs(self, expr): assert len(expr.args) == 1 return f"abs({self._print(expr.args[0])})" + # NB: It's expected that we've made explicit any promotion in the sympy + # expression, so it doesn't matter that Python max/min doesn't perform + # promotion def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" @@ -514,7 +607,7 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" @@ -653,6 +746,29 @@ def remainder(a, b): ) return ops.where(cond, ops.add(r, b), r) + @staticmethod + def trunc_to_int(a, dtype): + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def floor_to_int(a, dtype): + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a, dtype): + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def round_to_int(a, dtype): + return ops.to_dtype(ops.round(a), dtype) + + @staticmethod + def int_truediv(a, b): + # TODO: this is wrong + # TODO: an easy bandaid is to generate runtime asserts that it's + # <= 2**53, which is when this equation is correct + return ops.truediv(a, b) + @staticmethod def load_seed(name, offset): return ops.load(name, sympy.Integer(offset)) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index eabb5bbef470..311781102c3f 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -275,11 +275,11 @@ def visit_modular_indexing(divisor, modulus): original_index = index - div = sympy.Wild("divisor") + div = sympy.Wild("divisor", integer=True) if index.has(FloorDiv): index = index.replace(FloorDiv(var, div), visit_indexing_div) - mod = sympy.Wild("modulus") + mod = sympy.Wild("modulus", integer=True) if index.has(ModularIndexing): index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 4ab33a5e26dc..aac0c20df0c6 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -100,11 +100,54 @@ def _print_floor(self, expr): r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): assert len(expr.args) == 1 - r = f"std::trunc({self._print(expr.args[0])})" + r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::trunc({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" + + def _print_TruncToFloat(self, expr): + assert len(expr.args) == 1 + return f"std::trunc({self._print(expr.args[0])})" + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"static_cast({self._print(expr.args[0])})" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_CMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**53 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT + # use std::pow, that operates on floats + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + def _print_FloatPow(self, expr): + base, exp = expr.args + return f"std::pow({self._print(base)}, {self._print(exp)})" + def _print_Pow(self, expr): # Uses float constants to perform FP div base, exp = expr.args @@ -139,6 +182,11 @@ def _print_ceiling(self, expr): r = f"std::ceil({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_Min(self, expr): args = [self._print(a) for a in expr.args] if len(args) == 2: @@ -200,8 +248,9 @@ def _print_OpaqueUnaryFn_atan(self, expr): def _print_OpaqueUnaryFn_sqrt(self, expr): return f"std::sqrt({self._print(expr.args[0])})" - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 + # TODO: dispatch to llrint depending on index type return f"std::lrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 4b0ea92f3bf4..f74086615c66 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -272,23 +272,68 @@ def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]): return f"{value}[{', '.join(expand)}]" +# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a +# number of operators which Triton "implements", but in a way that is +# inconsistent with Python semantics (and consistent with C semantics). We +# must override all of these, or it is potential silent correctness problem class TritonPrinter(PythonPrinter): + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + return ( + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). If you are trying to hit this, maybe try something like + # torch.arange(n, device="cuda") - 1 and then do a modulus on it + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # TODO: This is wrong, see + # https://github.com/triton-lang/triton/issues/955 + # But for Sympy expressions, things will /mostly/ work out because we + # don't usually deal with negative numbers in the division + def _print_FloorDiv(self, expr): + assert expr.is_integer + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"({x} // {div})" + + # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher + # precision algorithm, which we would need to replicate here + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + # NB: sympy.floor/ceiling produce integers, so we have to do the + # conversion to index dtype def _print_floor(self, expr): assert len(expr.args) == 1 return ( f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): assert len(expr.args) == 1 return ( - f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + def _helper_sqrt(self, expr): return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))" @@ -359,20 +404,9 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" - def _print_FloorDiv(self, expr): - if expr.is_integer: - return super()._print_FloorDiv(expr) - - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"libdevice.floor({x} / {div}).to({V.kernel.index_dtype})" - - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 - return ( - f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" - ) + return f"libdevice.llrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): assert len(expr.args) == 2 diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 337a7375afa8..abe93686ac83 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1196,8 +1196,11 @@ def debug(msg): elif is_magic_method(n.target): # TODO: this is sus, it probably should be handled in the # lowerings themselves similarly to sym_size/sym-stride + # https://github.com/pytorch/pytorch/issues/127789 debug("is_magic_method") - if isinstance(n.meta["val"], torch.SymInt): + if isinstance( + n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) + ): result = n.meta["val"].node.expr else: result = super().run_node(n) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index c46cad5e41e2..e9adfcd19a2d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -44,7 +44,6 @@ is_boolean_dtype, is_float_dtype, make_channels_last_strides_for, - make_contiguous_strides_for, StrideType, ) from torch._subclasses.fake_tensor import get_schema_info @@ -236,7 +235,7 @@ def ir_node_to_tensor(x, guard_shape=True): if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] else: - stride = make_contiguous_strides_for(size) # type: ignore[arg-type] + stride = FlexibleLayout.contiguous_strides(size) # type: ignore[arg-type] dtype = x.get_dtype() device = x.get_device() size = convert_shape_to_symint(size) @@ -2766,6 +2765,7 @@ class FlexibleLayout(Layout): allow_indexing = False + # WARNING! This doesn't handle zero size tensors correctly @staticmethod def contiguous_strides(sizes): if len(sizes) == 0: @@ -5915,7 +5915,7 @@ def _original_deconv_weight_size( # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) if dynamic_shapes and is_contiguous_storage_and_layout(x): - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) else: output_stride = make_channels_last_strides_for(output_size) @@ -5967,7 +5967,7 @@ def _prepare_linear_fusion_create( assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" inputs = [x, weight] - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) kernel_layout = FixedLayout( x.get_device(), x.get_dtype(), @@ -6283,7 +6283,7 @@ def create(cls, x, packed_w, orig_w, B, batch_size): *m, _ = x.get_size() oc, _ = orig_w.get_size() output_size = list(m) + [oc] - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) inputs = [x, packed_w, orig_w] constant_args = [batch_size] if B is not None: @@ -6601,13 +6601,13 @@ def create( def get_strides_of_lstm_output(output_shape, batch_first): assert len(output_shape) == 3, "Expect output_shape to be 3D" - return make_contiguous_strides_for(output_shape) + return FlexibleLayout.contiguous_strides(output_shape) output_sizes = [output_shape, hy_shape, cy_shape] output_strides = [ get_strides_of_lstm_output(output_shape, batch_first), - make_contiguous_strides_for(hy_shape), - make_contiguous_strides_for(cy_shape), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), ] output_ir = [ MultiOutput( diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 42fabf65591d..f3492949a84d 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -5,7 +5,6 @@ from typing import Any, List, Tuple import torch -from torch._prims_common import make_contiguous_strides_for from .. import config from ..ir import ( ComputedBuffer, @@ -389,7 +388,7 @@ def flex_attention(*args, **kwargs): query.get_device(), query.get_dtype(), query.get_size(), - make_contiguous_strides_for(query.get_size()), + FlexibleLayout.contiguous_strides(query.get_size()), ) # see NOTE:[TritonTemplates with multiple outputs] logsumexp_shape = query.get_size()[:-1] # [B, H, M] @@ -745,7 +744,7 @@ def flex_attention_backward(*args, **kwargs): key.get_device(), key.get_dtype(), key.get_size(), - make_contiguous_strides_for(key.get_size()), + FlexibleLayout.contiguous_strides(key.get_size()), ) # Create delta which will is needed for the bwd's kernel diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 0a1909890e69..deec9b13e566 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -34,7 +34,7 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 @@ -4262,7 +4262,7 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): out_sz = out_sz[dim] in_sz = in_sz[dim] kernel_sz = kernel_sz[dim] - alpha = (in_sz - kernel_sz) / (out_sz - 1) + alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) samples_loader = samples.make_loader() def load(prefix, i): @@ -4372,7 +4372,7 @@ def upsample_nearest2d_backward( w_kernel_max = ceildiv(inp_w, out_w) def start_index(index, out_dim, inp_dim): - return CeilDiv(index * inp_dim, out_dim) + return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) def end_index(index, out_dim, inp_dim): return start_index((index + 1), out_dim, inp_dim) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 5630061b4426..f88cd948ca4d 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -138,6 +138,38 @@ def to_dtype( """ ... + def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with truncation semantics (similar to how the int + constructor works in Python). In Inductor codegen, this just decays + to trunc and then to_dtype, but this composite operation helps + roundtrips for Sympy evaluation. + + dtype is taken as an explicit parameter because the desired output + dtype is typically the index dtype, which may vary between int32 and + int64 depending on if we've shown that all the indexing operations can + be done in int32. + """ + ... + + def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def floor_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def round_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with round-to-even semantics. See also trunc_to_int. + """ + ... + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: """ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) @@ -398,21 +430,23 @@ def isinf(self, x0: T) -> T: def isnan(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation + # This rounds half to even to break ties def round(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation def floor(self, x0: T) -> T: ... def sign(self, x0: T) -> T: ... - def to_int(self, x0: T) -> T: - ... - + # NB: this returns a float, like the torch operation def trunc(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation def ceil(self, x0: T) -> T: ... @@ -449,6 +483,7 @@ def sub(self, x0: T, x1: T) -> T: def mul(self, x0: T, x1: T) -> T: ... + # NB: this returns a float, like the torch operation def pow(self, x0: T, x1: T) -> T: ... @@ -617,14 +652,21 @@ def truncdiv(self, x0: T, x1: T) -> T: def floordiv(self, x0: T, x1: T) -> T: """Python-style floor division between integers only. Computes the - true division of two numbers and floors the result. + true division of two numbers and floors the result. If you want + floor division for floats, do regular truediv and floor the result. """ ... def truediv(self, x0: T, x1: T) -> T: - """True division between floats. Integer inputs are NOT valid: to do - Python style (int, int) -> float division, promote the inputs to float - first.""" + """True division between floats. Integer inputs are NOT valid. To + do Python-style (int, int) -> float division, use int_truediv""" + ... + + def int_truediv(self, x0: T, x1: T) -> T: + """True division between integers. This is NOT the same as promoting + to float and doing integer division, there is a bespoke algorithm for + doing the division in higher precision than the above. + """ ... def div(self, x0: T, x1: T) -> T: @@ -640,6 +682,10 @@ def remainder(self, x0: T, x1: T) -> T: """Python-style modulus, take sign from RHS (x1).""" ... + def round_decimal(self, x0: T, x1: T) -> T: + """Python-style round with decimal argument""" + ... + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # In CUDA, optimized implementations of other mathematical operations are # offered separately via libdevice for double precision computation (in diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 5e5cbf35baf9..a1b029aa2883 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -386,7 +386,7 @@ def store_output( assert isinstance(mask, (str, type(None))) assert self.template_mask is None indices = list(map(TritonPrinter.paren, indices)) - index_symbols = [sympy.Symbol(x) for x in indices] + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] lengths = [ V.graph.sizevars.simplify(s) for s in self.output_node.get_size() ] @@ -410,7 +410,7 @@ def store_output( output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) if output_index == contiguous_index: - output_index = sympy.Symbol("xindex") + output_index = sympy.Symbol("xindex", integer=True) epilogue_args = [val] for input_node in itertools.chain( diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index bc8803a5e715..fba9a66f9237 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -161,9 +161,9 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(ModularIndexing): expr = expr.replace( ModularIndexing( - sympy.Wild("base"), - sympy.Wild("divisor"), - sympy.Wild("modulus"), + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + sympy.Wild("modulus", integer=True), ), visit_modular_indexing, ) @@ -171,8 +171,8 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(FloorDiv): expr = expr.replace( FloorDiv( - sympy.Wild("base"), - sympy.Wild("divisor"), + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), ), visit_indexing_div, ) @@ -604,11 +604,11 @@ def _join_dimensions_cached(expr: Expr) -> Expr: """ assert isinstance(expr, sympy.Add) - scale = sympy.Wild("scale", exclude=[0]) - base = sympy.Wild("base") - divisor = sympy.Wild("divisor") - mod1 = sympy.Wild("modulus") - mod2 = sympy.Wild("modulus2") + scale = sympy.Wild("scale", exclude=[0], integer=True) + base = sympy.Wild("base", integer=True) + divisor = sympy.Wild("divisor", integer=True) + mod1 = sympy.Wild("modulus", integer=True) + mod2 = sympy.Wild("modulus2", integer=True) for term1 in expr.args: m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) if m1: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 0915a8330c34..a635c2f509c1 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -192,7 +192,7 @@ def ceildiv( numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] ) -> Union[int, sympy.Expr]: if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): - return CeilDiv(numer, denom) + return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) # TODO: There is a bug in a call to this function, to repro: # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy # --amp --only YituTechConvBert --dynamic-shapes diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 47d4abcf77b9..9343490de3e8 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1727,7 +1727,7 @@ def go(t, real_t): for run_impl_check, op_impl in op_implementations_checks: if run_impl_check(func): op_impl_out = op_impl(self, func, *args, **kwargs) - if op_impl_out != NotImplemented: + if op_impl_out is not NotImplemented: return maybe_propagate_real_tensors(op_impl_out) def maybe_run_unsafe_fallback(error=None): diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index a7ce337f9ac8..2a3cb62c56d7 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1200,8 +1200,13 @@ void initJITBindings(PyObject* module) { SYMNODE_BINARY(sub) SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) + SYMNODE_BINARY(int_truediv) + SYMNODE_BINARY(float_truediv) SYMNODE_BINARY(pow) + SYMNODE_BINARY(float_pow) + SYMNODE_BINARY(pow_by_natural) SYMNODE_BINARY(floordiv) + SYMNODE_BINARY(int_floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(eq) SYMNODE_BINARY(ne) diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index f8c710cf6579..15738b1a67e1 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -198,14 +198,34 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__func__, other); } + c10::SymNode float_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode int_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode pow(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } + c10::SymNode float_pow(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode pow_by_natural(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode floordiv(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } + c10::SymNode int_floordiv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode mod(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index a4ed16e975b8..ac2bdd60a550 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,7 +1,6 @@ import builtins import dataclasses import inspect -import math import sys import weakref from collections import defaultdict @@ -254,11 +253,14 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): shared: Optional[_ConstraintTarget] = None debug_name: Optional[str] = None - def _clone_with_range(self, lower=0, upper=math.inf): + def _clone_with_range(self, lower=0, upper=None): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges + if upper is None: + upper = sys.maxsize - 1 + constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), warn_only=False, @@ -486,7 +488,6 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): ) # Import sympy locally - import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges @@ -496,7 +497,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): id(t), index, StrictMinMaxConstraint( - vr=ValueRanges(lower=0, upper=sympy.oo), warn_only=False + vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False ), debug_name=debug_name, ) diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 4bf9ebab17b3..28df3fddab0e 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -277,7 +277,13 @@ def wrapper(*args, **kwargs): raise except Exception: - log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs) + log.error( # noqa: G201 + "failed while running %s(*%s, **%s)", + name, + args[1:], + kwargs, + exc_info=log.isEnabledFor(logging.INFO), + ) raise return wrapper diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 98cba67a73a1..c7f0aba9fac4 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -267,8 +267,11 @@ def mul(self, other) -> "SymNode": def mod(self, other) -> "SymNode": return self._mod(other) # type: ignore[attr-defined] - def pow(self, other) -> "SymNode": - return self._pow(other) # type: ignore[attr-defined] + def float_pow(self, other) -> "SymNode": + return self._float_pow(other) # type: ignore[attr-defined] + + def pow_by_natural(self, other) -> "SymNode": + return self._pow_by_natural(other) # type: ignore[attr-defined] def and_(self, other) -> "SymNode": return self._and_(other) # type: ignore[attr-defined] @@ -276,11 +279,14 @@ def and_(self, other) -> "SymNode": def or_(self, other) -> "SymNode": return self._or_(other) # type: ignore[attr-defined] - def truediv(self, other) -> "SymNode": - return self._truediv(other) # type: ignore[attr-defined] + def float_truediv(self, other) -> "SymNode": + return self._float_truediv(other) # type: ignore[attr-defined] - def floordiv(self, other) -> "SymNode": - return self._floordiv(other) # type: ignore[attr-defined] + def int_truediv(self, other) -> "SymNode": + return self._int_truediv(other) # type: ignore[attr-defined] + + def int_floordiv(self, other) -> "SymNode": + return self._int_floordiv(other) # type: ignore[attr-defined] def lshift(self, other) -> "SymNode": return self._lshift(other) # type: ignore[attr-defined] @@ -361,6 +367,17 @@ def sym_or(self, other): def sym_and(self, other): return self.and_(other) + # There is no int_truediv available from C++ + def truediv(self, other): + return self.float_truediv(other) + + def floordiv(self, other) -> "SymNode": + return self.int_floordiv(other) + + # We didn't bind integer pow in C++ + def pow(self, other): + return self.float_pow(other) + def is_non_overlapping_and_dense(self, sizes, strides): return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] @@ -477,7 +494,7 @@ def is_constant(self): "eq": operator.eq, "floor": math.floor, "trunc": math.trunc, - "floordiv": operator.floordiv, + "int_floordiv": operator.floordiv, "ge": operator.ge, "gt": operator.gt, "is_integer": lambda x: x.is_integer(), @@ -489,7 +506,8 @@ def is_constant(self): "ne": operator.ne, "neg": operator.neg, "or": operator.or_, - "pow": operator.pow, + "float_pow": operator.pow, + "pow_by_natural": operator.pow, "round": builtins.round, "rshift": operator.rshift, "sub": operator.sub, @@ -498,12 +516,14 @@ def is_constant(self): "sym_max": sym_max, "sym_min": sym_min, "sym_not": sym_not, - "truediv": operator.truediv, + "float_truediv": operator.truediv, + "int_truediv": operator.truediv, } unary_magic_methods = { "abs", "sym_float", + "sym_int", "ceil", "floor", "neg", @@ -559,20 +579,20 @@ def fn(self): bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods # Methods that are only for float -only_float_magic_methods = {"is_integer"} +only_float_magic_methods = {"is_integer", "round", "sym_int"} magic_methods_on_operator_with_trailing_underscore = {"and", "or"} -always_float_magic_methods = {"truediv", "sym_float", "pow"} +always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} for name in math_op_names: sym_name = f"sym_{name}" always_float_magic_methods.add(sym_name) -always_int_magic_methods = {"ceil", "floor", "trunc"} +always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} always_bool_magic_methods = { "eq", "ne", @@ -590,10 +610,16 @@ def fn(self): # Methods that have a `__foo__` as well as `__rfoo__` -def _sympy_truediv(a, b): - from torch.utils._sympy.functions import TrueDiv +def _sympy_float_truediv(a, b): + from torch.utils._sympy.functions import FloatTrueDiv - return TrueDiv(a, b) + return FloatTrueDiv(a, b) + + +def _sympy_int_truediv(a, b): + from torch.utils._sympy.functions import IntTrueDiv + + return IntTrueDiv(a, b) def _sympy_floordiv(a, b): @@ -603,15 +629,24 @@ def _sympy_floordiv(a, b): def _sympy_mod(a, b): - from torch.utils._sympy.functions import Mod + from torch.utils._sympy.functions import Mod, PythonMod + + if a.is_nonnegative and b.is_nonnegative: + return Mod(a, b) + else: + return PythonMod(a, b) + - return Mod(a, b) +def _sympy_pow_by_natural(a, b): + from torch.utils._sympy.functions import PowByNatural + return PowByNatural(a, b) -def _sympy_pow(a, b): - from torch.utils._sympy.functions import Pow - return Pow(a, b) +def _sympy_float_pow(a, b): + from torch.utils._sympy.functions import FloatPow + + return FloatPow(a, b) def _sympy_and(a, b): @@ -643,11 +678,13 @@ def _sympy_rshift(a, b): "sub": operator.sub, "mul": operator.mul, "mod": _sympy_mod, - "pow": _sympy_pow, + "pow_by_natural": _sympy_pow_by_natural, + "float_pow": _sympy_float_pow, "and": _sympy_and, "or": _sympy_or, - "truediv": _sympy_truediv, - "floordiv": _sympy_floordiv, + "float_truediv": _sympy_float_truediv, + "int_truediv": _sympy_int_truediv, + "int_floordiv": _sympy_floordiv, "lshift": _sympy_lshift, "rshift": _sympy_rshift, } @@ -672,21 +709,23 @@ def _floor_ceil_helper(a, fn): def _sympy_floor(a): - import sympy + from torch.utils._sympy.functions import FloorToInt - return _floor_ceil_helper(a, sympy.floor) + return FloorToInt(a) +# NB: this is Python trunc semantics which returns an int. Do NOT use this to +# represent torch.trunc (which is float to float) def _sympy_trunc(a): - from torch.utils._sympy.functions import Trunc + from torch.utils._sympy.functions import TruncToInt - return Trunc(a) + return TruncToInt(a) def _sympy_ceil(a): - import sympy + from torch.utils._sympy.functions import CeilToInt - return _floor_ceil_helper(a, sympy.ceiling) + return CeilToInt(a) def _sympy_eq(a, b): @@ -771,26 +810,28 @@ def _sympy_abs(a): def _sympy_round(number, ndigits=None): - from torch.utils._sympy.functions import Round, RoundDecimal + from torch.utils._sympy.functions import RoundDecimal, RoundToInt if ndigits is None: - return Round(number) + return RoundToInt(number) else: return RoundDecimal(number, ndigits) def _sympy_sym_float(a): - # Cannot use sympy.Float(a) here, coz it expects python literals - # Multiply by 1.0 to cast to float. This is needed when the input - # is a SymInt which has the assumption that it is integer and - # SymPy will otherwise assume that return value cannot be a float. - return a * 1.0 + from torch.utils._sympy.functions import ToFloat + + # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly + # reports that it is an integer + return ToFloat(a) def _sympy_is_integer(a): import sympy - return sympy.Eq(sympy.floor(a), a) + from torch.utils._sympy.functions import ToFloat + + return sympy.Eq(ToFloat(sympy.floor(a)), a) magic_methods = { @@ -989,9 +1030,26 @@ def binary_magic_impl(self, other): self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) ) assert isinstance(other, SymNode) - # TODO: consider constant prop here try: - out = func(self.expr, other.expr) + if method == "mod": + from torch.utils._sympy.functions import Mod, PythonMod + + # Special handling for mod that requires access to the value + # ranges + shape_env = self.shape_env + if ( + self.expr.is_nonnegative + or shape_env.bound_sympy(self.expr).lower >= 0 + ) and ( + other.expr.is_nonnegative + or shape_env.bound_sympy(other.expr).lower >= 0 + ): + out = Mod(self.expr, other.expr) + else: + out = PythonMod(self.expr, other.expr) + else: + # TODO: consider constant prop here + out = func(self.expr, other.expr) except Exception: log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) raise @@ -1122,9 +1180,13 @@ def round_impl(self, ndigits=None): except Exception: log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) raise + out = safe_expand(out) - pytype = int if ndigits is None else self.pytype + if ndigits is None: + pytype = int + else: + pytype = self.pytype out_hint = None if self.hint is not None: @@ -1136,6 +1198,7 @@ def round_impl(self, ndigits=None): # hack down below works, because all round function down the line all take ndigits=None as default in their # signature. # TODO: Remove the args construction below if a different sentinel is used by FX. + # ezyang(May 2024): LOL args = [self.fx_node] if ndigits is not None: args.append(ndigits) @@ -1259,6 +1322,32 @@ def is_constant(x): return x.node.is_constant() return False + # Promotion rules for binary operations. NB: we preserve PYTHON semantics + # - if args are same type, do nothing + # - if one arg is float, promote other arg to float + # - nb: this applies to floordiv, even though output is integral + # (it's still float) + # - pow is funny business + # - if both ints + # - trigger a guard on exponent >= 0 + # - if non-negative, output is int + # - otherwise, output is float + # - otherwise, promote other arg to float + # - nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + # - equality is pain: Python does the fancy thing where it unpacks the + # mantissa from the float and then compares that against the int. + # Which means it is able to tell that + # 9007199254740993 != 9007199254740992. (rather than if the LHS was + # promoted to float, in which case it would have truncated to the RHS + # and subsequently been equal). We'll model this exactly by having + # special mixed type equality operations. Unfortunately, we need to + # do this for all comparison operations (maybe I'll only implement + # compare) + # - sym_ite mumble mumble really shouldn't allow mixed but whatever + if method in bool_becomes_int_magic_methods: def promote(x): @@ -1272,6 +1361,41 @@ def promote(x): def promote(x): return x + def promote2(self, other): + # TODO: Remove eq and other relations from this list. + # CPython has fancy implementations for these to get as much precision + # as possible instead of just promoting to float64 and praying, so we + # need to handle them specially too. + # Also, note that int_truediv doesn't go through this path: both + # arguments are "int" so there isn't any promotion + if method not in [ + "add", + "sub", + "mul", + "mod", + "float_pow", + "float_truediv", + "int_floordiv", + "sym_min", + "sym_max", + # TODO: remove these + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + ]: + return self, other + f_self = isinstance(self, (float, torch.SymFloat)) + f_other = isinstance(other, (float, torch.SymFloat)) + if f_self or f_other: + if not f_self: + self = torch.sym_float(self) + if not f_other: + other = torch.sym_float(other) + return self, other + # Before and after performing the operation, check if any operands are constant. # If so, extract out the constant values first. If `self` itself is a # constant, then "redispatch" by calling back into the operator. Sometimes @@ -1286,9 +1410,12 @@ def unary_magic_impl(self): return wrap_node(getattr(self.node, method_attr)()) def binary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented sym_node_log.debug("MAGIC %s %s %s", method, self, other) self = promote(self) other = promote(other) + self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): @@ -1300,8 +1427,11 @@ def binary_magic_impl(self, other): return get_constant(ret) if is_constant(ret) else ret def rbinary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented self = promote(self) other = promote(other) + self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index a2abde3a861e..687d2bcbd1eb 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -61,7 +61,7 @@ from torch import SymBool, SymFloat, SymInt from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator +from torch.utils._sympy.functions import FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.singleton_int import SingletonInt @@ -869,9 +869,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): for N=1. """ if min is None: - min = -sympy.oo + min = -sys.maxsize - 1 if max is None: - max = sympy.oo + max = sys.maxsize - 1 if max < min: raise ValueError( @@ -979,16 +979,6 @@ def eval_guards(gm, *args, ignore_static=True): def bind_symbols(gm, *args): return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) -def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges): - """ - We assert that the bounds are either Boolean, or not finite, or can be computed - in exact prevision via rational arithmetic. - The only exception to this is the rare case when the user calls `sqrt(s0)` - sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still) - """ - assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr) - assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr) - class DimDynamic(Enum): """ Controls how to perform symbol allocation for a dimension. It is always @@ -1387,14 +1377,19 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 'Min': min, 'Max': max, 'Mod': operator.mod, + 'PythonMod': operator.mod, 'FloorDiv': operator.floordiv, 'TrueDiv': operator.truediv, 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 'floor': math.floor, 'ceiling': math.ceil, + 'FloorToInt': math.floor, + 'CeilToInt': math.ceil, 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, - 'Round': builtins.round, + 'RoundToInt': builtins.round, 'RoundDecimal': builtins.round, + 'TruncToInt': math.trunc, + 'IntTrueDiv': operator.truediv, } @@ -1642,10 +1637,17 @@ def floor_div_handler(*args): congruence = (base - mod_reduced) % divisor if congruence != 0: self._congruences[s].add(congruence) + # NB: Must not be CleanDiv, it needs to be regular sympy division + # so inequality solver works. This is sort of problematic for + # is_integer tests though haha return (base - mod_reduced) / divisor if expr.has(Mod): expr = expr.replace(Mod, mod_handler) + # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative + # arguments should be OK. + if expr.has(PythonMod): + expr = expr.replace(PythonMod, mod_handler) if expr.has(FloorDiv): expr = expr.replace(FloorDiv, floor_div_handler) return expr @@ -3330,6 +3332,7 @@ def create_unbacked_symfloat(self): self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() + assert vr.is_float # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, float) @@ -3348,6 +3351,7 @@ def create_unbacked_symint(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = self._default_unspecified_value_range() + assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, int) @@ -3371,6 +3375,7 @@ def create_unbacked_symbool(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges(0, 1) + assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) @@ -3516,6 +3521,7 @@ def create_symbol( self.var_to_range[sympy_expr] &= constraint_dim.vr vr = self.var_to_range[sympy_expr] + assert vr.is_int if val not in vr: raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") @@ -3524,6 +3530,7 @@ def create_symbol( elif isinstance(val, float): self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) range_str = f"[{vr.lower}, {vr.upper}]" + assert vr.is_float else: # Skip var_range logic for SingletonInt # Only used for jagged layout nested tensors @@ -3573,6 +3580,7 @@ def create_symbol( def add_var_to_val(self, expr: sympy.Symbol, val: int): """ Adds a new symbol to the symbolic environment. """ + log.debug("add_var_to_val %s %s", expr, val, stack_info=True) assert expr not in self.var_to_val, f"{expr} already exists" self.var_to_val[expr] = sympy.Integer(val) @@ -4301,7 +4309,8 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa # Clamp values of size-like variables for x in self.size_like & var_to_range.keys(): if var_to_range[x] is not None: - var_to_range[x] = ValueRanges(2, sympy.oo) + var_to_range[x] = ValueRanges(2, sys.maxsize - 1) + assert var_to_range[x].is_int return bound_sympy(expr, var_to_range) @_lru_cache @@ -4418,6 +4427,11 @@ def _maybe_evaluate_static( vr = self._default_unspecified_value_range() if size_oblivious and k in self.size_like: lower = max(2, vr.lower) + # This is a bit dodgy: what this means is that there was a + # size-like unbacked symbol whose upper bound < 2. This + # causes... problems. + if lower <= vr.upper: + vr = ValueRanges(lower, vr.upper) else: lower = vr.lower # Don't do anything if we don't have a nontrivial lower bound @@ -4425,10 +4439,17 @@ def _maybe_evaluate_static( # SymInt if ( lower < (-sys.maxsize - 1) // 2 or - (unbacked_only and k in self.var_to_val) + (unbacked_only and k in self.var_to_val) or + not vr.is_int ): new_range_env[k] = vr continue + # The goal is to take our symbols which have various lower bounds + # and reallocate them into new symbols which are exactly positive; + # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in + # [1, inf], where s0 = ess0 + 1. This gives the most information + # to sympy for subsequent simplifications. + # # Positive means >= 1 # Positive - 1 means >= 0 # Positive + lower - 1 means >= lower @@ -4460,6 +4481,14 @@ def replace(expr, repl): self.counter["sympy_recursion_error"] += 1 return None + new_expr = safe_expand(new_expr) + if new_expr.is_number: + return new_expr + + # This is bad to do, the replacement with division leaves us with + # rationals when atom.args[0] is addition, e.g., sympy will happily + # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication! + """ floor_div_replace = {} for atom in new_expr.atoms(FloorDiv): floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) @@ -4468,13 +4497,12 @@ def replace(expr, repl): # are still free symbols if new_expr.is_number: return new_expr + """ # Check if the range can solve it statically out = bound_sympy(new_expr, new_range_env) - if expect_rational: - _assert_bound_is_rational(new_expr, out) - if out.is_singleton(): - return out.lower + if out.is_singleton(): + return out.lower return new_expr if unbacked_only else None @@ -4526,7 +4554,7 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": for fd in expr.atoms(FloorDiv): base, divisor = fd.args if self.replace(Mod(base, divisor)) in self.divisible: - div_replacements[fd] = base / divisor + div_replacements[fd] = CleanDiv(base, divisor) new_expr = expr.xreplace(div_replacements) new_expr = safe_expand(new_expr) new_pows = new_expr.atoms(sympy.Pow) @@ -4670,7 +4698,10 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1) def issubset(x, y): - return (x & int_range).issubset(y & int_range) + if x.is_int and y.is_int: + return (x & int_range).issubset(y & int_range) + else: + return x.issubset(y) # First, refine the value range of a based on the computed value range # of tgt. This is always OK to do, even if we decide not to do the @@ -4688,7 +4719,7 @@ def issubset(x, y): b = next(iter(tgt.free_symbols)) # Try to invert the equality r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) - if r is not None: + if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])): b_bound = self.bound_sympy(r[1]) self.var_to_range[b] = b_bound & self.var_to_range[b] tgt_bound = self.bound_sympy(tgt) @@ -4899,12 +4930,12 @@ def trivial_solve(lhs, rhs): ): # We have Mod(i0, q / c) == 0, which means we can # rewrite i0 as (q / gcd(q, c)) * i1 - d = q / sympy.gcd(q, c) + d = q / sympy.gcd(q, c) # TODO: CleanDiv? i1 = self.create_unbacked_symint().node.expr # Propagate the value ranges. It doesn't really # matter if we use truediv or floordiv, because we # have established divisibility. - self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv( + self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv( self.var_to_range[i0], ValueRanges.wrap(d) )) # Propagate size-like-ness @@ -5341,7 +5372,6 @@ def _refine_ranges(self, expr: sympy.Expr) -> None: lower, upper = vr.lower, vr.upper rhs_vr = bound_sympy(rhs, self.var_to_range) - _assert_bound_is_rational(rhs, rhs_vr) # Let's suppose that we have a preexisting range for x [0, 100]. # Now, we issue a guard x > y, where the range for y is [50, 150]. diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 6dcb59db7979..d06b38d60c80 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -216,10 +216,7 @@ def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: def abs(self, number: z3.ArithRef) -> z3.ArithRef: return z3.Abs(number) - def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: - if ndigits is not None: - raise ValueError("round(..., ndigits=) is currently not supported by shape validations.") - + def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: # Pythons builtin 'round' implements the 'round half to even' strategy # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to @@ -284,7 +281,7 @@ def wrapper(*args): operator.truediv: lift(ops.div), operator.mod: lift(ops.mod), operator.abs: lift(ops.abs), - builtins.round: lift(ops.round), + builtins.round: lift(ops.round_to_int), # Math module. math.ceil: lift(ops.ceil), @@ -350,6 +347,7 @@ def __init__( self._ops = _Z3Ops(self._validator) def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + # TODO: Probably OK to relax this and allow lower precision if dtype is torch.int64: return z3.IntVal(int(value)) if dtype is torch.double: @@ -358,6 +356,20 @@ def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: return z3.BoolVal(bool(value)) raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") + def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + if dtype == torch.float64: + return z3.ToReal(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return z3.ToInt(x) + + def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.round_to_int(x) + + def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: return self._ops.div(numerator, denominator) @@ -370,11 +382,17 @@ def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: return self._ops.pow(base, exp) + def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: return self._ops.mod(p, q) - def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: - return self._ops.round(number, ndigits) + def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.ceil(x) + + def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.floor(x) def __getattr__(self, name: str) -> Any: REPLACEMENT = { diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 1384261b4512..128ce537c019 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,43 +1,78 @@ +import functools import math +import sys import sympy from sympy import S -from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or __all__ = [ "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", - "Pow", - "TrueDiv", + "IntTrueDiv", + "FloatTrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", - "Round", + "RoundToInt", "RoundDecimal", + "ToFloat", + "FloatPow", + "PowByNatural", ] +def _keep_float(f): + @functools.wraps(f) + def inner(*args): + r = f(*args) + if any(isinstance(a, sympy.Float) for a in args) and not isinstance( + r, sympy.Float + ): + r = sympy.Float(float(r)) + return r + + return inner + + def fuzzy_eq(x, y): if None in (x, y): return None return x == y +# It would be nice to have assertions on whether or not inputs is_integer +# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy +# sometimes inconsistently reports floats an integers. +# +# What we can assume from sympy is that if something is an int, it +# definitely is is_integer, but if it is a float it may or may not +# be is_integer. So we are unable to do strong asserts that things +# are NOT integers. + + +# TODO: In Triton, // rounds to zero, but in Python, it is floor division. +# When we can prove both arguments are non-negative, we should just have a +# GenericFloorDiv (name pending) which can codegen efficiently in Python/C, +# and then PythonFloorDiv and CIntDiv which have the appropriate rounding +# semantics. +# +# Right now, FloorDiv de facto changes behavior if arguments are negative or +# not, this can potentially cause correctness issues. class FloorDiv(sympy.Function): """ We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) + + NB: This is Python-style floor division, round to -Inf """ nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 - # Default return type for SymPy assumptions. - # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers - is_real = True + is_integer = True @property def base(self): @@ -52,29 +87,14 @@ def _sympystr(self, printer): divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" - # SymPy assumptions based on argument types. - def _eval_is_real(self): - return fuzzy_or([self.base.is_real, self.divisor.is_real]) - - def _eval_is_integer(self): - return fuzzy_and([self.base.is_integer, self.divisor.is_integer]) - # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod def eval(cls, base, divisor): - def check_supported_type(x): - if ( - x.is_integer is False and x.is_real is False and x.is_complex - ) or x.is_Boolean: - raise TypeError( - f"unsupported operand type(s) for //: " - f"'{type(base).__name__}' and '{type(divisor).__name__}'" - f", expected integer or real" - ) - - check_supported_type(base) - check_supported_type(divisor) + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # Assert triggered by inequality solver + # assert base.is_integer, base + # assert divisor.is_integer, divisor # We don't provide the same error message as in Python because SymPy # makes it difficult to check the types. @@ -85,26 +105,22 @@ def check_supported_type(x): return sympy.S.Zero if base.is_integer and divisor == 1: return base - if base.is_real and divisor == 1: - return sympy.floor(base) if base.is_integer and divisor == -1: return sympy.Mul(base, -1) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): - return base // divisor - if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance( - divisor, (sympy.Integer, sympy.Float) - ): - return sympy.floor(base / divisor) + return sympy.Integer(int(base) // int(divisor)) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) - if isinstance(divisor, sympy.Rational) and divisor.p == 1: - return sympy.floor(base * divisor.q) + # gcd in sympy is over polynomials, so you'll end up with rationals if + # you do this. Don't. + """ if isinstance(base, sympy.Add): for a in base.args: gcd = sympy.gcd(a, divisor) if gcd == divisor: return FloorDiv(base - a, divisor) + a / gcd + """ try: gcd = sympy.gcd(base, divisor) @@ -189,6 +205,19 @@ class Where(sympy.Function): nargs = (3,) + def _eval_is_integer(self): + return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] + + def _eval_is_nonnegative(self): + return ( + True + if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] + else None + ) + + def _eval_is_positive(self): + return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] + @classmethod def eval(cls, c, p, q): if c == sympy.true: @@ -197,28 +226,27 @@ def eval(cls, c, p, q): return q -class Mod(sympy.Function): - """ - We maintain this so that we avoid SymPy correctness issues, such as: - https://github.com/sympy/sympy/issues/25146 - """ - +# Python-style modulus: take sign from RHS +class PythonMod(sympy.Function): nargs = (2,) + is_integer = True + @classmethod def eval(cls, p, q): - # This was adapted from: sympy/core/mod.py + # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint + # Triggered by sympy.solvers.inequalities.reduce_inequalities + # assert p.is_integer, p + # assert q.is_integer, q if q.is_zero: raise ZeroDivisionError("Modulo by zero") - # If either of them is NaN or infinite. - if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False: - return S.NaN + # Three cases: # 1. p == 0 # 2. p is either q or -q # 3. p is integer and q == 1 - if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1): + if p is S.Zero or p in (q, -q) or q == 1: return S.Zero # Evaluate if they are both literals. @@ -247,10 +275,7 @@ def eval(cls, p, q): if sympy.Mod(p, q) == 0: return S.Zero - def _eval_is_integer(self): - p, q = self.args - return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined] - + # NB: args[1] for PythonMod def _eval_is_nonnegative(self): return True if self.args[1].is_positive else None # type: ignore[attr-defined] @@ -258,6 +283,58 @@ def _eval_is_nonpositive(self): return True if self.args[1].is_negative else None # type: ignore[attr-defined] +# Generic modulus: only defined on non-negative arguments +class Mod(sympy.Function): + nargs = (2,) + + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, p, q): + # This was adapted from: sympy/core/mod.py + + # Triggered by + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # assert p.is_integer, p + # assert q.is_integer, q + + if q.is_zero: + raise ZeroDivisionError("Modulo by zero") + + # Three cases: + # 1. p == 0 + # 2. p is either q or -q + # 3. p is integer and q == 1 + if p is S.Zero or p in (q, -q) or q == 1: + return S.Zero + + # Evaluate if they are both literals. + if q.is_Number and p.is_Number: + assert p >= 0, p + assert q >= 1, q + return p % q + + # If q == 2, it's a matter of whether p is odd or even. + if q.is_Number and q == 2: + if p.is_even: + return S.Zero + if p.is_odd: + return S.One + + # If p is a multiple of q. + r = p / q + if r.is_integer: + return S.Zero + + # If p < q and its ratio is positive, then: + # - floor(p / q) = 0 + # - p % q = p - floor(p / q) * q = p + less = p < q + if less.is_Boolean and bool(less) and r.is_positive: + return p + + class CleanDiv(FloorDiv): """ Div where we can assume no rounding. @@ -267,6 +344,36 @@ class CleanDiv(FloorDiv): pass +# Don't use sympy ceiling/floor as they will attempt simplifications involving +# frac +class CeilToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): + return sympy.Integer(math.ceil(float(number))) + + +class FloorToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): + return sympy.Integer(math.floor(float(number))) + + class CeilDiv(sympy.Function): """ Div used in indexing that rounds up. @@ -275,6 +382,8 @@ class CeilDiv(sympy.Function): is_integer = True def __new__(cls, base, divisor): + base = sympy.sympify(base) + divisor = sympy.sympify(divisor) if sympy.gcd(base, divisor) == divisor: return CleanDiv(base, divisor) else: @@ -282,6 +391,8 @@ def __new__(cls, base, divisor): class LShift(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, shift): if shift < 0: @@ -290,6 +401,8 @@ def eval(cls, base, shift): class RShift(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, shift): if shift < 0: @@ -297,28 +410,107 @@ def eval(cls, base, shift): return base // 2**shift -# Overloaded to be compatible with regular Python. -# https://github.com/pytorch/pytorch/issues/90900 -class Pow(sympy.Function): +def safe_pow(base, exp): + sign = 1 + if base < 0: + base = -base + sign = 1 if exp % 2 == 0 else -1 + return sign * _safe_pow(base, exp) + + +def _safe_pow(base, exponent): + if exponent < 0: + raise ValueError("Exponent must be non-negative.") + + if exponent == 0: + return 1 + + half_exp = safe_pow(base, exponent // 2) + if half_exp > sys.maxsize - 1: + return sys.maxsize - 1 + + result = half_exp * half_exp + if result > sys.maxsize - 1: + return sys.maxsize - 1 + + if exponent % 2 == 1: + result *= base + if result > sys.maxsize - 1: + return sys.maxsize - 1 + + return result + + +class PowByNatural(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, exp): - if exp.is_zero: - return sympy.Integer(1) - elif base.is_zero and exp < 0: - raise ZeroDivisionError(f"{base} cannot be raised to a negative power") - else: - return base**exp + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Integer(safe_pow(base, exp)) + if isinstance(exp, sympy.Integer): + # Translate power into iterated multiplication + r = sympy.Integer(1) + for _ in range(int(exp)): + r *= base + return r + # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp + # is a natural number if we do + + +# base is assumed to be nonnegative, thereby prevent complex numbers from +# occuring +class FloatPow(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, base, exp): + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Float(float(base) ** float(exp)) + # NB: do not do any nontrivial reasoning # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 -class TrueDiv(sympy.Function): +# +# In particular, sympy division is willing to simplify x/x == 1 +# where 1 is an integer, but this must be a float if x was float. +class FloatTrueDiv(sympy.Function): + is_integer = False + is_real = True + @classmethod def eval(cls, base, divisor): + # assert base.is_integer is not True, base + # assert divisor.is_integer is not True, divisor + if divisor.is_zero: raise ZeroDivisionError("division by zero") - else: - return base / divisor + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(float(base) / float(divisor)) + + +# Overloaded to be compatible with regular Python. We distinguish this from +# FloatTrueDiv, because the code generation has to be different for this case: +# Python has a fancy algorithm for integer true division that isn't just +# "promote both arguments to float and use float division", so you need to +# codegen it differently. While technically you can work it out from the +# types of the input, this is often inconvenient to do in Inductor codegen, +# so just have a different operator +# NB: Right now, Inductor codegen doesn't implement this correctly lol +class IntTrueDiv(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, base, divisor): + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(int(base) / int(divisor)) # TODO: As an indicator, this != 0 implies == 1 (and vice versa). @@ -353,45 +545,85 @@ def eval(cls, *args): return None -class Trunc(sympy.Function): +# NB: this is inconsistent with math.trunc in Python +class TruncToFloat(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if isinstance(number, sympy.Number): + # NB: It is safe to use truncation to integer, which is what + # math.trunc does, as Python integers are arbitrary precision and + # so we are guaranteed not to lose precision when we do this + return sympy.Float(math.trunc(float(number))) + + +class TruncToInt(sympy.Function): is_integer = True @classmethod def eval(cls, number): - if number.is_integer: - return number - elif isinstance(number, sympy.Number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): return sympy.Integer(math.trunc(float(number))) -class Round(sympy.Function): +# This is float -> int +class RoundToInt(sympy.Function): is_integer = True @classmethod def eval(cls, number): - if number.is_integer: - return number - elif isinstance(number, sympy.Number): - return sympy.Integer(round(float(number))) + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Float): + return sympy.Integer(round(float(number), 0)) - def __int__(self): - # This will only ever be called when computing size hints. At that point, self.args[0] should be a number and - # no longer an expression. If it were, the float call would fail and the caller would handle this further. - return round(float(self.args[0])) # type: ignore[arg-type] +# To get float -> int, Python style round semantics. +# +# x = PyFloat_AsDouble(self); +# if (o_ndigits == Py_None) { +# /* single-argument round or with None ndigits: +# * round to nearest integer */ +# rounded = round(x); +# if (fabs(x-rounded) == 0.5) +# /* halfway case: round to even */ +# rounded = 2.0*round(x/2.0); +# return PyLong_FromDouble(rounded); +# } + +# NB: Like Round, this only ever returns floats. ndigits cannot be None class RoundDecimal(sympy.Function): + is_integer = False + is_real = True + @classmethod def eval(cls, number, ndigits): - if number.is_integer and ndigits >= 0: + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): + return sympy.Float(round(float(number), int(ndigits))) + + +class ToFloat(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, number): + if number in [sympy.oo, -sympy.oo]: return number - elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): - value_type, output_type = ( - (int, sympy.Integer) - if isinstance(number, sympy.Integer) - else (float, sympy.Float) - ) - return output_type(round(value_type(number), int(ndigits))) + + if isinstance(number, sympy.Integer): + return sympy.Float(int(number)) def make_opaque_unary_fn(name): diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 806e91cfe281..09a4b8384749 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -15,16 +15,23 @@ import torch from .functions import ( + CeilToInt, CleanDiv, + FloatPow, + FloatTrueDiv, FloorDiv, + FloorToInt, + IntTrueDiv, IsNonOverlappingAndDenseIndicator, Mod, ModularIndexing, - Pow, - Round, + PowByNatural, + PythonMod, RoundDecimal, - TrueDiv, - Trunc, + RoundToInt, + ToFloat, + TruncToFloat, + TruncToInt, Where, ) @@ -49,30 +56,39 @@ def handlers(): sympy.Le: "le", sympy.Ge: "ge", sympy.Not: "not_", - TrueDiv: "truediv", + IntTrueDiv: "int_truediv", + FloatTrueDiv: "truediv", FloorDiv: "floordiv", - CleanDiv: "div", - Trunc: "trunc", + CleanDiv: "floordiv", # TODO: hmm? + TruncToFloat: "trunc", Where: "where", sympy.Add: "add", sympy.Mul: "mul", - Pow: "pow", - sympy.Pow: "pow", + FloatPow: "pow", + PowByNatural: "pow_by_natural", + # sympy simplifies x * x into Pow(x, 2), so we need to handle this. + # Do NOT use builtin Pow for floats + # TODO: There is a hazard here, if we have float * float it will + # also get turned into Pow(float, 2) but we don't want this because + # pow_by_natural is assumed to only be integers. Probably the fix is + # to add a FloatMul to impede this optimization + sympy.Pow: "pow_by_natural", Mod: "mod", + PythonMod: "mod", # TODO: this is wrong + # TODO: Inductor can generate these, but it's ill-specified which + # semantics were intended here. Needs to be cleaned up along with + # FloorDiv in a bigger cleanup sympy.Mod: "mod", sympy.Abs: "abs", sympy.log: "log", sympy.exp: "exp", - sympy.floor: "floor", - sympy.ceiling: "ceil", sympy.Min: "minimum", sympy.Max: "maximum", ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", - Round: "round", - RoundDecimal: "round", + RoundDecimal: "round_decimal", } for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: HANDLERS[getattr(sympy, name)] = name @@ -84,7 +100,11 @@ def handlers(): def sympy_interp( - analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean] + analysis, + env: Dict[sympy.Symbol, Any], + expr: Union[sympy.Expr, SympyBoolean], + *, + index_dtype=torch.int64, ): # Handle base cases dtype = None @@ -105,9 +125,32 @@ def sympy_interp( expr.args[1], sympy.core.numbers.Half ): return analysis.sqrt(sympy_interp(analysis, env, expr.args[0])) + if isinstance(expr, ToFloat): + return analysis.to_dtype( + sympy_interp(analysis, env, expr.args[0]), torch.float64 + ) # Recursive case args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type] + + # These handlers are special because they take an extra dtype argument + # specifying what they should convert to, and we need to appropriately set + # this up when we convert from Sympy. A reasonable default when you + # are translating is to conservatively do int64, and then narrow these + # arguments later when you discover you can narrow the index range. But + # if you already know that 32-bit indexing is OK, you can directly do the + # sympy translation with index_dtype=torch.int32 + INDEX_DTYPE_HANDLERS = { + TruncToInt: "trunc_to_int", + sympy.floor: "floor_to_int", + sympy.ceiling: "ceil_to_int", + FloorToInt: "floor_to_int", + CeilToInt: "ceil_to_int", + RoundToInt: "round_to_int", + } + if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: + return getattr(analysis, handler_name)(*args, index_dtype) + if hasattr(expr.func, "_torch_handler_name"): handler_name = expr.func._torch_handler_name else: diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 881b9d616eb5..b54a0d0503a1 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,12 +1,25 @@ import math +import operator + import sympy import torch from torch.utils._sympy.functions import ( + _keep_float, + FloatPow, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, + Mod, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, OpaqueUnaryFn_sqrt, + PowByNatural, + RoundDecimal, + RoundToInt, + ToFloat, + TruncToInt, ) @@ -62,20 +75,41 @@ def not_(a): @staticmethod def reciprocal(x): - return 1 / x + return FloatTrueDiv(1.0, x) @staticmethod def square(x): - return x * x + return PowByNatural(x, 2) + + @staticmethod + def trunc_to_int(x, dtype): + return TruncToInt(x) + + @staticmethod + def ceil_to_int(x, dtype): + return sympy.ceiling(x) + + @staticmethod + def floor_to_int(x, dtype): + return sympy.floor(x) + + @staticmethod + def floor(x): + return _keep_float(sympy.floor)(x) + + @staticmethod + def ceil(x): + return _keep_float(sympy.ceiling)(x) + + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return ToFloat(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") @staticmethod def mod(x, y): - ret = abs(x) % abs(y) - # without check: - # tracing will fail trying to go through control-flow if x is Proxy() - if isinstance(x, (int, sympy.Number)) and x < 0: - ret *= -1 - return ret + return Mod(x, y) @staticmethod def abs(x): @@ -87,37 +121,31 @@ def neg(x): @staticmethod def truediv(a, b): - return a / b + return FloatTrueDiv(a, b) @staticmethod - def div(a, b): - return ReferenceAnalysis.truediv(a, b) + def int_truediv(a, b): + return IntTrueDiv(a, b) @staticmethod def floordiv(a, b): - if b == 0: - return sympy.nan if a == 0 else sympy.zoo - return a // b + return FloorDiv(a, b) @staticmethod def truncdiv(a, b): - result = a / b - if result.is_finite: - result = sympy.Integer(result) - - return result + raise NotImplementedError("TODO: truncdiv") @staticmethod def add(a, b): - return a + b + return _keep_float(operator.add)(a, b) @staticmethod def mul(a, b): - return a * b + return _keep_float(operator.mul)(a, b) @staticmethod def sub(a, b): - return a - b + return _keep_float(operator.sub)(a, b) @staticmethod def exp(x): @@ -133,39 +161,27 @@ def sqrt(x): @staticmethod def pow(a, b): - return a**b + return _keep_float(FloatPow)(a, b) + + @staticmethod + def pow_by_natural(a, b): + return PowByNatural(a, b) @staticmethod def minimum(a, b): - # Poorman's version of upcasting in Sympy - # This won't do for sympy.Expr as the casting does nothing for those - if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: - result_type = sympy.Float - else: - assert a.is_Integer - assert b.is_Integer - result_type = sympy.Integer - return sympy.Min(result_type(a), result_type(b)) + return sympy.Min(a, b) @staticmethod def maximum(a, b): - # Poorman's version of upcasting in Sympy - # This won't do for sympy.Expr as the casting does nothing for those - if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: - result_type = sympy.Float - else: - assert a.is_Integer - assert b.is_Integer - result_type = sympy.Integer - return sympy.Max(result_type(a), result_type(b)) + return sympy.Max(a, b) @staticmethod - def floor(x): - return sympy.floor(x) + def round_to_int(a, dtype): + return RoundToInt(a) @staticmethod - def ceil(x): - return sympy.ceiling(x) + def round_decimal(a, b): + return RoundDecimal(a, b) # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain @@ -191,10 +207,20 @@ def not_(a): def floordiv(a, b): return a // b + @staticmethod + def mod(x, y): + return x % y + @staticmethod def truncdiv(a, b): return a / b + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return float(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + @staticmethod def exp(x): raise AssertionError("exp is not valid shape sympy expr") @@ -216,9 +242,40 @@ def maximum(a, b): return torch.sym_max(a, b) @staticmethod - def floor(x): + def floor_to_int(x, dtype): return math.floor(x) @staticmethod - def ceil(x): + def ceil_to_int(x, dtype): return math.ceil(x) + + @staticmethod + def floor(x): + return float(math.floor(x)) + + @staticmethod + def ceil(x): + return float(math.ceil(x)) + + @staticmethod + def truediv(a, b): + return a / b + + @staticmethod + def pow(a, b): + return a**b + + @staticmethod + def pow_by_natural(a, b): + # Pray that safe_pow is not needed here lol. In particular, this + # never participates in VR low/high ranges, so overflow should be + # unlikely + return a**b + + @staticmethod + def round_to_int(a, dtype): + return round(a) + + @staticmethod + def round_decimal(a, b): + return round(a, ndigits=b) diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index 6276c696293c..02ddf7c34219 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -88,6 +88,7 @@ def try_solve( # Return if we were able to isolate 'thing' on the left-hand side. if isinstance(e, sympy.Rel) and e.lhs == thing: + log.debug("solved: %s ---> %s", expr, e) return e, e.rhs return None diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index c7cc96beb980..4d364d4981b5 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -5,6 +5,7 @@ import logging import math import operator +import sys from typing import ( Callable, Dict, @@ -25,17 +26,20 @@ from torch._prims_common import dtype_to_type from .functions import ( - OpaqueUnaryFn_acos, - OpaqueUnaryFn_asinh, - OpaqueUnaryFn_atan, - OpaqueUnaryFn_cosh, + _keep_float, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, - OpaqueUnaryFn_sinh, OpaqueUnaryFn_sqrt, - OpaqueUnaryFn_tanh, - Round, + PowByNatural, RoundDecimal, + RoundToInt, + safe_pow, + ToFloat, + TruncToFloat, + TruncToInt, ) from .interp import sympy_interp @@ -120,6 +124,8 @@ class ValueRanges(Generic[_T]): lower: _T upper: _T is_bool: bool + is_int: bool + is_float: bool @overload def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: @@ -142,8 +148,39 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: # Because this is a frozen class object.__setattr__(self, "lower", lower) object.__setattr__(self, "upper", upper) + # Unlike bool/int in Python, we don't report bools are ints object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean)) - assert isinstance(upper, SympyBoolean) == self.is_bool + if self.is_bool: + assert isinstance(upper, SympyBoolean), (lower, upper) + + # Warning: is_int/is_float is best effort. We do pretty well in + # Dynamo, but in Inductor these attributes are often wrong because we + # are not very rigorous in dtype analysis. This is also why we need + # the flexible analysis for is_int: sometimes a sympy.oo pops in for + # an integer bound. I would /like/ for us not to do this, but it's + # too hard to push the invariant through right now. + + object.__setattr__( + self, + "is_int", + not self.is_bool + and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)), + ) + """ + # This assert is just impossible right now, too many sympy bugs + if self.is_int: + # NB: sympy will sometimes randomly lose the float-ness of zero, + # so we also need to account for that in the assertion here. + # See also https://github.com/sympy/sympy/issues/26620 + assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], ( + lower, + upper, + ) + assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper) + """ + # NB: [-oo, oo] always advertises as float! + object.__setattr__(self, "is_float", not self.is_bool and not self.is_int) + assert self.is_bool or self.is_int or self.is_float, (lower, upper) def boolify(self) -> ValueRanges[SympyBoolean]: if vr_is_bool(self): @@ -184,6 +221,8 @@ def __and__(self: AllVR, other: AllVR) -> AllVR: if self == ValueRanges.unknown(): return other assert self.is_bool == other.is_bool, (self, other) + assert self.is_int == other.is_int, (self, other) + assert self.is_float == other.is_float, (self, other) if self.is_bool: return ValueRanges( sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper) @@ -353,7 +392,12 @@ def constant(value, dtype): # using nan makes subsequent computation throw, and for the purposes of optimization # returning -math.inf - math.inf is equivalent to giving up if isinstance(value, SupportsFloat) and math.isnan(value): - return ValueRanges.unknown() + if dtype == torch.bool: + return ValueRanges.unknown_bool() + elif dtype.is_floating_point: + return ValueRanges.unknown() + else: + return ValueRanges(-sys.maxsize - 1, sys.maxsize) if is_python: type_ = dtype_to_type(dtype) @@ -369,7 +413,18 @@ def constant(value, dtype): # dtype is intXX assert value.is_integer - return ValueRanges.wrap(value) + r = ValueRanges.wrap(value) + return r + + @staticmethod + def to_dtype(a, dtype, src_dtype=None): + if dtype == torch.float64: + return ValueRanges.increasing_map(a, ToFloat) + return ValueRanges.unknown() + + @staticmethod + def trunc_to_int(a, dtype): + return ValueRanges.increasing_map(a, TruncToInt) @staticmethod def not_(a): @@ -428,7 +483,9 @@ def ge(cls, a, b): @staticmethod def add(a, b): - return ValueRanges.coordinatewise_increasing_map(a, b, operator.add) + return ValueRanges.coordinatewise_increasing_map( + a, b, _keep_float(operator.add) + ) @classmethod def mul(cls, a, b): @@ -448,11 +505,20 @@ def safe_mul(a, b): else: return a * b - return ValueRanges.coordinatewise_monotone_map(a, b, safe_mul) + return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul)) - @classmethod - def div(cls, a, b): - return cls.truediv(a, b) + @staticmethod + def int_truediv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if 0 in b or ( + (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + ): + return ValueRanges.unknown() + else: + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(IntTrueDiv) + ) @staticmethod def truediv(a, b): @@ -463,18 +529,22 @@ def truediv(a, b): ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, operator.truediv) + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(FloatTrueDiv) + ) @staticmethod def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) if 0 in b or ( - (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + # TODO: make this more precise + (-sympy.oo in a or sympy.oo in a) + or (-sympy.oo in b or sympy.oo in b) ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, operator.floordiv) + return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv) @classmethod def mod(cls, x, y): @@ -523,17 +593,51 @@ def modular_indexing(cls, a, b, c): @classmethod def is_non_overlapping_and_dense_indicator(cls, *args): - return ValueRanges.unknown() + return ValueRanges.unknown() # TODO: type here is wrong @classmethod - def pow(cls, a, b): - def is_integer(val): - return isinstance(val, int) or ( - hasattr(val, "is_integer") and val.is_integer + def pow_by_natural(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if a.is_singleton() and b.is_singleton(): + return ValueRanges.wrap(safe_pow(a.lower, b.lower)) + # NB: Exclude zero, because zero is special + elif a.lower >= 1: + # We should know that b >= 0 but we may have forgotten this fact due + # to replacements, so don't assert it, but DO clamp it to prevent + # degenerate problems + return ValueRanges.coordinatewise_increasing_map( + a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural + ) + elif b.is_singleton(): + if b.lower % 2 == 0: + # x^n where n is even + return ValueRanges.convex_min_zero_map( + a, lambda x: safe_pow(x, b.lower) + ) + else: + # x^n where n is odd + return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower)) + else: + # a is potentially negative, and we don't know if the exponent is + # even or odd. So just conservatively set the upper and lower + # bound based on what the maximum absolute value could be, in both + # directions + max_base = max(a.upper, -a.lower) + return ValueRanges( + -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper) ) + @classmethod + def pow(cls, a, b): + return ValueRanges.unknown() + + # We could implement all this, but for floating point pow, is there + # really a point? + """ a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) + # Not implemented yet. It's a bit tricky # If you want to implement it, compute the partial derivatives of a ** b # and check the ranges where the function is increasing / decreasing @@ -553,8 +657,7 @@ def is_integer(val): if b == 0: if not a.lower.is_finite: return ValueRanges.unknown() - type_ = sympy.Float if a.lower.is_real else sympy.Integer - return ValueRanges.wrap(type_(1)) + return ValueRanges.wrap(1.0) if b < 0: a = cls.reciprocal(a) @@ -563,21 +666,12 @@ def is_integer(val): if a == ValueRanges.unknown(): return ValueRanges.unknown() - # Here b > 0 - if not is_integer(b): - # If the base is positive, then we're good, otherwise nothing's defined - if a.lower >= 0: - return ValueRanges.increasing_map(a, lambda x: x**b) - else: - return ValueRanges.unknown() + # If the base is positive, then we're good, otherwise nothing's defined + if a.lower >= 0: + return ValueRanges.increasing_map(a, lambda x: x**b) else: - # b > 0 integer - if b % 2 == 0: - # x^n where n is even - return ValueRanges.convex_min_zero_map(a, lambda x: x**b) - else: - # x^n where n is odd - return ValueRanges.increasing_map(a, lambda x: x**b) + return ValueRanges.unknown() + """ @staticmethod def reciprocal(x): @@ -586,7 +680,7 @@ def reciprocal(x): if 0 in x: return ValueRanges.unknown() else: - return ValueRanges.decreasing_map(x, lambda y: 1 / y) + return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y)) @staticmethod def abs(x): @@ -615,45 +709,64 @@ def maximum(cls, a, b): def min_or_max(a, b, fn): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) + return ValueRanges.coordinatewise_increasing_map(a, b, fn) - # Performs upcasting first - def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr: - # Poorman's version of upcasting in Sympy - # Inf is not a float... - if x.is_Integer and y.is_Integer: - result_type = sympy.Integer - elif x.is_rational and y.is_rational: - result_type = sympy.Rational - else: - assert x.is_real or not x.is_finite or y.is_real or not y.is_finite - result_type = sympy.Float - return fn(result_type(x), result_type(y)) + @classmethod + def floor_to_int(cls, x, dtype): + return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) - return ValueRanges.coordinatewise_increasing_map(a, b, fn_) + @classmethod + def ceil_to_int(cls, x, dtype): + return ValueRanges.increasing_map( + x, sympy.functions.elementary.integers.ceiling + ) + + # I think these implementations are sound. The hazard here is that sympy + # will carry out the floor/ceil at too high precision and then something + # bad will happen when we convert it to float. + # + # For truncation, the implementation is clearly sound, because the desired + # target float is always exactly representable, since you're just chopping + # off bits the mantissa. But what about ceil/floor? + # + # The important constraint here is that we're not defining floor on + # arbitrary real numbers, only representable float numbers. So we can + # take advantage of the fact that before we reach the first + # unrepresentable integer in floating point space, we have the range of + # numbers corresponding to exponent zero: all integers, with no fractional + # amounts. floor/ceil is an identity operation in this case. In the + # range below here, representable floating point numbers are spaced + # exactly 1/2 apart, and notably, both the floor/ceil are defined floating + # point numbers. There is no "gap" as you step up to the next exponent. @classmethod def floor(cls, x): - return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) + return ValueRanges.increasing_map( + x, _keep_float(sympy.functions.elementary.integers.floor) + ) @classmethod def ceil(cls, x): return ValueRanges.increasing_map( - x, sympy.functions.elementary.integers.ceiling + x, _keep_float(sympy.functions.elementary.integers.ceiling) ) @classmethod - def round(cls, number, ndigits=None): - if ndigits is None: - fn = Round - else: - assert ndigits.is_singleton() - ndigits = ndigits.lower - # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind - # the second parameter. - fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 + def round_decimal(cls, number, ndigits): + if not ndigits.is_singleton(): + return ValueRanges.unknown() + + ndigits = ndigits.lower + # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind + # the second parameter. + fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 return ValueRanges.increasing_map(number, fn) + @classmethod + def round_to_int(cls, number, dtype): + return ValueRanges.increasing_map(number, RoundToInt) + # It's used in some models on symints @staticmethod def sqrt(x): @@ -708,12 +821,15 @@ def cos(x): @staticmethod def cosh(x): + return ValueRanges(0.0, sympy.oo) + """ x = ValueRanges.wrap(x) if x.lower > 0: return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh) elif x.upper < 0: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh) return ValueRanges(0.0, sympy.oo) + """ @staticmethod def sin(x): @@ -723,7 +839,8 @@ def sin(x): @staticmethod def sinh(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + return ValueRanges(-sympy.oo, sympy.oo) @staticmethod def tan(x): @@ -731,32 +848,37 @@ def tan(x): @staticmethod def tanh(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + return ValueRanges(-sympy.oo, sympy.oo) @staticmethod def asin(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh) return ValueRanges.unknown() + """ @staticmethod def acos(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos) return ValueRanges.unknown() + """ @staticmethod def atan(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) + return ValueRanges(-sympy.oo, sympy.oo) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) @staticmethod def trunc(x): - def trunc(x): - return sympy.Integer(x) if x.is_finite else x - - return ValueRanges.increasing_map(x, trunc) + return ValueRanges.increasing_map(x, TruncToFloat) class ValueRangeAnalysis(SymPyValueRangeAnalysis): @@ -791,9 +913,10 @@ def store(self, name, index, value, mode=None): def reduction(self, name, dtype, src_dtype, reduction_type, index, value): return ValueRanges.unknown() - def index_expr(self, index, dtype): + @classmethod + def index_expr(cls, index, dtype): assert isinstance(index, ValueRanges) - return index + return cls.to_dtype(index, dtype) @staticmethod def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): @@ -830,12 +953,15 @@ def cast(x, dtype): @staticmethod def square(x): - return ValueRanges.convex_min_zero_map(x, lambda y: y * y) + return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) @staticmethod def neg(x): return ValueRanges.decreasing_map(x, operator.neg) + # TODO: this is slightly inaccurate because truncdiv operates at integer + # precision, but we're going through float truediv which means we can + # potentially lose precision on the bounds @classmethod def truncdiv(cls, a, b): x = cls.truediv(a, b) @@ -856,6 +982,7 @@ def __getattr__(self, name): def bound_sympy( expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: + log.debug("bound_sympy(%s, %s)", expr, ranges) if isinstance(expr, sympy.Number): return ValueRanges.wrap(expr) From 68eb771265f74533a2025288b774164e0991b526 Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 6 Jun 2024 03:41:32 +0000 Subject: [PATCH 398/706] [2/N] Remove unused test functions (#128005) Following #127881, this PR continues to remove unused test functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128005 Approved by: https://github.com/ezyang --- aten/src/ATen/core/boxing/KernelFunction_test.cpp | 10 ---------- .../core/boxing/impl/kernel_function_legacy_test.cpp | 11 ----------- .../ATen/core/boxing/impl/kernel_function_test.cpp | 12 ------------ .../impl/make_boxed_from_unboxed_functor_test.cpp | 11 ----------- 4 files changed, 44 deletions(-) diff --git a/aten/src/ATen/core/boxing/KernelFunction_test.cpp b/aten/src/ATen/core/boxing/KernelFunction_test.cpp index a0f990e87aaf..cf45c709c58d 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_test.cpp +++ b/aten/src/ATen/core/boxing/KernelFunction_test.cpp @@ -275,16 +275,6 @@ void expectOutOfPlaceMultiBoxedCallingWorks(const KernelFunction& func) { EXPECT_TRUE(stack[1].toTensor().is_same(t2)); } -void expectBoxedCallingFailsWith(const KernelFunction& func, const char* errorMessage) { - called_with_args = c10::nullopt; - vector stack {3, 4}; - OperatorHandle dummy = makeDummyOperatorHandle(); - - expectThrows([&] { - func.callBoxed(dummy, CPU_TEST_SET, &stack); - }, errorMessage); -} - // // unboxed calling tests: // diff --git a/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp b/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp index a5cb61874173..f6dc3ee356a0 100644 --- a/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp +++ b/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp @@ -51,17 +51,6 @@ void expectCallsIncrement(DispatchKey dispatch_key) { EXPECT_EQ(6, result[0].toInt()); } -void expectCallsDecrement(DispatchKey dispatch_key) { - at::AutoDispatchBelowAutograd mode; - - // assert that schema and cpu kernel are present - auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); - ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(dispatch_key), 5); - EXPECT_EQ(1, result.size()); - EXPECT_EQ(4, result[0].toInt()); -} - TEST(OperatorRegistrationTestLegacyFunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel); expectCallsIncrement(DispatchKey::CPU); diff --git a/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp b/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp index 5662c0982bfb..2d6f7346eec2 100644 --- a/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp +++ b/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp @@ -662,18 +662,6 @@ void expectCallsConcatUnboxed(DispatchKey dispatch_key) { EXPECT_EQ("123", result); } -void expectCannotCallConcatBoxed(DispatchKey dispatch_key) { - at::AutoDispatchBelowAutograd mode; - - // assert that schema and cpu kernel are present - auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); - ASSERT_TRUE(op.has_value()); - expectThrows( - [&] {callOp(*op, dummyTensor(dispatch_key), "1", "2", 3);}, - "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()." - ); -} - TEST(OperatorRegistrationTestFunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) { auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", RegisterOperators::options().kernel(DispatchKey::CPU)); expectCallsConcatUnboxed(DispatchKey::CPU); diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp index 3ced237702aa..345f5b11cba8 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp @@ -51,17 +51,6 @@ void expectCallsIncrement(DispatchKey dispatch_key) { EXPECT_EQ(6, result[0].toInt()); } -void expectCallsDecrement(DispatchKey dispatch_key) { - at::AutoDispatchBelowAutograd mode; - - // assert that schema and cpu kernel are present - auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); - ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(dispatch_key), 5); - EXPECT_EQ(1, result.size()); - EXPECT_EQ(4, result[0].toInt()); -} - TEST(OperatorRegistrationTestFunctorBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)); expectCallsIncrement(DispatchKey::CPU); From cd42b95047e43fd2a13c91858ee2f5a08812d187 Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Thu, 6 Jun 2024 05:00:13 +0000 Subject: [PATCH 399/706] Handle aten::__contains__ during TorchScript to ExportedProgram conversion (#127544) #### Description Add support for converting `prim::__contains__` from TorchScript IR to ExportedProgram, e.g., ```python class MIn(torch.nn.Module): def forward(self, x: torch.Tensor): return x.dtype in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ``` #### Test Plan * Add test cases to cover both contains IR resulted from primitive types or Tensor. `pytest test/export/test_converter.py -s -k test_ts2ep_converter_contains` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127544 Approved by: https://github.com/angelayi --- test/export/test_converter.py | 24 +++++++++++++++++++++++- torch/_export/converter.py | 2 +- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 8d0b2d6f93c6..44f23309579d 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1,7 +1,7 @@ # Owner(s): ["oncall: export"] import unittest -from typing import Tuple +from typing import Dict, Tuple import torch @@ -264,6 +264,28 @@ def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]): inp = ((torch.zeros(1, 4), torch.ones(1, 4)),) self._check_equal_ts_ep_converter(MUnpackTuple(), inp) + def test_ts2ep_converter_contains(self): + class MIn(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x.dtype in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + class MNotIn(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x.dtype in [-1] + + class MTensorIn(torch.nn.Module): + def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]): + return x in x_dict + + inp = (torch.tensor(4),) + self._check_equal_ts_ep_converter(MIn(), inp) + self._check_equal_ts_ep_converter(MNotIn(), inp) + + inp = (torch.tensor(4), {torch.tensor(4): "foo"}) + self._check_equal_ts_ep_converter(MTensorIn(), inp) + inp = (torch.tensor(1), {torch.tensor(4): "foo"}) + self._check_equal_ts_ep_converter(MTensorIn(), inp) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 4cbcc3faea83..9e438e206984 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -55,6 +55,7 @@ def ir_name_to_func_name(name: str) -> str: "aten::__is__": operator.is_, "aten::__isnot__": operator.is_not, "aten::__not__": operator.not_, + "aten::__contains__": operator.contains, } @@ -506,7 +507,6 @@ def _convert_standard_operators(self, node: torch._C.Node): def convert_node(self, node: torch._C.Node): node_kind = node.kind() - node_kind_split = node_kind.split("::") # Get handler based on namespace and operator name. # Provide a default node handler as well in case we don't find From 638f543ac23f1638c804ec47d57561b830cf1731 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 6 Jun 2024 06:25:00 +0000 Subject: [PATCH 400/706] Enable single nadam test (#128087) https://github.com/pytorch/pytorch/issues/117150 has been fixed Pull Request resolved: https://github.com/pytorch/pytorch/pull/128087 Approved by: https://github.com/xmfan --- torch/testing/_internal/common_optimizers.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index bb8375e35cfd..ac4a7f920cc2 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -1578,13 +1578,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "TestOptimRenewed", "test_load_nontensor_step", ), - DecorateInfo( - skipIfTorchDynamo( - "Errors, see https://github.com/pytorch/pytorch/issues/117150" - ), - "TestOptimRenewed", - "test_state_dict_with_cuda_params", - ), DecorateInfo( skipIfTorchDynamo( "This test uses mocks, which dynamo does not support" From c8ff1cd38723f7089776de6b4deba913e6c5e510 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 5 Jun 2024 19:41:33 -0700 Subject: [PATCH 401/706] [FSDP2] Changed `test_register_forward_method` to use multiprocess test (#128100) The test seems to be flaky due to multi-threaded process group. This PR converts the test to use normal multi-process `ProcessGroupNCCL` to fix the flakiness. This PR closes https://github.com/pytorch/pytorch/issues/126851. Interestingly, the original MTPG version passes for me on devgpu. Either way, the new version also passes on devgpu, so we can see in CI. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128100 Approved by: https://github.com/weifengpy --- .../_composable/fsdp/test_fully_shard_training.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index a34a59f3cde6..836013f7fb24 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -1248,12 +1248,12 @@ def _test_train_parity_hsdp( check_sharded_parity(self, ref_model, model) -class TestFullyShardCustomForwardMethod(FSDPTestMultiThread): +class TestFullyShardCustomForwardMethod(FSDPTest): @property def world_size(self) -> int: - return 2 + return min(torch.cuda.device_count(), 2) - @unittest.skipIf(not TEST_CUDA, "no cuda") + @skip_if_lt_x_gpu(2) def test_register_fsdp_forward_method(self): """Based on https://github.com/pytorch/pytorch/issues/109385""" @@ -1280,8 +1280,6 @@ def forward(self, imgs: torch.Tensor) -> torch.Tensor: torch.manual_seed(42) model = Model() - for param in model.parameters(): - dist.broadcast(param.detach(), src=0) ref_model = copy.deepcopy(model).cuda() fully_shard(model.vit) fully_shard(model.projector) From 70ba6f0ab618fa0740c60ed04dd908bb892333ce Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 4 Jun 2024 17:08:55 -0700 Subject: [PATCH 402/706] Collect static parameter metadata in aot (#126820) Collect the indices of the static parameters to pass down to cudagraphs in order to re-record if necessary. This location was chosen in order to allow us to restrict this (if needed) in the future by setting metadata in dynamo. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126820 Approved by: https://github.com/bdhirsh --- .../_aot_autograd/collect_metadata_analysis.py | 10 ++++++++++ torch/_functorch/_aot_autograd/schemas.py | 3 +++ 2 files changed, 13 insertions(+) diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index e01f6df6957d..991e12a59d4b 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -665,6 +665,15 @@ def view_avoid_dupes_with_primals(t): ) user_outs = pytree.tree_map(from_fun, f_output_tangents) + if torch._dynamo.config.inline_inbuilt_nn_modules: + static_parameter_input_indices = [ + i + for i, arg in enumerate(flat_args) + if isinstance(arg, torch.nn.Parameter) + ] + else: + static_parameter_input_indices = [] + f_mutated_inputs = [ inp for inp, info in zip(flat_f_args, input_info) @@ -716,6 +725,7 @@ def view_avoid_dupes_with_primals(t): subclass_tangent_meta=create_subclass_meta(traced_tangents), is_train=is_train, grad_enabled_mutation=grad_enabled_mutation, + static_parameter_indices=static_parameter_input_indices, tokens=mode._tokens, ) return metadata diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 982fcb9e6464..3246f142ca43 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -304,6 +304,9 @@ class ViewAndMutationMeta: # raised deterministic: Optional[bool] = None + # Keeps track of which input indices store parameters (which we will treat as static) + static_parameter_indices: List[int] = field(default_factory=list) + # Map of effect type (ex. _EffectType.ORDERED) to token. If there are # side-effectful operators, FunctionalTensorMode will populate this # dictionary telling us how many tokens we will need during tracing. From 5a3bea1e88e5c7877393dff5eb3f0cfa7a94b85d Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 4 Jun 2024 17:08:55 -0700 Subject: [PATCH 403/706] Remove unused arg to GraphLowering (#126821) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126821 Approved by: https://github.com/eellison ghstack dependencies: #126820 --- test/inductor/test_cpu_repro.py | 2 -- test/inductor/test_torchinductor.py | 2 -- torch/_inductor/compile_fx.py | 2 -- torch/_inductor/graph.py | 3 --- 4 files changed, 9 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index d4d0e258c3e2..b2ab30832e06 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -2232,7 +2232,6 @@ def get_index(): graph_lowering = GraphLowering( torch.fx.GraphModule(submodules, _graph), shape_env=None, - num_static_inputs=0, ) def set_opt_dtype(graph): @@ -2343,7 +2342,6 @@ def get_index(): graph_lowering = GraphLowering( torch.fx.GraphModule(submodules, _graph), shape_env=None, - num_static_inputs=0, ) with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( graph_lowering diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index c6928ee37a8e..42c430866290 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -9190,7 +9190,6 @@ def func(arg0_1): graph = GraphLowering( gm, shape_env=shape_env, - num_static_inputs=0, ) with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()): graph.run(*example_inputs) @@ -10417,7 +10416,6 @@ def get_kernels(self, fn, args) -> typing.List[CachingAutotuner]: cxt = TritonCodeGenTests.NoOpCompilerBackend() torch._dynamo.optimize(backend=cxt.noop_backend)(fn)(*args) graph = GraphLowering(cxt.model) - graph.num_static_inputs = 0 kernels = [] with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()): graph.run(*(cxt.example_args)) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index c6eddbed19fe..44581c29d2ac 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -749,7 +749,6 @@ def fx_codegen_and_compile( const_gm, example_inputs=[], shape_env=shape_env, - num_static_inputs=num_fixed, graph_id=graph_id, cpp_wrapper=cpp_wrapper, aot_mode=aot_mode, @@ -771,7 +770,6 @@ def fx_codegen_and_compile( # we currently use fake tensors and defake them later. example_inputs=example_inputs, shape_env=shape_env, - num_static_inputs=num_fixed, graph_id=graph_id, cpp_wrapper=cpp_wrapper, aot_mode=aot_mode, diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index abe93686ac83..d5ec55afd05e 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -296,7 +296,6 @@ def __init__( gm: torch.fx.GraphModule, example_inputs: Optional[List[torch.Tensor]] = None, shape_env=None, - num_static_inputs=None, graph_id=None, cpp_wrapper=False, aot_mode=False, @@ -311,7 +310,6 @@ def __init__( name=None, ): super().__init__(gm) - self.example_inputs = example_inputs self.layout_opt = ( layout_opt @@ -374,7 +372,6 @@ def __init__( Callable[[List[ir.ExternKernelNode]], Any] ] = extern_node_serializer self.current_node: torch.fx.Node = None # type: ignore[assignment] - self.num_static_inputs = num_static_inputs self.lists: Dict[str, List[str]] = {} self.mutated_inputs: Set[str] = set() self.mutated_input_idxs: List[int] = [] From f5328542b5365741176e71dd8a2954e0f350b9bc Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Jun 2024 12:51:50 -0700 Subject: [PATCH 404/706] Allow multiple cudagraph recordings per compiled graph (#126822) ### Introduction/Problem Today when dynamo traces a builtin nn module (nn.Linear for example) it will specially handle parameters of that module by storing them as constant attributes of the graph. This requires that dynamo guard on the ID of the NNModule because if the instance of the module changes, we need to retrace and recollect the new parameters as attributes of the graph. This creates a 1:1 compiled graph to cudagraph relationship. With hierarchical compilation, dynamo will treat builtin nn modules like any other code. This reduces complexity and critically, if there are multiple identical layers in a model, we only need to compile one of those layers once, and reuse the same compiled artifact for each layer. This introduces a problem for the current approach to parameter handling. Since the parameters could now possibly change across calls to the compiled artifact, these need to be inputs to the graph instead of attributes. This introduces a problem for cudagraphs - previously cudagraphs was guaranteed that the parameters of builtin NN Modules would be constant across calls, but now since the compiled artifact needs to be agnostic to the actual instance of the NN module being used these parameter memory locations may vary. Previously cudagraphs simply copies varying inputs to cudagraph owned memory, but since the parameters are quite large, this is catastrophic for performance. ### Solution To avoid this performance cliff, this PR allows cudagraphs to re-record a new cudagraph if only parameters change. Metadata about which arguments are parameters are propagated from AOT Autograd to compile_fx, and these indices are passed to cudagraphs. If these memory locations change, a new graph is recorded vs previously where this would be an error (because this previously should not happen). This enables a 1:many compiled graph to cudagraph relationship. Across similar modules we will re-record cudagraphs and dispatch the correct graph if parameter pointers match when the cudagraph is executed. ### Next steps (if needed) It is theoretically possible that a user passes Parameters that change frequently as inputs to model code - if this is a common issue this design allows for dynamo to pass metadata indicating which parameters were created in a builtin NN Module context to only permit those parameters to have the multi-cudagraph behavior, but this PR does not implement this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126822 Approved by: https://github.com/eellison ghstack dependencies: #126820, #126821 --- test/inductor/test_cudagraph_trees.py | 146 +++++++++++++++++++++++++- torch/_inductor/compile_fx.py | 43 ++++++-- torch/_inductor/cudagraph_trees.py | 33 +++++- torch/_inductor/cudagraph_utils.py | 5 +- 4 files changed, 209 insertions(+), 18 deletions(-) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 58c819b804ff..7e8b9fce2b3b 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -648,7 +648,9 @@ def get_aligned_inputs(): with mode: inps = [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)] - compiled_f = compile_fx_inner(mod, inps, num_fixed=1, cudagraphs=True) + compiled_f = compile_fx_inner( + mod, inps, static_input_idxs=[0], cudagraphs=True + ) def get_unaligned_inputs(): return [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)] @@ -1770,6 +1772,148 @@ def forward(self, x) -> torch.Tensor: [foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp] ) + def run_static_input_param_test(self, fn_eager, num_graphs): + with torch.device("cuda"): + fn_compiled = torch.compile(fn_eager, mode="reduce-overhead") + + def run_iter(param, fn): + fwd_output = fn(torch.ones(2, 2), param) + fwd_output.sum().backward() + grad_output = param.grad.clone().detach() + param.grad = None + return fwd_output, grad_output + + def loop(param): + exp_output, exp_grad = run_iter(param, fn_eager) + for _ in range(5): + compiled_output, compiled_grad = run_iter(param, fn_compiled) + self.assertEqual(exp_output, compiled_output) + self.assertEqual(exp_grad, compiled_grad) + + p1 = torch.nn.Parameter(torch.rand([2, 2])) + loop(p1) + + p2 = torch.nn.Parameter(torch.rand([2, 2])) + loop(p2) + + # Run p1 again to ensure we reuse the previous recording + loop(p1) + + self.assertEqual(self.get_manager().new_graph_id().id, num_graphs) + + def _module_test(self, mod): + with torch.device("cuda"): + + def fn(x, mod): + return mod(x) + + fn_compiled = torch.compile(fn, mode="reduce-overhead", fullgraph=True) + + def run_test_iter(mod, fn): + fwd_output = fn(torch.ones(2, 2), mod) + fwd_output.sum().backward() + grad_output = mod.weight.grad.clone().detach() + mod.zero_grad() + return fwd_output, grad_output + + def run_test(): + exp_output, exp_grad = run_test_iter(mod, fn) + for _ in range(5): + compiled_output, compiled_grad = run_test_iter(mod, fn_compiled) + self.assertEqual(exp_output, compiled_output) + self.assertEqual(exp_grad, compiled_grad) + + run_test() + old = mod.weight.data + mod.weight.data = torch.rand_like(mod.weight.data) + run_test() + # Run original version to verify we reuse the other recording + mod.weight.data = old + run_test() + + # Fwd + bwd graphs for each version of the function => 4 graphs + self.assertEqual(self.get_manager().new_graph_id().id, 4) + + @torch._inductor.config.patch("triton.cudagraphs", True) + @torch._dynamo.config.patch("error_on_recompile", True) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_multi_dispatch_single_compile_param_inputs(self): + # Verify that we can record multiple cudagraphs for a single + # compiled function with param inputs + def fn(x, y): + return x * y + + # Fwd + bwd graphs for each version of the function => 4 graphs + self.run_static_input_param_test(fn, 4) + + @torch._inductor.config.patch("triton.cudagraphs", True) + @torch._dynamo.config.patch("error_on_recompile", True) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_multi_dispatch_single_compile_builtin_module(self): + # Verify that we don't recompile when changing the param of a builtin module + # and that we record another cudagraph + # Note: Linear is a builtin module so we enable that config setting above + self._module_test(torch.nn.Linear(2, 3, device="cuda")) + + @torch._inductor.config.patch("triton.cudagraphs", True) + @torch._dynamo.config.patch("error_on_recompile", True) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_multi_dispatch_custom_module(self): + # Test that we can correctly dispatch multiple graphs + # if params of a custom module change + class TestModule(torch.nn.Module): + def __init__(self, param) -> None: + super().__init__() + self.weight = param + + def forward(self, x): + return self.weight * x + + self._module_test( + TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda"))) + ) + + @torch._inductor.config.patch("triton.cudagraphs", True) + @torch._dynamo.config.patch("error_on_recompile", True) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_multi_dispatch_child_node(self): + # Test that we can correctly dispatch multiple graphs if a child node + # in the tree has stable input pointers change + def fn(x, p): + # Graph 1 + y = x * x + torch._dynamo.graph_break() + # Graph 2 + return y * p + + # We have 5 graphs here + # Graph 1 + # / \ + # Graph 2 w/ p1 Graph 2 w/ p2 + # and then two backward graphs + self.run_static_input_param_test(fn, 5) + + @torch._inductor.config.patch("triton.cudagraphs", True) + @torch._dynamo.config.patch("error_on_recompile", True) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_multi_dispatch_parent_node(self): + def fn(x, p): + # Graph 1 + y = x * p + torch._dynamo.graph_break() + # Graph 2 + return y + x + + # We have 6 graphs here + # Graph 1 w/ p1 Graph 1 w/ p2 + # | | + # Graph 2 (v1) Graph 2 (v2) + # There are two versions of graph 2 because + # we re-record due to different memory state after running the + # two versions of Graph 1 + # and then two backward graphs + self.run_static_input_param_test(fn, 6) + instantiate_parametrized_tests(CudaGraphTreeTests) if __name__ == "__main__": diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 44581c29d2ac..856335819621 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -120,6 +120,19 @@ def complex_memory_overlap(t: torch.Tensor) -> bool: return False +def get_static_input_idxs(num_fixed): + # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes + # of cudagraphs. Rather than copying these into cudagraph-owned memory + # like we do for normal inputs on each run, we will re-record a cudagraph if these + # parameter locations change. + context = torch._guards.TracingContext.try_get() + fixed = list(range(num_fixed)) + if not context or not context.fw_metadata: + return fixed + + return fixed + context.fw_metadata.static_parameter_indices + + @functools.lru_cache(None) def _step_logger(): return dynamo_logging.get_step_logger(log) @@ -415,7 +428,7 @@ def compile_fx_inner( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], cudagraphs: Optional[BoxedBool] = None, - num_fixed: int = 0, + static_input_idxs: Optional[List[int]] = None, is_backward: bool = False, graph_id: Optional[int] = None, cpp_wrapper: bool = False, @@ -440,6 +453,9 @@ def compile_fx_inner( _LazyGraphModule.force_recompile(gm) return make_boxed_func(gm.forward) + if static_input_idxs is None: + static_input_idxs = [] + assert isinstance( next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list) ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" @@ -449,7 +465,7 @@ def compile_fx_inner( gm, example_inputs, cudagraphs=cudagraphs, - num_fixed=num_fixed, + static_input_idxs=static_input_idxs, is_backward=is_backward, graph_id=graph_id, cpp_wrapper=cpp_wrapper, @@ -468,7 +484,7 @@ def compile_fx_inner( # of fx_codegen_and_compile changes, the dict should be updated accordingly graph_kwargs = { "cudagraphs": cudagraphs, - "num_fixed": num_fixed, + "static_input_idxs": static_input_idxs, "is_backward": is_backward, "graph_id": graph_id, "cpp_wrapper": cpp_wrapper, @@ -482,7 +498,7 @@ def compile_fx_inner( start = time.time() fx_graph_remote_cache = should_use_remote_fx_graph_cache() - inputs_to_check = get_input_idxs_to_check(example_inputs, range(num_fixed)) + inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) if ( not config.force_disable_caches and (config.fx_graph_cache or fx_graph_remote_cache) @@ -492,7 +508,7 @@ def compile_fx_inner( if ( isinstance(input, torch.Tensor) and input.device.type == "cuda" - and i < num_fixed + and i in static_input_idxs ): input._is_inductor_static = True # type: ignore[attr-defined] @@ -551,7 +567,7 @@ def compile_fx_inner( ) has_mutation_str = check_for_mutation_ignore_cuda_graph_managed_tensor( - gm, compiled_graph, num_fixed + gm, compiled_graph, static_input_idxs ) has_mutation = has_mutation_str is not None @@ -591,7 +607,7 @@ def compile_fx_inner( compiled_graph.current_callable = cudagraphify( compiled_graph.current_callable, example_inputs, - static_input_idxs=range(num_fixed), + static_input_idxs=static_input_idxs, device_index=next(iter(compiled_graph.device_idxs)), stack_traces=stack_traces, is_backward=is_backward, @@ -660,7 +676,7 @@ def fx_codegen_and_compile( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], cudagraphs: Optional[BoxedBool] = None, - num_fixed: int = 0, + static_input_idxs: Optional[List[int]] = None, is_backward: bool = False, graph_id: Optional[int] = None, cpp_wrapper: bool = False, @@ -1178,6 +1194,7 @@ def fw_compiler_freezing( n.name for n in model_outputs if isinstance(n, torch.fx.Node) ) + static_input_idxs = list(range(num_fixed)) # constant params will be real tensors, not fake tracing_context = torch._guards.TracingContext.try_get() if tracing_context is not None: @@ -1187,11 +1204,14 @@ def fw_compiler_freezing( if i not in preserved_arg_indices: params_flat[i] = None + if tracing_context.fw_metadata: + static_input_idxs += tracing_context.fw_metadata.static_parameter_indices + with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): optimized_function = inner_compile( opt_model, aot_example_inputs, - num_fixed=num_fixed, + static_input_idxs=static_input_idxs, cudagraphs=cudagraphs, graph_id=graph_id, is_inference=True, @@ -1322,6 +1342,7 @@ def fw_compiler_base( fixed = torch._inductor.utils.num_fw_fixed_arguments( num_example_inputs, len(example_inputs) ) + user_visible_outputs = {} if config.keep_output_stride: @@ -1377,7 +1398,7 @@ def fw_compiler_base( return inner_compile( model, example_inputs, - num_fixed=fixed, + static_input_idxs=get_static_input_idxs(fixed), cudagraphs=cudagraphs, graph_id=graph_id, is_inference=is_inference, @@ -1421,7 +1442,7 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) return inner_compile( model, example_inputs, - num_fixed=fixed, + static_input_idxs=list(range(fixed)), cudagraphs=cudagraphs, is_backward=True, graph_id=graph_id, diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index e7a1f3364823..d49404ddafde 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -753,6 +753,11 @@ def __init__( self.device = device_index self.stack_traces = stack_traces self.stream = stream + # If we are inlining builtin nn modules we will re-record if static inputs change + # if not we should error because dynamo should have recompiled in this case + self.rerecord_if_static_inputs_change = ( + torch._dynamo.config.inline_inbuilt_nn_modules + ) # if this is a root parent will be None. use weakref to prevent reference cycle self._parent = weakref.ref(parent) if parent is not None else None @@ -952,8 +957,13 @@ def _copy_inputs_and_remove_from_src(self, dsts, srcs): def check_static_inputs_are_stable(self, new_inputs): # avoid checking managed tensor static points since we already checked those in check_invariants - if not torch._C._tensors_data_ptrs_at_indices_equal( - new_inputs, self.static_input_data_ptrs, self.non_managed_static_input_idxs + if ( + not self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + new_inputs, + self.static_input_data_ptrs, + self.non_managed_static_input_idxs, + ) ): # this should error static_tensors = [new_inputs[i] for i in self.non_managed_static_input_idxs] @@ -1000,6 +1010,9 @@ def run(self, new_inputs): if config.triton.force_cudagraph_sync: torch.cuda.synchronize() + # Reset this to run the check in the future + self.static_inputs_stable = False + return outputs def reconstruct_outputs(self): @@ -1553,8 +1566,8 @@ def _allocate_and_copy_recording_inputs( def check_invariants(self, inputs: List[Tensor]) -> bool: """ - Checks if this node can be run. The same pattern of tensor liveness and tensors - managed in the cudagraph private pool must remain stable. + Checks if this node can be run. The same pattern of tensor liveness, static inputs, + and tensors managed in the cudagraph private pool must remain stable. """ # previously managed data pointers remain stable @@ -1565,6 +1578,18 @@ def check_invariants(self, inputs: List[Tensor]) -> bool: ): return False + # static input data pointers should remain stable + # if we are inlining builtin nn modules we re-record in this case + # if we are not inlining builtin nn modules, we check this in check_static_inputs_are_stable + # and error if they are not stable + if ( + self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, self.static_input_data_ptrs, self.static_input_idxs + ) + ): + return False + if not self._check_liveness( self.expected_dead_indices_before_graph, self.path_weakrefs ): diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index a1ac4936f417..8556a0f751ed 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -143,15 +143,16 @@ def set(self, device_idx: Optional[int]): def check_for_mutation_ignore_cuda_graph_managed_tensor( - gm: torch.fx.GraphModule, compiled_graph, num_fixed: int + gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: List[int] ) -> Optional[str]: default_msg = format_default_skip_message("mutated inputs") # doesnt work for non-trees because the warmup run would apply mutation twice if torch._inductor.config.triton.cudagraph_trees: + unique_idxs = set(static_input_idxs) # checking if mutation is only on parameters/static inputs mutation_indices = [ - idx for idx in compiled_graph.mutated_input_idxs if idx >= num_fixed + idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs ] has_mutation = len(mutation_indices) != 0 if not has_mutation: From 457df212e1c6e1aa4f1eb2ad6ee292052d7c07e1 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Wed, 5 Jun 2024 11:48:36 +0000 Subject: [PATCH 405/706] Add OpInfo entry for alias_copy (#127232) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127232 Approved by: https://github.com/lezcano --- .../ATen/functorch/BatchRulesDecompositions.cpp | 1 + test/distributed/_tensor/test_dtensor_ops.py | 1 + .../HasDecompTest.test_has_decomposition.expect | 2 -- test/functorch/test_vmap_registrations.py | 1 + tools/autograd/gen_variable_type.py | 1 + torch/_decomp/__init__.py | 1 + torch/_inductor/exc.py | 2 +- torch/_refs/__init__.py | 4 ++++ .../_internal/common_methods_invocations.py | 15 +++++++++++++++ 9 files changed, 25 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 3e064d6c39dc..a0007aa18a00 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -324,6 +324,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(type_as); OP_DECOMPOSE(linalg_diagonal); OP_DECOMPOSE(diagonal_copy); + OP_DECOMPOSE(alias_copy); m.impl("pad", native::pad_symint); m.impl("_pad_circular", native::_pad_circular_symint); OP_DECOMPOSE(swapdims_); diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 83f0bb875167..07f8bfedc615 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -102,6 +102,7 @@ def wrapped(fn): xfail("addr"), xfail("all"), xfail("allclose"), + xfail("alias_copy"), xfail("amax"), xfail("amin"), xfail("aminmax"), diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index ad9cf07d7550..eeee3685e1fb 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -647,8 +647,6 @@ aten::adaptive_max_pool3d_backward.grad_input aten::addbmm aten::addbmm.out aten::addr_ -aten::alias_copy -aten::alias_copy.out aten::allclose aten::angle aten::angle.out diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index 967152945af5..737927a60f80 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -25,6 +25,7 @@ } xfail_functorch_batched_decomposition = { + "aten::alias_copy", "aten::diagonal_copy", "aten::is_same_size", "aten::unfold_copy", diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index b9651ea2da80..6abb13d244e9 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -305,6 +305,7 @@ "linalg_eig", "diagonal_copy", "diagonal_scatter", + "alias_copy", "select_backward", "diagonal_backward", "slice_backward", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index b277bb7eceb0..74587354424f 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -260,6 +260,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.addcmul_, aten.addr, aten.affine_grid_generator, + aten.alias_copy, aten.all, aten.aminmax, aten.arange.default, diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index 9e6aa6effae2..83638e91679d 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -45,7 +45,7 @@ def __init__(self, target, args, kwargs): There is a decomposition available for {target} in torch._decomp.get_decompositions(). Please add this operator to the - `decompositions` list in torch._inductor.decompositions + `decompositions` list in torch._inductor.decomposition """ ) ) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 68675c751736..f166edc009d8 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -232,6 +232,7 @@ # View & Shape Ops # "alias", + "alias_copy", "atleast_1d", "atleast_2d", "atleast_3d", @@ -4451,6 +4452,9 @@ def alias(a: TensorLikeType) -> TensorLikeType: return prims.view_of(a) +alias_copy = _make_copy_from_view(alias) + + @register_decomposition(aten.transpose) def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 151210cf9f53..e3fec2129634 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11584,6 +11584,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1): out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:] return np.reshape(input, out_shape) + +def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad)) + yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) + + # Operator database (sorted alphabetically) op_db: List[OpInfo] = [ UnaryUfuncInfo('abs', @@ -13087,6 +13093,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_diagonal_scatter), + OpInfo('alias_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_alias_copy, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), BinaryUfuncInfo('eq', ref=np.equal, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), @@ -23223,6 +23234,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # # View & Shape OpInfos # + PythonRefInfo( + "_refs.alias_copy", + torch_opinfo_name="alias_copy", + ), PythonRefInfo( "_refs.atleast_1d", torch_opinfo_name="atleast_1d", From c97e3ebb96d7457075b019b94411e8c2d058e68b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 5 Jun 2024 05:53:23 +0000 Subject: [PATCH 406/706] Fix wrongly exposed variables in `torch/__init__.py` (#127795) image This PR removes temporary variables in `torch/__init__.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127795 Approved by: https://github.com/albanD --- test/allowlist_for_publicAPI.json | 2 - torch/__init__.py | 88 +++++++++++++++++-------------- torch/_dynamo/trace_rules.py | 2 +- 3 files changed, 48 insertions(+), 44 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index c3d3fe2f00ec..0ead16868f2f 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1321,12 +1321,10 @@ "_weight_norm_interface", "autocast", "broadcast_shapes", - "candidate", "compiled_with_cxx11_abi", "from_dlpack", "lobpcg", "lu", - "obj", "segment_reduce", "set_default_dtype", "set_grad_enabled", diff --git a/torch/__init__.py b/torch/__init__.py index dfb1da76739d..2efc457f33ba 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -699,8 +699,6 @@ def sym_min(a, b): return builtins.min(a, b) # Drop in replacement for math.sqrt, math.sin, math.cos etc -current_module = sys.modules[__name__] - def _get_sym_math_fn(name): def fn(a): from .overrides import has_torch_function_unary, handle_torch_function @@ -713,18 +711,19 @@ def fn(a): return fn -for name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan"): - sym_name = f"_sym_{name}" - fn = _get_sym_math_fn(name) - fn.__qualname__ = fn.__name__ = sym_name - setattr(current_module, sym_name, fn) +__fn, __name, __sym_name = None, '', '' +for __name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan"): + __sym_name = f"_sym_{__name}" + __fn = _get_sym_math_fn(__name) + __fn.__qualname__ = __fn.__name__ = __sym_name + globals()[__sym_name] = __fn + +del __fn, __name, __sym_name, _get_sym_math_fn # Adding temporary shortcut -sym_sqrt = current_module._sym_sqrt +sym_sqrt = globals()["_sym_sqrt"] __all__.append("sym_sqrt") -del fn, name, sym_name, current_module # type: ignore[possibly-undefined] - def sym_ite(b, t, f): from .overrides import has_torch_function, handle_torch_function @@ -760,30 +759,35 @@ def sym_ite(b, t, f): ''').strip()) from None raise # If __file__ is not None the cause is unknown, so just re-raise. -for name in dir(_C): - if name[0] != '_' and not name.endswith('Base'): - __all__.append(name) - obj = getattr(_C, name) - if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type] - if (obj.__module__ != 'torch'): +__name, __obj = '', None +for __name in dir(_C): + if __name[0] != '_' and not __name.endswith('Base'): + __all__.append(__name) + __obj = getattr(_C, __name) + if callable(__obj) or inspect.isclass(__obj): + if __obj.__module__ != __name__: # TODO: fix their module from C++ side - if name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']: - obj.__module__ = 'torch' - elif name == 'TensorBase': + if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']: + __obj.__module__ = __name__ + elif __name == 'TensorBase': # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch. - delattr(sys.modules[__name__], name) + delattr(sys.modules[__name__], __name) + +del __name, __obj if not TYPE_CHECKING: # issue 38137 and python issue 43367. Submodules of a C extension are # non-standard, and attributes of those submodules cannot be pickled since # pickle expect to be able to import them as "from _C.sub import attr" # which fails with "_C is not a package - for attr in dir(_C): - candidate = getattr(_C, attr) - if type(candidate) is type(_C): + __name, __candidate = '', None + for __name in dir(_C): + __candidate = getattr(_C, __name) + if type(__candidate) is type(_C): # submodule - if f'torch._C.{attr}' not in sys.modules: - sys.modules[f'torch._C.{attr}'] = candidate + sys.modules.setdefault(f"{__name__}._C.{__name}", __candidate) + + del __name, __candidate ################################################################################ @@ -1669,7 +1673,7 @@ def _dtype(self): # Initialize extension ################################################################################ -def manager_path(): +def _manager_path(): if _running_with_deploy() or platform.system() == 'Windows': return b"" path = get_file_path('torch', 'bin', 'torch_shm_manager') @@ -1686,8 +1690,8 @@ def manager_path(): py_int = int # Shared memory manager needs to know the exact location of manager executable -_C._initExtension(manager_path()) -del manager_path +_C._initExtension(_manager_path()) +del _manager_path # Appease the type checker: it can't deal with direct setting of globals(). # Note that we will see "too many" functions when reexporting this way; there @@ -1708,20 +1712,22 @@ def manager_path(): 'unique_dim', ) -for name in dir(_C._VariableFunctions): - if name.startswith('__') or name in PRIVATE_OPS: +__name, __obj = '', None +for __name in dir(_C._VariableFunctions): + if __name.startswith('__') or __name in PRIVATE_OPS: continue - obj = getattr(_C._VariableFunctions, name) - obj.__module__ = 'torch' + __obj = getattr(_C._VariableFunctions, __name) + __obj.__module__ = __name__ # Hide some APIs that should not be public - if name == "segment_reduce": + if __name == "segment_reduce": # TODO: Once the undocumented FC window is passed, remove the line bellow - globals()[name] = obj - name = "_" + name - globals()[name] = obj - if not name.startswith("_"): - __all__.append(name) + globals()[__name] = __obj + __name = "_" + __name + globals()[__name] = __obj + if not __name.startswith("_"): + __all__.append(__name) +del __name, __obj ################################################################################ # Add torch.dtype instances to the public API @@ -1729,9 +1735,9 @@ def manager_path(): import torch -for attribute in dir(torch): - if isinstance(getattr(torch, attribute), torch.dtype): - __all__.append(attribute) +__all__.extend( + name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype) +) ################################################################################ # Import TorchDynamo's lazy APIs to avoid circular dependenices diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 90f0667fecfb..73c4beb547ee 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1995,7 +1995,6 @@ "torch.not_equal", "torch.nuclear_norm", "torch.numel", - "torch.obj", "torch.ones_like", "torch.ones", "torch.orgqr", @@ -2182,6 +2181,7 @@ "torch.xlogy", "torch.zero_", "torch.zeros", + "torch.zeros_like", "torch._fused_sgd_", "torch.slice_inverse", "torch._assert_scalar", From f08fd8e9e313adc44d1cd8cffd73edabfe231457 Mon Sep 17 00:00:00 2001 From: Hengwen Tong Date: Thu, 6 Jun 2024 13:01:39 +0000 Subject: [PATCH 407/706] Remove redundant device guard in Resize.h (#126498) In https://github.com/pytorch/pytorch/pull/113386 a device guard was [inserted](https://github.com/pytorch/pytorch/pull/113386/files#diff-2691af3a999b3a8f4a0f635aabcd8edf0ffeda501edfa9366648e8a89de12a90R30). The new inserted device guarded has a clear and more confined guarded scope. And it's hard to tell the exact purpose and scope of the [old device guard](https://github.com/kurtamohler/pytorch/blob/78ffe49a3fcd3ddc4f9f98500ccd3cbdee22a029/aten/src/ATen/native/cuda/Resize.h#L41). Removing the guard has negligible positive performance impact and make the code more understandable. Thanks Pull Request resolved: https://github.com/pytorch/pytorch/pull/126498 Approved by: https://github.com/eqy, https://github.com/lezcano --- aten/src/ATen/native/cuda/Resize.h | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/aten/src/ATen/native/cuda/Resize.h b/aten/src/ATen/native/cuda/Resize.h index 569b145fa61d..d5de128cac1d 100644 --- a/aten/src/ATen/native/cuda/Resize.h +++ b/aten/src/ATen/native/cuda/Resize.h @@ -29,18 +29,10 @@ static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_b inline TensorImpl* resize_impl_cuda_( TensorImpl* self, IntArrayRef size, - at::OptionalIntArrayRef stride, - bool device_guard = true) { + at::OptionalIntArrayRef stride) { if (self->sizes() == size && (!stride || self->strides() == stride)) { return self; } - - // NB: We don't need to hold the device guard when calling from TH - cuda::OptionalCUDAGuard guard; - if (device_guard) { - guard.set_index(self->storage().device().index()); - } - const auto itemsize = self->dtype().itemsize(); const auto storage_offset = self->storage_offset(); size_t storage_size = 1; From 48a54146e78773bac268493a6b4eb9be392b1b9e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 6 Jun 2024 14:21:29 +0000 Subject: [PATCH 408/706] Revert "[dynamo] Support ndarray.dtype attribute access (#124490)" This reverts commit 4adee71155bec4e419bac32be2cbc1763bc6c98f. Reverted https://github.com/pytorch/pytorch/pull/124490 on behalf of https://github.com/atalman due to Breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/124490#issuecomment-2152664749)) --- test/dynamo/test_functions.py | 4 -- test/test_binary_ufuncs.py | 4 +- test/test_unary_ufuncs.py | 2 +- .../numpy_tests/core/test_multiarray.py | 15 ++----- torch/_dynamo/variables/misc.py | 5 --- torch/_dynamo/variables/tensor.py | 3 -- torch/testing/_internal/common_utils.py | 42 +++++++++---------- 7 files changed, 29 insertions(+), 46 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 4d0285871ced..e2baebf60321 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1619,10 +1619,6 @@ def test_numpy_dtype_call_in_function(x): dt = np.dtype("float") return np.full_like(x, 2.4, dtype=dt) - @make_test - def test_numpy_dtype_attr(x): - return np.ones_like(x).dtype == x.dtype - @make_test def test_numpy_linalg(x): return np.linalg.norm(x.numpy(), axis=0) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index f1423d0ac8cb..ffa3e5388979 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -121,7 +121,9 @@ def _test_reference_numerics(self, dtype, op, gen, equal_nan=True): def _helper_reference_numerics( expected, actual, msg, exact_dtype, equal_nan=True ): - if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype], dtype): + if not torch.can_cast( + numpy_to_torch_dtype_dict[expected.dtype.type], dtype + ): exact_dtype = False if dtype is torch.bfloat16 and expected.dtype == np.float32: diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index b232d47260f4..f47e7d36222f 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -184,7 +184,7 @@ def _helper_reference_numerics( expected, actual, msg, exact_dtype, equal_nan=True ): if not torch.can_cast( - numpy_to_torch_dtype_dict[expected.dtype], dtype + numpy_to_torch_dtype_dict[expected.dtype.type], dtype ): exact_dtype = False diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index a957c8dd86c4..76af79f62084 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -1833,7 +1833,7 @@ def test_argsort_axis(self): a = np.array(["aaaaaaaaa" for i in range(100)], dtype=np.unicode_) assert_equal(a.argsort(kind="m"), r) - @xfail # (reason="TODO: searchsorted with nans differs in pytorch") + @xpassIfTorchDynamo # (reason="TODO: searchsorted with nans differs in pytorch") @parametrize( "a", [ @@ -1905,7 +1905,7 @@ def test_searchsorted_n_elements(self): b = a.searchsorted([0, 1, 2], "right") assert_equal(b, [0, 2, 2]) - @xfail # ( + @xpassIfTorchDynamo # ( # reason="RuntimeError: self.storage_offset() must be divisible by 8" # ) def test_searchsorted_unaligned_array(self): @@ -1984,7 +1984,7 @@ def test_searchsorted_with_invalid_sorter(self): # assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1, 0, 1, 2, 3]) # assert_raises(ValueError, np.searchsorted, a, 0, sorter=[4, 0, -1, 2, 3]) - @xfail # (reason="self.storage_offset() must be divisible by 8") + @xpassIfTorchDynamo # (reason="self.storage_offset() must be divisible by 8") def test_searchsorted_with_sorter(self): a = np.random.rand(300) s = a.argsort() @@ -3713,14 +3713,7 @@ def test_out_overlap(self): y = np.take(x, [1, 2, 3], out=x[2:5], mode="wrap") assert_equal(y, np.array([1, 2, 3])) - @parametrize( - "shape", - [ - subtest((1, 2)), - subtest((1,)), - subtest((), decorators=[skip("Sensitive to np version")]), - ], - ) + @parametrize("shape", [(1, 2), (1,), ()]) def test_ret_is_out(self, shape): # 0d arrays should not be an exception to this rule x = np.arange(5) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 06372f4d53b5..cc0fb7096701 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1189,11 +1189,6 @@ class NumpyTypeInfoVariable(ConstantLikeVariable): class NumpyDTypeVariable(ConstantLikeVariable): _error_prefix = "np.dtype[...]" - def __init__(self, value, **kwargs): - if isinstance(value, tnp.DType): - value = ConstantLikeVariable.np_dtype(value.name) - super().__init__(value, **kwargs) - def as_proxy(self): """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable: diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 30cbd556d0b2..0552a8e62122 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1089,7 +1089,6 @@ def var_getattr(self, tx, name): from ..utils import numpy_attr_wrapper from .builder import wrap_fx_proxy - from .misc import NumpyDTypeVariable result = None @@ -1136,8 +1135,6 @@ def insert_into_graph(): if not has_free_symbols(r := example_ndarray.size): return ConstantVariable.create(int(r)) return insert_into_graph() - if name == "dtype": - return NumpyDTypeVariable(example_ndarray.dtype) elif name in ["base", "flags", "dtype"]: unimplemented(f"TODO: add support for ndarray.{name}") elif name in ["__version__"]: diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index a85c44fe1e05..8e3a66c77929 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1500,31 +1500,31 @@ def wrapper(*args, **kwargs): # Dict of NumPy dtype -> torch dtype (when the correspondence exists) numpy_to_torch_dtype_dict = { - np.dtype(np.bool_) : torch.bool, - np.dtype(np.uint8) : torch.uint8, - np.dtype(np.uint16) : torch.uint16, - np.dtype(np.uint32) : torch.uint32, - np.dtype(np.uint64) : torch.uint64, - np.dtype(np.int8) : torch.int8, - np.dtype(np.int16) : torch.int16, - np.dtype(np.int32) : torch.int32, - np.dtype(np.int64) : torch.int64, - np.dtype(np.float16) : torch.float16, - np.dtype(np.float32) : torch.float32, - np.dtype(np.float64) : torch.float64, - np.dtype(np.complex64) : torch.complex64, - np.dtype(np.complex128): torch.complex128 + np.bool_ : torch.bool, + np.uint8 : torch.uint8, + np.uint16 : torch.uint16, + np.uint32 : torch.uint32, + np.uint64 : torch.uint64, + np.int8 : torch.int8, + np.int16 : torch.int16, + np.int32 : torch.int32, + np.int64 : torch.int64, + np.float16 : torch.float16, + np.float32 : torch.float32, + np.float64 : torch.float64, + np.complex64 : torch.complex64, + np.complex128 : torch.complex128 } -# numpy dtypes like np.float64 are not instances, but rather classes. This leads -# to rather absurd cases like np.float64 != np.dtype("float64") but -# np.dtype(np.float64) == np.dtype("float64") and -# np.dtype(np.dtype("float64")) == np.dtype("float64"). Especially when -# checking against a reference we can't be sure which variant we get, so we -# simply apply the conversion. +# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like +# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type. +# Especially when checking against a reference we can't be sure which variant we get, so we simply try both. def numpy_to_torch_dtype(np_dtype): - return numpy_to_torch_dtype_dict[np.dtype(np_dtype)] + try: + return numpy_to_torch_dtype_dict[np_dtype] + except KeyError: + return numpy_to_torch_dtype_dict[np_dtype.type] def has_corresponding_torch_dtype(np_dtype): From 9d849d4312cd1e62d97b9e9d58979ec78d36c95f Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Thu, 6 Jun 2024 15:17:35 +0000 Subject: [PATCH 409/706] Disable py3.12 nightly wheel builds for ROCm (#127968) Triton commit bump PR https://github.com/pytorch/pytorch/pull/125396 reverted due to missing llnl-hatchet dependency for triton. Workaround is to disable py3.12 binary build jobs for ROCm on PyTorch CI until llnl-hatchet publishes py3.12 wheels on [PyPI](https://pypi.org/project/llnl-hatchet/#files) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127968 Approved by: https://github.com/atalman, https://github.com/pruthvistony --- .../scripts/generate_binary_build_matrix.py | 4 + ...nerated-linux-binary-manywheel-nightly.yml | 206 ------------------ 2 files changed, 4 insertions(+), 206 deletions(-) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index b192475f72b1..855f37af7eda 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -347,6 +347,10 @@ def generate_wheels_matrix( for python_version in python_versions: for arch_version in arches: gpu_arch_type = arch_type(arch_version) + # Disable py3.12 builds for ROCm because of triton dependency + # on llnl-hatchet, which doesn't have py3.12 wheels available + if gpu_arch_type == "rocm" and python_version == "3.12": + continue gpu_arch_version = ( "" if arch_version == "cpu" diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 8ad43b4c3660..272e15577cc7 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -2410,209 +2410,3 @@ jobs: conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-rocm6_0-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.0 - GPU_ARCH_VERSION: 6.0 - GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-rocm6_0 - build_environment: linux-binary-manywheel - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-rocm6_0-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_12-rocm6_0-build - runs-on: linux.rocm.gpu - timeout-minutes: 240 - env: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.0 - GPU_ARCH_VERSION: 6.0 - GPU_ARCH_TYPE: rocm - SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main - DESIRED_PYTHON: "3.12" - steps: - - name: Setup ROCm - uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 - name: Download Build Artifacts - with: - name: manywheel-py3_12-rocm6_0 - path: "${{ runner.temp }}/artifacts/" - - name: Checkout PyTorch - uses: malfet/checkout@silent-checkout - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - submodules: recursive - path: pytorch - quiet-checkout: true - - name: Clean PyTorch checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: pytorch - - name: Checkout pytorch/builder - uses: malfet/checkout@silent-checkout - with: - ref: main - submodules: recursive - repository: pytorch/builder - path: builder - quiet-checkout: true - - name: Clean pytorch/builder checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: builder - - name: ROCm set GPU_FLAG - run: | - echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" - - name: Pull Docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@main - with: - docker-image: pytorch/manylinux-builder:rocm6.0-main - - name: Test Pytorch binary - uses: ./pytorch/.github/actions/test-pytorch-binary - - name: Teardown ROCm - uses: ./.github/actions/teardown-rocm - manywheel-py3_12-rocm6_0-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-rocm6_0-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.0 - GPU_ARCH_VERSION: 6.0 - GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-rocm6_0 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-rocm6_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.1 - GPU_ARCH_VERSION: 6.1 - GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-rocm6_1 - build_environment: linux-binary-manywheel - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-rocm6_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_12-rocm6_1-build - runs-on: linux.rocm.gpu - timeout-minutes: 240 - env: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.1 - GPU_ARCH_VERSION: 6.1 - GPU_ARCH_TYPE: rocm - SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main - DESIRED_PYTHON: "3.12" - steps: - - name: Setup ROCm - uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 - name: Download Build Artifacts - with: - name: manywheel-py3_12-rocm6_1 - path: "${{ runner.temp }}/artifacts/" - - name: Checkout PyTorch - uses: malfet/checkout@silent-checkout - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - submodules: recursive - path: pytorch - quiet-checkout: true - - name: Clean PyTorch checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: pytorch - - name: Checkout pytorch/builder - uses: malfet/checkout@silent-checkout - with: - ref: main - submodules: recursive - repository: pytorch/builder - path: builder - quiet-checkout: true - - name: Clean pytorch/builder checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: builder - - name: ROCm set GPU_FLAG - run: | - echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" - - name: Pull Docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@main - with: - docker-image: pytorch/manylinux-builder:rocm6.1-main - - name: Test Pytorch binary - uses: ./pytorch/.github/actions/test-pytorch-binary - - name: Teardown ROCm - uses: ./.github/actions/teardown-rocm - manywheel-py3_12-rocm6_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-rocm6_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.1 - GPU_ARCH_VERSION: 6.1 - GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-rocm6_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml From c58d3af3b47dd1413c1401fe9e1d90d00d428cd0 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 6 Jun 2024 15:44:47 +0000 Subject: [PATCH 410/706] Revert "Add OpInfo entry for alias_copy (#127232)" This reverts commit 457df212e1c6e1aa4f1eb2ad6ee292052d7c07e1. Reverted https://github.com/pytorch/pytorch/pull/127232 on behalf of https://github.com/clee2000 due to broke [onnx](https://github.com/pytorch/pytorch/actions/runs/9397057801/job/25880181144) and [mps](https://github.com/pytorch/pytorch/actions/runs/9397057805/job/25879818705) tests, [hud link](https://hud.pytorch.org/pytorch/pytorch/commit/457df212e1c6e1aa4f1eb2ad6ee292052d7c07e1) , base is 15 days old, the onnx test xfailed on the pr but the xfail was removed so if you rebase itll surface, mps build failed so no mps tests were run on the pr ([comment](https://github.com/pytorch/pytorch/pull/127232#issuecomment-2152848758)) --- .../ATen/functorch/BatchRulesDecompositions.cpp | 1 - test/distributed/_tensor/test_dtensor_ops.py | 1 - .../HasDecompTest.test_has_decomposition.expect | 2 ++ test/functorch/test_vmap_registrations.py | 1 - tools/autograd/gen_variable_type.py | 1 - torch/_decomp/__init__.py | 1 - torch/_inductor/exc.py | 2 +- torch/_refs/__init__.py | 4 ---- .../_internal/common_methods_invocations.py | 15 --------------- 9 files changed, 3 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index a0007aa18a00..3e064d6c39dc 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -324,7 +324,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(type_as); OP_DECOMPOSE(linalg_diagonal); OP_DECOMPOSE(diagonal_copy); - OP_DECOMPOSE(alias_copy); m.impl("pad", native::pad_symint); m.impl("_pad_circular", native::_pad_circular_symint); OP_DECOMPOSE(swapdims_); diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 07f8bfedc615..83f0bb875167 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -102,7 +102,6 @@ def wrapped(fn): xfail("addr"), xfail("all"), xfail("allclose"), - xfail("alias_copy"), xfail("amax"), xfail("amin"), xfail("aminmax"), diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index eeee3685e1fb..ad9cf07d7550 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -647,6 +647,8 @@ aten::adaptive_max_pool3d_backward.grad_input aten::addbmm aten::addbmm.out aten::addr_ +aten::alias_copy +aten::alias_copy.out aten::allclose aten::angle aten::angle.out diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index 737927a60f80..967152945af5 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -25,7 +25,6 @@ } xfail_functorch_batched_decomposition = { - "aten::alias_copy", "aten::diagonal_copy", "aten::is_same_size", "aten::unfold_copy", diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 6abb13d244e9..b9651ea2da80 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -305,7 +305,6 @@ "linalg_eig", "diagonal_copy", "diagonal_scatter", - "alias_copy", "select_backward", "diagonal_backward", "slice_backward", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 74587354424f..b277bb7eceb0 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -260,7 +260,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.addcmul_, aten.addr, aten.affine_grid_generator, - aten.alias_copy, aten.all, aten.aminmax, aten.arange.default, diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index 83638e91679d..9e6aa6effae2 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -45,7 +45,7 @@ def __init__(self, target, args, kwargs): There is a decomposition available for {target} in torch._decomp.get_decompositions(). Please add this operator to the - `decompositions` list in torch._inductor.decomposition + `decompositions` list in torch._inductor.decompositions """ ) ) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index f166edc009d8..68675c751736 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -232,7 +232,6 @@ # View & Shape Ops # "alias", - "alias_copy", "atleast_1d", "atleast_2d", "atleast_3d", @@ -4452,9 +4451,6 @@ def alias(a: TensorLikeType) -> TensorLikeType: return prims.view_of(a) -alias_copy = _make_copy_from_view(alias) - - @register_decomposition(aten.transpose) def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e3fec2129634..151210cf9f53 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11584,12 +11584,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:] return np.reshape(input, out_shape) - -def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): - yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad)) - yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) - - # Operator database (sorted alphabetically) op_db: List[OpInfo] = [ UnaryUfuncInfo('abs', @@ -13093,11 +13087,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_diagonal_scatter), - OpInfo('alias_copy', - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), - sample_inputs_func=sample_inputs_alias_copy, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True), BinaryUfuncInfo('eq', ref=np.equal, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), @@ -23234,10 +23223,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): # # View & Shape OpInfos # - PythonRefInfo( - "_refs.alias_copy", - torch_opinfo_name="alias_copy", - ), PythonRefInfo( "_refs.atleast_1d", torch_opinfo_name="atleast_1d", From a5ba9b2858c5535cd51e79cd861fe96763343d0c Mon Sep 17 00:00:00 2001 From: Joona Havukainen Date: Thu, 6 Jun 2024 16:09:18 +0000 Subject: [PATCH 411/706] Fix for addcdiv contiguous problem (#124442) Fixes issue number #118115 Co-authored-by: Siddharth Kotapati Pull Request resolved: https://github.com/pytorch/pytorch/pull/124442 Approved by: https://github.com/kulinseth --- .../native/mps/operations/PointwiseOps.mm | 27 ++++++++++++++--- test/test_mps.py | 30 +++++++++++++++++++ 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm index 137c14be6ef4..9010dd3add24 100644 --- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm +++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm @@ -38,6 +38,19 @@ static void addc_mul_div_out_mps(const Tensor& self, }; @autoreleasepool { + bool executeGatherOpOnSelf = + !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || + self.is_contiguous(MemoryFormat::ChannelsLast3d)); + Tensor output_ = at::empty_like(self, executeGatherOpOnSelf ? MemoryFormat::Contiguous : MemoryFormat::Preserve); + + bool executeGatherOpOnFirstTensor = + !(tensor1.is_contiguous(MemoryFormat::Contiguous) || tensor1.is_contiguous(MemoryFormat::ChannelsLast) || + tensor1.is_contiguous(MemoryFormat::ChannelsLast3d)); + + bool executeGatherOpOnSecondTensor = + !(tensor2.is_contiguous(MemoryFormat::Contiguous) || tensor2.is_contiguous(MemoryFormat::ChannelsLast) || + tensor2.is_contiguous(MemoryFormat::ChannelsLast3d)); + string key = op_name + getTensorsStringKey({self, tensor1, tensor2}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { @@ -72,10 +85,12 @@ static void addc_mul_div_out_mps(const Tensor& self, }); // Inputs as placeholders - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor, self); - Placeholder tensor1Placeholder = Placeholder(cachedGraph->firstTensor, tensor1); - Placeholder tensor2Placeholder = Placeholder(cachedGraph->secondTensor, tensor2); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor, self, nil, executeGatherOpOnSelf); + Placeholder tensor1Placeholder = Placeholder(cachedGraph->firstTensor, tensor1, nil, executeGatherOpOnFirstTensor); + Placeholder tensor2Placeholder = + Placeholder(cachedGraph->secondTensor, tensor2, nil, executeGatherOpOnSecondTensor); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor, executeGatherOpOnSelf ? output_ : output, nil, false); MPSScalar value_scalar = getMPSScalar(value_opt, self.scalar_type()); // Create dictionary of inputs and outputs @@ -87,6 +102,10 @@ static void addc_mul_div_out_mps(const Tensor& self, }; runMPSGraph(mpsStream, cachedGraph->graph(), feeds, outputPlaceholder); + + if (executeGatherOpOnSelf) { + output.copy_(output_); + } } } diff --git a/test/test_mps.py b/test/test_mps.py index 32e7a2a08861..93437fd5509d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -3277,6 +3277,36 @@ def helper(shape, value): helper((2, 8, 4, 5), 0.2) helper((2, 3, 4, 5), 1.0) # value of 1 should be ignored internally + def test_addcdiv_transpose(self): + # Regression test for issue https://github.com/pytorch/pytorch/issues/118115 + # Testing continuity of all input tensors + + def helper(shape, value): + shape_t = shape[::-1] + for i in range(2): + for j in range(2): + for k in range(2): + x = torch.rand(shape, device="cpu") if i == 0 else torch.rand(shape_t, device="cpu").t() + y = torch.rand(shape, device="cpu") if j == 0 else torch.rand(shape_t, device="cpu").t() + z = torch.rand(shape, device="cpu") if k == 0 else torch.rand(shape_t, device="cpu").t() + + x_mps = x.detach().clone().to(device="mps") + y_mps = y.detach().clone().to(device="mps") + z_mps = z.detach().clone().to(device="mps") + + result_cpu = x.addcdiv_(y, z, value=value) + result_mps = x_mps.addcdiv(y_mps, z_mps, value=value) + result_mps_out = result_cpu.detach().clone().to('mps') + torch.addcdiv(x_mps, y_mps, z_mps, out=result_mps_out, value=value) + + self.assertEqual(result_cpu, result_mps) + self.assertEqual(result_cpu, result_mps_out) + + helper((2, 3), 1.0) + helper((2, 3), 0.2) + helper((100, 300), 1.0) + helper((100, 300), 0.2) + def test_buffer_size_match(self): # this test shouldn't cause any crash size = 16 From a5b86a1ec0150f78d88fc389df4909212ece0108 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 6 Jun 2024 16:12:34 +0000 Subject: [PATCH 412/706] Revert "FP8 rowwise scaling (#125204)" This reverts commit 5dc912822913b3d90f4938891c7eca722a057cf1. Reverted https://github.com/pytorch/pytorch/pull/125204 on behalf of https://github.com/atalman due to Sorry need to revert this failing, on internal CI. I suggest to reimport this and try to land internally resolving all issues ([comment](https://github.com/pytorch/pytorch/pull/125204#issuecomment-2152905513)) --- aten/src/ATen/CMakeLists.txt | 1 - aten/src/ATen/cuda/detail/LazyNVRTC.cpp | 37 -- aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h | 15 +- aten/src/ATen/native/cuda/Blas.cpp | 113 +--- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 536 ------------------- aten/src/ATen/native/cuda/RowwiseScaledMM.h | 15 - test/test_matmul_cuda.py | 149 +----- third_party/cutlass.BUILD | 14 +- 8 files changed, 25 insertions(+), 855 deletions(-) delete mode 100644 aten/src/ATen/native/cuda/RowwiseScaledMM.cu delete mode 100644 aten/src/ATen/native/cuda/RowwiseScaledMM.h diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 5cd6aacf2463..0087dd95d96e 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -473,7 +473,6 @@ endif() if(USE_CUDA AND NOT USE_ROCM) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) - list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) if($ENV{ATEN_STATIC_CUDA}) list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${CUDA_LIBRARIES} diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp index 75c503d48d51..1b85e7776e22 100644 --- a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp +++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -170,43 +170,6 @@ CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *); CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int); CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction); -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 -CUresult CUDAAPI -cuTensorMapEncodeTiled( - CUtensorMap* tensorMap, - CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, - void* globalAddress, - const cuuint64_t* globalDim, - const cuuint64_t* globalStrides, - const cuuint32_t* boxDim, - const cuuint32_t* elementStrides, - CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill) { - auto fn = reinterpret_cast( - getCUDALibrary().sym(__func__)); - if (!fn) - throw std::runtime_error("Can't get cuTensorMapEncodeTiled"); - lazyNVRTC.cuTensorMapEncodeTiled = fn; - return fn( - tensorMap, - tensorDataType, - tensorRank, - globalAddress, - globalDim, - globalStrides, - boxDim, - elementStrides, - interleave, - swizzle, - l2Promotion, - oobFill); -} - -#endif - // Irregularly shaped functions CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index cb34d10db254..574b2c41c264 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -59,25 +59,16 @@ namespace at { namespace cuda { _(cuLinkAddData) \ _(cuLinkComplete) \ _(cuFuncSetAttribute) \ - _(cuFuncGetAttribute) \ - -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 -#define AT_FORALL_NVRTC_EXTENDED(_) \ - AT_FORALL_NVRTC_BASE(_) \ - _(cuTensorMapEncodeTiled) -#else -#define AT_FORALL_NVRTC_EXTENDED(_) \ - AT_FORALL_NVRTC_BASE(_) -#endif + _(cuFuncGetAttribute) #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 #define AT_FORALL_NVRTC(_) \ - AT_FORALL_NVRTC_EXTENDED(_) \ + AT_FORALL_NVRTC_BASE(_) \ _(nvrtcGetCUBINSize) \ _(nvrtcGetCUBIN) #else #define AT_FORALL_NVRTC(_) \ - AT_FORALL_NVRTC_EXTENDED(_) + AT_FORALL_NVRTC_BASE(_) #endif #else diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index ed59b47349cc..84c59a4fd0d7 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1,7 +1,3 @@ -#include -#include -#include -#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -14,7 +10,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -824,97 +819,24 @@ static bool _scaled_mm_allowed_device() { #endif } -namespace{ - -enum class ScalingType { - TensorWise, - RowWise, - Error -}; - -// Validates the scale tensors to scaled_mm -// And returns the type of scaling/which kernel to use -ScalingType get_scaling_type( - const c10::optional& scale_a, - const c10::optional& scale_b, - int64_t dim_m, - int64_t dim_n) { - TORCH_CHECK( - scale_a.has_value() == scale_b.has_value(), - "Both scale_a and scale_b must be present or absent."); - - if (scale_a.has_value()) { - // Both Per-Tensor and Row-wise scaling expect fp32 tensors - TORCH_CHECK( - scale_a->scalar_type() == kFloat && scale_b->scalar_type() == kFloat, - "Both scale_a and scale_b must be float (fp32) tensors."); - - // Check the singluar scale case for per-tensor scaling - if (scale_a->numel() == 1 && scale_b->numel() == 1) { - return ScalingType::TensorWise; - } else if (scale_a->dim() == 1 && scale_a->size(0) == dim_m) { -// Check the per-row scaling case -#if !defined(USE_ROCM) && !defined(_MSC_VER) || \ - (defined(USE_ROCM) && ROCM_VERSION >= 60000) - TORCH_CHECK( - scale_a->dim() == 1 && scale_b->dim() == 1, - "Both scale_a and scale_b must be 1-dimensional tensors"); - TORCH_CHECK( - scale_b->size(0) == dim_n, - "For row-wise scaling, scale_b must have size ", - dim_n, - " but got ", - scale_b->size(0), - "."); - TORCH_CHECK( - scale_a->is_contiguous() && scale_b->is_contiguous(), - "Both scale_a and scale_b must be contiguous."); - return ScalingType::RowWise; -#else - TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); - return ScalingType::Error; -#endif // !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && - // ROCM_VERSION >= 60000) - } else { - TORCH_CHECK( - false, - "For row-wise scaling, scale_a must be size ", - dim_m, - " but got ", - scale_a->numel(), - " and scale_b must be size ", - dim_n, - " but got ", - scale_b->numel(), - "."); - // Unreachable - return ScalingType::RowWise; - } - } - return ScalingType::Error; -} - -} // namespace - // Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax // Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default. // If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed. // Known limitations: // - Only works if mat1 is row-major and mat2 is column-major // - Only works if matrices sizes are divisible by 32 -// - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0) -// and scale_b should have size = to mat2.size(1) +// // Arguments: // - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16` // - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type -// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type -// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type -// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type +// - `scale_a`: a scalar tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type +// - `scale_b`: a scalar tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type +// - `scale_result`: a scalar tensor with the scale of the output, only set if the output is a float8 type // - `use_fast_accum`: if true, enables fast float8 accumulation // - `out`: a reference to the output tensor -// - `amax`: a reference to the amax tensor of the output, only mutated if the output is a float8 type and will be updated inplace +// - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace std::tuple _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, @@ -933,11 +855,10 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK( mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - - // Check what type of scaling we are doing based on inputs - ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1)); - TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); - + TORCH_CHECK(!scale_a || (scale_a->numel() == 1 && scale_a->scalar_type() == kFloat), + "scale_a must be float scalar"); + TORCH_CHECK(!scale_b || (scale_b->numel() == 1 && scale_b->scalar_type() == kFloat), + "scale_b must be a float scalar"); TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], @@ -980,26 +901,12 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, {scale_result_, "scale_result", 7}}; checkAllSameGPU(__func__, targs); } - // Validation checks have passed lets resize the output to actual size + IntArrayRef mat1_sizes = mat1.sizes(); IntArrayRef mat2_sizes = mat2.sizes(); at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); at::native::resize_output(amax, {}); - // We are doing row-wise scaling - if (scaling_choice == ScalingType::RowWise) { - TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling."); - at::cuda::detail::f8f8bf16_rowwise( - mat1, - mat2, - scale_a.value(), - scale_b.value(), - bias, - use_fast_accum, - out); - return {out, amax}; - } - cublasCommonArgs args(mat1, mat2, out); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu deleted file mode 100644 index 84655d281afc..000000000000 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ /dev/null @@ -1,536 +0,0 @@ -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include -#include -#include - -// Determine if the architecture supports rowwise scaled mm -// Currenlty failing on windows with: https://github.com/NVIDIA/cutlass/issues/1571 -#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 - -#define BUILD_ROWWISE_FP8_KERNEL -#endif - -#if defined(BUILD_ROWWISE_FP8_KERNEL) - -// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader -static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled( - CUtensorMap* tensorMap, - CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, - void* globalAddress, - const cuuint64_t* globalDim, - const cuuint64_t* globalStrides, - const cuuint32_t* boxDim, - const cuuint32_t* elementStrides, - CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill) { - return at::globalContext().getNVRTC().cuTensorMapEncodeTiled( - tensorMap, - tensorDataType, - tensorRank, - globalAddress, - globalDim, - globalStrides, - boxDim, - elementStrides, - interleave, - swizzle, - l2Promotion, - oobFill); -} - - -#include -#include -#include -#include -#include -#include -#include - -// Rename the global function symbol -#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled -#include -#undef cuTensorMapEncodeTiled -// Set everything back to normal - -#include -#include -#include - -#include -#include -#include -#include - - -namespace { -// Cutlass rowwise kernel -template < - int TB_M, - int TB_N, - int TB_K, - int TBS_M, - int TBS_N, - int TBS_K, - bool PONG, - bool FAST_ACCUM, - bool USE_BIAS, - typename INPUT_DTYPE, - typename BIAS_DTYPE> -void f8f8bf16_rowwise_impl( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, - at::Tensor w_scale, - c10::optional bias, - at::Tensor out) { - int M = XQ.size(0); - int N = WQ.size(1); - int K = XQ.size(1); - - TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); - TORCH_CHECK( - WQ.is_cuda() && WQ.ndimension() == 2 && WQ.stride(1) == WQ.size(0) && - WQ.stride(0) == 1); - - // auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); - - using ElementInputA = INPUT_DTYPE; - using LayoutInputA = cutlass::layout::RowMajor; - constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); - - using ElementInputB = cutlass::float_e4m3_t; - using LayoutInputB = cutlass::layout::ColumnMajor; - constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); - - using ElementBias = BIAS_DTYPE; - - using ElementOutput = cutlass::bfloat16_t; - using LayoutOutput = cutlass::layout::RowMajor; - constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); - - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; - using TileShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Threadblock-level - // tile size - using ClusterShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Shape of the - // threadblocks in a - // cluster - using KernelSchedule = cutlass::gemm::collective:: - KernelScheduleAuto; // Kernel to launch based on the default setting in - // the Collective Builder - - // Implement rowwise scaling epilogue. - using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0, - TileShape, - ElementComputeEpilogue, - cute::Stride, cute::Int<0>, cute::Int<0>>>; - - using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< - PONG ? 2 : 1, - TileShape, - ElementComputeEpilogue, - cute::Stride, cute::Int<1>, cute::Int<0>>>; - - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< - PONG ? 2 : 1, - TileShape, - ElementBias, - cute::Stride, cute::Int<1>, cute::Int<0>>>; - - using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, - ElementComputeEpilogue, // First stage output type. - ElementComputeEpilogue, // First stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, - cute::conditional_t< // Second stage output type. - USE_BIAS, - ElementBias, - ElementOutput>, - ElementComputeEpilogue, // Second stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute1 = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< - cutlass::plus, - ElementOutput, // Final (optional) stage output type. - ElementBias, // Final stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeBias = - cutlass::epilogue::fusion::Sm90EVT; - - using EpilogueEVT = - cute::conditional_t; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, - cutlass::arch::OpClassTensorOp, - TileShape, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, - ElementComputeEpilogue, - ElementOutput, - LayoutOutput, - AlignmentOutput, - ElementOutput, - LayoutOutput, - AlignmentOutput, - cutlass::epilogue::TmaWarpSpecialized, - EpilogueEVT>::CollectiveOp; - - using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; - using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using FastDefaultSchedule = - cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; - using FastPongSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using SlowAccum = cute::conditional_t; - using FastAccum = - cute::conditional_t; - using MainLoopSchedule = - cute::conditional_t; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementInputA, - LayoutInputA, - AlignmentInputA, - ElementInputB, - LayoutInputB, - AlignmentInputB, - ElementAccumulator, - TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainLoopSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue>; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideInputA = typename Gemm::GemmKernel::StrideA; - using StrideInputB = typename Gemm::GemmKernel::StrideB; - using StrideOutput = typename Gemm::GemmKernel::StrideC; - - StrideInputA stride_a = cutlass::make_cute_packed_stride( - StrideInputA{}, cute::make_shape(M, K, 1)); - StrideInputB stride_b = cutlass::make_cute_packed_stride( - StrideInputB{}, cute::make_shape(N, K, 1)); - StrideOutput stride_output = cutlass::make_cute_packed_stride( - StrideOutput{}, cute::make_shape(M, N, 1)); - - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K}, - {reinterpret_cast(XQ.data_ptr()), - stride_a, - reinterpret_cast(WQ.data_ptr()), - stride_b}, - {{}, // Epilogue thread we populate below. - (ElementOutput*)out.data_ptr(), - stride_output, - (ElementOutput*)out.data_ptr(), - stride_output}}; - - if constexpr (USE_BIAS) { - arguments.epilogue.thread = { - {reinterpret_cast(bias.value().data_ptr())}, // bias - // compute_1 - { - {reinterpret_cast( - x_scale.data_ptr())}, // x_scale - // compute_0 - { - {reinterpret_cast( - w_scale.data_ptr())}, // w_scale - {}, // Accumulator - {} // Multiplies - }, - {}, // Multiplies - }, - {}, // Plus - }; - } else { - arguments.epilogue.thread = { - {reinterpret_cast( - x_scale.data_ptr())}, // x_scale - // compute_0 - { - {reinterpret_cast( - w_scale.data_ptr())}, // w_scale - {}, // Accumulator - {} // Multiplies - }, - {}, // Multiplies - }; - } - - Gemm gemm; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm(at::cuda::getCurrentCUDAStream()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error( - std::string("cutlass cannot run") + - cutlass::cutlassGetStatusString(status)); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -// FP8 Rowwise Cutlass kernel dispatch. -enum class KernelMode { Small, Large, Default }; - -KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { - auto M = XQ.size(0); - auto K = XQ.size(1); - auto N = WQ.size(0); - // Use a large kernel if at least two shapes are large.... - bool use_large_kernel = - ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || - (K >= 2048 && N >= 2048)); - if (M <= 128 || N <= 128) { - return KernelMode::Small; - } else if (use_large_kernel) { - return KernelMode::Large; - } else { - return KernelMode::Default; - } -} - -template -void dispatch_fp8_rowwise_kernel( - at::Tensor XQ, - at::Tensor WQ, - at::Tensor x_scale, - at::Tensor w_scale, - c10::optional bias, - at::Tensor out) { - KernelMode kernel = get_kernel_mode(XQ, WQ); - if (kernel == KernelMode::Small) { - return f8f8bf16_rowwise_impl< - 64, - 128, - 128, - 2, - 1, - 1, - false, - FastAccum, - UseBias, - InputDType, - BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); - } else if (kernel == KernelMode::Large) { - return f8f8bf16_rowwise_impl< - 128, - 128, - 128, - 2, - 1, - 1, - true, - FastAccum, - UseBias, - InputDType, - BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return f8f8bf16_rowwise_impl< - 128, - 128, - 128, - 1, - 2, - 1, - false, - FastAccum, - UseBias, - InputDType, - BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); - } -} - -} // namespace - -#endif // !defined(USE_ROCM) - -namespace at::cuda::detail { -void f8f8bf16_rowwise( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, // FP32 - at::Tensor w_scale, // FP32 - c10::optional bias, // BF16 - bool use_fast_accum, - at::Tensor& out) { -#if defined(BUILD_ROWWISE_FP8_KERNEL) - // Check datatypes. - TORCH_CHECK( - x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, - "Scale tensors must be float32."); - if (bias.has_value()) { - TORCH_CHECK( - bias.value().dtype() == at::kFloat || - bias.value().dtype() == at::kBFloat16, - "Bias type must be bfloat16 or float32 if provided."); - } - // Extract problem size. - int M = XQ.size(0); - int N = WQ.size(1); - int K = XQ.size(1); - - bool use_bias = bias.has_value(); - bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; - - // Templatize based on input dtype. - bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2; - TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn, "For row-wise scaling the second input is required to be a float8_e4m3fn dtype."); - - if (use_bias) { - if (bf16_bias) { - if (use_fast_accum) { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - true, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - true, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); - } - } else { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - false, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - false, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); - } - } - } else { - if (use_fast_accum) { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - true, - true, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - true, - true, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } - } else { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - false, - true, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - false, - true, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } - } - } - } else { - if (use_fast_accum) { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - true, - false, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - true, - false, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } - } else { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - false, - false, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - false, - false, - float>(XQ, WQ, x_scale, w_scale, bias, out); - } - } - } -#else // BUILD_ROWWISE_FP8_KERNEL - TORCH_CHECK(false, "Rowwise scaling is not currenlty supported on your device"); -#endif -} - -} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.h b/aten/src/ATen/native/cuda/RowwiseScaledMM.h deleted file mode 100644 index 4d9054108c85..000000000000 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once -#include -#include - - -namespace at::cuda::detail { -TORCH_API void f8f8bf16_rowwise( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, // FP32 - at::Tensor w_scale, // FP32 - c10::optional bias, // BF16 - bool use_fast_accum, - at::Tensor& out); -} // at::cuda::detail diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 74381567a552..a5c583580848 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -204,6 +204,7 @@ def _expand_to_batch(t: torch.Tensor): self.assertEqual(out1_gpu, out2_gpu[0]) + f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" if torch.version.hip: @@ -255,12 +256,8 @@ def amax_to_scale( scale.copy_(res) return scale -def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): - if dim is None: - amax = torch.max(torch.abs(x)) - else: - amax = torch.max(torch.abs(x), dim=dim).values - +def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype): + amax = torch.max(torch.abs(x)) return amax_to_scale(amax, float8_dtype, x.dtype) def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype): @@ -319,6 +316,7 @@ def mm_float8( def to_fp8_saturated( x: torch.Tensor, + x_scale: torch.tensor, fp8_dtype: torch.dtype ): """ @@ -341,6 +339,8 @@ def to_fp8_saturated( of a tensor has a maximum value of `amax1`, and the current amax value is `amax2`, where `amax1 < amax2`. """ + x_scaled = x * x_scale + if fp8_dtype == e4m3_type: x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) elif fp8_dtype == e5m2_type: @@ -353,6 +353,8 @@ def to_fp8_saturated( @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") class TestFP8MatmulCuda(TestCase): + + @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) def _test_tautological_mm(self, device: str = "cuda", x_dtype: torch.dtype = e4m3_type, @@ -416,8 +418,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype): x_scale = tensor_to_scale(x, input_dtype).float() y_scale = tensor_to_scale(y, input_dtype).float() - x_fp8 = to_fp8_saturated(x * x_scale, e4m3_type) - y_fp8 = to_fp8_saturated(y * y_scale, e4m3_type) + x_fp8 = to_fp8_saturated(x, x_scale, e4m3_type) + y_fp8 = to_fp8_saturated(y, y_scale, e4m3_type) # Calculate actual F8 mm out_scaled_mm, output_amax_scaled = mm_float8( @@ -524,137 +526,6 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) - @unittest.skipIf(not scaled_mm_supported_device() or IS_WINDOWS, f8_msg) - @skipIfRocm() - @parametrize("use_fast_accum", [True, False]) - def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: - M, K, N = (1024, 512, 2048) - fill_value = 0.5 - x = torch.full((M, K), fill_value, device=device) - y = torch.full((N, K), fill_value, device=device) - - x_scales = torch.ones(x.shape[0], device=device, dtype=torch.float32) - y_scales = torch.ones(y.shape[0], device=device, dtype=torch.float32) - - x_fp8 = x.to(torch.float8_e4m3fn) - y_fp8 = y.to(torch.float8_e4m3fn).t() - - out_fp8, _ = torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=x_scales, - scale_b=y_scales, - out_dtype=torch.bfloat16, - use_fast_accum=use_fast_accum, - ) - self.assertEqual( - out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) - ) - - @unittest.skipIf(not scaled_mm_supported_device() or IS_WINDOWS, f8_msg) - @skipIfRocm() - def test_float8_error_messages(self, device) -> None: - M, K, N = (1024, 512, 2048) - fill_value = 0.5 - x = torch.full((M, K), fill_value, device=device) - y = torch.full((N, K), fill_value, device=device) - - x_fp8 = x.to(torch.float8_e4m3fn) - y_fp8 = y.to(torch.float8_e4m3fn).t() - - with self.assertRaisesRegex( - RuntimeError, - "For row-wise scaling, scale_a must be size 1024 but got 1 and scale_b must be size 2048 but got 2", - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((), device="cuda"), - scale_b=torch.ones((2), device="cuda"), - out_dtype=torch.bfloat16, - ) - - with self.assertRaisesRegex( - RuntimeError, - "For row-wise scaling, scale_b must have size 2048 but got 2049.", - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N + 1), device="cuda"), - out_dtype=torch.bfloat16, - ) - with self.assertRaisesRegex( - RuntimeError, - "Both scale_a and scale_b must be 1-dimensional tensors", - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N, N), device="cuda"), - out_dtype=torch.bfloat16, - ) - - with self.assertRaisesRegex( - RuntimeError, - "Both scale_a and scale_b must be contiguous.", - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N * 2), device="cuda")[::2], - out_dtype=torch.bfloat16, - ) - - with self.assertRaisesRegex( - RuntimeError, - "For row-wise scaling the second input is required to be a float8_e4m3fn dtype.", - ): - torch._scaled_mm( - x_fp8, - y_fp8.to(torch.float8_e5m2), - scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N), device="cuda"), - out_dtype=torch.bfloat16, - ) - - @unittest.skipIf(not scaled_mm_supported_device() or IS_WINDOWS, f8_msg) - @skipIfRocm() - @parametrize("base_dtype", [torch.bfloat16]) - def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): - torch.manual_seed(42) - input_dtype = e4m3_type - output_dtype = base_dtype - - x = torch.randn(16, 16, device="cuda", dtype=base_dtype) - y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() - - x_scales = tensor_to_scale(x, input_dtype, dim=1).float() - y_scales = tensor_to_scale(y, input_dtype, dim=0).float() - - x_fp8 = to_fp8_saturated(x * x_scales[:, None], e4m3_type) - y_fp8 = to_fp8_saturated(y * y_scales[None, :], e4m3_type) - - # Calculate actual F8 mm - out_scaled_mm, _ = mm_float8( - x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype - ) - - # Calculate emulated F8 mm - out_emulated, _ = mm_float8_emulated( - x_fp8, x_scales[:, None], y_fp8, y_scales[None, :], output_dtype - ) - - if base_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 7e-2, 7e-2 - else: - atol, rtol = 2e-3, 2e-3 - - torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD index e3e7b7b288e7..e712d59597cc 100644 --- a/third_party/cutlass.BUILD +++ b/third_party/cutlass.BUILD @@ -5,17 +5,7 @@ load("@rules_cc//cc:defs.bzl", "cc_library") cc_library( name = "cutlass", - hdrs = glob([ - "include/**/*.h", - "include/**/*.hpp", - "include/**/*.inl", - "tools/util/include/**/*.h", - "tools/util/include/**/*.hpp", - "tools/util/include/**/*.inl", - ]), - includes = [ - "include/", - "tools/util/include/", - ], + hdrs = glob(["include/**/*.h", "include/**/*.hpp"]), + includes = ["include/"], visibility = ["//visibility:public"], ) From 0de6d2427f42037e54ced973cabadaa85739ca55 Mon Sep 17 00:00:00 2001 From: eqy Date: Thu, 6 Jun 2024 16:17:43 +0000 Subject: [PATCH 413/706] Bump tolerances for `inductor/test_efficient_conv_bn_eval.py::EfficientConvBNEvalCudaTests::test_basic_cuda` attempt 2 (#128048) CC @nWEIdia @huydhn @Skylion007 Same thing but also bump backward tolerances... Pull Request resolved: https://github.com/pytorch/pytorch/pull/128048 Approved by: https://github.com/Skylion007 --- test/inductor/test_efficient_conv_bn_eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_efficient_conv_bn_eval.py b/test/inductor/test_efficient_conv_bn_eval.py index c65b7585f9f3..9def5230bdba 100644 --- a/test/inductor/test_efficient_conv_bn_eval.py +++ b/test/inductor/test_efficient_conv_bn_eval.py @@ -158,7 +158,7 @@ def test_conv_bn_eval( out_eager = mod_eager(inp) out_optimized = mod_optimized(inp) - self.assertEqual(out_optimized, out_eager, atol=2e-04, rtol=1e-5) + self.assertEqual(out_optimized, out_eager, atol=3e-04, rtol=1e-5) out_eager.mean().backward() out_optimized.mean().backward() @@ -170,7 +170,7 @@ def test_conv_bn_eval( out_eager_bw = mod_eager(inp_bw) out_optimized_bw = mod_optimized(inp_bw) - self.assertEqual(out_eager_bw, out_optimized_bw, atol=2e-04, rtol=1e-5) + self.assertEqual(out_eager_bw, out_optimized_bw, atol=3e-04, rtol=1e-5) current_value = counters["inductor"]["efficient_conv_bn_eval"] self.assertEqual( current_value - original_value, test_class.expected_optimization_count From f0dd11df5534ae074ad2d090e6700576a22719d6 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 6 Jun 2024 09:44:00 -0400 Subject: [PATCH 414/706] Make ValueRange repr less chatty by default (#128043) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/128043 Approved by: https://github.com/lezcano --- test/dynamo/test_misc.py | 12 ++++++------ torch/utils/_sympy/value_ranges.py | 3 +++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index dc2b9530f0dd..e173a4d7a69e 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9309,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9343,7 +9343,7 @@ def test_shape_env_equal_unbacked(self): > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False, is_int=True, is_float=False), u1: ValueRanges(lower=0, upper=1, is_bool=False, is_int=True, is_float=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False, is_int=False, is_float=True)} + > Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]} > Right: {} """, ) @@ -9420,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self): > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]} + > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} """, ) self._replay_and_check(main) @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self): > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]} + > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} """, ) self._replay_and_check(main) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 4d364d4981b5..c7257f999b52 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -127,6 +127,9 @@ class ValueRanges(Generic[_T]): is_int: bool is_float: bool + def __repr__(self) -> str: + return f"VR[{self.lower}, {self.upper}]" + @overload def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: ... From 4f87f47ea1143f45341cb32054ee3855d80d1fa9 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 6 Jun 2024 00:23:24 -0700 Subject: [PATCH 415/706] [dtensor] reuse DTensorSpec as much as possible (#128112) as titled, given that our DTensorSpec is immutable, we can always reuse the spec if the input/output have the same tensor metadata. this helps two fold: 1. We don't need to re-calculate the hash everytime we produce a DTensorSpec, reduce runtime operator overhead 2. reduce the DTensor construction overhead. Some local benchmark on a 800 parameter clip_grad_norm shows that for foreach_norm the CPU overhead reduces from 11ms -> 7.8ms (around 30% improvement) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128112 Approved by: https://github.com/awgu --- test/distributed/_tensor/test_dtensor.py | 50 +++++++----- .../_tensor/test_dtensor_compile.py | 29 ++++--- test/distributed/_tensor/test_utils.py | 23 ++++-- .../_composable/fsdp/_fsdp_common.py | 17 ++-- torch/distributed/_tensor/__init__.py | 22 ++++-- torch/distributed/_tensor/_dispatch.py | 11 +-- torch/distributed/_tensor/_redistribute.py | 24 +++--- torch/distributed/_tensor/api.py | 79 +++++++++++-------- torch/distributed/tensor/parallel/loss.py | 34 ++++---- 9 files changed, 171 insertions(+), 118 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index 531245057e1f..e29eede07d87 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -14,7 +14,13 @@ init_device_mesh, ) from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import Partial, Replicate, Shard +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Replicate, + Shard, + TensorMeta, +) from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -55,27 +61,29 @@ def test_dtensor_constructor(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) placements = [Shard(0)] local_tensor = torch.randn(3, 3, requires_grad=True) - dist_tensor_shape = torch.Size([self.world_size * 3, 3]) + + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + torch.Size([self.world_size * 3, 3]), + local_tensor.stride(), + local_tensor.dtype, + ), + ) + dist_tensor = DTensor( local_tensor, - device_mesh, - placements, - shape=dist_tensor_shape, - dtype=local_tensor.dtype, + spec, requires_grad=True, - stride=local_tensor.stride(), ) self.assertEqual(dist_tensor.size(), torch.Size((self.world_size * 3, 3))) with self.assertWarnsRegex(UserWarning, "To construct"): DTensor( local_tensor, - device_mesh, - placements, - shape=dist_tensor_shape, - dtype=local_tensor.dtype, + spec, requires_grad=False, - stride=local_tensor.stride(), ) @with_comms @@ -272,19 +280,23 @@ def test_from_local_negative_dim(self): def test_to_local(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) placements = (Shard(0),) - dist_tensor_shape = torch.Size([self.world_size * 3, 3]) local_tensor_with_grad = torch.randn( 3, 3, device=self.device_type, requires_grad=True ) - + dist_tensor_shape = torch.Size([self.world_size * 3, 3]) + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + dist_tensor_shape, + local_tensor_with_grad.stride(), + local_tensor_with_grad.dtype, + ), + ) sharded_tensor = DTensor( local_tensor_with_grad, - device_mesh, - placements, - shape=dist_tensor_shape, - dtype=local_tensor_with_grad.dtype, + spec, requires_grad=True, - stride=local_tensor_with_grad.stride(), ) self.assertEqual(sharded_tensor.size(), dist_tensor_shape) self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad) diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index 325d18be79f3..0f097e07e92f 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -21,6 +21,7 @@ Replicate, Shard, ) +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, @@ -193,41 +194,45 @@ def fn(x): def test_dtensor_constructor_w_graph_break(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + x = torch.randn(64, 32, requires_grad=True) + spec = DTensorSpec( + mesh, + (Replicate(), Shard(0)), + tensor_meta=TensorMeta( + shape=torch.Size([128, 32]), stride=(32, 1), dtype=x.dtype + ), + ) # test passing in DTensor as inputs/outputs and run some tensor computation def fn(x): print("graph break!") return DTensor( x, - mesh, - (Replicate(), Shard(0)), - shape=[128, 32], - dtype=x.dtype, + spec, requires_grad=x.requires_grad, - stride=[32, 1], ) - x = torch.randn(64, 32, requires_grad=True) out = fn(x) out2 = torch.compile(fn, backend="eager")(x) def test_dtensor_constructor_w_dynamo_disable(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + x = torch.randn(32, requires_grad=True) + spec = DTensorSpec( + mesh, + (Replicate(),), + tensor_meta=TensorMeta(shape=torch.Size([32]), stride=(1,), dtype=x.dtype), + ) @torch._dynamo.disable(recursive=False) def fn(x): print("foo") return DTensor( x, - mesh, - (Replicate(),), - shape=torch.Size([32]), - dtype=x.dtype, + spec, requires_grad=x.requires_grad, - stride=(1,), ) - x = torch.randn(32, requires_grad=True) out = fn(x) out2 = torch.compile(fn, backend="eager")(x) self.assertEqual(out, out2) diff --git a/test/distributed/_tensor/test_utils.py b/test/distributed/_tensor/test_utils.py index 3d6608a491ec..467b5e092306 100644 --- a/test/distributed/_tensor/test_utils.py +++ b/test/distributed/_tensor/test_utils.py @@ -10,7 +10,12 @@ ) from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import Replicate, Shard +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Replicate, + Shard, + TensorMeta, +) from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.testing._internal.common_utils import run_tests @@ -185,14 +190,20 @@ def test_fsdp2_tp_2d_dtensor_local_shards_and_offsets(self): chunks = list(torch.chunk(dtensor_tp.to_local(), 2, dim=0)) shard_rank = 0 if self.rank // 2 == 0 else 1 sharded_param = chunks[shard_rank] + spec_2d = DTensorSpec( + mesh=mesh_2d, + placements=(Shard(0), Shard(0)), + tensor_meta=TensorMeta( + global_tensor.size(), + global_tensor.stride(), + global_tensor.dtype, + ), + ) + dtensor_2d = DTensor( sharded_param, - mesh_2d, - [Shard(0), Shard(0)], - shape=global_tensor.size(), - dtype=global_tensor.dtype, + spec_2d, requires_grad=False, - stride=global_tensor.stride(), ) self.assertEqual( diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 1395e3487847..e7654964144b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -10,6 +10,7 @@ import torch.nn as nn from torch.distributed._composable.contract import _get_registry from torch.distributed._tensor import DeviceMesh, DTensor, Placement +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta @dataclass @@ -120,17 +121,23 @@ def _from_local_no_grad( This method is similar to ``DTensor.from_local()`` except that in eager mode it avoids some CPU overhead by avoiding default args and not being differentiable. """ + if not torch._dynamo.compiled_autograd.compiled_autograd_enabled: + spec = DTensorSpec( + device_mesh, + placements, + tensor_meta=TensorMeta( + global_size, + global_stride, + local_tensor.dtype, + ), + ) return DTensor( # Use the local tensor directly instead of constructing a new tensor # variable, e.g. with `view_as()`, since this is not differentiable local_tensor, - device_mesh, - placements, - shape=global_size, - dtype=local_tensor.dtype, + spec, requires_grad=local_tensor.requires_grad, - stride=global_stride, ) else: return DTensor.from_local( diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index 6ab35e10a69f..de01187f2512 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -52,6 +52,8 @@ def _dtensor_init_helper( placements=None, **kwargs, ) -> DTensor: + from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + # if device_mesh is None, use the one from mesh resources device_mesh = device_mesh or _mesh_resources.get_current_mesh() kwargs["device"] = device_mesh.device_type @@ -77,8 +79,6 @@ def _dtensor_init_helper( # this tensor meta is not used except `shape` dtype = kwargs.get("dtype", torch.get_default_dtype()) - from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta - tensor_meta = TensorMeta(size, (0,), dtype) spec = DTensorSpec(device_mesh, placements, tensor_meta=tensor_meta) @@ -91,13 +91,19 @@ def _dtensor_init_helper( else: local_tensor = init_op(local_shape, **kwargs) + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), + ) + return DTensor( - local_tensor=local_tensor, - device_mesh=device_mesh, - placements=tuple(placements), - shape=size, - dtype=local_tensor.dtype, - stride=torch_stride, + local_tensor, + spec, requires_grad=kwargs["requires_grad"], ) diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/_tensor/_dispatch.py index 17f565b6d776..1739243a5d3b 100644 --- a/torch/distributed/_tensor/_dispatch.py +++ b/torch/distributed/_tensor/_dispatch.py @@ -395,16 +395,7 @@ def wrap(res: object, spec: OutputSpecType) -> object: assert isinstance( spec, DTensorSpec ), f"output spec does not match with output! Expected DTensorSpec, got {spec}." - assert spec.tensor_meta is not None - return dtensor.DTensor( - res, - spec.mesh, - spec.placements, - shape=spec.tensor_meta.shape, - dtype=spec.tensor_meta.dtype, - requires_grad=res.requires_grad, - stride=spec.tensor_meta.stride, - ) + return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) else: # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor assert res.ndim == 0, "output tensor should be scalar!" diff --git a/torch/distributed/_tensor/_redistribute.py b/torch/distributed/_tensor/_redistribute.py index b72db29157f8..c8e54a98b927 100644 --- a/torch/distributed/_tensor/_redistribute.py +++ b/torch/distributed/_tensor/_redistribute.py @@ -12,6 +12,7 @@ Placement, Replicate, Shard, + TensorMeta, ) @@ -283,15 +284,12 @@ def forward( # type: ignore[override] else: # use the same local tensor if placements are the same. output = input._local_tensor + target_spec = current_spec return dtensor.DTensor( output, - device_mesh, - placements, - shape=input.shape, - dtype=input.dtype, + target_spec, requires_grad=input.requires_grad, - stride=input.stride(), ) @staticmethod @@ -316,14 +314,20 @@ def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] normalized_placements.append(Replicate()) else: normalized_placements.append(previous_placement) + + spec = DTensorSpec( + previous_spec.device_mesh, + tuple(normalized_placements), + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=grad_output.dtype, + ), + ) output_dtensor = dtensor.DTensor( output, - previous_spec.mesh, - tuple(normalized_placements), - shape=grad_output.shape, - dtype=grad_output.dtype, + spec, requires_grad=grad_output.requires_grad, - stride=grad_output.stride(), ) return ( diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 0a3f89af3c20..49fe7267c634 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -86,16 +86,21 @@ def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] ) tensor_stride = tuple(tensor_stride) grad_placements = grad_placements or dtensor_spec.placements + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) return ( DTensor( grad_output, - mesh, - grad_placements, - shape=dtensor_meta.shape, - dtype=dtensor_meta.dtype, + grad_spec, requires_grad=grad_output.requires_grad, - stride=tensor_stride, ), None, ) @@ -146,17 +151,23 @@ def forward( # type: ignore[override] input = input.contiguous() mesh_broadcast(input, device_mesh, mesh_dim=idx) + dist_spec = DTensorSpec( + device_mesh, + placements, + tensor_meta=TensorMeta( + tensor_shape, + tensor_stride, + input.dtype, + ), + ) + # We want a fresh Tensor object that shares memory with the input tensor dist_tensor = DTensor( input.view_as(input), - device_mesh, - placements, - shape=tensor_shape, - dtype=input.dtype, + dist_spec, # requires_grad of the dist tensor depends on if input # requires_grad or not requires_grad=input.requires_grad, - stride=tensor_stride, ) return dist_tensor @@ -202,13 +213,9 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ def __new__( cls, local_tensor: torch.Tensor, - device_mesh: DeviceMesh, - placements: Tuple[Placement, ...], + spec: DTensorSpec, *, - shape: torch.Size, - dtype: torch.dtype, requires_grad: bool, - stride: Tuple[int, ...], ) -> "DTensor": """ Construct a DTensor from a local tensor, device mesh, and placement and @@ -228,19 +235,18 @@ def __new__( # new method instruct wrapper tensor from local_tensor and add # placement spec, it does not do actual distribution + assert spec.tensor_meta is not None, "TensorMeta should not be None!" r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] cls, - shape, - strides=stride, - dtype=dtype, + spec.tensor_meta.shape, + strides=spec.tensor_meta.stride, + dtype=spec.tensor_meta.dtype, device=local_tensor.device, layout=local_tensor.layout, requires_grad=requires_grad, ) - tensor_meta = TensorMeta(shape, stride, dtype) - # deepcopy and set spec - r._spec = DTensorSpec(device_mesh, placements, tensor_meta=tensor_meta) + r._spec = spec r._local_tensor = local_tensor return r @@ -264,14 +270,20 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): ), "Expecting spec to be not None from `__tensor_flatten__` return value!" local_tensor = inner_tensors["_local_tensor"] spec, requires_grad = flatten_spec - return DTensor( - local_tensor, - spec.mesh, - spec.placements, + unflatten_tensor_meta = TensorMeta( shape=outer_size, + stride=outer_stride, dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=unflatten_tensor_meta, + ) + return DTensor( + local_tensor, + unflatten_spec, requires_grad=requires_grad, - stride=outer_stride, ) def __coerce_tangent_metadata__(self): @@ -638,14 +650,19 @@ def distribute_tensor( assert local_tensor is not None, "distributing a tensor should not be None" # detach the local tensor passed to DTensor since after the construction # of DTensor, autograd would work on top of DTensor instead of local tensor + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + shape=tensor.size(), + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) return DTensor( local_tensor.requires_grad_(tensor.requires_grad), - device_mesh, - placements, - shape=tensor.size(), - dtype=tensor.dtype, + spec, requires_grad=tensor.requires_grad, - stride=tensor.stride(), ) diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index f7144a38e923..8e7b7de84e1e 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -14,7 +14,7 @@ Reduction, replicate_reduction_dims, ) -from torch.distributed._tensor.placement_types import Placement, TensorMeta +from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten @@ -164,14 +164,16 @@ def _log_softmax_handler( res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim) - return DTensor( - res, + res_spec = DTensorSpec( spec.mesh, spec.placements, - shape=output_tensor_meta.shape, - dtype=output_tensor_meta.dtype, + tensor_meta=output_tensor_meta, + ) + + return DTensor( + res, + res_spec, requires_grad=res.requires_grad, - stride=output_tensor_meta.stride, ) @@ -317,16 +319,13 @@ def _nll_loss_forward_handler( spec.mesh, mesh_dim, ) + out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta) return ( DTensor( result, - spec.mesh, - output_placements, - shape=output_tensor_meta.shape, - dtype=output_tensor_meta.dtype, + out_spec, requires_grad=result.requires_grad, - stride=output_tensor_meta.stride, ), total_weight, ) @@ -452,16 +451,17 @@ def _nll_loss_backward_handler( spec.mesh, mesh_dim, ) + # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim + out_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=output_tensor_meta, + ) return DTensor( result, - spec.mesh, - # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim - spec.placements, - shape=output_tensor_meta.shape, - dtype=output_tensor_meta.dtype, + out_spec, requires_grad=result.requires_grad, - stride=output_tensor_meta.stride, ) From 2d47385f0f036ac5db78630efc4d235fdde1c024 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Thu, 6 Jun 2024 16:55:56 +0000 Subject: [PATCH 416/706] [BE]: Enable ruff TCH rules and autofixes for better imports (#127688) Automated fixes to put imports that are only used in type hints into TYPE_CHECKING imports. This also enables the RUFF TCH rules which will automatically apply autofixes to move imports in and out of TYPE_CHECKING blocks as needed in the future, this will make the initial PyTorch import faster and will reduce cyclic dependencies. Co-authored-by: Xuehai Pan Pull Request resolved: https://github.com/pytorch/pytorch/pull/127688 Approved by: https://github.com/XuehaiPan, https://github.com/ezyang, https://github.com/malfet --- pyproject.toml | 5 +++++ test/conftest.py | 6 ++++-- test/distributed/_spmd/test_data_parallel.py | 2 +- test/distributed/_spmd/test_graph_utils.py | 2 +- test/distributed/_spmd/test_tracing.py | 2 +- test/distributed/checkpoint/test_traverse.py | 5 ++++- test/higher_order_ops/test_with_effects.py | 2 +- test/onnx/internal/test_diagnostics.py | 4 +++- test/onnx/test_fx_op_consistency.py | 2 +- test/onnx/test_fx_to_onnx.py | 2 +- test/test_autograd.py | 2 +- torch/__init__.py | 2 +- torch/_dynamo/decorators.py | 1 + .../_functorch/_aot_autograd/autograd_cache.py | 5 ++++- torch/_functorch/partitioners.py | 7 ++++--- torch/_inductor/async_compile.py | 6 ++++-- torch/_inductor/autotune_process.py | 2 +- torch/_inductor/codecache.py | 5 +++-- torch/_inductor/codegen/simd.py | 3 --- torch/_inductor/codegen/triton.py | 17 +++++++++++++++-- torch/distributed/_cuda_p2p/__init__.py | 6 ++++-- .../distributed/pipelining/PipelineSchedule.py | 16 ++++++++++++++-- torch/masked/maskedtensor/_ops_refs.py | 10 +++++----- torch/nn/parallel/replicate.py | 8 ++++---- torch/onnx/_internal/fx/type_utils.py | 2 +- 25 files changed, 84 insertions(+), 40 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aa532c59da3c..24a917b80847 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,7 @@ select = [ "RUF016", # type error non-integer index "RUF017", "RUF018", # no assignment in assert + "TCH", "TRY002", # ban vanilla raise (todo fix NOQAs) "TRY302", "TRY401", # verbose-log-message @@ -175,6 +176,10 @@ select = [ # autogenerated #TODO figure out why file level noqa is ignored "torch/_inductor/fx_passes/serialized_patterns/**" = ["F401", "F501"] "torch/onnx/**" = [ + "TCH001", # beartype may need runtime types + "TCH002", + "TCH003", + "TCH004", "UP037", # ONNX does runtime type checking ] diff --git a/test/conftest.py b/test/conftest.py index 9ba728689285..5b84898df8a3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -7,10 +7,9 @@ import xml.etree.ElementTree as ET from collections import defaultdict from types import MethodType -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, TYPE_CHECKING, Union import pytest -from _pytest._code.code import ReprFileLocation from _pytest.config import Config, filename_arg from _pytest.config.argparsing import Parser from _pytest.junitxml import _NodeReporter, bin_xml_escape, LogXML @@ -20,6 +19,9 @@ from _pytest.terminal import _get_raw_skip_reason from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin +if TYPE_CHECKING: + from _pytest._code.code import ReprFileLocation + # a lot of this file is copied from _pytest.junitxml and modified to get rerun info xml_key = StashKey["LogXMLReruns"]() diff --git a/test/distributed/_spmd/test_data_parallel.py b/test/distributed/_spmd/test_data_parallel.py index 4940320c0724..140ed54c037c 100644 --- a/test/distributed/_spmd/test_data_parallel.py +++ b/test/distributed/_spmd/test_data_parallel.py @@ -12,7 +12,7 @@ from torch.distributed._tensor import Replicate from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests # noqa: TCH001 from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, diff --git a/test/distributed/_spmd/test_graph_utils.py b/test/distributed/_spmd/test_graph_utils.py index 2c90159237c7..2545678e0f15 100644 --- a/test/distributed/_spmd/test_graph_utils.py +++ b/test/distributed/_spmd/test_graph_utils.py @@ -2,7 +2,7 @@ import os from torch.distributed._spmd.graph_utils import dump_graphs_to_files -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests # noqa: TCH001 from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase diff --git a/test/distributed/_spmd/test_tracing.py b/test/distributed/_spmd/test_tracing.py index b77a87a7f44d..20ad2a6e06f9 100644 --- a/test/distributed/_spmd/test_tracing.py +++ b/test/distributed/_spmd/test_tracing.py @@ -20,7 +20,7 @@ from torch.nn import functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests # noqa: TCH001 from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms as base_with_comms, diff --git a/test/distributed/checkpoint/test_traverse.py b/test/distributed/checkpoint/test_traverse.py index 22ab029a612f..95e77a5662ee 100644 --- a/test/distributed/checkpoint/test_traverse.py +++ b/test/distributed/checkpoint/test_traverse.py @@ -1,13 +1,16 @@ # Owner(s): ["oncall: distributed"] from collections import OrderedDict +from typing import TYPE_CHECKING import torch import torch.distributed.checkpoint._traverse as _traverse -from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE from torch.testing._internal.common_utils import run_tests, TestCase +if TYPE_CHECKING: + from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + # TODO: add comments for TestTraverse class TestTraverse(TestCase): diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index ea53b57e8209..cd8b80e8e886 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -25,7 +25,7 @@ ) from torch.testing._internal.torchbind_impls import init_torchbind_implementations -from torch.utils.hooks import RemovableHandle +from torch.utils.hooks import RemovableHandle # noqa: TCH001 @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support") diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index 223ff04606db..3c5526de53f9 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -6,7 +6,6 @@ import io import logging import typing -import unittest from typing import AbstractSet, Protocol, Tuple import torch @@ -17,6 +16,9 @@ from torch.onnx._internal.fx import diagnostics as fx_diagnostics from torch.testing._internal import common_utils, logging_utils +if typing.TYPE_CHECKING: + import unittest + class _SarifLogBuilder(Protocol): def sarif_log(self) -> sarif.SarifLog: diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index 4a4171699e65..e72c4206d578 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -65,7 +65,7 @@ common_methods_invocations, common_utils, ) -from torch.testing._internal.opinfo import core as opinfo_core +from torch.testing._internal.opinfo import core as opinfo_core # noqa: TCH001 # NOTE: For ATen signature modifications that will break ONNX export, diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 6369ff3872d4..61cb9e807f70 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -17,7 +17,7 @@ from torch._subclasses import fake_tensor from torch.nn import functional as F from torch.onnx import dynamo_export, ExportOptions -from torch.onnx._internal.diagnostics import infra +from torch.onnx._internal.diagnostics import infra # noqa: TCH001 from torch.onnx._internal.fx import diagnostics, registration from torch.testing._internal import common_utils diff --git a/test/test_autograd.py b/test/test_autograd.py index 911762024930..ecd267dc9f77 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -81,7 +81,7 @@ from torch.utils._python_dispatch import TorchDispatchMode from torch.utils.checkpoint import checkpoint, checkpoint_sequential from torch.utils.cpp_extension import load_inline -from torch.utils.hooks import RemovableHandle +from torch.utils.hooks import RemovableHandle # noqa: TCH001 def graph_desc(fn): diff --git a/torch/__init__.py b/torch/__init__.py index 2efc457f33ba..896a2c50c36d 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -291,7 +291,7 @@ def load_shared_libraries(library_path): # Appease the type checker; ordinarily this binding is inserted by the # torch._C module initialization code in C if TYPE_CHECKING: - from . import _C as _C + from . import _C as _C # noqa: TCH004 class SymInt: """ diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 201dbd2f1453..87fdc6502436 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -1,3 +1,4 @@ +# ruff: noqa: TCH004 from dataclasses import dataclass from typing import TYPE_CHECKING diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 057aff8467c5..dd3ec09408aa 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -6,6 +6,7 @@ import functools import logging import os +from typing import TYPE_CHECKING import torch from torch._functorch import config @@ -16,10 +17,12 @@ FxGraphHashDetails, get_code_hash, ) -from torch.fx.node import Node from .schemas import AOTConfig # noqa: F401 +if TYPE_CHECKING: + from torch.fx.node import Node + log = logging.getLogger(__name__) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 0956ee7e367c..cbfb4ca17168 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -8,9 +8,7 @@ import os from collections import defaultdict from dataclasses import dataclass, replace -from typing import Callable, Dict, List, Optional, Set, Tuple, Union - -import sympy +from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union import torch import torch._inductor.inductor_prims @@ -29,6 +27,9 @@ from . import config from .compile_utils import fx_graph_cse, get_aten_target +if TYPE_CHECKING: + import sympy + AOT_PARTITIONER_DEBUG = config.debug_partitioner log = logging.getLogger(__name__) diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index c163df9bd878..633946bb4ed8 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -8,7 +8,7 @@ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor from functools import partial from time import time -from typing import Any, Callable, Dict, List, Optional, Set +from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING import torch from torch._dynamo.device_interface import get_registered_device_interfaces @@ -34,10 +34,12 @@ _set_triton_ptxas_path, _worker_compile_triton, ) -from torch._inductor.runtime.hints import HalideMeta from torch.hub import _Faketqdm, tqdm +if TYPE_CHECKING: + from torch._inductor.runtime.hints import HalideMeta + # timing metrics for time spent in the compilation _cumulative_compile_time = 0.0 _t0: Optional[float] = None diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index c9462d788e8d..e2503e1d8ca2 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -11,7 +11,6 @@ import warnings from concurrent.futures import ThreadPoolExecutor from ctypes import byref, c_size_t, c_void_p, CDLL -from types import ModuleType from typing import ( Any, Callable, @@ -41,6 +40,7 @@ if TYPE_CHECKING: from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue + from types import ModuleType from torch._inductor.select_algorithm import TritonTemplateCaller diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6e70adf1758b..a421432125b6 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -25,7 +25,6 @@ import threading import warnings from bisect import bisect_right -from concurrent.futures import Future from copy import copy from ctypes import c_void_p, cdll, CDLL from functools import partial @@ -56,7 +55,6 @@ _reload_python_module, _reload_python_module_in_subproc, ) -from torch._inductor.runtime.hints import HalideMeta from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.utils import clear_on_fresh_inductor_cache, is_linux @@ -69,8 +67,11 @@ from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv if TYPE_CHECKING: + from concurrent.futures import Future + from torch._inductor.graph import GraphLowering from torch._inductor.ir import ChoiceCaller + from torch._inductor.runtime.hints import HalideMeta _HERE = os.path.abspath(__file__) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 9140b1887f7f..ed7261f2a3eb 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -20,7 +20,6 @@ Sequence, Set, Tuple, - TYPE_CHECKING, Union, ) @@ -54,8 +53,6 @@ from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter from .multi_kernel import MultiKernel -if TYPE_CHECKING: - pass log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index f74086615c66..104d24585de2 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -7,7 +7,18 @@ import os import textwrap from functools import lru_cache -from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) import sympy @@ -23,7 +34,6 @@ from .. import config, ir from ..codecache import code_hash, get_path, PyCodeCache -from ..ir import IRNode from ..metrics import is_metric_table_enabled, log_kernel_metadata from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK from ..runtime.runtime_utils import do_bench_gpu, get_max_y_grid, next_power_of_2 @@ -52,6 +62,9 @@ from .simd import constant_repr, IterationRangesEntry, pexpr, SIMDKernel, SIMDScheduling from .triton_utils import config_of, signature_of, signature_to_meta +if TYPE_CHECKING: + from ..ir import IRNode + log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") diff --git a/torch/distributed/_cuda_p2p/__init__.py b/torch/distributed/_cuda_p2p/__init__.py index 2c77bd375f34..84fda06265d9 100644 --- a/torch/distributed/_cuda_p2p/__init__.py +++ b/torch/distributed/_cuda_p2p/__init__.py @@ -2,13 +2,15 @@ from contextlib import contextmanager from functools import partial -from typing import Callable, cast, Dict, List, Optional, Tuple, Union +from typing import Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d -from torch._C._distributed_c10d import _DistributedBackendOptions, Backend + +if TYPE_CHECKING: + from torch._C._distributed_c10d import _DistributedBackendOptions, Backend """ diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index 5c04ac824e69..8d696a5aa2b9 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -4,16 +4,28 @@ from abc import ABC, abstractmethod from collections import defaultdict from enum import Enum -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) import torch import torch.distributed as dist from torch.profiler import record_function -from ._IR import Pipe from .microbatch import merge_chunks, split_args_kwargs_into_chunks from .PipelineStage import _PipelineStageBase +if TYPE_CHECKING: + from ._IR import Pipe + __all__ = [ "PipelineScheduleSingle", diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 69c947bc262d..7544fc84ff9f 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -4,10 +4,6 @@ from typing import Any, Callable, Dict, TYPE_CHECKING import torch - -if TYPE_CHECKING: - import torch._ops - from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS from .core import ( _get_data, @@ -26,6 +22,10 @@ from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS +if TYPE_CHECKING: + from torch._ops import OpOverload + + __all__ = [] # type: ignore[var-annotated] @@ -226,7 +226,7 @@ def _function_to_sparse_csr(func, *args, **kwargs): return _MaskedToSparseCsr.apply(args[0]) -_MASKEDTENSOR_DISPATCH_TABLE: Dict["torch._ops.OpOverload", Callable[..., Any]] = {} +_MASKEDTENSOR_DISPATCH_TABLE: Dict["OpOverload", Callable[..., Any]] = {} def register_dispatch_func(aten_ops): diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index 016a6fbd0c40..fbe12d23ee8b 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -7,8 +7,8 @@ from collections import OrderedDict if TYPE_CHECKING: - import torch.jit - import torch.jit._state + from torch.jit import ScriptModule + from torch.jit._state import EnabledProxy __all__ = ['replicate'] @@ -22,12 +22,12 @@ def _is_script_method(module: Module) -> bool: return isinstance(module, torch._C.ScriptMethod) -def _init_script_module() -> "torch.jit.ScriptModule": +def _init_script_module() -> "ScriptModule": import torch.jit return torch.jit.ScriptModule() -def _is_jit_enabled() -> "torch.jit._state.EnabledProxy": +def _is_jit_enabled() -> "EnabledProxy": import torch.jit._state return torch.jit._state._enabled diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index b7f3d6cea642..90abdc244d99 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -22,7 +22,7 @@ from torch._subclasses import fake_tensor if TYPE_CHECKING: - import onnx.defs.OpSchema.AttrType # type: ignore[import] + import onnx.defs.OpSchema.AttrType # type: ignore[import] # noqa: TCH004 # Enable both TorchScriptTensor and torch.Tensor to be tested From 2ffdf556eabdad9ddc5bc84381139a17a9638a74 Mon Sep 17 00:00:00 2001 From: albanD Date: Thu, 6 Jun 2024 17:02:29 +0000 Subject: [PATCH 417/706] Add back API that some people rely on in torch.cuda.amp.grad_scaler namespace (#128056) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128056 Approved by: https://github.com/kit1980, https://github.com/eqy --- torch/cuda/amp/grad_scaler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py index 367f21594f1c..c108e7f49a01 100644 --- a/torch/cuda/amp/grad_scaler.py +++ b/torch/cuda/amp/grad_scaler.py @@ -2,6 +2,9 @@ import torch +# We need to keep this unused import for BC reasons +from torch.amp.grad_scaler import OptState # noqa: F401 + __all__ = ["GradScaler"] From e9c5144cbcff67922b5c977d5a809e79411f829c Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Thu, 6 Jun 2024 17:10:42 +0000 Subject: [PATCH 418/706] Fix bug in update_process_group DDP API (#128092) Fix bug in `_update_process_group` DDP API where we didn't correctly reset `local_used_map_` and a few other variables. This resulted in errors like `Encountered gradient which is undefined, but still allreduced by...` Added a unit test as well that reproduced the issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128092 Approved by: https://github.com/awgu, https://github.com/fegin --- torch/csrc/distributed/c10d/reducer.cpp | 6 +++ .../_internal/distributed/distributed_test.py | 51 +++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index d600426192ce..ae4db6bd7a17 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -2366,10 +2366,16 @@ void Reducer::reset_state() { // Ensure forward can run despite previous backward not succeeding. expect_autograd_hooks_ = false; require_finalize_ = false; + first_autograd_hook_called_ = false; // Unset allreduce division factor, as it may change in next backwards pass // when running with DDP join mode. div_factor_ = kUnsetDivFactor; + + // Reset unused parameter accounting. + // See Note [local_used_map_ -> local_used_map_dev copying] + local_used_map_.fill_(0); + local_used_map_reduced_ = false; } } // namespace c10d diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 92fff2623b31..77e9f1f9486f 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -9888,6 +9888,57 @@ def test_ddp_update_process_group_new_group(self): def test_ddp_update_process_group_default_group(self): self._run_ddp_update_process_group(new_pg=False) + @skip_if_lt_x_gpu(4) + @require_world_size(4) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_update_process_group_grad_undefined(self): + class SimulateError(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + raise RuntimeError + + class MyModel(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10).cuda(device) + self.fc2 = torch.nn.Linear(10, 10).cuda(device) + self.fc3 = torch.nn.Linear(10, 10).cuda(device) + + def forward(self, inp, error): + if error: + return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp)))) + else: + return self.fc2(self.fc1(inp)) + + + input = torch.rand(10, 10, requires_grad=True).cuda(self.rank) + ddp = torch.nn.parallel.DistributedDataParallel( + MyModel(self.rank), + device_ids=[self.rank], + find_unused_parameters=True, + bucket_cap_mb=1, + ) + + try: + ddp(input, True).sum().backward() + except RuntimeError: + ddp._update_process_group(_get_default_group()) + + # Reset grads. + for param in ddp.parameters(): + param.grad = None + + # Run ddp again. + ddp(input, False).sum().backward() + + @skip_if_lt_x_gpu(2) @skip_but_pass_in_sandcastle_if( BACKEND not in DistTestCases.backend_feature["ddp"], From 1d0c1087dd4a32906edb88690f7187902bd7ad65 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Mon, 3 Jun 2024 14:03:24 -0700 Subject: [PATCH 419/706] Allow overriding per-dim group options via _MeshEnv.set_dim_group_options (#126599) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126599 Approved by: https://github.com/wanchaol ghstack dependencies: #126598 --- test/distributed/test_device_mesh.py | 9 +++++++++ torch/distributed/device_mesh.py | 27 ++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index c024ddc98690..26a9bd1e0b3d 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -208,6 +208,15 @@ def test_raises_invalid_device_type(self): "cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp") ) + @with_comms + def test_set_mesh_dim_group_options(self): + device_type = "cuda" if torch.cuda.is_available() else "cpu" + _mesh_resources._set_mesh_dim_group_options(1, "fake", None) + + mesh_tensor = torch.arange(4).reshape(2, 2) + mesh = DeviceMesh(device_type, mesh_tensor) + self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake") + class DeviceMeshTestNDim(DTensorTestBase): @property diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 2913f28f9e36..5bbe4e113464 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -62,6 +62,9 @@ class _MeshEnv(threading.local): def __init__(self) -> None: self.mesh_stack: List[DeviceMesh] = [] self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {} + self.mesh_dim_group_options: Dict[ + int, Tuple[str, Optional[ProcessGroup.Options]] + ] = {} def get_current_mesh(self) -> "DeviceMesh": if len(self.mesh_stack) == 0: @@ -155,6 +158,14 @@ def get_mesh_dim_by_name( ) return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name)) + def _set_mesh_dim_group_options( + self, + dim: int, + backend: str, + pg_options: Optional[ProcessGroup.Options] = None, + ) -> None: + self.mesh_dim_group_options[dim] = (backend, pg_options) + _mesh_resources: _MeshEnv = _MeshEnv() def _get_device_handle(device_type: str = "cuda"): @@ -312,10 +323,24 @@ def _init_process_groups(self): for dim_mesh in pg_ranks_by_dim: subgroup_ranks = dim_mesh.tolist() + # Respect dim group options specified via _MeshEnv.set_dim_group_options(). + # Inherit from the parent group if no options are specified for the group. + if dim in _mesh_resources.mesh_dim_group_options: + ( + backend, + pg_options, + ) = _mesh_resources.mesh_dim_group_options[dim] + else: + backend, pg_options = None, None + # We temporarily revert the re-use subgroup, since it breaks two internal tests. # Temporarily reverting to resolve test timeout while root-causing. # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. - dim_group = new_group(ranks=subgroup_ranks) + dim_group = new_group( + ranks=subgroup_ranks, + backend=backend, + pg_options=pg_options, + ) # only add to dim_groups if the current rank in the subgroup if self.get_rank() in subgroup_ranks: From 304956e1fb6f8fac2fa3702ea08ca1c0a55687d2 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Mon, 3 Jun 2024 15:29:31 +0000 Subject: [PATCH 420/706] Switch to torch.float16 on XPU AMP mode (#127741) # Motivation Previously, the default dtype for AMP on XPU was aligned with the CPU. To align with other GPUs, we intend to change the default dtype for AMP to `torch.float16`. This change aims to save users the effort of converting models from `torch.float16` to `torch.bfloat16`, or vice versa when they want to run the model on different types of GPUs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127741 Approved by: https://github.com/EikanWang, https://github.com/albanD --- aten/src/ATen/autocast_mode.cpp | 2 +- test/test_xpu.py | 130 ++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index f0c73cde2dda..10fb72796fc6 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -68,7 +68,7 @@ thread_local std::array at::kBFloat16, // XLA / TPU at::ScalarType::Undefined, // Vulkan at::ScalarType::Undefined, // Metal - at::kBFloat16, // XPU + at::kHalf, // XPU at::ScalarType::Undefined, // MPS at::ScalarType::Undefined, // Meta (tensors with no data) at::kBFloat16, // HPU / HABANA diff --git a/test/test_xpu.py b/test/test_xpu.py index a3838f1d5a05..86a0bc6fa2b9 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,11 +1,13 @@ # Owner(s): ["module: intel"] +import collections import sys import tempfile import unittest import torch import torch.xpu._gpu_trace as gpu_trace +from torch.testing._internal.autocast_test_lists import AutocastTestLists from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyXPU, @@ -309,6 +311,134 @@ def test_serialization_array_with_empty(self): instantiate_device_type_tests(TestXpu, globals(), only_for="xpu") +class TestXpuAutocast(TestCase): + def setUp(self): + super().setUp() + self.autocast_lists = AutocastTestLists(torch.device("xpu")) + + def tearDown(self): + del self.autocast_lists + super().tearDown() + + def _run_autocast_outofplace( + self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None + ): + # helper to cast args + def cast(val, to_type): + if isinstance(val, torch.Tensor): + return val.to(to_type) if val.is_floating_point() else val + elif isinstance(val, collections.abc.Iterable): + return type(val)(cast(v, to_type) for v in val) + else: + return val + + if add_kwargs is None: + add_kwargs = {} + fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16 + self.assertFalse(torch.is_autocast_enabled()) + with torch.amp.autocast("xpu", dtype=fast_dtype): + self.assertTrue(torch.is_autocast_enabled()) + + out_type = out_type if out_type is not None else run_as_type + output = output_method = None + + # Try module.* variant, if requested: + if module is not None and hasattr(module, op): + output = getattr(module, op)(*args, **add_kwargs) + if isinstance(output, torch.Tensor): + self.assertTrue( + out_type == output.dtype, + f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", + ) + + # Try Tensor.* variant: + if hasattr(torch.Tensor, op): + output_method = getattr(args[0], op)(*args[1:], **add_kwargs) + if isinstance(output_method, torch.Tensor): + self.assertTrue( + out_type == output_method.dtype, + f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", + ) + + self.assertTrue( + (output is not None) or (output_method is not None), + f"{op} not found as an attribute on either Tensor or the requested module {module}", + ) + + # Accounts for ops that return Tensors, iterables, and other non-Tensors. + # For example, lstm_cell returns a tuple and equal returns bool. + def compare(first, second): + if isinstance(first, torch.Tensor): + return torch.equal(first, second) + elif isinstance(first, collections.abc.Iterable): + return all(compare(f, s) for f, s in zip(first, second)) + else: + return first == second + + # If both torch.* and Tensor.* variants were found, check outputs are identical + if (output is not None) and (output_method is not None): + self.assertTrue(type(output) == type(output_method)) + comparison = compare(output, output_method) + self.assertTrue( + comparison, f"torch.{op} result did not match Tensor.{op} result" + ) + + # Compare numerics to Python-side "autocasting" that (we expect) does the same thing + # as the C++-side autocasting, and should be bitwise accurate. + output_to_compare = output if output is not None else output_method + with torch.amp.autocast("xpu", enabled=False): + self.assertFalse(torch.is_autocast_enabled()) + + if module is not None and hasattr(module, op): + control = getattr(module, op)( + *cast(args, run_as_type), **add_kwargs + ) + else: + control = getattr(args[0].to(run_as_type), op)( + *cast(args[1:], run_as_type), **add_kwargs + ) + self.assertTrue(type(output_to_compare) == type(control)) + comparison = compare(output_to_compare, control) + self.assertTrue(comparison, f"torch.{op} result did not match control") + self.assertTrue(torch.is_autocast_enabled()) + self.assertFalse(torch.is_autocast_enabled()) + + def test_autocast_torch_fp16(self): + for op_with_args in self.autocast_lists.torch_fp16: + skip_test = False + op, args = op_with_args[0], op_with_args[1] + if len(op_with_args) == 3: + skip_test = True # skip cudnn op + if not skip_test: + self._run_autocast_outofplace(op, args, torch.float16) + + def test_autocast_torch_bf16(self): + for op_with_args in self.autocast_lists.torch_fp16: + skip_test = False + op, args = op_with_args[0], op_with_args[1] + if len(op_with_args) == 3: + skip_test = True # skip cudnn op + if not skip_test: + self._run_autocast_outofplace(op, args, torch.bfloat16) + + def test_autocast_torch_need_autocast_promote(self): + for op, args in self.autocast_lists.torch_need_autocast_promote: + self._run_autocast_outofplace(op, args, torch.float32) + + def test_autocast_torch_expect_builtin_promote(self): + for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: + self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + + def test_xpu_autocast_dtype(self): + dtype = torch.get_autocast_dtype("xpu") + self.assertEqual(dtype, torch.float16) + mat0_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu") + mat1_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu") + with torch.amp.autocast("xpu"): + result = torch.mm(mat0_fp32, mat1_fp32) + self.assertEqual(result.dtype, torch.float16) + + class TestXpuTrace(TestCase): def setUp(self): torch._C._activate_gpu_trace() From 78a6b0c4793d93d0a9105d9c92e7b88794016e66 Mon Sep 17 00:00:00 2001 From: laithsakka Date: Wed, 5 Jun 2024 07:39:17 -0700 Subject: [PATCH 421/706] update test_reformer_train test to handle nn module inlining (#127467) number of call nodes increase due to inlining before inlining: ``` class GraphModule(torch.nn.Module): def forward(self, function_ctx, cat: "f32[1, s0, 512]"): # No stacktrace found for following nodes _set_grad_enabled = torch._C._set_grad_enabled(False) # File: /data/users/lsakka/pytorch/pytorch/test/dynamo/test_repros.py:283 in backward, code: grad_attn_output, grad_hidden_states = torch.chunk( chunk = torch.chunk(cat, 2, dim = -1); cat = None getitem: "f32[1, s0, 256]" = chunk[0] getitem_1: "f32[1, s0, 256]" = chunk[1]; chunk = None # No stacktrace found for following nodes _set_grad_enabled_1 = torch._C._set_grad_enabled(True) return (getitem_1, None) ``` after inlining: ``` class GraphModule(torch.nn.Module): def forward(self, s0: "Sym(s0)", L_hidden_states_: "f32[1, s0, 256]", L_self_layers_0_weight: "f32[256, 256]", L_self_layers_0_bias: "f32[256]", L_self_layer_norm_weight: "f32[512]", L_self_layer_norm_bias: "f32[512]", L_self_layer_norm_normalized_shape_0_: "Sym(512)"): l_hidden_states_ = L_hidden_states_ l_self_layers_0_weight = L_self_layers_0_weight l_self_layers_0_bias = L_self_layers_0_bias l_self_layer_norm_weight = L_self_layer_norm_weight l_self_layer_norm_bias = L_self_layer_norm_bias l_self_layer_norm_normalized_shape_0_ = L_self_layer_norm_normalized_shape_0_ # File: /data/users/lsakka/pytorch/pytorch/test/dynamo/test_repros.py:332 in forward, code: hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) hidden_states: "f32[1, s0, 512]" = torch.cat([l_hidden_states_, l_hidden_states_], dim = -1); l_hidden_states_ = None # File: /data/users/lsakka/pytorch/pytorch/test/dynamo/test_repros.py:333 in forward, code: hidden_states = _ReversibleFunction.apply( function_ctx = torch.autograd.function.FunctionCtx() # File: /data/users/lsakka/pytorch/pytorch/test/dynamo/test_repros.py:258 in forward, code: hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) chunk = torch.chunk(hidden_states, 2, dim = -1); hidden_states = None hidden_states_1: "f32[1, s0, 256]" = chunk[0] attn_output: "f32[1, s0, 256]" = chunk[1]; chunk = None # File: /data/users/lsakka/pytorch/pytorch/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) attn_output_1: "f32[1, s0, 256]" = torch._C._nn.linear(attn_output, l_self_layers_0_weight, l_self_layers_0_bias); attn_output = l_self_layers_0_weight = l_self_layers_0_bias = None # File: /data/users/lsakka/pytorch/pytorch/test/dynamo/test_repros.py:272 in forward, code: ctx.save_for_backward(attn_output.detach(), hidden_states.detach()) detach: "f32[1, s0, 256]" = attn_output_1.detach() detach_1: "f32[1, s0, 256]" = hidden_states_1.detach() # File: /data/users/lsakka/pytorch/pytorch/test/dynamo/test_repros.py:279 in forward, code: return torch.cat([attn_output, hidden_states], dim=-1) hidden_states_2: "f32[1, s0, 512]" = torch.cat([attn_output_1, hidden_states_1], dim = -1); attn_output_1 = hidden_states_1 = None # File: /data/users/lsakka/pytorch/pytorch/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( hidden_states_3: "f32[1, s0, 512]" = torch.nn.functional.layer_norm(hidden_states_2, (l_self_layer_norm_normalized_shape_0_,), l_self_layer_norm_weight, l_self_layer_norm_bias, 1e-12); hidden_states_2 = l_self_layer_norm_normalized_shape_0_ = l_self_layer_norm_weight = l_self_layer_norm_bias = None # File: /data/users/lsakka/pytorch/pytorch/test/dynamo/test_repros.py:352 in forward, code: hidden_states = torch.nn.functional.dropout( hidden_states_4: "f32[1, s0, 512]" = torch.nn.functional.dropout(hidden_states_3, p = 0.5, training = True); hidden_states_3 = None return (hidden_states_4,) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127467 Approved by: https://github.com/anijain2305 ghstack dependencies: #126444, #127146, #127424, #127440 --- test/dynamo/test_repros.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 647507be076f..771b9e96c88c 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -2,6 +2,7 @@ PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes with test_rewrite_assert_with_msg and test_rewrite_assert_without_msg) """ + # Owner(s): ["module: dynamo"] import collections import contextlib @@ -1150,13 +1151,12 @@ def test_reformer_eval(self): def test_reformer_train(self): with torch.enable_grad(): cnt = self._reformer(nopython=False) - # cant inline torch.autograd.Function means graph break - if torch._dynamo.config.assume_static_by_default: - self.assertExpectedInline(cnt.frame_count, """1""") - self.assertExpectedInline(cnt.op_count, """5""") - else: - self.assertExpectedInline(cnt.frame_count, """1""") - self.assertExpectedInline(cnt.op_count, """5""") + expected_op_count = ( + """11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5""" + ) + + self.assertExpectedInline(cnt.frame_count, """1""") + self.assertExpectedInline(cnt.op_count, expected_op_count) @disable_translation_validation_if_dynamic_shapes def test_longformer_chunk(self): From 32fb68960e1b44668d2c7d1ab4c8aa7458c1088a Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 6 Jun 2024 08:09:00 -0700 Subject: [PATCH 422/706] [FSDP2] Added experimental warning to `unshard` API (#128138) There is still ongoing discussion on how this API should work. Current approach: - The pre-all-gather ops run in the default stream and the all-gather is called from the default stream with `async_op=True`. - Pros: - The all-gather input and output tensors are allocated in the default stream, so there is no increased memory fragmentation across stream pools. - There is no need for additional CUDA synchronization. The API is self-contained. - Cons: - The pre-all-gather ops (e.g. cast from fp32 -> bf16 and all-gather copy-in device copies) cannot overlap with other default stream compute. The biggest concern here is for CPU offloading, the H2D copies cannot overlap. Alternative approach: - Follow the default implicit prefetching approach, where the pre-all-gather ops and all-gather run in separate streams. - Pros: - The pre-all-gather ops can overlap with default stream compute. - Cons: - We require an API that should be called after the last optimizer step (namely, last op that modified sharded parameters) and before the first `unshard` call that has the all-gather streams wait for the default stream. The API is no longer self-contained and now has a complementary API. - The all-gather input and output tensors are allocated in separate streams (not the default stream), so there can be increased memory fragmentation across pools. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128138 Approved by: https://github.com/wanchaol ghstack dependencies: #128100 --- torch/distributed/_composable/fsdp/fully_shard.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 337c9a7e40b8..ca050790cdd6 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -181,6 +181,8 @@ def unshard(self, async_op: bool = False) -> Optional["UnshardHandle"]: ``False``, then returns ``None`` and waits on the handle inside this function. + .. warning:: This method is experimental and subject to change. + .. note:: If ``async_op=True``, then the user does not have to call :meth:`wait` on the returned handle if waiting on the unshard op in the module's pre-forward is tolerable. FSDP will wait on the From 936225d7b2fd46c1252a9df9f9278207fa47b7d6 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Thu, 6 Jun 2024 18:22:20 +0000 Subject: [PATCH 423/706] [mergebot] Fix pending unstable jobs being viewed as failed (#128080) https://github.com/pytorch/pytorch/pull/128038#issuecomment-2150802030 In the above, pending unstable jobs get put into the ok_failed_checks list, and because there are a lot of unstable jobs, it exceeds the threshold and merge fails. I don't think unstable jobs should be considered in the ok failed checks threshold, only flaky and broken trunk jobs should be considered there. Change looks big, but main thing is that unstable jobs don't get included in the check for how many flaky failures there are. The other changes are mostly renames so things are clearer Pull Request resolved: https://github.com/pytorch/pytorch/pull/128080 Approved by: https://github.com/huydhn --- .github/scripts/test_trymerge.py | 12 ++++++------ .github/scripts/trymerge.py | 30 +++++++++++++++--------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/.github/scripts/test_trymerge.py b/.github/scripts/test_trymerge.py index 2641fd30f348..ec3e69b706f8 100755 --- a/.github/scripts/test_trymerge.py +++ b/.github/scripts/test_trymerge.py @@ -773,13 +773,13 @@ def test_get_classifications_broken_trunk(self, *args: Any) -> None: # than the one on the base commit. This should still count as broken trunk "pr_num": 104214, "related_failure_count": 0, - "unrelated_failure_count": 1, + "flaky_or_broken_trunk": 1, }, { # This PR had one broken trunk failure and it used ghstack "pr_num": 105145, "related_failure_count": 0, - "unrelated_failure_count": 1, + "flaky_or_broken_trunk": 1, }, { # The failure on the merge base was retried successfully and @@ -788,20 +788,20 @@ def test_get_classifications_broken_trunk(self, *args: Any) -> None: # be used to detect broken trunk "pr_num": 107160, "related_failure_count": 0, - "unrelated_failure_count": 4, + "flaky_or_broken_trunk": 1, }, { # This PR used Dr.CI broken trunk classification "pr_num": 111253, "related_failure_count": 1, - "unrelated_failure_count": 2, + "flaky_or_broken_trunk": 1, }, ] for case in test_cases: pr_num = case["pr_num"] related_failure_count = case["related_failure_count"] - unrelated_failure_count = case["unrelated_failure_count"] + flaky_or_broken_trunk = case["flaky_or_broken_trunk"] pr = GitHubPR("pytorch", "pytorch", pr_num) checks = pr.get_checkrun_conclusions() @@ -823,7 +823,7 @@ def test_get_classifications_broken_trunk(self, *args: Any) -> None: ) self.assertTrue(len(pending) == 0) self.assertTrue( - len(failed) == unrelated_failure_count + related_failure_count + len(failed) == flaky_or_broken_trunk + related_failure_count ) def test_ignore_current(self, *args: Any) -> None: diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 95311d2d9b83..6a6d080a9b3a 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -2027,10 +2027,8 @@ def categorize_checks( pending_checks: List[Tuple[str, Optional[str], Optional[int]]] = [] failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = [] - # ok_failed_checks is used with ok_failed_checks_threshold while ignorable_failed_checks - # is used to keep track of all ignorable failures when saving the merge record on Rockset - ok_failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = [] - ignorable_failed_checks: Dict[str, List[Any]] = defaultdict(list) + # failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on Rockset + failed_checks_categorization: Dict[str, List[Any]] = defaultdict(list) # If required_checks is not set or empty, consider all names are relevant relevant_checknames = [ @@ -2058,36 +2056,38 @@ def categorize_checks( continue elif not is_passing_status(check_runs[checkname].status): target = ( - ignorable_failed_checks[classification] + failed_checks_categorization[classification] if classification in ("IGNORE_CURRENT_CHECK", "BROKEN_TRUNK", "FLAKY", "UNSTABLE") else failed_checks ) target.append((checkname, url, job_id)) - if classification in ("BROKEN_TRUNK", "FLAKY", "UNSTABLE"): - ok_failed_checks.append((checkname, url, job_id)) + flaky_or_broken_trunk = ( + failed_checks_categorization["BROKEN_TRUNK"] + + failed_checks_categorization["FLAKY"] + ) - if ok_failed_checks: + if flaky_or_broken_trunk: warn( - f"The following {len(ok_failed_checks)} checks failed but were likely due flakiness or broken trunk: " - + ", ".join([x[0] for x in ok_failed_checks]) + f"The following {len(flaky_or_broken_trunk)} checks failed but were likely due flakiness or broken trunk: " + + ", ".join([x[0] for x in flaky_or_broken_trunk]) + ( f" but this is greater than the threshold of {ok_failed_checks_threshold} so merge will fail" if ok_failed_checks_threshold is not None - and len(ok_failed_checks) > ok_failed_checks_threshold + and len(flaky_or_broken_trunk) > ok_failed_checks_threshold else "" ) ) if ( ok_failed_checks_threshold is not None - and len(ok_failed_checks) > ok_failed_checks_threshold + and len(flaky_or_broken_trunk) > ok_failed_checks_threshold ): - failed_checks = failed_checks + ok_failed_checks + failed_checks = failed_checks + flaky_or_broken_trunk - # The list of ignorable_failed_checks is returned so that it can be saved into the Rockset merge record - return (pending_checks, failed_checks, ignorable_failed_checks) + # The list of failed_checks_categorization is returned so that it can be saved into the Rockset merge record + return (pending_checks, failed_checks, failed_checks_categorization) def merge( From fba21edf5b9aa14babb9c0bc860dc9c597eb8010 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Thu, 6 Jun 2024 18:23:50 +0000 Subject: [PATCH 424/706] [CI] Ensure inductor/test_cpu_cpp_wrapper is actually run in inductor_cpp_wrapper_abi_compatible (#126717) `inductor/test_cpu_cpp_wrapper` is not actually being run in `inductor_cpp_wrapper_abi_compatible` test config The cpu device type gets removed in https://github.com/pytorch/pytorch/blob/d28868c7e8bcd41c9219f099aa5f7a5332c912fd/torch/testing/_internal/common_device_type.py#L733 so https://github.com/pytorch/pytorch/blob/d28868c7e8bcd41c9219f099aa5f7a5332c912fd/test/inductor/test_cpu_cpp_wrapper.py#L396 returns false. Feel free to make a PR with a different way to do this (a better RUN_CPU check?) Add a skip for a failing test. I am not equipped to fix it Pull Request resolved: https://github.com/pytorch/pytorch/pull/126717 Approved by: https://github.com/ZainRizvi --- .ci/pytorch/test.sh | 2 +- test/inductor/test_cpu_cpp_wrapper.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index ee4bf37fdb0b..d8eb45ee1d95 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -368,7 +368,7 @@ test_inductor_cpp_wrapper_abi_compatible() { echo "Testing Inductor cpp wrapper mode with TORCHINDUCTOR_ABI_COMPATIBLE=1" # cpu stack allocation causes segfault and needs more investigation - python test/run_test.py --include inductor/test_cpu_cpp_wrapper + PYTORCH_TESTING_DEVICE_ONLY_FOR="" python test/run_test.py --include inductor/test_cpu_cpp_wrapper python test/run_test.py --include inductor/test_cuda_cpp_wrapper TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/timm_models.py --device cuda --accuracy --amp \ diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 477193664431..e77c6a5a8208 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -115,6 +115,7 @@ def make_test_case( slow=False, func_inputs=None, code_string_count=None, + skip=None, ): test_name = f"{name}_{device}" if device else name if code_string_count is None: @@ -123,6 +124,8 @@ def make_test_case( func = getattr(tests, test_name) assert callable(func), "not a callable" func = slowTest(func) if slow else func + if skip: + func = unittest.skip(skip)(func) @config.patch(cpp_wrapper=True, search_autotune_cache=False) def fn(self): @@ -170,6 +173,7 @@ class BaseTest(NamedTuple): slow: bool = False func_inputs: list = None code_string_count: dict = {} + skip: str = None for item in [ BaseTest("test_add_complex"), @@ -228,7 +232,9 @@ class BaseTest(NamedTuple): torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported(), ), - BaseTest("test_linear_packed", "", test_cpu_repro.CPUReproTests()), + BaseTest( + "test_linear_packed", "", test_cpu_repro.CPUReproTests(), skip="Failing" + ), BaseTest( "test_lstm_packed_change_input_sizes", "cpu", @@ -302,18 +308,21 @@ class BaseTest(NamedTuple): "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), + skip="Failing", ), BaseTest( "test_qlinear_add", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), + skip="Failing", ), BaseTest( "test_qlinear_add_relu", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), + skip="Failing", ), BaseTest( "test_qlinear_dequant_promotion", @@ -369,6 +378,7 @@ class BaseTest(NamedTuple): item.slow, item.func_inputs, item.code_string_count, + skip=item.skip, ) test_torchinductor.copy_tests( From de4f8b99469da0c7d56362441cb25d4e62cfd2db Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Thu, 6 Jun 2024 18:45:22 +0000 Subject: [PATCH 425/706] [BE]: Update cudnn to 9.1.0.70 (#123475) cuDNN has managed to upload cu11 and cu12 wheels for ~~9.0.0.312~~ 9.1.0.70, so trying this out... CC @Skylion007 @malfet Co-authored-by: Wei Wang Co-authored-by: atalman Pull Request resolved: https://github.com/pytorch/pytorch/pull/123475 Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/nWEIdia, https://github.com/atalman --- .ci/docker/build.sh | 50 +++++++++---------- .ci/docker/common/install_base.sh | 2 +- .ci/docker/common/install_cudnn.sh | 17 +++---- .ci/docker/ubuntu-cuda/Dockerfile | 2 +- .../scripts/generate_binary_build_matrix.py | 8 +-- .github/workflows/docker-builds.yml | 18 +++---- ...linux-aarch64-binary-manywheel-nightly.yml | 10 ++-- .../generated-linux-binary-manywheel-main.yml | 6 +-- ...nerated-linux-binary-manywheel-nightly.yml | 30 +++++------ ...d-linux-s390x-binary-manywheel-nightly.yml | 10 ++-- ...rated-macos-arm64-binary-wheel-nightly.yml | 10 ++-- ...generated-windows-binary-wheel-nightly.yml | 40 +++++++-------- .../workflows/inductor-micro-benchmark.yml | 2 +- .github/workflows/inductor-perf-compare.yml | 2 +- .../workflows/inductor-perf-test-nightly.yml | 2 +- .github/workflows/inductor-periodic.yml | 8 +-- .github/workflows/inductor.yml | 8 +-- .github/workflows/lint.yml | 4 +- .github/workflows/periodic.yml | 8 +-- .github/workflows/pull.yml | 20 ++++---- .github/workflows/slow.yml | 4 +- .../target-determination-indexer.yml | 2 +- .github/workflows/torchbench.yml | 2 +- .github/workflows/trunk.yml | 10 ++-- .../aot_eager_timm_training.csv | 2 +- .../dynamic_inductor_torchbench_training.csv | 2 +- .../cu124/inductor_torchbench_training.csv | 2 +- .../dynamic_aot_eager_timm_training.csv | 2 +- .../dynamic_inductor_timm_training.csv | 2 +- .../inductor_timm_training.csv | 2 +- docker.Makefile | 2 +- 31 files changed, 142 insertions(+), 147 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index fa4dbf2b0165..537b0b9d2ba7 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -91,9 +91,9 @@ _UCC_COMMIT=20eae37090a4ce1b32bcce6144ccad0b49943e0b # configuration, so we hardcode everything here rather than do it # from scratch case "$image" in - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -105,9 +105,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -119,9 +119,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -134,9 +134,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -149,9 +149,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=9 PROTOBUF=yes @@ -164,9 +164,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=9 PROTOBUF=yes @@ -179,9 +179,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9) CUDA_VERSION=11.8.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -193,9 +193,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -207,9 +207,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -221,9 +221,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -330,10 +330,10 @@ case "$image" in DOCS=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12) + pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12) ANACONDA_PYTHON_VERSION=3.8 CUDA_VERSION=11.8 - CUDNN_VERSION=8 + CUDNN_VERSION=9 CLANG_VERSION=12 PROTOBUF=yes DB=yes @@ -380,7 +380,7 @@ case "$image" in ANACONDA_PYTHON_VERSION=3.9 CONDA_CMAKE=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter) + pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter) ANACONDA_PYTHON_VERSION=3.9 CUDA_VERSION=11.8 CONDA_CMAKE=yes @@ -447,7 +447,7 @@ tmp_tag=$(basename "$(mktemp -u)" | tr '[:upper:]' '[:lower:]') #when using cudnn version 8 install it separately from cuda if [[ "$image" == *cuda* && ${OS} == "ubuntu" ]]; then IMAGE_NAME="nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu${UBUNTU_VERSION}" - if [[ ${CUDNN_VERSION} == 8 ]]; then + if [[ ${CUDNN_VERSION} == 9 ]]; then IMAGE_NAME="nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}" fi fi @@ -499,7 +499,7 @@ docker build \ "$@" \ . -# NVIDIA dockers for RC releases use tag names like `11.0-cudnn8-devel-ubuntu18.04-rc`, +# NVIDIA dockers for RC releases use tag names like `11.0-cudnn9-devel-ubuntu18.04-rc`, # for this case we will set UBUNTU_VERSION to `18.04-rc` so that the Dockerfile could # find the correct image. As a result, here we have to replace the # "$UBUNTU_VERSION" == "18.04-rc" diff --git a/.ci/docker/common/install_base.sh b/.ci/docker/common/install_base.sh index ebaa17878ade..fd58ad8a60b8 100755 --- a/.ci/docker/common/install_base.sh +++ b/.ci/docker/common/install_base.sh @@ -3,7 +3,7 @@ set -ex install_ubuntu() { - # NVIDIA dockers for RC releases use tag names like `11.0-cudnn8-devel-ubuntu18.04-rc`, + # NVIDIA dockers for RC releases use tag names like `11.0-cudnn9-devel-ubuntu18.04-rc`, # for this case we will set UBUNTU_VERSION to `18.04-rc` so that the Dockerfile could # find the correct image. As a result, here we have to check for # "$UBUNTU_VERSION" == "18.04"* diff --git a/.ci/docker/common/install_cudnn.sh b/.ci/docker/common/install_cudnn.sh index 3afd2f28841f..60f4561d420c 100644 --- a/.ci/docker/common/install_cudnn.sh +++ b/.ci/docker/common/install_cudnn.sh @@ -1,23 +1,18 @@ #!/bin/bash -if [[ ${CUDNN_VERSION} == 8 ]]; then +if [[ -n "${CUDNN_VERSION}" ]]; then # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement mkdir tmp_cudnn pushd tmp_cudnn - if [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-8.9.7.29_cuda12-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz - elif [[ ${CUDA_VERSION:0:4} == "12.1" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-8.9.2.26_cuda12-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz - elif [[ ${CUDA_VERSION:0:4} == "11.8" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-8.7.0.84_cuda11-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/redist/cudnn/v8.7.0/local_installers/11.8/${CUDNN_NAME}.tar.xz + if [[ ${CUDA_VERSION:0:2} == "12" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda12-archive" + elif [[ ${CUDA_VERSION:0:2} == "11" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda11-archive" else print "Unsupported CUDA version ${CUDA_VERSION}" exit 1 fi - + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz tar xf ${CUDNN_NAME}.tar.xz cp -a ${CUDNN_NAME}/include/* /usr/local/cuda/include/ cp -a ${CUDNN_NAME}/lib/* /usr/local/cuda/lib64/ diff --git a/.ci/docker/ubuntu-cuda/Dockerfile b/.ci/docker/ubuntu-cuda/Dockerfile index cb3ea502d231..3b2bbea0097a 100644 --- a/.ci/docker/ubuntu-cuda/Dockerfile +++ b/.ci/docker/ubuntu-cuda/Dockerfile @@ -139,7 +139,7 @@ COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm ARG CUDNN_VERSION ARG CUDA_VERSION COPY ./common/install_cudnn.sh install_cudnn.sh -RUN if [ "${CUDNN_VERSION}" -eq 8 ]; then bash install_cudnn.sh; fi +RUN if [ -n "${CUDNN_VERSION}" ]; then bash install_cudnn.sh; fi RUN rm install_cudnn.sh # Install CUSPARSELT diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 855f37af7eda..920ca65fbf52 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -19,7 +19,7 @@ CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.1": "12.1.1", "12.4": "12.4.0"} -CUDA_ARCHES_CUDNN_VERSION = {"11.8": "8", "12.1": "8", "12.4": "8"} +CUDA_ARCHES_CUDNN_VERSION = {"11.8": "9", "12.1": "9", "12.4": "9"} ROCM_ARCHES = ["6.0", "6.1"] @@ -42,7 +42,7 @@ "nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950 "nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " @@ -55,7 +55,7 @@ "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950 "nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | " @@ -68,7 +68,7 @@ "nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | " diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 0eec1556bb96..f732dab42050 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -38,19 +38,19 @@ jobs: matrix: runner: [linux.12xlarge] docker-image-name: [ - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9, - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9, - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9, + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9, + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9, + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9, pytorch-linux-focal-py3.8-clang10, pytorch-linux-focal-py3.11-clang10, pytorch-linux-focal-py3.12-clang10, pytorch-linux-focal-rocm-n-1-py3, pytorch-linux-focal-rocm-n-py3, - pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12, + pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12, pytorch-linux-focal-py3-clang9-android-ndk-r21e, pytorch-linux-jammy-py3.8-gcc11, pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks, @@ -58,7 +58,7 @@ jobs: pytorch-linux-jammy-py3-clang15-asan, pytorch-linux-focal-py3-clang10-onnx, pytorch-linux-focal-linter, - pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter, + pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter, pytorch-linux-jammy-py3-clang12-executorch ] include: diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index 726dbf40f985..a1a7e6fd9537 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -54,7 +54,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_8-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cpu-aarch64-test: # Testing @@ -162,7 +162,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-aarch64-test: # Testing @@ -270,7 +270,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-aarch64-test: # Testing @@ -378,7 +378,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-aarch64-test: # Testing @@ -486,7 +486,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-aarch64-test: # Testing diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 6e7edae7b613..053877b1c90e 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -48,7 +48,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda11_8-test: # Testing @@ -88,7 +88,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_1-test: # Testing @@ -128,7 +128,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_4-test: # Testing diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 272e15577cc7..9d59728bbbbb 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -174,7 +174,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda11_8-test: # Testing @@ -237,7 +237,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_1-test: # Testing @@ -300,7 +300,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_4-test: # Testing @@ -690,7 +690,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda11_8-test: # Testing @@ -753,7 +753,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda12_1-test: # Testing @@ -816,7 +816,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda12_4-test: # Testing @@ -1206,7 +1206,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda11_8-test: # Testing @@ -1269,7 +1269,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_1-test: # Testing @@ -1332,7 +1332,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_4-test: # Testing @@ -1722,7 +1722,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda11_8-test: # Testing @@ -1785,7 +1785,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_1-test: # Testing @@ -1848,7 +1848,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_4-test: # Testing @@ -2238,7 +2238,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda11_8-test: # Testing @@ -2301,7 +2301,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_1-test: # Testing @@ -2364,7 +2364,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_4-test: # Testing diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index 4f0569c253f2..db0748463da5 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -54,7 +54,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_8-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cpu-s390x-test: # Testing @@ -117,7 +117,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_9-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-s390x-test: # Testing @@ -180,7 +180,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-s390x-test: # Testing @@ -243,7 +243,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-s390x-test: # Testing @@ -306,7 +306,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-s390x-test: # Testing diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index 94a8fd9cd3de..b4910d46ed5e 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -46,7 +46,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -165,7 +165,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -284,7 +284,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -403,7 +403,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -522,7 +522,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index d64c221e7895..d06f99bd9a5a 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -46,7 +46,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -290,7 +290,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -536,7 +536,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -782,7 +782,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1027,7 +1027,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1271,7 +1271,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1517,7 +1517,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1763,7 +1763,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2008,7 +2008,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2252,7 +2252,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2498,7 +2498,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2744,7 +2744,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2989,7 +2989,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3233,7 +3233,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3479,7 +3479,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3725,7 +3725,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3970,7 +3970,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4214,7 +4214,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4460,7 +4460,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4706,7 +4706,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index 4fe0ddf50ef2..431545ea6d0d 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -21,7 +21,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index e485a8bfce1b..a5e4ad1781aa 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -18,7 +18,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index e77c915749f3..2f129c52fe13 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -71,7 +71,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 34b3dc8101f2..731291697cef 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -23,7 +23,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -64,7 +64,7 @@ jobs: with: sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -105,7 +105,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -131,7 +131,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 08d3b9fcfb24..2030ff5aee3b 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -44,7 +44,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -86,7 +86,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -112,7 +112,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -135,7 +135,7 @@ jobs: with: sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f1b6611d00e0..e0e4d3c20cd8 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,7 +20,7 @@ jobs: with: timeout: 120 runner: linux.2xlarge - docker-image: pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter + docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 @@ -36,7 +36,7 @@ jobs: with: timeout: 120 runner: linux.2xlarge - docker-image: pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter + docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index bcbd7ad6a5b5..bae31f44d742 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -42,7 +42,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, @@ -65,7 +65,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, @@ -120,7 +120,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda11.8-py3.9-gcc9 - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -142,7 +142,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda11.8-py3.10-gcc9-debug - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 build-with-debug: true test-matrix: | { include: [ diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 808e8a3795e3..b435f1fe0791 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -237,7 +237,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda11.8-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "distributed", shard: 1, num_shards: 3, runner: "linux.8xlarge.nvidia.gpu" }, @@ -262,7 +262,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, @@ -297,12 +297,12 @@ jobs: { config: "default", shard: 1, num_shards: 1 }, ]} - linux-jammy-cuda-11_8-cudnn8-py3_8-clang12-build: - name: linux-jammy-cuda11.8-cudnn8-py3.8-clang12 + linux-jammy-cuda-11_8-cudnn9-py3_8-clang12-build: + name: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 uses: ./.github/workflows/_linux-build-label.yml with: - build-environment: linux-jammy-cuda11.8-cudnn8-py3.8-clang12 - docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12 + build-environment: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 + docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -361,7 +361,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-version: cpu test-matrix: | { include: [ @@ -373,7 +373,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-version: "12.1" test-matrix: | { include: [ @@ -385,7 +385,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 cuda-version: "12.4" test-matrix: | { include: [ @@ -447,7 +447,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 31db7af8fc55..50f74b01f08c 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -41,7 +41,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3-gcc9-slow-gradcheck - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -70,7 +70,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index 0ce1bae6a413..e8bf91c8d9ee 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -26,7 +26,7 @@ jobs: id: calculate-docker-image uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 working-directory: pytorch - name: Use following to pull public copy of the image diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index 73befe34c078..ac5814966899 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -16,7 +16,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 4d4cb7672653..6897d4b1fa6d 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -39,7 +39,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -66,7 +66,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: libtorch-linux-focal-cuda12.1-py3.7-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 build-generates-artifacts: false runner: linux.4xlarge test-matrix: | @@ -80,7 +80,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-no-ops - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -91,7 +91,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: libtorch-linux-focal-cuda12.4-py3.7-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 build-generates-artifacts: false runner: linux.4xlarge test-matrix: | @@ -105,7 +105,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-no-ops - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv index 1def1d99bd53..fe7efa082cea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv @@ -218,7 +218,7 @@ tf_mixnet_l,pass,6 -tinynet_a,pass,6 +tinynet_a,fail_accuracy,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv index a3c9c3915fc5..ee58808c0bb0 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv @@ -182,7 +182,7 @@ phlippe_densenet,pass,6 -phlippe_resnet,fail_accuracy,6 +phlippe_resnet,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv index 02411bef6cc5..cfc524426644 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv @@ -182,7 +182,7 @@ phlippe_densenet,pass,6 -phlippe_resnet,fail_accuracy,6 +phlippe_resnet,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv index 1def1d99bd53..fe7efa082cea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv @@ -218,7 +218,7 @@ tf_mixnet_l,pass,6 -tinynet_a,pass,6 +tinynet_a,fail_accuracy,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv index e5464160d32f..ae860db793c9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv @@ -6,7 +6,7 @@ adv_inception_v3,pass,6 -beit_base_patch16_224,pass,7 +beit_base_patch16_224,fail_accuracy,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv index e5464160d32f..ae860db793c9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv @@ -6,7 +6,7 @@ adv_inception_v3,pass,6 -beit_base_patch16_224,pass,7 +beit_base_patch16_224,fail_accuracy,7 diff --git a/docker.Makefile b/docker.Makefile index a33c411907bc..7f131707e7ab 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -10,7 +10,7 @@ endif CUDA_VERSION_SHORT ?= 12.1 CUDA_VERSION ?= 12.1.1 -CUDNN_VERSION ?= 8 +CUDNN_VERSION ?= 9 BASE_RUNTIME = ubuntu:22.04 BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-devel-ubuntu22.04 CMAKE_VARS ?= From e8670f6aeaeded83b7d9f66d1e71c31f81592214 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 6 Jun 2024 19:47:09 +0000 Subject: [PATCH 426/706] [Dynamo][TVM] Support macOS and Linux/aarch64 platforms (#128124) Fixes #128122 With this fix, I've confirmed that the repro works on the platforms below. - macOS 14.5 (arm64) - Ubuntu 20.04.6 LTS (GNU/Linux 5.10.120-tegra aarch64) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128124 Approved by: https://github.com/malfet --- torch/_dynamo/backends/tvm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch/_dynamo/backends/tvm.py b/torch/_dynamo/backends/tvm.py index 84084616995e..bf4413690a1c 100644 --- a/torch/_dynamo/backends/tvm.py +++ b/torch/_dynamo/backends/tvm.py @@ -4,6 +4,7 @@ import importlib import logging import os +import sys import tempfile from types import MappingProxyType from typing import Optional @@ -180,6 +181,10 @@ def has_tvm(): @functools.lru_cache(None) def llvm_target(): - if "avx512" in open("/proc/cpuinfo").read(): - return "llvm -mcpu=skylake-avx512" - return "llvm -mcpu=core-avx2" + if sys.platform == "linux": + cpuinfo = open("/proc/cpuinfo").read() + if "avx512" in cpuinfo: + return "llvm -mcpu=skylake-avx512" + elif "avx2" in cpuinfo: + return "llvm -mcpu=core-avx2" + return "llvm" From 7e059b3c9597c77adff1bffb1ea77e9d3ca7c677 Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 6 Jun 2024 20:25:39 +0000 Subject: [PATCH 427/706] Add a call to validate docker images after build step is complete (#127768) Adds validation to docker images. As discussed here: https://github.com/pytorch/pytorch/issues/125879 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127768 Approved by: https://github.com/huydhn, https://github.com/Skylion007 --- .github/workflows/docker-release.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index 9f5221a88f9c..351497bee753 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -149,3 +149,10 @@ jobs: - name: Teardown Linux uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() + + validate: + needs: build + uses: pytorch/builder/.github/workflows/validate-docker-images.yml@main + with: + channel: nightly + ref: main From 2184cdd29128a924583e4702489177f83fb8270a Mon Sep 17 00:00:00 2001 From: chilli Date: Wed, 5 Jun 2024 09:50:22 -0700 Subject: [PATCH 428/706] Added memory budget to partitioner (#126320) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126320 Approved by: https://github.com/shunting314 --- test/functorch/test_ac.py | 301 +++++++++++++++++++++++ torch/_functorch/config.py | 33 +++ torch/_functorch/partitioners.py | 393 +++++++++++++++++++++++++++++-- 3 files changed, 702 insertions(+), 25 deletions(-) create mode 100644 test/functorch/test_ac.py diff --git a/test/functorch/test_ac.py b/test/functorch/test_ac.py new file mode 100644 index 000000000000..ee3a2c545183 --- /dev/null +++ b/test/functorch/test_ac.py @@ -0,0 +1,301 @@ +# Owner(s): ["oncall: pt2"] +import random + +import torch +import torch._functorch.config as config +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.utils.flop_counter import FlopCounterMode + + +def compile_with_ac(f, memory_budget): + return torch.compile(f, backend="aot_eager_decomp_partition") + + +def get_act_mem(f): + out = f() + out.backward() + start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + out = f() + cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + act_mem = (cur_mem - start_mem) / (1024 * 1024) + out.backward() + return act_mem + + +def get_bw_flops(f): + # Normalized so that a 512 square matmul returns 1 + f().backward() + out = f() + with FlopCounterMode(display=False) as mode: + out.backward() + return mode.get_total_flops() / (512**3 * 2) + + +def create_pair(B_I, O): + # results in B_I * O memory, requires B_I * B_I * O flops + # arithmetic intensity of B_I + x = torch.randn(B_I * 512, B_I * 512, requires_grad=True) + w = torch.randn(B_I * 512, O * 512, requires_grad=True) + return x, w + + +def get_mem_and_flops(f, memory_budget=None): + # Returns megabytes rounded to 1 decimal point and FLOPs + # Note that each value of size (512, 512, torch.float32) is 1 MiB + torch._dynamo.reset() + with config.patch(activation_memory_budget=memory_budget): + if memory_budget is not None: + f = torch.compile(f, backend="aot_eager_decomp_partition") + + # We round this to nearest 10th of a megabyte. + return round(get_act_mem(f), 1), get_bw_flops(f) + + +class MemoryBudgetTest(TestCase): + def setUp(self): + super().setUp() + torch.set_default_device("cuda") + + def test_rematerializes_cheap(self): + def f(x, w): + x = x.cos() + x = torch.mm(x, w) + return x.sum() + + x = torch.randn(512, 512, requires_grad=True) + w = torch.randn(512, 512, requires_grad=True) + + def call(): + return f(x, w) + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, 1.0) + mem_10, flops_10 = get_mem_and_flops(call, memory_budget=1.0) + # Recomputing `.cos()` is not free here. + self.assertEqual(mem_10, 1.0) + self.assertEqual(eager_flops, flops_10) + mem_5, flops_5 = get_mem_and_flops(call, memory_budget=0.5) + # We can just recompute `x.cos()` here to only depend on the inputs + self.assertEqual(mem_5, 0.0) + self.assertEqual(flops_5, eager_flops) + + def test_matmul_even_chain(self): + def f(x, ws): + x = x.cos() + for w in ws: + x = torch.mm(x, w).cos() + return x.sum() + + x = torch.randn(512, 512, requires_grad=True) + ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] + + def call(): + return f(x, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + for budget in range(0, 11): + mem, flops = get_mem_and_flops(call, memory_budget=budget / 10) + if budget <= 5: + # We start saving the matmuls + self.assertEqual(mem, budget) + self.assertEqual(flops, eager_flops + (5 - budget)) + elif budget < 10: + # We're only recomputing the `cos` operations + self.assertEqual(mem, 5.0) + self.assertEqual(flops, eager_flops) + elif budget == 10: + self.assertEqual(mem, 10.0) + self.assertEqual(flops, eager_flops) + + def test_matmul_uneven_chain(self): + # This function is constructed so that we are saving one input of size + # [512, in_dim] for each w + # In addition, every matmul has a same ratio of compute to "memory + # saved", so this test is essentially testing our knapsack solving + + def f(x, ws): + xs = [torch.mm(x, w).cos() for w in ws] + return sum([x.sum() for x in xs]) + + x = torch.randn(512, 512, requires_grad=True) + + def make_weights(w_shapes): + ws = [] + for idx, dim in enumerate(w_shapes): + ws.append(torch.randn(512, dim * 512, requires_grad=True)) + return ws + + def make_weights_chain(w_shapes): + ws = [] + for idx, _ in enumerate(w_shapes): + old_dim = 512 if idx == 0 else w_shapes[idx - 1] * 512 + new_dim = w_shapes[idx] * 512 + ws.append(torch.randn(old_dim, new_dim, requires_grad=True)) + return ws + + weight_configs = [ + ( + [11, 3, 4, 2], + [ + 18, # 11 + 4 + 3 + 17, # 11 + 4 + 2 + 16, # 11 + 3 + 2 + 15, # 11 + 4 + 14, # 11 + 3 + 13, # 11 + 2 + 11, # 11 + 2 + 7, # 4 + 3 + 6, # 4 + 2 + 5, # 3 + 2 + ], + ), + ( + [3, 5, 11, 17, 14], + [ + 42, # 17 + 14 + 9 + 30, # 11 + 15 + 5 + 19, # 11 + 5 + 3 + 8, # 5 + 3 + 3, # 3 + ], + ), + ] + random.seed(0) + random_arr = [random.randint(0, 50) for _ in range(10)] + exact_sums = [] + for i in range(10): + random.shuffle(random_arr) + exact_sums.append(sum(random_arr[:i])) + weight_configs.append((random_arr, exact_sums)) + + for weight_shapes, exact_solves in weight_configs: + ws = make_weights(weight_shapes) + + def call(): + return f(x, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + total_mem = sum(weight_shapes) + self.assertEqual(eager_mem, sum(weight_shapes)) + for mem_achieved in exact_solves: + mem, _ = get_mem_and_flops(call, memory_budget=mem_achieved / total_mem) + self.assertEqual(mem, mem_achieved) + + def test_prioritize_cheaper_matmul(self): + def f(xs, ws): + xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] + return sum([x.sum() for x in xs]) + + x1, w1 = create_pair(1, 4) + x2, w2 = create_pair(2, 2) + + def call(): + return f([x1, x2], [w1, w2]) + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, 8) + self.assertEqual(eager_flops, 24) + comp_mem, comp_flops = get_mem_and_flops(call, memory_budget=0.5) + self.assertEqual(comp_mem, 4) + # We are recomputing x1 @ w1 here! + self.assertEqual(comp_flops, eager_flops + 4) + + @config.patch(activation_memory_budget_runtime_estimator="profile") + def test_profile(self): + def f(x, ws): + x = x.cos() + for w in ws: + x = torch.mm(x, w).cos() + return x.sum() + + x = torch.randn(512, 512, requires_grad=True) + ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] + + def call(): + return f(x, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + mem, flops = get_mem_and_flops(call, memory_budget=0.2) + # We start saving the matmuls + self.assertEqual(mem, 2) + self.assertEqual(flops, eager_flops + 3) + + def test_prioritize_cheaper_matmul2(self): + def f(xs, ws): + xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] + return sum([x.sum() for x in xs]) + + data = [(4, 4), (6, 2), (2, 6)] + xs, ws = zip(*[create_pair(a, b) for a, b in data]) + + def call(): + return f(xs, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, 40) + self.assertEqual(eager_flops, 320) + mem, flops = get_mem_and_flops(call, memory_budget=28 / eager_mem) + # Save w1 and w2 + self.assertEqual(mem, 28) + # We're recomputing w3 (the cheap one!) + self.assertEqual(flops - eager_flops, 2 * 2 * 6) + mem, flops = get_mem_and_flops(call, memory_budget=16 / eager_mem) + # Save w2. Note that even though saving w1 gets us closer to our memory + # limit, w2 is actually *more* FLOPs than w1! + self.assertEqual(mem, 12) + self.assertEqual(flops - eager_flops, 2 * 2 * 6 + 4 * 4 * 4) + + def test_attention_vs_linear(self): + def f(x, w): + orig_shape = x.shape + x = x.reshape(1, 1, x.shape[0], x.shape[1]) + # I know this isn't technically right lol + x = torch.nn.functional.scaled_dot_product_attention( + x, x, x, is_causal=False + ).reshape(*orig_shape) + x = torch.mm(x, w) + x = x.cos() + return x.sum() + + def try_seq_length(S, D, expected_recompute): + x = torch.randn(S * 512, D * 512, requires_grad=True) + w = torch.randn(D * 512, D * 512, requires_grad=True) + + def call(): + return f(x, w) + + with FlopCounterMode(display=False) as mode: + call() + mm_flops = mode.get_flop_counts()["Global"][torch.ops.aten.mm] + attn_flops = mode.get_total_flops() - mm_flops + mm_flops /= 512**3 * 2 + attn_flops /= 512**3 * 2 + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, S * D * 2) + + mem, flops = get_mem_and_flops( + call, memory_budget=0.6 + ) # Force it to recompute one of mm or attn + self.assertEqual(mem, S * D) + if expected_recompute == "attn": + expected_flops = attn_flops + else: + expected_flops = mm_flops + self.assertEqual(flops - eager_flops, expected_flops) + + # General behind this test is that if sequence length * 2 > D, then + # attention is more expensive than the linear. + try_seq_length(1, 1, "mm") + try_seq_length(1, 3, "attn") + try_seq_length(2, 2, "mm") + try_seq_length(2, 1, "mm") + try_seq_length(2, 5, "attn") + try_seq_length(4, 7, "mm") + try_seq_length(4, 9, "attn") + + +if __name__ == "__main__": + if HAS_CUDA: + run_tests() diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index c559951f3809..60bbf1f21c66 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -88,6 +88,39 @@ # a fusion can be expensive. ban_recompute_reductions = True +# By default, the partitioner is purely trying to optimize for runtime (although +# it should always use less memory than eager) +# This knob controls the partitioner to make that tradeoff for you, choosing the +# fastest option that saves less activations than the memory budget. +# Specifically, 0.0 corresponds to the activation memory from applying +# activation checkpointing to the full compiled region, and 1.0 corresponds to +# the activation memory from the default runtime-optimized strategy. So, 0.4 +# would result in a strategy that saves 40% of the activations compared to the +# default strategy. +# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below +# the activation memory budget. +# NOTE: This *cannot* be treated as +activation_memory_budget = 1.0 + +# This controls how we estimate the runtime when deciding what the cheapest +# operators to recompute are. The 3 options are +# "flops": Bases it off of the flop count provided by torch.utils.flop_counter +# "profile": Benchmarks each operator to come up with a runtime +# "testing": Returns 1 for everything +activation_memory_budget_runtime_estimator = "flops" + +# This controls the solver used for the 0-1 knapsack. By default we use a +# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" +# (which has a scipy dependency). +activation_memory_budget_solver = "dp" + +# This dumps out a png visualization of the expected runtime vs. activation +# memory tradeoffs for all memory budget values from 0 to 1 in increments of +# 0.5. See an example here: +# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 +visualize_memory_budget_pareto = ( + os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1" +) # Sets all of the ban_recompute heuristics to False except ban_recompute_reductions # Generally, this will probably result in some memory improvement, but at the diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index cbfb4ca17168..fc1c995e5907 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -25,6 +25,7 @@ ) from torch.fx.passes import graph_drawer from . import config +from ._aot_autograd.logging_utils import get_aot_graph_name from .compile_utils import fx_graph_cse, get_aten_target if TYPE_CHECKING: @@ -451,14 +452,16 @@ def _size_of(node: fx.Node) -> int: # layering violation) elif isinstance(val, (list, tuple)): return sum( - _tensor_nbytes(hint_int(n.numel(), fallback=4098), n.dtype) + _tensor_nbytes(hint_int(n.numel(), fallback=4096), n.dtype) for n in val if isinstance(n, torch.Tensor) ) elif isinstance(val, torch.Tensor): - return _tensor_nbytes(hint_int(val.numel(), fallback=4098), val.dtype) + return _tensor_nbytes(hint_int(val.numel(), fallback=4096), val.dtype) raise RuntimeError(f"Unknown metadata type {type(val)}") + if node.op == "get_attr": + return 0 raise RuntimeError("We should always have `val` metadata on the nodes") @@ -532,25 +535,22 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: for idx, node in enumerate(gm.graph.nodes): order[node] = idx - # Populate depth for the nodes. Depth is the distance from the inputs. - depths = {} - output_node = next(iter(gm.graph.find_nodes(op="output"))) - for node in gm.graph.nodes: - if node.op == "placeholder": - depths[node] = 0 - else: - depths[node] = max([depths[arg] for arg in node.all_input_nodes], default=0) - def insert_node_in_graph(node): - if node in env: - return env[node] + cur_nodes = [node] + insertable_nodes = set() + while len(cur_nodes) > 0: + node = cur_nodes.pop() + if node in insertable_nodes or node in env: + continue + insertable_nodes.add(node) - # Bias traversal towards the nodes that have higher depth - prioritizes - # critical path first. - for arg, _ in sort_depths(node.all_input_nodes, depths): - env[arg] = insert_node_in_graph(arg) - env[node] = new_graph.node_copy(node, lambda x: env[x]) - return env[node] + # Bias traversal towards the nodes that have higher depth - prioritizes + # critical path first. + cur_nodes += node.all_input_nodes + + insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n]) + for node in insertable_nodes: + env[node] = new_graph.node_copy(node, lambda x: env[x]) # Find first bwd node in the graph tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) @@ -750,7 +750,7 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: return joint_module -def get_saved_values( +def solve_min_cut( joint_graph: fx.Graph, node_info: NodeInfo, min_cut_options: MinCutOptions, @@ -877,7 +877,6 @@ def ban_recomputation_if_allowed(node): return False if node in dont_ban: return False - # breakpoint() # This bans recomputation of the node unless we've been forced not to by # user annotation # NB: "recompute" > 0 means that user annotation has asked us to @@ -1268,9 +1267,197 @@ def get_name_to_node(graph: fx.Graph): return name_to_node +def greedy_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + n = len(runtimes) + items = list(range(n)) + + # Sort items based on the ratio of runtime to memory in descending order + items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) + + total_memory = 0.0 + total_runtime = 0.0 + items_to_save = [] + items_to_allow_recomputing = [] + + for i in items: + if total_memory + memory[i] <= max_memory: + total_memory += memory[i] + total_runtime += runtimes[i] + items_to_save.append(i) + else: + items_to_allow_recomputing.append(i) + return total_runtime, items_to_save, items_to_allow_recomputing + + +def ilp_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + import numpy as np + + try: + from scipy.optimize import Bounds, LinearConstraint, milp + except ImportError: + raise RuntimeError( + "To use the ILP for memory budget checkpointing you need to install scipy" + ) from None + + np_memory = np.array(memory) + np_runtimes = np.array(runtimes) + c = -np_runtimes # type: ignore[operator] + + memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) + constraints = [memory_constraint] + + integrality = np.ones_like(c) + res = milp( + c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) + ) + if not res.success: + raise RuntimeError("Somehow scipy solving failed") + + items_to_save = [] + items_to_allow_recomputing = [] + for idx, i in enumerate(res.x): + if i == 1: + items_to_save.append(idx) + else: + items_to_allow_recomputing.append(idx) + return -res.fun, items_to_save, items_to_allow_recomputing + + +def dp_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # Quantize the memory weights + quantized_memory = torch.tensor( + [int(round(m * S)) for m in memory], dtype=torch.long, device="cpu" + ) + runtimes = torch.tensor(runtimes, dtype=torch.float32, device="cpu") + + # Quantized pseudopolynomial DP for 0-1 Knapsack + quantized_max_memory = int(round(max_memory * S)) + + n = len(memory) + + # Initialize the DP table + # TODO(chilli): I think if needed, this memory can be optimized with sliding + # window trick + Hirschberg trick: + # https://codeforces.com/blog/entry/47247?#comment-316200 + dp = torch.zeros( + (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" + ) + + for i in range(1, n + 1): + current_memory = quantized_memory[i - 1] + current_runtime = runtimes[i - 1] + + # Copy the previous row + dp[i, :] = dp[i - 1, :] + + # Update dp[i, j] for all j >= current_memory + if current_memory == 0: + dp[i, :] = dp[i - 1, :] + current_runtime + else: + dp[i, current_memory:] = torch.maximum( + dp[i - 1, current_memory:], + dp[i - 1, :-current_memory] + current_runtime, + ) + + # Backtrack to find the items included in the knapsack + saved_items = [] + recomputable_items = [] + j: int = quantized_max_memory + for i in range(n, 0, -1): + if dp[i][j] != dp[i - 1][j]: + saved_items.append(i - 1) # Include this item (indexing from 0) + j -= int(quantized_memory[i - 1].item()) + else: + recomputable_items.append(i - 1) + + saved_items.reverse() # To get items in the order they were added + + # The maximum runtime that can be achieved within the max_memory constraint + max_runtime = dp[n][quantized_max_memory].item() + + return max_runtime, saved_items, recomputable_items + + +def _optimize_runtime_with_given_memory( + memory: List[float], + runtimes: List[float], + max_memory: float, +) -> Tuple[float, List[int], List[int]]: + SOLVER = config.activation_memory_budget_solver + if SOLVER == "greedy": + return greedy_knapsack(memory, runtimes, max_memory) + elif SOLVER == "ilp": + return ilp_knapsack(memory, runtimes, max_memory) + elif SOLVER == "dp": + return dp_knapsack(memory, runtimes, max_memory) + else: + raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}") + + +from torch.utils._mode_utils import no_dispatch + + +def estimate_runtime(node): + RUNTIME_MODE = config.activation_memory_budget_runtime_estimator + + def materialize_arg(x): + if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor): + shape = list(x.meta["val"].shape) + + def realize_symbol(d): + return hint_int(d, fallback=4096) + + shape = [realize_symbol(s) for s in shape] + return x.meta["val"].new_zeros(shape) + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): + return hint_int(x.meta["val"], fallback=4096) + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat): + return 1.0 + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool): + return True + else: + return x + + if RUNTIME_MODE == "testing": + return 1 + + elif RUNTIME_MODE == "profile": + from triton.testing import do_bench + + with no_dispatch(): + args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) + ms = do_bench(lambda: node.target(*args, **kwargs)) + return ms + + elif RUNTIME_MODE == "flops": + # todo(chilli): Normalize this to also return ms + from torch.utils.flop_counter import FlopCounterMode + + args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) + with FlopCounterMode(display=False) as mode: + node.target(*args, **kwargs) + counted_flops = mode.get_total_flops() + return max(counted_flops, 1) + else: + raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}") + + def choose_saved_values_set( joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 ) -> List[fx.Node]: + if memory_budget > 1 or memory_budget < 0: + raise RuntimeError( + f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}" + ) min_cut_options = MinCutOptions( ban_if_used_far_apart=config.ban_recompute_used_far_apart, ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, @@ -1287,16 +1474,164 @@ def choose_saved_values_set( ban_if_materialized_backward=False, ban_if_not_in_allowlist=False, ) - if memory_budget == 0: return node_info.inputs - runtime_optimized_saved_values, _ = get_saved_values( + runtime_optimized_saved_values, _ = solve_min_cut( joint_graph, node_info, min_cut_options, ) - return runtime_optimized_saved_values + # return runtime_optimized_saved_values + if memory_budget == 1: + return runtime_optimized_saved_values + + def estimate_activations_size(saved_values: List[fx.Node]) -> float: + return sum([_size_of(i) for i in saved_values]) / 1e9 + + min_act_size = estimate_activations_size(node_info.inputs) + max_act_size = estimate_activations_size(runtime_optimized_saved_values) + # The optimized choice is smaller than the inputs anyways + if max_act_size <= min_act_size: + return runtime_optimized_saved_values + + def get_normalized_size(sz): + return (sz / 1e9) / (max_act_size - min_act_size) + + def get_mem_ratio(activations: List[fx.Node]): + return (estimate_activations_size(activations) - min_act_size) / ( + max_act_size - min_act_size + ) + + more_aggressive_options = replace( + min_cut_options, + ban_if_used_far_apart=False, + ban_if_long_fusible_chains=False, + ban_if_materialized_backward=False, + ) + more_aggressive_saved_values, _ = solve_min_cut( + joint_graph, node_info, more_aggressive_options + ) + if get_mem_ratio(more_aggressive_saved_values) < memory_budget: + return more_aggressive_saved_values + + aggressive_options = replace( + more_aggressive_options, + ban_if_not_in_allowlist=False, + ) + aggressive_recomputation_saved_values, banned_nodes = solve_min_cut( + joint_graph, node_info, aggressive_options + ) + + if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget: + return aggressive_recomputation_saved_values + + from torch._inductor.fx_utils import get_node_storage + + input_storages = {get_node_storage(node) for node in node_info.inputs} + + def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]: + return [ + i + for i in banned_nodes + if ( + # Only allow recomputing nodes that are actually required for BW + i.dist_from_bw < int(1e9) # type: ignore[attr-defined] + and get_node_storage(i) not in input_storages + ) + ] + + recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes) + + # default: runtime_optimized_saved_values + # more aggressive: more_aggressive_saved_values + # full aggressive: aggressive_recomputation_saved_values + + all_recomputable_banned_nodes = sorted( + recomputable_banned_nodes, key=_size_of, reverse=True + ) + if len(all_recomputable_banned_nodes) == 0: + return node_info.inputs + memories_banned_nodes = [ + get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes + ] + runtimes_banned_nodes = [ + estimate_runtime(node) for node in all_recomputable_banned_nodes + ] + from torch.utils._mode_utils import no_dispatch + + def get_saved_values_knapsack(memory_budget): + with no_dispatch(): + ( + expected_runtime, + saved_node_idxs, + recomputable_node_idxs, + ) = _optimize_runtime_with_given_memory( + memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0) + ) + dont_ban = set() + for idx in recomputable_node_idxs: + dont_ban.add(all_recomputable_banned_nodes[idx]) + assert dont_ban.issubset(all_recomputable_banned_nodes) + + saved_values, _ = solve_min_cut( + joint_graph, + node_info, + aggressive_options, + dont_ban, + ) + return saved_values, expected_runtime + + if config.visualize_memory_budget_pareto: + options = [] + for sweep_memory_budget in range(100, -1, -5): + saved_values, expected_runtime = get_saved_values_knapsack( + sweep_memory_budget / 100 + ) + options.append( + ( + sweep_memory_budget, + sum(runtimes_banned_nodes) - expected_runtime, + get_mem_ratio(saved_values), + ) + ) + + import matplotlib.pyplot as plt + + x_values = [item[2] for item in options] + y_values = [item[1] for item in options] + + # Plotting the values with updated axis labels and chart title + plt.figure(figsize=(10, 6)) + plt.plot(x_values, y_values, marker="o") + + # Adding labels for each point + for i, txt in enumerate(x_values): + plt.annotate( + f"{txt:.2f}", + (x_values[i], y_values[i]), + textcoords="offset points", + xytext=(0, 10), + ha="center", + ) + + plt.xlabel("Memory Budget") + plt.ylabel("Runtime of Recomputed Components") + plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime") + plt.grid(True) + fig = plt.gcf() + plt.show() + fig_name = f"memory_budget_pareto_{get_aot_graph_name()}.png" + fig.savefig(fig_name) + log.warning("Generated Pareto frontier curve at %s", fig_name) + + # todo(chilli): Estimated doesn't align exactly with actual - actual is + # usually less memory than estimated. i'm guessing (actually quite + # unsure about this) that's because estimated is just only including + # tensors we actually banned from recompute, but there may be other + # tensors that we choose to save. + + return get_saved_values_knapsack(memory_budget=memory_budget)[0] def min_cut_rematerialization_partition( @@ -1412,7 +1747,15 @@ def classify_nodes(joint_module): for user in node.users: node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) - saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget=1) + memory_budget = config.activation_memory_budget + for node in joint_graph.nodes: + if isinstance(node.meta.get("memory_budget", None), float): + memory_budget = node.meta["memory_budget"] + break + # print("Memory Budget: ", memory_budget) + saved_values = choose_saved_values_set( + joint_graph, node_info, memory_budget=memory_budget + ) # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(is_sym_node, saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) From 95543004362e3d1de20003cbb2e0c99e8b6606c4 Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Wed, 29 May 2024 19:29:02 +0000 Subject: [PATCH 429/706] [inductor][codegen] Codegen constexpr globals and constexpr annotated globals correctly. (#126195) [Triton #3762](https://github.com/triton-lang/triton/pull/3762) disallows access to globals which are not `tl.constexpr` Triton has always treated captured globals this way, but they now require it be explicit in user code. Updated codegen to make sure these variables are defined before writing the kernel source when compiling a user defined triton kernel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126195 Approved by: https://github.com/alexbaden, https://github.com/bertmaher --- test/inductor/test_triton_kernels.py | 11 +++++------ torch/_inductor/codegen/wrapper.py | 23 +++++++++++++++++++++-- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index accab8beae6b..113f1daea0f2 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -31,11 +31,10 @@ fast_dividef as my_fast_dividef, ) - -# Define shared triton constants here. -CONSTANT_C = 4 -STRING_CONSTANT_C = "CONSTANT_C" -BOOL_CONSTANT_C = True + # Define shared triton constants here. + CONSTANT_C: tl.constexpr = 4 + STRING_CONSTANT_C: tl.constexpr = "CONSTANT_C" + BOOL_CONSTANT_C: tl.constexpr = True class KernelTests(torch._inductor.test_case.TestCase): @@ -600,7 +599,7 @@ def mulC_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) - if CONSTANT_NAME.value == STRING_CONSTANT_C: + if CONSTANT_NAME == STRING_CONSTANT_C: output = CONSTANT_C * x if BOOL_CONSTANT_C: output *= CONSTANT_C diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index c90776e4cbd2..fa1bb3463cb6 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1207,7 +1207,9 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): # Also include any possible kernel being called indirectly from triton import JITFunction + from triton.language import constexpr + # global constexpr vars handled above symbols_included = {original_name} def traverse(cur_kernel): @@ -1220,6 +1222,7 @@ def traverse(cur_kernel): for inst in dis.Bytecode(cur_kernel.fn) if inst.opname == "LOAD_GLOBAL" } + global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {}) for symbol_name in cur_kernel.fn.__code__.co_names: if symbol_name in symbols_included: continue @@ -1231,9 +1234,25 @@ def traverse(cur_kernel): compile_wrapper.splice(symbol.src, strip=True) symbols_included.add(symbol_name) traverse(symbol) - elif isinstance(symbol, (int, str, bool)): + elif isinstance(symbol, (int, str, bool, constexpr)): compile_wrapper.newline() - compile_wrapper.writeline(f"{symbol_name} = {symbol!r}") + if isinstance(symbol, constexpr): + symbol_str = f"tl.constexpr({symbol.value!r})" + else: + symbol_str = f"{symbol!r}" + if annotation := global_annotations.get(symbol_name): + annotion_code = "" + if isinstance(annotation, type): + annotation_code = ( + f": {annotation.__module__}.{annotation.__name__}" + ) + else: + annotation_code = f": {annotation!r}" + compile_wrapper.writeline( + f"{symbol_name}{annotation_code} = {symbol_str}" + ) + else: + compile_wrapper.writeline(f"{symbol_name} = {symbol!r}") symbols_included.add(symbol_name) elif ( symbol_name in unqualified_loads From baaa914bf7de37608448aba7195c6688e5604ebd Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Thu, 6 Jun 2024 10:34:07 -0700 Subject: [PATCH 430/706] [small] test clean up (#128079) remove unnecessary line: https://github.com/pytorch/pytorch/issues/123733 add main so test can be run `python ...`: https://github.com/pytorch/pytorch/issues/124906 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128079 Approved by: https://github.com/awgu --- .../distributed/elastic/multiprocessing/redirects_test.py | 4 ++++ test/run_test.py | 8 -------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/test/distributed/elastic/multiprocessing/redirects_test.py b/test/distributed/elastic/multiprocessing/redirects_test.py index 0d8c14310f87..2fa507a15a36 100644 --- a/test/distributed/elastic/multiprocessing/redirects_test.py +++ b/test/distributed/elastic/multiprocessing/redirects_test.py @@ -138,3 +138,7 @@ def c_print(i): libc.printf(bytes(f"c:{i}\n", "utf-8")) self._redirect_large_buffer(c_print) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/run_test.py b/test/run_test.py index d9ef52a42af3..065e24f90801 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -24,7 +24,6 @@ import torch.distributed as dist from torch.multiprocessing import current_process, get_context from torch.testing._internal.common_utils import ( - FILE_SCHEMA, get_report_path, IS_CI, IS_MACOS, @@ -745,14 +744,7 @@ def test_distributed(test_module, test_directory, options): old_environ = dict(os.environ) os.environ["TEMP_DIR"] = tmp_dir os.environ["BACKEND"] = backend - os.environ["INIT_METHOD"] = "env://" os.environ.update(env_vars) - if with_init_file: - if test_module.name == "test_distributed_spawn": - init_method = f"{FILE_SCHEMA}{tmp_dir}/" - else: - init_method = f"{FILE_SCHEMA}{tmp_dir}/shared_init_file" - os.environ["INIT_METHOD"] = init_method try: os.mkdir(os.path.join(tmp_dir, "barrier")) os.mkdir(os.path.join(tmp_dir, "test_dir")) From 04272a0e129a7314275d21f871e5fcadc37c796d Mon Sep 17 00:00:00 2001 From: Andrea Frittoli Date: Thu, 6 Jun 2024 21:22:07 +0000 Subject: [PATCH 431/706] Add docstring for the torch.ao.quantization.utils.get_combined_dict function (#128127) Fixes: #127906 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128127 Approved by: https://github.com/jerryzh168 --- torch/ao/quantization/utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index d0de50bbeb57..5ce1d1109e72 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -121,6 +121,25 @@ def check_node(node, modules): return is_call_function, is_call_method, is_call_module def get_combined_dict(default_dict, additional_dict): + """ + Combines two dictionaries. + + This function takes two dictionaries as input and returns a new dictionary + that contains all the key-value pairs from both input dictionaries. + If there are any duplicate keys in the `additional_dict`, the values + from the `additional_dict` will overwrite those in the `default_dict`. + Args: + default_dict (dict): The main dictionary that will be used as the base + additional_dict (dict): The dictionary used to update `default_dict` + + Returns: + dict: The resulting dictionary + Example: + >>> x = dict(a=1, b=1) + >>> y = dict(b=2, c=3) + >>> get_combined_dict(x, y) + {'a': 1, 'b': 2, 'c': 3} + """ d = default_dict.copy() d.update(additional_dict) return d From 54fe2d0e89e1d7c64c1fb2ab120e966a750aff4d Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Thu, 6 Jun 2024 21:43:29 +0000 Subject: [PATCH 432/706] [cuDNN][quantization] skip qlinear test in cuDNN v9.1.0 (#128166) #120006 only very recently unskipped this test 3 days ago so we don't consider it a blocker for cuDNNv9 for now CC @atalman Pull Request resolved: https://github.com/pytorch/pytorch/pull/128166 Approved by: https://github.com/atalman, https://github.com/nWEIdia --- test/quantization/core/test_quantized_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 6671b6634e00..5b86693e11c1 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4052,6 +4052,7 @@ def test_qlinear_with_input_q_dq_qweight_dq_output_fp32( use_channelwise=st.sampled_from([False])) # channelwise currently not supported for qlinear cudnn @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") + @unittest.skipIf(TEST_CUDNN and torch.backends.cudnn.version() == 90100, "expected failure on cuDNN 9.1.0") @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") @unittest.skipIf(TEST_ROCM, "not supported on rocm.") # TODO: check with yang regarding CUDNN flags From 0a761f0627130e739f0e2748e3f71a0c347552c4 Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Thu, 6 Jun 2024 10:25:05 -0700 Subject: [PATCH 433/706] [RFC] Provide optional switches to _dump_nccl_trace (#127651) Summary: Data from PyTorch distributed is mostly useful during initial stages of model development. Provide options to reduce data sent/dumped. `_dump_nccl_trace` takes 3 optional switches. Default as before returns everything - `includeCollectives`: option to also include collectives: Default is True. - `includeStacktraces`: option to include stack traces in collectives. Default is True. - `onlyActive`: option to only send active collective work - i.e. not completed. Default is False (i.e. send everything) Test Plan: Unit tests Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/127651 Approved by: https://github.com/wconstab --- test/distributed/test_c10d_nccl.py | 73 ++++++++------- .../distributed/c10d/ProcessGroupNCCL.cpp | 25 +++-- .../distributed/c10d/ProcessGroupNCCL.hpp | 15 ++- torch/csrc/distributed/c10d/TraceUtils.h | 92 ++++++++++++------- torch/csrc/distributed/c10d/init.cpp | 25 ++++- 5 files changed, 152 insertions(+), 78 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index baf2adb1fb2d..21a8a632bade 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3523,7 +3523,8 @@ class NCCLTraceTest(NCCLTraceTestBase): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("timing_enabled", [True, False]) - def test_short(self, timing_enabled): + @parametrize("include_collectives", [True, False]) + def test_short(self, timing_enabled, include_collectives): if self.rank == self.MAIN_PROCESS_RANK: return pg = self._create_process_group_nccl() @@ -3538,8 +3539,14 @@ def test_short(self, timing_enabled): # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api time.sleep(1) - - t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + if include_collectives: + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + else: + t = pickle.loads( + torch._C._distributed_c10d._dump_nccl_trace( + includeCollectives=False, includeStackTraces=None, onlyActive=None + ) + ) ver = t["version"] self.assertEqual(ver, "2.1") pg_config = t["pg_config"] @@ -3550,35 +3557,39 @@ def test_short(self, timing_enabled): self.assertIn("ranks", default_pg_info) global_ranks = pg_config["0"]["ranks"] self.assertEqual(len(json.loads(global_ranks)), self.world_size) - t = t["entries"] - self.assertEqual(len(t), 2) - last = t[-1] - self.assertEqual(last["process_group"], ("0", "default_pg")) - self.assertEqual(last["state"], "completed") - s = last["time_discovered_started_ns"] - f = last["time_discovered_completed_ns"] - self.assertEqual(last["record_id"], 1) - self.assertIsNotNone(f) - if timing_enabled: - self.assertIsNotNone(s) - self.assertTrue(s <= f) - self.assertIn("test_c10d_nccl.py", str(last["frames"])) - self.assertEqual(last["input_sizes"], ((3, 4),)) - self.assertEqual(last["input_dtypes"], ["Float"]) - self.assertEqual(last["output_sizes"], ((3, 4),)) - self.assertEqual(last["output_dtypes"], ["Float"]) - self.assertEqual(last["collective_seq_id"], 2) - now = datetime.now() - event_created_time = datetime.fromtimestamp( - last["time_created_ns"] / 1000000000 - ) - before_test = now - timedelta(minutes=1) - self.assertTrue(before_test < event_created_time < now) - if timing_enabled: - # very loose bounds, measured 0.036 ms on devgpu - self.assertTrue(0 < last["duration_ms"] < 100) + if include_collectives: + self.assertEqual(len(t["entries"]), 2) + t = t["entries"] + self.assertEqual(len(t), 2) + last = t[-1] + self.assertEqual(last["process_group"], ("0", "default_pg")) + self.assertEqual(last["state"], "completed") + s = last["time_discovered_started_ns"] + f = last["time_discovered_completed_ns"] + self.assertEqual(last["record_id"], 1) + self.assertIsNotNone(f) + if timing_enabled: + self.assertIsNotNone(s) + self.assertTrue(s <= f) + self.assertIn("test_c10d_nccl.py", str(last["frames"])) + self.assertEqual(last["input_sizes"], ((3, 4),)) + self.assertEqual(last["input_dtypes"], ["Float"]) + self.assertEqual(last["output_sizes"], ((3, 4),)) + self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["collective_seq_id"], 2) + now = datetime.now() + event_created_time = datetime.fromtimestamp( + last["time_created_ns"] / 1000000000 + ) + before_test = now - timedelta(minutes=1) + self.assertTrue(before_test < event_created_time < now) + if timing_enabled: + # very loose bounds, measured 0.036 ms on devgpu + self.assertTrue(0 < last["duration_ms"] < 100) + else: + self.assertTrue("duration_ms" not in last) else: - self.assertTrue("duration_ms" not in last) + self.assertTrue("entries" not in t) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 26381207ca7d..8adf1e02c1a0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -342,7 +342,10 @@ void cacheAllocatorDeregisterHook( } #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) -std::string dump_nccl_trace() { +std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { std::unordered_map< std::string /* ncclUniqueID */, std::unordered_map /* dump from this comm */> @@ -362,19 +365,27 @@ std::string dump_nccl_trace() { std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); } - return NCCLTraceBuffer::get()->dump(ncclDumpMap); + return NCCLTraceBuffer::get()->dump( + ncclDumpMap, includeCollectives, includeStackTraces, onlyActive); } + #else -std::string dump_nccl_trace() { - return NCCLTraceBuffer::get()->dump(c10::nullopt); +std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + return NCCLTraceBuffer::get()->dump( + c10::nullopt, includeCollectives, includeStackTraces, onlyActive); } #endif // TODO(c-p-i-o): add a JSON endpoint. control_plane::RegisterHandler dumpHandler{ "dump_nccl_trace_pickle", - [](const control_plane::Request&, control_plane::Response& res) { - res.setContent(dump_nccl_trace(), "application/octet-stream"); + [](const control_plane::Request& req, control_plane::Response& res) { + // TODO: c-p-i-o: params from the request need to go to dump_nccl_trace. + res.setContent( + dump_nccl_trace(true, true, false), "application/octet-stream"); }}; std::optional)>>& @@ -1197,7 +1208,7 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() { // We dump nccl trace into local disk by default and users can register // their customized writer by inheriting `DebugInfoWriter` via // `registerDebugInfoWriter`. - auto ncclTrace = dump_nccl_trace(); + auto ncclTrace = dump_nccl_trace(true, true, false); DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank()); LOG(INFO) << logPrefix() << "ProcessGroupNCCL dumping nccl trace to " << writer.getWriterTarget(); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index f36ebdeb16e9..faaabe411bfc 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -1114,11 +1114,16 @@ class TORCH_API ProcessGroupNCCL : public Backend { ProcessGroupStatus pgStatus_; }; -TORCH_API std::string dump_nccl_trace(); - -// Gets a mutable reference to a global optional function. Heartbeat Monitor -// will use this function to dump traces, if available. Inside fbcode, we store -// a function here that uses an internal tool for process tracing +// Dumps the NCCL comm traces and additional information about the Process +// Group. +TORCH_API std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive); + +// Gets a mutable reference to a global optional function.Heartbeat Monitor +// will use this function to dump traces, if available. Inside fbcode, we +// store a function here that uses an internal tool for process tracing TORCH_API std::optional< std::function)>>& get_cpp_trace_dumper(); diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index e8dadb6537e0..c3b0464cf992 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -655,31 +655,44 @@ struct NCCLTraceBuffer { entry->start_ = entry->end_ = nullptr; } - std::string dump( - const std::optional>>& ncclDumpMap) { - auto result = dump_entries(); + const c10::List getCollectiveTrace( + bool includeStacktraces, + bool onlyActive) { auto entries = new_list(); - + auto result = dump_entries(); std::vector tracebacks; - for (auto& e : result) { - tracebacks.push_back(e.traceback_.get()); - } - torch::SymbolizedTracebacks stracebacks = torch::symbolize(tracebacks); + torch::SymbolizedTracebacks stracebacks; std::vector all_frames; - for (const auto& f : stracebacks.all_frames) { - auto d = new_dict(); - d.insert(name_key, f.funcname); - d.insert(filename_key, f.filename); - d.insert(line_key, int64_t(f.lineno)); - all_frames.emplace_back(std::move(d)); + if (includeStacktraces) { + for (auto& e : result) { + tracebacks.push_back(e.traceback_.get()); + } + stracebacks = torch::symbolize(tracebacks); + for (const auto& f : stracebacks.all_frames) { + auto d = new_dict(); + d.insert(name_key, f.funcname); + d.insert(filename_key, f.filename); + d.insert(line_key, int64_t(f.lineno)); + all_frames.emplace_back(std::move(d)); + } } - for (auto i : c10::irange(result.size())) { - auto& e = result.at(i); - auto& tb = stracebacks.tracebacks.at(i); auto dict = new_dict(); + auto& e = result.at(i); + // Skip completed events + if (onlyActive && e.time_discovered_completed_.has_value()) { + continue; + } + + if (includeStacktraces) { + auto& tb = stracebacks.tracebacks.at(i); + auto frames = new_list(); + for (int64_t frame : tb) { + frames.push_back(all_frames.at(frame)); + } + dict.insert(frames_key, frames); + } + dict.insert(record_id_key, int64_t(e.id_)); dict.insert(pg_id_key, int64_t(e.pg_id_)); dict.insert(pg_name_key, e.pg_name_); @@ -741,13 +754,13 @@ struct NCCLTraceBuffer { dict.insert(retired_key, e.retired_); dict.insert(is_p2p_key, e.isP2P_); - auto frames = new_list(); - for (int64_t frame : tb) { - frames.push_back(all_frames.at(frame)); - } - dict.insert(frames_key, frames); entries.push_back(dict); } + return entries; + } + + // dump pg_entries + const c10::Dict getPgConfig() { auto pg_config = new_dict(); for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { auto pg_info = new_dict(); @@ -756,6 +769,27 @@ struct NCCLTraceBuffer { pg_info.insert("ranks", ranks_str(ranks)); pg_config.insert(std::get<0>(pg_name), pg_info); } + return pg_config; + } + + // dump all collectives + ncclDumpMap + std::string dump( + const std::optional>>& ncclDumpMap, + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + auto result = new_dict(); + // common values + result.insert(version_key, version_val); + result.insert(pg_config_key, getPgConfig()); + + // collective trace + if (includeCollectives) { + result.insert( + entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); + } // convert ncclDumpMap into a dictionary auto per_comm_dict = new_dict(); @@ -768,16 +802,10 @@ struct NCCLTraceBuffer { per_comm_dict.insert(ncclId, inner_dict); } } - - auto dict = new_dict(); - dict.insert(entries_key, entries); - dict.insert(version_key, version_val); if (per_comm_dict.size() > 0) { - dict.insert(nccl_comm_key, per_comm_dict); + result.insert(nccl_comm_key, per_comm_dict); } - dict.insert(pg_config_key, pg_config); - - return pickle_str(dict); + return pickle_str(result); } }; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index ea8e6db9290b..027e87efee56 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -3161,9 +3161,28 @@ such as `dist.all_reduce(tensor, async_op=True)`. Arguments: tensors(List[torch.Tensor]): List of tensors we want to hash. )"); - module.def("_dump_nccl_trace", []() { - return py::bytes(::c10d::dump_nccl_trace()); - }); + module.def( + "_dump_nccl_trace", + [](std::optional includeCollectives, + std::optional includeStackTraces, + std::optional onlyActive) { + return py::bytes(::c10d::dump_nccl_trace( + includeCollectives.value_or(true), + includeStackTraces.value_or(true), + onlyActive.value_or(false))); + }, + py::arg("includeCollectives") = std::optional(), + py::arg("includeStackTraces") = std::optional(), + py::arg("onlyActive") = std::optional(), + R"( + Arguments: + includeCollectives(bool, optional): Whether to include collective work traces. Default is True. + includeStackTraces(bool, optional): Whether to include stacktraces in the collective work traces. Default is True. + onlyActive (bool, optional): Whether to only include active collective work traces. Default is False. + Returns: + Stringified pickle work traces. + Default settings return everything - i.e. contains NCCL comm dumps and collective traces. + )"); #endif intrusive_ptr_class_<::c10d::control_plane::WorkerServer>( From 80fa2778ed12c0dc6f2d16901f6fafd528015b88 Mon Sep 17 00:00:00 2001 From: joncrall Date: Thu, 6 Jun 2024 21:59:18 +0000 Subject: [PATCH 434/706] Update types for verbose in lr_scheduler (#127943) I'm currently locked into jsonargparse version 4.19.0, and it complains when used in combination with LightningCLI (v2.0.8). This is because it cares about the types declared in google style docstrings. This causes a problem when it tries to parse how it should cast arguments to construct an instance of an LRScheduler class because the docstrings declare the "verbose" parameter as a bool, but the defaults recently changed to a string "deprecated". This means the type should really be `bool | str`. This PR adds a `| str` to the docstring type in each learning rate scheduler class. This will prevent jsonargparse from complaining. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127943 Approved by: https://github.com/janeyx99 --- torch/optim/lr_scheduler.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 42c55db82a43..cb7d9738df5a 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -286,7 +286,7 @@ class LambdaLR(LRScheduler): factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -388,7 +388,7 @@ class MultiplicativeLR(LRScheduler): factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -487,7 +487,7 @@ class StepLR(LRScheduler): gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -546,7 +546,7 @@ class MultiStepLR(LRScheduler): gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -609,7 +609,7 @@ class ConstantLR(LRScheduler): total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor. Default: 5. last_epoch (int): The index of the last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -685,7 +685,7 @@ class LinearLR(LRScheduler): total_iters (int): The number of iterations that multiplicative factor reaches to 1. Default: 5. last_epoch (int): The index of the last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -776,7 +776,7 @@ class ExponentialLR(LRScheduler): optimizer (Optimizer): Wrapped optimizer. gamma (float): Multiplicative factor of learning rate decay. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -811,7 +811,7 @@ class SequentialLR(LRScheduler): schedulers (list): List of chained schedulers. milestones (list): List of integers that reflects milestone points. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): Does nothing. + verbose (bool | str): Does nothing. .. deprecated:: 2.2 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the @@ -945,7 +945,7 @@ class PolynomialLR(LRScheduler): optimizer (Optimizer): Wrapped optimizer. total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. power (float): The power of the polynomial. Default: 1.0. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -1035,7 +1035,7 @@ class CosineAnnealingLR(LRScheduler): T_max (int): Maximum number of iterations. eta_min (float): Minimum learning rate. Default: 0. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -1238,7 +1238,7 @@ class ReduceLROnPlateau(LRScheduler): eps (float): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -1468,7 +1468,7 @@ class CyclicLR(LRScheduler): number of *batches* computed, not the total number of epochs computed. When last_epoch=-1, the schedule is started from the beginning. Default: -1 - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -1692,7 +1692,7 @@ class CosineAnnealingWarmRestarts(LRScheduler): T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1. eta_min (float, optional): Minimum learning rate. Default: 0. last_epoch (int, optional): The index of the last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -1896,7 +1896,7 @@ class OneCycleLR(LRScheduler): number of *batches* computed, not the total number of epochs computed. When last_epoch=-1, the schedule is started from the beginning. Default: -1 - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 From 56a3d276fefa256ec354a1b94d756161605535e2 Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Thu, 6 Jun 2024 22:06:51 +0000 Subject: [PATCH 435/706] Handle custom op during TorchScript to ExportedProgram conversion (#127580) #### Description Handle custom ops during TorchScript to ExportedProgram covnersion ```python torch.library.define( "mylib::foo", "(Tensor x) -> Tensor", lib=lib, ) # PyTorch custorm op implementation @torch.library.impl( "mylib::foo", "CompositeExplicitAutograd", lib=lib, ) def foo_impl(x): return x + x # Meta function of the custom op. @torch.library.impl_abstract( "mylib::foo", lib=lib, ) def foo_meta(x): return x + x class M(torch.nn.Module): def forward(self, x): return torch.ops.mylib.foo(x) ``` #### Test Plan * Add a test case where custom op is called and converted. `pytest test/export/test_converter.py -s -k test_ts2ep_converter_custom_op` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127580 Approved by: https://github.com/angelayi --- test/export/test_converter.py | 36 +++++++++++++++++++++++++++++++++++ torch/_export/converter.py | 34 +++++++++++++++------------------ 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 44f23309579d..90e92f183746 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -286,6 +286,42 @@ def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]): inp = (torch.tensor(1), {torch.tensor(4): "foo"}) self._check_equal_ts_ep_converter(MTensorIn(), inp) + def test_ts2ep_converter_custom_op(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + torch.library.define( + "mylib::foo", + "(Tensor x) -> Tensor", + lib=lib, + ) + + # PyTorch custorm op implementation + @torch.library.impl( + "mylib::foo", + "CompositeExplicitAutograd", + lib=lib, + ) + def foo_impl(x): + return x + x + + # Meta function of the custom op. + @torch.library.impl_abstract( + "mylib::foo", + lib=lib, + ) + def foo_meta(x): + return x + x + + class M(torch.nn.Module): + def forward(self, x): + return torch.ops.mylib.foo(x) + + inp = (torch.randn(3, 3),) + m = M() + self._check_equal_ts_ep_converter(m, inp) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 9e438e206984..777249c24a2a 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -65,11 +65,17 @@ def get_op_overload(node: torch._C.Node): ns, op_name = str(schema.name.name).split("::") override = schema.name.overload_name - op_overload_packet = getattr(torch.ops.aten, op_name) - if override: - op_overload = getattr(op_overload_packet, override) - else: - op_overload = op_overload_packet.default + try: + op_overload_mod = getattr(torch.ops, ns) + op_overload_packet = getattr(op_overload_mod, op_name) + if override: + op_overload = getattr(op_overload_packet, override) + else: + op_overload = op_overload_packet.default + except Exception as e: + raise RuntimeError( + f"Unable to find operator {node.kind()} with schema {node.schema}" + ) from e return op_overload @@ -264,11 +270,8 @@ def get_attr(name: str): f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name ) - def convert_aten_op(self, node: torch._C.Node): - try: - target = get_op_overload(node) - except Exception as e: - raise RuntimeError(f"Unsupported node {node.kind()}") from e + def convert_call_function_op(self, node: torch._C.Node): + target = get_op_overload(node) if target is torch.ops.aten.size.int: target = torch.ops.aten.sym_size.int @@ -404,7 +407,7 @@ def convert_aten_div(self, node: torch._C.Node): self.name_to_node[output_name] = fx_node return - self.convert_aten_op(node) + self.convert_call_function_op(node) def convert_aten___getitem__(self, node: torch._C.Node): input_container, index = tuple( @@ -512,16 +515,9 @@ def convert_node(self, node: torch._C.Node): # Provide a default node handler as well in case we don't find # matching converter for that. handler_func_name = ir_name_to_func_name(node_kind) - handler_func = getattr(self, handler_func_name, self.convert_default_node) + handler_func = getattr(self, handler_func_name, self.convert_call_function_op) handler_func(node) - def convert_default_node(self, node: torch._C.Node): - node_kind = node.kind() - if node_kind.startswith("aten::"): - self.convert_aten_op(node) - else: - raise ValueError(f"Unsupported node kind: {node_kind}") - def convert_graph_outputs(self): args = [] for graph_output in self.ts_graph.outputs(): From 6dfdce92ba7ec7da8f1b018aea19f40223bac035 Mon Sep 17 00:00:00 2001 From: brightonanc Date: Thu, 6 Jun 2024 22:47:02 +0000 Subject: [PATCH 436/706] Fixed typos in the complex numbers portion of the autograd docs (#127948) This PR fixes several typos in the complex numbers section of the docs for autograd. Only documentation was altered. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127948 Approved by: https://github.com/soulitzer --- docs/source/notes/autograd.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/notes/autograd.rst b/docs/source/notes/autograd.rst index 81ebf64bc43a..f070f2204183 100644 --- a/docs/source/notes/autograd.rst +++ b/docs/source/notes/autograd.rst @@ -463,7 +463,7 @@ functions are used in the research community since complex numbers are not part ordered field and so having complex valued loss does not make much sense. It also turns out that no interesting real-valued objective fulfill the -Cauchy-Riemann equations. So the theory with homomorphic function cannot be +Cauchy-Riemann equations. So the theory with holomorphic function cannot be used for optimization and most people therefore use the Wirtinger calculus. Wirtinger Calculus comes into the picture ... @@ -602,7 +602,7 @@ Solving the above equations for :math:`\frac{\partial L}{\partial u}` and :math: .. math:: \begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ - \frac{\partial L}{\partial v} = -1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) + \frac{\partial L}{\partial v} = 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) \end{aligned} :label: [3] @@ -610,9 +610,9 @@ Substituting :eq:`[3]` in :eq:`[1]`, we get: .. math:: \begin{aligned} - \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} - 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ + \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} + 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \left(\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j\right) + \frac{\partial L}{\partial s^*} * \left(\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j\right) \\ - &= \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s} * \frac{\partial (u + vj)^*}{\partial z^*} \\ + &= \frac{\partial L}{\partial s} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned} From e5b3387166b0c0969c7a254b7401d9b508509731 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 6 Jun 2024 10:00:12 -0700 Subject: [PATCH 437/706] [dynamo] Bugfix for nn parameter construction (#128001) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128001 Approved by: https://github.com/jansel --- torch/_dynamo/variables/torch.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index fbfabb5fdf06..0b3e28860aaf 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -896,6 +896,13 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): ) assert isinstance(result, variables.TensorVariable) result.class_type = torch.nn.Parameter + + # TODO(jansel/bdhirsh) - There is some issue with + # tracable_create_paramter. It does not seem to use the right + # grad_enabled. Since this is parameter, we can just override the + # has_grad_fn field to False to workaround the issue. + result.has_grad_fn = False + # In reconstruct() should use the original parameter. The one returned by the graph will be an alias. result.source = placeholder.source @@ -919,6 +926,12 @@ def _nn_param_via_prefix_insert(tx, data, requires_grad): cg.store(varname) tx.output.pregraph_bytecode.extend(cg.get_instructions()) + data_node = data.as_proxy().node + if data_node.op not in ("placeholder", "get_attr"): + unimplemented( + "Unexpected type of data placeholder op for parameter construction" + ) + # add the newly constructed nn.Parameter as a graph input source = SyntheticLocalSource(varname) example_value = torch.nn.Parameter( From 7ede78f9f5d7e6c993faa1a70a5f0b0eaec5640d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 6 Jun 2024 10:00:13 -0700 Subject: [PATCH 438/706] [dynamo][nn-modules] Trace through nn.Module dunder methods for UnspecializedNNModule (#126578) Tracing through `__init__` is important because it initializes (calls STORE_ATTR) on members. By doing that, we kick in the mutation tracking for these objects. So, things like mutating `_modules` etc is tracked automatically. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126578 Approved by: https://github.com/jansel ghstack dependencies: #128001 --- test/distributed/test_dynamo_distributed.py | 10 +-- test/dynamo/test_higher_order_ops.py | 16 ++--- ...ddingNN.test_embedding_sparse_empty_tensor | 0 ...ngNN.test_embeddingbag_include_last_offset | 0 ....test_profiler_pattern_matcher_json_report | 0 .../TestJitGeneratedModule.test_nn_Bilinear | 0 .../TestJitGeneratedModule.test_nn_Embedding | 0 ...dModule.test_nn_EmbeddingBag_discontiguous | 0 ...itGeneratedModule.test_nn_EmbeddingBag_max | 0 ...odule.test_nn_EmbeddingBag_max_padding_idx | 0 ...tGeneratedModule.test_nn_EmbeddingBag_mean | 0 ...dule.test_nn_EmbeddingBag_mean_padding_idx | 0 ...eneratedModule.test_nn_EmbeddingBag_sparse | 0 ...itGeneratedModule.test_nn_EmbeddingBag_sum | 0 ...odule.test_nn_EmbeddingBag_sum_padding_idx | 0 ...atedModule.test_nn_Embedding_discontiguous | 0 ...itGeneratedModule.test_nn_Embedding_sparse | 0 .../TestJitGeneratedModule.test_nn_Linear | 0 ...eneratedModule.test_nn_Linear_no_batch_dim | 0 ...GeneratedModule.test_nn_PReLU_no_batch_dim | 0 .../TestNN.test_ParameterDict | 0 .../TestNN.test_Sequential_iadd | 0 .../TestNN.test_bilinear_broadcasting | 0 ...st_layer_norm_grads_with_create_graph_flag | 0 ..._linear_autograd_device_cpu_bias_weightCOO | 0 ..._linear_autograd_device_cpu_bias_weightCSC | 0 ..._linear_autograd_device_cpu_bias_weightCSR | 0 .../TestNN.test_linear_broadcasting | 0 .../TestNN.test_module_apply_inplace_op | 0 ...est_overwrite_module_params_on_conversion} | 0 ...metrized_tensor_parametrization_swap_False | 0 ....test_new_spectral_norm_forward_swap_True} | 0 ...rization.test_new_spectral_norm_swap_True} | 0 ...weight_norm_parametrization_swap_False_cpu | 0 ..._weight_norm_parametrization_swap_True_cpu | 0 ...sorDeviceTypeCPU.test_embedding_jagged_cpu | 0 .../TestPruningNN.test_identity_pruning | 0 ...TestPruningNN.test_pruning_id_consistency} | 0 .../TestPruningNN.test_random_pruning_0perc | 0 test/profiler/test_profiler.py | 1 + torch/_dynamo/create_parameter_op.py | 20 ++++++ torch/_dynamo/mutation_guard.py | 3 + torch/_dynamo/side_effects.py | 32 ++++++---- torch/_dynamo/symbolic_convert.py | 11 +++- torch/_dynamo/utils.py | 4 +- torch/_dynamo/variables/dicts.py | 6 +- torch/_dynamo/variables/misc.py | 26 +++++--- torch/_dynamo/variables/nn_module.py | 40 ++++++++---- torch/_dynamo/variables/torch.py | 9 ++- torch/_dynamo/variables/user_defined.py | 63 ++++++++++++------- 50 files changed, 169 insertions(+), 72 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor delete mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset delete mode 100644 test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim delete mode 100644 test/dynamo_expected_failures/TestNN.test_ParameterDict delete mode 100644 test/dynamo_expected_failures/TestNN.test_Sequential_iadd delete mode 100644 test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting delete mode 100644 test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_broadcasting delete mode 100644 test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op rename test/dynamo_expected_failures/{FakeTensorTest.test_embedding_bag_meta => TestNN.test_overwrite_module_params_on_conversion} (100%) delete mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False rename test/dynamo_expected_failures/{TestCompileTransformsCPU.test_compile_vmap_hessian_cpu => TestNNParametrization.test_new_spectral_norm_forward_swap_True} (100%) rename test/dynamo_expected_failures/{TestEmbeddingNN.test_embedding_max_norm => TestNNParametrization.test_new_spectral_norm_swap_True} (100%) delete mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu delete mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu delete mode 100644 test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu delete mode 100644 test/dynamo_expected_failures/TestPruningNN.test_identity_pruning rename test/dynamo_expected_failures/{TestEmbeddingNN.test_embedding_sparse_basic => TestPruningNN.test_pruning_id_consistency} (100%) delete mode 100644 test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index b31a2f717537..db44f1ce915d 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1084,12 +1084,14 @@ def _(ctx): # far from an exhaustive check of all the expected guards, just check a couple of them. FileCheck().check("""local "L['self']" TYPE_MATCH""").check( """local "L['self']" ID_MATCH""" - ).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check( - f"""{expected_guard_source} "L['self'].net" ID_MATCH""" ).check( - f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH""" + f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" ).check( - f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH""" + f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH""" + ).check( + f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" + ).check( + f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH""" ).run( GUARDS_FILE.getvalue() ) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 9b86a90b02f3..43bc69ea403b 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -5118,10 +5118,10 @@ def wrapper_fn(x): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): - l_self_tensor_constant0 = L_self_tensor_constant0 + def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"): + l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_ - alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None + alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default) @@ -5140,16 +5140,16 @@ def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): - getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_ - getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_ + def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): + l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_ + l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_ l_flat_tangents_1_ = L_flat_tangents_1_ - _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None + _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None - mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None + mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None return (mul_tensor,) """, ) diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset b/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report b/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_ParameterDict b/test/dynamo_expected_failures/TestNN.test_ParameterDict deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_Sequential_iadd b/test/dynamo_expected_failures/TestNN.test_Sequential_iadd deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting b/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag b/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_broadcasting b/test/dynamo_expected_failures/TestNN.test_linear_broadcasting deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op b/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta b/test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion similarity index 100% rename from test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta rename to test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False b/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu b/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True similarity index 100% rename from test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu rename to test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm b/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True similarity index 100% rename from test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm rename to test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu b/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning b/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic b/test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency similarity index 100% rename from test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic rename to test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency diff --git a/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc b/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 38e83d448fdd..2844663ca6e0 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -2408,6 +2408,7 @@ def test_profiler_matmul_dim_fp16_pattern(self): num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) + @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_profiler_pattern_matcher_json_report(self): x = torch.ones((100, 100)) model = nn.Sequential( diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py index 42981fcf1015..601d3c94bdc1 100644 --- a/torch/_dynamo/create_parameter_op.py +++ b/torch/_dynamo/create_parameter_op.py @@ -1,3 +1,6 @@ +import threading +from contextlib import contextmanager + import torch doc = """ @@ -36,3 +39,20 @@ def new_parameter_placeholder(size, dtype, device, requires_grad): # Allocating a zero tensor would causes assert failures in autograd. result.untyped_storage().resize_(0) return result + + +_TLS = threading.local() + + +@contextmanager +def do_not_convert_to_tracable_parameter(): + old_flag = getattr(_TLS, "convert_tracable_parameter", True) + _TLS.convert_tracable_parameter = False + try: + yield False + finally: + _TLS.convert_tracable_parameter = old_flag + + +def can_convert_to_tracable_parameter(): + return getattr(_TLS, "convert_tracable_parameter", True) diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py index 1fa24cfa25bb..00347a012676 100644 --- a/torch/_dynamo/mutation_guard.py +++ b/torch/_dynamo/mutation_guard.py @@ -10,6 +10,9 @@ from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks +unpatched_nn_module_init = torch.nn.Module.__init__ + + class MutationTracker: db = ExactWeakKeyDictionary() diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 647fae379c54..1fa1c004e01a 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -346,13 +346,7 @@ def codegen_save_tempvars(self, cg: PyCodegen): elif isinstance(var.mutable_local, AttributeMutationNew): if isinstance(var, variables.AutogradFunctionContextVariable): unimplemented("AutogradFunctionContextVariable escaped") - if "__call_nn_module_init" in self.store_attr_mutations.get( - var.mutable_local, {} - ): - assert isinstance(var, variables.UnspecializedNNModuleVariable) - cg.load_import_from(utils.__name__, "nn_module_new") - else: - cg.load_import_from(utils.__name__, "object_new") + cg.load_import_from(utils.__name__, "object_new") cg(var.mutable_local.cls_source) cg.extend_output(create_call_function(1, True)) cg.add_cache(var) @@ -479,9 +473,25 @@ def codegen_update_mutated(self, cg: PyCodegen): ] ) elif self.is_attribute_mutation(var): - for name, value in self.store_attr_mutations.get( - var.mutable_local, {} - ).items(): + # Applying mutations involves two steps: 1) Push all + # reconstructed objects onto the stack. 2) Call STORE_ATTR to + # apply the mutations. + # + # Dynamo must ensure that mutations are applied in the same + # order as in the original program. Therefore, two reverse + # operations occur below. + # + # The first reverse operation concerns `suffixes`. We apply + # suffixes in reverse order due to the way Python handles the + # stack. In Step 1, we push all reconstructed objects onto the + # stack, but the item at the top of the stack refers to the last + # attribute in the mutation order. If not fixed, this will apply + # the mutations of attributes in the reverse order. To account + # for this reversal, we iterate through the mutable attributes + # in reverse order. + for name, value in reversed( + self.store_attr_mutations.get(var.mutable_local, {}).items() + ): if isinstance(var, variables.NewGlobalVariable): cg.tx.output.update_co_names(name) cg(value) @@ -489,8 +499,6 @@ def codegen_update_mutated(self, cg: PyCodegen): suffixes.append( [create_instruction("STORE_GLOBAL", argval=name)] ) - elif name == "__call_nn_module_init": - pass # handled in codegen_save_tempvars elif isinstance(value, variables.DeletedVariable): if isinstance( var.mutable_local, AttributeMutationExisting diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 71ed48fbb292..30f28e2ab265 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -415,10 +415,15 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): self.push(value) self.jump(inst) elif isinstance(value, UserDefinedObjectVariable): - x = value.var_getattr(self, "__bool__") - # if __bool__ is missing, trying __len__ to infer a truth value. - if isinstance(x, GetAttrVariable): + try: + x = value.var_getattr(self, "__bool__") + except exc.ObservedException: + # if __bool__ is missing, trying __len__ to infer a truth value. x = value.var_getattr(self, "__len__") + else: + if isinstance(x, GetAttrVariable): + # if __bool__ is missing, trying __len__ to infer a truth value. + x = value.var_getattr(self, "__len__") # __bool__ or __len__ is function if isinstance(x, UserMethodVariable): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 04f01c757674..54c497be3781 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2018,12 +2018,12 @@ def object_has_getattribute(value: Any): return False -def get_custom_getattr(value: Any): +def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): try: getattr_fn = inspect.getattr_static(type(value), "__getattr__") except AttributeError: getattr_fn = None - if getattr_fn is torch.nn.Module.__getattr__: + if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__: # ignore this case of getattr getattr_fn = None return getattr_fn diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 0724a80621f7..8391563c8e76 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -174,7 +174,11 @@ def python_type(self): def __contains__(self, vt): assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker - return is_hashable(vt) and Hashable(vt) in self.items + return ( + is_hashable(vt) + and Hashable(vt) in self.items + and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) + ) def reconstruct(self, codegen): # instructions to load collections.OrderedDict if necessary diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index cc0fb7096701..9ef36eb7f29f 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -14,8 +14,10 @@ import torch.utils._pytree as pytree from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction +from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import unimplemented from ..guards import GuardBuilder, install_guard +from ..mutation_guard import unpatched_nn_module_init from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource from ..utils import ( check_unspec_or_constant_args, @@ -121,7 +123,6 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) - if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: @@ -133,12 +134,10 @@ def call_method( and isinstance(objvar.mutable_local, AttributeMutationNew) and not (args or kwargs) ): - tx.output.side_effects.store_attr( - objvar, - "__call_nn_module_init", - variables.ConstantVariable.create(True), - ) - return variables.ConstantVariable.create(None) + with do_not_convert_to_tracable_parameter(): + return variables.UserFunctionVariable( + unpatched_nn_module_init, source=source + ).call_function(tx, [self.objvar] + args, kwargs) else: unimplemented("super() nn.Module.__init__") elif isinstance(inner_fn, types.FunctionType): @@ -175,6 +174,19 @@ def call_method( self.objvar, UserDefinedObjectVariable ): return self.objvar.method_setattr_standard(tx, *args, **kwargs) + elif inner_fn is object.__delattr__: + attr = args[0] + try: + attr = attr.as_python_constant() + except NotImplementedError: + unimplemented(f"non-const delattr attr: {attr}") + if not tx.output.side_effects.is_attribute_mutation(self.objvar): + unimplemented(f"delattr({self.objvar}, {attr}, ...)") + + tx.output.side_effects.store_attr( + self.objvar, attr, variables.DeletedVariable() + ) + return variables.ConstantVariable(None) unimplemented(f"non-function or method super: {inner_fn}") diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 0a6bad4730dd..5699d7341429 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -236,7 +236,7 @@ def _custom_getattr_fallback(self, base, tx, name, options): if object_has_getattribute(base): unimplemented("torch.nn.Module with a custom __getattribute__ defined") - getattr_fn = get_custom_getattr(base) + getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True) if getattr_fn is None: return None @@ -672,7 +672,6 @@ def gen_source(source, name): if isinstance(args[0], SliceVariable): # Build a TupleVariable of NNModules result = [] - submods = [] # Turn the slice into the list of integers keys = list(range(len(module)))[args[0].as_python_constant()] @@ -686,9 +685,8 @@ def gen_source(source, name): source=src, ) ) - submods.append(submod) - new_module = torch.nn.Sequential(*submods) + new_module = module[args[0].as_python_constant()] new_module_variable = tx.output.register_attr_or_module( new_module, f"{self}.__getitem__(slice)", @@ -702,8 +700,10 @@ def gen_source(source, name): if isinstance(args[0], SymNodeVariable): key = args[0].evaluate_expr(tx.output) - else: + elif args[0].is_python_constant(): key = args[0].as_python_constant() + else: + unimplemented(f"getitem on NNModuleVariable with key {args[0]}") submod = module[key] return tx.output.register_attr_or_module( @@ -783,7 +783,7 @@ def __init__(self, value, **kwargs): @functools.lru_cache(None) def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler - supported = {torch.nn.Module.__setattr__} + supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} return { id(x.__code__) for x in torch.nn.Module.__dict__.values() @@ -791,8 +791,6 @@ def _nn_module_method_ids(): } def unpack_var_sequence(self, tx): - from .builder import VariableBuilder - try: fn = inspect.getattr_static(self.value_type, "__iter__") except AttributeError as e: @@ -803,11 +801,16 @@ def unpack_var_sequence(self, tx): torch.nn.ParameterList.__iter__, torch.nn.Sequential.__iter__, ): - assert self.source - return [ - VariableBuilder(tx, source=GetItemSource(self.source, idx))(item) - for idx, item in enumerate(self.value) - ] + # The program can mutate the nn module object but the saved `value` + # will not reflect the mutations. So, trace through the `__iter__` + # function to reflect any tracked mutations. + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn), + [ + self, + ], + {}, + ).unpack_var_sequence(tx) return super().unpack_var_sequence(tx) @@ -934,6 +937,17 @@ def call_method( # Handle submodules self.is_state_mutated = True + if method is torch.nn.Module.__setattr__ and isinstance( + args[1], variables.DeletedVariable + ): + # Trace through __delattr__ to track mutations on the module + # members like `_modules``. + return tx.inline_user_function_return( + variables.UserFunctionVariable(torch.nn.Module.__delattr__), + [self, args[0]], + kwargs, + ) + return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 0b3e28860aaf..36fa0a697032 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -17,7 +17,11 @@ from ..._guards import TracingContext from .. import config, polyfill, variables from ..codegen import PyCodegen -from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter +from ..create_parameter_op import ( + can_convert_to_tracable_parameter, + new_parameter_placeholder, + tracable_create_parameter, +) from ..device_interface import get_registered_device_interfaces from ..exc import unimplemented from ..guards import GuardBuilder, install_guard @@ -870,6 +874,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) + if not can_convert_to_tracable_parameter(): + unimplemented("Workaround for issues with nn_parameter construction") + try: shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) dtype = data.var_getattr(tx, "dtype").as_python_constant() diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 5b785293911f..d5faafcffbed 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -34,7 +34,8 @@ from torch._guards import TracingContext from .. import variables -from ..exc import unimplemented +from ..create_parameter_op import do_not_convert_to_tracable_parameter +from ..exc import ObservedException, unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource from ..utils import ( @@ -57,10 +58,7 @@ def is_standard_setattr(val): - return val in ( - object.__setattr__, - torch.nn.Module.__setattr__, - ) + return val in (object.__setattr__,) class UserDefinedVariable(VariableTracker): @@ -378,17 +376,7 @@ def call_function( else UserDefinedObjectVariable, {}, ) - if ( - inspect.getattr_static(self.value, "__init__", None) - is torch.nn.Module.__init__ - ): - tx.output.side_effects.store_attr( - var, - "__call_nn_module_init", - variables.ConstantVariable.create(True), - ) - return var - else: + with do_not_convert_to_tracable_parameter(): var.call_method(tx, "__init__", args, kwargs) return var elif variables.CustomizedDictVariable.is_matching_cls(self.value): @@ -638,6 +626,10 @@ def call_method( else AttrSource(AttrSource(self.source, "__class__"), name) ) # TODO(jansel): add a guard to check for monkey patching? + from ..mutation_guard import unpatched_nn_module_init + + if method is torch.nn.Module.__init__: + method = unpatched_nn_module_init return UserMethodVariable(method, self, source=source).call_function( tx, args, kwargs ) @@ -799,7 +791,7 @@ def _check_for_getattr(self): def _getattr_static(self, name): if ( - isinstance(self.value, (torch.nn.Module, PyTreeSpec)) + isinstance(self.value, PyTreeSpec) or "__slots__" in self.value.__class__.__dict__ or type(self.value) == threading.local ): @@ -812,7 +804,6 @@ def _getattr_static(self, name): return cls_var except AttributeError: pass # __slots__ - # this might call torch.nn.Module.__getattr__ subobj = getattr(self.value, name) else: subobj = inspect.getattr_static(self.value, name) @@ -1001,14 +992,35 @@ def call_hasattr(self, tx, name: str) -> "VariableTracker": install_guard( AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) ) - if self._check_for_getattribute() or self._check_for_getattr(): - unimplemented("hasattr with custom __getattr__") + if self._check_for_getattribute(): + unimplemented("hasattr with custom __getattribute__") try: self._getattr_static(name) return variables.ConstantVariable.create(True) except AttributeError: - return variables.ConstantVariable.create(False) + # Now check in __getattr__ function + getattr_fn = self._check_for_getattr() + if isinstance(getattr_fn, types.FunctionType): + # Dynamo is going to trace the __getattr__ function with + # args=name. Set the source accordingly. + new_source = None + if self.source: + new_source = AttrSource(self.source, "__getattr__") + try: + result = variables.UserMethodVariable( + getattr_fn, self, source=new_source + ).call_function(tx, [variables.ConstantVariable.create(name)], {}) + + return variables.ConstantVariable.create( + not isinstance(result, variables.DeletedVariable) + ) + except ObservedException: + return variables.ConstantVariable.create(False) + elif getattr_fn is None: + return variables.ConstantVariable.create(False) + else: + unimplemented("UserDefined with non-function __getattr__") def odict_getitem(self, tx, key): from .builder import VariableBuilder @@ -1075,6 +1087,12 @@ def var_getattr(self, tx, name): return super().var_getattr(tx, name) +class RemovableHandleClass: + # Dummy class to pass to python_type of RemovableHandleVariable + # Useful for isinstance check on hooks + pass + + class RemovableHandleVariable(VariableTracker): REMOVED = -1 @@ -1105,3 +1123,6 @@ def reconstruct(self, codegen): return # unreachable due to codegen.add_cache() when the hook is installed super().reconstruct(codegen) + + def python_type(self): + return RemovableHandleClass From 901226ae837bd4629b34735c84a3481c4988bb5b Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Wed, 5 Jun 2024 23:23:21 -0700 Subject: [PATCH 439/706] [inductor] simplify indexing (#127661) This is a short term fix for: https://github.com/pytorch/pytorch/issues/124002 We found the cause of bad perf for the int8_unpack kernel is due to sub-optimal indexing. In this PR we introduce 2 indexing optimizations: 1. expand FloorDiv to the entire expression when feasible. E.g. `x1 * 1024 + x2 // 2` will be transformed to `(x1 * 2048 + x2) // 2`. The motivation is that we have more chance to simplify loops for `x1 * 2048 + x2`. 2. merge ModularIndexing pairs: `ModularIndexing(ModularIndex(x, 1, a), 1, b)`, can be simplified to `ModularIndexing(x, 1, b)` if a is a multiple of b. With both indexing optimizations, we improve int8_unpack perf by 1.54x (183us -> 119us). Pull Request resolved: https://github.com/pytorch/pytorch/pull/127661 Approved by: https://github.com/jansel --- test/inductor/test_indexing.py | 78 ++++++++++++++++++- torch/_inductor/codegen/simd.py | 19 ++++- torch/_inductor/sizevars.py | 131 ++++++++++++++++++++++++++++++++ 3 files changed, 226 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index da527cfbb1d8..19a736160908 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -1,16 +1,24 @@ # Owner(s): ["module: inductor"] +import os +import unittest + import sympy +import torch + from torch._inductor.codegen.cpp import cexpr from torch._inductor.codegen.triton import texpr from torch._inductor.codegen.wrapper import pexpr +from torch._inductor.runtime.runtime_utils import do_bench_gpu from torch._inductor.sizevars import SizeVarAllocator from torch._inductor.test_case import TestCase as InductorTestCase +from torch._inductor.utils import run_and_get_triton_code from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, ) +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA from torch.utils._sympy.functions import ( FloorDiv, ModularIndexing, @@ -18,6 +26,8 @@ RoundToInt, ) +DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" + class TestIndexingSimplification(InductorTestCase): def test_indexing_simplification(self): @@ -164,6 +174,73 @@ def test_indexing_join(self): self.assertEqual(simplified, FloorDiv(i0, 3)) self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485})) + def test_modular_indexing_pairs_merged(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + a = 1024 + b = 32 + expr1 = ModularIndexing(x, 1, a) + expr2 = ModularIndexing(expr1, 1, b) + expected = ModularIndexing(x, 1, b) + + actual = sizevars.combine_modular_indexing_pairs(expr2) + self.assertEqual(expected, actual) + self.assertNotEqual(expr2, actual) + + def test_modular_indexing_pairs_not_merged(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + a = 1024 + b = 3 # pick a 'b' that we can not merge + expr1 = ModularIndexing(x, 1, a) + expr2 = ModularIndexing(expr1, 1, b) + + actual = sizevars.combine_modular_indexing_pairs(expr2) + self.assertEqual(expr2, actual) + self.assertNotEqual(ModularIndexing(x, 1, b), actual) + + def test_expand_floor_div_skipped(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + y = sympy.Symbol("y", integer=True, positive=True) + + expr = FloorDiv(x, 2) + FloorDiv(y, 3) + # The expression can not be simplified since there are multiple + # FloorDiv. We return False in that case + self.assertFalse(sizevars.expand_floor_div(expr)) + + def test_expand_floor_div_applied(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + y = sympy.Symbol("y", integer=True, positive=True) + + expr = x * 5 + FloorDiv(y, 3) + actual, denominator = sizevars.expand_floor_div(expr) + self.assertNotEqual(expr, actual) + expected = FloorDiv(x * 15 + y, 3) + self.assertEqual(expected, FloorDiv(actual, denominator)) + + @unittest.skipUnless(HAS_CUDA, "Need GPU for this test") + def test_int8_unpack(self): + @torch.compile + def f(x): + first_elements = x >> 4 + second_elements = x & 15 + unpacked = torch.stack([first_elements, second_elements], dim=-1).view( + *x.size()[:-1], -1 + ) + return unpacked * 2 + + x = torch.randint(0, 255, (2, 4096, 5504), dtype=torch.uint8, device="cuda") + + triton_code = run_and_get_triton_code(f, x) + # Make sure the 2 load uses simpified indexing rather than something like + # tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), + self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),")) + if DO_PERF_TEST: + ms = do_bench_gpu(lambda: f(x)) + print(f"{ms=:.03f}") + class ExprPrinterTests(InductorTestCase): def test_print_pow(self): @@ -281,7 +358,6 @@ def test_print_Min_Max(self): if __name__ == "__main__": from torch._inductor.test_case import run_tests - from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA if HAS_CPU or HAS_CUDA: run_tests("sympy") diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index ed7261f2a3eb..c5fc2747bee7 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -338,7 +338,8 @@ def simplify_indexing(index: sympy.Expr): index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) for tree in self.range_trees: index = self.combine_contiguous_dims(index, tree) - return index + + return self.combine_modular_indexing_pairs(index) self.simplify_indexing = simplify_indexing self.initialize_range_tree(pid_cache) @@ -422,7 +423,23 @@ def dense_size_str(self): sizes = self.dense_size_list() return f"[{', '.join(sizes)}]" + def combine_modular_indexing_pairs(self, index): + if not isinstance(index, ModularIndexing): + return index + x = index.args[0] + if (tree_node := self.range_tree_nodes.get(x)) is None: + return index + new_index = sympy_subs(index, {x: tree_node.expr}) + return V.graph.sizevars.combine_modular_indexing_pairs(new_index) + def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): + if expand_res := V.graph.sizevars.expand_floor_div(index): + new_index, denominator = expand_res # type: ignore[misc] + return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator) + else: + return self._combine_contiguous_dims(index, tree) + + def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): """ More aggressive simplification to merge contiguous dims """ diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index fba9a66f9237..910e85e79906 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -583,6 +583,137 @@ def lookup_precomputed_size(self, expr: Expr) -> Expr: def free_symbols(self) -> Set[sympy.Symbol]: return set(self.var_to_val.keys()) - set(self.replacements.keys()) + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: + """ + A pair of special ModularIndexing can be combined. + + E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b) + We can simplify this to ModuleIndexing(x, 1, b), if + 1. x is non negative integer + 2. a and b are positive integers + 3. a is a multiple of b. + """ + + def _check_args(x, div, mod, is_first): + if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer): + return False + if div != 1: + return False + if mod <= 0: + return False + + if is_first: + # first ModularIndexing should conatins a nested ModularIndex + if not isinstance(x, ModularIndexing): + return False + else: + # second ModularIndexing should constains a non-negative + # symbol + if not isinstance(x, sympy.Symbol) or not self.statically_known_geq( + x, 0 + ): + return False + return True + + if isinstance(index, ModularIndexing): + x, div, mod = index.args + + if not _check_args(x, div, mod, True): + return index + + x2, div2, mod2 = x.args + + if not _check_args(x2, div2, mod2, False): + return index + + if mod2 % mod != 0: + return index + + return ModularIndexing(x2, 1, mod) + + return index + + def expand_floor_div( + self, index: sympy.Expr + ) -> Union[bool, Tuple[sympy.Expr, sympy.Expr]]: + """ + Expand the FloorDiv to the entire expression so that the expression may + be simplfied. + + E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables + x1, x2, index expression 'x1 * 2b + x2' can be easily combined. + But index expression 'x1 * b + x2 // 2' can not. + By expanding the FloorDiv to the entire expression, we get + '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops + for the numerator! + + Return false if this optimization can be applied; + Return the new expression and the denominator otherwise. + The original expression will be equivalent to 'new_expression // denominator' + """ + if not isinstance(index, sympy.Add): + return False + terms = index.args + + if len(terms) < 2: + return False + floor_div_index = -1 + varlist = [] + factorlist = [] + for idx, term in enumerate(terms): + if isinstance(term, sympy.Mul): + # For dynamic shape, term like '2*s1*x1' has 3 child nodes. + # - A integer for 2 + # - A symbol for s1 + # - A symbol for x1 + # Skip for now. + if len(term.args) != 2: + return False + factor, var = term.args + varlist.append(var) + factorlist.append(factor) + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + # It's easier to reason about the correceness of the transformation + # for non-negative integers. + if not self.statically_known_geq(var, 0): + return False + elif isinstance(term, FloorDiv): + var, factor = term.args + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + if not self.statically_known_geq(var, 0): + return False + if floor_div_index >= 0: + # can not handle multi FloorDiv yet + return False + + floor_div_index = idx + varlist.append(var) + # this factor is denominator + factorlist.append(factor) + else: + return False + + if floor_div_index < 0: + return False + + # Construct the new expression and remember the denominator + denominator = factorlist[floor_div_index] + new_index = sympy.Integer(0) + + for var, factor, idx in zip(varlist, factorlist, itertools.count()): + if idx == floor_div_index: + new_index += var + else: + new_index += (factor * denominator) * var + + return new_index, denominator + def join_dimensions(expr: Expr) -> Expr: if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): From 117ab34891c26fff223af5d9b602da79a3a18ce8 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 7 Jun 2024 00:43:18 +0000 Subject: [PATCH 440/706] Documenting the torch.utils.collect_env.get_pretty_env_info function (#128123) Fixes #127888 This PR adds docstring to the `torch.utils.collect_env.get_pretty_env_info` function Pull Request resolved: https://github.com/pytorch/pytorch/pull/128123 Approved by: https://github.com/ezyang, https://github.com/malfet --- torch/utils/collect_env.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index b2254312109d..039bc012226c 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -614,6 +614,17 @@ def maybe_start_on_next_line(string): def get_pretty_env_info(): + """ + Returns a pretty string of environment information. + + This function retrieves environment information by calling the `get_env_info` function + and then formats the information into a human-readable string. The retrieved environment + information is listed in the document of `get_env_info`. + This function is used in `python collect_env.py` that should be executed when reporting a bug. + + Returns: + str: A pretty string of the environment information. + """ return pretty_str(get_env_info()) From 740cd0559f3b5926b3d0e6879cd682db9c40ce87 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Fri, 7 Jun 2024 00:48:49 +0000 Subject: [PATCH 441/706] Filter non input symexprs from codecache guards (#128052) Summary: Dynamo lifts all symexprs that appear in the inputs to top level which means that we do not need to look at guards that contain symexprs that do not appear in the inputs. Prune them. Test Plan: added two new tests Differential Revision: D58200476 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128052 Approved by: https://github.com/ezyang, https://github.com/masnesral --- test/inductor/test_codecache.py | 75 ++++++++++++++++++++++++ torch/_inductor/codecache.py | 5 +- torch/fx/experimental/symbolic_shapes.py | 38 ++++++++++-- 3 files changed, 111 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 1330d635f8db..21d70d90d290 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -465,6 +465,81 @@ def fn(x, y): self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + @config.patch({"fx_graph_cache": True}) + def test_cache_with_nt(self): + def gen_nt(r): + values = torch.randn(r, 16) + offsets = torch.tensor([0, 2, 3, 6, 13, r]) + return torch.nested.nested_tensor_from_jagged(values, offsets) + + def fn(nt): + if nt.values().size(0) % 16 == 0: + return nt.sin() + return nt.cos() + + inp1 = gen_nt(19) + inp2 = gen_nt(20) + + counters.clear() + torch.compile(fn)(inp1) + torch.compile(fn)(inp2) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + + self.reset() + counters.clear() + torch.compile(fn)(inp1) + torch.compile(fn)(inp2) + # TODO(oulgen): This doesnt actually produce a cache hit. + # Despite pickling the exact same object, pickle produces different + # results. + # self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + # self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + + @config.patch({"fx_graph_cache": True}) + def test_cache_with_symint_non_arg_guard(self): + def fn(x, ref_id): + self_id = 22 + if self_id == ref_id: + x = torch.mul(x, 1.0) + else: + x = torch.mul(x, 0) + return x + + x = torch.ones(2) + + counters.clear() + torch.compile(fn, fullgraph=True, dynamic=True)(x, 2) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + + self.reset() + counters.clear() + torch.compile(fn, fullgraph=True, dynamic=True)(x, 2) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + + @config.patch({"fx_graph_cache": True}) + def test_cache_guard(self): + def f(x, val): + if val > 5: + return x.sin() + else: + return x.cos() + + x = torch.ones(2) + a = torch.compile(f, dynamic=True)(x, 6) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + + self.reset() + counters.clear() + b = torch.compile(f, dynamic=True)(x, 4) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + + self.assertNotEqual(a, b) + class TestFxGraphCacheHashing(TestCase): def test_tensor_constants(self): diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index a421432125b6..335cec4d4056 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -905,7 +905,10 @@ def _save_graph( shape_env = FxGraphCache._get_shape_env() assert shape_env is not None symints = FxGraphCache._filter_backed_symints(example_inputs) - disk_compiled_graph.guards_expr = shape_env.produce_guards_expression(symints) + guards = shape_env.get_pruned_guards(symints) + disk_compiled_graph.guards_expr = shape_env.produce_guards_expression( + placeholders=symints, guards=guards + ) try: content = pickle.dumps(disk_compiled_graph) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 687d2bcbd1eb..e2573bcd3ef9 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -3612,6 +3612,7 @@ def produce_guards( sources, source_ref=lambda n: n.name(), *, + guards: List[ShapeGuard] = None, input_contexts: Optional[DimList[SymbolicContext]] = None, # Encodes user-specified input shape equations of the form s = s' and s = fn(s'). # (See docs on EqualityConstraint for details of the encoding.) @@ -4080,7 +4081,7 @@ def issue_guard(guard: ShapeGuard) -> None: # First, issue all guards. # This removes all the checks that follow from bounds # We could simply emit those and also the bounds 2 <= size when necessary - for guard in self.guards: + for guard in (guards if guards is not None else self.guards): if self._maybe_evaluate_static(guard.expr, axioms=()) is not None: continue issue_guard(guard) @@ -4208,10 +4209,18 @@ def issue_guard(guard: ShapeGuard) -> None: with fx_traceback.preserve_node_meta(): PopulateValidator(self.graph, self.validator).run() - self._check_translation_validate() + # Only run translation validation when we are not passing custom guards + if guards is None: + self._check_translation_validate() return exprs - def produce_guards_expression(self, placeholders, ignore_static=True): + def produce_guards_expression( + self, + placeholders, + *, + guards: Optional[List[ShapeGuard]] = None, + ignore_static=True + ): """ Expected to be used with evaluate_guards_expression(). Produces the guards for the given placeholders and returns a string expression to be evaluated @@ -4219,9 +4228,14 @@ def produce_guards_expression(self, placeholders, ignore_static=True): """ from torch._dynamo.source import LocalSource arg_names = [f"t{i}" for i in range(len(placeholders))] - guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static) - if guards: - return " and ".join(guards) + produced_guards = self.produce_guards( + placeholders, + [LocalSource(a) for a in arg_names], + guards=guards, + ignore_static=ignore_static, + ) + if produced_guards: + return " and ".join(produced_guards) return None def evaluate_guards_expression(self, code, args): @@ -4240,6 +4254,18 @@ def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True): return self.evaluate_guards_expression(code, args) return True + def get_pruned_guards(self, symints): + """ + Get a list of guards, but pruned so it only provides guards that + reference symints from the passed in input + """ + symints = {s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)} + guards = [] + for g in self.guards: + if all(s in symints for s in g.expr.free_symbols): + guards.append(g) + return guards + def bind_symbols(self, placeholders, args): """ Given a paired list of placeholders (fake tensors with From f99409903cadd0a2160c81e2e3beea1bd38f6929 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 7 Jun 2024 00:49:40 +0000 Subject: [PATCH 442/706] Documenting `torch.distributions.utils.clamp_probs` (#128136) Fixes https://github.com/pytorch/pytorch/issues/127889 This PR adds docstring to the `torch.distributions.utils.clamp_probs` function. Co-authored-by: Svetlana Karslioglu Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/128136 Approved by: https://github.com/janeyx99, https://github.com/svekars, https://github.com/malfet --- torch/distributions/utils.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 7a6d31a05722..91e4345e983c 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -90,6 +90,27 @@ def logits_to_probs(logits, is_binary=False): def clamp_probs(probs): + """Clamps the probabilities to be in the open interval `(0, 1)`. + + The probabilities would be clamped between `eps` and `1 - eps`, + and `eps` would be the smallest representable positive number for the input data type. + + Args: + probs (Tensor): A tensor of probabilities. + + Returns: + Tensor: The clamped probabilities. + + Examples: + >>> probs = torch.tensor([0.0, 0.5, 1.0]) + >>> clamp_probs(probs) + tensor([1.1921e-07, 5.0000e-01, 1.0000e+00]) + + >>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64) + >>> clamp_probs(probs) + tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64) + + """ eps = torch.finfo(probs.dtype).eps return probs.clamp(min=eps, max=1 - eps) From 65aa16f968af2cd18ff8c25cc657e7abda594bfc Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 6 Jun 2024 14:59:49 -0700 Subject: [PATCH 443/706] Revert "Default XLA to use swap_tensors path in nn.Module._apply (#126814)" (#128170) https://github.com/pytorch/pytorch/issues/128165 :( This reverts commit a7b1dd82ff3063894fc665ab0c424815231c10e6. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128170 Approved by: https://github.com/drisspg, https://github.com/albanD --- test/test_nn.py | 4 ++-- torch/nn/modules/module.py | 10 ++-------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 6bcb4017e4b5..6dfac4f7ca1b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8184,9 +8184,9 @@ def test_batchnorm_large_batch(self, device, dtype): @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128) def test_conv_empty_input(self, device, dtype): def help(input, conv, memory_format): - ref_out = conv(input).detach() + ref_out = conv(input) conv_cl = conv.to(memory_format=memory_format) - out_cl = conv_cl(input).detach() + out_cl = conv_cl(input) self.assertEqual(ref_out, out_cl) input_cl = input.to(memory_format=memory_format) out_cl2 = conv(input_cl) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 3d683cb82181..ffd429cc06f2 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -794,13 +794,6 @@ def compute_should_use_set_data(tensor, tensor_applied): should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() - def compute_should_use_swap_tensors(tensor, tensor_applied): - return (should_use_swap_tensors - # subclasses may have multiple child tensors so we need to use swap_tensors - or is_traceable_wrapper_subclass(tensor_applied) - or tensor.device.type == 'xla' - or tensor_applied.device.type == 'xla') - for key, param in self._parameters.items(): if param is None: continue @@ -811,7 +804,8 @@ def compute_should_use_swap_tensors(tensor, tensor_applied): param_applied = fn(param) p_should_use_set_data = compute_should_use_set_data(param, param_applied) - p_should_use_swap_tensors = compute_should_use_swap_tensors(param, param_applied) + # subclasses may have multiple child tensors so we need to use swap_tensors + p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) param_grad = param.grad if p_should_use_swap_tensors: From 50155e825be4a8655920451060e5ab5ddea1ca86 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Fri, 7 Jun 2024 03:29:06 +0000 Subject: [PATCH 444/706] [export] provide refine function for automatically accepting dynamic shapes suggested fixes (#127436) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Part of the work helping export's automatic dynamic shapes / dynamic shapes refining based on suggested fixes. Introduces a util function refine_dynamic_shapes_from_suggested_fixes() that takes the error message from a ConstraintViolationError message containing suggested dynamic shapes fixes, along with the original dynamic shapes spec, and returns the new spec. Written so that the suggested fixes from export can be directly parsed and used. Example usage for the automatic dynamic shapes workflow: ``` # export, fail, parse & refine suggested fixes, re-export try: export(model, inps, dynamic_shapes=dynamic_shapes) except torch._dynamo.exc.UserError as exc: new_shapes = refine_dynamic_shapes_from_suggested_fixes(exc.msg, dynamic_shapes) export(model, inps, dynamic_shapes=new_shapes) ``` For examples of behavior, see the added test and docstring. Will take suggestions for renaming the function to something else 😅 Test Plan: test_export tests Differential Revision: D57409142 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127436 Approved by: https://github.com/avikchaudhuri --- docs/source/export.rst | 1 + test/export/test_export.py | 131 +++++++++++++++++ torch/export/dynamic_shapes.py | 171 ++++++++++++++++++++++- torch/fx/experimental/symbolic_shapes.py | 17 +-- 4 files changed, 306 insertions(+), 14 deletions(-) diff --git a/docs/source/export.rst b/docs/source/export.rst index c6134d187b66..29069d3228e4 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -683,6 +683,7 @@ API Reference .. automethod:: dynamic_shapes +.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes .. autoclass:: Constraint .. autoclass:: ExportedProgram diff --git a/test/export/test_export.py b/test/export/test_export.py index 228db50b9dc4..5b0c93135ba7 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1969,6 +1969,137 @@ def forward(self, x, y, z): dynamic_shapes = {"x": (3 * _dx - 1,), "y": (3 * _dx,), "z": (3 * _dx + 2,)} export(Foo(), inputs, dynamic_shapes=dynamic_shapes) + def test_refine_dynamic_shapes_from_suggested_fixes(self): + from torch.export.dynamic_shapes import ( + refine_dynamic_shapes_from_suggested_fixes, + ) + + def helper(model, inputs, dynamic_shapes): + # export, fail, parse & refine suggested fixes, re-export + try: + export(Foo(), inps, dynamic_shapes=dynamic_shapes) + raise Exception("should have raised constraint violation error") + except torch._dynamo.exc.UserError as exc: + new_shapes = refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + export(Foo(), inps, dynamic_shapes=new_shapes) + return new_shapes + + # specialize dims + derived dims + class Foo(torch.nn.Module): + def forward(self, x, y, z): + x0 = x + y[1:] + z[2:] + x1 = x @ torch.randn(4, 4) + return x0, x1 + + inps = ( + torch.randn( + 4, + ), + torch.randn( + 5, + ), + torch.randn( + 6, + ), + ) + dx = Dim("dx", max=16) + dynamic_shapes = {"x": (dx,), "y": (dx + 1,), "z": (dx + 2,)} + new_shapes = helper(Foo(), inps, dynamic_shapes) + self.assertEqual(new_shapes["x"][0], 4) + self.assertEqual(new_shapes["z"][0], 6) + + # refine lower, upper bound + class Foo(torch.nn.Module): + def forward(self, x, y): + if x.shape[0] >= 6 and y.shape[0] <= 16: + return x * 2.0, y + 1 + + inps = (torch.randn(16), torch.randn(12)) + dynamic_shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)} + new_shapes = helper(Foo(), inps, dynamic_shapes) + self.assertEqual(new_shapes["x"][0].min, 6) + self.assertEqual(new_shapes["y"][0].max, 16) + + # divisiblity, will introduce new root + class Foo(torch.nn.Module): + def forward(self, x): + if x.shape[0] >= 9: + return x.reshape([-1, 3]) + + inps = ( + torch.randn( + 15, + ), + ) + dynamic_shapes = ((Dim("dx"),),) + new_shapes = helper(Foo(), inps, dynamic_shapes) + dim = new_shapes[0][0] + root = dim.root + self.assertEqual(dim.fn(2), 6) + self.assertEqual(root.min, 3) + + # turn dim into derived dim/relation + class Foo(torch.nn.Module): + def forward(self, x, y): + return x + y[4:] + + inps = (torch.randn(6, 4), torch.randn(10, 4)) + dynamic_shapes = { + "x": (Dim("dx0"), Dim("dx1")), + "y": (Dim("dy0"), Dim("dy1")), + } + new_shapes = helper(Foo(), inps, dynamic_shapes) + self.assertEqual(new_shapes["x"][0], new_shapes["y"][0].root) # dy0 = dx0 + 4 + self.assertEqual(new_shapes["y"][0].fn(5), 9) + self.assertEqual(new_shapes["x"][1], new_shapes["y"][1]) # dx1 = dy1 + + # nested dynamic shapes spec + class Foo(torch.nn.Module): + def forward(self, x, y): + x0 = x[0]["data"] + x[1] + x[2][2:] + x1 = y["a"] @ torch.randn(4, 4) + x2 = y["b"] @ torch.randn(6, 6) + return x0, x1, x2 + + inps = ( + [ + {"data": torch.randn(4, 4)}, + torch.randn(4, 4), + torch.randn(6, 4), + ], + { + "a": torch.randn(8, 4), + "b": torch.randn(9, 6), + }, + ) + dynamic_shapes = { + "x": [ + {"data": (Dim("dx00"), Dim("dx01"))}, + (Dim("dx10"), Dim("dx11")), + (Dim("dx20"), Dim("dx21")), + ], + "y": { + "a": (Dim("dya0"), Dim("dya1")), + "b": (Dim("dyb0"), Dim("dyb1")), + }, + } + new_shapes = helper(Foo(), inps, dynamic_shapes) + self.assertEqual( + new_shapes["x"][0]["data"][0], new_shapes["x"][1][0] + ) # dx10 = dx00 + self.assertEqual( + new_shapes["x"][2][0].root, new_shapes["x"][0]["data"][0] + ) # dx20 = dx00 + 2 + self.assertEqual(new_shapes["x"][2][0].fn(10), 12) + self.assertEqual( + new_shapes["x"][0]["data"][1], new_shapes["x"][1][1] + ) # dx11 = dx01 + self.assertEqual(new_shapes["y"]["a"][1], 4) + self.assertEqual(new_shapes["y"]["b"][1], 6) + self.assertEqual(new_shapes["y"]["b"][0].__name__, "dyb0") # unchanged + def test_dynamic_shapes_spec_with_pytree(self): from torch.export import Dim, export from torch.utils._pytree import tree_map diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index ac2bdd60a550..e98e83af340f 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -4,10 +4,16 @@ import sys import weakref from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union import torch -from torch.utils._pytree import _get_node_type, BUILTIN_TYPES, SUPPORTED_NODES, tree_map +from torch.utils._pytree import ( + _get_node_type, + BUILTIN_TYPES, + SUPPORTED_NODES, + tree_flatten, + tree_map, +) from .exported_program import ExportedProgram @@ -18,7 +24,13 @@ from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint -__all__ = ["Constraint", "Dim", "dims", "dynamic_dim"] +__all__ = [ + "Constraint", + "Dim", + "dims", + "dynamic_dim", + "refine_dynamic_shapes_from_suggested_fixes", +] class _Dim(type): @@ -897,3 +909,156 @@ def assoc_shape(t, dynamic_shape): constraints.append(primary) return constraints # type: ignore[return-value] + + +def _get_dim_name_mapping( + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None] +): + name_to_dim = {} + for dim in tree_flatten( + dynamic_shapes, + is_leaf=lambda x: isinstance(x, _Dim), + )[0]: + if dim is None or isinstance(dim, int): + continue + name_to_dim[dim.__name__] = dim + if isinstance(dim, _DerivedDim): + name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined] + return name_to_dim + + +def refine_dynamic_shapes_from_suggested_fixes( + msg: str, + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], +) -> Union[Dict[str, Any], Tuple[Any], List[Any]]: + """ + For working with export's dynamic shapes suggested fixes, and/or automatic dynamic shapes. + Refines the given dynamic shapes spec, given a ConstraintViolation error message and the original dynamic shapes. + + For most cases behavior is straightforward - i.e. for suggested fixes that specialize or refine a Dim's range, + or fixes that suggest a derived relation, the new dynamic shapes spec will be updated as such. + + e.g. + Suggested fixes: + + dim = Dim('dim', min=3, max=6) -> this just refines the dim's range + dim = 4 -> this specializes to a constant + dy = dx + 1 -> dy was specified as an independent dim, but is actually tied to dx with this relation + + However, suggested fixes associated with derived dims can be more complicated. + For example, if a suggested fix is provided for a root dim, the new derived dim value is evaluated based on the root. + + e.g. + dx = Dim('dx') + dy = dx + 2 + dynamic_shapes = {"x": (dx,), "y": (dy,)} + + Suggested fixes: + + dx = 4 # specialization will lead to dy also specializing = 6 + dx = Dim('dx', max=6) # dy now has max = 8 + + Derived dims suggested fixes can also be used to express divisibility constraints. + This involves creating new root dims that aren't tied to a particular input shape. + In this case the root dims won't appear directly in the new spec, but as a root of + one of the dims. + + e.g. + Suggested fixes: + + _dx = Dim('_dx', max=1024) # this won't appear in the return result, but dx will + dx = 4*_dx # dx is now divisible by 4, with a max value of 4096 + """ + + import re + + import sympy + + from torch._dynamo.exc import UserError, UserErrorType + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + try: + shape_fixes_msg = msg.split("Suggested fixes:")[1].strip() + except Exception as exc: + raise UserError( + UserErrorType.INVALID_INPUT, + "Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()", + ) from exc + + # build shape_fixes dictionary + shape_fixes = {} + for fix in shape_fixes_msg.split("\n"): + fix = fix.strip() + if match := re.match(r"(.*) = Dim\('(.*)'.*\)", fix): + name = match.group(1) + _min, _max = None, None + if match_min := re.match(r".* = Dim\('.*', min\=([0-9]+).*\)", fix): + _min = int(match_min.group(1)) + if match_max := re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)", fix): + _max = int(match_max.group(1)) + shape_fixes[name] = Dim(name, min=_min, max=_max) + else: + name, expr = fix.split(" = ") + expr = sympy.sympify(expr) + if isinstance(expr, sympy.Number): + shape_fixes[name] = int(expr) # static, integer + else: + shape_fixes[name] = expr # relation or derived dim + + name_to_dim = _get_dim_name_mapping(dynamic_shapes) + + # track derived dim roots + roots: Set[str] = set() + for k, c in shape_fixes.items(): + assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr)) + if isinstance(c, sympy.Expr): # check dim/derived dim expression + assert _is_supported_equivalence(c) + shape_fixes[k] = c + roots.add(str(next(iter(c.free_symbols)))) + if isinstance(c, _DerivedDim): + roots.add(c.root.__name__) # type: ignore[attr-defined] + + # check keys are existing dims or new roots + for k, c in shape_fixes.items(): + assert k in name_to_dim or k in roots + + # cache so we don't produce multiple derived dim objects + derived_dim_cache: Dict[str, _DerivedDim] = {} + + def apply_fixes(dim, dummy): + if dim is None or isinstance(dim, int): # not dynamic + return dim + elif dim.__name__ in shape_fixes: # directly fix + fix = shape_fixes[dim.__name__] + if isinstance(fix, sympy.Expr): # now derived or related + if str(fix) in derived_dim_cache: + return derived_dim_cache[str(fix)] + else: + symbol = next(iter(fix.free_symbols)) + # try to locate symbol + if symbol.name in shape_fixes: # type: ignore[attr-defined] + root = shape_fixes[symbol.name] # type: ignore[attr-defined] + else: + assert symbol.name in name_to_dim # type: ignore[attr-defined] + root = name_to_dim[symbol.name] # type: ignore[attr-defined] + # figure out value of fix + modulus, remainder = sympy.polys.polytools.div(fix, symbol) + dim = root + if modulus != 1: + dim = int(modulus) * dim + if remainder != 0: + dim = dim + int(remainder) + derived_dim_cache[str(fix)] = dim + return dim + else: + return fix + elif isinstance(dim, _DerivedDim) and dim.root.__name__ in shape_fixes: # type: ignore[attr-defined] + if dim.__name__ in derived_dim_cache: + return derived_dim_cache[dim.__name__] + else: # evaluate new derived value based on root + _dim = dim.fn(shape_fixes[dim.root.__name__]) # type: ignore[attr-defined] + derived_dim_cache[dim.__name__] = _dim + return _dim + return dim # unchanged dim + + return _tree_map(apply_fixes, dynamic_shapes, dynamic_shapes) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index e2573bcd3ef9..42ab606e7827 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1988,7 +1988,10 @@ def _check_same_range(c, dim): return ( self._is_dim(dim) and ("min" in c or "max" in c) - and (dim.min < 2 or dim.min == c.get("min", 2)) # let pass if min < 2 + and ( + (dim.min < 2 and c.get("min", 2) == 2) + or dim.min == c.get("min", 2) + ) # let pass if analysis min = 2 and specified min = 0/1 and dim.max == c.get("max", sys.maxsize - 1) ) @@ -2116,6 +2119,7 @@ def prettify_results( forced_specializations=None, ): """Format a message for constraint violation erros""" + from torch.export.dynamic_shapes import _get_dim_name_mapping if self._dcp.source_name_to_debug_name: def transform(s, inverse=False): @@ -2153,16 +2157,7 @@ def relation_with_digit(expr, op, digit): results[expr]["eq"] = digit # retrieve dynamic shapes - name_to_dim = {} - for dim in pytree.tree_flatten( - dynamic_shapes, - is_leaf=lambda x: self._is_derived_dim(x) or self._is_dim(x), - )[0]: - if dim is None or isinstance(dim, int): - continue - name_to_dim[dim.__name__] = dim - if self._is_derived_dim(dim): - name_to_dim[dim.root.__name__] = dim.root + name_to_dim = _get_dim_name_mapping(dynamic_shapes) for s in self._static_results.union(self._dynamic_results): t = transform(s) From 476bfe6cce545955276224ead4b9a4003e46abf4 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 5 Jun 2024 17:59:04 -0700 Subject: [PATCH 445/706] fix torch.compile with triton kernels under inference_mode (#124489) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124489 Approved by: https://github.com/albanD --- test/inductor/test_triton_kernels.py | 17 +++++++++++++++++ .../_aot_autograd/traced_function_transforms.py | 11 ++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 113f1daea0f2..58ef3d4e84bc 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -1549,6 +1549,23 @@ def argmax_kernel(a_ptr, c_ptr, stride_am, stride_an): expected, ) + @requires_cuda + @skipIfRocm + def test_triton_kernel_inference_mode(self): + def f(x, y, out): + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=4) + + with torch.inference_mode(): + x = torch.ones(32, device="cuda") + y = torch.ones(32, device="cuda") + out_ref = torch.zeros_like(x) + out_test = torch.zeros_like(x) + f(x, y, out_ref) + torch.compile(f)(x, y, out_test) + self.assertEqual(out_ref, out_test) + @make_mutation_test def test_cumsum(): @triton.jit diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py index c673acdabe12..27d3f2c9ad99 100644 --- a/torch/_functorch/_aot_autograd/traced_function_transforms.py +++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -546,9 +546,14 @@ def _functionalized_f_helper(*args): and meta.input_info[i].mutations_hidden_from_autograd ): # Hidden from autograd = run under no_grad, **and** don't bump VC - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - inpt_old - ): + # (although if the tensor was created in inference mode, it has no VC) + if inpt_old.is_inference(): + maybe_preserve_vc = nullcontext() + else: + maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( + inpt_old # type: ignore[assignment] + ) + with torch.no_grad(), maybe_preserve_vc: inpt_old.copy_(inpt_new) elif ( meta.input_info[i].mutates_data From 4d0ece81963e0f7b7714b1218441c0ee29748bd8 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 6 Jun 2024 16:18:55 -0700 Subject: [PATCH 446/706] [pipelining] Consolidate chunk counting between stage and schedule (#127935) We used to have two backward chunk id counting systems, one at schedule level, the other at stage level. (Which makes safety dependent on the two advancing hand-in-hand.) This PR consolidates the counting system to the schedule side only, which would pass `mb_index` to the following stage calls: `forward_one_chunk` `backward_one_chunk` `get_bwd_send_ops` ... Pull Request resolved: https://github.com/pytorch/pytorch/pull/127935 Approved by: https://github.com/H-Huang --- .../pipelining/PipelineSchedule.py | 62 +++++++------- torch/distributed/pipelining/PipelineStage.py | 80 ++++++++----------- 2 files changed, 66 insertions(+), 76 deletions(-) diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index 8d696a5aa2b9..2e6856e25151 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -353,14 +353,14 @@ def _step_microbatches( # Run microbatches for i in range(self._n_microbatches): with record_function(f"Forward {i}"): - ops = self._stage.get_fwd_recv_ops() + ops = self._stage.get_fwd_recv_ops(i) works = _sorted_batch_p2p(ops, desc="fwd_recv") for work in works.values(): work.wait() - output = self._stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] - ops = self._stage.get_fwd_send_ops() + ops = self._stage.get_fwd_send_ops(i) works = _sorted_batch_p2p(ops, desc="fwd_send") fwd_sends_to_wait.extend(works.values()) @@ -388,15 +388,15 @@ def _step_microbatches( self._stage._configure_data_parallel_mode(i == self._n_microbatches - 1) with record_function(f"Backward {i}"): - ops = self._stage.get_bwd_recv_ops() + ops = self._stage.get_bwd_recv_ops(i) works = _sorted_batch_p2p(ops, desc="bwd_recv") for work in works.values(): work.wait() loss = self._maybe_get_loss(self._stage, i) - self._stage.backward_one_chunk(loss=loss) + self._stage.backward_one_chunk(i, loss=loss) - ops = self._stage.get_bwd_send_ops() + ops = self._stage.get_bwd_send_ops(i) works = _sorted_batch_p2p(ops, desc="bwd_send") bwd_sends_to_wait.extend(works.values()) @@ -450,12 +450,12 @@ def _step_microbatches( fwd_sends = [] for _ in range(warmup_chunks): # Receive activations - fwd_recvs = self._stage.get_fwd_recv_ops() + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"): recv_work.wait() # Compute - output = self._stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] # Clear previous chunk's forward sends (hopefully they have well # finished, otherwise, we are heavily communication bound, in which @@ -465,7 +465,7 @@ def _step_microbatches( send_work.wait() # Send activations - fwd_sends = self._stage.get_fwd_send_ops() + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) if fwd_mb_index != warmup_chunks - 1: # Safe to fire send_work = _batch_p2p(fwd_sends, desc="fwd_send") @@ -481,7 +481,7 @@ def _step_microbatches( # 1B1F phase while True: # Don't worry, we have a break inside # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops - bwd_recvs = self._stage.get_bwd_recv_ops() + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) # Now, we need to fire the fwd_sends and bwd_recvs together if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"): @@ -489,10 +489,10 @@ def _step_microbatches( # Backward one chunk loss = self._maybe_get_loss(self._stage, bwd_mb_index) - self._stage.backward_one_chunk(loss=loss) + self._stage.backward_one_chunk(bwd_mb_index, loss=loss) # Get the bwd send ops, but don't fire, to be fused with the 1F below - bwd_sends = self._stage.get_bwd_send_ops() + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) bwd_mb_index += 1 if fwd_mb_index == self._n_microbatches: @@ -500,20 +500,20 @@ def _step_microbatches( break # We prepare 1F of the `1B1F` - fwd_recvs = self._stage.get_fwd_recv_ops() + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) # Fuse it with bwd_sends above if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"): fuse_work.wait() # Now do the fwd - output = self._stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] # Compute loss self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around) - fwd_sends = self._stage.get_fwd_send_ops() + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) fwd_mb_index += 1 # Remember we still have some bwd_sends left over after the break? Now it is time to fire it @@ -522,20 +522,20 @@ def _step_microbatches( # Cooldown while bwd_mb_index < self._n_microbatches: # prepare bwd recv ops - bwd_recvs = self._stage.get_bwd_recv_ops() + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"): recv_work.wait() # Backward one chunk loss = self._maybe_get_loss(self._stage, bwd_mb_index) - self._stage.backward_one_chunk(loss=loss) + self._stage.backward_one_chunk(bwd_mb_index, loss=loss) # Clear previous chunk's backward sends (hopefully they have well finished) if send_work: send_work.wait() # Get the bwd send ops, fire it - bwd_sends = self._stage.get_bwd_send_ops() + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) send_work = _batch_p2p(bwd_sends, desc="bwd_send") bwd_mb_index += 1 @@ -650,14 +650,14 @@ def _step_microbatches( for stage in self._stages: for i in range(self._n_microbatches): with record_function(f"Stage {stage.stage_index} Forward"): - ops = stage.get_fwd_recv_ops() + ops = stage.get_fwd_recv_ops(i) if ops: _batch_p2p(ops, desc="fwd_recv").wait() - output = stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) + output = stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) self._maybe_compute_loss(stage, output, target_mbs, i) - ops = stage.get_fwd_send_ops() + ops = stage.get_fwd_send_ops(i) if ops: _batch_p2p(ops, desc="fwd_send") @@ -665,14 +665,14 @@ def _step_microbatches( for i in range(self._n_microbatches): stage._configure_data_parallel_mode(i == self._n_microbatches - 1) with record_function(f"Stage {stage.stage_index} Backward"): - ops = stage.get_bwd_recv_ops() + ops = stage.get_bwd_recv_ops(i) if ops: _batch_p2p(ops, desc="bwd_recv").wait() loss = self._maybe_get_loss(stage, i) - stage.backward_one_chunk(loss=loss) + stage.backward_one_chunk(i, loss=loss) - ops = stage.get_bwd_send_ops() + ops = stage.get_bwd_send_ops(i) if ops: _batch_p2p(ops, desc="bwd_send") @@ -719,7 +719,7 @@ def __init__( # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} - # ======================================================================== + for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) self.pipeline_order[rank] = rank_ops @@ -871,10 +871,10 @@ def _step_microbatches( # perform forward computation stage = stage_index_to_stage[stage_index] output = stage.forward_one_chunk( - arg_mbs[mb_index], kwarg_mbs[mb_index] + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] ) self._maybe_compute_loss(stage, output, target_mbs, mb_index) - ops.extend(stage.get_fwd_send_ops()) + ops.extend(stage.get_fwd_send_ops(mb_index)) elif computation_type == _ComputationType.BACKWARD: # perform backward computation stage = stage_index_to_stage[stage_index] @@ -882,8 +882,8 @@ def _step_microbatches( mb_index == self._n_microbatches - 1 ) loss = self._maybe_get_loss(stage, mb_index) - stage.backward_one_chunk(loss=loss) - ops.extend(stage.get_bwd_send_ops()) + stage.backward_one_chunk(mb_index, loss=loss) + ops.extend(stage.get_bwd_send_ops(mb_index)) else: raise ValueError(f"Unknown computation type {computation_type}") @@ -901,7 +901,7 @@ def _step_microbatches( # TODO: We are assuming that stage will always receive from stage-1 # however that is not necessarily true of get_fwd_recv_ops stage = stage_index_to_stage[stage_index + 1] - ops.extend(stage.get_fwd_recv_ops()) + ops.extend(stage.get_fwd_recv_ops(mb_index)) elif computation_type == _ComputationType.BACKWARD: # Previous rank doing backward has no influence for the current rank forward recv pass @@ -923,7 +923,7 @@ def _step_microbatches( # TODO: We are assuming that stage will always receive from stage+1 # however that is not necessarily true of get_bwd_recv_ops stage = stage_index_to_stage[stage_index - 1] - ops.extend(stage.get_bwd_recv_ops()) + ops.extend(stage.get_bwd_recv_ops(mb_index)) else: raise ValueError(f"Unknown computation type {computation_type}") diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index f5aac602faba..50e31ce7d471 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -128,18 +128,9 @@ def __init__( self._outputs_meta: Optional[Tuple[torch.Tensor, ...]] = None # map microbatch ID to list of forward tensor args self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} - # Current forward chunk id to be used in computation - self.fwd_chunk_id: int = 0 - # Current backward chunk id to be used in computation - self.bwd_chunk_id: int = 0 # Caching chunk outputs for final output merge or reduction self.output_chunks: List[Any] = [] - # Current forward chunk id to be used in recv - self.recv_fwd_chunk_id: int = 0 - # Current backward chunk id to be used in recv - self.recv_bwd_chunk_id: int = 0 - # Create stage id to group rank mapping # In interleaved case, `group_rank` is stage index % group size. self.stage_index_to_group_rank: Dict[int, int] = {} @@ -189,6 +180,12 @@ def is_last(self): """ return self.stage_index == self.num_stages - 1 + def _check_chunk_id(self, chunk_id: int): + if chunk_id >= self.chunks: + raise RuntimeError( + f"Chunk id {chunk_id} is out of range [0, {self.chunks})" + ) + def _configure_outputs_meta(self, outputs_meta: Tuple[torch.Tensor, ...]): """ Track the output shapes/dtype of this stage since they determine the send operation(s) which must match @@ -267,24 +264,23 @@ def _get_recv_ops( return ops - def get_fwd_recv_ops(self) -> List[dist.P2POp]: + def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: """ Returns a list of ops that are needed to receive the input arguments for this stage. """ - recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[self.recv_fwd_chunk_id] + recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id] # In case there is backward pass, set requires_grad for receive buffers # before first forward - if self.has_backward and not self.set_requires_grad[self.recv_fwd_chunk_id]: + if self.has_backward and not self.set_requires_grad[fwd_chunk_id]: for a in recv_infos: if isinstance(a, _RecvInfo): a.buffer.requires_grad_(True) - self.recv_fwd_chunk_id += 1 return self._get_recv_ops(recv_infos) - def get_bwd_recv_ops(self) -> List[dist.P2POp]: + def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: """ Returns a list of ops that are needed to receive the gradients for this stage. @@ -294,20 +290,18 @@ def get_bwd_recv_ops(self) -> List[dist.P2POp]: # Create bwd recv infra lazily recv_infos = self.grad_recv_info.setdefault( - self.recv_bwd_chunk_id, + bwd_chunk_id, # `grad_recv_info` is a mirror of `act_send_info` self._create_grad_recv_info(self.act_send_info), ) - self.recv_bwd_chunk_id += 1 return self._get_recv_ops(recv_infos) - def get_fwd_send_ops(self) -> List[dist.P2POp]: + def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: """ Get the activation send ops for current stage's forward. """ - # Use "-1" to get the outputs created by the last chunk - output = self.output_chunks[-1] + output = self.output_chunks[fwd_chunk_id] # Unify output form to tuple for easy correspondance with # `act_send_info` output_tuple = output if type(output) is tuple else (output,) @@ -333,10 +327,12 @@ def get_fwd_send_ops(self) -> List[dist.P2POp]: return ops - def get_bwd_send_ops(self) -> List[dist.P2POp]: + def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: """ Get the gradient send ops for current stage's backward. """ + self._check_chunk_id(bwd_chunk_id) + if not self.has_backward or self.is_first: return [] @@ -365,7 +361,7 @@ def get_bwd_send_ops(self) -> List[dist.P2POp]: else: if not (grad is None and grad_recv_stage is None): raise RuntimeError( - f"[{self.stage_index}] for chunk {self.bwd_chunk_id - 1} has gradients {grad} " + f"[{self.stage_index}] for chunk {bwd_chunk_id - 1} has gradients {grad} " f"and is expecting to send gradients to stage {grad_recv_stage}" ) return ops @@ -374,11 +370,6 @@ def clear_runtime_states(self) -> None: """ Clear runtime states of the stage. """ - # Reset pointers - self.fwd_chunk_id = 0 - self.bwd_chunk_id = 0 - self.recv_fwd_chunk_id = 0 - self.recv_bwd_chunk_id = 0 # map microbatch ID to list of forward tensor args self.fwd_cache.clear() # Caching chunk outputs for final output merge or reduction @@ -416,23 +407,22 @@ def get_recv_tensor(info): return tensors - def _retrieve_recv_activations( - self, - ): + def _retrieve_recv_activations(self, fwd_chunk_id: int): """ Retrieve the activations received for the current stage during forward. """ - recv_infos = self.args_recv_info[self.fwd_chunk_id] + recv_infos = self.args_recv_info[fwd_chunk_id] activations = self._map_tensor_from_recv_info(recv_infos) return activations def _retrieve_recv_grads( self, + bwd_chunk_id: int, ): """ Retrieve the gradients received for the current stage during backward. """ - recv_infos = self.grad_recv_info[self.bwd_chunk_id] + recv_infos = self.grad_recv_info[bwd_chunk_id] grads = self._map_tensor_from_recv_info(recv_infos) return grads @@ -481,6 +471,7 @@ def backward_maybe_with_nosync(self, bwd_kwargs: Dict, bwd_chunk_id: int): def forward_one_chunk( self, + fwd_chunk_id: int, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, ): @@ -497,7 +488,7 @@ def forward_one_chunk( else: # Receive activations for this chunk # Activations only come in args form - composite_args = self._retrieve_recv_activations() + composite_args = self._retrieve_recv_activations(fwd_chunk_id) composite_kwargs = {} self._validate_fwd_input(args, kwargs) @@ -529,30 +520,32 @@ def forward_one_chunk( flat_args = flatten_args(composite_args) flat_kwargs = flatten_args(composite_kwargs) flatten_input_tensors = flat_args + flat_kwargs - self.fwd_cache[self.fwd_chunk_id] = ( + self.fwd_cache[fwd_chunk_id] = ( output_tuple, # stage_output flatten_input_tensors, # input_values ) logger.debug( - f"{self.log_prefix} Forwarded chunk {self.fwd_chunk_id}, outputs: {map_debug_info(output)}" # noqa: G004 + f"{self.log_prefix} Forwarded chunk {fwd_chunk_id}, outputs: {map_debug_info(output)}" # noqa: G004 ) self._validate_fwd_outputs(output_tuple) - self.fwd_chunk_id += 1 return output def backward_one_chunk( self, + bwd_chunk_id: int, loss=None, ): """ Perform backward pass on the module. This should only be called once per microbatch. """ + self._check_chunk_id(bwd_chunk_id) + ( stage_output, input_values, - ) = self.fwd_cache.pop(self.bwd_chunk_id) + ) = self.fwd_cache.pop(bwd_chunk_id) # Compute backward if self.is_last: @@ -565,7 +558,7 @@ def backward_one_chunk( } else: # Otherwise, receive gradients from next stage - grads_output = self._retrieve_recv_grads() + grads_output = self._retrieve_recv_grads(bwd_chunk_id) # If an input to the pipeline requires gradient, # `torch.autograd.backward` will accumulate the gradient into the # `.grad` field of such input @@ -575,20 +568,17 @@ def backward_one_chunk( "input_values": input_values, } - self.grads_input = self.backward_maybe_with_nosync( - bwd_kwargs, self.bwd_chunk_id - ) - logger.debug( - f"{self.log_prefix} Backwarded chunk {self.bwd_chunk_id}" # noqa: G004 - ) - self.bwd_chunk_id += 1 + self.grads_input = self.backward_maybe_with_nosync(bwd_kwargs, bwd_chunk_id) + logger.debug(f"{self.log_prefix} Backwarded chunk {bwd_chunk_id}") # noqa: G004 def _validate_fwd_input(self, args, kwargs): """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" if self.is_first: # TODO why is there a separate recv_info for each pipeline chunk? - expected_args = self.args_recv_info[self.fwd_chunk_id] + # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we + # check all chunks against args_recv_info[0] + expected_args = self.args_recv_info[0] else: # We don't check inputs for non-0 stages assuming they don't accept # user inputs in canonical pipeline scenarios From 5e5bbdb35ea73daa3b63458c62ae394fb6dc7315 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Fri, 7 Jun 2024 03:33:31 +0000 Subject: [PATCH 447/706] [DDP] Bucket handling: make first bucket size equal to bucket_cap_mb if it was set (#121640) The fist DDP bucket is always being created of the size of `dist._DEFAULT_FIRST_BUCKET_BYTES` (1 MiB) by default regardless of `bucket_cap_mb`. The proposal is to set `bucket_cap_mb` as the one main bucket size if it was supplied by the user. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121640 Approved by: https://github.com/wanchaol --- torch/nn/parallel/distributed.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index ef6034ade58e..069be22991cd 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -548,7 +548,8 @@ class DistributedDataParallel(Module, Joinable): multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation. :attr:`bucket_cap_mb` controls the bucket size in - MegaBytes (MB). (default: 25) + MebiBytes (MiB). If ``None``, a default size of 25 MiB + will be used. (default: ``None``) find_unused_parameters (bool): Traverse the autograd graph from all tensors contained in the return value of the wrapped module's ``forward`` function. Parameters @@ -631,7 +632,7 @@ def __init__( dim=0, broadcast_buffers=True, process_group=None, - bucket_cap_mb=25, + bucket_cap_mb=None, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, @@ -788,7 +789,14 @@ def __init__( self.broadcast_bucket_size = int(250 * 1024 * 1024) # reduction bucket size + if bucket_cap_mb is None: + # default case (bucket cap is 25 MiB) + bucket_cap_mb = 25 + self.bucket_bytes_cap_default = True + else: + self.bucket_bytes_cap_default = False self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) + # Whether to perform input tensor CPU to GPU copies on a side-stream self.use_side_stream_for_tensor_copies = ( os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1" @@ -1156,10 +1164,13 @@ def _ddp_init_helper( if static_graph is True or self.find_unused_parameters is False: bucket_size_limits = [sys.maxsize] else: - bucket_size_limits = [ - dist._DEFAULT_FIRST_BUCKET_BYTES, - self.bucket_bytes_cap, - ] + if self.bucket_bytes_cap_default: + bucket_size_limits = [ + dist._DEFAULT_FIRST_BUCKET_BYTES, + self.bucket_bytes_cap, + ] + else: + bucket_size_limits = [self.bucket_bytes_cap] ( bucket_indices, per_bucket_size_limits, @@ -1195,7 +1206,9 @@ def _ddp_init_helper( param_to_name_mapping, # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first # bucket. - dist._DEFAULT_FIRST_BUCKET_BYTES, + dist._DEFAULT_FIRST_BUCKET_BYTES + if self.bucket_bytes_cap_default + else self.bucket_bytes_cap, ) self.logger = dist.Logger(self.reducer) From 747fc35ff54154ddec2a5ab5661f57c28d65c591 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 6 Jun 2024 13:22:35 -0700 Subject: [PATCH 448/706] [dynamo] Support if cond on UnspecializedNNModuleVariable and add inline tests (#128158) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128158 Approved by: https://github.com/jansel ghstack dependencies: #128001, #126578 --- test/dynamo/test_inline_inbuilt_nn_modules.py | 62 +++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 8 ++- 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 test/dynamo/test_inline_inbuilt_nn_modules.py diff --git a/test/dynamo/test_inline_inbuilt_nn_modules.py b/test/dynamo/test_inline_inbuilt_nn_modules.py new file mode 100644 index 000000000000..f7ba32bc15f3 --- /dev/null +++ b/test/dynamo/test_inline_inbuilt_nn_modules.py @@ -0,0 +1,62 @@ +# Owner(s): ["module: dynamo"] + +from torch._dynamo import config +from torch._dynamo.testing import make_test_cls_with_patches + +try: + from . import ( + test_aot_autograd, + test_functions, + test_higher_order_ops, + test_misc, + test_modules, + # test_repros, + ) +except ImportError: + import test_aot_autograd + import test_functions + import test_higher_order_ops + import test_misc + import test_modules + + +test_classes = {} + + +def make_inline_inbuilt_nn_modules_cls(cls): + suffix = "_inline_inbuilt_nn_modules" + + cls_prefix = "InlineInbuiltNNModules" + + test_class = make_test_cls_with_patches( + cls, + cls_prefix, + suffix, + (config, "inline_inbuilt_nn_modules", True), + xfail_prop="_expected_failure_inline_inbuilt_nn_modules", + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + return test_class + + +tests = [ + test_misc.MiscTests, + test_functions.FunctionTests, + test_modules.NNModuleTests, + test_higher_order_ops.HigherOrderOpTests, + test_higher_order_ops.FuncTorchHigherOrderOpTests, + test_aot_autograd.AotAutogradFallbackTests, + # test_repros.ReproTests, +] +for test in tests: + make_inline_inbuilt_nn_modules_cls(test) +del test + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 30f28e2ab265..da04fdfa8584 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -101,7 +101,7 @@ PythonModuleVariable, UnknownVariable, ) -from .variables.nn_module import NNModuleVariable +from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable from .variables.user_defined import ( RemovableHandleVariable, @@ -414,6 +414,12 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if push: self.push(value) self.jump(inst) + elif isinstance(value, UnspecializedNNModuleVariable): + mod = value.value + if truth_fn(mod): + if push: + self.push(value) + self.jump(inst) elif isinstance(value, UserDefinedObjectVariable): try: x = value.var_getattr(self, "__bool__") From 3df53c2a8f53f7846ce14e74cf333bbf2ea94296 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 6 Jun 2024 09:43:38 -0700 Subject: [PATCH 449/706] [dtensor] directly return local_tensor under no_grad (#128145) as titled, skip the autograd function and directly return the local_tensor if it's under no_grad context, this would avoid creating views Pull Request resolved: https://github.com/pytorch/pytorch/pull/128145 Approved by: https://github.com/awgu ghstack dependencies: #128112 --- test/distributed/_tensor/test_dtensor.py | 5 +++++ torch/distributed/_tensor/api.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index e29eede07d87..17a6aebd8f93 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -331,6 +331,11 @@ def test_to_local(self): except RuntimeError: self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size]) + # test the case under no-grad we directly return the local tensor + with torch.no_grad(): + local_no_grad = sharded_tensor.to_local() + assert local_no_grad is sharded_tensor._local_tensor + @with_comms def test_to_local_grad_hint(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 49fe7267c634..be887f3ce6ca 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -418,6 +418,9 @@ def to_local( .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned will depend on if the `DTensor` requires_grad or not. """ + if not torch.is_grad_enabled(): + return self._local_tensor + if grad_placements is not None and not isinstance(grad_placements, tuple): grad_placements = tuple(grad_placements) return _ToTorchTensor.apply( From 96806b177702dec48ea63837ff890cb2479674de Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 6 Jun 2024 21:06:05 -0700 Subject: [PATCH 450/706] [pipelining][doc] Add frontend description and change tracer example (#128070) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128070 Approved by: https://github.com/wconstab, https://github.com/H-Huang --- docs/source/distributed.pipelining.rst | 362 ++++++++++++++++--------- torch/distributed/pipelining/_IR.py | 15 +- 2 files changed, 245 insertions(+), 132 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 2f4218a0d980..48f66b5d3276 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -4,184 +4,314 @@ Pipeline Parallelism #################### -.. note:: ``torch.distributed.pipelining`` is a package migrated from the `PiPPy `_ project. It is currently in alpha state and under extensive development. For examples that work with our APIs, please refer to PiPPy's `examples `_ directory. +.. note:: + ``torch.distributed.pipelining`` is currently in alpha state and under + development. API changes may be possible. It was migrated from the `PiPPy + `_ project. + Why Pipeline Parallel? ********************** -One of the most important techniques for advancing the state of the art in deep learning is scaling. Common techniques for scaling neural networks include *data parallelism*, *tensor/operation parallelism*, and *pipeline parallelism* (or *pipelining*). Pipelining is a technique in which the *code* of the model is partitioned and multiple *micro-batches* execute different parts of the model code concurrently. In many cases, pipeline parallelism can be an effective technique for scaling, in particular for large-scale jobs or bandwidth-limited interconnects. To learn more about pipeline parallelism in deep learning, see `this article `_. - -What is ``torch.distributed.pipelining``? -***************************************** +Pipeline Parallelism is one of the **primitive** parallelism for deep learning. +It allows the **execution** of a model to be partitioned such that multiple +**micro-batches** can execute different parts of the model code concurrently. +Pipeline parallelism can be an effective technique for: -.. automodule:: torch.distributed.pipelining +* large-scale training +* bandwidth-limited clusters +* large model inference. -.. currentmodule:: torch.distributed.pipelining +The above scenarios share a commonality that the computation per device cannot +hide the communication of conventional parallelism, for example, the weight +all-gather of FSDP. -While promising for scaling, pipelining is often difficult to implement, requiring intrusive code changes to model code and difficult-to-implement runtime orchestration code. ``torch.distributed.pipelining`` aims to provide **a toolkit that does said things automatically to allow high-productivity scaling of models.** It consists of a **compiler** and a **runtime** stack for easy pipelining of PyTorch models. In particular, it provides the following features: -* Splitting of model code based on your specification. The goal is for the user to provide model code as-is to the system for parallelization, without having to make heavyweight modifications to make parallelism work. The specification is also simple. -* Support for rich pipeline scheduling paradigms, including GPipe, 1F1B, Interleaved 1F1B and Looped BFS. It will be also easy to customize your own schedule under this framework. -* First-class support for cross-host pipeline parallelism, as this is where PP is typically used (over slower interconnects). -* Composability with other PyTorch parallel schemes such as data parallelism (DDP, FSDP) or tensor parallelism (overall, known as "3d parallelism"). +What is ``torch.distributed.pipelining``? +***************************************** -Examples -******** +While promising for scaling, pipelining is often difficult to implement because +it needs to **partition the execution** of a model in addition to model weights. +The partitioning of execution often requires intrusive code changes to your +model. Another aspect of complexity comes from **scheduling micro-batches in a +distributed environment**, with **data flow dependency** considered. -In the `PiPPy `_ repo where this package is migrated from, we provide rich examples based on realistic models. In particular, we show how to apply pipelining without any model code change. You can refer to the `HuggingFace examples directory `_. Popular examples include: `GPT2 `_, and `LLaMA `_. +The ``pipelining`` package provides a toolkit that does said things +**automatically** which allows easy implementation of pipeline parallelism +on **general** models. -Techniques Explained -******************** +It consists of two parts: a +**splitting frontend** and a **distributed runtime**. +The splitting frontend takes your model code as-is, splits it up into "model +partitions", and capture the data-flow relationship. The distributed runtime +executes the pipeline stages on different devices in parallel, handling things +like micro-batch splitting, scheduling, communication, and gradient propagation, +etc. -``torch.distributed.pipelining`` consists of two parts: a *compiler* and a *runtime*. The compiler takes your model code, splits it up, and transforms it into a ``Pipe``, which is a wrapper that describes the model at each pipeline stage and their data-flow relationship. The runtime executes the ``PipelineStage`` in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc. We will cover the APIs for these concepts in this section. +Overall, the ``pipelining`` package provides the following features: -Splitting a Model with ``pipeline`` -=================================== +* Splitting of model code based on simple specification. The goal is to make + parallelism work for your model with **zero model code change**. +* Rich support for pipeline schedules, including GPipe, 1F1B, + Interleaved 1F1B and Looped BFS, and provide the infrastruture for writing + customized schedules. +* First-class support for cross-host pipeline parallelism, as this is where PP + is typically used (over slower interconnects). +* Composability with other PyTorch parallel techniques such as data parallel + (DDP, FSDP) or tensor parallel. The `TorchTitan + `_ project demonstrates a "3D parallel" + application on the Llama model. -To see how we can split a model into a pipeline, let's first take an example trivial neural network: -.. code-block:: python +Step 1: choose the frontend that fits your need +*********************************************** - import torch +The ``pipelining`` package provides two frontends for two different use cases. +You can make your choice based on whether you have: - class MyNetworkBlock(torch.nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.lin = torch.nn.Linear(in_dim, out_dim) +* a full model, or +* module constructor for each stage. - def forward(self, x): - x = self.lin(x) - x = torch.relu(x) - return x +Frontend 1: the ``pipeline`` API -- if you have a full model +============================================================ - class MyNetwork(torch.nn.Module): - def __init__(self, in_dim, layer_dims): - super().__init__() +If you have a full model and do not want to spend time on modifying it into a +sequence of "model partitions", the ``pipeline`` API is here to help. +Here is a brief example: - prev_dim = in_dim - for i, dim in enumerate(layer_dims): - setattr(self, f'layer{i}', MyNetworkBlock(prev_dim, dim)) - prev_dim = dim - - self.num_layers = len(layer_dims) - # 10 output classes - self.output_proj = torch.nn.Linear(layer_dims[-1], 10) - - def forward(self, x): - for i in range(self.num_layers): - x = getattr(self, f'layer{i}')(x) +.. code-block:: python - return self.output_proj(x) + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.Embedding(10, 3) + self.layers = torch.nn.ModuleList( + Layer() for _ in range(2) + ) + self.lm = LMHead() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.emb(x) + for layer in self.layers: + x = layer(x) + x = self.lm(x) + return x - in_dim = 512 - layer_dims = [512, 1024, 256] - mn = MyNetwork(in_dim, layer_dims).to(device) +If we print the model, we can see multiple hierarchies, which makes it hard to split by hand:: -This network is written as free-form Python code; it has not been modified for any specific parallelism technique. + Model( + (emb): Embedding(10, 3) + (layers): ModuleList( + (0-1): 2 x Layer( + (lin): Linear(in_features=3, out_features=3, bias=True) + ) + ) + (lm): LMHead( + (proj): Linear(in_features=3, out_features=3, bias=True) + ) + ) -Let us see our usage of the ``pipeline`` interface: +Let us see how the ``pipeline`` API works: .. code-block:: python - from torch.distributed.pipelining import annotate_split_points, pipeline, Pipe, SplitPoint + from torch.distributed.pipelining import pipeline, SplitPoint - annotate_split_points(mn, {'layer0': SplitPoint.END, - 'layer1': SplitPoint.END}) - - batch_size = 32 - example_input = torch.randn(batch_size, in_dim, device=device) - chunks = 4 + x = torch.LongTensor([1, 2, 4, 5]) + pipe = pipeline( + module=mod, + num_chunks=1, + example_args=(x,), + split_spec={ + "layers.1": SplitPoint.BEGINNING, + } + ) - pipe = pipeline(mn, chunks, example_args=(example_input,)) - print(pipe) +The ``pipeline`` API splits your model given a ``split_spec``, where +``SplitPoint.BEGINNING`` stands for adding a split point +*before* execution of certain submodule in the ``forward`` function, and +similarly, ``SplitPoint.END`` for split point *after* such. -:: +If we ``print(pipe)``, we can see:: - ************************************* pipe ************************************* GraphModule( (submod_0): GraphModule( - (layer0): InterpreterModule( - (lin): InterpreterModule() + (emb): InterpreterModule() + (layers): Module( + (0): InterpreterModule( + (lin): InterpreterModule() + ) ) ) (submod_1): GraphModule( - (layer1): InterpreterModule( - (lin): InterpreterModule() + (layers): Module( + (1): InterpreterModule( + (lin): InterpreterModule() + ) ) - ) - (submod_2): GraphModule( - (layer2): InterpreterModule( - (lin): InterpreterModule() + (lm): InterpreterModule( + (proj): InterpreterModule() ) - (output_proj): InterpreterModule() ) ) - def forward(self, arg8_1): - submod_0 = self.submod_0(arg8_1); arg8_1 = None + def forward(self, x): + submod_0 = self.submod_0(x); x = None submod_1 = self.submod_1(submod_0); submod_0 = None - submod_2 = self.submod_2(submod_1); submod_1 = None - return (submod_2,) - -So what's going on here? First, ``pipeline`` turns our model into a directed acyclic graph (DAG) by tracing the model. Then, it groups together the operations and parameters into *pipeline stages*. Stages are represented as ``submod_N`` submodules, where ``N`` is a natural number. + return (submod_1,) -We used ``annotate_split_points`` to specify that the code should be split and the end of ``layer0`` and ``layer1``. Our code has thus been split into *three* pipeline stages. Our library also provides ``SplitPoint.BEGINNING`` if a user wants to split before certain annotation point. -While the ``annotate_split_points`` API gives users a way to specify the split points without modifying the model, our library also provides an API for in-model annotation: ``pipe_split()``. For details, you can read `this example `_. +The "model partitions" are represented by submodules (``submod_0``, +``submod_1``), each of which is reconstructed with original model operations +and hierarchies. In addition, a "root-level" ``forward`` function is +reconstructed to capture the data flow between those partitions. Such data flow +will be replayed by the pipeline runtime later, in a distributed fashion. -This covers the basic usage of the ``Pipe`` API. For more information, please see the documentation. +The ``Pipe`` object provides a method for retrieving the "model partitions": -Using ``PipelineSchedule`` for Execution -======================================== +.. code-block:: python -After transforming the model into a ``Pipe`` representation, we can run its stages in a distributed *runtime*. This can be done in two steps: -* instantiate a ``PipelineStage`` from a stage module of ``Pipe``; -* run the ``PipelineStage`` according to a ``PipelineSchedule``. + stage_mod : nn.Module = pipe.get_stage_module(stage_idx) -First off, let us instantiate a ``PipelineStage`` instance: +You can also create a distributed stage runtime on a device using ``Pipe``: .. code-block:: python - # We are using `torchrun` to run this example with multiple processes. - # `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`. - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) + from torch.distributed.pipelining import PipelineStage - # Initialize distributed environment - import torch.distributed as dist - dist.init_process_group(rank=rank, world_size=world_size) + stage = PipelineStage(pipe, stage_idx, device) - # Pipeline stage is our main pipeline runtime. It takes in the pipe object, - # the rank of this process, and the device. - from torch.distributed.pipelining import PipelineStage - stage = PipelineStage(pipe, rank, device) +.. note:: + The ``pipeline`` frontend uses a tracer (``torch.export``) to capture your + model into a single graph. If your model is not full-graph'able, you can use + our manual frontend below. + +Frontend 2: ``ManualPipelineStage`` -- if you already have module for each stage +================================================================================ -We can now attach the ``PipelineStage`` to a pipeline schedule, GPipe for example, and run with data: +If you already have the module for each stage, you can skip the pipeline split +step above and directly connect to our runtime offering: ``ManualPipelineStage``. +The ``ManualPipelineStage`` wraps your stage module given a distributed context, +i.e. a ``ProcessGroup`` along the pipeline dimension. + +TODO: manual example here + + +Step 2: use ``PipelineSchedule`` for execution +********************************************** + +We can now attach the ``PipelineStage`` to a pipeline schedule, and run the +schedule with input data. Here is a GPipe example: .. code-block:: python from torch.distributed.pipelining import ScheduleGPipe - schedule = ScheduleGPipe(stage, chunks) - # Input data + # Create a schedule + schedule = ScheduleGPipe(stage, n_microbatches) + + # Input data (whole batch) x = torch.randn(batch_size, in_dim, device=device) - # Run the pipeline with input `x`. Divide the batch into 4 micro-batches - # and run them in parallel on the pipeline + # Run the pipeline with input `x` + # `x` will be divided into microbatches automatically if rank == 0: schedule.step(x) else: output = schedule.step() -Note that since we split our model into three stages, we must run this script with three workers. For this example, we will use ``torchrun`` to run multiple processes within a single machine for demonstration purposes. We can collect up all of the code blocks above into a file named `example.py `_ and then run it with ``torchrun`` like so: +Note that the above code needs to be launched for each worker, thus we use a +launcher service to launch multiple processes: .. code-block:: bash - torchrun --nproc_per_node=3 example.py + torchrun --nproc_per_node=2 example.py + + +Hugging Face Examples +********************* + +In the `PiPPy `_ repo where this package was +original created, we kept examples based on unmodified Hugging Face models. +See the `examples/huggingface +`_ directory. + +Examples include: + +* `GPT2 `_ +* `Llama `_ + + +Technical Deep Dive +******************* + +How does the ``pipeline`` API split a model? +============================================ + +First, the ``pipeline`` API turns our model into a directed acyclic graph (DAG) +by tracing the model. It traces the model using ``torch.export`` -- a PyTorch 2 +full-graph capturing tool. + +Then, it groups together the **operations and parameters** needed by a stage +into a reconstructed submodule: ``submod_0``, ``submod_1``, ... + +Different from conventional submodule access methods like ``Module.children()``, +the ``pipeline`` API does not only cut the module structure of your model, but +also the **forward** function of your model. + +This is necessary because model structure like ``Module.children()`` merely +captures information during ``Module.__init__()``, and does not capture any +information about ``Module.forward()``. Said differently, ``Module.children()`` +lacks information about the following aspects key to pipelininig: -Pipeline Transformation APIs +* Exectuion order of child modules in ``forward`` +* Activation flows between child modules +* Whether there are any functional operators between child modules (for example, + ``relu`` or ``add`` operations will not be captured by ``Module.children()``). + +The ``pipeline`` API, on the contrary, makes sure that the ``forward`` behavior +is truly preserved. It also captures the activation flow between the partitions, +helping the distributed runtime to make correct send/receive calls without human +intervention. + +Another flexibility of the ``pipeline`` API is that split points can be at +arbitrary hierarchy of your model. In the split partitions, the original model +hierarchy related to that partition will be reconstructed at no cost of yours. +At a result, fully-qualified names (FQNs) pointing to a submodule or parameter +would be still valid, and services that relies on FQNs (such as FSDP, TP or +checkpointing) can still run with your partitioned modules with almost zero code +change. + + +Implementing Your Own Schedule +****************************** + +You can implement your own pipeline schedule by extending one of the following two class: + +* ``PipelineScheduleSingle`` +* ``PipelineScheduleMulti`` + +``PipelineScheduleSingle`` is for schedules that assigns *only one* stage per rank. +``PipelineScheduleMulti`` is for schedules that assigns multiple stages per rank. + +For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``. +Whereas, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS`` are subclasses of ``PipelineScheduleMulti``. + +.. currentmodule:: torch.distributed.pipelining.PipelineSchedule + +.. autoclass:: PipelineScheduleSingle + +.. autoclass:: PipelineScheduleMulti + + +API Reference +************* + +.. automodule:: torch.distributed.pipelining + +Model Split APIs ============================ The following set of APIs transform your model into a pipeline representation. @@ -240,23 +370,3 @@ Pipeline Schedules .. autoclass:: ScheduleInterleaved1F1B .. autoclass:: ScheduleLoopedBFS - -Implementing Your Own Schedule -============================== - -You can implement your own pipeline schedule by extending one of the following two class: - -* ``PipelineScheduleSingle`` -* ``PipelineScheduleMulti`` - -``PipelineScheduleSingle`` is for schedules that assigns *only one* stage per rank. -``PipelineScheduleMulti`` is for schedules that assigns multiple stages per rank. - -For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``. -Whereas, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS`` are subclasses of ``PipelineScheduleMulti``. - -.. currentmodule:: torch.distributed.pipelining.PipelineSchedule - -.. autoclass:: PipelineScheduleSingle - -.. autoclass:: PipelineScheduleMulti diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index c7ea787f98b5..0a45c4459f30 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -487,17 +487,17 @@ def _direct_serialization_reduce(self): class Pipe(torch.nn.Module): # Class variables - """ - args_chunk_spec: - Chunking specification for positional inputs. (default: `None`) - kwargs_chunk_spec: - Chunking specification for keyword inputs. (default: `None`) - """ # args_chunk_spec and kwargs_chunk_spec are used to specify how to chunk # inputs. They are used to create microbatched examples before tracing. # See context managers `ArgsChunkSpec` and `KwargsChunkSpec`. # TODO: Do we need to support `_Replicate`? It's unclear, dropping for now. + + # args_chunk_spec: + # Chunking specification for positional inputs. (default: `None`) args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None + + # kwargs_chunk_spec: + # Chunking specification for keyword inputs. (default: `None`) kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None @dataclass @@ -622,6 +622,9 @@ def forward(self, *args, **kwargs): return res def get_stage_module(self, stage_idx: int) -> torch.nn.Module: + """ + Return a stage module corresponding to `stage_idx` of the `pipe`. + """ if stage_idx < 0 or stage_idx >= self.num_stages: raise ValueError(f"Invalid stage index {stage_idx}!") return getattr(self.split_gm, f"submod_{stage_idx}") From e8e0bdf541ef8a688a71fce2c3df7a06b7cc6274 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 5 Jun 2024 20:51:56 +0000 Subject: [PATCH 451/706] [inductor] parallel-compile: call triton_key() before forking (#127639) Summary: A user reported severe slowdown on a workload when using parallel compile. The issue is that in some environments, the process affinity changes after forking such that all forked subprocesses use a single logical processor. Described here: https://github.com/pytorch/pytorch/issues/99625. That requires a separate fix, but during debuging we noticed that we can at least optimize the expensive call to triton_key() before forking. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127639 Approved by: https://github.com/eellison, https://github.com/anijain2305 --- torch/_inductor/async_compile.py | 23 +++++++++++++++++++--- torch/_inductor/compile_worker/__main__.py | 6 ++---- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 633946bb4ed8..b8e3d338dd9b 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -47,6 +47,25 @@ kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") +def pre_fork_setup(): + """ + Setup that must be done prior to forking with a process pool. + """ + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + + # Computing the triton key can be slow. If we call it before fork, + # it will be cached for the forked subprocesses. + try: + from triton.compiler.compiler import triton_key + + triton_key() + except ModuleNotFoundError: + # Might not be installed. + pass + + def caching_device_properties(): for _, device_interface in get_registered_device_interfaces(): if device_interface.is_available(): @@ -115,9 +134,7 @@ def process_pool() -> AnyPool: # Wrapper around ProcessPoolExecutor forks in a new process we control pool = SubprocPool(config.compile_threads) else: - # ensure properties have been calculated before processes - # are forked - caching_device_properties() + pre_fork_setup() ctx = multiprocessing.get_context(config.worker_start_method) pool = ProcessPoolExecutor( config.compile_threads, diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py index e478a5345675..fc8148f20c5f 100644 --- a/torch/_inductor/compile_worker/__main__.py +++ b/torch/_inductor/compile_worker/__main__.py @@ -3,7 +3,7 @@ import sys import typing -from torch._inductor.async_compile import caching_device_properties +from torch._inductor.async_compile import pre_fork_setup from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path @@ -34,9 +34,7 @@ def main(): # redirect output of workers to stderr os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) - # ensure properties have been calculated before processes - # are forked - caching_device_properties() + pre_fork_setup() _async_compile_initializer(args.parent) SubprocMain(args.workers, read_fd, write_fd).main() From 6a2bf48cfaef6edc5d23ba0dd83ca6af88bcd67f Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 5 Jun 2024 17:04:16 -0700 Subject: [PATCH 452/706] [inductor] subproc parallel-compile: start thread last in init (#128037) Summary: Observed on an internal workload: the helper thread started and attempted to access member variables before they were initialized. Differential Revision: [D58239827](https://our.internmc.facebook.com/intern/diff/D58239827) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128037 Approved by: https://github.com/Skylion007, https://github.com/eellison --- torch/_inductor/compile_worker/subproc_pool.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index f3f8e7b3b3ef..4260ae80e2ab 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -90,7 +90,6 @@ def __init__(self, nprocs: int): self.write_lock = threading.Lock() self.read_pipe: Pipe = typing.cast(Pipe, self.process.stdout) self.read_thread = threading.Thread(target=self._read_thread, daemon=True) - self.read_thread.start() self.futures_lock = threading.Lock() self.pending_futures: Dict[int, Future[Any]] = {} @@ -98,6 +97,10 @@ def __init__(self, nprocs: int): self.running = True + # Start thread last to ensure all member variables are initialized + # before any access. + self.read_thread.start() + def submit(self, job_fn: Callable[..., Any], *args): if args: job_fn = functools.partial(job_fn, *args) From dc8e3c2e904f22f05646e1d87987459af8845d7b Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 5 Jun 2024 17:04:20 -0700 Subject: [PATCH 453/706] [inductor] subproc parallel compile: initialize future before sending work to the pool (#128086) Summary: I got reports of intermittent failures in CI and the logs show errors like this: ``` CRITICAL:concurrent.futures:Future 139789013754560 in unexpected state: FINISHED ``` I can't repro locally, but seems clear that we should initialize the future _before_ sending work to the subprocess pool since it could finish before we call set_running_or_notify_cancel() Differential Revision: [D58239829](https://our.internmc.facebook.com/intern/diff/D58239829) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128086 Approved by: https://github.com/jansel ghstack dependencies: #128037 --- torch/_inductor/compile_worker/subproc_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 4260ae80e2ab..fbed608b851b 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -109,11 +109,11 @@ def submit(self, job_fn: Callable[..., Any], *args): with self.futures_lock: job_id = next(self.job_id_count) self.pending_futures[job_id] = future = Future() + future.set_running_or_notify_cancel() with self.write_lock: if not self.running: raise RuntimeError("submit() on closed pool") _send_msg(self.write_pipe, job_id, job_data) - future.set_running_or_notify_cancel() return future def _read_thread(self): From 7e48d6a49719c22d81f024236f8a7f7457c0694b Mon Sep 17 00:00:00 2001 From: laithsakka Date: Thu, 6 Jun 2024 14:58:35 -0700 Subject: [PATCH 454/706] reset dynamo in test_do_not_skip_side_effects unit test loop to avoid dynamo cache limit hit (#127487) fix https://github.com/pytorch/pytorch/issues/127483 When nn module inlining is enabled, all recompilations are considered for the same frame hence we hit the cache limit for test_do_not_skip_side_effects, but without inlining things are different , each time we hit a new Object Model we do not consider that a re-compilation, as explained in https://github.com/pytorch/pytorch/issues/127483 For that test we do not really care about cache size hence i reset dynamo in the main loop. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127487 Approved by: https://github.com/anijain2305 --- test/dynamo/test_skip_non_tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py index 43fe1ba1ece4..3ced7859cd7e 100644 --- a/test/dynamo/test_skip_non_tensor.py +++ b/test/dynamo/test_skip_non_tensor.py @@ -170,6 +170,8 @@ def test_do_not_skip_side_effects(self): global _variable, _variable_2 for mode in range(1, 7): + torch._dynamo.reset() + _variable = 0 _variable_2 = 0 From 68cc63ae278616641c6346e173147b68fa04bf1a Mon Sep 17 00:00:00 2001 From: laithsakka Date: Thu, 6 Jun 2024 14:58:35 -0700 Subject: [PATCH 455/706] introduce skipIfNNModuleInlined and skip test_cpu_cuda_module_after_dynamo (#128023) see the issue https://github.com/pytorch/pytorch/issues/127636 to for details about the issue, TLDR is that when inlining is enabled, we create a fake tensor while tracing in dynamo and try to perform aten.add.Tensor between two tensor of different types, with out inlining we do not hit that operation during tracing. ``` Failed running call_function (*(FakeTensor(..., size=(20, 20), grad_fn=), FakeTensor(..., device='cuda:0', size=(20, 20))), **{}): Unhandled FakeTensor Device Propagation for aten.add.Tensor, found two different devices cpu, cuda:0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128023 Approved by: https://github.com/anijain2305 ghstack dependencies: #127487, #127553 --- test/dynamo/test_minifier.py | 2 ++ torch/testing/_internal/common_utils.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 6412d015d56b..9014be6f7557 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -3,6 +3,7 @@ import torch._dynamo from torch._dynamo.test_minifier_common import MinifierTestBase +from torch.testing._internal.common_utils import skipIfNNModuleInlined requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") @@ -111,6 +112,7 @@ def test_after_dynamo_cuda_accuracy_backend_passes(self): ) # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd + @skipIfNNModuleInlined() @requires_cuda def test_cpu_cuda_module_after_dynamo(self): backend_name = "relu_compile_error_TESTING_ONLY" diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 8e3a66c77929..fbfb5cdfa02b 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1549,6 +1549,31 @@ def has_corresponding_torch_dtype(np_dtype): torch.complex32: np.complex64 }) +def skipIfNNModuleInlined( + msg="test doesn't currently work with nn module inlining", + condition=torch._dynamo.config.inline_inbuilt_nn_modules, +): # noqa: F821 + def decorator(fn): + if not isinstance(fn, type): + + @wraps(fn) + def wrapper(*args, **kwargs): + if condition: + raise unittest.SkipTest(msg) + else: + fn(*args, **kwargs) + + return wrapper + + assert isinstance(fn, type) + if condition: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = msg + + return fn + + return decorator + def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"): def dec_fn(fn): reason = f"skipIfRocm: {msg}" From d943357a2165616a729d588a8c75f88232d462ab Mon Sep 17 00:00:00 2001 From: Stonepia Date: Fri, 7 Jun 2024 06:25:44 +0000 Subject: [PATCH 456/706] [XPU] Add xpu support of `make triton` (#126513) This PR is to add XPU support for `make triton`. If a user wishes to use Triton with XPU support, the user needs to install the [intel-xpu-backend-for-triton](https://github.com/intel/intel-xpu-backend-for-triton). This PR allows the user to easily install Triton for xpu backend support: ``` # clone the pytorch repo export USE_XPU=1 make triton ``` The XPU version of triton will always be built from the source. It will cat the commit id from `.ci/docker/ci_commit_pins/triton-xpu.txt`, for example, `b8c64f64c18d8cac598b3adb355c21e7439c21de`. So the final call would be like: ``` pip install --force-reinstall "git+https://github.com/intel/intel-xpu-backend-for-triton@b8c64f64c18d8cac598b3adb355c21e7439c21de#subdirectory=python" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126513 Approved by: https://github.com/EikanWang, https://github.com/atalman --- README.md | 1 + scripts/install_triton_wheel.sh | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 9123dea20107..aa4638f9ece6 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,7 @@ conda install -c pytorch magma-cuda121 # or the magma-cuda* that matches your C # (optional) If using torch.compile with inductor/triton, install the matching version of triton # Run from the pytorch directory after cloning +# For Intel GPU support, please explicitly `export USE_XPU=1` before running command. make triton ``` diff --git a/scripts/install_triton_wheel.sh b/scripts/install_triton_wheel.sh index 269b80d07599..793c9a604edf 100755 --- a/scripts/install_triton_wheel.sh +++ b/scripts/install_triton_wheel.sh @@ -1,11 +1,23 @@ #!/bin/bash # Updates Triton to the pinned version for this copy of PyTorch BRANCH=$(git rev-parse --abbrev-ref HEAD) -TRITON_VERSION="pytorch-triton==$(cat .ci/docker/triton_version.txt)" -DOWNLOAD_PYTORCH_ORG="https://download.pytorch.org/whl" -if [[ "$BRANCH" =~ .*release.* ]]; then - pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/test/ $TRITON_VERSION +if [[ -z "${USE_XPU}" ]]; then + # Default install from PyTorch source + + TRITON_VERSION="pytorch-triton==$(cat .ci/docker/triton_version.txt)" + DOWNLOAD_PYTORCH_ORG="https://download.pytorch.org/whl" + if [[ "$BRANCH" =~ .*release.* ]]; then + pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/test/ $TRITON_VERSION + else + pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ $TRITON_VERSION+$(head -c 10 .ci/docker/ci_commit_pins/triton.txt) + fi else - pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ $TRITON_VERSION+$(head -c 10 .ci/docker/ci_commit_pins/triton.txt) + # Always install Triton for XPU from source + + TRITON_XPU_REPO="https://github.com/intel/intel-xpu-backend-for-triton" + TRITON_XPU_COMMIT_ID="$(cat .ci/docker/ci_commit_pins/triton-xpu.txt)" + + # force-reinstall to ensure the pinned version is installed + pip install --force-reinstall "git+${TRITON_XPU_REPO}@${TRITON_XPU_COMMIT_ID}#subdirectory=python" fi From 2ff312359ce392c5220388b3355b53ffb4f3e138 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Thu, 6 Jun 2024 05:36:56 -0700 Subject: [PATCH 457/706] skip hf_T5_generate in dynamic shape test (#121129) As reported in https://github.com/pytorch/pytorch/issues/119434, `hf_T5_generate` failed with dynamic shape testing, we propose to skip the dynamic batch size testing of this model in this PR. * Error msg is ``` File "/home/jiayisun/pytorch/torch/_dynamo/guards.py", line 705, in SHAPE_ENV guards = output_graph.shape_env.produce_guards( File "/home/jiayisun/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3253, in produce_guards raise ConstraintViolationError( torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['inputs_tensor'].size()[0])! For more information, run with TORCH_LOGS="+dynamic". - Not all values of RelaxedUnspecConstraint(L['inputs_tensor'].size()[0]) are valid because L['inputs_tensor'].size()[0] was inferred to be a constant (4). ``` * Root Cause is This error happens while creating guard for this [model script line](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L561): `scores += position_bias_masked` I run it with TORCH_LOGS="+dynamic" and got the key line : `I0305 00:21:00.849974 140376923287424 torch/fx/experimental/symbolic_shapes.py:3963] [6/0_1] eval Eq(s0, 4) [guard added] at miniconda3/envs/pt2/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py:561 in forward (_refs/__init__.py:403 in _broadcast_shapes)` The reason for this error is that the batch dimension of `inputs_tensor` in the dynamic batch size test is marked as dynamic shape `s0`, so the batch dimension of `scores` generated by a series of operations with `inputs_tensor` is also `s0`. However, because the function of creating `attention_mask` is not in Dynamo but in python. The batch dimension of `attention_mask` is the real shape `4`, and the batch dimension of `position_bias_masked` generated by a series of operations with `attention_mask` is also the real shape `4`, not the dynamic shape `s0`. The current line of `scores += position_bias_masked` requires creating a guard and check whether the batch dimension of `scores` is always equal to the batch dimension of `position_bias_masked`, Eq(s0, 4), the error happens. So the root cause of this error is that the function of creating `attention_mask` not in Dynamo but in python. The reason why the function of `attention_mask` not in Dynamo is that Dynamo has a graph break on this function (happened in the [model script line](https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L476): `is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)`) due to the following error: `torch._dynamo.exc.Unsupported: Tensor.item` Pull Request resolved: https://github.com/pytorch/pytorch/pull/121129 Approved by: https://github.com/leslie-fang-intel, https://github.com/ezyang --- .../cu124/dynamic_aot_eager_torchbench_inference.csv | 2 +- .../cu124/dynamic_inductor_torchbench_inference.csv | 2 +- .../dynamic_aot_eager_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_inference.csv | 2 +- benchmarks/dynamo/common.py | 1 + 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv index 431a91d10669..bcdf06917b64 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,fail_to_run,5 +hf_T5_generate,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv index f652e5ffa91a..3f60be5afd97 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,fail_to_run,5 +hf_T5_generate,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 431a91d10669..bcdf06917b64 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,fail_to_run,5 +hf_T5_generate,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index f652e5ffa91a..3f60be5afd97 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,fail_to_run,5 +hf_T5_generate,pass,5 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 466e6b30d0b1..2b685b8926b3 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -143,6 +143,7 @@ class CI(NamedTuple): "pyhpc_equation_of_state", "pyhpc_turbulent_kinetic_energy", "detectron2_fcos_r_50_fpn", + "hf_T5_generate", } # These models currently fail accuracy with eager Adam optimizer From a448b3ae9537c0ae233fb9199a4a221fdffbbeab Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 6 Jun 2024 13:06:26 -0700 Subject: [PATCH 458/706] [Traceable FSDP2] Check hasattr('fsdp_pre_all_gather') only when not compile (#127855) Dynamo doesn't support `hasattr(inner_tensor, "fsdp_post_all_gather")` yet. We will work on this support in Q3. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127855 Approved by: https://github.com/awgu --- torch/distributed/_composable/fsdp/_fsdp_common.py | 3 ++- torch/distributed/_composable/fsdp/_fsdp_param.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index e7654964144b..f372fcd2e073 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -6,6 +6,7 @@ from typing import Any, cast, List, Optional, Tuple import torch +import torch._dynamo.compiled_autograd as ca import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.contract import _get_registry @@ -122,7 +123,7 @@ def _from_local_no_grad( it avoids some CPU overhead by avoiding default args and not being differentiable. """ - if not torch._dynamo.compiled_autograd.compiled_autograd_enabled: + if not ca.compiled_autograd_enabled: spec = DTensorSpec( device_mesh, placements, diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index 5fed53f4a11a..cf28a8e4fe13 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -4,6 +4,7 @@ from typing import Any, cast, List, Optional, Sequence, Tuple import torch +import torch._dynamo.compiled_autograd as ca import torch.nn as nn from torch._prims_common import make_contiguous_strides_for @@ -326,7 +327,9 @@ def init_unsharded_param(self): self._extensions_data.clear() return inner_tensor = self._sharded_local_tensor - if hasattr(inner_tensor, "fsdp_post_all_gather"): + if not ca.compiled_autograd_enabled and hasattr( + inner_tensor, "fsdp_post_all_gather" + ): all_gather_outputs = self._unflatten_all_gather_outputs() ( unsharded_tensor, @@ -496,7 +499,9 @@ def free_unsharded_param(self) -> None: def all_gather_inputs(self) -> List[torch.Tensor]: # 1D self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) if self.sharded_state == ShardedState.SHARDED: - if hasattr(self._sharded_local_tensor, "fsdp_pre_all_gather"): + if not ca.compiled_autograd_enabled and hasattr( + self._sharded_local_tensor, "fsdp_pre_all_gather" + ): sharded_local_tensor = self._sharded_local_tensor if self.offload_to_cpu: sharded_local_tensor = sharded_local_tensor.to( @@ -517,7 +522,9 @@ def all_gather_inputs(self) -> List[torch.Tensor]: # 1D ) return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)] elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD: - if hasattr(self._sharded_local_tensor, "fsdp_pre_all_gather"): + if not ca.compiled_autograd_enabled and hasattr( + self._sharded_local_tensor, "fsdp_pre_all_gather" + ): raise NotImplementedError all_gather_input = _to_dtype_if_needed( cast(torch.Tensor, self._sharded_post_forward_param_data), From 190f06d468832662761eedf49deb5ab1e0e0c28e Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 6 Jun 2024 16:18:56 -0700 Subject: [PATCH 459/706] [pipelining] Lower _configure_data_parallel_mode to stage (#127946) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127946 Approved by: https://github.com/wconstab ghstack dependencies: #127935 --- .../pipelining/PipelineSchedule.py | 7 ---- torch/distributed/pipelining/PipelineStage.py | 40 ++++++++++++------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index 2e6856e25151..fabc9377277a 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -384,9 +384,6 @@ def _step_microbatches( # Delay send waits bwd_sends_to_wait: List[dist.Work] = [] for i in range(self._n_microbatches): - # set library-specific data-parallel config flags to ensure gradient accumulation across microbatches - self._stage._configure_data_parallel_mode(i == self._n_microbatches - 1) - with record_function(f"Backward {i}"): ops = self._stage.get_bwd_recv_ops(i) works = _sorted_batch_p2p(ops, desc="bwd_recv") @@ -663,7 +660,6 @@ def _step_microbatches( for stage in reversed(self._stages): for i in range(self._n_microbatches): - stage._configure_data_parallel_mode(i == self._n_microbatches - 1) with record_function(f"Stage {stage.stage_index} Backward"): ops = stage.get_bwd_recv_ops(i) if ops: @@ -878,9 +874,6 @@ def _step_microbatches( elif computation_type == _ComputationType.BACKWARD: # perform backward computation stage = stage_index_to_stage[stage_index] - stage._configure_data_parallel_mode( - mb_index == self._n_microbatches - 1 - ) loss = self._maybe_get_loss(stage, mb_index) stage.backward_one_chunk(mb_index, loss=loss) ops.extend(stage.get_bwd_send_ops(mb_index)) diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index 50e31ce7d471..b301c2e6e1ec 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -155,6 +155,10 @@ def __init__( self.grad_recv_info: Dict = {} self.grad_send_info: Optional[List] = None + # Number of backward chunks seen. This is used to determine when to do + # grad reduction in DDP or FSDP. + self._seen_bwd_chunks = 0 + @property def has_backward(self) -> bool: """ @@ -374,6 +378,8 @@ def clear_runtime_states(self) -> None: self.fwd_cache.clear() # Caching chunk outputs for final output merge or reduction self.output_chunks.clear() + # Reset bwd chunk counter + self._seen_bwd_chunks = 0 # Clear grad of input buffers in between schedule steps. This is because # `torch.autograd.backward()` will accumulate gradients into leaf @@ -426,17 +432,6 @@ def _retrieve_recv_grads( grads = self._map_tensor_from_recv_info(recv_infos) return grads - def _configure_data_parallel_mode(self, last_backward: bool): - """ - Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the - other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but - there are additional state-variables and performance considerations depending on the data parallelism used. - This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. - """ - if isinstance(self.submod, FSDPModule): - self.submod.set_is_last_backward(last_backward) - self.submod.set_requires_gradient_sync(last_backward) - def forward_maybe_with_nosync(self, *args, **kwargs): # If submod is wrapped with DDP, we use the `no_sync` context manager to # avoid gradient all-reduce per microbatch @@ -447,9 +442,18 @@ def forward_maybe_with_nosync(self, *args, **kwargs): out_val = self.submod(*args, **kwargs) return out_val - def backward_maybe_with_nosync(self, bwd_kwargs: Dict, bwd_chunk_id: int): + def backward_maybe_with_nosync(self, bwd_kwargs: Dict): + """ + Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the + other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but + there are additional state-variables and performance considerations depending on the data parallelism used. + This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. + """ + last_backward = self._seen_bwd_chunks == self.chunks - 1 + + # If submod is wrapped by DDP if isinstance(self.submod, DistributedDataParallel): - if bwd_chunk_id == self.chunks - 1: + if last_backward: # Last chunk, prepare for gradient reduction # HACK: reaching into DDP implementation details here. Is there a better way? self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] @@ -463,10 +467,16 @@ def backward_maybe_with_nosync(self, bwd_kwargs: Dict, bwd_chunk_id: int): else: with self.submod.no_sync(): # type: ignore[operator] grads_input = stage_backward(**bwd_kwargs) + # If submod is a FSDP module + elif isinstance(self.submod, FSDPModule): + self.submod.set_is_last_backward(last_backward) + self.submod.set_requires_gradient_sync(last_backward) + grads_input = stage_backward(**bwd_kwargs) else: - # Non-DDP submodule, regular backward + # Non-DP submodule, regular backward grads_input = stage_backward(**bwd_kwargs) + self._seen_bwd_chunks += 1 return grads_input def forward_one_chunk( @@ -568,7 +578,7 @@ def backward_one_chunk( "input_values": input_values, } - self.grads_input = self.backward_maybe_with_nosync(bwd_kwargs, bwd_chunk_id) + self.grads_input = self.backward_maybe_with_nosync(bwd_kwargs) logger.debug(f"{self.log_prefix} Backwarded chunk {bwd_chunk_id}") # noqa: G004 def _validate_fwd_input(self, args, kwargs): From 00c6ca44598cfe47e218be8613f05a9fef834fbf Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 5 Jun 2024 22:27:25 -0700 Subject: [PATCH 460/706] [compiled autograd][cudagraphs] Inputs runtime wrapper to move cpu scalars to cuda (#125382) Most commonly CPU scalars used for philox random seed. Right now, any cpu input will skip cudagraphing the entire graph. We need both the traced graph and the runtime inputs to be cudaified. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125382 Approved by: https://github.com/jansel --- test/inductor/test_compiled_autograd.py | 150 +++++++++++++++++- test/test_autograd.py | 7 + torch/_dynamo/compiled_autograd.py | 70 +++++++- torch/_inductor/compile_fx.py | 3 + .../csrc/dynamo/python_compiled_autograd.cpp | 32 +++- torch/fx/_utils.py | 7 +- 6 files changed, 256 insertions(+), 13 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 2daacc308071..776496f9331f 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import functools +import io import logging import re import sys @@ -13,6 +14,7 @@ from torch import _inductor as inductor from torch._dynamo import compiled_autograd, config from torch._dynamo.utils import counters +from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA from torch.testing._internal.logging_utils import logs_to_string @@ -457,7 +459,7 @@ def test_inputs_aliasing_bytecode_attr_mutations(self): param_proxy, activ_proxy = proxies buf = activ_proxy * 2 torch.ops.inductor.accumulate_grad_.default(param_proxy, buf) - compiled_fn = compiler.end_capture(buf) + runtime_wrapper, compiled_fn = compiler.end_capture(buf) def bytecode_hook(code, out_code): import dis @@ -494,7 +496,9 @@ def bytecode_hook(code, out_code): torch._dynamo.reset() handle = torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook) try: - compiled_fn(inputs=[param, activ], sizes=(), hooks=()) + runtime_wrapper( + compiled_fn=compiled_fn, inputs=[param, activ], sizes=(), hooks=() + ) finally: handle.remove() @@ -1658,6 +1662,147 @@ def fn(inputs): out = compiled_fn(activations) self.assertTrue(len(activations) == 0) + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_cudagraphs_cpu_division(self): + from torch._dynamo.testing import reduce_to_scalar_loss + + model = torch.nn.Linear(10, 10, dtype=torch.float16).cuda() + inputs = torch.randn(10, 10, dtype=torch.float16).cuda() + out = model(inputs) + loss = reduce_to_scalar_loss(out) + + stderr_msgs = io.StringIO() + with mock.patch("sys.stderr", stderr_msgs), compiled_autograd.enable( + compiler_fn + ): + torch._inductor.config.triton.cudagraphs = True + loss.backward() + torch._inductor.config.triton.cudagraphs = False + + self.assertFalse("skipping cudagraphs" in stderr_msgs.getvalue()) + + def test_cudagraphs_cpu_graph(self): + from torch._dynamo.testing import reduce_to_scalar_loss + + model = torch.nn.Linear(10, 10, dtype=torch.float16) + inputs = torch.randn(10, 10, dtype=torch.float16) + out = model(inputs) + loss = reduce_to_scalar_loss(out) + + with compiled_autograd.enable(compiler_fn): + torch._inductor.config.triton.cudagraphs = True + loss.backward() + torch._inductor.config.triton.cudagraphs = False + + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_cudagraphs_sdpa(self): + query = torch.rand( + 32, 8, 128, 64, dtype=torch.float16, device="cuda", requires_grad=True + ) + key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + out = torch.nn.functional.scaled_dot_product_attention(query, key, value) + + with config.patch(compiled_autograd=True), inductor_config.patch( + "triton.cudagraphs", True + ): + opt_bwd = torch.compile(lambda: out.sum().backward()) + opt_bwd() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self): + class MyFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + cpu_tensor = torch.tensor(5) + ctx.save_for_backward(x, cpu_tensor) # visible to c++/autograd + ctx.cpu_scalar = 5 # opaque to c++/autograd + return x.sum() + + @staticmethod + def backward(ctx, gO): + x, cpu_tensor = ctx.saved_tensors + expand = gO * torch.ones_like(x) + return expand * cpu_tensor * ctx.cpu_scalar + + x = torch.randn(10, requires_grad=True, device="cuda") + out = MyFn.apply(x) + with config.patch(compiled_autograd=True), inductor_config.patch( + "triton.cudagraphs", True + ): + opt_bwd = torch.compile(lambda: out.backward()) + opt_bwd() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + # Compiled autograd lifts custom autograd.Function bwd instead of tracing it. + # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self): + cpp_source = """ +struct CustomOpAutogradFunction : public torch::autograd::Function { + static constexpr bool is_traceable = true; + + static torch::Tensor forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& x) { + const auto& cpu_tensor = torch::tensor(1); + ctx->save_for_backward({x, cpu_tensor}); + ctx->saved_data["cpu_scalar"] = 1; + return x; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext *ctx, + torch::autograd::variable_list grad_output) { + const auto& saved_variables = ctx->get_saved_variables(); + assert(saved_variables.size() == 2); + torch::Tensor x = saved_variables[0]; + torch::Tensor cpu_tensor = saved_variables[1]; + int cpu_scalar = ctx->saved_data["cpu_scalar"].toInt(); + auto expand = grad_output[0] * torch::ones_like(x); + torch::autograd::variable_list grad_inputs(1); + grad_inputs[0] = expand * cpu_tensor * cpu_scalar; // autograd engine asserts that tensors are on same device + return grad_inputs; + } +}; + +torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) { + return CustomOpAutogradFunction::apply(x); +} + +TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { + m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); +} + """ + + module = torch.utils.cpp_extension.load_inline( + name="test_cudagraphs_cpu_scalar_used_in_cpp_custom_op", + cpp_sources=cpp_source, + functions="custom_op_backed_by_autograd_fn", + verbose=True, + ) + + x = torch.randn(2, 2, requires_grad=True, device="cuda") + with config.patch(compiled_autograd=True), inductor_config.patch( + "triton.cudagraphs", True + ): + out = torch.ops.test_cudagraphs_cpu_scalar_used_in_cpp_custom_op.custom_op_backed_by_autograd_fn( + x + ) + opt_bwd = torch.compile(lambda: out.sum().backward()) + opt_bwd() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + # always safe to move, since we trace into the autograd::function bwd and can see if it's only used by aten ops + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + def test_verbose_logs_graph(self): torch._logging.set_logs(compiled_autograd_verbose=True) @@ -1978,6 +2123,7 @@ def wrap_test_class(orig_cls): "test_autograd_function_backed_op", # RuntimeError: compiled_args not implemented "test_setitem", # AssertionError: Tensor-likes are not close! "test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors) + "test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads } if not HAS_CUDA: diff --git a/test/test_autograd.py b/test/test_autograd.py index ecd267dc9f77..c032319fa160 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -9508,6 +9508,13 @@ def f(x): memory_with_hooks = torch.cuda.memory_allocated() self.assertEqual(memory_with_hooks, memory_without_grad) + @unittest.skipIf(not TEST_CUDA, "test requires CUDA") + def test_scalar_grad_mixed_device(self): + x = torch.tensor(1.0, requires_grad=True) + y = torch.randn(2, 2, device="cuda") + out = x * y + out.sum().backward() + def test_multi_grad_all_hooks(self): t1 = torch.rand(2, requires_grad=True) t2 = torch.rand(2, requires_grad=True) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 7a87a2c7d575..bbc8d722b7e2 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1,6 +1,6 @@ import contextlib import functools -from typing import List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING import torch from torch._dynamo.external_utils import call_backward, call_hook @@ -41,6 +41,10 @@ def cpp_verbose_log_fn(msg: str) -> None: verbose_log.debug(msg) +def snapshot_cudagraph_enabled(): + return torch._inductor.config.triton.cudagraphs + + def maybe_clone(x): if x is not None: return clone_preserve_strides(x) @@ -203,6 +207,52 @@ def post_acc_grad_hook(self, input, hook_id): self.bind_tensors_to_proxies(input, proxies) return input + # Note: [Compiled autograd and cudagraphs] + # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_. + # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph + # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the + # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too. + def move_graph_nodes_to_cuda(self, graph) -> List[int]: + to_move: Dict[int, torch.fx.Node] = {} + has_cuda_inputs = False + nodes = list(graph.nodes) + assert nodes[0].target == "inputs" + inputs = nodes[0] + inputs_users = list(inputs.users.keys()) + # the ordering of the nodes should always [inputs, sizes, hooks, getitem, getitem1, ...] + # where getitemi accesses inputs[i] + first_getitem_idx = 3 + assert nodes[first_getitem_idx] == inputs_users[0] + last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 + assert nodes[last_getitem_idx] == inputs_users[-1] + for i, node in enumerate(inputs_users): + if not has_cuda_inputs and node.meta["val"].device.type == "cuda": + has_cuda_inputs = True + continue + + is_cpu = node.meta["val"].device.type == "cpu" + is_scalar = len(node.meta["val"].size()) == 0 + if is_cpu and is_scalar: + node_users = list(node.users.keys()) + if all( + isinstance(user.target, torch._ops.OpOverload) + and user.target.namespace in ("prims", "aten") + for user in node_users + ): + # all users are prims/aten, can move safely + to_move[i] = node + + # only move cpu scalars to cuda if there were cuda activations in this graph, + # this is to handle the case where cudagraphs is enabled on a cpu-only graph + if has_cuda_inputs: + for node in to_move.values(): + node.meta["val"] = node.meta["val"].cuda() + + # return runtime indices we need to move to cuda + return list(to_move.keys()) + + return [] + def end_capture(self, outputs): self.stack.close() self.fx_tracer.create_node( @@ -212,6 +262,10 @@ def end_capture(self, outputs): {}, ) self.reorder_accumulate_grad_nodes() + runtime_inputs_to_move: List[int] = [] + if snapshot_cudagraph_enabled(): + runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) + graph = GraphModule( self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd" ) @@ -220,13 +274,23 @@ def end_capture(self, outputs): "%s", lazy_format_graph_code("Compiled autograd graph", graph) ) verbose_log.debug( - "%s", lazy_format_graph_code("Compiled autograd graph", graph) + "%s", + lazy_format_graph_code( + "Compiled autograd graph", graph, include_device=True + ), ) trace_structured( "compiled_autograd_graph", payload_fn=lambda: graph.print_readable(print_output=False), ) - return self.compiler_fn(graph) + + def runtime_wrapper(compiled_fn, inputs, sizes, hooks): + for i in runtime_inputs_to_move: + inputs[i] = inputs[i].cuda() + + return compiled_fn(inputs, sizes, hooks) + + return runtime_wrapper, self.compiler_fn(graph) def reorder_accumulate_grad_nodes(self): """ diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 856335819621..ce7d8f6e9b14 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -24,6 +24,7 @@ utils as dynamo_utils, ) from torch._dynamo.utils import ( + counters, detect_fake_mode, flatten_graph_inputs, lazy_format_graph_code, @@ -534,6 +535,8 @@ def compile_fx_inner( log_cudagraph_skip_and_bump_counter( f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}" ) + else: + counters["inductor"]["cudagraph_skips"] += 1 BoxedBool.disable(cudagraphs) # Return the output strides to the caller via TracingContext diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 3a79a7bc6372..6cdce255d7df 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -186,6 +186,7 @@ struct CacheNode { next.clear(); key_storage.clear(); expected_sizes.clear(); + runtime_wrapper = nullptr; compiled_fn = nullptr; } @@ -193,10 +194,12 @@ struct CacheNode { return next.empty() && !compiled_fn; } - CacheNode() : compiled_fn(nullptr) {} + CacheNode() : runtime_wrapper(nullptr), compiled_fn(nullptr) {} ~CacheNode() { if (!Py_IsInitialized()) { - compiled_fn.release(); // leak on shutdown + // leak on shutdown + runtime_wrapper.release(); + compiled_fn.release(); } } CacheNode(CacheNode&&) = delete; @@ -250,6 +253,7 @@ struct CacheNode { if (!cache_hit) { // we missed cache because static size inputs didn't match; force // recompilation with the varying size input as dynamic + runtime_wrapper = nullptr; compiled_fn = nullptr; } return cache_hit; @@ -298,6 +302,7 @@ struct CacheNode { std::vector key_storage; std::vector expected_sizes; + THPObjectPtr runtime_wrapper; THPObjectPtr compiled_fn; }; @@ -591,12 +596,22 @@ CacheNode* _compiled_autograd_impl( } } - cache->compiled_fn = check(call_end_capture(py_compiler, state.outputs)); + PyObject* res = check(call_end_capture(py_compiler, state.outputs)); + TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple"); + TORCH_CHECK( + PyTuple_Size(res) == 2, + "Expected end_capture to return tuple of size 2"); + cache->runtime_wrapper = Py_NewRef(PyTuple_GetItem(res, 0)); + TORCH_CHECK( + PyCallable_Check(cache->runtime_wrapper), + "Expected end_capture to return runtime_wrapper"); + cache->compiled_fn = Py_NewRef(PyTuple_GetItem(res, 1)); + TORCH_CHECK( + PyCallable_Check(cache->compiled_fn), + "Expected end_capture to return compiled_fn"); state.debug_asserts(); } // End cache miss region - // TODO(jansel): we should release all the variables and then use a - // boxed calling convention so activation memory can be freed // TODO(jansel): clear grads we will overwrite below if (!graph_task.keep_graph_) { for (auto& call : calls) { @@ -639,7 +654,12 @@ variable_list compiled_autograd( &hooks); THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs( - cache->compiled_fn.get(), inputs.get(), sizes.get(), hooks.get(), NULL))); + cache->runtime_wrapper.get(), + cache->compiled_fn.get(), + inputs.get(), + sizes.get(), + hooks.get(), + NULL))); variable_list outputs = THPVariable_UnpackList(pyresult); TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size()); return outputs; diff --git a/torch/fx/_utils.py b/torch/fx/_utils.py index 5f99d698586c..598aeafee2d9 100644 --- a/torch/fx/_utils.py +++ b/torch/fx/_utils.py @@ -5,7 +5,7 @@ from torch._logging import LazyString -def lazy_format_graph_code(name, gm, maybe_id=None): +def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): """ Returns a LazyString that formats the graph code. """ @@ -16,11 +16,14 @@ def format_name(): else: return name + if "print_output" not in kwargs: + kwargs["print_output"] = False + return LazyString( lambda: _format_graph_code( f"===== {format_name()} =====\n", gm.forward.__code__.co_filename, - gm.print_readable(print_output=False), + gm.print_readable(**kwargs), ) ) From 70724bdbfee0eef6b37a24a7201cff38e38f30db Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 6 Jun 2024 08:13:19 -0700 Subject: [PATCH 461/706] Bugfix for nondeterminstic torch_key (#128111) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128111 Approved by: https://github.com/oulgen --- torch/_inductor/codecache.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 335cec4d4056..6251513f0119 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -565,9 +565,13 @@ def get_code_hash(roots): module = spec.origin assert module is not None with open(module, "rb") as f: - contents[module] = f.read() - - return hashlib.sha256(pickle.dumps(contents)).digest() + contents[spec.name] = f.read() + hasher = hashlib.sha256() + # Iterate over dict in sorted order since iter_modules may not be deterministic + for name, value in sorted(contents.items()): + hasher.update(name.encode("utf-8")) + hasher.update(value) + return hasher.digest() @functools.lru_cache(None) From 01601ebd4169ed90ce1ac6b70863c3028b2e9755 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 7 Jun 2024 01:06:55 -0700 Subject: [PATCH 462/706] Retire torch.distributed.pipeline (#127354) Actually retiring module after deprecation warning for a while. The new supported module is: torch.distributed.pipelining. Please migrate. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127354 Approved by: https://github.com/wconstab --- .lintrunner.toml | 24 - .../distributed/pipeline/benchmark_dataset.py | 58 -- benchmarks/distributed/pipeline/pipe.py | 296 ------ docs/source/conf.py | 87 -- docs/source/distributed.pipelining.rst | 12 +- docs/source/distributed.rst | 19 - docs/source/index.rst | 1 - docs/source/pipeline.rst | 85 -- test/allowlist_for_publicAPI.json | 28 - test/distributed/pipeline/sync/LICENSE | 27 - test/distributed/pipeline/sync/__init__.py | 8 - test/distributed/pipeline/sync/conftest.py | 61 -- .../pipeline/sync/skip/__init__.py | 6 - .../pipeline/sync/skip/test_api.py | 52 -- .../pipeline/sync/skip/test_gpipe.py | 126 --- .../sync/skip/test_inspect_skip_layout.py | 118 --- .../pipeline/sync/skip/test_leak.py | 136 --- .../pipeline/sync/skip/test_portal.py | 163 ---- .../pipeline/sync/skip/test_stash_pop.py | 144 --- .../pipeline/sync/skip/test_tracker.py | 145 --- .../sync/skip/test_verify_skippables.py | 165 ---- .../distributed/pipeline/sync/test_balance.py | 240 ----- test/distributed/pipeline/sync/test_bugs.py | 146 --- .../pipeline/sync/test_checkpoint.py | 178 ---- test/distributed/pipeline/sync/test_copy.py | 85 -- .../pipeline/sync/test_deferred_batch_norm.py | 200 ---- .../pipeline/sync/test_dependency.py | 152 ---- .../distributed/pipeline/sync/test_inplace.py | 79 -- .../pipeline/sync/test_microbatch.py | 148 --- test/distributed/pipeline/sync/test_phony.py | 57 -- test/distributed/pipeline/sync/test_pipe.py | 858 ------------------ .../pipeline/sync/test_pipeline.py | 36 - test/distributed/pipeline/sync/test_stream.py | 198 ---- .../pipeline/sync/test_transparency.py | 55 -- test/distributed/pipeline/sync/test_worker.py | 118 --- test/test_public_bindings.py | 2 - test/test_testing.py | 1 - torch/distributed/pipeline/__init__.py | 13 - torch/distributed/pipeline/sync/LICENSE | 27 - torch/distributed/pipeline/sync/__init__.py | 12 - .../pipeline/sync/_balance/__init__.py | 164 ---- .../pipeline/sync/_balance/blockpartition.py | 95 -- .../pipeline/sync/_balance/profile.py | 116 --- .../pipeline/sync/_balance/py.typed | 6 - torch/distributed/pipeline/sync/batchnorm.py | 159 ---- torch/distributed/pipeline/sync/checkpoint.py | 364 -------- torch/distributed/pipeline/sync/copy.py | 108 --- torch/distributed/pipeline/sync/dependency.py | 54 -- torch/distributed/pipeline/sync/microbatch.py | 234 ----- torch/distributed/pipeline/sync/phony.py | 50 - torch/distributed/pipeline/sync/pipe.py | 490 ---------- torch/distributed/pipeline/sync/pipeline.py | 255 ------ torch/distributed/pipeline/sync/py.typed | 6 - .../pipeline/sync/skip/__init__.py | 11 - .../distributed/pipeline/sync/skip/layout.py | 92 -- .../pipeline/sync/skip/namespace.py | 50 - .../distributed/pipeline/sync/skip/portal.py | 231 ----- .../pipeline/sync/skip/skippable.py | 431 --------- .../distributed/pipeline/sync/skip/tracker.py | 180 ---- torch/distributed/pipeline/sync/stream.py | 120 --- torch/distributed/pipeline/sync/utils.py | 38 - torch/distributed/pipeline/sync/worker.py | 132 --- .../pipelining/PipelineSchedule.py | 22 + .../distributed/pipe_with_ddp_test.py | 149 --- .../distributed/pipeline/__init__.py | 0 .../_internal/distributed/rpc_utils.py | 4 - 66 files changed, 28 insertions(+), 7899 deletions(-) delete mode 100644 benchmarks/distributed/pipeline/benchmark_dataset.py delete mode 100644 benchmarks/distributed/pipeline/pipe.py delete mode 100644 docs/source/pipeline.rst delete mode 100644 test/distributed/pipeline/sync/LICENSE delete mode 100644 test/distributed/pipeline/sync/__init__.py delete mode 100644 test/distributed/pipeline/sync/conftest.py delete mode 100644 test/distributed/pipeline/sync/skip/__init__.py delete mode 100644 test/distributed/pipeline/sync/skip/test_api.py delete mode 100644 test/distributed/pipeline/sync/skip/test_gpipe.py delete mode 100644 test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py delete mode 100644 test/distributed/pipeline/sync/skip/test_leak.py delete mode 100644 test/distributed/pipeline/sync/skip/test_portal.py delete mode 100644 test/distributed/pipeline/sync/skip/test_stash_pop.py delete mode 100644 test/distributed/pipeline/sync/skip/test_tracker.py delete mode 100644 test/distributed/pipeline/sync/skip/test_verify_skippables.py delete mode 100644 test/distributed/pipeline/sync/test_balance.py delete mode 100644 test/distributed/pipeline/sync/test_bugs.py delete mode 100644 test/distributed/pipeline/sync/test_checkpoint.py delete mode 100644 test/distributed/pipeline/sync/test_copy.py delete mode 100644 test/distributed/pipeline/sync/test_deferred_batch_norm.py delete mode 100644 test/distributed/pipeline/sync/test_dependency.py delete mode 100644 test/distributed/pipeline/sync/test_inplace.py delete mode 100644 test/distributed/pipeline/sync/test_microbatch.py delete mode 100644 test/distributed/pipeline/sync/test_phony.py delete mode 100644 test/distributed/pipeline/sync/test_pipe.py delete mode 100644 test/distributed/pipeline/sync/test_pipeline.py delete mode 100644 test/distributed/pipeline/sync/test_stream.py delete mode 100644 test/distributed/pipeline/sync/test_transparency.py delete mode 100644 test/distributed/pipeline/sync/test_worker.py delete mode 100644 torch/distributed/pipeline/__init__.py delete mode 100644 torch/distributed/pipeline/sync/LICENSE delete mode 100644 torch/distributed/pipeline/sync/__init__.py delete mode 100644 torch/distributed/pipeline/sync/_balance/__init__.py delete mode 100644 torch/distributed/pipeline/sync/_balance/blockpartition.py delete mode 100644 torch/distributed/pipeline/sync/_balance/profile.py delete mode 100644 torch/distributed/pipeline/sync/_balance/py.typed delete mode 100644 torch/distributed/pipeline/sync/batchnorm.py delete mode 100644 torch/distributed/pipeline/sync/checkpoint.py delete mode 100644 torch/distributed/pipeline/sync/copy.py delete mode 100644 torch/distributed/pipeline/sync/dependency.py delete mode 100644 torch/distributed/pipeline/sync/microbatch.py delete mode 100644 torch/distributed/pipeline/sync/phony.py delete mode 100644 torch/distributed/pipeline/sync/pipe.py delete mode 100644 torch/distributed/pipeline/sync/pipeline.py delete mode 100644 torch/distributed/pipeline/sync/py.typed delete mode 100644 torch/distributed/pipeline/sync/skip/__init__.py delete mode 100644 torch/distributed/pipeline/sync/skip/layout.py delete mode 100644 torch/distributed/pipeline/sync/skip/namespace.py delete mode 100644 torch/distributed/pipeline/sync/skip/portal.py delete mode 100644 torch/distributed/pipeline/sync/skip/skippable.py delete mode 100644 torch/distributed/pipeline/sync/skip/tracker.py delete mode 100644 torch/distributed/pipeline/sync/stream.py delete mode 100644 torch/distributed/pipeline/sync/utils.py delete mode 100644 torch/distributed/pipeline/sync/worker.py delete mode 100644 torch/testing/_internal/distributed/pipe_with_ddp_test.py delete mode 100644 torch/testing/_internal/distributed/pipeline/__init__.py diff --git a/.lintrunner.toml b/.lintrunner.toml index 1d7c00a2c772..874a553ee9bc 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1532,28 +1532,6 @@ exclude_patterns = [ 'torch/distributed/optim/post_localSGD_optimizer.py', 'torch/distributed/optim/utils.py', 'torch/distributed/optim/zero_redundancy_optimizer.py', - 'torch/distributed/pipeline/__init__.py', - 'torch/distributed/pipeline/sync/__init__.py', - 'torch/distributed/pipeline/sync/_balance/__init__.py', - 'torch/distributed/pipeline/sync/_balance/blockpartition.py', - 'torch/distributed/pipeline/sync/_balance/profile.py', - 'torch/distributed/pipeline/sync/batchnorm.py', - 'torch/distributed/pipeline/sync/checkpoint.py', - 'torch/distributed/pipeline/sync/copy.py', - 'torch/distributed/pipeline/sync/dependency.py', - 'torch/distributed/pipeline/sync/microbatch.py', - 'torch/distributed/pipeline/sync/phony.py', - 'torch/distributed/pipeline/sync/pipe.py', - 'torch/distributed/pipeline/sync/pipeline.py', - 'torch/distributed/pipeline/sync/skip/__init__.py', - 'torch/distributed/pipeline/sync/skip/layout.py', - 'torch/distributed/pipeline/sync/skip/namespace.py', - 'torch/distributed/pipeline/sync/skip/portal.py', - 'torch/distributed/pipeline/sync/skip/skippable.py', - 'torch/distributed/pipeline/sync/skip/tracker.py', - 'torch/distributed/pipeline/sync/stream.py', - 'torch/distributed/pipeline/sync/utils.py', - 'torch/distributed/pipeline/sync/worker.py', 'torch/distributed/remote_device.py', 'torch/distributed/rendezvous.py', 'torch/distributed/rpc/__init__.py', @@ -1847,8 +1825,6 @@ exclude_patterns = [ 'torch/testing/_internal/distributed/nn/__init__.py', 'torch/testing/_internal/distributed/nn/api/__init__.py', 'torch/testing/_internal/distributed/nn/api/remote_module_test.py', - 'torch/testing/_internal/distributed/pipe_with_ddp_test.py', - 'torch/testing/_internal/distributed/pipeline/__init__.py', 'torch/testing/_internal/distributed/rpc/__init__.py', 'torch/testing/_internal/distributed/rpc/dist_autograd_test.py', 'torch/testing/_internal/distributed/rpc/dist_optimizer_test.py', diff --git a/benchmarks/distributed/pipeline/benchmark_dataset.py b/benchmarks/distributed/pipeline/benchmark_dataset.py deleted file mode 100644 index 3cd22e9a468d..000000000000 --- a/benchmarks/distributed/pipeline/benchmark_dataset.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -from torch.utils.data import Dataset - - -def collate_sentences_lm(samples): - if len(samples) == 0: - return {} - - id = torch.LongTensor([s["id"] for s in samples]) - src_tokens = torch.stack([s["source"] for s in samples], 0) - tgt_tokens = torch.stack([s["target"] for s in samples], 0) - ntokens = len(samples) * len(samples[0]["target"]) - src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples)) - - batch = { - "id": id, - "nsentences": len(samples), - "ntokens": ntokens, - "input": src_tokens, - "target": tgt_tokens, - } - return batch - - -class BenchmarkLMDataset(Dataset): - """ - Dataset to benchmark a translation like seq2seq task. - Args: - vocab_size (int, optional): size of the vocabulary (default 10000). - max_source_positions (int, optional): max number of tokens in the - source sentence (default: 1024). - total_samples (int, optional): the total number of rows in the - dataset (default: 10000). - """ - - def __init__( - self, - vocab_size=10000, - max_source_positions=1024, - total_samples=10000, - ): - self.vocab_size = vocab_size - self.max_source_positions = max_source_positions - self.total_samples = total_samples - self.sizes = [self.max_source_positions] * self.total_samples - - def __getitem__(self, index): - length = self.sizes[index] - source = torch.randint(1, self.vocab_size, (length,)) - target = source.clone() - return { - "id": index, - "source": source, - "target": target, - } - - def __len__(self): - return self.total_samples diff --git a/benchmarks/distributed/pipeline/pipe.py b/benchmarks/distributed/pipeline/pipe.py deleted file mode 100644 index c465c2488565..000000000000 --- a/benchmarks/distributed/pipeline/pipe.py +++ /dev/null @@ -1,296 +0,0 @@ -import argparse -import math -import os -import time - -from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm - -import torch -import torch.nn as nn -from torch.distributed import rpc - -from torch.distributed.pipeline.sync import Pipe -from torch.distributed.pipeline.sync.utils import partition_model -from torch.optim import Adam -from torch.utils.data import DataLoader - - -def sizeof_fmt(num, suffix="B"): - for unit in ["", "Ki", "Mi", "Gi", "Ti"]: - if abs(num) < 1024.0: - return f"{num:3.2f}{unit}B" - num /= 1024.0 - - -def init_random_seed(seed: int): - import numpy - - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - numpy.random.seed(seed) - - -iteration_count = 0 - - -class EmbeddingLayer(nn.Embedding): - def __init__(self, ntoken, ninp, initrange): - super().__init__(ntoken, ninp) - self.ninp = ninp - nn.init.uniform_(self.weight, -initrange, initrange) - - def forward(self, src): - return super().forward(src) * math.sqrt(self.ninp) - - -class PositionalEncodingLayer(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer("pe", pe) - - def forward(self, x): - x = x + self.pe[: x.size(0), :] - return self.dropout(x) - - -class TransformerDecoderLayer(nn.TransformerEncoderLayer): - """Though this class inherits from torch.nn.TransformerEncoderLayer, - it functions as a decoder in this model""" - - def __init__(self, ninp, nhead, nhid, droupout): - super().__init__(ninp, nhead, nhid, droupout) - self.src_mask = None - - def forward(self, src): - global iteration_count - iteration_count += 1 - - if self.src_mask is None or self.src_mask.size(0) != len(src): - device = src.device - mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device) - self.src_mask = mask - - return super().forward(src, self.src_mask) - - -class LinearLayer(nn.Linear): - def __init__(self, ninp, ntoken, initrange): - super().__init__(ninp, ntoken) - nn.init.zeros_(self.bias) - nn.init.uniform_(self.weight, -initrange, initrange) - - -class TransformerLMSequential(nn.Sequential): - """A small language model based on the design of GPT-2 using nn.Sequential - for compatibility with Pipe""" - - def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder): - layers = [ - EmbeddingLayer(ntokens, ninp, initrange), - PositionalEncodingLayer(ninp, dropout), - ] - for _ in range(ndecoder): - layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout)) - - layers.append(LinearLayer(ninp, ntokens, initrange)) - super().__init__(*layers) - - -def make_model(args, device, ntokens): - ninp = 2048 # embedding dimension - nhid = ( - 2048 # the dimension of the feedforward network model in nn.TransformerEncoder - ) - nhead = 32 # the number of heads in the multiheadattention models - dropout = 0 - initrange = 0.1 - ndecoder = args.num_decoder_layers - - model = TransformerLMSequential( - ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder - ).to(device) - - criterion = nn.CrossEntropyLoss() - lr = 0.01 # learning rate - - def make_adam(model): - return Adam(model.parameters(), lr=lr) - - optimizer = make_adam - - return model, criterion, optimizer - - -def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): - model.train() - - vocab_size = 10000 - total_loss = 0.0 - start_time = time.time() - word_counter = 0 - - optimizer = optimizer(model) - - def get_first_device(model): - if model.devices: - return model.devices[0] - else: - return torch.cuda.current_device() - - def get_last_device(model): - if model.devices: - return model.devices[-1] - else: - return torch.cuda.current_device() - - print( - f"Number of parameters for model: {sum(p.numel() for p in model.parameters())}" - ) - for i, batch in enumerate(lm_dataloader): - bi = batch["input"] - if args.max_batch and i > args.max_batch: - break - optimizer.zero_grad() - try: - tmp = batch["input"].to(get_first_device(model)) - output = model(tmp).local_value() - except Exception as e: - raise RuntimeError( - f"training failed on {torch.distributed.get_rank()}" - ) from e - - target = batch["target"].to(get_last_device(model)) - output = output.to(target.device) - - loss = criterion(output.view(-1, vocab_size), target.view(-1)) - loss.backward() - del target - del output - - torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) - optimizer.step() - - total_loss += loss.item() - log_interval = 1 - word_counter += batch["ntokens"] - if i % log_interval == 0 and i > 0: - cur_loss = total_loss / log_interval - elapsed = time.time() - start_time - print( - f"| batch {i:5d} | wps {word_counter / elapsed:5.2f} | loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}" - ) - word_counter = 0 - total_loss = 0 - start_time = time.time() - - print("Peak memory usage for GPUs: ", end="") - for i in range(len(model.devices)): - print( - f"cuda:{i}: {sizeof_fmt(torch.cuda.memory_stats(i)['allocated_bytes.all.peak'])}, ", - end="", - ) - print() - - -def generate_balance(num_devices, num_layers): - balance = [] - layers_assigned = 0 - for i in range(num_devices): - x = (num_layers - layers_assigned) / (num_devices - i) - if x.is_integer(): - balance.append(int(x)) - layers_assigned += x - else: - balance.append(math.ceil(x)) - layers_assigned += math.ceil(x) - return balance - - -def make_model_and_data(args, device): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - vocab_size = 10000 - model, criterion, optimizer = make_model(args, device, vocab_size) - lm_dataset = BenchmarkLMDataset() - lm_dataloader = DataLoader( - lm_dataset, - batch_size=args.batch_size, - shuffle=True, - num_workers=0, - collate_fn=collate_sentences_lm, - ) - return { - "model": model, - "criterion": criterion, - "optimizer": optimizer, - "data": lm_dataloader, - "vocab_size": vocab_size, - } - - -def bench_single_process(args): - os.environ.update({"MASTER_ADDR": args.host}) - os.environ.update({"MASTER_PORT": "10638"}) - - rpc.init_rpc( - "worker", - rank=0, - world_size=1, - ) - - num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 - num_devices = min(args.num_devices, num_devices) - assert num_devices > 0 - init_random_seed(0) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - blob = make_model_and_data(args, None) - model = blob["model"] - - balance = generate_balance(num_devices, len(model)) - model = partition_model(model, balance) - p = Pipe(model, chunks=args.chunks, checkpoint=args.checkpoint) - del model - del blob["model"] - - train( - blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args - ) - - -parser = argparse.ArgumentParser(description="benchmark") -parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname") -parser.add_argument( - "--chunks", type=int, default=4, help="number of microbatches per batch" -) -parser.add_argument("--batch-size", type=int, default=8, help="size of a batch") -parser.add_argument("--max-batch", type=int, default=10, help="Max number of batches") -parser.add_argument( - "--num-decoder-layers", - type=int, - default=10, - help="Number of decoder layers in the model", -) -parser.add_argument( - "--checkpoint", - default="except_last", - choices=["always", "except_last", "never"], - help="Checkpointing strategy for pipe", -) -parser.add_argument( - "--num-devices", type=int, default=4, help="Number of GPU devices to use" -) - -if __name__ == "__main__": - args = parser.parse_args() - print(f"Running benchmark with args: {args}") - bench_single_process(args) diff --git a/docs/source/conf.py b/docs/source/conf.py index ef492f17c506..4f73c111cb23 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -606,47 +606,6 @@ # torch.distributed.optim.utils "as_functional_optim", "register_functional_optim", - # torch.distributed.pipeline.sync.checkpoint - "checkpoint", - "enable_checkpointing", - "enable_recomputing", - "is_checkpointing", - "is_recomputing", - "restore_rng_states", - "save_rng_states", - # torch.distributed.pipeline.sync.dependency - "fork", - "join", - # torch.distributed.pipeline.sync.microbatch - "check", - "gather", - "scatter", - # torch.distributed.pipeline.sync.phony - "get_phony", - # torch.distributed.pipeline.sync.skip.layout - "inspect_skip_layout", - # torch.distributed.pipeline.sync.skip.tracker - "current_skip_tracker", - "use_skip_tracker", - # torch.distributed.pipeline.sync.stream - "as_cuda", - "current_stream", - "default_stream", - "get_device", - "is_cuda", - "new_stream", - "record_stream", - "use_device", - "use_stream", - "wait_stream", - # torch.distributed.pipeline.sync.utils - "partition_model", - # torch.distributed.pipeline.sync.worker - "create_workers", - "spawn_workers", - "worker", - # torch.distributed.pipelining.PipelineSchedule - "step", # torch.distributed.rendezvous "register_rendezvous_handler", "rendezvous", @@ -2650,52 +2609,6 @@ "PostLocalSGDOptimizer", # torch.distributed.optim.zero_redundancy_optimizer "ZeroRedundancyOptimizer", - # torch.distributed.pipeline.sync.batchnorm - "DeferredBatchNorm", - # torch.distributed.pipeline.sync.checkpoint - "Checkpoint", - "Checkpointing", - "Context", - "Function", - "Recompute", - "ThreadLocal", - # torch.distributed.pipeline.sync.copy - "Context", - "Copy", - "Wait", - # torch.distributed.pipeline.sync.dependency - "Fork", - "Join", - # torch.distributed.pipeline.sync.microbatch - "Batch", - "NoChunk", - # torch.distributed.pipeline.sync.pipe - "BalanceError", - "Pipe", - "PipeSequential", - "WithDevice", - # torch.distributed.pipeline.sync.pipeline - "Pipeline", - # torch.distributed.pipeline.sync.skip.layout - "SkipLayout", - # torch.distributed.pipeline.sync.skip.namespace - "Namespace", - # torch.distributed.pipeline.sync.skip.portal - "Context", - "Portal", - "PortalBlue", - "PortalCopy", - "PortalOrange", - # torch.distributed.pipeline.sync.skip.skippable - "Skippable", - # torch.distributed.pipeline.sync.skip.tracker - "SkipTracker", - "SkipTrackerThroughPotals", - "ThreadLocal", - # torch.distributed.pipeline.sync.stream - "CPUStreamType", - # torch.distributed.pipeline.sync.worker - "Task", # torch.distributed.rpc.api "AllGatherStates", "RRef", diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 48f66b5d3276..a8203a5f3b2c 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -299,12 +299,6 @@ You can implement your own pipeline schedule by extending one of the following t For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``. Whereas, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS`` are subclasses of ``PipelineScheduleMulti``. -.. currentmodule:: torch.distributed.pipelining.PipelineSchedule - -.. autoclass:: PipelineScheduleSingle - -.. autoclass:: PipelineScheduleMulti - API Reference ************* @@ -370,3 +364,9 @@ Pipeline Schedules .. autoclass:: ScheduleInterleaved1F1B .. autoclass:: ScheduleLoopedBFS + +.. autoclass:: PipelineScheduleSingle + :members: + +.. autoclass:: PipelineScheduleMulti + :members: diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 0b091d567031..f4c73b9381e5 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -876,9 +876,6 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.nn.api .. py:module:: torch.distributed.nn.jit .. py:module:: torch.distributed.nn.jit.templates -.. py:module:: torch.distributed.pipeline -.. py:module:: torch.distributed.pipeline.sync -.. py:module:: torch.distributed.pipeline.sync.skip .. py:module:: torch.distributed.tensor .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks @@ -964,22 +961,6 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.optim.post_localSGD_optimizer .. py:module:: torch.distributed.optim.utils .. py:module:: torch.distributed.optim.zero_redundancy_optimizer -.. py:module:: torch.distributed.pipeline.sync.batchnorm -.. py:module:: torch.distributed.pipeline.sync.checkpoint -.. py:module:: torch.distributed.pipeline.sync.copy -.. py:module:: torch.distributed.pipeline.sync.dependency -.. py:module:: torch.distributed.pipeline.sync.microbatch -.. py:module:: torch.distributed.pipeline.sync.phony -.. py:module:: torch.distributed.pipeline.sync.pipe -.. py:module:: torch.distributed.pipeline.sync.pipeline -.. py:module:: torch.distributed.pipeline.sync.skip.layout -.. py:module:: torch.distributed.pipeline.sync.skip.namespace -.. py:module:: torch.distributed.pipeline.sync.skip.portal -.. py:module:: torch.distributed.pipeline.sync.skip.skippable -.. py:module:: torch.distributed.pipeline.sync.skip.tracker -.. py:module:: torch.distributed.pipeline.sync.stream -.. py:module:: torch.distributed.pipeline.sync.utils -.. py:module:: torch.distributed.pipeline.sync.worker .. py:module:: torch.distributed.remote_device .. py:module:: torch.distributed.rendezvous .. py:module:: torch.distributed.rpc.api diff --git a/docs/source/index.rst b/docs/source/index.rst index ea704f20c3af..dcaadcbb63ed 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -103,7 +103,6 @@ Features described in this documentation are classified by release status: optim complex_numbers ddp_comm_hooks - pipeline quantization rpc torch.random diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst deleted file mode 100644 index 94d730ee223d..000000000000 --- a/docs/source/pipeline.rst +++ /dev/null @@ -1,85 +0,0 @@ -.. _pipeline-parallelism: - -Pipeline Parallelism -==================== - -Pipeline parallelism was original introduced in the -`Gpipe `__ paper and is an efficient -technique to train large models on multiple GPUs. - -.. warning :: - torch.distributed.pipeline is deprecated, so is this document. For - up-to-date pipeline parallel implementation, please refer to the - `PiPPy `__ library under the PyTorch - organization (Pipeline Parallelism for PyTorch). - -Model Parallelism using multiple GPUs -------------------------------------- - -Typically for large models which don't fit on a single GPU, model parallelism -is employed where certain parts of the model are placed on different GPUs. -Although, if this is done naively for sequential models, the training process -suffers from GPU under utilization since only one GPU is active at one time as -shown in the figure below: - -.. figure:: _static/img/pipeline_parallelism/no_pipe.png - - The figure represents a model with 4 layers placed on 4 different GPUs - (vertical axis). The horizontal axis represents training this model through - time demonstrating that only 1 GPU is utilized at a time - (`image source `__). - -Pipelined Execution -------------------- - -To alleviate this problem, pipeline parallelism splits the input minibatch into -multiple microbatches and pipelines the execution of these microbatches across -multiple GPUs. This is outlined in the figure below: - -.. figure:: _static/img/pipeline_parallelism/pipe.png - - The figure represents a model with 4 layers placed on 4 different GPUs - (vertical axis). The horizontal axis represents training this model through - time demonstrating that the GPUs are utilized much more efficiently. - However, there still exists a bubble (as demonstrated in the figure) where - certain GPUs are not utilized. - (`image source `__). - -Pipe APIs in PyTorch --------------------- -.. autoclass:: torch.distributed.pipeline.sync.Pipe - :members: forward - -Skip connections -^^^^^^^^^^^^^^^^ - -Certain models like `ResNeXt `__ -are not completely sequential and have skip connections between layers. -Naively implementing as part of pipeline parallelism would imply that -we need to copy outputs for certain layers through multiple GPUs till -we eventually reach the GPU where the layer for the skip connection resides. -To avoid this copy overhead, we provide APIs below to stash and pop Tensors -in different layers of the model. - -.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.skippable -.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.stash -.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.pop -.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.verify_skippables - -Tutorials ---------- - -The following tutorials give a good overview of how to use the -:class:`~torch.distributed.pipeline.sync.Pipe` API to train your models with the -rest of the components that PyTorch provides: - -- `Training Transformer models using Pipeline Parallelism `__ -- `Training Transformer models using Distributed Data Parallel and Pipeline Parallelism `__ - -Acknowledgements ----------------- - -The implementation for pipeline parallelism is based on `fairscale's pipe implementation `__ and -`torchgpipe `__. We would like to -thank both teams for their contributions and guidance towards bringing pipeline -parallelism into PyTorch. diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 0ead16868f2f..8bedc0072300 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -211,30 +211,6 @@ "torch.distributed.optim.utils": [ "Type" ], - "torch.distributed.pipeline.sync.pipe": [ - "Pipeline" - ], - "torch.distributed.pipeline.sync.skip.layout": [ - "SkipLayout", - "inspect_skip_layout" - ], - "torch.distributed.pipeline.sync.skip.portal": [ - "Context", - "Portal", - "PortalBlue", - "PortalCopy", - "PortalOrange" - ], - "torch.distributed.pipeline.sync.skip.skippable": [ - "Skippable" - ], - "torch.distributed.pipeline.sync.skip.tracker": [ - "SkipTracker", - "SkipTrackerThroughPotals", - "ThreadLocal", - "current_skip_tracker", - "use_skip_tracker" - ], "torch.distributed.remote_device": [ "Optional", "Union" @@ -1695,10 +1671,6 @@ "get_args_parser", "run" ], - "torch.distributed.pipeline.sync": [ - "NoChunk", - "WithDevice" - ], "torch.distributed.rpc.rref_proxy": [ "Future", "partial", diff --git a/test/distributed/pipeline/sync/LICENSE b/test/distributed/pipeline/sync/LICENSE deleted file mode 100644 index e52be240fdc9..000000000000 --- a/test/distributed/pipeline/sync/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright 2019-2020 Kakao Brain - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from this - software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/test/distributed/pipeline/sync/__init__.py b/test/distributed/pipeline/sync/__init__.py deleted file mode 100644 index 94cd5bcb415e..000000000000 --- a/test/distributed/pipeline/sync/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH. -# See also: https://docs.pytest.org/en/latest/goodpractices.html diff --git a/test/distributed/pipeline/sync/conftest.py b/test/distributed/pipeline/sync/conftest.py deleted file mode 100644 index 4f2479b27b29..000000000000 --- a/test/distributed/pipeline/sync/conftest.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import tempfile - -import pytest - -import torch -import torch.distributed as dist - - -@pytest.fixture(autouse=True) -def manual_seed_zero(): - torch.manual_seed(0) - - -@pytest.fixture(scope="session") -def cuda_sleep(): - # Warm-up CUDA. - torch.empty(1, device="cuda") - - # From test/test_cuda.py in PyTorch. - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - torch.cuda._sleep(1000000) - end.record() - end.synchronize() - cycles_per_ms = 1000000 / start.elapsed_time(end) - - def cuda_sleep(seconds): - torch.cuda._sleep(int(seconds * cycles_per_ms * 1000)) - - return cuda_sleep - - -def pytest_report_header(): - return f"torch: {torch.__version__}" - - -@pytest.fixture -def setup_rpc(scope="session"): - file = tempfile.NamedTemporaryFile() - dist.rpc.init_rpc( - name="worker0", - rank=0, - world_size=1, - rpc_backend_options=dist.rpc.TensorPipeRpcBackendOptions( - init_method=f"file://{file.name}", - ), - ) - yield - dist.rpc.shutdown() - - -def pytest_ignore_collect(path, config): - "Skip this directory if distributed modules are not enabled." - return not dist.is_available() diff --git a/test/distributed/pipeline/sync/skip/__init__.py b/test/distributed/pipeline/sync/skip/__init__.py deleted file mode 100644 index ab03724cafbf..000000000000 --- a/test/distributed/pipeline/sync/skip/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/test/distributed/pipeline/sync/skip/test_api.py b/test/distributed/pipeline/sync/skip/test_api.py deleted file mode 100644 index be38d6d83dac..000000000000 --- a/test/distributed/pipeline/sync/skip/test_api.py +++ /dev/null @@ -1,52 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import copy - -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, skippable, stash -from torch.testing._internal.common_utils import run_tests - - -def test_namespace_difference(): - ns1 = Namespace() - ns2 = Namespace() - assert ns1 != ns2 - - -def test_namespace_copy(): - ns = Namespace() - assert copy.copy(ns) == ns - assert copy.copy(ns) is not ns - - -def test_skippable_repr(): - @skippable(stash=["hello"]) - class Hello(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 1, 1) - - def forward(self, x): - yield stash("hello", x) - return self.conv(x) # noqa: B901 - - m = Hello() - assert ( - repr(m) - == """ -@skippable(Hello( - (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1)) -)) -""".strip() - ) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_gpipe.py b/test/distributed/pipeline/sync/skip/test_gpipe.py deleted file mode 100644 index 4f433ab38941..000000000000 --- a/test/distributed/pipeline/sync/skip/test_gpipe.py +++ /dev/null @@ -1,126 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.portal import ( - PortalBlue, - PortalCopy, - PortalOrange, -) -from torch.distributed.pipeline.sync.utils import partition_model -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -@pytest.mark.parametrize( - "balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"] -) -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_1to3(balance, checkpoint, setup_rpc): - if torch.cuda.device_count() < len(balance): - pytest.skip("at least %d cuda devices required" % len(balance)) - - @skippable(stash=["1to3"]) - class Layer1(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - yield stash("1to3", input) - output = self.conv(input) - return output # noqa: B901 - - class Layer2(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - output = self.conv(input) - return output - - @skippable(pop=["1to3"]) - class Layer3(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - skip_1to3 = yield pop("1to3") - output = self.conv(input) + skip_1to3 - return output - - model = nn.Sequential(Layer1(), Layer2(), Layer3()) - model = partition_model(model, balance) - model = Pipe(model, chunks=3, checkpoint=checkpoint) - - in_device = model.devices[0] - out_device = model.devices[-1] - - input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) - output = model(input) - loss = output.local_value().mean() - loss.backward() - - assert torch.allclose( - output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1 - ) - assert torch.allclose( - input.grad.norm(), torch.tensor(0.0004533053, device=in_device) - ) - - -def test_none_skip(setup_rpc): - @skippable(stash=["none"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("none", None) - return input # noqa: B901 - - @skippable(pop=["none"]) - class Pop(nn.Module): - def forward(self, input): - none = yield pop("none") - assert none is None - return input - - model = nn.Sequential(Stash(), Pop()) - model = Pipe(model, chunks=5) - - input = torch.rand(10, requires_grad=True) - output = model(input) - - def assert_grad_fn_is_not_portal(grad_fn, visited=None): - if visited is None: - visited = set() - if grad_fn in visited or grad_fn is None: - return - - assert not isinstance(grad_fn, PortalBlue._backward_cls) - assert not isinstance(grad_fn, PortalCopy._backward_cls) - assert not isinstance(grad_fn, PortalOrange._backward_cls) - - visited.add(grad_fn) - for next_grad_fn, _ in grad_fn.next_functions: - assert_grad_fn_is_not_portal(next_grad_fn, visited) - - assert_grad_fn_is_not_portal(output.local_value().grad_fn) - - output.local_value().sum().backward() - assert input.grad.mean().item() == 1 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py b/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py deleted file mode 100644 index 4d542285cd5a..000000000000 --- a/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py +++ /dev/null @@ -1,118 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, pop, skippable, stash -from torch.distributed.pipeline.sync.skip.layout import inspect_skip_layout -from torch.testing._internal.common_utils import run_tests - - -class Pass(nn.Module): - def forward(self, input): - return input - - -@skippable(stash=["foo"]) -class StashFoo(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input # noqa: B901 - - -@skippable(pop=["foo"]) -class PopFoo(nn.Module): - def forward(self, input): - foo = yield stash("foo") - return input + foo - - -@skippable(stash=["bar"]) -class StashBar(nn.Module): - def forward(self, input): - yield stash("bar", input) - return input # noqa: B901 - - -@skippable(pop=["bar"]) -class PopBar(nn.Module): - def forward(self, input): - bar = yield pop("bar") - return input + bar - - -def test_no_skippables(): - p1 = nn.Sequential(Pass()) - p2 = nn.Sequential(Pass()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], []] - - -def test_inner_partition(): - p1 = nn.Sequential(StashFoo(), PopFoo()) - p2 = nn.Sequential(Pass()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], []] - - -def test_adjoining_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(PopFoo()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], [(0, None, "foo")]] - - -def test_far_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(Pass()) - p3 = nn.Sequential(PopFoo()) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - assert policy == [[], [], [(0, None, "foo")]] - - -def test_pop_2_from_different_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(StashBar()) - p3 = nn.Sequential(PopBar(), PopFoo()) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. - assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]] - - -def test_namespace(): - ns1 = Namespace() - ns2 = Namespace() - - p1 = nn.Sequential(StashFoo().isolate(ns1)) - p2 = nn.Sequential(StashFoo().isolate(ns2)) - p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1)) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. - assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]] - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_leak.py b/test/distributed/pipeline/sync/skip/test_leak.py deleted file mode 100644 index f4d1043e0549..000000000000 --- a/test/distributed/pipeline/sync/skip/test_leak.py +++ /dev/null @@ -1,136 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import is_checkpointing, is_recomputing, Pipe -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.tracker import current_skip_tracker -from torch.testing._internal.common_utils import run_tests - - -@skippable(stash=["skip"]) -class Stash(nn.Module): - def forward(self, input): - yield stash("skip", input) - return input # noqa: B901 - - -@skippable(pop=["skip"]) -class Pop(nn.Module): - def forward(self, input): - skip = yield pop("skip") - return input + skip - - -@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) -def test_delete_portal_tensor(train, checkpoint, setup_rpc): - # Without checkpointing: - # +- Stash --+ +--- Pop ----+ - - - layers - # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function - # +----------+ +------------+ - # - # With checkpointing: - # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ - # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | - # +----------+ +------------+ +------------+ +----------+ - - def portal_tensor_life_is(tensor_life, skip_tracker=None): - if skip_tracker is None: - skip_tracker = current_skip_tracker() - - # Get the current portal. - portal = next(iter(skip_tracker.portals.values())) - - if tensor_life == 0: - return portal.tensor_life == 0 and portal.tensor is None - else: - return portal.tensor_life == tensor_life and portal.tensor is not None - - # Check the portal tensor after 'Stash'. - stash_ = Stash() - - @stash_.register_forward_hook - def check_portal_tensor_after_stash(*_): - if is_checkpointing(): - assert portal_tensor_life_is(2) - elif is_recomputing(): - assert portal_tensor_life_is(0) - else: - assert portal_tensor_life_is(1) - - pop_ = Pop() - - @pop_.register_forward_hook - def check_portal_tensor_after_pop(*_): - if is_checkpointing(): - assert portal_tensor_life_is(1) - elif is_recomputing(): - assert portal_tensor_life_is(0) - else: - assert portal_tensor_life_is(0) - - class NoPortalTensorAtBackward(nn.Module): - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - ctx.skip_tracker = current_skip_tracker() - return input.detach() - - @staticmethod - def backward(ctx, grad): - assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) - return grad - - def forward(self, input): - return self.F.apply(input) - - model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) - model = Pipe(model, chunks=2, checkpoint=checkpoint) - - input = torch.rand(10, requires_grad=True) - - if train: - model.train() - output = model(input).local_value() - output.norm().backward() - else: - model.eval() - with torch.no_grad(): - model(input) - - -@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -def test_no_portal_without_pipe(train, monkeypatch, setup_rpc): - def deny(*args, **kwargs): - raise AssertionError("tried to create Portal without Pipe") - - monkeypatch.setattr( - "torch.distributed.pipeline.sync.skip.portal.Portal.__init__", deny - ) - - model = nn.Sequential(Stash(), Pop()) - - input = torch.rand(10, requires_grad=True) - - if train: - model.train() - output = model(input) - output.norm().backward() - else: - model.eval() - with torch.no_grad(): - model(input) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_portal.py b/test/distributed/pipeline/sync/skip/test_portal.py deleted file mode 100644 index 5ad180b6f9c8..000000000000 --- a/test/distributed/pipeline/sync/skip/test_portal.py +++ /dev/null @@ -1,163 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.dependency import fork, join -from torch.distributed.pipeline.sync.skip.portal import Portal -from torch.distributed.pipeline.sync.stream import default_stream -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_copy_returns_on_next_device(): - portal = Portal(torch.rand(1), tensor_life=1) - - prev_stream = default_stream(torch.device("cpu")) - next_stream = default_stream(torch.device("cuda")) - - phony = torch.zeros(0, requires_grad=True) - assert phony.device.type == "cpu" - - phony = portal.copy(prev_stream, next_stream, phony) - assert phony.device.type == "cuda" - - -def test_blue_orange(): - tensor1 = torch.rand(1, requires_grad=True) - tensor2 = torch.rand(1, requires_grad=True) - - # Same with: output = tensor1*2 + tensor2 - # - # +----------------------+ - # | | - # tensor2 -- PortalBlue -+ +- PortalOrange -+ - # | | | - # tensor1 ------------ Join -- Fork --- Mul --- Add -- output - # - main = tensor1 - portal = Portal(tensor2, tensor_life=2) - phony = portal.blue() - main = join(main, phony) - main, phony = fork(main) - sub = portal.orange(phony) - output = main * 2 + sub - - output.backward() - - assert torch.allclose(tensor1.grad, torch.tensor([2.0])) - assert torch.allclose(tensor2.grad, torch.tensor([1.0])) - - -def test_blue_orange_not_requires_grad(): - tensor1 = torch.rand(1, requires_grad=True) - tensor2 = torch.rand(1) - - # Same with: output = tensor1*2 + tensor2 - # - # +----------------------+ - # | | - # tensor2 -- PortalBlue -+ +- PortalOrange -+ - # | | | - # tensor1 ------------ Join -- Fork --- Mul --- Add -- output - # - main = tensor1 - portal = Portal(tensor2, tensor_life=2) - phony = portal.blue() - main = join(main, phony) - main, phony = fork(main) - sub = portal.orange(phony) - output = main * 2 + sub - - output.backward() - - assert torch.allclose(tensor1.grad, torch.tensor([2.0])) - assert tensor2.grad is None - - -def test_use_grad(): - tensor = torch.rand(1, requires_grad=True) - portal = Portal(tensor, tensor_life=1) - - portal.put_grad(tensor) - assert portal.use_grad() is tensor - - # Gradient in a portal is ephemeral. - with pytest.raises(RuntimeError): - portal.use_grad() - - -class TestTensorLife: - @pytest.fixture - def new_portal(self): - portal = None - - def new_portal(tensor_life): - nonlocal portal - tensor = torch.rand(1, requires_grad=True) - portal = Portal(tensor, tensor_life) - return portal, tensor - - yield new_portal - - # A test using this fixture must exhaust the tensor in the portal. - with pytest.raises(RuntimeError): - portal.check_tensor_life() - assert portal.tensor is None - - def test_tensor_life_0(self, new_portal): - portal, tensor = new_portal(0) - assert portal.tensor is None - - def test_tensor_life_1(self, new_portal): - portal, tensor = new_portal(1) - assert portal.tensor is tensor - - portal.blue() - - def test_tensor_life_2(self, new_portal): - portal, tensor = new_portal(2) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - def test_tensor_life_3(self, new_portal): - portal, tensor = new_portal(3) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - def test_tensor_life_4(self, new_portal): - portal, tensor = new_portal(4) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - portal.blue() - - def test_tensor_life_3_plus_1(self, new_portal): - portal, tensor = new_portal(3) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - another_tensor = torch.rand(1, requires_grad=True) - portal.put_tensor(another_tensor, tensor_life=1) - portal.blue() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_stash_pop.py b/test/distributed/pipeline/sync/skip/test_stash_pop.py deleted file mode 100644 index 5d273860f6a6..000000000000 --- a/test/distributed/pipeline/sync/skip/test_stash_pop.py +++ /dev/null @@ -1,144 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker -from torch.testing._internal.common_utils import run_tests - - -@pytest.fixture(autouse=True) -def skip_tracker(): - skip_tracker = SkipTracker() - with use_skip_tracker(skip_tracker): - yield skip_tracker - - -def test_stash(skip_tracker): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - l1 = Stash() - - assert len(skip_tracker.tensors) == 0 - - with use_skip_tracker(skip_tracker): - l1(torch.tensor(42)) - - assert len(skip_tracker.tensors) == 1 - - -def test_pop(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - l1 = Stash() - l2 = Pop() - - output = l2(l1(torch.tensor(42))) - - assert output.item() == 42 - - -def test_declare_but_not_use(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - return input * 2 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - return input * 3 - - l1 = Stash() - l2 = Pop() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - with pytest.raises(RuntimeError): - l2(torch.tensor(42)) - - -def test_stash_not_declared(): - @skippable() - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - l1 = Stash() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - -def test_pop_not_declared(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable() - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - l1 = Stash() - l2 = Pop() - - latent = l1(torch.tensor(42)) - - with pytest.raises(RuntimeError): - l2(latent) - - -def test_pop_not_stashed(): - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - yield pop("foo") - - l1 = Pop() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - -def test_stash_none(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", None) - return input * 2 # noqa: B901 - - l1 = Stash() - l1(torch.tensor(42)) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_tracker.py b/test/distributed/pipeline/sync/skip/test_tracker.py deleted file mode 100644 index 9c3a970f7574..000000000000 --- a/test/distributed/pipeline/sync/skip/test_tracker.py +++ /dev/null @@ -1,145 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import threading -from queue import Queue - -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync.checkpoint import ( - enable_checkpointing, - enable_recomputing, -) -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.layout import SkipLayout -from torch.distributed.pipeline.sync.skip.tracker import ( - current_skip_tracker, - SkipTracker, - SkipTrackerThroughPotals, -) -from torch.testing._internal.common_utils import run_tests - - -def test_default_skip_tracker(): - q = Queue() - - def f(): - q.put(current_skip_tracker()) - - t = threading.Thread(target=f) - t.start() - t.join() - - skip_tracker = q.get() - - assert type(skip_tracker) is SkipTracker - assert type(skip_tracker) is not SkipTrackerThroughPotals - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_default_skip_tracker_by_data_parallel(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - model = nn.Sequential(Stash(), Pop()) - model = nn.DataParallel(model, device_ids=[0, 0], output_device=0) - - input = torch.rand(10, device=0) - output = model(input) - - assert torch.allclose(output, input) - - -def test_reuse_portal(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - a = torch.tensor([2.0]) - b = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "test", a) - portal = skip_tracker.portals[(None, "test")] - - skip_tracker.save(batch, None, "test", b) - assert portal is skip_tracker.portals[(None, "test")] - - -def test_no_copy_no_portal(): - skip_layout = SkipLayout( - num_partitions=2, - skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)}, - ) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - a = torch.tensor([2.0]) - b = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "copy", a) - skip_tracker.save(batch, None, "not_copy", b) - - assert (None, "copy") in skip_tracker.portals - assert (None, "copy") not in skip_tracker.tensors - assert (None, "not_copy") in skip_tracker.tensors - assert (None, "not_copy") not in skip_tracker.portals - - -def test_tensor_life_without_checkpointing(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - tensor = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 1 - - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - -def test_tensor_life_with_checkpointing(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - tensor = torch.tensor([2.0]) - - with enable_checkpointing(): - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 2 - - with enable_checkpointing(): - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 1 - - with enable_recomputing(): - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - with enable_recomputing(): - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_verify_skippables.py b/test/distributed/pipeline/sync/skip/test_verify_skippables.py deleted file mode 100644 index 1d5941487da8..000000000000 --- a/test/distributed/pipeline/sync/skip/test_verify_skippables.py +++ /dev/null @@ -1,165 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, skippable, verify_skippables -from torch.testing._internal.common_utils import run_tests - - -def test_matching(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - verify_skippables(nn.Sequential(Layer1(), Layer2())) - - -def test_stash_not_pop(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "no module declared 'foo' as poppable but stashed" in str(e.value) - - -def test_pop_unknown(): - @skippable(pop=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value) - - -def test_stash_again(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer3(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - assert "'1' redeclared 'foo' as stashable" in str(e.value) - - -def test_pop_again(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer3(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - assert "'2' redeclared 'foo' as poppable" in str(e.value) - - -def test_stash_pop_together_different_names(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"], stash=["bar"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["bar"]) - class Layer3(nn.Module): - pass - - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - - -def test_stash_pop_together_same_name(): - @skippable(stash=["foo"], pop=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value) - - -def test_double_stash_pop(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer3(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer4(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4())) - assert "'2' redeclared 'foo' as stashable" in str(e.value) - assert "'3' redeclared 'foo' as poppable" in str(e.value) - - -def test_double_stash_pop_but_isolated(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer3(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer4(nn.Module): - pass - - ns1 = Namespace() - ns2 = Namespace() - - verify_skippables( - nn.Sequential( - Layer1().isolate(ns1), - Layer2().isolate(ns1), - Layer3().isolate(ns2), - Layer4().isolate(ns2), - ) - ) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_balance.py b/test/distributed/pipeline/sync/test_balance.py deleted file mode 100644 index faf09f4581ae..000000000000 --- a/test/distributed/pipeline/sync/test_balance.py +++ /dev/null @@ -1,240 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import time - -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync._balance import ( - balance_by_size, - balance_by_time, - blockpartition, -) -from torch.distributed.pipeline.sync._balance.profile import layerwise_sandbox -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - -devices = ["cpu"] -if torch.cuda.is_available(): - devices.append("cuda") - - -def test_blockpartition(): - assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [ - [1, 2, 3, 4], - [5, 6], - ] - - -def test_blockpartition_zeros(): - assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]] - - -def test_blockpartition_non_positive_partitions(): - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=0) - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=-1) - - -def test_blockpartition_short_sequence(): - with pytest.raises(ValueError): - blockpartition.solve([], partitions=1) - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=2) - - -@pytest.mark.parametrize("device", devices) -@pytest.mark.skip(reason="Flaky due to time.sleep()") -def test_balance_by_time(device): - class Delay(nn.Module): - def __init__(self, seconds): - super().__init__() - self.seconds = seconds - - def forward(self, x): - time.sleep(self.seconds) - return x - - model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]]) - sample = torch.rand(1) - balance = balance_by_time(2, model, sample, device=device) - assert balance == [4, 2] - - -def test_balance_by_time_loop_resets_input(): - # nn.Flatten was introduced at PyTorch 1.2.0. - class Flatten(nn.Module): - def forward(self, x): - return x.flatten(1) - - model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10)) - sample = torch.rand(10, 3, 8, 8) - balance = balance_by_time(2, model, sample, device="cpu") - assert balance == [1, 2] - - -@skip_if_no_cuda -def test_balance_by_size_latent(): - class Expand(nn.Module): - def __init__(self, times): - super().__init__() - self.times = times - - def forward(self, x): - for i in range(self.times): - x = x + torch.rand_like(x, requires_grad=True) - return x - - sample = torch.rand(10, 100, 100) - - model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]]) - balance = balance_by_size(2, model, sample) - assert balance == [4, 2] - - model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]]) - balance = balance_by_size(2, model, sample) - assert balance == [2, 4] - - -@skip_if_no_cuda -def test_balance_by_size_param(): - model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)]) - sample = torch.rand(7, 1) - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [4, 2] - - model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))]) - sample = torch.rand(1, 7) - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [2, 4] - - -@skip_if_no_cuda -def test_balance_by_size_param_scale(): - class Tradeoff(nn.Module): - def __init__(self, param_size, latent_size): - super().__init__() - self.fc = nn.Linear(param_size, param_size) - self.latent_size = latent_size - - def forward(self, x): - for i in range(self.latent_size): - x = x + torch.rand_like(x, requires_grad=True) - return x - - model = nn.Sequential( - Tradeoff(param_size=1, latent_size=6), - Tradeoff(param_size=2, latent_size=5), - Tradeoff(param_size=3, latent_size=4), - Tradeoff(param_size=4, latent_size=3), - Tradeoff(param_size=5, latent_size=2), - Tradeoff(param_size=6, latent_size=1), - ) - - sample = torch.rand(1, requires_grad=True) - - balance = balance_by_size(2, model, sample, param_scale=0) - assert balance == [2, 4] - - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [4, 2] - - -@pytest.mark.parametrize("device", devices) -def test_layerwise_sandbox(device): - model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) - model.eval() - - for layer in layerwise_sandbox(model, torch.device(device)): - assert layer.training - assert all(p.device.type == device for p in layer.parameters()) - - assert all(not l.training for l in model) - assert all(p.device.type == "cpu" for p in model.parameters()) - - -@pytest.mark.parametrize("device", devices) -def test_sandbox_during_profiling(device): - model = nn.Sequential(nn.BatchNorm2d(3)) - - before = {k: v.clone() for k, v in model.state_dict().items()} - - sample = torch.rand(1, 3, 10, 10) - balance_by_time(1, model, sample, device=device) - - after = model.state_dict() - - assert before.keys() == after.keys() - for key, value in before.items(): - assert torch.allclose(after[key], value), key - - -def test_not_training(): - class AssertTraining(nn.Module): - def forward(self, x): - assert self.training - return x - - model = nn.Sequential(AssertTraining()) - - model.eval() - assert not model.training - - sample = torch.rand(1) - balance_by_time(1, model, sample, device="cpu") - - assert not model.training - - -def test_balance_by_time_tuple(): - class Twin(nn.Module): - def forward(self, x): - return x, x.detach() - - class Add(nn.Module): - def forward(self, a, b): - return a + b - - model = nn.Sequential(Twin(), Add()) - sample = torch.rand(1, requires_grad=True) - balance_by_time(1, model, sample, device="cpu") - - -@skip_if_no_cuda -def test_balance_by_size_tuple(): - class Twin(nn.Module): - def forward(self, x): - return x, x.detach() - - class Add(nn.Module): - def forward(self, a, b): - return a + b - - model = nn.Sequential(Twin(), Add()) - sample = torch.rand(1, requires_grad=True) - balance_by_size(1, model, sample) - - -def test_already_has_grad(): - model = nn.Sequential(nn.Conv2d(3, 3, 1)) - sample = torch.rand(1, 3, 32, 32) - model(sample).norm().backward() - - with pytest.raises(ValueError, match="some parameter already has gradient"): - balance_by_time(1, model, sample, device="cpu") - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_bugs.py b/test/distributed/pipeline/sync/test_bugs.py deleted file mode 100644 index 928a78db6e32..000000000000 --- a/test/distributed/pipeline/sync/test_bugs.py +++ /dev/null @@ -1,146 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -import torch.nn.functional as F -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_utils import run_tests - - -def test_python_autograd_function(setup_rpc): - # A Python autograd function might fail with this error: - # - # RuntimeError: Returning Variables sharing storage with other Variables - # that require grad is not supported in Python functions. Please submit a - # feature request if you hit this error. - # - # It doesn't look like an essential restriction. But it happens on the - # current PyTorch version. To avoid it, we should detach the tensor before - # returning by identity autograd functions, such as Wait, Fork, and Join. - # - class Identity(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad): - return grad - - class M(nn.Module): - def forward(self, input): - return Identity.apply(input) - - model = nn.Sequential(M(), M()) - model = Pipe(model, checkpoint="always") - - x = torch.rand(42) - y = model(x) - assert torch.allclose(x, y.local_value()) - - -def test_exception_no_hang(setup_rpc): - # In v0.0.2, once a failed partition receives a normal message - # (non-closing) for the next micro-batch, a hang occurred. The reason was - # that a failed partition didn't call in_queue.task_done() on a normal - # message. So the former partition was blocked at out_queue.join() for the - # next of next micro-batch. - class ExpectedException(Exception): - pass - - class Pass(nn.Module): - def forward(self, x): - return x - - class Raise(nn.Module): - def forward(self, x): - raise ExpectedException - - model = nn.Sequential(Pass(), Pass(), Raise()) - model = Pipe(model, chunks=3) - - with pytest.raises(ExpectedException): - model(torch.rand(3)) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="2 cuda devices required") -def test_tuple_wait(cuda_sleep, setup_rpc): - # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. - # Under this behavior, if checkpointing was disabled, there's a possibility - # that gradient accumulations on other tensors are not synchronized - # properly to the copy stream. - class Sleep(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - return x.detach() - - @staticmethod - def backward(ctx, grad): - with torch.cuda.device(grad.device): - cuda_sleep(0.05) - return grad - - class Layer1(nn.Module): - def __init__(self): - super().__init__() - self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) - - def forward(self, a, b): - a = a * self.ones - return a * 1, b * 2, b * 3 - - class Layer2(nn.Module): - def __init__(self): - super().__init__() - self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) - - def forward(self, a, b, c): - a = a * self.ones - b = Sleep.apply(b) - return a + b + c - - model = nn.Sequential(Layer1().cuda(0), Layer2().cuda(1)) - model = Pipe(model, chunks=32, checkpoint="never") - - a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) - b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) - - y = model(a, b) - y.local_value().norm().backward() - - torch.cuda.synchronize(0) - torch.cuda.synchronize(1) - - assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000)) - - -def test_parallel_randoms(setup_rpc): - class Dropouts(nn.Module): - def forward(self, x): - for _ in range(100): - x = F.dropout(x, p=0.001) - return x - - model = nn.Sequential(Dropouts(), Dropouts()) - - x = torch.rand(10, 10, requires_grad=True) - model = Pipe(model, chunks=10, checkpoint="always") - y = model(x) - y = y.local_value() - y.norm().backward() - - assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_checkpoint.py b/test/distributed/pipeline/sync/test_checkpoint.py deleted file mode 100644 index 7be8ddefafe9..000000000000 --- a/test/distributed/pipeline/sync/test_checkpoint.py +++ /dev/null @@ -1,178 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from functools import partial - -import pytest - -import torch -import torch.cuda -from torch import nn - -from torch.distributed.pipeline.sync.checkpoint import ( - checkpoint, - Checkpointing, - is_checkpointing, - is_recomputing, -) -from torch.distributed.pipeline.sync.dependency import fork, join -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.testing._internal.common_utils import run_tests - -devices = ["cpu"] -if torch.cuda.is_available(): - devices.append("cuda") - - -@pytest.mark.parametrize("device", devices) -def test_serial_checkpoints(device): - # Copied from https://github.com/pytorch/pytorch/pull/18568. - timeline = [] - - class Log(torch.autograd.Function): - @staticmethod - def forward(ctx, name, x): - ctx.name = name - timeline.append(f"{name}:forward") - return x.detach() - - @staticmethod - def backward(ctx, grad_output): - name = ctx.name - timeline.append(f"{name}:backward") - return None, grad_output - - a = torch.rand(1, device=device, requires_grad=True) - b = torch.rand(1, device=device, requires_grad=True) - - # Increase the next function sequence number. - _ = a + 1 + 2 + 3 + 4 + 5 - - a = checkpoint(partial(Log.apply, "a"), a) - - a, phony = fork(a) - b = join(b, phony) - - b = checkpoint(partial(Log.apply, "b"), b) - - c = torch.cat((a, b)) - - out = c.sum() - - # +--> {a} --Checkpoint(Log)--> {a} - # {out} --Sum--> {c} --Cat ^-----------------------------+ - # +--> {b} --Checkpoint(Log)--> {b} --First--> {b} - out.backward() - - assert timeline == [ - "a:forward", - "b:forward", - "b:forward", - "b:backward", - "a:forward", - "a:backward", - ] - # |----------------------| |-----------------------| |-----------------------| - # forward pass Checkpoint(Log[b]) Checkpoint(Log[a]) - - -def test_not_requires_grad(): - x = Batch(torch.rand(1, requires_grad=False)) - assert not x[0].requires_grad - - def f(x): - return x * 2 - - chk = Checkpointing(f, x) - x = chk.checkpoint() - assert x[0].requires_grad - - chk.recompute(x) - assert x[0].requires_grad - - x.tensor.backward() - - -def test_not_requires_grad_with_parameter(): - x = torch.rand(1, requires_grad=False) - a = torch.rand(1, requires_grad=True) - - def f(x): - return x * a - - y = checkpoint(f, x) - y.backward() - - assert a.grad is not None - - -@pytest.mark.parametrize("device", devices) -def test_random_in_checkpoint(device): - dropout = nn.Dropout(p=0.5) - - torch.manual_seed(0) - x = torch.randn(3, 3, device=device, requires_grad=True) - y = dropout(x) - y.norm().backward() - - torch.manual_seed(0) - chk_x = torch.randn(3, 3, device=device, requires_grad=True) - chk_y = checkpoint(dropout, chk_x) - chk_y.norm().backward() - - assert torch.allclose(x.grad, chk_x.grad) - - -def test_detect_checkpointing_recomputing(): - logs = [] - - class Detect(nn.Module): - def forward(self, input): - logs.append((is_checkpointing(), is_recomputing())) - return input - - model = Detect() - input = torch.rand(1, requires_grad=True) - - output = checkpoint(model, input) - output.backward() - - assert logs == [(True, False), (False, True)] - - -def test_detect_checkpointing_recomputing_without_checkpoint(): - logs = [] - - class Detect(nn.Module): - def forward(self, input): - logs.append((is_checkpointing(), is_recomputing())) - return input - - model = Detect() - input = torch.rand(1, requires_grad=True) - - output = model(input) - output.backward() - - assert logs == [(False, False)] - - -def test_non_grad_output(): - class ForkNonGrad(nn.Module): - def forward(self, input): - return (input * 2, torch.rand(1)) - - model = ForkNonGrad() - input = torch.rand(1, requires_grad=True) - - output = checkpoint(model, input) - output[0].backward() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_copy.py b/test/distributed/pipeline/sync/test_copy.py deleted file mode 100644 index 302c3d25d53f..000000000000 --- a/test/distributed/pipeline/sync/test_copy.py +++ /dev/null @@ -1,85 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.copy import Copy, Wait -from torch.distributed.pipeline.sync.stream import ( - CPUStream, - current_stream, - get_device, - is_cuda, - new_stream, - use_stream, -) -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - - -def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None): - device = get_device(prev_stream) - - with use_stream(prev_stream): - if is_cuda(prev_stream): - cuda_sleep(0.5) - x = torch.ones(100, device=device, requires_grad=True) - - (y,) = Copy.apply(prev_stream, next_stream, x) - (y,) = Wait.apply(prev_stream, next_stream, x) - - with use_stream(next_stream): - assert torch.allclose(y.sum(), torch.tensor(100.0, device=device)) - y.norm().backward() - with use_stream(prev_stream): - assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device)) - - -def test_copy_wait_cpu_cpu(): - prev_stream = CPUStream - next_stream = CPUStream - _test_copy_wait(prev_stream, next_stream) - - -@skip_if_no_cuda -def test_copy_wait_cpu_cuda(cuda_sleep): - prev_stream = CPUStream - next_stream = current_stream(torch.device("cuda")) - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -@skip_if_no_cuda -def test_copy_wait_cuda_cpu(cuda_sleep): - prev_stream = current_stream(torch.device("cuda")) - next_stream = CPUStream - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -@skip_if_no_cuda -def test_copy_wait_cuda_cuda(cuda_sleep): - prev_stream = current_stream(torch.device("cuda")) - next_stream = new_stream(torch.device("cuda")) - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -def test_wait_multiple_tensors(): - a = torch.rand(1, requires_grad=True) - b = torch.rand(1, requires_grad=True) - - a, b = Wait.apply(CPUStream, CPUStream, a, b) - - assert a.grad_fn is b.grad_fn - assert a.grad_fn.__class__ is Wait._backward_cls - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_deferred_batch_norm.py b/test/distributed/pipeline/sync/test_deferred_batch_norm.py deleted file mode 100644 index c3807c57d612..000000000000 --- a/test/distributed/pipeline/sync/test_deferred_batch_norm.py +++ /dev/null @@ -1,200 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from copy import deepcopy -from itertools import chain - -import pytest - -import torch -from torch import nn, optim - -from torch.distributed.pipeline.sync.batchnorm import DeferredBatchNorm -from torch.testing._internal.common_utils import run_tests - -CHUNKS = 4 - - -def tilt_dist(input): - # Tilt variance by channel. - rgb = input.transpose(0, 1) - rgb[0] *= 1 - rgb[1] *= 10 - rgb[2] *= 100 - - # Tilt mean by single batch. - for i, single in enumerate(input): - single += 2**i - - return input - - -def chunked_forward(model, input, chunks=CHUNKS): - output_chunks = [] - - for chunk in input.chunk(chunks): - output_chunks.append(model(chunk)) - - return torch.cat(output_chunks) - - -@pytest.mark.parametrize("chunks", [1, 4]) -@pytest.mark.parametrize("input_requires_grad", [True, False]) -def test_transparency(chunks, input_requires_grad): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks) - - input1 = torch.rand(16, 3, 224, 224) - input1 = tilt_dist(input1) - input2 = input1.clone() - input1.requires_grad = input_requires_grad - input2.requires_grad = input_requires_grad - - output1 = chunked_forward(bn, input1, chunks=chunks) - output2 = chunked_forward(dbn, input2, chunks=chunks) - - assert torch.allclose(output1, output2, atol=1e-4) - - output1.mean().backward() - output2.mean().backward() - - assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4) - - if input_requires_grad: - assert input1.grad is not None - assert input2.grad is not None - assert torch.allclose(input1.grad, input2.grad, atol=1e-4) - - -@pytest.mark.parametrize("momentum", [0.1, None]) -def test_running_stats(momentum): - bn = nn.BatchNorm2d(3, momentum=momentum) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - bn(input) - chunked_forward(dbn, input) - - assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4) - assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4) - - -def test_convert_deferred_batch_norm(): - bn = nn.BatchNorm2d(3, track_running_stats=False) - bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS) - assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False - - dbn = DeferredBatchNorm(3, chunks=CHUNKS) - dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS) - assert dbn is dbn_again - - dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1) - assert dbn is not dbn_again # because of different chunks - - -def test_eval(): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - bn(input) - chunked_forward(dbn, input) - - bn.eval() - dbn.eval() - - assert torch.allclose(bn(input), dbn(input), atol=1e-4) - - -def test_optimize(): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0) - - for i in range(5): - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - # train - y = bn(input) - a = y.sum() - a.backward() - - y = chunked_forward(dbn, input) - b = y.sum() - b.backward() - - opt.step() - - # eval - bn.eval() - dbn.eval() - - with torch.no_grad(): - assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10**i)) - - -def test_conv_bn(): - bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1) - - # 1st step - a = bn(input) - b = chunked_forward(dbn, input) - - # Outputs are different. (per-mini-batch vs. per-micro-batch) - assert not torch.allclose(a, b) - - a.sum().backward() - b.sum().backward() - opt.step() - opt.zero_grad() - - # Conv layers are also trained differently because of their different outputs. - assert not torch.allclose(bn[0].weight, dbn[0].weight) - - # But BNs track identical running stats. - assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) - assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) - - # 2nd step - a = bn(input) - b = chunked_forward(dbn, input) - a.sum().backward() - b.sum().backward() - - # BNs can't track identical running stats due to the different conv layers. - assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) - assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) - - -def test_input_requiring_grad(): - dbn = DeferredBatchNorm(3, chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - input.requires_grad = True - - chunked_forward(dbn, input) - - assert not dbn.sum.requires_grad - assert dbn.sum.grad_fn is None - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_dependency.py b/test/distributed/pipeline/sync/test_dependency.py deleted file mode 100644 index e966d6541bf5..000000000000 --- a/test/distributed/pipeline/sync/test_dependency.py +++ /dev/null @@ -1,152 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import weakref - -import pytest - -import torch - -from torch.distributed.pipeline.sync.dependency import Fork, fork, Join, join -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_fork_join(): - logs = [] - - class Log(torch.autograd.Function): - @staticmethod - def forward(ctx, number, tensor): - ctx.number = number - return tensor.detach() - - @staticmethod - def backward(ctx, grad): - logs.append(ctx.number) - return None, grad - - a = torch.rand(1, device="cpu", requires_grad=True) - b = torch.rand(1, device="cuda", requires_grad=True) - - a = Log.apply(1, a) - - a, phony = fork(a) - b = join(a, phony) - - b = Log.apply(2, b) - b = b.to("cpu") - - (a + b).backward() - - assert logs == [2, 1] - - -def test_fork_join_enable_grad(): - x = torch.rand(1, requires_grad=True) - - with torch.enable_grad(): - x2, p = fork(x) - - assert p.requires_grad - assert x2 is not x - x = x2 - - assert x.requires_grad - assert p.requires_grad - assert x.grad_fn.__class__ is Fork._backward_cls - assert p.grad_fn.__class__ is Fork._backward_cls - - with torch.enable_grad(): - x2 = join(x, p) - - assert x2 is not x - x = x2 - - assert x.requires_grad - assert x.grad_fn.__class__ is Join._backward_cls - - -def test_fork_join_no_grad(monkeypatch): - def do_not_apply(*args): - raise AssertionError("Function.apply called") - - monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply) - - x = torch.rand(1, requires_grad=True) - - with torch.no_grad(): - x2, p = fork(x) - - assert not p.requires_grad - assert x2 is x - x = x2 - - with torch.no_grad(): - x2 = join(x, p) - - assert x2 is x - x = x2 - - -def test_fork_leak(): - leak = None - - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad): - nonlocal leak - leak = weakref.ref(ctx) - return grad - - x = torch.rand(1, requires_grad=True) - x = F.apply(x) - x, phony = fork(x) - x = join(x, phony) - - x.backward() - del x, phony - - assert leak() is None - - -def test_join_when_fork_not_requires_grad(): - x = torch.rand(2, 1) - a, b = x.chunk(2) - - assert not a.requires_grad - a, p = fork(a) - assert not a.requires_grad - assert not p.requires_grad - - assert not b.requires_grad - b = join(b, p) - assert not b.requires_grad - - -def test_join_when_fork_requires_grad(): - x = torch.rand(2, 1) - a, b = x.chunk(2) - - a.requires_grad_() - assert a.requires_grad - a, p = fork(a) - assert a.requires_grad - assert p.requires_grad - - assert not b.requires_grad - b = join(b, p) - assert b.requires_grad - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_inplace.py b/test/distributed/pipeline/sync/test_inplace.py deleted file mode 100644 index 33f31b2a52bb..000000000000 --- a/test/distributed/pipeline/sync/test_inplace.py +++ /dev/null @@ -1,79 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_utils import run_tests - - -def test_inplace_on_requires_grad(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) - model = Pipe(model, checkpoint="always") - - x = torch.rand(1) - y = model(x).local_value() - - message = r"a leaf Variable that requires grad .* used in an in-place operation." - with pytest.raises(RuntimeError, match=message): - y.backward() - - -@pytest.mark.xfail(strict=True) -def test_inplace_on_not_requires_grad(setup_rpc): - # In-place operation on a tensor not requiring grad doesn't cause a - # RuntimeError. Currently, we cannot detect this case. - model = nn.Sequential(nn.ReLU(inplace=True)) - model = Pipe(model, [1], devices=["cpu"], checkpoint="always") - - x = torch.rand(1) - y = model(x).local_value() - del model - - message = r"a leaf Variable that requires grad .* used in an in-place operation." - with pytest.raises(RuntimeError, match=message): - y.backward() - - -@pytest.mark.xfail(strict=True) -def test_inplace_incorrect_grad(setup_rpc): - class M(nn.Module): - def forward(self, foo_bar): - # 'foo' requires grad but 'bar' does not. In-place operation on - # 'bar' won't cause a RuntimeError. - foo, bar = foo_bar - - # add_(1) is not idempotent, in contrast to relu_(). If it is - # executed multiple times, it will accumulates each difference onto - # 'bar'. - bar.add_(1) - - # 'bar' is still captured by checkpointing. 'foo' will get - # incorrect grad. - return foo * bar - - model = nn.Sequential(M()) - model = Pipe(model, [1], devices=["cpu"], checkpoint="always") - - foo = torch.tensor([1.0], requires_grad=True) - bar = torch.tensor([1.0]) - - output = model((foo, bar)).local_value() - del model - output.backward() - - # The gradient of 'foo' should be 2, but it is 3 actually because - # bar.add_(1) was executed twice due to checkpointing. - assert foo.grad.item() == 2.0 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_microbatch.py b/test/distributed/pipeline/sync/test_microbatch.py deleted file mode 100644 index b5e44aa73a8d..000000000000 --- a/test/distributed/pipeline/sync/test_microbatch.py +++ /dev/null @@ -1,148 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -import torch.cuda - -from torch.distributed.pipeline.sync.microbatch import Batch, check, gather, scatter -from torch.testing._internal.common_utils import run_tests - - -def test_batch_atomic(): - x = torch.tensor(42) - b = Batch(x) - - assert b.atomic - - assert b.tensor is x - with pytest.raises(AttributeError): - b.tensors - - assert list(b) == [x] - assert len(b) == 1 - assert b[0] is x - - -def test_batch_non_atomic(): - x, y = torch.tensor(42), torch.tensor(21) - b = Batch((x, y)) - - assert not b.atomic - - with pytest.raises(AttributeError): - b.tensor - - assert list(b) == [x, y] - assert len(b) == 2 - assert b[0] is x - assert b[1] is y - - -def test_batch_call(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - def f(x): - return x - - def g(x, y): - return x, y - - assert a.call(f).atomic - assert not b.call(g).atomic - - -def test_batch_setitem_by_index(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - a[0] = torch.tensor(0) - b[0] = torch.tensor(0) - - assert a.atomic - assert a[0].item() == 0 - - assert not b.atomic - assert len(b) == 2 - assert b[0].item() == 0 - assert b[1].item() == 21 - - -def test_batch_setitem_by_slice(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - a[:] = (torch.tensor(0),) - b[:] = (torch.tensor(0),) - - assert a.atomic - assert a[0].item() == 0 - - assert not b.atomic - assert len(b) == 1 - assert b[0].item() == 0 - - -def test_check(): - check(torch.device("cpu"), torch.tensor(42)) - check(torch.device("cpu"), torch.tensor(4), torch.tensor(2)) - - with pytest.raises(TypeError): - check(torch.device("cpu"), 42) - - with pytest.raises(TypeError): - check(torch.device("cpu"), "str") - - with pytest.raises(TypeError): - check(torch.device("cpu"), (torch.tensor(4), 2)) - - -def test_gather_tensors(): - a = torch.zeros(1, 1) - b = torch.zeros(1, 1) - - ab = gather([Batch(a), Batch(b)]) - - assert ab.size() == (2, 1) - - -def test_gather_tuples(): - a = (torch.zeros(1, 1), torch.zeros(2, 2)) - b = (torch.zeros(1, 1), torch.zeros(2, 2)) - - ab = gather([Batch(a), Batch(b)]) - - assert isinstance(ab, tuple) - assert ab[0].size() == (2, 1) - assert ab[1].size() == (4, 2) - - -def test_scatter_tensor(): - ab = torch.zeros(2, 1) - - a, b = scatter(ab, chunks=2) - - assert a.tensor.size() == (1, 1) - assert b.tensor.size() == (1, 1) - - -def test_scatter_multiple_tensors(): - ab = (torch.zeros(2, 1), torch.zeros(4, 2)) - - a, b = scatter(*ab, chunks=2) - - assert next(iter(a)).size() == (1, 1) - assert next(iter(b)).size() == (1, 1) - assert list(a)[1].size() == (2, 2) - assert list(b)[1].size() == (2, 2) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_phony.py b/test/distributed/pipeline/sync/test_phony.py deleted file mode 100644 index 6aeb873b30b2..000000000000 --- a/test/distributed/pipeline/sync/test_phony.py +++ /dev/null @@ -1,57 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import torch - -from torch.distributed.pipeline.sync.phony import get_phony -from torch.testing._internal.common_utils import run_tests - - -def test_phony_size(): - p = get_phony(torch.device("cpu"), requires_grad=False) - assert p.size() == (0,) - - -def test_phony_requires_grad(): - p1 = get_phony(torch.device("cpu"), requires_grad=True) - p2 = get_phony(torch.device("cpu"), requires_grad=False) - assert p1.requires_grad - assert not p2.requires_grad - - -def test_cached_phony(): - p1 = get_phony(torch.device("cpu"), requires_grad=True) - p2 = get_phony(torch.device("cpu"), requires_grad=True) - assert p1 is p2 - - p3 = get_phony(torch.device("cpu"), requires_grad=False) - p4 = get_phony(torch.device("cpu"), requires_grad=False) - assert p3 is p4 - - assert p1 is not p3 - - -def test_phony_in_autograd_function(): - class Phonify(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - phony = get_phony(input.device, requires_grad=False) - return phony.detach() - - x = torch.rand(1, requires_grad=True) - - p1 = Phonify.apply(x) - p2 = get_phony(torch.device("cpu"), requires_grad=True) - - assert p1 is not p2 - assert p1.grad_fn is not None - assert p2.grad_fn is None - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_pipe.py b/test/distributed/pipeline/sync/test_pipe.py deleted file mode 100644 index e493b1d5a03e..000000000000 --- a/test/distributed/pipeline/sync/test_pipe.py +++ /dev/null @@ -1,858 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import random -import time -from collections import OrderedDict -from copy import deepcopy - -import pytest - -import torch -from torch import nn, Tensor - -from torch.distributed.pipeline.sync import NoChunk, Pipe, WithDevice -from torch.distributed.pipeline.sync.pipe import PipeSequential -from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_utils import run_tests, TEST_CUDA - -skip_if_no_cuda = pytest.mark.skipif(not TEST_CUDA, reason="cuda required") - - -def test_pipe_without_rpc(): - model = nn.Sequential(nn.Linear(1, 1)) - with pytest.raises(RuntimeError, match="Please initialize RPC framework"): - pipe = Pipe(model, chunks=1) - - -def test_parameters(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, chunks=1) - assert list(pipe.parameters()) != [] - - -def test_public_attrs(setup_rpc): - class MyString: - def __init__(self, value): - self.value = value - - def __str__(self): - return self.value - - model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always")) - - assert pipe.devices == [torch.device("cpu")] - assert pipe.chunks == 42 - assert isinstance(pipe.chunks, int) - assert pipe.checkpoint == "always" - assert isinstance(pipe.checkpoint, str) - - -def test_sequential_like(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - assert len(model) == 2 - assert list(model) == [a, b] - - assert model[0] is a - assert model[1] is b - with pytest.raises(IndexError): - _ = model[2] - - assert model[-1] is b - assert model[-2] is a - - -def test_chunks_less_than_1(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - with pytest.raises(ValueError): - Pipe(model, chunks=0) - - with pytest.raises(ValueError): - Pipe(model, chunks=-1) - - -def test_batch_size_indivisible(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=4) - - with pytest.warns(None) as record: - model(torch.rand(7, 1)) - - # Indivisible batch size is legal. - assert not record - - -def test_batch_size_small(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=4) - - with pytest.warns(None) as record: - model(torch.rand(2, 1)) - - # Batch size smaller than chunks is legal. - assert not record - - -def test_checkpoint_mode(setup_rpc): - def count_grad_fn(grad_fn, name, visited=None): - if visited is None: - visited = set() - if grad_fn in visited: - return 0 - visited.add(grad_fn) - - if grad_fn is None: - return 0 - if grad_fn.__class__.__name__ == name: - return 1 - - counter = 0 - for next_grad_fn, _ in grad_fn.next_functions: - counter += count_grad_fn(next_grad_fn, name, visited=visited) - return counter - - model = nn.Sequential(nn.Linear(1, 1)) - input = torch.rand(2, 1) - - always = Pipe(model, chunks=2, checkpoint="always") - except_last = Pipe(model, chunks=2, checkpoint="except_last") - never = Pipe(model, chunks=2, checkpoint="never") - - always_output = always(input) - except_last_output = except_last(input) - never_output = never(input) - - assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2 - assert ( - count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") - == 1 - ) - assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0 - - -def test_checkpoint_mode_invalid(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - with pytest.raises( - ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'" - ): - Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT") - - -def test_checkpoint_mode_when_chunks_1(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - # All checkpoint modes are fine. - Pipe(model, chunks=1, checkpoint="except_last") - Pipe(model, chunks=1, checkpoint="always") - Pipe(model, chunks=1, checkpoint="never") - - -def test_checkpoint_eval(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=2) - input = torch.rand(2, 1) - - def find_grad_fn(grad_fn, name): - if grad_fn is None: - return False - if grad_fn.__class__.__name__ == name: - return True - for next_grad_fn, _ in grad_fn.next_functions: - if find_grad_fn(next_grad_fn, name): - return True - return False - - model.train() - train_output = model(input) - assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward") - assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward") - - model.eval() - eval_output = model(input) - assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward") - assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward") - - -def test_checkpoint_non_float_input(setup_rpc): - class ForkNonFloat(nn.Module): - def forward(self, input): - return (input * 2, torch.tensor([False])) - - class JoinNonFloat(nn.Module): - def forward(self, input, non_float): - return input * 2 - - model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) - model = Pipe(model, chunks=1, checkpoint="always") - - input = torch.rand(1, requires_grad=True) - output = model(input) - output.backward() - - -def test_no_grad(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=2) - input = torch.rand(2, 1) - - latent = None - - def hook(module, input, output): - _ = module - _ = input - - nonlocal latent - latent = output - - partition = model.partitions[0] - partition.register_forward_hook(hook) - - with torch.no_grad(): - model(input) - - assert latent.grad_fn is None - - -def test_exception(setup_rpc): - class ExpectedException(Exception): - pass - - class Raise(nn.Module): - def forward(self, *_): - raise ExpectedException - - model = nn.Sequential(Raise()) - model = Pipe(model, chunks=1) - - with pytest.raises(ExpectedException): - model(torch.rand(1)) - - -def test_exception_early_stop_asap(setup_rpc): - """Even the first partitions have finished to process, the partition before - the failed partition should be killed as soon as possible. - """ - - class ExpectedException(Exception): - pass - - class Pass(nn.Module): - def forward(self, x): - return x - - counter = 0 - - class Counter(nn.Module): - def forward(self, x): - time.sleep(0.1) - - nonlocal counter - counter += 1 - - return x - - class Raise(nn.Module): - def forward(self, x): - raise ExpectedException - - model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) - model = Pipe(model, chunks=3) - - with pytest.raises(ExpectedException): - model(torch.rand(3)) - - # If the early stop doesn't work, it would be 3 instead. - assert counter == 2 - - -def test_nested_input(setup_rpc): - class NestedInput(nn.Module): - def __init__(self): - super().__init__() - self.fc_a = nn.Linear(1, 1) - self.fc_b = nn.Linear(1, 1) - - def forward(self, inp): - return inp - - model = nn.Sequential(NestedInput()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - b = torch.rand(10, 1, requires_grad=True) - - # TypeError: expected Tensor, but got tuple - with pytest.raises(TypeError): - model((a, (a, b))).local_value() - - # TypeError: expected Tensor, but got list - with pytest.raises(TypeError): - model((a, [a, b])).local_value() - - -def test_input_pair(setup_rpc): - class Two(nn.Module): - def __init__(self): - super().__init__() - self.fc_a = nn.Linear(1, 1) - self.fc_b = nn.Linear(1, 1) - - def forward(self, a, b): - return (self.fc_a(a), self.fc_b(b)) - - model = nn.Sequential(Two()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - b = torch.rand(10, 1, requires_grad=True) - - a_out, b_out = model(a, b).local_value() - loss = (a_out + b_out).mean() - loss.backward() - - assert a.grad is not None - assert b.grad is not None - - -def test_multi_sequence_input(setup_rpc): - class MultiSeq(nn.Module): - def forward(self, tup1, tup2): - return tup1, tup2 - - model = Pipe(nn.Sequential(MultiSeq())) - with pytest.raises(TypeError): - model([torch.rand(10), torch.rand(10)], [torch.rand(10), torch.rand(10)]) - - -def test_input_singleton(setup_rpc): - class One(nn.Module): - def __init__(self): - super().__init__() - self.fc = nn.Linear(1, 1) - - def forward(self, a): - return (self.fc(a),) - - model = nn.Sequential(One()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - - (a_out,) = model(a).local_value() - loss = a_out.mean() - loss.backward() - - assert all(p.grad is not None for p in model.parameters()) - assert a.grad is not None - - -def test_input_varargs(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model) - - a = torch.rand(1) - b = torch.rand(1) - - # TypeError: forward() takes 2 positional arguments but 3 were given - with pytest.raises(TypeError): - model(a, b) - - -def test_non_tensor(setup_rpc): - class NonTensor(nn.Module): - def forward(self, _): - return "hello" - - model = nn.Sequential(NonTensor()) - model = Pipe(model) - x = torch.rand(1) - - with pytest.raises(TypeError): - model(x) - - with pytest.raises(TypeError): - model("hello") - - -def test_non_tensor_sequence(setup_rpc): - class NonTensorTuple(nn.Module): - def forward(self, x): - return (x, "hello") - - class NonTensorArgs(nn.Module): - def forward(self, x: str, y: bool): - return x, y - - model = nn.Sequential(NonTensorTuple()) - model = Pipe(model) - x = torch.rand(1) - - with pytest.raises(TypeError): - model((x, "hello")) - - with pytest.raises(TypeError): - model([x, "hello"]) - - model = nn.Sequential(NonTensorArgs()) - model = Pipe(model) - - with pytest.raises(TypeError): - # Need atleast one Tensor. - model("hello", True) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_valid_non_tensor(checkpoint, setup_rpc): - class NonTensor1(nn.Module): - def forward(self, a: int, b: Tensor, c: bool, d: Tensor): - res = b + a if c else b * a - if d is not None: - res += d - return res, c, a, b, "hello", d - - class NonTensor2(nn.Module): - def forward(self, a: Tensor, b: bool, c: int, d: Tensor, e: str, f: Tensor): - res = a * c if b else a + c - res += d - return c, res, a, d + f if f is not None else d, b, e, f - - model = Pipe( - nn.Sequential(NonTensor1(), NonTensor2()), chunks=5, checkpoint=checkpoint - ) - a = random.randint(0, 10) - b = torch.rand(10, 10) - c = random.randint(0, 1) == 0 - d = torch.rand(10, 10) - res = model(a, b, c, d).local_value() - assert 7 == len(res) - assert [a] * 5 == res[0] - if c: - assert torch.allclose(((b + a + d) * a) + b, res[1]) - assert torch.allclose(b + a + d, res[2]) - else: - assert torch.allclose(((b * a) + d + a) + b, res[1]) - assert torch.allclose(b * a + d, res[2]) - assert torch.allclose(b + d, res[3]) - assert [c] * 5 == res[4] - assert ["hello"] * 5 == res[5] - assert torch.allclose(d, res[6]) - - # Test one of the tensors can be None - res = model(a, b, c, None).local_value() - assert 7 == len(res) - assert [a] * 5 == res[0] - if c: - assert torch.allclose(((b + a) * a) + b, res[1]) - assert torch.allclose(b + a, res[2]) - else: - assert torch.allclose(((b * a) + a) + b, res[1]) - assert torch.allclose(b * a, res[2]) - assert torch.allclose(b, res[3]) - assert [c] * 5 == res[4] - assert ["hello"] * 5 == res[5] - assert [None] * 5 == res[6] - - # Need atleast one tensor. - with pytest.raises(TypeError): - model(a, None, c, None) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_no_tensor_output(checkpoint, setup_rpc): - class Model1(nn.Module): - def forward(self, a: int, b: Tensor, c: bool): - return a, c, "hello" - - class Model2(nn.Module): - def forward(self, a: int, b: bool, c: str): - return a, c, b - - model = Pipe(nn.Sequential(Model1(), Model2()), chunks=5) - a = random.randint(0, 10) - b = torch.rand(10, 10) - c = random.randint(0, 1) == 0 - - # Need atleast one tensor across partitions too. - with pytest.raises(TypeError): - res = model(a, b, c).local_value() - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_uneven_batch_size(checkpoint, setup_rpc): - class Model(nn.Module): - def forward(self, a: Tensor, b: int, c: Tensor): - return a, b, c - - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(3, 10) - b = random.randint(0, 10) - c = torch.rand(6, 10) - res = model(a, b, c).local_value() - assert torch.allclose(a, res[0]) - assert [b] * 3 == res[1] # 3 chunks - assert torch.allclose(c, res[2]) - - # Two tensors producing uneven chunks would fail. - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(3, 10) - b = random.randint(0, 10) - c = torch.rand(4, 10) - - with pytest.raises(RuntimeError, match="Found different number of chunks"): - model(a, b, c) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_no_chunk(checkpoint, setup_rpc): - class Model(nn.Module): - def forward(self, a: Tensor, b: int, c: Tensor): - return a, b, c - - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(10, 10) - b = random.randint(0, 10) - c = torch.rand(10, 10) - res = model(a, b, NoChunk(c)).local_value() - assert torch.allclose(a, res[0]) - assert [b] * 5 == res[1] - # c gets replicated due to NoChunk and the same tensor gets concatenated 5 - # times in the output. - assert torch.allclose(torch.cat((c, c, c, c, c)), res[2]) - - # Test invalid type for NoChunk - with pytest.raises(TypeError, match="NoChunk only supported for tensors"): - NoChunk(b) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_deferred_batch_norm(checkpoint, setup_rpc): - bn = nn.BatchNorm2d(3) - pipe_bn = deepcopy(bn) - pipe = Pipe( - nn.Sequential(pipe_bn), - chunks=2, - checkpoint=checkpoint, - deferred_batch_norm=True, - ) - - x = torch.rand(4, 3, 10, 10) - pipe(x).local_value().mean().backward() - bn(x).mean().backward() - - assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) - assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4) - - -@pytest.mark.parametrize("checkpoint", ["never", "always"]) -def test_deferred_batch_norm_params(checkpoint, setup_rpc): - bn = nn.BatchNorm2d(3) - pipe_bn = deepcopy(bn) - pipe = Pipe( - nn.Sequential(pipe_bn), - chunks=1, - checkpoint=checkpoint, - deferred_batch_norm=True, - ) - - x = torch.rand(4, 3, 10, 10) - pipe(x).local_value().mean().backward() - bn(x).mean().backward() - - assert pipe[0].weight.grad is not None - assert pipe[0].bias.grad is not None - - assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4) - assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4) - - -def test_devices(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - c = nn.Linear(1, 1) - - # There are extra two devices. - model = nn.Sequential(a, b, c) - model = Pipe(model) - - cpu = torch.device("cpu") - # Extra devices must be discarded. - assert model.devices == [cpu, cpu, cpu] - - -def test_partitions(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - assert isinstance(model.partitions, nn.ModuleList) - assert isinstance(model.partitions[0], nn.Sequential) - assert isinstance(model.partitions[1], nn.Sequential) - - assert "partitions.0.0.weight" in model.state_dict() - - -@skip_if_no_cuda -def test_merged_partitions(setup_rpc): - a = nn.Linear(1, 1).to(0) - b = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 2)).to(0) - c = nn.Linear(1, 1) - d = nn.Linear(1, 2) - - model = nn.Sequential(a, b, c, d) - model = Pipe(model) - - assert isinstance(model.partitions, nn.ModuleList) - assert isinstance(model.partitions[0], PipeSequential) - assert isinstance(model.partitions[1], PipeSequential) - assert list(model.partitions[0]) == [a, b[0], b[1]] - assert list(model.partitions[1]) == [c] - assert list(model.partitions[2]) == [d] - - -def test_deny_moving(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - # Moving is denied. - with pytest.raises(TypeError): - model.cuda() - - with pytest.raises(TypeError): - model.cpu() - - with pytest.raises(TypeError): - model.to(torch.device("cuda")) - - with pytest.raises(TypeError): - model.to(0) - - with pytest.raises(TypeError): - model.to("cuda") - - with pytest.raises(TypeError): - model.to(device=0) - - with pytest.raises(TypeError): - model.to(torch.rand(1)) - - with pytest.raises(TypeError): - model.to(tensor=torch.rand(1)) - - # Casting is allowed. - model.half() - model.to(torch.double) - model.to(dtype=torch.float) - - -def test_empty_module(setup_rpc): - # Empty sequential module is not illegal. - model = nn.Sequential() - model = Pipe(model) - - assert model(torch.tensor(42)).local_value() == torch.tensor(42) - - # But only tensor or tensors is legal in Pipe. - with pytest.raises(TypeError): - model(42) - - -def test_named_children(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) - model = Pipe(model) - - names = {n for n, _ in model.named_modules()} - assert "partitions.0.0" in names - assert "partitions.1.0" in names - - # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires - # several methods in its namespace. - with pytest.raises(AttributeError): - model.a - - -def test_verify_module_non_sequential(setup_rpc): - with pytest.raises( - TypeError, match="module must be nn.Sequential to be partitioned" - ): - Pipe(nn.Module()) - - -def test_verify_module_duplicate_children(setup_rpc): - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(conv, conv) - - with pytest.raises( - ValueError, match="module with duplicate children is not supported" - ): - Pipe(model) - - -@skip_if_no_cuda -def test_verify_module_params_on_same_device(setup_rpc): - class Surrogate(nn.Module): - def __init__(self, param1, param2): - super().__init__() - self.param1 = param1 - self.param2 = param2 - - conv1 = nn.Conv2d(3, 3, 1) - conv2 = nn.Conv2d(3, 3, 1) - model = nn.Sequential(Surrogate(conv1, conv2.cuda())) - - with pytest.raises( - ValueError, - match=r"should have all parameters on a single device, please use .to\(\)" - " to place the module on a single device", - ): - Pipe(model) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_verify_nested_modules(setup_rpc): - model = nn.Sequential( - nn.Sequential(nn.Linear(32, 16).cuda(0), nn.Linear(16, 8).cuda(0)), - nn.Sequential(nn.Linear(8, 4).cuda(1), nn.Linear(4, 2).cuda(1)), - ) - - pipe = Pipe(model) - out = pipe(torch.rand(10, 32).cuda(0)) - assert out.local_value().device == torch.device("cuda:1") - assert out.local_value().size() == torch.Size([10, 2]) - - -def test_verify_module_duplicate_parameters_on_same_device(setup_rpc): - class Surrogate(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(Surrogate(conv), Surrogate(conv)) - - Pipe(model) - - -def test_forward_lockstep(setup_rpc): - timeline = [] - - class DelayedLog(nn.Module): - def __init__(self, j, seconds): - super().__init__() - self.i = 0 - self.j = j - self.seconds = seconds - - def forward(self, x): - time.sleep(self.seconds) - - timeline.append((self.i, self.j)) - self.i += 1 - - return x - - model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1)) - model = Pipe(model, chunks=3) - model(torch.rand(3, 1)) - - # Expected timeline: (Logs are recorded at !) - # - # Partition #0: 0! 1! 2! - # Partition #1: 000! 111! 222! - # - assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)] - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -@skip_if_no_cuda -def test_multiple_inputs(checkpoint, setup_rpc): - class Module1(nn.Module): - def forward(self, a, b, c): - return a + b + c, a * b * c - - class Module2(nn.Module): - def forward(self, a, b): - return a + b - - model = Pipe( - nn.Sequential(Module1().cuda(0), Module2().cuda(0)), - chunks=2, - checkpoint=checkpoint, - ) - t = torch.rand(10) - res = model(t, t, t).local_value() - assert torch.equal(res, (t + t + t) + (t * t * t)) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_inputs_wrong_device(setup_rpc): - class Module1(nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(5)) - - def forward(self, a, b): - return a + b + self.param, b - - # Start inputs on wrong device and ensure Pipe moves them correctly. - a = torch.rand(10).cuda(1) - b = torch.rand(10).cuda(1) - model = Pipe(nn.Sequential(Module1().cuda(0), Module1().cuda(1)), chunks=2) - with pytest.raises( - ValueError, - match="All inputs should be on the same device as the first partition", - ): - model(a, b) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_with_device_wrapper(setup_rpc): - fc1 = nn.Linear(16, 8).cuda(0) - fc2 = nn.Linear(8, 4).cuda(1) - dropout = nn.Dropout() - - model = nn.Sequential(fc1, fc2, WithDevice(dropout, "cuda:1")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:1") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0"), torch.device("cuda:1")] == model.devices - - model = nn.Sequential(fc1, WithDevice(dropout, "cuda:1")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:1") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0"), torch.device("cuda:1")] == model.devices - - model = nn.Sequential(fc1, WithDevice(fc2, "cuda:0")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:0") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0")] == model.devices - assert torch.device("cuda:0") == fc2.weight.device - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_pipeline.py b/test/distributed/pipeline/sync/test_pipeline.py deleted file mode 100644 index 9548cb959db1..000000000000 --- a/test/distributed/pipeline/sync/test_pipeline.py +++ /dev/null @@ -1,36 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from torch.distributed.pipeline.sync.pipeline import _clock_cycles -from torch.testing._internal.common_utils import run_tests - - -def test_clock_cycles(): - assert list(_clock_cycles(1, 1)) == [[(0, 0)]] - assert list(_clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]] - assert list(_clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]] - - assert list(_clock_cycles(3, 3)) == [ - [(0, 0)], - [(1, 0), (0, 1)], - [(2, 0), (1, 1), (0, 2)], - [(2, 1), (1, 2)], - [(2, 2)], - ] - - assert list(_clock_cycles(4, 2)) == [ - [(0, 0)], - [(1, 0), (0, 1)], - [(2, 0), (1, 1)], - [(3, 0), (2, 1)], - [(3, 1)], - ] - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_stream.py b/test/distributed/pipeline/sync/test_stream.py deleted file mode 100644 index f9702c8e4152..000000000000 --- a/test/distributed/pipeline/sync/test_stream.py +++ /dev/null @@ -1,198 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.stream import ( - CPUStream, - current_stream, - default_stream, - get_device, - is_cuda, - new_stream, - record_stream, - use_device, - use_stream, - wait_stream, -) -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - - -class TestNewStream: - def test_new_stream_cpu(self): - stream = new_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_new_stream_cuda(self): - stream = new_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream != torch.cuda.default_stream() - - -class TestCurrentStream: - def test_current_stream_cpu(self): - stream = current_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_current_stream_cuda(self): - stream = current_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream == torch.cuda.current_stream() - - -class TestDefaultStream: - def test_default_stream_cpu(self): - stream = default_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_default_stream_cuda(self): - stream = default_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream == torch.cuda.default_stream() - - -class TestUseDevice: - def test_use_device_cpu(self): - with use_device(torch.device("cpu")): - pass - - @skip_if_no_cuda - def test_use_device_cuda(self): - with use_device(torch.device("cuda")): - pass - - -class TestUseStream: - def test_use_stream_cpu(self): - with use_stream(CPUStream): - pass - - @skip_if_no_cuda - def test_use_stream_cuda(self): - stream = new_stream(torch.device("cuda")) - with use_stream(stream): - assert current_stream(torch.device("cuda")) == stream - - -class TestGetDevice: - def test_get_device_cpu(self): - assert get_device(CPUStream).type == "cpu" - - @skip_if_no_cuda - def test_get_device_cuda(self): - stream = current_stream(torch.device("cuda")) - assert get_device(stream).type == "cuda" - - -class TestWaitStream: - def _test_wait_stream(self, source, target, cuda_sleep=None): - with use_stream(target): - if is_cuda(target): - cuda_sleep(0.5) - x = torch.ones(100, 100, device=get_device(target)) - - wait_stream(source, target) - - with use_stream(source): - assert x.sum().item() == 10000 - - def test_wait_stream_cpu_cpu(self): - source = CPUStream - target = CPUStream - self._test_wait_stream(source, target) - - @skip_if_no_cuda - def test_wait_stream_cpu_cuda(self, cuda_sleep): - source = CPUStream - target = new_stream(torch.device("cuda")) - self._test_wait_stream(source, target, cuda_sleep) - - @skip_if_no_cuda - def test_wait_stream_cuda_cpu(self, cuda_sleep): - source = new_stream(torch.device("cuda")) - target = CPUStream - self._test_wait_stream(source, target, cuda_sleep) - - @skip_if_no_cuda - def test_wait_stream_cuda_cuda(self, cuda_sleep): - source = current_stream(torch.device("cuda")) - target = new_stream(torch.device("cuda")) - self._test_wait_stream(source, target, cuda_sleep) - - -class TestRecordStream: - def test_record_stream_cpu(self): - # It should silently ignore CPU tensors. - x = torch.rand(1, device=torch.device("cpu")) - record_stream(x, CPUStream) - - @skip_if_no_cuda - def test_record_stream_cuda(self, cuda_sleep): - # This test detects unexpected block reallocation. For reliable test, - # the stream to allocate tensors is isolated. The allocator will not - # reuse free blocks which were allocated from another stream. - stream_alloc = new_stream(torch.device("cuda")) - with torch.cuda.stream(stream_alloc): - x = torch.rand(1, device=torch.device("cuda")) - - stream = new_stream(torch.device("cuda")) - record_stream(x, stream) - with use_stream(stream): - cuda_sleep(0.5) - - # 'x' is deleted at Python's perspective. But the block of 'x' is still - # required for 'stream'. 'y' shouldn't be allocated to the block. - data_ptr = x.data_ptr() - del x - stream_alloc.synchronize() - with torch.cuda.stream(stream_alloc): - y = torch.rand(1, device=torch.device("cuda")) - assert y.data_ptr() != data_ptr - - # Pause Python until 'stream' finishes tasks queued. Now the block of - # 'x' is free to be reallocated. - wait_stream(CPUStream, stream) - with torch.cuda.stream(stream_alloc): - z = torch.rand(1, device=torch.device("cuda")) - assert z.data_ptr() == data_ptr - - @skip_if_no_cuda - def test_record_stream_shifted_view(self, cuda_sleep): - # Issue: https://github.com/pytorch/pytorch/issues/27366 - stream_alloc = new_stream(torch.device("cuda")) - with torch.cuda.stream(stream_alloc): - x = torch.rand(2, device=torch.device("cuda")) - - y = x[1:] - assert y.data_ptr() > x.data_ptr() - - stream = new_stream(torch.device("cuda")) - with use_stream(stream): - cuda_sleep(0.5) - record_stream(y, stream) - - data_ptr = x.data_ptr() - del x, y - - stream_alloc.synchronize() - with torch.cuda.stream(stream_alloc): - z = torch.rand(2, device=torch.device("cuda")) - assert z.data_ptr() != data_ptr - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_transparency.py b/test/distributed/pipeline/sync/test_transparency.py deleted file mode 100644 index a87a04150fdc..000000000000 --- a/test/distributed/pipeline/sync/test_transparency.py +++ /dev/null @@ -1,55 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_utils import run_tests - - -def test_simple_linears(setup_rpc): - def sum_grad(parameters): - return sum(p.grad.sum() for p in parameters if p.grad is not None) - - def zero_grad(parameters): - for p in parameters: - p.grad = None - - inputs = torch.rand(8, 1) - model = nn.Sequential( - nn.Linear(1, 2), - nn.Linear(2, 4), - nn.Linear(4, 2), - nn.Linear(2, 1), - ) - - # Without Pipe - outputs = model(inputs) - loss = outputs.mean() - loss.backward() - - grad_without_pipe = sum_grad(model.parameters()) - - zero_grad(model.parameters()) - - # With Pipe - model = Pipe(model, chunks=4) - - outputs = model(inputs).local_value() - loss = outputs.mean() - loss.backward() - - grad_with_pipe = sum_grad(model.parameters()) - - # Both grads should be identical. - assert torch.allclose(grad_with_pipe, grad_without_pipe) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_worker.py b/test/distributed/pipeline/sync/test_worker.py deleted file mode 100644 index f82af2ea0067..000000000000 --- a/test/distributed/pipeline/sync/test_worker.py +++ /dev/null @@ -1,118 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import threading - -import pytest - -import torch - -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.distributed.pipeline.sync.stream import CPUStream -from torch.distributed.pipeline.sync.worker import spawn_workers, Task -from torch.testing._internal.common_utils import run_tests - - -class fake_device: - """A test double for :class:`torch.device`. Every fake device is different - with each other. - """ - - type = "fake" - index = None - - -def test_compute_multithreading(): - """Task.compute should be executed on multiple threads.""" - thread_ids = set() - - def log_thread_id(): - thread_id = threading.current_thread().ident - thread_ids.add(thread_id) - return Batch(()) - - with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues): - for i in range(2): - t = Task(CPUStream, compute=log_thread_id, finalize=None) - in_queues[i].put(t) - for i in range(2): - out_queues[i].get() - - assert len(thread_ids) == 2 - - -def test_compute_success(): - """Task.compute returns (True, (task, batch)) on success.""" - - def _42(): - return Batch(torch.tensor(42)) - - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - t = Task(CPUStream, compute=_42, finalize=None) - in_queues[0].put(t) - ok, (task, batch) = out_queues[0].get() - - assert ok - assert task is t - assert isinstance(batch, Batch) - assert batch[0].item() == 42 - - -def test_compute_exception(): - """Task.compute returns (False, exc_info) on failure.""" - - def zero_div(): - 0 / 0 - - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - t = Task(CPUStream, compute=zero_div, finalize=None) - in_queues[0].put(t) - ok, exc_info = out_queues[0].get() - - assert not ok - assert isinstance(exc_info, tuple) - assert issubclass(exc_info[0], ZeroDivisionError) - - -@pytest.mark.parametrize("grad_mode", [True, False]) -def test_grad_mode(grad_mode): - def detect_grad_enabled(): - x = torch.rand(1, requires_grad=torch.is_grad_enabled()) - return Batch(x) - - with torch.set_grad_enabled(grad_mode): - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - task = Task(CPUStream, compute=detect_grad_enabled, finalize=None) - in_queues[0].put(task) - - ok, (_, batch) = out_queues[0].get() - - assert ok - assert batch[0].requires_grad == grad_mode - - -def test_worker_per_device(): - cpu = torch.device("cpu") - cpu0 = torch.device("cpu", index=0) - fake1 = fake_device() - fake2 = fake_device() - - with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues): - assert len(in_queues) == len(out_queues) == 5 - - # 0: cpu, 1: cpu, 2: cpu0 - assert in_queues[0] is in_queues[1] is in_queues[2] - assert out_queues[0] is out_queues[1] is out_queues[2] - - # 3: fake1, 4: fake2 - assert in_queues[3] is not in_queues[4] - assert out_queues[3] is not out_queues[4] - - -if __name__ == "__main__": - run_tests() diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 1db0e5718ce6..8ab2ac1f511f 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -329,7 +329,6 @@ def test_modules_can_be_imported(self): "torch.testing._internal.distributed.fake_pg", "torch.testing._internal.distributed.multi_threaded_pg", "torch.testing._internal.distributed.nn.api.remote_module_test", - "torch.testing._internal.distributed.pipe_with_ddp_test", "torch.testing._internal.distributed.rpc.dist_autograd_test", "torch.testing._internal.distributed.rpc.dist_optimizer_test", "torch.testing._internal.distributed.rpc.examples.parameter_server_test", @@ -408,7 +407,6 @@ def test_modules_can_be_imported(self): "torch.distributed.nn.api.remote_module", "torch.distributed.optim", "torch.distributed.optim.optimizer", - "torch.distributed.pipeline.sync", "torch.distributed.rendezvous", "torch.distributed.rpc.api", "torch.distributed.rpc.backend_registry", diff --git a/test/test_testing.py b/test/test_testing.py index ba9558a3ddd1..1e1dce59a32e 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2245,7 +2245,6 @@ def test_circular_dependencies(self) -> None: else: ignored_modules.append("torch.distributed.nn.api.") ignored_modules.append("torch.distributed.optim.") - ignored_modules.append("torch.distributed.pipeline.") ignored_modules.append("torch.distributed.rpc.") ignored_modules.append("torch.testing._internal.dist_utils") # And these both end up with transitive dependencies on distributed diff --git a/torch/distributed/pipeline/__init__.py b/torch/distributed/pipeline/__init__.py deleted file mode 100644 index eacd2bc99d04..000000000000 --- a/torch/distributed/pipeline/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -import warnings - - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "`torch.distributed.pipeline` is deprecated. For up-to-date pipeline parallel " - "implementation, please refer to the PiPPy library under the PyTorch " - "organization (Pipeline Parallelism for PyTorch): " - "https://github.com/pytorch/PiPPy", - DeprecationWarning, - stacklevel=2, - ) diff --git a/torch/distributed/pipeline/sync/LICENSE b/torch/distributed/pipeline/sync/LICENSE deleted file mode 100644 index e52be240fdc9..000000000000 --- a/torch/distributed/pipeline/sync/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright 2019-2020 Kakao Brain - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from this - software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/torch/distributed/pipeline/sync/__init__.py b/torch/distributed/pipeline/sync/__init__.py deleted file mode 100644 index 75a80c5db0f9..000000000000 --- a/torch/distributed/pipeline/sync/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""A Pipe implementation in PyTorch.""" -from .checkpoint import is_checkpointing, is_recomputing -from .pipe import Pipe, WithDevice -from .microbatch import NoChunk - -__all__ = ["Pipe", "is_checkpointing", "is_recomputing"] diff --git a/torch/distributed/pipeline/sync/_balance/__init__.py b/torch/distributed/pipeline/sync/_balance/__init__.py deleted file mode 100644 index 8ffc657896d8..000000000000 --- a/torch/distributed/pipeline/sync/_balance/__init__.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""A helper to roughly balance a sequential module. - -Usage:: - - import torch - from torch.distributed.pipeline.sync import Pipe - from torch.distributed.pipeline.sync.balance import balance_by_time - - sample = torch.empty(128, 3, 224, 224) - balance = balance_by_time(torch.cuda.device_count(), model, sample) - - pipe = Pipe(model, balance, chunks=8) - -""" -from typing import Any, List, Union, Sequence - -import torch -from torch import Tensor -import torch.nn as nn - -from . import blockpartition -from .profile import profile_sizes, profile_times - -__all__ = ["balance_by_time", "balance_by_size"] - - -Device = Union[torch.device, int, str] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - - -def balance_cost(cost: List[int], partitions: int) -> List[int]: - partitioned = blockpartition.solve(cost, partitions) - return [len(p) for p in partitioned] - - -def balance_by_time( - partitions: int, - module: nn.Sequential, - sample: Union[List[Any], Tensor], - *, - timeout: float = 1.0, - device: Device = torch.device("cuda"), -) -> List[int]: - """Naive automatic balancing by elapsed time per layer. - :: - - sample = torch.empty(128, 3, 224, 224) - balance = balance_by_time(torch.cuda.device_count(), model, sample) - pipe = Pipe(model, balance, chunks=8) - - Args: - partitions (int): - intended number of partitions - module (torch.nn.Sequential): - sequential module to be partitioned - sample (torch.Tensor): - example input with arbitrary batch size - - Keyword Args: - timeout (float): - profiling iterates again if the timeout (in second) is not exceeded - (default: ``1.0``) - device ('cpu' or 'cuda' device): - CPU or CUDA device where each layer is profiled (default: the - current CUDA device) - - Returns: - A list of number of layers in each partition. Use it for the `balance` - parameter of :class:`~torchpipe.Pipe`. - - .. note:: - `module` and `sample` must be placed on the same device. - - """ - times = profile_times(module, sample, timeout, torch.device(device)) - return balance_cost(times, partitions) - - -def balance_by_size( - partitions: int, - module: nn.Sequential, - input: Union[List[Any], Tensor], - *, - chunks: int = 1, - param_scale: float = 2.0, - device: Device = torch.device("cuda"), -) -> List[int]: - """Naive automatic balancing by CUDA memory usage per layer. - - During training, required memory for parameters depends on which optimizer - is used. Optimizers may use buffers for each parameter to track - optimization statistics internally, such as momentum buffer in SGD. - - To get more reliable size based balance, you should specify `param_scale` - with regard to your optimizer. The default `param_scale` is 2 instead of 1 - due to gradient accumulation which is necessary for every optimizer. - - Follow this guide to choose correct `param_scale` for typical optimizers: - - ========= ============= ========================================= - Optimizer `param_scale` Internal State - ========= ============= ========================================= - SGD 2--3 (momentum_buffer) - Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq) - Adadelta 4 square_avg, acc_delta - Adagrad 3 sum - RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg) - ========= ============= ========================================= - - Here's a simple example with the Adam optimizer:: - - balance = balance_by_size( - torch.cuda.device_count(), - model, - - # Same size with mini-batch to train - torch.empty(1024, 3, 224, 224), - - # Number of micro-batches to train with Pipe - chunks=8, - - # 4 for Adam - param_scale=4.0, - ) - - pipe = Pipe(model, balance, chunks=8) - adam = Adam(pipe.parameters()) - - Args: - partitions (int): - intended number of partitions - module (torch.nn.Sequential): - sequential module to be partitioned - input (torch.Tensor): - example mini-batch with the same size to train - - Keyword Args: - chunks (int): - number of micro-batches will be used to train (default: ``1``) - param_scale (float): - how many copies of parameters would be allocated for training. It - depends on optimizer. See the above guide. (default: ``2.0``) - device ('cuda' device): - CUDA device where each layer is profiled (default: the current CUDA - device) - - Returns: - A list of number of layers in each partition. Use it for the `balance` - parameter of :class:`~torchpipe.Pipe`. - - .. note:: - `module` and `input` must be placed on the same CUDA device. - - """ - sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device)) - return balance_cost(sizes, partitions) diff --git a/torch/distributed/pipeline/sync/_balance/blockpartition.py b/torch/distributed/pipeline/sync/_balance/blockpartition.py deleted file mode 100644 index ccdf5fe4df99..000000000000 --- a/torch/distributed/pipeline/sync/_balance/blockpartition.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Implements "Block Partitions of Sequences" by Imre B\u00e1r\u00e1ny et al. - -Paper: https://arxiv.org/pdf/1308.2452.pdf - -""" -from typing import Iterator, List, Tuple - -__all__ = ["solve"] - - -def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]: - """Splits a sequence into several partitions to minimize variance for each - partition. - - The result might not be optimal. However, it can be done only in O(kn\u00b3), - where k is the number of partitions and n is the length of the sequence. - - """ - if partitions < 1: - raise ValueError(f"partitions must be a positive integer ({partitions} < 1)") - - n = len(sequence) - if n < partitions: - raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})") - - # Normalize the sequence in [0, 1]. - minimum = min(sequence) - maximum = max(sequence) - minimum - - normal_sequence: List[float] - if maximum == 0: - normal_sequence = [0 for _ in sequence] - else: - normal_sequence = [(x - minimum) / maximum for x in sequence] - - splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n] - - def block_size(i: int) -> float: - start = splits[i - 1] if i > 0 else 0 - stop = splits[i] - return sum(normal_sequence[start:stop]) - - def leaderboard() -> Iterator[Tuple[float, int]]: - return ((block_size(i), i) for i in range(partitions)) - - while True: - """ - (1) Fix p element-of [k] with M(P) = bp. So Bp is a maximal block of P. - """ - # max_size: M(P) - max_size, p = max(leaderboard()) - - while True: - """ - (2) If M(P) <= m(P) + 1, then stop. - """ - # min_size: m(P) - min_size, q = min(leaderboard()) - - if max_size <= min_size + 1: - return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)] - - """ - (3) If M(P) > m(P) + 1, then let m(P) = bq for the q element-of [k] which is - closest to p (ties broken arbitrarily). Thus Bq is a minimal block - of P. Let Bh be the block next to Bq between Bp and Bq. (Note that - Bh is a non-empty block: if it were, then m(P) = 0 and we should - have chosen Bh instead of Bq.) - """ - if p < q: - """ - So either p < q and then h = q-1 and we define P * by moving - the last element from Bh = Bq-1 to Bq, - """ - h = q - 1 - splits[h] -= 1 - else: - """ - or q < p, and then h = q + 1 and P * is obtained by moving the - first element of Bh = Bq+1 to Bq. - """ - h = q + 1 - splits[q] += 1 - - """ - Set P = P * . If p = h, then go to (1), else go to (2). - """ - if p == h: - break diff --git a/torch/distributed/pipeline/sync/_balance/profile.py b/torch/distributed/pipeline/sync/_balance/profile.py deleted file mode 100644 index fa1a0c06a8e3..000000000000 --- a/torch/distributed/pipeline/sync/_balance/profile.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Per-layer profilers.""" -import copy -import time -from typing import Any, Generator, List, Union, Sequence - -import torch -from torch import Tensor -import torch.nn as nn - -from ..microbatch import Batch - -__all__: List[str] = [] - - -Device = Union[torch.device, int, str] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - - -def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]: - """Copies layers for ease to profile. It doesn't modify the given - module. - """ - for layer in module: - layer_copy = copy.deepcopy(layer) - layer_copy.to(device) - layer_copy.train() - yield layer_copy - - -def detach(batch: Batch) -> None: - """Detaches from autograd graph.""" - for i, x in enumerate(batch): - batch[i] = x.detach().requires_grad_(x.requires_grad) - - -def profile_times(module: nn.Sequential, sample: Union[List[Any], Tensor], timeout: float, device: torch.device,) -> List[int]: - """Profiles elapsed times per layer.""" - if any(p.grad is not None for p in module.parameters()): - raise ValueError("some parameter already has gradient") - - _batch = Batch(sample) - for i, x in enumerate(_batch): - _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad) - - time_bufs: List[List[float]] = [[] for _ in module] - begun_at = time.time() - - while time.time() - begun_at < timeout: - batch = _batch - - for i, layer in enumerate(layerwise_sandbox(module, device)): - detach(batch) - - if device.type == "cuda": - torch.cuda.synchronize(device) - tick = time.time() - - # Forward - batch = batch.call(layer) - - # Backward - backward_tensors = tuple(y for y in batch if y.requires_grad) - if backward_tensors: - torch.autograd.backward(backward_tensors, backward_tensors) - - if device.type == "cuda": - torch.cuda.synchronize(device) - tock = time.time() - - time_bufs[i].append(tock - tick) - - us = 1_000_000 - return [sum(int(t * us) for t in buf) for buf in time_bufs] - - -def profile_sizes( - module: nn.Sequential, input: Union[List[Any], Tensor], chunks: int, param_scale: float, device: torch.device, -) -> List[int]: - """Profiles CUDA memory usage per layer.""" - if device.type != "cuda": - raise ValueError("size profiler supports only CUDA device") - - batch = Batch(input) - sizes: List[int] = [] - - latent_scale = batch[0].size(0) / chunks - for i, x in enumerate(batch): - batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad) - - for layer in layerwise_sandbox(module, device): - detach(batch) - - # Detect memory usage at forward. - torch._C._cuda_clearCublasWorkspaces() - memory_before = torch.cuda.memory_allocated(device) - batch = batch.call(layer) - torch._C._cuda_clearCublasWorkspaces() - memory_after = torch.cuda.memory_allocated(device) - latent_size = memory_after - memory_before - - # Analyze size of parameters. - param_size = sum(p._typed_storage()._nbytes() for p in layer.parameters()) - - # Combine size of parameters and activations with normalize scales. - size = latent_size * latent_scale + param_size * param_scale - sizes.append(int(size)) - - return sizes diff --git a/torch/distributed/pipeline/sync/_balance/py.typed b/torch/distributed/pipeline/sync/_balance/py.typed deleted file mode 100644 index ab03724cafbf..000000000000 --- a/torch/distributed/pipeline/sync/_balance/py.typed +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/batchnorm.py b/torch/distributed/pipeline/sync/batchnorm.py deleted file mode 100644 index 868ad50cf3fc..000000000000 --- a/torch/distributed/pipeline/sync/batchnorm.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Tracks the running statistics per mini-batch instead of micro-batch.""" -from typing import TypeVar, Optional, cast - -import torch -from torch import Tensor, nn -from torch.nn.functional import batch_norm -from torch.nn.modules.batchnorm import _BatchNorm - -from .checkpoint import is_recomputing - -__all__ = ["DeferredBatchNorm"] - - -TModule = TypeVar("TModule", bound=nn.Module) - - -class DeferredBatchNorm(_BatchNorm): - """A BatchNorm layer tracks multiple micro-batches to update running statistics per mini-batch.""" - - sum: Tensor - sum_squares: Tensor - running_mean: Tensor - running_var: Tensor - num_batches_tracked: Tensor - - def __init__( - self, - num_features: int, - eps: float = 1e-5, - momentum: Optional[float] = 0.1, - affine: bool = True, - chunks: int = 1, - ) -> None: - super().__init__(num_features, eps, momentum, affine, track_running_stats=True) - - self.register_buffer("sum", torch.zeros_like(self.running_mean)) - self.register_buffer("sum_squares", torch.zeros_like(self.running_var)) - - self.counter = 0 - self.tracked = 0 - self.chunks = chunks - - def _check_input_dim(self, input: Tensor) -> None: - # It's the typical _check_input_dim() implementation in PyTorch. - if input.dim() <= 2: - raise ValueError("expected at least 3D input (got %dD input)" % input.dim()) - - def _track(self, input: Tensor) -> bool: - """Tracks statistics of a micro-batch.""" - # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d. - dim = [0] - dim.extend(range(2, input.dim())) - - with torch.no_grad(): - self.sum += input.sum(dim) - self.sum_squares += (input ** 2).sum(dim) - - size = input.size().numel() // input.size(1) - self.counter += size - self.tracked += 1 - - return self.tracked == self.chunks - - def _commit(self) -> None: - """Update the running statistics of a mini-batch.""" - exponential_average_factor = 0.0 - self.num_batches_tracked += 1 - if self.momentum is None: # use cumulative moving average - exponential_average_factor = 1.0 / float(self.num_batches_tracked) - else: # use exponential moving average - exponential_average_factor = self.momentum - - mean = self.sum / self.counter - var = self.sum_squares / self.counter - mean ** 2 - - # Calculate the exponential moving average here. - m = exponential_average_factor - - self.running_mean *= 1 - m - self.running_mean += mean * m - - self.running_var *= 1 - m - self.running_var += var * m - - self.sum.zero_() - self.sum_squares.zero_() - self.counter = 0 - self.tracked = 0 - - def forward(self, input: Tensor) -> Tensor: - if not self.training: - # Don't train parameters on the evaluation mode. - return batch_norm( - input, - running_mean=self.running_mean, - running_var=self.running_var, - weight=self.weight, - bias=self.bias, - training=False, - momentum=0.0, - eps=self.eps, - ) - - if not is_recomputing(): - # Track a micro-batch on the training mode - # but not under a recomputation. - tracked_enough = self._track(input) - - # Update the running statistics for a mini-batch - # if it has tracked enough micro-batches. - if tracked_enough: - self._commit() - - # Normalize a micro-batch and train the parameters. - return batch_norm( - input, - running_mean=None, - running_var=None, - weight=self.weight, - bias=self.bias, - training=True, - momentum=0.0, - eps=self.eps, - ) - - @classmethod - def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule: - """Converts a :class:`nn.BatchNorm` or underlying :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`:: - - from torchvision.models.resnet import resnet101 - from torchpipe.batchnorm import DeferredBatchNorm - model = resnet101() - model = DeferredBatchNorm.convert_deferred_batch_norm(model) - - """ - if isinstance(module, DeferredBatchNorm) and module.chunks is chunks: - return cast(TModule, module) - - module_output: nn.Module = module - - if isinstance(module, _BatchNorm) and module.track_running_stats: - module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks) - if module.affine: - module_output.register_parameter("weight", module.weight) - module_output.register_parameter("bias", module.bias) - module_output.register_buffer("running_mean", module.running_mean) - module_output.register_buffer("running_var", module.running_var) - module_output.register_buffer("num_batches_tracked", module.num_batches_tracked) - - for name, child in module.named_children(): - module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks)) - - return cast(TModule, module_output) diff --git a/torch/distributed/pipeline/sync/checkpoint.py b/torch/distributed/pipeline/sync/checkpoint.py deleted file mode 100644 index e67da2499d57..000000000000 --- a/torch/distributed/pipeline/sync/checkpoint.py +++ /dev/null @@ -1,364 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Checkpointing with preceding recomputation. - -PyTorch already provides the official checkpointing utilities in -:mod:`torch.utils.checkpoint`. The official checkpointing combines -recomputation and recursive backpropagation into one autograd function named -``CheckpointFunction``. Hence, the recomputation can be started only when the -gradients arrive to the function. In Pipe, the recomputation needs to precede -the gradient arrival to minimize the GPU idle time. - -We solve this problem by introducing separate autograd functions named -:class:`Recompute` and :class:`Checkpoint`. Each function represents -recomputation and recursive backpropagation, respectively. We can manipulate -the control flow in aspect of both the autograd engine and CUDA with a pair of -the functions. - -Specifically, we place CUDA stream synchronization between :class:`Recompute` -and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is -copied entirely. - -""" -from collections import deque -from contextlib import contextmanager -import threading -from typing import ( - Any, - Deque, - Generator, - List, - Optional, - Protocol, - Union, - Sequence, - Tuple -) - -import torch -from torch import Tensor -import torch.autograd - -from .dependency import fork, join -from .microbatch import Batch -from .phony import get_phony - -__all__ = ["Function", "checkpoint", "Checkpointing", "ThreadLocal", "enable_checkpointing", - "enable_recomputing", "is_checkpointing", "is_recomputing", "Context", "save_rng_states", - "restore_rng_states", "Checkpoint", "Recompute"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -# Types for shared memory between Checkpoint and Recompute. -Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf) -RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state) - - -# Protocol with __call__ instead of Callable can be used as an attribute type. -# See: https://github.com/python/mypy/issues/708#issuecomment-561735949 -class Function(Protocol): - def __call__(self, input: TensorOrTensors) -> TensorOrTensors: - ... - - -def checkpoint(function: Function, input): - """Make a checkpoint with a simple interface like - :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug - :class:`Checkpoint` and :class:`Recompute` without boilerplate. - """ - batch = Batch(input) - - chk = Checkpointing(function, batch) - batch = chk.checkpoint() - chk.recompute(batch) - - return batch.values - - -class Checkpointing: - """Generates a pair of :class:`Checkpoint` and :class:`Recompute`.""" - - def __init__(self, function: Function, batch: Batch) -> None: - self.function = function - self.batch = batch - - # Shared memory between Checkpoint and Recompute. 1-length deque is - # used for mutability and length limitation. - self.recomputed: Deque[Recomputed] = deque(maxlen=1) - self.rng_states: Deque[RNGStates] = deque(maxlen=1) - - def checkpoint(self) -> Batch: - """Return a batch applied by :class:`Checkpoint`.""" - input_atomic = self.batch.atomic - inputs = tuple(self.batch) - - # Use a phony which requires grad to ensure that Checkpoint can be - # tracked by the autograd engine even when none of the input tensors - # require grad. - phony = get_phony(self.batch.get_device(), requires_grad=True) - - output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) - - # Gradients are only supported for float Tensors. - if isinstance(output, tuple): - output = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in output]) - - return Batch(output) - - def recompute(self, batch: Batch) -> None: - """Apply :class:`Recompute` to the batch in place.""" - input_atomic = self.batch.atomic - inputs = tuple(self.batch) - - # Use a tensor in the batch to tie together fork-join - tensor_idx = batch.find_tensor_idx() - # batch[tensor_idx] is always requiring grad, because it has been passed - # checkpoint with a phony requiring grad. - batch[tensor_idx], phony = fork(batch[tensor_idx]) - phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) - batch[tensor_idx] = join(batch[tensor_idx], phony) - - -class ThreadLocal(threading.local): - def __init__(self) -> None: - self.is_checkpointing = False - self.is_recomputing = False - - -thread_local = ThreadLocal() - - -@contextmanager -def enable_checkpointing() -> Generator[None, None, None]: - """Make :func:`is_checkpointing` return :data:`True` within a context.""" - orig = thread_local.is_checkpointing - thread_local.is_checkpointing = True - try: - yield - finally: - thread_local.is_checkpointing = orig - - -@contextmanager -def enable_recomputing() -> Generator[None, None, None]: - """Makes :func:`is_recomputing` return :data:`True` within a context.""" - orig = thread_local.is_recomputing - thread_local.is_recomputing = True - try: - yield - finally: - thread_local.is_recomputing = orig - - -def is_checkpointing() -> bool: - """Whether the current forward propagation is under checkpointing. - - Returns: - bool: :data:`True` if it's under checkpointing. - - """ - return thread_local.is_checkpointing - - -def is_recomputing() -> bool: - """Whether the current forward propagation is under checkpoint recomputation. - - Use this to prevent duplicated side-effects at forward - propagation:: - - class Counter(nn.Module): - def __init__(self): - super().__init__() - self.counter = 0 - - def forward(self, input): - if not is_recomputing(): - self.counter += 1 - return input - - Returns: - bool: :data:`True` if it's under checkpoint recomputation. - - .. seealso:: :ref:`Detecting Recomputation` - - """ - return thread_local.is_recomputing - - -class Context: - """The common interface between the :class:`Checkpoint` and :class:`Recompute` context.""" - - recomputed: Deque[Recomputed] - rng_states: Deque[RNGStates] - function: Function - input_atomic: bool - inputs: Sequence[Any] - - saved_tensors: Tuple[Tensor, ...] - - def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover - pass - - -def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: - """: - Capture the current random number generator states. - - meth:`Checkpoint.forward` captures the current PyTorch's random number - generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. - - .. seealso:: :ref:`Referential Transparency` - - """ - cpu_rng_state = torch.get_rng_state() - - gpu_rng_state: Optional[Tensor] - if device.type == "cuda": - gpu_rng_state = torch.cuda.get_rng_state(device) - else: - gpu_rng_state = None - - rng_states.append((cpu_rng_state, gpu_rng_state)) - - -@contextmanager -def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: - """: - Restore the random number generator state. - - meth:`Recompute.backward` restores the random number generator states - captured by :func:`save_rng_states` within its context. - - .. seealso:: :ref:`Referential Transparency` - - """ - cpu_rng_state, gpu_rng_state = rng_states.pop() - - gpu_devices: List[torch.device] = [] - if device.type == "cuda": - gpu_devices.append(device) - - with torch.random.fork_rng(gpu_devices): - torch.set_rng_state(cpu_rng_state) - if gpu_rng_state is not None: - torch.cuda.set_rng_state(gpu_rng_state, device) - yield - - -class Checkpoint(torch.autograd.Function): - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - phony: Tensor, - recomputed: Deque[Recomputed], - rng_states: Deque[RNGStates], - function: Function, - input_atomic: bool, - *inputs, - ): - ctx.recomputed = recomputed - ctx.rng_states = rng_states - - save_rng_states(phony.device, ctx.rng_states) - - ctx.function = function - ctx.input_atomic = input_atomic - if input_atomic: - tensors = [inputs[0]] - else: - tensors = [] - for input in inputs: - if torch.is_tensor(input): - tensors.append(input) - - ctx.save_for_backward(*tensors) - - with torch.no_grad(), enable_checkpointing(): - if input_atomic: - assert len(inputs) == 1 - output = function(inputs[0]) - else: - output = function(*inputs) - return output - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover - output, input_leaf = ctx.recomputed.pop() - - if isinstance(output, tuple): - outputs = output - else: - outputs = (output,) - if any(torch.is_tensor(y) and y.requires_grad for y in outputs): - tensors = tuple([x for x in outputs if torch.is_tensor(x) and x.requires_grad]) - torch.autograd.backward(tensors, grad_output) - - grad_input: List[Optional[Tensor]] = [None, None, None, None, None] - grad_input.extend(x.grad if torch.is_tensor(x) else None for x in input_leaf) - return tuple(grad_input) - - -class Recompute(torch.autograd.Function): - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - phony: Tensor, - recomputed: Deque[Recomputed], - rng_states: Deque[RNGStates], - function: Function, - input_atomic: bool, - *inputs, - ) -> Tensor: - ctx.recomputed = recomputed - ctx.rng_states = rng_states - - ctx.function = function - ctx.input_atomic = input_atomic - ctx.inputs = inputs - if input_atomic: - tensors = [inputs[0]] - else: - tensors = [] - for input in inputs: - if torch.is_tensor(input): - tensors.append(input) - ctx.save_for_backward(*tensors) - - return phony - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover - inputs = ctx.inputs - inputs_leaf = tuple(x.detach().requires_grad_(x.requires_grad) if torch.is_tensor(x) else x for x in inputs) - - # Get the device for the inputs from a tensor - device = None - for input in inputs: - if torch.is_tensor(input): - device = input.device - break - - if device is None: - raise RuntimeError(f'No tensors found in {inputs}') - - with restore_rng_states(device, ctx.rng_states): - with torch.enable_grad(), enable_recomputing(): - if ctx.input_atomic: - assert len(inputs_leaf) == 1 - output = ctx.function(inputs_leaf[0]) - else: - output = ctx.function(*inputs_leaf) - - ctx.recomputed.append((output, inputs_leaf)) - - grad_input: List[None] = [None, None, None, None, None] - grad_input.extend(None for _ in ctx.inputs) - return tuple(grad_input) diff --git a/torch/distributed/pipeline/sync/copy.py b/torch/distributed/pipeline/sync/copy.py deleted file mode 100644 index b717f0c2932c..000000000000 --- a/torch/distributed/pipeline/sync/copy.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Autograd functions for stream-aware CUDA copy. - -It is used to overlap copy and computation on the same GPU. -""" -from collections import deque -from typing import Deque, List, Optional, Tuple, Sequence - -import torch -from torch import Tensor - -from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream - -__all__: List[str] = ["Context", "Copy", "Wait"] - - -Tensors = Sequence[Tensor] - - -# Common interface between :class:`Copy` and :class:`Wait`. -class Context: - prev_stream: AbstractStream - next_stream: AbstractStream - - -class Copy(torch.autograd.Function): - """Copies tensors on specific streams.""" - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input,) -> Tensors: - ctx.prev_stream = prev_stream - ctx.next_stream = next_stream - - output = [] - output_stream = current_stream(get_device(next_stream)) - - with use_stream(prev_stream), use_stream(next_stream): - for x in input: - if torch.is_tensor(x): - y = x.to(get_device(next_stream), non_blocking=True) - output.append(y) - - # 'prev_stream' is not where 'x' has been allocated. - record_stream(x, prev_stream) - # 'y' has been allocated on 'next_stream'. - # It might be used on the current stream captured as 'output_stream'. - record_stream(y, output_stream) - else: - output.append(x) - - return tuple(output) - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: - prev_stream = ctx.prev_stream - next_stream = ctx.next_stream - - grad_input: Deque[Tensor] = deque(maxlen=len(grad_output)) - input_stream = current_stream(get_device(prev_stream)) - - with use_stream(prev_stream), use_stream(next_stream): - for x in reversed(grad_output): - y = x.to(get_device(prev_stream), non_blocking=True) - grad_input.appendleft(y) - - # 'next_stream' is not where 'x' has been allocated. - record_stream(x, next_stream) - # 'y' has been allocated on 'prev_stream'. - # It might be used on the current stream captured as 'input_stream'. - record_stream(y, input_stream) - - grad_streams: Tuple[Optional[Tensor], ...] = (None, None) - return grad_streams + tuple(grad_input) - - -class Wait(torch.autograd.Function): - """Synchronizes a stream to another stream. - - Place it just before you want to start an operation on the next stream, - provided that all operations on the previous stream are done. - - """ - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input) -> Tensors: - ctx.prev_stream = prev_stream - ctx.next_stream = next_stream - - wait_stream(next_stream, prev_stream) - - return tuple(x.detach() if torch.is_tensor(x) else x for x in input) - - @staticmethod - def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]: - prev_stream = ctx.prev_stream - next_stream = ctx.next_stream - - wait_stream(prev_stream, next_stream) - - grad_streams: Tuple[Optional[Tensor], ...] = (None, None) - return grad_streams + grad_input diff --git a/torch/distributed/pipeline/sync/dependency.py b/torch/distributed/pipeline/sync/dependency.py deleted file mode 100644 index ca5c69e388fe..000000000000 --- a/torch/distributed/pipeline/sync/dependency.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Arbitrary dependency between two autograd lanes.""" -from typing import List, Tuple - -import torch -from torch import Tensor - -from .phony import get_phony - -__all__: List[str] = ["fork", "Fork", "join", "Join"] - - -def fork(input: Tensor) -> Tuple[Tensor, Tensor]: - """Branches out from an autograd lane of the given tensor.""" - if torch.is_grad_enabled() and input.requires_grad: - input, phony = Fork.apply(input) - else: - phony = get_phony(input.device, requires_grad=False) - - return input, phony - - -class Fork(torch.autograd.Function): - @staticmethod - def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore[override] - phony = get_phony(input.device, requires_grad=False) - return input.detach(), phony.detach() - - @staticmethod - def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore[override] - return grad_input - - -def join(input: Tensor, phony: Tensor) -> Tensor: - """Merge two autograd lanes.""" - if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): - input = Join.apply(input, phony) - - return input - - -class Join(torch.autograd.Function): - @staticmethod - def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore[override] - return input.detach() - - @staticmethod - def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore[override] - return grad_input, None diff --git a/torch/distributed/pipeline/sync/microbatch.py b/torch/distributed/pipeline/sync/microbatch.py deleted file mode 100644 index 5b8aca257548..000000000000 --- a/torch/distributed/pipeline/sync/microbatch.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Manipulation of micro-batches.""" -import typing -from typing import Any, Callable, List, Union, cast, Sequence - -import torch -from torch import Tensor -import torch.cuda.comm - -__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] -Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]] - - -class NoChunk: - """ - Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor - should not be chunked on the batch dimension and instead be replicated - as-is across all micro-batches. This is useful for tensors which might - not have any 'batch' semantics for the model. - """ - def __init__(self, inp: Tensor): - if not torch.is_tensor(inp): - raise TypeError(f'NoChunk only supported for tensors, found: {inp}') - self._tensor = inp - - @property - def tensor(self): - return self._tensor - - -class Batch: - """ - An abstraction representing a microbatch in the pipeline. - """ - - def __init__(self, values: Union[List[Any], Tensor]) -> None: - self._values = values - self.atomic = torch.is_tensor(values) - - # Verify at least on tensor - if not self.atomic: - if not any(torch.is_tensor(value) for value in self._values): - raise TypeError(f'No tensors found in batch: {self._values}') - - @property - def tensor(self) -> Tensor: - """Retrieves the underlying tensor.""" - if not self.atomic: - raise AttributeError("not atomic batch") - return cast(Tensor, self._values) - - @property - def values(self): - """Retrieves the underlying values for the batch""" - return self._values - - def find_tensor_idx(self): - """ - Retrieves the index of first tensor found. - """ - if self.atomic: - return 0 - for i, value in enumerate(self._values): - if torch.is_tensor(value): - return i - - raise TypeError("No tensor found!") - - def get_device(self): - """ - Retrieves the device for this microbatch. - """ - if self.atomic: - return self._values.device # type: ignore[union-attr] - - for value in self._values: - if torch.is_tensor(value): - return value.device - - def call(self, function: Function) -> "Batch": - """Calls a function on the microbatch. It also wraps - the output with :class:`Batch`. - """ - if self.atomic: - return Batch(function(self._values)) - else: - return Batch(function(*self._values)) - - def __repr__(self) -> str: - return f"Batch[atomic={self.atomic!r}]({self._values!r})" - - def __iter__(self): - if self.atomic: - yield self._values - else: - yield from self._values - - def __len__(self) -> int: - return 1 if self.atomic else len(self._values) - - def __getitem__(self, index: int): - if not self.atomic: - return self._values[index] - - if index != 0: - raise IndexError("atomic batch allows index 0 only") - - return self._values - - # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload". - @typing.overload - def __setitem__(self, index: int, value: Tensor) -> None: - ... - - @typing.overload - def __setitem__(self, index: slice, value: Tensors) -> None: - ... - - def __setitem__(self, index: Union[int, slice], value) -> None: - if isinstance(index, int): - self._setitem_by_index(index, value) - else: - self._setitem_by_slice(index, value) - - def _setitem_by_index(self, index: int, value) -> None: - if not self.atomic: - i = index - self._values = self._values[:i] + (value,) + self._values[i + 1 :] # type: ignore[operator] - return - - if index != 0: - raise IndexError("atomic batch allows index 0 only") - - self._values = value - - def _setitem_by_slice(self, index: slice, value) -> None: - if not (index.start is index.stop is index.step is None): # noqa: E714 - raise NotImplementedError("only slice [:] supported") - - if not self.atomic: - self._values = value - return - - if len(value) != 1: - raise IndexError("atomic batch cannot be replaced with multiple tensors") - - self._values = value[0] - - -def check(first_device, *inputs) -> None: - """ - Checks whether the input contains at least one tensor and each tensor is - on the same device as the first partition. - - Raises: - ValueError: input does not contain at least one tensor - - """ - - if not any(torch.is_tensor(input) for input in inputs): - raise TypeError(f'inputs do not have any tensors: {inputs}') - if any(torch.is_tensor(input) and input.device != first_device for input in inputs): - raise ValueError('All inputs should be on the same device as the first partition') - - -def scatter(*inputs, chunks: int) -> List[Batch]: - """Splits an input mini-batch into multiple micro-batches.""" - if len(inputs) == 1 and isinstance(inputs[0], Tensor): - return [Batch(x) for x in inputs[0].chunk(chunks)] - - batches: List[Any] = [[] for _ in range(chunks)] - # Actual number of chunks produced - num_chunks = -1 - for input in inputs: - if torch.is_tensor(input): - # Chunk only tensors. - tensors = input.chunk(chunks) - - # Validate number of chunks equal across all inputs. - if num_chunks != -1 and num_chunks != len(tensors): - raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}') - num_chunks = len(tensors) - - for i, tensor in enumerate(tensors): - batches[i].append(tensor) - else: - # Replicate non-tensors or tensors wrapped with 'NoChunk'. - for i in range(chunks): - if isinstance(input, NoChunk): - # Extract the tensor out. - batches[i].append(input.tensor) - else: - batches[i].append(input) - - # Truncate to actual number of chunks - batches = batches[:num_chunks] - - return [Batch(x) for x in batches] - - -def gather(outputs: List[Batch]): - """Concatenates output micro-batches into a mini-batch.""" - output: Any - - if outputs[0].atomic: - tensors = tuple(b.tensor for b in outputs) - output = torch.cat(tensors) - else: - output_buf: List[Any] = [] - for i in range(len(outputs[0])): - output_type = type(outputs[0][i]) - current_outputs = [] - for batch in outputs: - if output_type != type(batch[i]): - raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}') - current_outputs.append(batch[i]) - - if torch.is_tensor(outputs[0][i]): - output_buf.append(torch.cat(current_outputs)) - else: - output_buf.append(current_outputs) - - output = tuple(output_buf) - - return output diff --git a/torch/distributed/pipeline/sync/phony.py b/torch/distributed/pipeline/sync/phony.py deleted file mode 100644 index 012926699cfb..000000000000 --- a/torch/distributed/pipeline/sync/phony.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Provides phony for arbitrary dependency in a autograd graph.""" -from typing import Dict, List, Tuple - -import torch -from torch import Tensor - -from .stream import default_stream, use_stream - -__all__: List[str] = ["get_phony"] - - -_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} - - -def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: - """Get a phony. Phony is tensor without space. - - It is useful to make arbitrary dependency in a autograd graph because it doesn't require any - gradient accumulation. - - .. note:: - - Phonies for each device are cached. If an autograd function gets a phony - internally, the phony must be detached to be returned. Otherwise, the - autograd engine will mutate the cached phony in-place:: - - class Phonify(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - phony = get_phony(input.device, requires_grad=False) - return phony.detach() # detach() is necessary. - - """ - key = (device, requires_grad) - - try: - phony = _phonies[key] - except KeyError: - with use_stream(default_stream(device)): - phony = torch.empty(0, device=device, requires_grad=requires_grad) - - _phonies[key] = phony - - return phony diff --git a/torch/distributed/pipeline/sync/pipe.py b/torch/distributed/pipeline/sync/pipe.py deleted file mode 100644 index 5e61341d9ad9..000000000000 --- a/torch/distributed/pipeline/sync/pipe.py +++ /dev/null @@ -1,490 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The Pipe interface.""" -from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Union, Sequence, Tuple, cast - -import torch -from torch import Tensor, nn -from torch.distributed.rpc import RRef -import torch.autograd -import torch.cuda - -from . import microbatch -from .batchnorm import DeferredBatchNorm -from .pipeline import Pipeline -from .skip.layout import inspect_skip_layout -from .skip.skippable import verify_skippables -from .stream import AbstractStream, new_stream - -__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"] - - -Device = Union[torch.device, int, str] -Devices = Union[Iterable[Device], List[Device]] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -if TYPE_CHECKING: - # Typechecking: nn.Module is not a Generic - Module = nn.Module[TensorOrTensors] # type: ignore[type-arg] - NamedModules = OrderedDict[str, Module] -else: - Module = nn.Module - NamedModules = OrderedDict - - -def _recommend_auto_balance(message: str) -> str: - """Expands a message with recommendation to :mod:`torchpipe.balance`.""" - return f"""{message} - -If your model is still under development, its optimal balance would change -frequently. In this case, we highly recommend 'torch.distributed.pipeline.sync.balance' for -naive automatic balancing: - - from torch.distributed.pipeline.sync import Pipe - from torch.distributed.pipeline.sync.balance import balance_by_time - - partitions = torch.cuda.device_count() - sample = torch.empty(...) - balance = balance_by_time(partitions, model, sample) - - model = Pipe(model, balance, ...) -""" - - -def _verify_module(module: nn.Sequential) -> None: - if not isinstance(module, nn.Sequential): - raise TypeError("module must be nn.Sequential to be partitioned") - - named_children = list(module.named_children()) - if len(named_children) != len(module): - raise ValueError("module with duplicate children is not supported") - - -def _verify_splitting( - module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device] -) -> None: - num_parameters = len(list(module.parameters())) - num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) - if num_parameters == num_child_parameters: - return - - for i in range(len(partitions)): - for j in range(i + 1, len(partitions)): - parti = partitions[i] - partj = partitions[j] - if devices[i] == devices[j]: - continue - for p in parti.parameters(): - for q in partj.parameters(): - if p is q: - raise ValueError("module with duplicate parameters on distinct devices is not supported") - - -class BalanceError(ValueError): - pass - - -def _retrieve_device(module: nn.Module) -> torch.device: - """Validates all parameters in the Module have the same device and returns - the appropriate device. - - Args: - An ``nn.Module`` to process. - - Returns: - ``torch.Device`` for the entire module. - - Raises: - ValueError: - If devices for ``nn.Module`` parameters are not all same. - """ - - device = None - for parameter in module.parameters(): - if device is None: - device = parameter.device - elif device != parameter.device: - raise ValueError( - f'nn.Module: {module}, should have all parameters on a single device,' - ' please use .to() to place the module on a single device') - - return device if device is not None else torch.device("cpu") - - -class PipeSequential(nn.Sequential): - """ - Pipe variant of ``nn.Sequential`` which supports multiple inputs. - """ - - def forward(self, *inputs): - for module in self: - if isinstance(inputs, Tuple): # type: ignore[arg-type] - inputs = module(*inputs) - else: - # Don't expand single variables (ex: lists/Tensor) - inputs = module(inputs) - return inputs - - -class WithDevice(nn.Module): - """ - Wraps an ``nn.Module`` which is part of ``nn.Sequential`` passed into :class:`Pipe` - that overrides the device for that module. In cases where :class:`Pipe` - can't implicitly determine the device for the module and places it on CPU, - this wrapper can be used to override the implicit behavior and explicitly - specify which device a module should run on. - - The provided module is also moved to the given device via ``.to(device)`` - by :class:`Pipe` - - Args: - module(:class:`torch.nn.Module`): The module to be wrapped. - device(:class:`torch.device`): The device to run the module on. - - Example:: - >>> # xdoctest: +SKIP("distributed") - >>> fc1 = nn.Linear(16, 8).cuda(0) - >>> fc2 = nn.Linear(8, 4).cuda(1) - >>> dropout = nn.Dropout() - >>> - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) - >>> # Dropout does not have any parameters/buffers, but we want to - >>> # run it on cuda:1 to avoid any GPU to CPU transfers. - >>> model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1')) - >>> # xdoctest: +SKIP("Needs RPC framework init") - >>> model = Pipe(model, chunks=8) - """ - def __init__(self, module: nn.Module, device: torch.device): - super().__init__() - self._module = module - self._device = torch.device(device) - - def forward(self, *args, **kwargs): - return self._module(*args, **kwargs) - - @property - def module(self): - return self._module - - @property - def device(self): - return self._device - - -def _assemble_partition(modules: List[nn.Module]): - modules_list: List[nn.Module] = [] - for module in modules: - if isinstance(module, nn.Sequential): - modules_list.extend(module.children()) - else: - modules_list.append(module) - return PipeSequential(*modules_list) - - -def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]: - partitions = [] - devices = [] - - current_partition = [] - current_device = None - for name, module in modules.named_children(): - if isinstance(module, WithDevice): - # Process device override and move module to appropriate device. - device = module.device - module = module.module - module.to(device) - else: - device = _retrieve_device(module) - if current_device is not None and (current_device != device or device.type == 'cpu'): - partitions.append(_assemble_partition(current_partition)) - devices.append(current_device) - current_partition = [] - current_device = device - current_partition.append(module) - - if current_device is not None: - partitions.append(_assemble_partition(current_partition)) - devices.append(current_device) - - partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) - - return partitions, devices - - -MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement") - - -class Pipe(Module): - """Wraps an arbitrary :class:`nn.Sequential ` module - to train on using synchronous pipeline parallelism. If the module requires - lots of memory and doesn't fit on a single GPU, pipeline parallelism is a - useful technique to employ for training. - - The implementation is based on the torchgpipe_ paper. - - .. _torchgpipe: https://arxiv.org/abs/2004.09910 - - Pipe combines pipeline parallelism with checkpointing to reduce peak - memory required to train while minimizing device under-utilization. - - You should place all the modules on the appropriate devices and wrap them - into an :class:`nn.Sequential ` module defining the - desired order of execution. If a module does not contain any - parameters/buffers, it is assumed this module should be executed on CPU - and appropriate input tensors to the module are moved to CPU before - execution. This behavior can be overridden by the :class:`WithDevice` - wrapper which can be used to explicitly specify which device a module - should run on. - - Args: - module (:class:`nn.Sequential `): - sequential module to be parallelized using pipelining. Each module - in the sequence has to have all of its parameters on a single - device. Each module in the sequence has to either be an nn.Module - or :class:`nn.Sequential ` (to combine multiple - sequential modules on a single device) - chunks (int): - number of micro-batches (default: ``1``) - checkpoint (str): - when to enable checkpointing, one of ``'always'``, - ``'except_last'``, or ``'never'`` (default: ``'except_last'``). - ``'never'`` disables checkpointing completely, ``'except_last'`` - enables checkpointing for all micro-batches except the last one - and ``'always'`` enables checkpointing for all micro-batches. - deferred_batch_norm (bool): - whether to use deferred ``BatchNorm`` moving statistics (default: - :data:`False`). If set to :data:`True`, we track statistics across - multiple micro-batches to update the running statistics per - mini-batch. - - Raises: - TypeError: - the module is not a :class:`nn.Sequential `. - ValueError: - invalid arguments - - Example:: - Pipeline of two FC layers across GPUs 0 and 1. - - >>> # Need to initialize RPC framework first. - >>> # xdoctest: +SKIP - >>> os.environ['MASTER_ADDR'] = 'localhost' - >>> os.environ['MASTER_PORT'] = '29500' - >>> torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1) - >>> - >>> # Build pipe. - >>> fc1 = nn.Linear(16, 8).cuda(0) - >>> fc2 = nn.Linear(8, 4).cuda(1) - >>> model = nn.Sequential(fc1, fc2) - >>> model = Pipe(model, chunks=8) - >>> input = torch.rand(16, 16).cuda(0) - >>> output_rref = model(input) - - .. note:: - You can wrap a :class:`Pipe` model with - :class:`torch.nn.parallel.DistributedDataParallel` only when the - checkpoint parameter of :class:`Pipe` is ``'never'``. - - .. note:: - :class:`Pipe` only supports intra-node pipelining currently, but - will be expanded to support inter-node pipelining in the future. - The forward function returns an :class:`~torch.distributed.rpc.RRef` - to allow for inter-node pipelining in the future, where the output - might be on a remote host. For intra-node pipelining you can use - :meth:`~torch.distributed.rpc.RRef.local_value` to retrieve the - output locally. - - .. warning:: - :class:`Pipe` is experimental and subject to change. - """ - - def __init__( - self, - module: nn.Sequential, - chunks: int = 1, - checkpoint: str = "except_last", - deferred_batch_norm: bool = False, - ) -> None: - super().__init__() - - # Check if RPC framework is initialized. - if not torch.distributed.rpc._is_current_rpc_agent_set(): - raise RuntimeError( - 'Please initialize RPC framework for Pipe using ' - 'torch.distributed.rpc.init_rpc') - - chunks = int(chunks) - checkpoint = str(checkpoint) - - if chunks <= 0: - raise ValueError("number of chunks must be positive integer") - if checkpoint not in ["always", "except_last", "never"]: - raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") - - _verify_module(module) - - # Verify if the underlying skippable modules satisfy integrity. The - # integrity can be verified before forward() because it is static. - verify_skippables(module) - - self.chunks = chunks - self.checkpoint = checkpoint - - if deferred_batch_norm: - module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) - - self.partitions, self.devices = _split_module(module) - _verify_splitting(module, self.partitions, self.devices) - - self._copy_streams: List[List[AbstractStream]] = [] - self._skip_layout = inspect_skip_layout(self.partitions) - - # Separate CUDA streams for copy. - copy_streams = self._ensure_copy_streams() - - # The micro-batch index where the checkpointing stops. - checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] - - self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) - - def __len__(self) -> int: - """Counts the length of the underlying sequential module.""" - return sum(len(p) for p in self.partitions) - - def __getitem__(self, index: int) -> nn.Module: - """Gets a layer in the underlying sequential module.""" - partitions = self.partitions - if index < 0: - partitions = partitions[::-1] - - for partition in partitions: - try: - return partition[index] - except IndexError: - pass - - shift = len(partition) - - if index < 0: - index += shift - else: - index -= shift - - raise IndexError - - def __iter__(self) -> Iterator[nn.Module]: - """Iterates over children of the underlying sequential module.""" - for partition in self.partitions: - yield from partition - - # Pipe should manage the device of each partition. - # Deny cuda(), cpu(), and to() with device, by TypeError. - def cuda(self, device: Optional[Device] = None) -> "Pipe": - raise MOVING_DENIED - - def cpu(self) -> "Pipe": - raise MOVING_DENIED - - def to(self, *args: Any, **kwargs: Any) -> "Pipe": - # Deny these usages: - # - # - to(device[, dtype, non_blocking]) - # - to(tensor[, non_blocking]) - # - # But allow this: - # - # - to(dtype[, non_blocking]) - # - if "device" in kwargs or "tensor" in kwargs: - raise MOVING_DENIED - - if args: - if isinstance(args[0], (torch.device, int, str)): - raise MOVING_DENIED - if torch.is_tensor(args[0]): - raise MOVING_DENIED - - return super().to(*args, **kwargs) - - def _ensure_copy_streams(self) -> List[List[AbstractStream]]: - """Ensures that :class:`Pipe` caches CUDA streams for copy. - - It's worth to cache CUDA streams although PyTorch already manages a - pool of pre-allocated CUDA streams, because it may reduce GPU memory - fragmentation when the number of micro-batches is small. - - """ - if not self._copy_streams: - for device in self.devices: - self._copy_streams.append([new_stream(device) for _ in range(self.chunks)]) - - return self._copy_streams - - def forward(self, *inputs) -> RRef: - """ - Processes a single input mini-batch through the pipe and returns an - :class:`~torch.distributed.rpc.RRef` pointing to the output. - :class:`Pipe` is a fairly transparent module wrapper. It doesn't - modify the input and output signature of the underlying module. But - there's type restriction. Input and output have to contain at least one - tensor. This restriction is applied at partition boundaries too. - - The sequence of inputs are fed into the first stage of the pipeline as - ``*inputs``. As a result the positional args for this function should - match the positional args for the first stage of the pipeline. The same - condition applies for output of one stage of the pipeline which is the - input for the next stage. - - The input tensor is split into multiple micro-batches based on the - ``chunks`` parameter used to initialize :class:`Pipe`. The batch size - is assumed to be the first dimension of the tensor and if the batch - size is less than ``chunks``, the number of micro-batches is equal to - the batch size. - - Only tensors are split into multiple micro-batches, non-Tensor inputs - are just replicated as-is in each micro-batch. For non-Tensor outputs - in the last stage of the pipeline, they are aggregated as a ``List`` - and returned the user. For example, if you have 2 micro-batches - returning the integer 5, the user would receive the consolidated - output of `[5, 5]` - - All the input tensors need to be on the same device as the first - partition of the pipeline. - - If a tensor is wrapped with the :class:`NoChunk` wrapper, the tensor - is not split across micro-batches and is replicated as-is similar to - non-tensors. - - Args: - inputs: input mini-batch - - Returns: - :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch - - Raises: - TypeError: input doesn't contain at least one tensor - - """ - first_partition_device = self.devices[0] if len(self.devices) != 0 else torch.device("cpu") - microbatch.check(first_partition_device, *inputs) - - if not self.devices: - # Empty sequential module is not illegal. - return RRef(*inputs) - - # Divide a mini-batch into micro-batches. - batches = microbatch.scatter(*inputs, chunks=self.chunks) - - # Run pipeline parallelism. - self.pipeline.run(batches) - - # Merge the micro-batches into one mini-batch. - output = microbatch.gather(batches) - return RRef(output) diff --git a/torch/distributed/pipeline/sync/pipeline.py b/torch/distributed/pipeline/sync/pipeline.py deleted file mode 100644 index 7cd5e5831169..000000000000 --- a/torch/distributed/pipeline/sync/pipeline.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The pipeline parallelism of Pipe.""" -from queue import Queue -from types import TracebackType -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence - -import torch -from torch import Tensor, nn -from torch.autograd.profiler import record_function - -from .checkpoint import Checkpointing -from .copy import Copy, Wait -from .dependency import fork, join -from .microbatch import Batch -from .skip.layout import SkipLayout -from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker -from .stream import AbstractStream, current_stream, use_device -from .worker import Task, create_workers - -__all__: List[str] = ["Pipeline"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] - -# Queue is generic only in stubs. -# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime -if TYPE_CHECKING: - InQueue = Queue[Optional["Task"]] - OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] -else: - InQueue = Queue - OutQueue = Queue - - -def _depend(fork_from: Batch, join_to: Batch) -> None: - fork_from_idx = fork_from.find_tensor_idx() - join_to_idx = join_to.find_tensor_idx() - - fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx]) - join_to[join_to_idx] = join(join_to[join_to_idx], phony) - - -def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: - batch[:] = Copy.apply(prev_stream, next_stream, *batch) - # Gradients are only supported for float Tensors. - batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) - - -def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: - batch[:] = Wait.apply(prev_stream, next_stream, *batch) - # Gradients are only supported for float Tensors. - batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) - - -def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]: - """Generate schedules for each clock cycle.""" - # m: number of micro-batches - # n: number of partitions - # i: index of micro-batch - # j: index of partition - # k: clock number - # - # k (i,j) (i,j) (i,j) - # - ----- ----- ----- - # 0 (0,0) - # 1 (1,0) (0,1) - # 2 (2,0) (1,1) (0,2) - # 3 (2,1) (1,2) - # 4 (2,2) - for k in range(m + n - 1): - yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] - - -class Pipeline: - """The pipeline parallelism for Pipe.""" - - def __init__( - self, - partitions: List[nn.Sequential], - devices: List[torch.device], - copy_streams: List[List[AbstractStream]], - skip_layout: SkipLayout, - checkpoint_stop: int, - ) -> None: - self.partitions = partitions - self.devices = devices - self.copy_streams = copy_streams - self.skip_layout = skip_layout - self.checkpoint_stop = checkpoint_stop - (self.in_queues, self.out_queues) = create_workers(devices) - - def run(self, batches: List[Batch]) -> None: - """Runs pipeline parallelism. - - It modifies the given batches in place. - - """ - partitions = self.partitions - devices = self.devices - skip_layout = self.skip_layout - - m = len(batches) - n = len(partitions) - - skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] - - for schedule in _clock_cycles(m, n): - self.fence(batches, schedule, skip_trackers) - self.compute(batches, schedule, skip_trackers) - - def fence( - self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], - ) -> None: - """Copy micro-batches after computation for the previous micro-batches.""" - copy_streams = self.copy_streams - skip_layout = self.skip_layout - - for i, j in schedule: - # Ensure that batches[i-1] is executed after batches[i] in - # backpropagation by an explicit dependency. - if i != 0 and j != 0: - _depend(batches[i - 1], batches[i]) - - next_stream = copy_streams[j][i] - - for prev_j, ns, name in skip_layout.copy_policy(j): - prev_stream = copy_streams[prev_j][i] - skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name) - - if j != 0: - prev_stream = copy_streams[j - 1][i] - _copy(batches[i], prev_stream, next_stream) - - def compute( - self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], - ) -> None: - """Run tasks with synchronization to copy streams.""" - partitions = self.partitions - devices = self.devices - copy_streams = self.copy_streams - checkpoint_stop = self.checkpoint_stop - - # Disable checkpointing if in eval mode. - if not self.partitions[0].training: - checkpoint_stop = 0 - - n = len(partitions) - streams = [current_stream(d) for d in devices] - exc_info: Optional[ExcInfo] = None - - # With checkpointing, the autograd graph looks like this diagram: - # +-----+------+ - # | Copy | - # +-----+------+ (fence) - # - - - + - - - - - - - - - - # | (compute) - # +-----+------+ - # | Wait | [1] Synchronize the current stream with the copy stream. - # +-----+------+ - # +-----+------+ - # | Checkpoint | [2] Compute a partition within checkpointing. - # +-----+------+ - # +-----+------+ - # | Wait | [3] Synchronize the copy stream with the current stream. - # +-----+------+ - # + - - - + - # | +-----+-----+ - # | | Recompute | [4] Schedule the recomputation at backpropagation. - # | +-----+-----+ - # + - - - + - # | - # - - - + - - - - - - - - - - # +-----+------+ (fence) - # | Copy | - # +-----+------+ - for i, j in schedule: - batch = batches[i] - partition = partitions[j] - - # Synchronize with the copied input. ([1] in the diagram) - if j != 0: - _wait(batch, copy_streams[j][i], streams[j]) - - # Determine whether checkpointing or not. - checkpoint = i < checkpoint_stop - if checkpoint: - - def function( - *inputs, - partition: nn.Module = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> TensorOrTensors: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return partition(*inputs) - - chk = Checkpointing(function, batch) # type: ignore[arg-type] - task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) - del function, chk - - else: - - def compute( - batch: Batch = batch, - partition: nn.Module = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> Batch: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return batch.call(partition) - - task = Task(streams[j], compute=compute, finalize=None) - del compute - - # Compute tasks in parallel. ([2] in the diagram) - self.in_queues[j].put(task) - - for i, j in schedule: - ok, payload = self.out_queues[j].get() - - # Hold the first exception. - if exc_info is not None: - continue - elif not ok: - exc_info = cast(ExcInfo, payload) - continue - - task, batch = cast(Tuple[Task, Batch], payload) - - # The copy stream synchronizes to copy the output. ([3] in the - # diagram) - if j != n - 1: - _wait(batch, streams[j], copy_streams[j][i]) - - # Finalize tasks. If checkpointing is enabled, here the - # recomputation is scheduled at backpropagation. ([4] in the - # diagram) - with use_device(devices[j]): - task.finalize(batch) - - batches[i] = batch - - # Fail at the first exception. - if exc_info is not None: - raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) diff --git a/torch/distributed/pipeline/sync/py.typed b/torch/distributed/pipeline/sync/py.typed deleted file mode 100644 index ab03724cafbf..000000000000 --- a/torch/distributed/pipeline/sync/py.typed +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/skip/__init__.py b/torch/distributed/pipeline/sync/skip/__init__.py deleted file mode 100644 index bdcb913867a7..000000000000 --- a/torch/distributed/pipeline/sync/skip/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Supports efficiency with skip connections.""" -from .namespace import Namespace -from .skippable import pop, skippable, stash, verify_skippables - -__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"] diff --git a/torch/distributed/pipeline/sync/skip/layout.py b/torch/distributed/pipeline/sync/skip/layout.py deleted file mode 100644 index 04d76d34ea16..000000000000 --- a/torch/distributed/pipeline/sync/skip/layout.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Static skip connection layout of ``@skippable`` modules.""" -from typing import Dict, Iterable, List, Tuple - -from torch import nn - -from .namespace import Namespace - -__all__: List[str] = [] - - -class SkipLayout: - """Represents a skip connection layout across partitions.""" - - # Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...} - by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]] - - # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] - by_partition: List[List[Tuple[int, Namespace, str]]] - - def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: - # The skip routes are already indexed by 'ns, name'. - self.by_ns_name = skip_routes - - # Index skip routes by partition number 'j'. - self.by_partition = [[] for _ in range(num_partitions)] - - for (ns, name), (prev_j, next_j) in skip_routes.items(): - self.by_partition[next_j].append((prev_j, ns, name)) - - for p in self.by_partition: - p.sort() - - def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]: - """Generates skip routes for the given destination partition number. - The skip routes are sorted by source partition number in ascending - order. - - Yields: - Each tuple of (source partition number, namespace, name). - - """ - for prev_j, ns, name in self.by_partition[next_j]: - if prev_j == next_j: - # This skip tensor will be popped at the same partition where - # it is stashed. In this case, copy is not required. - continue - - yield (prev_j, ns, name) - - def requires_copy(self, ns: Namespace, name: str) -> bool: - """Whether the given namespace and name requires partition-to-partition - copy or not. - """ - prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1)) - return prev_j != next_j - - -def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout: - """Inspects the skip connection layout in the given partitions.""" - # NOTE(sublee): Hide circular import inside this subroutine. Circular - # import is not ideal but placing this logic near to SkipLayout may - # increase cohesion of code. - from .skippable import Skippable - - skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {} - stashed_at: Dict[Tuple[Namespace, str], int] = {} - - for j, partition in enumerate(partitions): - def inspect_layer(layer): - if not isinstance(layer, Skippable): - return - - for ns, name in layer.stashable(): - stashed_at[(ns, name)] = j - - for ns, name in layer.poppable(): - prev_j = stashed_at.pop((ns, name)) - skip_routes[(ns, name)] = (prev_j, j) - - if isinstance(partition, nn.Sequential): - for layer in partition: - inspect_layer(layer) - else: - inspect_layer(partition) - - return SkipLayout(len(partitions), skip_routes) diff --git a/torch/distributed/pipeline/sync/skip/namespace.py b/torch/distributed/pipeline/sync/skip/namespace.py deleted file mode 100644 index 7d9c0d9b7d84..000000000000 --- a/torch/distributed/pipeline/sync/skip/namespace.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Provides isolated namespace of skip tensors.""" -import abc -from functools import total_ordering -from typing import Any -import uuid - -__all__ = ["Namespace"] - - -@total_ordering -class Namespace(metaclass=abc.ABCMeta): # noqa: B024 - """Namespace for isolating skip tensors used by :meth:`isolate() - `. - """ - - __slots__ = ("id",) - - def __init__(self) -> None: - self.id = uuid.uuid4() - - def __repr__(self) -> str: - return f"" - - def __hash__(self) -> int: - return hash(self.id) - - # Namespaces should support ordering, since SkipLayout will sort tuples - # including a namespace. But actual order between namespaces is not - # important. That's why they are ordered by version 4 UUID which generates - # random numbers. - def __lt__(self, other: Any) -> bool: - if isinstance(other, Namespace): - return self.id < other.id - return False - - def __eq__(self, other: object) -> bool: - if isinstance(other, Namespace): - return self.id == other.id - return False - - -# 'None' is the default namespace, -# which means that 'isinstance(None, Namespace)' is 'True'. -Namespace.register(type(None)) diff --git a/torch/distributed/pipeline/sync/skip/portal.py b/torch/distributed/pipeline/sync/skip/portal.py deleted file mode 100644 index 335793f4cc13..000000000000 --- a/torch/distributed/pipeline/sync/skip/portal.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the -autograd engine. The shared context of three functions (:class:`PortalBlue`, -:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is -one of the most important feature of :mod:`torchpipe.skip`. - -The metaphor is inspired by Portal(tm) from Valve. - -""" -from typing import List, Optional, Tuple - -import torch -from torch import Tensor - -from ..copy import Context as CopyContext -from ..copy import Copy -from ..phony import get_phony -from ..stream import AbstractStream, get_device - -__all__: List[str] = [] - - -class Portal: - """A portal for a tensor.""" - - def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None: - self.put_tensor(tensor, tensor_life) - self.grad: Optional[Tensor] = None - - def blue(self) -> Tensor: - """Creates a :class:`PortalBlue` which hides the underlying tensor from - the autograd engine. - - Join the returning phony to the main lane of the autograd graph to - assure the correct backpropagation:: - - PortalBlue --+ - | - ---------- Join -- - - """ - tensor = self.use_tensor() - - if tensor is None: - return get_phony(torch.device("cpu"), requires_grad=False) - - return PortalBlue.apply(self, tensor) - - def orange(self, phony: Tensor) -> Optional[Tensor]: - """Creates a :class:`PortalOrange` which retrieves the hidden tensor - without losing ability of backpropagation. - - Give a phony forked from the main lane of an autograd graph:: - - +-- PortalOrange --+ - | | - -- Fork --------- f(a, b) -- - - """ - self.check_tensor_life() - - if self.tensor is None: - return self.use_tensor() - - return PortalOrange.apply(self, phony) - - def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor: - """Copies the hidden tensor by a :class:`PortalCopy`. - - Give a phony and use the returning phony to keep backpropagation:: - - +-- PortalCopy --+ - | | - -- Fork ---------- Join -- - - """ - if self.tensor is None: - return get_phony(torch.device("cpu"), requires_grad=False) - - return PortalCopy.apply(self, prev_stream, next_stream, phony) - - def check_tensor_life(self) -> None: - if self.tensor_life <= 0: - raise RuntimeError("tensor in portal has been removed") - - def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None: - """Stores a tensor into this portal.""" - # [Life of Tensor through Portal] - # - # The tensor can be retrieved by use_tensor() up to 'tensor_life' - # times. When the life becomes 0, the tensor will be deleted for - # deallocation in CUDA memory. - # - # The below events participate in a tensor through a portal. - # Note that [x] denotes the events which call use_tensor(): - # - # 1. [x] blue() - # 2. [ ] PortalBlue.forward - # 3. [ ] copy() - # 4. [ ] PortalCopy.forward - # 5. [ ] orange() - # 6. [x] PortalOrange.forward - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 7. [ ] orange() (recomputed) - # 8. [x] PortalOrange.forward (recomputed) - # 9. [ ] PortalOrange.backward - # 10. [ ] PortalCopy.backward - # 11. [x] blue() (recomputed) - # 12. [ ] PortalBlue.forward (recomputed) - # 13. [ ] PortalBlue.backward - # - self.tensor_life = tensor_life - - if tensor_life > 0: - self.tensor = tensor - else: - self.tensor = None - - def use_tensor(self) -> Optional[Tensor]: - """Retrieves the underlying tensor and decreases the tensor life. When - the life becomes 0, it the tensor will be removed. - """ - self.check_tensor_life() - - tensor = self.tensor - - self.tensor_life -= 1 - - if self.tensor_life <= 0: - self.tensor = None - - return tensor - - def put_grad(self, grad: Tensor) -> None: - """Stores a gradient into this portal.""" - self.grad = grad - - def use_grad(self) -> Tensor: - """Retrieves and removes the underlying gradient. The gradient is - always ephemeral. - """ - if self.grad is None: - raise RuntimeError("grad in portal has been removed or never set") - - grad = self.grad - self.grad = None - return grad - - -# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and -# :class:`PortalCopy`. -class Context(CopyContext): - portal: Portal - - -class PortalBlue(torch.autograd.Function): - """Hides a tensor from the autograd engine by a :class:`Portal`.""" - - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - portal: Portal, - # This tensor must be retrieved by portal.use_tensor(). - tensor: Tensor, - ) -> Tensor: - ctx.portal = portal - - phony = get_phony(tensor.device, requires_grad=False) - return phony.detach() - - @staticmethod - # type: ignore[override] - def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]: - # The paired PortalOrange should keep the gradient. - grad = ctx.portal.use_grad() - return None, grad - - -class PortalOrange(torch.autograd.Function): - """Retrieves the hidden tensor from a :class:`Portal`.""" - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor: - ctx.portal = portal - - tensor = portal.use_tensor() - assert tensor is not None - - return tensor.detach() - - @staticmethod - def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore[override] - # The paired PortalBlue will use the gradient. - ctx.portal.put_grad(grad) - return None, None - - -class PortalCopy(torch.autograd.Function): - """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden - tensor with copied one. - """ - - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor, - ) -> Tensor: - ctx.portal = portal - - assert portal.tensor is not None - (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor) - - phony = get_phony(get_device(next_stream), requires_grad=False) - return phony.detach() - - @staticmethod - # type: ignore[override] - def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]: - portal = ctx.portal - - assert portal.grad is not None - _, _, portal.grad = Copy.backward(ctx, portal.grad) - - return None, None, None, None diff --git a/torch/distributed/pipeline/sync/skip/skippable.py b/torch/distributed/pipeline/sync/skip/skippable.py deleted file mode 100644 index 9d4db76c6b67..000000000000 --- a/torch/distributed/pipeline/sync/skip/skippable.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The user interface to define skip connections.""" -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Dict, - FrozenSet, - Generator, - Iterable, - List, - Optional, - Set, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -from torch import Tensor, nn - -from ..microbatch import Batch -from .namespace import Namespace -from .tracker import current_skip_tracker - -__all__ = ["skippable", "stash", "pop", "verify_skippables"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -StashPop = Union["stash", "pop"] -StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors] -if TYPE_CHECKING: - # Typechecking: nn.Module is not a Generic - SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg] -else: - SkippableModule = nn.Module - -T = TypeVar("T", bound="Skippable") - - -class Skippable(nn.Module): - """The base class for skippable modules. - - Do not use this class directly. Define a subclass by :func:`skippable` - instead. - - """ - - module_cls: ClassVar[Type[SkippableModule]] - stashable_names: ClassVar[FrozenSet[str]] - poppable_names: ClassVar[FrozenSet[str]] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self.module = self.module_cls(*args, **kwargs) # type: ignore[call-arg] - self.namespaces: Dict[str, Namespace] = {} - - def __repr__(self) -> str: - return f"@skippable({self.module})" - - def namespaced(self, name: str) -> Tuple[Namespace, str]: - """Prepend namespace for the given skip name.""" - ns = self.namespaces.get(name) - ns = cast(Namespace, ns) - return (ns, name) - - def stashable(self) -> Iterable[Tuple[Namespace, str]]: - """Iterate over namespaced skip names to be stashed.""" - for name in self.stashable_names: - yield self.namespaced(name) - - def poppable(self) -> Iterable[Tuple[Namespace, str]]: - """Iterate over namespaced skip names to be popped.""" - for name in self.poppable_names: - yield self.namespaced(name) - - def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T: - r"""Isolate a specified subset or the whole set of skip tensors. - - In a single sequential module, skip tensors with the same - name are not allowed unless they are isolated by different namespaces. - - Here's an example using the same name for skip tensors twice. Each pair - of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1`` - and ``ns2``. There is no conflict anymore:: - - ns1 = Namespace() - ns2 = Namespace() - - model = nn.Sequential( - Layer1().isolate(ns1), - Layer1().isolate(ns2), - Layer2(), - Layer3().isolate(ns2), - Layer3().isolate(ns1), - ) - - When `only` parameter is omitted, all skip tensors are isolated. You - can isolate a subset of skip tensors by passing `only` parameter:: - - ns_alice = Namespace() - ns_bob = Namespace() - - model = nn.Sequential( - ... - StashStashPop().isolate(ns_alice, only=['alice']) \ - .isolate(ns_bob, only=['bob']), - ... - ) - - Args: - ns (Namespace): - namespace for isolation - - Keyword Args: - only (iterable of strs): - names of specific skip tensors to be isolated (omit this option - to isolate all skip tensors declared in this module) - - Returns: - this module itself - - """ - names: Iterable[str] - - if only is None: - names = self.stashable_names | self.poppable_names - else: - names = set(only) - - for name in names: - self.namespaces[name] = ns - - return self - - def dispatch( - self, - input, - handle_stash: Callable[[str, Optional[Tensor]], None], - handle_pop: Callable[[str], Optional[Tensor]], - ): - """Dispatch :class:`stash` or :class:`pop` commands. - - The commands are generated by the module's ``forward()``. - """ - generator = self.module(input) - - if not isinstance(generator, Generator): - # The underlying module returned output without any yield. - output = generator - return output - - try: - op = next(generator) - - while True: - if isinstance(op, stash): - handle_stash(op.name, op.tensor) - op = next(generator) - continue - - if isinstance(op, pop): - tensor = handle_pop(op.name) - op = generator.send(tensor) - continue - - raise TypeError(f"{op!r} is not a command from @skippable") - - except StopIteration as stop: - output = stop.args[0] - return output - - def forward(self, input: Union[List[Any], Tensor]) -> TensorOrTensors: - """Perform the forward propagation. - - :class:`stash` or :class:`pop` commands will be handled by portals - silently. The portals won't be exposed to users. - - Raises: - RuntimeError: - illegal 'stash' or 'pop' is found. - - """ - skip_tracker = current_skip_tracker() - stashed_tensors: Dict[str, Optional[Tensor]] = {} - - # Load skip tensors that might be popped. - poppable_tensors = {} - batch = Batch(input) - for ns, name in self.poppable(): - try: - poppable_tensors[name] = skip_tracker.load(batch, ns, name) - except KeyError as e: - raise RuntimeError(f"'{name}' has not been stashed") from e - input = batch.values - - # Handle skip commands. - def handle_stash(name: str, tensor: Optional[Tensor]) -> None: - if name not in self.stashable_names: - raise RuntimeError(f"'{name}' has not been declared as stashable") - stashed_tensors[name] = tensor - - def handle_pop(name: str) -> Optional[Tensor]: - if name not in self.poppable_names: - raise RuntimeError(f"'{name}' has not been declared as poppable") - return poppable_tensors.pop(name) - - output = self.dispatch(input, handle_stash, handle_pop) - - # All declared skips must be stashed or popped. - not_stashed = self.stashable_names - stashed_tensors.keys() - if not_stashed: - comma_names = ", ".join(f"'{n}'" for n in not_stashed) - raise RuntimeError(f"{comma_names} must be stashed but have not") - - not_popped = poppable_tensors.keys() - if not_popped: - comma_names = ", ".join(f"'{n}'" for n in not_popped) - raise RuntimeError(f"{comma_names} must be popped but have not") - - # Save stashed skip tensors. - batch = Batch(output) - for ns, name in self.stashable(): - tensor = stashed_tensors[name] - skip_tracker.save(batch, ns, name, tensor) - output = batch.values - - return output - - -# TODO(sublee): Move to above of Skippable class for better read flow. -def skippable( - stash: Iterable[str] = (), pop: Iterable[str] = (), -) -> Callable[[Type[SkippableModule]], Type[Skippable]]: - """Define a decorator to create :class:`nn.Module ` with skip connections. - - These decorated modules are called "skippable". This functionality works perfectly - fine even when the module is not wrapped by :class:`~torch.distributed.pipeline.sync.Pipe`. - - Each skip tensor is managed by its name. Before manipulating skip tensors, - a skippable module must statically declare the names for skip tensors by - `stash` and/or `pop` parameters. Skip tensors with pre-declared name can be - stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield - pop(name)``. - - Here is an example with three layers. A skip tensor named "1to3" is stashed - and popped at the first and last layer, respectively:: - - @skippable(stash=['1to3']) - class Layer1(nn.Module): - def forward(self, input): - yield stash('1to3', input) - return f1(input) - - class Layer2(nn.Module): - def forward(self, input): - return f2(input) - - @skippable(pop=['1to3']) - class Layer3(nn.Module): - def forward(self, input): - skip_1to3 = yield pop('1to3') - return f3(input) + skip_1to3 - - model = nn.Sequential(Layer1(), Layer2(), Layer3()) - - One skippable module can stash or pop multiple skip tensors:: - - @skippable(stash=['alice', 'bob'], pop=['carol']) - class StashStashPop(nn.Module): - def forward(self, input): - yield stash('alice', f_alice(input)) - yield stash('bob', f_bob(input)) - carol = yield pop('carol') - return input + carol - - Every skip tensor must be associated with exactly one pair of `stash` and - `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this - restriction automatically when wrapping a module. You can also check the - restriction by :func:`verify_skippables` - without :class:`~torch.distributed.pipeline.sync.Pipe`. - - """ - stashable_names = frozenset(stash) - poppable_names = frozenset(pop) - - def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]: - name = module_cls.__name__ - bases = (Skippable,) - attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names} - return type(name, bases, attrs) - - return extend_skippable - - -class stash: - """The command to stash a skip tensor. - - :: - - def forward(self, input): - yield stash('name', input) - return f(input) - - Args: - name (str): name of skip tensor - input (torch.Tensor or None): tensor to pass to the skip connection - - """ - - __slots__ = ("name", "tensor") - - def __init__(self, name: str, tensor: Optional[Tensor]) -> None: - self.name = name - self.tensor = tensor - - -class pop: - """The command to pop a skip tensor. - - :: - - def forward(self, input): - skip = yield pop('name') - return f(input) + skip - - Args: - name (str): name of skip tensor - - Returns: - the skip tensor previously stashed by another layer under the same name - - """ - - __slots__ = ("name",) - - def __init__(self, name: str) -> None: - self.name = name - - -def verify_skippables(module: nn.Sequential) -> None: - """Verify if the underlying skippable modules satisfy integrity. - - Every skip tensor must have only one pair of `stash` and `pop`. If there - are one or more unmatched pairs, it will raise :exc:`TypeError` with the - detailed messages. - - Here are a few failure cases. :func:`verify_skippables` will report failure - for these cases:: - - # Layer1 stashes "1to3". - # Layer3 pops "1to3". - - nn.Sequential(Layer1(), Layer2()) - # +---- ? - - nn.Sequential(Layer2(), Layer3()) - # ? ----+ - - nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3()) - # +-------------------+ ^^^^^^ - - nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3()) - # ^^^^^^ +-------------------+ - - To use the same name for multiple skip tensors, they must be isolated by - different namespaces. See :meth:`isolate() - `. - - Raises: - TypeError: - one or more pairs of `stash` and `pop` are not matched. - - """ - stashed: Set[Tuple[Namespace, str]] = set() - popped: Set[Tuple[Namespace, str]] = set() - msgs: List[str] = [] - - for layer_name, layer in module.named_children(): - if not isinstance(layer, Skippable): - continue - - for name in layer.stashable_names & layer.poppable_names: - msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable" - msgs.append(msg) - - for ns, name in layer.stashable(): - if name in layer.poppable_names: - continue - - if (ns, name) in stashed: - msg = f"'{layer_name}' redeclared '{name}' as stashable but not isolated by namespace" - msgs.append(msg) - continue - - stashed.add((ns, name)) - - for ns, name in layer.poppable(): - if name in layer.stashable_names: - continue - - if (ns, name) in popped: - msg = f"'{layer_name}' redeclared '{name}' as poppable but not isolated by namespace" - msgs.append(msg) - continue - - if (ns, name) not in stashed: - msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed" - msgs.append(msg) - continue - - popped.add((ns, name)) - - for (_, name) in stashed - popped: - msg = f"no module declared '{name}' as poppable but stashed" - msgs.append(msg) - - if msgs: - raise TypeError( - "one or more pairs of stash and pop do not match:\n\n{}" "".format("\n".join(f"* {x}" for x in msgs)) - ) diff --git a/torch/distributed/pipeline/sync/skip/tracker.py b/torch/distributed/pipeline/sync/skip/tracker.py deleted file mode 100644 index 8ac82bc05dc9..000000000000 --- a/torch/distributed/pipeline/sync/skip/tracker.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Tracks skip tensors on a thread.""" -from contextlib import contextmanager -import threading -from typing import Dict, Generator, List, Optional, Tuple - -from torch import Tensor - -from ..checkpoint import is_checkpointing -from ..dependency import fork, join -from ..microbatch import Batch -from ..stream import AbstractStream -from .layout import SkipLayout -from .namespace import Namespace -from .portal import Portal - -__all__: List[str] = [] - - -class SkipTracker: - """Tracks saved skip tensors. - - It will update the given micro-batch in place. This is because when it - manipulates the underlying skip tensors, the current micro-batch also has - to be connected with the skip tensors. - - One thread has one skip tracker. Call :func:`current_skip_tracker` to get - the skip tracker on the current thread. - - """ - - def __init__(self) -> None: - self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {} - - def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: - self.tensors[(ns, name)] = tensor - - def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: - return self.tensors.pop((ns, name)) - - def copy( - self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, - ) -> None: - raise TypeError("copy is not supported for non-portal skip tensors") - - -class SkipTrackerThroughPotals(SkipTracker): - """Tracks saved skip tensors through portals. The skip tensors will be - hidden in portals so that the autograd engine does not need to track them. - - This tracker is only used when the training or evaluating module is wrapped - with :class:`torchpipe.Pipe`. - - """ - - def __init__(self, skip_layout: SkipLayout) -> None: - super().__init__() - self.skip_layout = skip_layout - self.portals: Dict[Tuple[Namespace, str], Portal] = {} - - def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: - """Saves the stashed skip tensor in a portal. The portal is then - connected to the given micro-batch with :class:`Join`. - """ - if not self.skip_layout.requires_copy(ns, name): - super().save(batch, ns, name, tensor) - return - - # See [Tensor Life of Portal] at Portal.put_tensor() to understand the - # below tensor_life values. Here are the selected events which retrieve - # the tensor in portal: - # - # 1. [x] blue() - # ... - # 6. [x] PortalOrange.forward - # ... - # 8. [x] PortalOrange.forward (recomputed) - # ... - # 11. [x] blue() (recomputed) - # - if (ns, name) not in self.portals: - if is_checkpointing(): - # Under checkpointing, the tensor used by the first - # PortalOrange should be alive in the portal. This tensor will - # be used again by the second PortalOrange during the - # recomputation. - tensor_life = 3 # Delete at [8. PortalOrange.forward (recomputed)] - else: - tensor_life = 2 # Delete at [6. PortalOrange.forward] - - portal = Portal(tensor, tensor_life) - self.portals[(ns, name)] = portal - - else: - # Under recomputation, the portal already exists. - portal = self.portals[(ns, name)] - - # The existing tensor life already became 0. It should be reset as - # 1 to delete the tensor after the second PortalBlue immediately. - tensor_life = 1 # Delete at [11. blue() (recomputed)] - - portal.put_tensor(tensor, tensor_life) - - phony = portal.blue() - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx] = join(batch[tensor_idx], phony) - - def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: - """Loads a skip tensor from the corresponding portal to pop. The given - micro-batch is connected to the portal with :class:`Fork`. - """ - if not self.skip_layout.requires_copy(ns, name): - tensor = super().load(batch, ns, name) - return tensor - - portal = self.portals[(ns, name)] - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx], phony = fork(batch[tensor_idx]) - tensor = portal.orange(phony) - return tensor - - def copy( - self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, - ) -> None: - """Copies the skip tensor in the corresponding portal. The given - micro-batch and the portal will be tied with :class:`Fork` and - :class:`Join`. - """ - assert self.skip_layout.requires_copy(ns, name) - - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx], phony = fork(batch[tensor_idx]) - - portal = self.portals[(ns, name)] - phony = portal.copy(prev_stream, next_stream, phony) - - batch[tensor_idx] = join(batch[tensor_idx], phony) - - -class ThreadLocal(threading.local): - def __init__(self) -> None: - self.skip_tracker: Optional[SkipTracker] = None - - -thread_local = ThreadLocal() - - -@contextmanager -def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]: - """Registers the given skip tracker on the current thread within a - context:: - - with use_skip_tracker(my_skip_tracker): - ... - - """ - orig = thread_local.skip_tracker - - thread_local.skip_tracker = skip_tracker - - try: - yield - finally: - thread_local.skip_tracker = orig - - -def current_skip_tracker() -> SkipTracker: - """Gets the skip tracker on the current thread.""" - skip_tracker = thread_local.skip_tracker - - if skip_tracker is None: - skip_tracker = SkipTracker() - thread_local.skip_tracker = skip_tracker - - return skip_tracker diff --git a/torch/distributed/pipeline/sync/stream.py b/torch/distributed/pipeline/sync/stream.py deleted file mode 100644 index 59fedf865a42..000000000000 --- a/torch/distributed/pipeline/sync/stream.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Utilities for eliminating boilerplate code to handle abstract streams with -CPU device. -""" -from contextlib import contextmanager -from typing import Generator, List, Union, cast - -import torch - -__all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream", - "use_device", "use_stream", "get_device", "wait_stream", "record_stream", - "is_cuda", "as_cuda"] - - -class CPUStreamType: - pass - - -# The placeholder on place of streams for the CPU device instead of CUDA. -CPUStream = CPUStreamType() - -# It represents both CUDA streams and the CPU stream. -AbstractStream = Union[torch.cuda.Stream, CPUStreamType] - - -def new_stream(device: torch.device) -> AbstractStream: - """Creates a new stream for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.Stream(device) - - -def current_stream(device: torch.device) -> AbstractStream: - """:func:`torch.cuda.current_stream` for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.current_stream(device) - - -def default_stream(device: torch.device) -> AbstractStream: - """:func:`torch.cuda.default_stream` for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.default_stream(device) - - -@contextmanager -def use_device(device: torch.device) -> Generator[None, None, None]: - """:func:`torch.cuda.device` for either CPU or CUDA device.""" - if device.type != "cuda": - yield - return - - with torch.cuda.device(device): - yield - - -@contextmanager -def use_stream(stream: AbstractStream) -> Generator[None, None, None]: - """:func:`torch.cuda.stream` for either CPU or CUDA stream.""" - if not is_cuda(stream): - yield - return - - with torch.cuda.stream(as_cuda(stream)): - yield - - -def get_device(stream: AbstractStream) -> torch.device: - """Gets the device from CPU or CUDA stream.""" - if is_cuda(stream): - return as_cuda(stream).device - return torch.device("cpu") - - -def wait_stream(source: AbstractStream, target: AbstractStream) -> None: - """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It - makes the source stream wait until the target stream completes work queued. - """ - if is_cuda(target): - if is_cuda(source): - # A CUDA stream waits another CUDA stream. - as_cuda(source).wait_stream(as_cuda(target)) - else: - # CPU waits a CUDA stream. - as_cuda(target).synchronize() - - # If the target is CPU, synchronization is not required. - - -def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: - """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream.""" - if is_cuda(stream): - # NOTE(sublee): record_stream() on a shifted view tensor throws - # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely - # protect the tensor against unexpected reallocation, here we use a - # temporal tensor associated with the same storage without shifting as - # a workaround. - # - # Issue: https://github.com/pytorch/pytorch/issues/27366 - # - tensor = tensor.new_empty([0]).set_(tensor._typed_storage()) - - # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream - tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type] - - -def is_cuda(stream: AbstractStream) -> bool: - """Returns ``True`` if the given stream is a valid CUDA stream.""" - return stream is not CPUStream - - -def as_cuda(stream: AbstractStream) -> torch.cuda.Stream: - """Casts the given stream as :class:`torch.cuda.Stream`.""" - return cast(torch.cuda.Stream, stream) diff --git a/torch/distributed/pipeline/sync/utils.py b/torch/distributed/pipeline/sync/utils.py deleted file mode 100644 index 210c475317e2..000000000000 --- a/torch/distributed/pipeline/sync/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -from torch import nn -from typing import List, Optional - -__all__ = ["partition_model"] - -def partition_model( - module: nn.Sequential, - balance: List[int], - devices: Optional[List[int]] = None): - """ - Partions the model accross multiple GPU devices. - - Given an :class:`nn.Sequential ` module, partitions - the model across multiple GPU devices according the provided ``balance`` - and ``devices``. - - Args: - module (:class:`nn.Sequential `): - Sequential model representing the pipe. - balance (List[int]): - List indicating the number of layers in each partition. - devices (List[int], optional): - List indicating the device to use for each partition. Defaults to - ``range(len(balance))`` - """ - device_idx = 0 - pipe_idx = 0 - balanced_pipe = [] - for num_layers in balance: - layers = [] - for i in range(num_layers): - layers.append(module[pipe_idx]) - pipe_idx += 1 - device = device_idx if devices is None else devices[device_idx] - balanced_pipe.append(nn.Sequential(*layers).to(device)) - device_idx += 1 - - return nn.Sequential(*balanced_pipe) diff --git a/torch/distributed/pipeline/sync/worker.py b/torch/distributed/pipeline/sync/worker.py deleted file mode 100644 index 87b20c4a5551..000000000000 --- a/torch/distributed/pipeline/sync/worker.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Multithreading in pipeline parallelism.""" -from contextlib import contextmanager -from queue import Queue -import sys -from threading import Thread -from types import TracebackType -from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast - -import torch - -from .microbatch import Batch -from .stream import AbstractStream, use_device, use_stream - -__all__: List[str] = ["Task", "worker", "create_workers", "spawn_workers"] - - -ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] - -# Queue is generic only in stubs. -# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime -if TYPE_CHECKING: - InQueue = Queue[Optional["Task"]] - OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] -else: - InQueue = Queue - OutQueue = Queue - - -class Task: - """A task represents how to compute a micro-batch on a partition. - - It consists of two parts: :meth:`compute` and :meth:`finalize`. - :meth:`compute` should be executed in worker threads concurrently. - :meth:`finalize` should be executed after when worker threads complete to - execute :meth:`compute`. - - :meth:`compute` might be boosted by worker threads. Because it produces - several CUDA API calls by user code. In PyTorch, parallel CUDA API calls - are not serialized through GIL. So more than one CUDA API call can be - produced at the same time. - - """ - - def __init__( - self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]], - ) -> None: - self.stream = stream - self._compute = compute - self._finalize = finalize - self._grad_enabled = torch.is_grad_enabled() - - def compute(self) -> Batch: - with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): - return self._compute() - - def finalize(self, batch: Batch) -> None: - if self._finalize is None: - return - with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): - self._finalize(batch) - - -def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None: - """Main loop of a worker thread.""" - with use_device(device): - while True: - task = in_queue.get() - - if task is None: - break - - try: - batch = task.compute() - except Exception: - exc_info = cast(ExcInfo, sys.exc_info()) - out_queue.put((False, exc_info)) - continue - - out_queue.put((True, (task, batch))) - - done = (False, None) - out_queue.put(done) - - -def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]: - """Spawns worker threads. A worker thread is bound to a device.""" - in_queues: List[InQueue] = [] - out_queues: List[OutQueue] = [] - - # Spawn workers. - workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {} - - def normalize_device(device: torch.device) -> torch.device: - if device.type == "cuda" and device.index is None: - return torch.device("cuda", index=torch.cuda.current_device()) - - if device.type == "cpu" and device.index is not None: - return torch.device("cpu") - - return device - - for device in devices: - device = normalize_device(device) - - try: - in_queue, out_queue = workers[device] - except KeyError: - in_queue = Queue() - out_queue = Queue() - workers[device] = (in_queue, out_queue) - - t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,) - t.start() - - in_queues.append(in_queue) - out_queues.append(out_queue) - - return (in_queues, out_queues) - -@contextmanager -def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]: - try: - (in_queues, out_queues) = create_workers(devices) - yield (in_queues, out_queues) - finally: - pass diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index fabc9377277a..28b7514ab16f 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -303,6 +303,17 @@ def __init__( self._stage.has_backward = self._has_backward def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Clean per iteration self._stage.clear_runtime_states() @@ -583,6 +594,17 @@ def __init__( ) def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Clean per iteration for stage in self._stages: stage.clear_runtime_states() diff --git a/torch/testing/_internal/distributed/pipe_with_ddp_test.py b/torch/testing/_internal/distributed/pipe_with_ddp_test.py deleted file mode 100644 index 1ed9f3cc96df..000000000000 --- a/torch/testing/_internal/distributed/pipe_with_ddp_test.py +++ /dev/null @@ -1,149 +0,0 @@ -# mypy: ignore-errors - -import torch -import torch.distributed as dist - -from torch import nn -from torch.nn.parallel import DistributedDataParallel -from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init -from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( - RpcAgentTestFixture, -) -from torch.testing._internal.common_distributed import ( - requires_gloo, - requires_nccl, - skip_if_lt_x_gpu, - skip_if_rocm, -) -from torch.distributed.pipeline.sync import Pipe - -class PipeWithDDPTest(RpcAgentTestFixture): - @property - def world_size(self) -> int: - return 2 - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_never(self): - self._run_basic_test("nccl", "never") - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_never_find_unused(self): - self._run_basic_test("nccl", "never", find_unused_parameters=True) - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_always(self): - self._run_basic_test("nccl", "always", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_except_last(self): - self._run_basic_test("nccl", "except_last", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_never(self): - self._run_basic_test("gloo", "never") - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_never_find_unused(self): - self._run_basic_test("gloo", "never", find_unused_parameters=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_always(self): - self._run_basic_test("gloo", "always", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_except_last(self): - self._run_basic_test("gloo", "except_last", static_graph=True) - - def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False, static_graph=False): - dist.init_process_group( - backend=backend, - init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), - world_size=self.world_size, - rank=self.rank, - ) - - # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another - # pipe between GPU 2 and 3. Both replicas are replicated via DDP. - fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank) - - class MyModule(nn.Module): - def __init__(self, device): - super().__init__() - self.fc2 = nn.Linear(8, 4, bias=False).cuda(device) - self.fc3 = nn.Linear(4, 2, bias=False).cuda(device) - - def forward(self, inp): - if find_unused_parameters: - return self.fc2(inp) - else: - return self.fc3(self.fc2(inp)) - - layer2 = MyModule(2 * self.rank + 1) - model = nn.Sequential( - fc1, - layer2 - ) - model = Pipe(model, chunks=2, checkpoint=checkpoint) - model = DistributedDataParallel( - model, - find_unused_parameters=find_unused_parameters, - static_graph=static_graph, - ) - - # Ensure inputs are different across ranks to verify that gradient - # sync indeed occurs. - model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - out = model(model_input).local_value() - out.sum().backward() - - # Run forward again for find_unused_parameters to trigger any potential errors. - if find_unused_parameters: - # Ensure inputs are different across ranks to verify that gradient - # sync indeed occurs. - unused_param_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - model(unused_param_input).local_value().sum().backward() - - # Run a few more iterations of fwd + bwd to ensure gradient synchronization - # occurs properly across iterations via delay_all_reduce/bucketized allreduce. - for _ in range(3): - model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - out = model(model_input).local_value() - out.sum().backward() - - # Check grads - output = [torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad)] - dist.all_gather(output, fc1.weight.grad) - self.assertEqual(output[0], output[1]) - - output = [torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad)] - dist.all_gather(output, layer2.fc2.weight.grad) - self.assertEqual(output[0], output[1]) - - if not find_unused_parameters: - output = [torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad)] - dist.all_gather(output, layer2.fc3.weight.grad) - self.assertEqual(output[0], output[1]) diff --git a/torch/testing/_internal/distributed/pipeline/__init__.py b/torch/testing/_internal/distributed/pipeline/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/testing/_internal/distributed/rpc_utils.py b/torch/testing/_internal/distributed/rpc_utils.py index cdbbdcfd0681..5b6e2c90770f 100644 --- a/torch/testing/_internal/distributed/rpc_utils.py +++ b/torch/testing/_internal/distributed/rpc_utils.py @@ -16,9 +16,6 @@ DdpComparisonTest, DdpUnderDistAutogradTest, ) -from torch.testing._internal.distributed.pipe_with_ddp_test import ( - PipeWithDDPTest, -) from torch.testing._internal.distributed.nn.api.remote_module_test import ( CudaRemoteModuleTest, RemoteModuleTest, @@ -121,7 +118,6 @@ def tearDown(self): CudaDistAutogradTest, CudaRemoteModuleTest, CudaDdpComparisonTest, - PipeWithDDPTest, ] From 0c16800b4a26a0748ba62bde1f2837d484fa52fe Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Fri, 7 Jun 2024 08:40:51 +0000 Subject: [PATCH 463/706] [pipelining] include lifted constants in input_to_state (#128173) Previous PR only looked at state dict to determine inputs to state, missing out on lifted tensors Pull Request resolved: https://github.com/pytorch/pytorch/pull/128173 Approved by: https://github.com/kwen2501 --- torch/distributed/pipelining/_IR.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 0a45c4459f30..ed2ef32d4255 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -868,7 +868,11 @@ def move_param_to_callee( inputs_to_state: Dict[str, List[str]] = {} for attr in attr_nodes: _, tensor = _recursive_getattr_with_parent(mod, attr.target) - inputs_to_state[attr.name] = list(id_to_fqns[id(tensor)]) + fqns = list(id_to_fqns[id(tensor)]) + if fqns: + inputs_to_state[attr.name] = fqns + elif attr.target in exported_program.constants: # lifted constants + inputs_to_state[attr.name] = [attr.target] # [aliasing] for each submodule split, assign attributes on FQNs that may be used. # We determine this based on whether or not the FQN attribute parent exists. From 7efaeb1494c56a08254bd2238713cf5339af95f3 Mon Sep 17 00:00:00 2001 From: chunyuan Date: Fri, 7 Jun 2024 05:29:44 +0000 Subject: [PATCH 464/706] [AOTI] docs: add suggestion to turn on freezing on CPU (#128010) With https://github.com/pytorch/pytorch/pull/124350 landed, it is now suggested in AOTI to turn on freezing on CPU to get better performance. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128010 Approved by: https://github.com/desertfire --- docs/source/torch.compiler_aot_inductor.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/torch.compiler_aot_inductor.rst b/docs/source/torch.compiler_aot_inductor.rst index 0ebd03bbcecf..257f16f40cc0 100644 --- a/docs/source/torch.compiler_aot_inductor.rst +++ b/docs/source/torch.compiler_aot_inductor.rst @@ -37,7 +37,9 @@ For more details on ``torch.export``, you can refer to the :ref:`torch.export do If you have a CUDA-enabled device on your machine and you installed PyTorch with CUDA support, the following code will compile the model into a shared library for CUDA execution. - Otherwise, the compiled artifact will run on CPU. + Otherwise, the compiled artifact will run on CPU. For better performance during CPU inference, + it is suggested to enable freezing by setting `export TORCHINDUCTOR_FREEZING=1` + before running the Python script below. .. code-block:: python From 5f81265572151ed2b486afbbbc27d88d345b80f3 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 6 Jun 2024 13:58:57 -0700 Subject: [PATCH 465/706] [Traceable FSDP2] Return early from _register_post_backward_hook when compile (#127864) Dynamo doesn't support `RegisterPostBackwardFunction` very well yet. This PR skips it and rely on `root_post_backward_callback` under compile. We will improve `RegisterPostBackwardFunction` support in Q3. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127864 Approved by: https://github.com/awgu --- torch/distributed/_composable/fsdp/_fsdp_param_group.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index ea2307222ce1..bb66977848a3 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -3,6 +3,7 @@ from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple import torch +import torch._dynamo.compiled_autograd as ca import torch.distributed as dist import torch.nn as nn from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates @@ -402,6 +403,9 @@ def use_training_state(self, training_state: TrainingState): def _register_post_backward_hook( self, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + # Compile relies on `root_post_backward_callback` to call each `FSDPParamGroup.post_backward` + if ca.compiled_autograd_enabled: + return args, kwargs if not torch.is_grad_enabled(): return args, kwargs args_list, args_spec = tree_flatten(args) From 543a870943120484db547382ed9ca9538a40f284 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Thu, 6 Jun 2024 23:19:30 -0700 Subject: [PATCH 466/706] [pipelining] Rename ManualPipelineStage -> PipelineStage (#128157) Renaming ManualPipelineStage to remove the "Manual" part. I needed to replace the existing `PipelineStage` which takes in the `pipe` argument, so I have renamed that to `TracerPipelineStage`. @kwen2501 will remove this entirely in favor of adding a util to `Pipe` to just create the stage directly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128157 Approved by: https://github.com/wconstab --- docs/source/distributed.pipelining.rst | 5 ++-- .../pipelining/test_composability.py | 4 ++-- test/distributed/pipelining/test_schedule.py | 12 +++++----- test/distributed/pipelining/test_stage.py | 8 +++---- torch/distributed/pipelining/PipelineStage.py | 24 +++++++++---------- torch/distributed/pipelining/README.md | 2 +- torch/distributed/pipelining/__init__.py | 4 ++-- 7 files changed, 29 insertions(+), 30 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index a8203a5f3b2c..c1b82b8b1bb9 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -180,8 +180,7 @@ You can also create a distributed stage runtime on a device using ``Pipe``: .. code-block:: python from torch.distributed.pipelining import PipelineStage - - stage = PipelineStage(pipe, stage_idx, device) + stage = TracerPipelineStage(pipe, stage_idx, device) .. note:: The ``pipeline`` frontend uses a tracer (``torch.export``) to capture your @@ -348,7 +347,7 @@ Pipeline Stages .. autoclass:: PipelineStage -.. autoclass:: ManualPipelineStage +.. autoclass:: TracerPipelineStage Pipeline Schedules ================== diff --git a/test/distributed/pipelining/test_composability.py b/test/distributed/pipelining/test_composability.py index bbf3f1929fbc..3503001ba49e 100644 --- a/test/distributed/pipelining/test_composability.py +++ b/test/distributed/pipelining/test_composability.py @@ -16,7 +16,7 @@ ) from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.pipelining import ManualPipelineStage +from torch.distributed.pipelining import PipelineStage from torch.distributed.pipelining.PipelineSchedule import ( PipelineScheduleSingle, Schedule1F1B, @@ -127,7 +127,7 @@ def apply_dp(partial_model, dp_type): def build_stage(stage_idx, num_stages): partial_model, offset = get_stage_module(stage_idx, num_stages) dp_model = apply_dp(partial_model, dp_type) - stage = ManualPipelineStage( + stage = PipelineStage( dp_model, stage_idx, num_stages, diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 462ba83da07e..81b4e1c7ae07 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -13,13 +13,13 @@ import torch import torch.distributed as dist from torch.distributed.pipelining import ( - ManualPipelineStage, pipeline, PipelineStage, Schedule1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleLoopedBFS, + TracerPipelineStage, ) from torch.distributed.pipelining.PipelineSchedule import _Action, _ComputationType from torch.distributed.pipelining.PipelineStage import _PipelineStageBase @@ -91,7 +91,7 @@ def test_multi_iter(self, ScheduleClass): split_spec=split_spec, ) - stage = PipelineStage( + stage = TracerPipelineStage( pipe, self.rank, device=self.device, @@ -130,7 +130,7 @@ def test_kwargs_with_tracer(self, ScheduleClass): example_kwargs={"y": y}, ) - stage = PipelineStage( + stage = TracerPipelineStage( pipe, self.rank, device=self.device, @@ -192,7 +192,7 @@ def test_grad_with_tracer(self, ScheduleClass, ModelClass): split_spec=split_spec, ) - stage = PipelineStage( + stage = TracerPipelineStage( pipe, self.rank, device=self.device, @@ -263,7 +263,7 @@ def test_grad_with_manual(self, ScheduleClass): stage_module = full_mod.get_submodule(submod_name) chunks = 4 # Create a pipeline stage to wrap that submodule - stage = ManualPipelineStage( + stage = PipelineStage( stage_module, self.rank, self.world_size, @@ -347,7 +347,7 @@ def test_grad_with_manual_interleaved(self, ScheduleClass): chunks = 8 input_args = x.chunk(chunks)[0] stages = [ - ManualPipelineStage( + PipelineStage( stage_module, stage_idx, n_stages, diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index 97a147cb357a..45f4b0b01a9c 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -9,10 +9,10 @@ import torch import torch.distributed as dist from torch.distributed.pipelining import ( - ManualPipelineStage, pipeline, PipelineStage, ScheduleGPipe, + TracerPipelineStage, ) from torch.distributed.pipelining._utils import PipeliningShapeError from torch.testing._internal.common_cuda import TEST_MULTIGPU @@ -91,7 +91,7 @@ def test_tracer(self, ModelClass): split_spec=split_spec, ) - stage = PipelineStage( + stage = TracerPipelineStage( pipe, self.rank, device=self.device, @@ -157,7 +157,7 @@ def test_tracer_kwargs(self, ModelClass): example_kwargs={"y": y}, ) - stage = PipelineStage( + stage = TracerPipelineStage( pipe, self.rank, device=self.device, @@ -211,7 +211,7 @@ def test_manual(self): x = torch.randn(batch_size, d_hid, device=self.device) - stage = ManualPipelineStage( + stage = PipelineStage( stage_mod, self.rank, self.world_size, diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index b301c2e6e1ec..58ffdb9717e3 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -2,7 +2,7 @@ import logging import operator from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -21,7 +21,7 @@ __all__ = [ "PipelineStage", - "ManualPipelineStage", + "TracerPipelineStage", ] logger = logging.getLogger(__name__) @@ -80,7 +80,7 @@ def _make_tensor_from_meta( class _PipelineStageBase(ABC): """ Base class for pipeline stages. - Implements common methods used by both the `PipelineStage` used by the tracing frontend and `ManualPipelineStage`. + Implements common methods used by both the `TracerPipelineStage` used by the tracing frontend and `PipelineStage`. """ def __init__( @@ -894,7 +894,8 @@ def _create_grad_recv_info( return grad_recv_info_tuple -class PipelineStage(_PipelineStage): +# TODO: Update this to be returned by helper method under Pipe (kwen) +class TracerPipelineStage(_PipelineStage): def __init__( self, pipe: Pipe, @@ -919,7 +920,7 @@ def __init__( def _create_empty_tensors( - tensor: Union[torch.Tensor, List[torch.Tensor]], device: torch.device + tensor: Union[torch.Tensor, Iterable[torch.Tensor]], device: torch.device ) -> List[torch.Tensor]: """ Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s), @@ -1069,7 +1070,7 @@ def _get_stage_shapes( return stage_id_to_shapes -class ManualPipelineStage(_PipelineStageBase): +class PipelineStage(_PipelineStageBase): """ A class representing a pipeline stage in a pipeline parallelism setup. This class is created manually by providing a example input (and optionally output) @@ -1083,8 +1084,8 @@ class ManualPipelineStage(_PipelineStageBase): num_stages (int): The total number of stages. device (torch.device): The device where this stage is located. num_microbatches (int): The number of microbatches to use. - input_args (Union[torch.Tensor, List[torch.tensor]], optional): The input arguments for the submodule. - output_args (Union[torch.Tensor, List[torch.tensor]], optional): The output arguments for the submodule. + input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule. + output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule. group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. """ @@ -1095,8 +1096,8 @@ def __init__( num_stages: int, device: torch.device, num_microbatches: int, - input_args: Union[torch.Tensor, List[torch.Tensor]], - output_args: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + output_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None, group: Optional[dist.ProcessGroup] = None, ): super().__init__( @@ -1104,7 +1105,6 @@ def __init__( ) self.submod.to(self.device) # When we materialize the model partition on cuda, we call reset_parameters() if it is available - # logger.info(f"input args {input_args=}") self.inputs: List[torch.Tensor] = [] self.outputs: List[torch.Tensor] = [] @@ -1219,7 +1219,7 @@ def _init_p2p_neighbors(self): return True -def _validate_stage_shapes(pipeline_stages: List[ManualPipelineStage]): +def _validate_stage_shapes(pipeline_stages: List[PipelineStage]): """ Check that the buffer shapes match between stages was expected by performing an all_gather between all stages. diff --git a/torch/distributed/pipelining/README.md b/torch/distributed/pipelining/README.md index 46a05a22c8ce..556814c29b37 100644 --- a/torch/distributed/pipelining/README.md +++ b/torch/distributed/pipelining/README.md @@ -151,7 +151,7 @@ dist.init_process_group(rank=rank, world_size=world_size) # Pipeline stage is our main pipeline runtime. It takes in the pipe object, # the rank of this process, and the device. from torch.distributed.pipelining import PipelineStage -stage = PipelineStage(pipe, rank, device) +stage = TracerPipelineStage(pipe, rank, device) ``` We can now run the pipeline by attaching the `PipelineStage` to a pipeline schedule, GPipe for example: diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index c192c314e802..eca6e451bdc3 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -14,7 +14,7 @@ ScheduleInterleaved1F1B, ScheduleLoopedBFS, ) -from .PipelineStage import ManualPipelineStage, PipelineStage +from .PipelineStage import PipelineStage, TracerPipelineStage __all__ = [ "Pipe", @@ -24,7 +24,7 @@ "pipeline", "ArgsChunkSpec", "KwargsChunkSpec", - "ManualPipelineStage", + "TracerPipelineStage", "PipelineStage", "Schedule1F1B", "ScheduleGPipe", From 3f9798a4fd267e32e6ca96adbd127168f8bb8992 Mon Sep 17 00:00:00 2001 From: zabboud Date: Fri, 7 Jun 2024 15:17:22 +0000 Subject: [PATCH 467/706] add docstring to masked_fill, expand, select, unsqueeze, cat fns (#128055) Fixes #127891 Fixes #127893 Fixes #127894 Fixes #127907 Fixes #127910 ## Description Add docstring to `masked_fill`, `expand`, `select`, `unsqueeze`, and `cat` functions in torch.onnx.symbolic_opset9.py remaining pydocstyle errors: 257 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128055 Approved by: https://github.com/xadupre --- torch/onnx/symbolic_opset9.py | 57 +++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 95e8fcef391f..de9b616103e7 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -3,6 +3,7 @@ Opset 9 is supported by ONNX release 1.4.1 release on 01/23/19 """ + from __future__ import annotations import builtins @@ -521,6 +522,16 @@ def reciprocal(g: jit_utils.GraphContext, self): @symbolic_helper.parse_args("v", "i") @_beartype.beartype def cat(g: jit_utils.GraphContext, tensor_list, dim): + """Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension. + + Parameters: + g (jit_utils.GraphContext): Graph context. + tensor_list (List[torch.Tensor]): List of tensors to concatenate. + dim (int): Dimension along which to concatenate the tensors. + + Returns: + ONNX graph node representing the concatenated tensor. + """ tensors = symbolic_helper._unpack_list(tensor_list) # torch.cat ignores empty tensors such as `torch.Tensor([])` # These needs to be removed as input from ONNX's concat too, otherwise shape inference @@ -849,6 +860,7 @@ def numpy_T(g: jit_utils.GraphContext, input): @symbolic_helper.quantized_args(True) @_beartype.beartype def expand(g: jit_utils.GraphContext, self, size, implicit): + """Implement the expand function for a pytorch tensor in ONNX according to specified `size`""" size = symbolic_helper._maybe_get_const(size, "is") if not symbolic_helper._is_value(size): size = g.op("Constant", value_t=torch.LongTensor(size)) @@ -1132,6 +1144,10 @@ def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): @symbolic_helper.parse_args("v", "i", "v") @_beartype.beartype def select(g: jit_utils.GraphContext, self, dim, index): + """Implement the select functionality for a pytorch tensor in ONNX. + + Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor. + """ index = symbolic_helper._maybe_get_scalar(index) if (not symbolic_helper._is_value(index)) and (index < 0): if index == -1: @@ -1417,29 +1433,39 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding): ] # ensure last pooling starts inside ceiled_output_dim = [ - ceiled_output_dim[i] - 1 - if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) - else ceiled_output_dim[i] + ( + ceiled_output_dim[i] - 1 + if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) + else ceiled_output_dim[i] + ) for i in range(0, len(ceiled_output_dim)) ] padding_ceil = [ - 0 - if (stride[i] == 1) - else ( - kernel_size[i] - - (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1)) + ( + 0 + if (stride[i] == 1) + else ( + kernel_size[i] + - ( + dim[i] + + 2 * padding[i] + - ((ceiled_output_dim[i] - 1) * stride[i] + 1) + ) + ) ) for i in range(0, len(padding)) ] # ensure padding is not > kernel_size padding_ceil = [ ( - int(padding_ceil[i]) - if padding_ceil[i] < kernel_size[i] - 1 - else int(kernel_size[i] - 1) + ( + int(padding_ceil[i]) + if padding_ceil[i] < kernel_size[i] - 1 + else int(kernel_size[i] - 1) + ) + if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) + else int(padding_ceil[i]) ) - if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) - else int(padding_ceil[i]) for i in range(0, len(padding_ceil)) ] return padding_ceil @@ -4081,6 +4107,7 @@ def alias(g: jit_utils.GraphContext, self): @symbolic_helper.parse_args("v", "i") @_beartype.beartype def unsqueeze(g: jit_utils.GraphContext, self, dim): + """Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`""" # Handle negative dim if dim < 0: rank = symbolic_helper._get_tensor_rank(self) @@ -5580,6 +5607,10 @@ def lift(g: jit_utils.GraphContext, self): @_onnx_symbolic("aten::masked_fill") @_beartype.beartype def masked_fill(g: jit_utils.GraphContext, self, mask, value): + """Implement the masked_fill functionality available for a pytorch tensor in ONNX. + + Fills elements of the input tensor with `value` where `mask` is True. + """ mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) value = symbolic_helper._maybe_get_scalar(value) return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) From 771be55bb088c80766c690d48526e11f8b98a81e Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 7 Jun 2024 15:20:18 +0000 Subject: [PATCH 468/706] Documenting `torch.onnx.operator.shape_as_tensor` (#128051) Fixes #127890 This PR adds docstring to the `torch.onnx.operator.shape_as_tensor` function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128051 Approved by: https://github.com/xadupre --- torch/onnx/operators.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torch/onnx/operators.py b/torch/onnx/operators.py index e5f12444c355..489010519980 100644 --- a/torch/onnx/operators.py +++ b/torch/onnx/operators.py @@ -13,6 +13,20 @@ def shape_as_tensor(x): + """Get the shape of a tensor as a tensor. + + Args: + x (Tensor): The input tensor. + + Returns: + Tensor: A tensor of shape [len(x.shape)] containing the size of each dimension of x. + + Example: + >>> x = torch.randn(2, 3) + >>> shape_as_tensor(x) + tensor([2, 3]) + + """ return torch._shape_as_tensor(x) From 6e75024ff0673f2ebf55da3e7739f0265f7fa318 Mon Sep 17 00:00:00 2001 From: James Wu Date: Thu, 6 Jun 2024 09:30:24 -0700 Subject: [PATCH 469/706] Run TestAOTAutograd with dynamo (#128047) My goal is to run these tests with the autograd cache on, but first I want them running with dynamo. These tests already caught an interesting issue so I thought it would be helpful to just have them. Next up I'll have a second subclass of these tests, run them twice, and expect a cache hit the second time from autograd. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128047 Approved by: https://github.com/ezyang --- test/functorch/test_aotdispatch.py | 159 +++++++++++++++++++++++------ 1 file changed, 129 insertions(+), 30 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index a3ebb9eb08a5..bbb3ad5908f2 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -11,7 +11,7 @@ import unittest import warnings from contextlib import nullcontext -from functools import partial +from functools import partial, wraps from typing import Any, Callable, Dict, List, Optional, Union from unittest.mock import patch @@ -26,6 +26,7 @@ from functorch.compile import ( aot_function, aot_module, + aot_module_simplified, compiled_function, compiled_module, default_decompositions, @@ -39,11 +40,7 @@ ) from functorch.experimental import control_flow from torch._decomp import decomposition_table -from torch._functorch.aot_autograd import ( - aot_export_joint_simple, - aot_export_module, - aot_module_simplified, -) +from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module from torch._higher_order_ops.out_dtype import out_dtype from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch.fx.experimental.proxy_tensor import is_sym_node @@ -288,7 +285,66 @@ def is_in_base(t, maybe_tensors): return False +def skipIfDynamoInput(reason, xfail=False): + """ + Skip TestAOTAutograd if running with dynamo input + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + fn = func + if isinstance(self, TestAOTAutogradWithDynamo): + if xfail: + fn = unittest.expectedFailure(fn) + else: + self.skipTest( + f"Skipping {self._testMethodName} in TestAOTAutogradWithDynamo because {reason}" + ) + else: + fn(self, *args, **kwargs) + + return wrapper + + return decorator + + class TestAOTAutograd(AOTTestCase): + def run_autograd( + self, + f: Callable, + fw_graph_cell: List[Optional[Callable]], + decompositions: Optional[Dict], + keep_input_mutations: bool, + dynamic: bool, + ): + """ + Runs aot_autograd with the specified settings on f. + """ + if isinstance(f, nn.Module): + compiled_f = aot_module( + f, + fw_compiler=make_boxed_compiler( + partial(extract_graph, graph_cell=fw_graph_cell) + ), + bw_compiler=nop, + decompositions=decompositions, + keep_inference_input_mutations=keep_input_mutations, + dynamic=dynamic, + ) + else: + compiled_f = aot_function( + f, + fw_compiler=make_boxed_compiler( + partial(extract_graph, graph_cell=fw_graph_cell) + ), + bw_compiler=nop, + decompositions=decompositions, + keep_inference_input_mutations=keep_input_mutations, + dynamic=dynamic, + ) + return compiled_f + # test_mutation will: # - Ensure that inputs are non-leaves, so our graphs can mutate them # - try to mutate outputs of the graph (to ensure that autograd meta is set properly on outputs) @@ -349,28 +405,9 @@ def verify_aot_autograd( graph_inps = inp graph_inps_copy = inp_copy fw_graph_cell = [None] - if isinstance(f, nn.Module): - compiled_f = aot_module( - f, - fw_compiler=make_boxed_compiler( - partial(extract_graph, graph_cell=fw_graph_cell) - ), - bw_compiler=nop, - decompositions=decompositions, - keep_inference_input_mutations=keep_input_mutations, - dynamic=dynamic, - ) - else: - compiled_f = aot_function( - f, - fw_compiler=make_boxed_compiler( - partial(extract_graph, graph_cell=fw_graph_cell) - ), - bw_compiler=nop, - decompositions=decompositions, - keep_inference_input_mutations=keep_input_mutations, - dynamic=dynamic, - ) + compiled_f = self.run_autograd( + f, fw_graph_cell, decompositions, keep_input_mutations, dynamic + ) ref_out, ref_grad = outs_and_grads(f, graph_inps, inp) test_out, test_grad = outs_and_grads(compiled_f, graph_inps_copy, inp_copy) self.assertEqual(ref_grad, test_grad) @@ -537,6 +574,9 @@ def f(a, b): ] self.verify_aot_autograd(f, inp, keep_inp_mutations=True) + @skipIfDynamoInput( + "Test doesn't make sense with dynamo, which changes order of mutations" + ) def test_set__and_data_mutation_good(self): def f(a, b): # The data mutation happens *after* the set_(). This is ok (see the graph below) @@ -580,6 +620,7 @@ def forward(self, primals_1, primals_2): # https://github.com/pytorch/pytorch/issues/126236 # https://github.com/pytorch/pytorch/pull/126113 @xfailIfTorchDynamo + @skipIfDynamoInput("Not supported by dynamo", xfail=True) def test_set__and_data_mutation_bad(self): def f(a): a_view = a.view(-1) @@ -601,6 +642,9 @@ def f(a): f, inp, test_mutation=True, keep_inp_mutations=True ) + @skipIfDynamoInput( + "Test doesn't make sense with dynamo, which changes order of mutations" + ) def test_set__not_allowed(self): def f(a, b): with torch.no_grad(): @@ -678,8 +722,6 @@ def f(a): out_ref = f(ref_view) out_test = f_compiled(test_view) - print(ref) - print(test) self.assertEqual(ref, test) def test_input_mutation_modifies_autograd_meta_of_aliases(self): @@ -1809,6 +1851,7 @@ def forward(self, primals_1): ) @parametrize("req_grad", [False, True]) + @skipIfDynamoInput("Runtime error not raised with dynamo", xfail=True) def test_subclass_metadata_mutation(self, req_grad): def f(a): a.transpose_(1, 0) @@ -1882,6 +1925,7 @@ def forward(self, primals_1, primals_2): return [t, view_1, view_2]""", ) + @skipIfDynamoInput("https://github.com/pytorch/pytorch/issues/128035", xfail=True) def test_view_detach(self): def f(a): tmp = a.detach() @@ -1919,6 +1963,7 @@ def forward(self, primals_1, primals_2): # One gets a data mutation, the other gets a metadata mutation. # We need to make sure that the metadata mutation gets propagated # back to the original input. + @skipIfDynamoInput("Dynamo removes runtime error") def test_input_data_and_metadata_mutation_aliases_other_input(self): # a and b are aliased def f(a, b): @@ -2524,6 +2569,7 @@ def forward(self, primals_1, primals_2): return [as_strided_scatter, add, add_1]""", ) # noqa: B950 + @skipIfDynamoInput("Fails with dynamo") def test_input_mutation_aliases_bases_out_of_order(self): # This tests our calling convention: if b and d are aliased, then the outer calling convention # that we send to the compiled forward becomes: @@ -2598,6 +2644,7 @@ def inp_callable(): self.verify_aot_autograd(f, inp_callable, test_mutation=True) + @skipIfDynamoInput("https://github.com/pytorch/pytorch/issues/128035", xfail=True) def test_input_mutation_alias_everything(self): # Mondo test that tests a combination of: # input is mutated, that aliases another input (so we make a synthetic base) @@ -5821,5 +5868,57 @@ def test_aot_autograd_symbolic_module_exhaustive( instantiate_device_type_tests(TestEagerFusionModuleInfo, globals(), only_for=only_for) +@skipIfTorchDynamo("This test suite already uses dynamo") +class TestAOTAutogradWithDynamo(TestAOTAutograd): + """ + These are the same as TestAOTAutograd tests, but we run dynamo first to get a graph module. + """ + + def assertExpectedInline(self, *args, **kwargs): + # These will have different outputs because dynamo returns a different graph module + # But we don't really care about that assertion when testing with dynamo, + # only that the outputs match, etc. + pass + + # Compiler to passes to dynamo + def run_autograd( + self, + f: Callable, + fw_graph_cell: List[Optional[Callable]], + decompositions: Optional[Dict], + keep_input_mutations: bool, + dynamic: bool, + ): + """ + Runs dynamo and aot_autograd with the specified settings + """ + + def dynamo_compiler(gm, inputs, **kwargs): + result = aot_module_simplified( + gm, + inputs, + fw_compiler=make_boxed_compiler( + partial(extract_graph, graph_cell=fw_graph_cell) + ), + bw_compiler=nop, + decompositions=decompositions, + keep_inference_input_mutations=keep_input_mutations, + # Dynamic is calculated from whether the inputs have fake tensors + ) + return result + + def torch_compile_wrapper(*args, **kwargs): + torch._dynamo.reset() + fn = torch.compile(f, backend=dynamo_compiler) + try: + result = fn(*args, **kwargs) + except torch._dynamo.exc.BackendCompilerFailed as e: + # So that assertRaises works properly + raise e.inner_exception from e + return result + + return torch_compile_wrapper + + if __name__ == "__main__": run_tests() From 224b4339e590a6390e3e23fb05f11efbd4b3238a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Jun 2024 15:43:37 +0000 Subject: [PATCH 470/706] Revert "Make ValueRange repr less chatty by default (#128043)" This reverts commit f0dd11df5534ae074ad2d090e6700576a22719d6. Reverted https://github.com/pytorch/pytorch/pull/128043 on behalf of https://github.com/atalman due to Sorry reverting because in conflict with [#126905](https://github.com/pytorch/pytorch/pull/126905) which needs to be reverted ([comment](https://github.com/pytorch/pytorch/pull/128043#issuecomment-2155091732)) --- test/dynamo/test_misc.py | 12 ++++++------ torch/utils/_sympy/value_ranges.py | 3 --- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index e173a4d7a69e..dc2b9530f0dd 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9309,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} + > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9343,7 +9343,7 @@ def test_shape_env_equal_unbacked(self): > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]} + > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False, is_int=True, is_float=False), u1: ValueRanges(lower=0, upper=1, is_bool=False, is_int=True, is_float=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False, is_int=False, is_float=True)} > Right: {} """, ) @@ -9420,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self): > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]} - > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} + > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} """, ) self._replay_and_check(main) @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self): > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]} - > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} + > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} """, ) self._replay_and_check(main) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index c7257f999b52..4d364d4981b5 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -127,9 +127,6 @@ class ValueRanges(Generic[_T]): is_int: bool is_float: bool - def __repr__(self) -> str: - return f"VR[{self.lower}, {self.upper}]" - @overload def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: ... From 3090667cf9c3119ba5a5dbc4c1f093b80892b10f Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 7 Jun 2024 08:42:25 -0700 Subject: [PATCH 471/706] [pipelining] pipeline() taking microbatch as example input (#128163) Changed the API of `pipeline()` to take microbatch instead of full batch as example args. Main purpose is to: - make this API more atomic; - decouple tracing frontend from runtime info like `num_chunks`. Side effects: - Creates opportunity for varying `num_chunks` of schedules with the same `pipe` object. - User has to create example microbatch input. - Chunk spec stuff are now all moved to runtime side. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128163 Approved by: https://github.com/H-Huang --- docs/source/distributed.pipelining.rst | 6 - test/allowlist_for_publicAPI.json | 3 - test/distributed/pipelining/test_chunkspec.py | 72 ---------- .../distributed/pipelining/test_microbatch.py | 36 ++++- test/distributed/pipelining/test_pipe.py | 9 +- test/distributed/pipelining/test_schedule.py | 28 ++-- test/distributed/pipelining/test_stage.py | 18 ++- .../pipelining/test_transformer.py | 11 +- test/distributed/pipelining/test_unflatten.py | 1 - .../pipelining/PipelineSchedule.py | 39 ++++-- torch/distributed/pipelining/PipelineStage.py | 8 +- torch/distributed/pipelining/_IR.py | 126 +++--------------- torch/distributed/pipelining/__init__.py | 13 +- torch/distributed/pipelining/microbatch.py | 49 ++++++- 14 files changed, 168 insertions(+), 251 deletions(-) delete mode 100644 test/distributed/pipelining/test_chunkspec.py diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index c1b82b8b1bb9..4f816bc3b843 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -317,14 +317,8 @@ The following set of APIs transform your model into a pipeline representation. .. autoclass:: Pipe -.. autofunction:: annotate_split_points - .. autofunction:: pipe_split -.. autoclass:: ArgsChunkSpec - -.. autoclass:: KwargsChunkSpec - Microbatch Utilities ==================== diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 8bedc0072300..947b8d79077a 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2604,12 +2604,9 @@ "TensorPipeRpcBackendOptions" ], "torch.distributed.pipelining": [ - "ArgsChunkSpec", - "KwargsChunkSpec", "Pipe", "PipelineStage", "SplitPoint", - "annotate_split_points", "pipe_split", "pipeline" ], diff --git a/test/distributed/pipelining/test_chunkspec.py b/test/distributed/pipelining/test_chunkspec.py deleted file mode 100644 index 1b104e59ec77..000000000000 --- a/test/distributed/pipelining/test_chunkspec.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["oncall: distributed"] -import torch -from torch.distributed.pipelining import ( - ArgsChunkSpec, - KwargsChunkSpec, - pipe_split, - pipeline, -) -from torch.testing._internal.common_utils import run_tests, TestCase - - -d_hid = 512 -batch_size = 256 - -torch.manual_seed(0) - - -class ModelWithKwargs(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin1 = torch.nn.Linear(d_hid, d_hid) - self.lin2 = torch.nn.Linear(d_hid, d_hid) - - def forward(self, x, y, z=torch.zeros(batch_size, d_hid)): - x = torch.mm(x, self.mm_param0) - x = x + y - x = torch.relu(x) - x = x + z - pipe_split() - x = torch.mm(x, self.mm_param1) - x = self.lin1(x) - pipe_split() - x = torch.relu(x) - x = torch.mm(x, self.mm_param2) - pipe_split() - x = self.lin2(x) - x = torch.relu(x) - return x - - -class ChunkSpecTests(TestCase): - def test_chunk_spec(self): - mod = ModelWithKwargs() - - x = torch.randn(batch_size, d_hid) - y = torch.randn(batch_size, d_hid) - z = torch.randn(batch_size, d_hid) - - chunks = 4 - - with ArgsChunkSpec((0, 0)), KwargsChunkSpec({"z": 0}): - pipe = pipeline( - mod, - chunks, - example_args=(x, y), - example_kwargs={"z": z}, - ) - - assert pipe.num_stages == 4 - - ref = mod(x, y, z) - out = pipe(x, y, z)[0] - torch.testing.assert_close(out, ref) - print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipelining/test_microbatch.py b/test/distributed/pipelining/test_microbatch.py index c526c6ff7b91..9f67c2c37ea4 100644 --- a/test/distributed/pipelining/test_microbatch.py +++ b/test/distributed/pipelining/test_microbatch.py @@ -1,6 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] +from model_registry import ModelWithKwargs + import torch +from torch.distributed.pipelining import pipeline from torch.distributed.pipelining.microbatch import ( merge_chunks, split_args_kwargs_into_chunks, @@ -10,6 +13,7 @@ d_hid = 512 +torch.manual_seed(0) class MicrobatchTests(TestCase): @@ -49,9 +53,39 @@ def test_split_and_merge(self): }, ) torch.testing.assert_close(merged_kwargs, kwargs) - print("Microbatch test passed") + def test_chunk_spec(self): + mod = ModelWithKwargs() + batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE + + x = torch.randn(batch_size, d_hid) + y = torch.randn(batch_size, d_hid) + + num_chunks = 4 + + args_chunk_spec = TensorChunkSpec.from_tuple((0,)) + kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0}) + + args_split, kwargs_split = split_args_kwargs_into_chunks( + (x,), + {"y": y}, + num_chunks, + args_chunk_spec, + kwargs_chunk_spec, + ) + + pipe = pipeline( + mod, + mb_args=args_split[0], + mb_kwargs=kwargs_split[0], + ) + + ref = mod(x, y) + out = pipe(x, y)[0] + torch.testing.assert_close(out, ref) + print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") + if __name__ == "__main__": run_tests() diff --git a/test/distributed/pipelining/test_pipe.py b/test/distributed/pipelining/test_pipe.py index df053bd6c249..d4d158bc9d5f 100644 --- a/test/distributed/pipelining/test_pipe.py +++ b/test/distributed/pipelining/test_pipe.py @@ -13,7 +13,7 @@ d_hid = 512 -batch_size = 256 +microbatch_size = 16 torch.manual_seed(0) @@ -81,13 +81,12 @@ class PipeTests(TestCase): @parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias]) def test_model_split(self, ModelClass): mod = ModelClass() - x = torch.randn(batch_size, d_hid) - y = torch.randn(batch_size, d_hid) + x = torch.randn(microbatch_size, d_hid) + y = torch.randn(microbatch_size, d_hid) pipe = pipeline( mod, - num_chunks=4, - example_args=(x, y), + mb_args=(x, y), ) assert ( diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 81b4e1c7ae07..d040efdc7522 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -81,20 +81,22 @@ def test_multi_iter(self, ScheduleClass): target = torch.randn(batch_size, d_hid, device=self.device) loss_fn = torch.nn.MSELoss(reduction="sum") - # Create a pipeline chunks = 4 + x_mb = x.chunk(chunks)[0] + + # Create a pipeline split_spec = mod.split_spec if hasattr(mod, "split_spec") else None pipe = pipeline( mod, - chunks, - example_args=(x,), + mb_args=(x_mb,), split_spec=split_spec, ) stage = TracerPipelineStage( pipe, self.rank, - device=self.device, + self.device, + chunks, # to be cleaned ) # Attach to a schedule @@ -123,17 +125,20 @@ def test_kwargs_with_tracer(self, ScheduleClass): loss_fn = torch.nn.MSELoss(reduction="sum") chunks = 4 + x_mb = x.chunk(chunks)[0] + y_mb = y.chunk(chunks)[0] + pipe = pipeline( mod, - chunks, - example_args=(x,), - example_kwargs={"y": y}, + mb_args=(x_mb,), + mb_kwargs={"y": y_mb}, ) stage = TracerPipelineStage( pipe, self.rank, - device=self.device, + self.device, + chunks, # to be cleaned ) # Attach to a schedule @@ -184,18 +189,19 @@ def test_grad_with_tracer(self, ScheduleClass, ModelClass): # Create a pipeline chunks = 4 + x_mb = x.chunk(chunks)[0] split_spec = mod.split_spec if hasattr(mod, "split_spec") else None pipe = pipeline( mod, - chunks, - example_args=(x,), + mb_args=(x_mb,), split_spec=split_spec, ) stage = TracerPipelineStage( pipe, self.rank, - device=self.device, + self.device, + chunks, # to be cleaned ) # Attach to a schedule diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index 45f4b0b01a9c..ec459af7a596 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -82,19 +82,20 @@ def test_tracer(self, ModelClass): mod.to(self.device) x = torch.randn(batch_size, d_hid, device=self.device) + x_mb = x.chunk(chunks)[0] split_spec = mod.split_spec if hasattr(mod, "split_spec") else None pipe = pipeline( mod, - chunks, - example_args=(x,), + mb_args=(x_mb,), split_spec=split_spec, ) stage = TracerPipelineStage( pipe, self.rank, - device=self.device, + self.device, + chunks, # to be cleaned ) # Attach to a schedule @@ -150,17 +151,20 @@ def test_tracer_kwargs(self, ModelClass): x = torch.randn(batch_size, d_hid, device=self.device) y = torch.randn(batch_size, d_hid, device=self.device) + x_mb = x.chunk(chunks)[0] + y_mb = y.chunk(chunks)[0] + pipe = pipeline( mod, - chunks, - example_args=(x,), - example_kwargs={"y": y}, + mb_args=(x_mb,), + mb_kwargs={"y": y_mb}, ) stage = TracerPipelineStage( pipe, self.rank, - device=self.device, + self.device, + chunks, ) # Attach to a schedule diff --git a/test/distributed/pipelining/test_transformer.py b/test/distributed/pipelining/test_transformer.py index 9742c77b606a..070a62d11638 100644 --- a/test/distributed/pipelining/test_transformer.py +++ b/test/distributed/pipelining/test_transformer.py @@ -7,7 +7,7 @@ d_hid = 16 n_layers = 8 -batch_size = 4 +microbatch_size = 4 class MLPModule(torch.nn.Module): @@ -36,8 +36,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerTests(TestCase): def test_ir(self): transformer = TransformerLike() - print("Original model:\n", transformer) - x = torch.randn(batch_size, d_hid) + x = torch.randn(microbatch_size, d_hid) # Split into 2 stages num_stages = 2 @@ -45,7 +44,6 @@ def test_ir(self): pipe = pipeline( transformer, - 1, (x,), split_spec=split_spec, ) @@ -59,19 +57,18 @@ def get_layers(module): layers = [] for stage_idx in range(pipe.num_stages): stage_mod = pipe.get_stage_module(stage_idx) - print(f"\nStage {stage_idx}: \n", stage_mod) layers += get_layers(stage_mod) # Check layer completeness orig_layers = get_layers(transformer) assert sorted(layers) == sorted(orig_layers), f"{layers} != {orig_layers}" - print("Layers matched! ", layers) + print("Layers matched!") # Check equivalence ref = transformer(x) out = pipe(x)[0] torch.testing.assert_close(out, ref) - print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") + print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") if __name__ == "__main__": diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index 37eaf599e4d8..ef2e48d8ee9f 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -48,7 +48,6 @@ def test_unflatten(self): pipe = pipeline( mod, - 1, (x,), {"constant": constant}, ) diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index 28b7514ab16f..f3d64189fe0e 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -20,7 +20,7 @@ import torch.distributed as dist from torch.profiler import record_function -from .microbatch import merge_chunks, split_args_kwargs_into_chunks +from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec from .PipelineStage import _PipelineStageBase if TYPE_CHECKING: @@ -64,12 +64,24 @@ def __init__( self, n_microbatches: int, loss_fn: Optional[Callable[..., torch.Tensor]] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ): # From arguments self._n_microbatches = n_microbatches self._loss_fn = loss_fn + # Chunking specification for positional inputs. (default: `None`) + self._args_chunk_spec = args_chunk_spec + # Chunking specification for keyword inputs. (default: `None`) + self._kwargs_chunk_spec = kwargs_chunk_spec self._output_merge_spec = output_merge_spec + """ + # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. + # They are used to convert batch to microbatches in `step(x)`. See + # `TensorChunkSpec` for helper methods for creating them. + """ + # Derived self._has_backward = self._loss_fn is not None # To be filled by subclasses @@ -201,22 +213,13 @@ def _split_inputs( Splits a full-batch input into chunks (i.e. microbatches) and returns the chunks """ - if self._pipe_info is not None: - # Use spec from `pipe_info` - args_chunk_spec = self._pipe_info.args_chunk_spec - kwargs_chunk_spec = self._pipe_info.kwargs_chunk_spec - else: - # Use default spec from `microbatch.py` (i.e. chunk dim 0 for each arg/kwarg) - args_chunk_spec = None - kwargs_chunk_spec = None - if args or kwargs: args_split, kwargs_split = split_args_kwargs_into_chunks( args, kwargs, self._n_microbatches, - args_chunk_spec, - kwargs_chunk_spec, + self._args_chunk_spec, + self._kwargs_chunk_spec, ) return args_split, kwargs_split else: @@ -285,12 +288,16 @@ def __init__( stage: _PipelineStageBase, n_microbatches: int, loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ): # Init parent super().__init__( n_microbatches=n_microbatches, loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, ) self._pipe_info = ( @@ -567,6 +574,8 @@ def __init__( stages: List[_PipelineStageBase], n_microbatches: int, loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ): if len(stages) <= 1: @@ -577,6 +586,8 @@ def __init__( super().__init__( n_microbatches=n_microbatches, loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, ) self._pipe_info = ( @@ -712,6 +723,8 @@ def __init__( stages: List[_PipelineStageBase], n_microbatches: int, loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ): self.pp_group_size = stages[0].group_size @@ -726,6 +739,8 @@ def __init__( stages=stages, n_microbatches=n_microbatches, loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, ) diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index 58ffdb9717e3..5761e03d689a 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -630,6 +630,7 @@ def __init__( stage_index: int, pipe_info: Pipe.PipeInfo, device: torch.device, + num_chunks: int, group: Optional[dist.ProcessGroup] = None, ): """ @@ -642,7 +643,7 @@ def __init__( stage_index, pipe_info.num_stages, device, - pipe_info.num_chunks, + num_chunks, group, ) self.pipe_info = pipe_info @@ -901,6 +902,7 @@ def __init__( pipe: Pipe, stage_index: int, device: torch.device, + num_chunks: int, # To be cleaned group: Optional[dist.ProcessGroup] = None, ): """ @@ -910,7 +912,9 @@ def __init__( stage_module = pipe.get_stage_module(stage_index) # Get my pipe info pipe_info = pipe.info() - super().__init__(stage_module, stage_index, pipe_info, device, group) + super().__init__( + stage_module, stage_index, pipe_info, device, num_chunks, group + ) # Manual PipelineStage functions and definition diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index ed2ef32d4255..9c3e21ba70ea 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -23,7 +23,6 @@ from ._backward import _null_coalesce_accumulate, stage_backward from ._unflatten import _outline_submodules -from .microbatch import split_args_kwargs_into_chunks, TensorChunkSpec logger = logging.getLogger(__name__) @@ -486,28 +485,11 @@ def _direct_serialization_reduce(self): class Pipe(torch.nn.Module): - # Class variables - # args_chunk_spec and kwargs_chunk_spec are used to specify how to chunk - # inputs. They are used to create microbatched examples before tracing. - # See context managers `ArgsChunkSpec` and `KwargsChunkSpec`. - # TODO: Do we need to support `_Replicate`? It's unclear, dropping for now. - - # args_chunk_spec: - # Chunking specification for positional inputs. (default: `None`) - args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None - - # kwargs_chunk_spec: - # Chunking specification for keyword inputs. (default: `None`) - kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None - @dataclass class PipeInfo: graph: fx.Graph num_stages: int - num_chunks: int has_loss_and_backward: bool - args_chunk_spec: Optional[Tuple[Any, ...]] = None - kwargs_chunk_spec: Optional[Dict[str, Any]] = None def __init__( self, @@ -1000,7 +982,6 @@ def _trace_with_export( @staticmethod def from_tracing( mod: torch.nn.Module, - num_chunks: int, example_args: Tuple[Any, ...], example_kwargs: Optional[Dict[str, Any]] = None, split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, @@ -1019,19 +1000,11 @@ def from_tracing( ) """ - args_split, kwargs_split = split_args_kwargs_into_chunks( - example_args, - example_kwargs, - num_chunks, - Pipe.args_chunk_spec, - Pipe.kwargs_chunk_spec, - ) - # Trace with export exported_program = Pipe._trace_with_export( mod, - example_args=args_split[0], - example_kwargs=kwargs_split[0], + example_args, + example_kwargs, ) pipe = Pipe._from_traced( @@ -1075,10 +1048,7 @@ def from_tracing( pipe.pipe_info = Pipe.PipeInfo( graph=pipe.split_gm.graph, num_stages=pipe.num_stages, - num_chunks=num_chunks, has_loss_and_backward=pipe.has_loss_and_backward, - args_chunk_spec=Pipe.args_chunk_spec, - kwargs_chunk_spec=Pipe.kwargs_chunk_spec, ) return pipe @@ -1145,29 +1115,26 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): def pipeline( module: torch.nn.Module, - num_chunks: int, - example_args: Tuple[Any, ...], - example_kwargs: Optional[Dict[str, Any]] = None, + mb_args: Tuple[Any, ...], + mb_kwargs: Optional[Dict[str, Any]] = None, split_spec: Optional[Dict[str, SplitPoint]] = None, split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, ) -> Pipe: """ - Creates a pipeline representation for the provided module. + Split a module based on a specification. See `Pipe` for more details. Arguments --------- module: - The module to be transformed into a `Pipe`. - num_chunks: - The number of microbatches to be run with this pipeline. - example_args: - Example positional inputs to be used with this pipeline. - example_kwargs: - Example keyword inputs to be used with this pipeline. (default: `None`) + The module to be splitted. + mb_args: + Example positional inputs, in micro-batch form. + mb_kwargs: + Example keyword inputs, in micro-batch form. (default: `None`) split_spec: - A dictionary mapping module names to `SplitPoint`s. (default: `None`) + A dictionary using submodule names as split marker. (default: `None`) split_policy: The policy to use for splitting the module. (default: `None`) @@ -1185,77 +1152,14 @@ def pipeline( annotate_split_points(module, split_spec) return Pipe.from_tracing( mod=module, - num_chunks=num_chunks, - example_args=example_args, - example_kwargs=example_kwargs, + example_args=mb_args, + example_kwargs=mb_kwargs, ) else: # Use split policy return Pipe.from_tracing( mod=module, - num_chunks=num_chunks, - example_args=example_args, - example_kwargs=example_kwargs, + example_args=mb_args, + example_kwargs=mb_kwargs, split_policy=split_policy, ) - - -class ArgsChunkSpec: - """ - Context manager for setting `args_chunk_spec` during creation of Pipe - - Example: - >>> # xdoctest: +SKIP - >>> # There are three positional arguments to the model, and - >>> # we are chunking them along dimension 0, 0 and 1, respectively - >>> with ArgsChunkSpec((0, 0, 1)): - >>> pipe = pipeline(model, num_chunks, example_args) - """ - - def __init__( - self, - chunk_dims: Tuple[int, ...], - ): - self.args_chunk_spec = map_aggregate( - chunk_dims, - lambda dim: TensorChunkSpec(dim), - ) - - def __enter__(self): - # Inject into the Pipe class - Pipe.args_chunk_spec = self.args_chunk_spec - return self.args_chunk_spec - - def __exit__(self, exc_type, exc_val, traceback): - # Remove from the Pipe class - Pipe.args_chunk_spec = None - - -class KwargsChunkSpec: - """ - Context manager for setting `kwargs_chunk_spec` during creation of Pipe - - Example: - >>> # xdoctest: +SKIP - >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument - >>> with KwargsChunkSpec({"id": 0, "mask": 1}): - >>> pipe = pipeline(model, num_chunks, (), example_kwargs) - """ - - def __init__( - self, - chunk_dims: Dict[str, int], - ): - self.kwargs_chunk_spec = map_aggregate( - chunk_dims, - lambda dim: TensorChunkSpec(dim), - ) - - def __enter__(self): - # Inject into the Pipe class - Pipe.kwargs_chunk_spec = self.kwargs_chunk_spec - return self.kwargs_chunk_spec - - def __exit__(self, exc_type, exc_val, traceback): - # Remove from the Pipe class - Pipe.kwargs_chunk_spec = None diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index eca6e451bdc3..d9fd8feaf6e5 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -1,13 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from ._IR import ( - annotate_split_points, - ArgsChunkSpec, - KwargsChunkSpec, - Pipe, - pipe_split, - pipeline, - SplitPoint, -) +from ._IR import Pipe, pipe_split, pipeline, SplitPoint from .PipelineSchedule import ( Schedule1F1B, ScheduleGPipe, @@ -20,10 +12,7 @@ "Pipe", "pipe_split", "SplitPoint", - "annotate_split_points", "pipeline", - "ArgsChunkSpec", - "KwargsChunkSpec", "TracerPipelineStage", "PipelineStage", "Schedule1F1B", diff --git a/torch/distributed/pipelining/microbatch.py b/torch/distributed/pipelining/microbatch.py index 1201e235d036..6358a1293edb 100644 --- a/torch/distributed/pipelining/microbatch.py +++ b/torch/distributed/pipelining/microbatch.py @@ -3,9 +3,16 @@ from typing import Any, Dict, List, Optional, Tuple import torch +from torch.fx.node import map_aggregate from torch.utils._pytree import tree_flatten, tree_unflatten +__all__ = [ + "TensorChunkSpec", + "split_args_kwargs_into_chunks", + "merge_chunks", +] + logger = logging.getLogger(__name__) """ @@ -45,8 +52,11 @@ class _LossReducer(_CustomReducer): DEFAULT_CHUNK_DIM = 0 -# Class used to specify chunking of inputs class TensorChunkSpec: + """ + Class used to specify chunking of inputs + """ + def __init__(self, split_dim): self.split_dim = split_dim @@ -60,6 +70,43 @@ def __repr__(self): def __str__(self): return f"TensorChunkSpec({self.split_dim})" + @staticmethod + def from_tuple( + chunk_dims: Tuple[int, ...], + ): + """ + A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk + dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # There are three positional arguments to the model, and + >>> # we are chunking them along dimension 0, 0 and 1, respectively + >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) + """ + args_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), + ) + return args_chunk_spec + + @staticmethod + def from_dict( + chunk_dims: Dict[str, int], + ): + """ + A helper for creating a dictionary of `TensorChunkSpec` from a + dictionary of chunk dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument + >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) + """ + kwargs_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), + ) + return kwargs_chunk_spec + # Class used to specify replication of inputs class _Replicate: From a1b664adeb5739b3c28a6c48aedb5fda29bf92e3 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 7 Jun 2024 15:54:07 +0000 Subject: [PATCH 472/706] Add default values to PyTorchMemEffAttention::AttentionKernel::Params members (#112215) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Default values were added to Params in order to eliminate CUDA warnings like ``` and the implicitly-defined constructor does not initialize ‘PyTorchMemEffAttention::AttentionKernel::accum_t PyTorchMemEffAttention::AttentionKernel::Params::scale’ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/112215 Approved by: https://github.com/eqy, https://github.com/ezyang --- aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh | 4 ++-- .../transformers/cuda/mem_eff_attention/kernel_forward.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh b/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh index c9eeeadd542d..231cd167cacb 100644 --- a/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh +++ b/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh @@ -34,8 +34,8 @@ struct PhiloxCudaState { int64_t* ptr; }; - Payload seed_; - Payload offset_; + Payload seed_{}; + Payload offset_{}; uint32_t offset_intragraph_ = 0; bool captured_ = false; }; diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h index 642145f5a0da..a10e5a9c44a0 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -190,8 +190,8 @@ struct AttentionKernel { unsigned long long dropout_batch_head_rng_offset = 0; float dropout_prob = 0.0f; at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0); - int64_t* extragraph_offset; - int64_t* seed; + int64_t* extragraph_offset = nullptr; + int64_t* seed = nullptr; // Moves pointers to what we should process // Returns "false" if there is no work to do From 23c156cd2d699ea1f67deae2bf4353e327daf16b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Jun 2024 15:58:36 +0000 Subject: [PATCH 473/706] Revert "[inductor] simplify indexing (#127661)" This reverts commit 901226ae837bd4629b34735c84a3481c4988bb5b. Reverted https://github.com/pytorch/pytorch/pull/127661 on behalf of https://github.com/atalman due to Sorry reverting because in conflict with https://github.com/pytorch/pytorch/pull/126905 which needs to be reverted, will be relanding it ([comment](https://github.com/pytorch/pytorch/pull/127661#issuecomment-2155115388)) --- test/inductor/test_indexing.py | 78 +------------------ torch/_inductor/codegen/simd.py | 19 +---- torch/_inductor/sizevars.py | 131 -------------------------------- 3 files changed, 2 insertions(+), 226 deletions(-) diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 19a736160908..da527cfbb1d8 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -1,24 +1,16 @@ # Owner(s): ["module: inductor"] -import os -import unittest - import sympy -import torch - from torch._inductor.codegen.cpp import cexpr from torch._inductor.codegen.triton import texpr from torch._inductor.codegen.wrapper import pexpr -from torch._inductor.runtime.runtime_utils import do_bench_gpu from torch._inductor.sizevars import SizeVarAllocator from torch._inductor.test_case import TestCase as InductorTestCase -from torch._inductor.utils import run_and_get_triton_code from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA from torch.utils._sympy.functions import ( FloorDiv, ModularIndexing, @@ -26,8 +18,6 @@ RoundToInt, ) -DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" - class TestIndexingSimplification(InductorTestCase): def test_indexing_simplification(self): @@ -174,73 +164,6 @@ def test_indexing_join(self): self.assertEqual(simplified, FloorDiv(i0, 3)) self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485})) - def test_modular_indexing_pairs_merged(self): - sizevars = SizeVarAllocator() - x = sympy.Symbol("x", integer=True, positive=True) - a = 1024 - b = 32 - expr1 = ModularIndexing(x, 1, a) - expr2 = ModularIndexing(expr1, 1, b) - expected = ModularIndexing(x, 1, b) - - actual = sizevars.combine_modular_indexing_pairs(expr2) - self.assertEqual(expected, actual) - self.assertNotEqual(expr2, actual) - - def test_modular_indexing_pairs_not_merged(self): - sizevars = SizeVarAllocator() - x = sympy.Symbol("x", integer=True, positive=True) - a = 1024 - b = 3 # pick a 'b' that we can not merge - expr1 = ModularIndexing(x, 1, a) - expr2 = ModularIndexing(expr1, 1, b) - - actual = sizevars.combine_modular_indexing_pairs(expr2) - self.assertEqual(expr2, actual) - self.assertNotEqual(ModularIndexing(x, 1, b), actual) - - def test_expand_floor_div_skipped(self): - sizevars = SizeVarAllocator() - x = sympy.Symbol("x", integer=True, positive=True) - y = sympy.Symbol("y", integer=True, positive=True) - - expr = FloorDiv(x, 2) + FloorDiv(y, 3) - # The expression can not be simplified since there are multiple - # FloorDiv. We return False in that case - self.assertFalse(sizevars.expand_floor_div(expr)) - - def test_expand_floor_div_applied(self): - sizevars = SizeVarAllocator() - x = sympy.Symbol("x", integer=True, positive=True) - y = sympy.Symbol("y", integer=True, positive=True) - - expr = x * 5 + FloorDiv(y, 3) - actual, denominator = sizevars.expand_floor_div(expr) - self.assertNotEqual(expr, actual) - expected = FloorDiv(x * 15 + y, 3) - self.assertEqual(expected, FloorDiv(actual, denominator)) - - @unittest.skipUnless(HAS_CUDA, "Need GPU for this test") - def test_int8_unpack(self): - @torch.compile - def f(x): - first_elements = x >> 4 - second_elements = x & 15 - unpacked = torch.stack([first_elements, second_elements], dim=-1).view( - *x.size()[:-1], -1 - ) - return unpacked * 2 - - x = torch.randint(0, 255, (2, 4096, 5504), dtype=torch.uint8, device="cuda") - - triton_code = run_and_get_triton_code(f, x) - # Make sure the 2 load uses simpified indexing rather than something like - # tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), - self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),")) - if DO_PERF_TEST: - ms = do_bench_gpu(lambda: f(x)) - print(f"{ms=:.03f}") - class ExprPrinterTests(InductorTestCase): def test_print_pow(self): @@ -358,6 +281,7 @@ def test_print_Min_Max(self): if __name__ == "__main__": from torch._inductor.test_case import run_tests + from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA if HAS_CPU or HAS_CUDA: run_tests("sympy") diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index c5fc2747bee7..ed7261f2a3eb 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -338,8 +338,7 @@ def simplify_indexing(index: sympy.Expr): index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) for tree in self.range_trees: index = self.combine_contiguous_dims(index, tree) - - return self.combine_modular_indexing_pairs(index) + return index self.simplify_indexing = simplify_indexing self.initialize_range_tree(pid_cache) @@ -423,23 +422,7 @@ def dense_size_str(self): sizes = self.dense_size_list() return f"[{', '.join(sizes)}]" - def combine_modular_indexing_pairs(self, index): - if not isinstance(index, ModularIndexing): - return index - x = index.args[0] - if (tree_node := self.range_tree_nodes.get(x)) is None: - return index - new_index = sympy_subs(index, {x: tree_node.expr}) - return V.graph.sizevars.combine_modular_indexing_pairs(new_index) - def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): - if expand_res := V.graph.sizevars.expand_floor_div(index): - new_index, denominator = expand_res # type: ignore[misc] - return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator) - else: - return self._combine_contiguous_dims(index, tree) - - def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): """ More aggressive simplification to merge contiguous dims """ diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 910e85e79906..fba9a66f9237 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -583,137 +583,6 @@ def lookup_precomputed_size(self, expr: Expr) -> Expr: def free_symbols(self) -> Set[sympy.Symbol]: return set(self.var_to_val.keys()) - set(self.replacements.keys()) - def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: - """ - A pair of special ModularIndexing can be combined. - - E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b) - We can simplify this to ModuleIndexing(x, 1, b), if - 1. x is non negative integer - 2. a and b are positive integers - 3. a is a multiple of b. - """ - - def _check_args(x, div, mod, is_first): - if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer): - return False - if div != 1: - return False - if mod <= 0: - return False - - if is_first: - # first ModularIndexing should conatins a nested ModularIndex - if not isinstance(x, ModularIndexing): - return False - else: - # second ModularIndexing should constains a non-negative - # symbol - if not isinstance(x, sympy.Symbol) or not self.statically_known_geq( - x, 0 - ): - return False - return True - - if isinstance(index, ModularIndexing): - x, div, mod = index.args - - if not _check_args(x, div, mod, True): - return index - - x2, div2, mod2 = x.args - - if not _check_args(x2, div2, mod2, False): - return index - - if mod2 % mod != 0: - return index - - return ModularIndexing(x2, 1, mod) - - return index - - def expand_floor_div( - self, index: sympy.Expr - ) -> Union[bool, Tuple[sympy.Expr, sympy.Expr]]: - """ - Expand the FloorDiv to the entire expression so that the expression may - be simplfied. - - E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables - x1, x2, index expression 'x1 * 2b + x2' can be easily combined. - But index expression 'x1 * b + x2 // 2' can not. - By expanding the FloorDiv to the entire expression, we get - '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops - for the numerator! - - Return false if this optimization can be applied; - Return the new expression and the denominator otherwise. - The original expression will be equivalent to 'new_expression // denominator' - """ - if not isinstance(index, sympy.Add): - return False - terms = index.args - - if len(terms) < 2: - return False - floor_div_index = -1 - varlist = [] - factorlist = [] - for idx, term in enumerate(terms): - if isinstance(term, sympy.Mul): - # For dynamic shape, term like '2*s1*x1' has 3 child nodes. - # - A integer for 2 - # - A symbol for s1 - # - A symbol for x1 - # Skip for now. - if len(term.args) != 2: - return False - factor, var = term.args - varlist.append(var) - factorlist.append(factor) - if not isinstance(factor, sympy.Integer) or not isinstance( - var, sympy.Symbol - ): - return False - # It's easier to reason about the correceness of the transformation - # for non-negative integers. - if not self.statically_known_geq(var, 0): - return False - elif isinstance(term, FloorDiv): - var, factor = term.args - if not isinstance(factor, sympy.Integer) or not isinstance( - var, sympy.Symbol - ): - return False - if not self.statically_known_geq(var, 0): - return False - if floor_div_index >= 0: - # can not handle multi FloorDiv yet - return False - - floor_div_index = idx - varlist.append(var) - # this factor is denominator - factorlist.append(factor) - else: - return False - - if floor_div_index < 0: - return False - - # Construct the new expression and remember the denominator - denominator = factorlist[floor_div_index] - new_index = sympy.Integer(0) - - for var, factor, idx in zip(varlist, factorlist, itertools.count()): - if idx == floor_div_index: - new_index += var - else: - new_index += (factor * denominator) * var - - return new_index, denominator - def join_dimensions(expr: Expr) -> Expr: if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): From ac51f782fe012af58af57bd5e8aab781ed07c90c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Jun 2024 16:01:46 +0000 Subject: [PATCH 474/706] Revert "Complete revamp of float/promotion sympy handling (#126905)" This reverts commit 2f7cfecd86009a9d396fdbdcdfb4ba7a005db16b. Reverted https://github.com/pytorch/pytorch/pull/126905 on behalf of https://github.com/atalman due to Sorry need to revert - failing internally ([comment](https://github.com/pytorch/pytorch/pull/126905#issuecomment-2155118778)) --- c10/core/SymNodeImpl.h | 18 - test/dynamo/test_dynamic_shapes.py | 7 + test/dynamo/test_export.py | 3 +- test/dynamo/test_misc.py | 17 +- test/inductor/test_indexing.py | 72 +++- .../test_torchinductor_dynamic_shapes.py | 28 -- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 8 +- test/test_dynamic_shapes.py | 208 ++++++--- test/test_proxy_tensor.py | 3 +- test/test_sympy_utils.py | 122 +++--- torch/__init__.py | 162 +------ torch/_export/serde/serialize.py | 9 +- torch/_inductor/bounds.py | 5 - torch/_inductor/codegen/common.py | 176 ++------ torch/_inductor/codegen/cpp.py | 4 +- torch/_inductor/codegen/cpp_utils.py | 55 +-- torch/_inductor/codegen/triton.py | 64 +-- torch/_inductor/graph.py | 5 +- torch/_inductor/ir.py | 16 +- torch/_inductor/kernel/flex_attention.py | 5 +- torch/_inductor/lowering.py | 6 +- torch/_inductor/ops_handler.py | 60 +-- torch/_inductor/select_algorithm.py | 4 +- torch/_inductor/sizevars.py | 20 +- torch/_inductor/utils.py | 2 +- torch/_subclasses/fake_tensor.py | 2 +- torch/csrc/jit/python/init.cpp | 5 - torch/csrc/utils/python_symnode.h | 20 - torch/export/dynamic_shapes.py | 9 +- torch/fx/experimental/recording.py | 8 +- torch/fx/experimental/sym_node.py | 210 ++------- torch/fx/experimental/symbolic_shapes.py | 82 ++-- torch/fx/experimental/validator.py | 32 +- torch/utils/_sympy/functions.py | 398 ++++-------------- torch/utils/_sympy/interp.py | 71 +--- torch/utils/_sympy/reference.py | 151 +++---- torch/utils/_sympy/solve.py | 1 - torch/utils/_sympy/value_ranges.py | 275 ++++-------- 38 files changed, 669 insertions(+), 1674 deletions(-) diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index bb92b09775b7..9ffab5065109 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -49,33 +49,15 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode mul(const SymNode& other) { TORCH_CHECK(false, "NYI"); } - // NB: legacy, prefer float_truediv or int_truediv virtual SymNode truediv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } - virtual SymNode float_truediv(const SymNode& other) { - return truediv(other); - } - virtual SymNode int_truediv(const SymNode& other) { - return truediv(other); - } - // NB: legacy, prefer float_pow or pow_by_natural virtual SymNode pow(const SymNode& other) { TORCH_CHECK(false, "NYI"); } - virtual SymNode float_pow(const SymNode& other) { - return pow(other); - } - virtual SymNode pow_by_natural(const SymNode& other) { - return pow(other); - } - // NB: legacy, prefer int_floordiv virtual SymNode floordiv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } - virtual SymNode int_floordiv(const SymNode& other) { - return floordiv(other); - } virtual SymNode mod(const SymNode& other) { TORCH_CHECK(false, "NYI"); } diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index a3c63ef66152..0bead6e47e48 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -78,6 +78,13 @@ def make_dynamic_cls(cls): del test if TEST_Z3: + # this only fails when z3 is available + unittest.expectedFailure( + # SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'. + # Ref: https://github.com/sympy/sympy/issues/25146 + DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821 + ) + if not config.inline_inbuilt_nn_modules: # TODO model is somehow not being freed when z3 is available unittest.expectedFailure( diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 7ae0f839f6ff..9f1417e23247 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2385,7 +2385,8 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, "Constraints violated .*!(.*\n)*.*" - "Not all values of dim0 .* satisfy the generated guard 4 <= .* and .* <= 10(.*\n)*.*", + "by dim0 = 2\\*dim1(.*\n)*.*" + "Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*", ): torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index dc2b9530f0dd..bcb0fd18818e 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9309,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9343,7 +9343,7 @@ def test_shape_env_equal_unbacked(self): > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False, is_int=True, is_float=False), u1: ValueRanges(lower=0, upper=1, is_bool=False, is_int=True, is_float=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False, is_int=False, is_float=True)} + > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), u1: ValueRanges(lower=0, upper=1, is_bool=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False)} > Right: {} """, ) @@ -9420,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self): > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} """, ) self._replay_and_check(main) @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self): > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} """, ) self._replay_and_check(main) @@ -9484,7 +9484,10 @@ def test_shape_env_equal_runtime_assert(self): ShapeEnv not equal: field values don't match: ==> deferred_runtime_asserts: values don't match. - > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} + > Left: {u0: [Eq(Mod(u0, 3), 0)]} + > Right: {} +==> divisible: values don't match. + > Left: {Mod(u0, 3)} > Right: {} ==> name_to_node: values don't match. > Left: {_assert, eq, mod, u0} diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index da527cfbb1d8..299a619f9cd6 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -11,12 +11,7 @@ instantiate_parametrized_tests, parametrize, ) -from torch.utils._sympy.functions import ( - FloorDiv, - ModularIndexing, - RoundDecimal, - RoundToInt, -) +from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Round, RoundDecimal class TestIndexingSimplification(InductorTestCase): @@ -173,11 +168,21 @@ def test_print_pow(self): common_cases = [ # expr, result + # Test exprs. + ( + s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1), + lambda c, L: f"((-1{L})*({c}/((-1{L}) + (2{L}*foo)))) + (foo*({c}/((-1{L}) + (2{L}*foo))))", + ), + (s1 / (s2 - s3), lambda c, L: f"foo*({c}/(bar + ((-1{L})*baz)))"), # Test Pow directly. ( sympy.Pow(s1 + s2, 0), lambda _, L: f"1{L}", ), # note: simplified before _print_Pow + ( + sympy.Pow(s1 + s2, -3), + lambda c, _: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))", + ), ] gpu_cases = common_cases + [ @@ -226,10 +231,12 @@ def test_print_ceil(self): self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""") def test_print_round(self): - expr = RoundToInt(sympy.Symbol("x", integer=True) / 2) + expr = Round(sympy.Symbol("x", integer=True) / 2) self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""") self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""") - self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""") + self.assertExpectedInline( + texpr(expr), """libdevice.llrint((1/2)*x).to(tl.int64)""" + ) @parametrize("ndigits", [-1, 0, 1]) def test_print_round_decimal(self, ndigits): @@ -244,18 +251,45 @@ def test_print_round_decimal(self, ndigits): f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}", ) + expr = RoundDecimal(sympy.Symbol("x", integer=True), ndigits) + if ndigits >= 0: + for do_print in [pexpr, cexpr, texpr]: + self.assertEqual(do_print(expr), "x") + else: + self.assertEqual(pexpr(expr), f"round(x, {ndigits})") + for do_print in [cexpr, texpr]: + with self.assertRaisesRegex( + ValueError, "only non-negative ndigits are currently supported" + ): + do_print(expr) + def test_print_floor_div(self): - s1 = sympy.Symbol("s1", integer=True) - s2 = sympy.Symbol("s2", integer=True) - expr = FloorDiv(s1, s2) - self.assertEqual(pexpr(expr), "(s1 // s2)") - self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") - - s1 = sympy.Symbol("s1", integer=True) - s2 = sympy.S(-1) - expr = FloorDiv(s1, s2) - self.assertEqual(pexpr(expr), "(-1)*s1") - self.assertEqual(cexpr(expr), "(-1L)*s1") + for integer in [True, False]: + s1 = sympy.Symbol("s1", integer=integer) + s2 = sympy.Symbol("s2", integer=integer) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(s1 // s2)") + if integer: + self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") + else: + self.assertEqual( + cexpr(expr), + "c10::div_floor_floating(static_cast(s1), static_cast(s2))", + ) + + for integer in [True, False]: + s1 = sympy.Symbol("s1", integer=integer) + s2 = sympy.S(-1) + expr = FloorDiv(s1, s2) + if integer: + self.assertEqual(pexpr(expr), "(-1)*s1") + self.assertEqual(cexpr(expr), "(-1L)*s1") + else: + self.assertEqual(pexpr(expr), "(s1 // (-1))") + self.assertEqual( + cexpr(expr), + "c10::div_floor_floating(static_cast(s1), static_cast((-1L)))", + ) def test_print_Min_Max(self): cases = ( diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 2f9506a9d561..8513e928c412 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -3,7 +3,6 @@ import importlib import math -import operator import os import sys import unittest @@ -650,33 +649,6 @@ def fn(a): actual = cfn(5) self.assertEqual(expect, actual) - def test_interpolate_ceil_eq(self, device): - ceiling = math.ceil - IntTrueDiv = operator.truediv - - def fn(t): - s0, s2, s3 = t.size() - x = torch.zeros( - ( - s0, - 2048, - ceiling(IntTrueDiv(2 * ((s2 - 1) // 8) + 2, 1)), - ceiling(IntTrueDiv(2 * ((s3 - 1) // 8) + 2, 1)), - ), - dtype=torch.bfloat16, - ) - return torch.nn.functional.interpolate( - x, - scale_factor=2, - mode="nearest", - ) - - cfn = self.compile_fn(fn) - arg = torch.randn(4, 16, 18) - expect = fn(arg) - actual = cfn(arg) - self.assertEqual(expect, actual) - def test_full_recompiles(self, device): def fn(x): _, L = x.shape diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 0f0e01bc0dc2..b70bfbf9c4a7 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -158,12 +158,8 @@ def forward(self, x, y): torch.tensor([operator.sub(x.item(), y.item())]), torch.tensor([operator.mul(x.item(), y.item())]), torch.tensor([operator.truediv(x.item(), y.item())]), - # This requires torch.sym_float, probably easy to lower to - # ONNX but I don't know where to put it - # torch.tensor([operator.floordiv(x.item(), y.item())]), - # NB: abs so that the base and exponent are provably - # non-negative, so we don't generate runtime asserts - torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]), + torch.tensor([operator.floordiv(x.item(), y.item())]), + torch.tensor([operator.pow(x.item(), y.item())]), torch.tensor([operator.abs(x.item())]), torch.tensor([operator.neg(x.item())]), torch.tensor([math.ceil(x.item())]), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 3b47f12198d5..d548e9df0707 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -205,15 +205,15 @@ def create_symtype(cls, pytype, shape_env, val, duck=True): # TODO: default duck to False -def create_symint(shape_env, i: int, duck=True) -> SymInt: +def create_symint(shape_env, i: int, duck=True): return create_symtype(SymInt, int, shape_env, i, duck=duck) -def create_symbool(shape_env, b: bool) -> SymBool: +def create_symbool(shape_env, b: bool): return create_symtype(SymBool, bool, shape_env, b) -def create_symfloat(shape_env, f: float) -> SymFloat: +def create_symfloat(shape_env, f: float): return create_symtype(SymFloat, float, shape_env, f) @@ -457,16 +457,14 @@ def test_sym_int(self): r = sym_int(a1 / 2) self.assertEqual(guard_int(r), 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline( - str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" - ) + self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""") a3 = create_symint(shape_env, 3) r = sym_int(2.0 * torch.sym_float(a3)) self.assertEqual(guard_int(r), 6) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" + str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)""" ) def test_sym_sqrt(self): @@ -476,7 +474,7 @@ def test_sym_sqrt(self): self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)""" + str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)""" ) def test_sym_floor(self): @@ -485,17 +483,11 @@ def test_sym_floor(self): r = math.floor(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline( - str(shape_env.guards[0][0]), - """Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""", - ) + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""") r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline( - str(shape_env.guards[1][0]), - """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", - ) + self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") def test_sym_trunc(self): shape_env = ShapeEnv() @@ -503,14 +495,12 @@ def test_sym_trunc(self): r = math.trunc(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" - ) + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""") r = torch.sym_int(torch.sym_sqrt(a0)) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)""" + str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)""" ) def test_sym_ceil(self): @@ -520,17 +510,12 @@ def test_sym_ceil(self): self.assertEqual(r, 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), - """Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""", + str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""" ) - r1 = 3.0 * a0 - r = math.floor(r1) + r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline( - str(shape_env.guards[1][0]), - """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", - ) + self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") def test_sym_ite(self): shape_env = ShapeEnv() @@ -977,14 +962,8 @@ def test_ephemeral_source_unified_with_non_ephemeral_source(self): ) class TestSymNumberMagicMethods(TestCase): def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): - with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn): - return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn) - - def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn): # Helper function # NB: don't use one as that will get specialized - # TODO: We don't have to circuitously create the float, can just - # create a symfloat directly seed_node = (create_symint(shape_env, 2) / 2.0).node bool_seed_node = (create_symint(shape_env, 2) == 2).node @@ -997,42 +976,27 @@ def get_sym_inp(inp): else: return torch.SymFloat(to_node(seed_node, inp)) - if fn == "float_pow": - if inp1 < 0: - return - - if fn == "pow_by_natural": - if isinstance(inp1, float) or isinstance(inp2, float): - return - if inp2 < 0: - return - def maybe_xfail(inp1, inp2): if fn == "sym_sqrt" and inp1 < 0: # ValueError: math domain error return self.assertRaises((ValueError,)) - elif ( - fn in ("float_truediv", "int_truediv", "int_floordiv", "mod") - and inp2 == 0 - ): + elif fn in ("truediv", "floordiv", "mod") and inp2 == 0: # ZeroDivisionError: division by zero return self.assertRaises((ZeroDivisionError,)) - elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0: + elif fn == "pow" and inp1 == 0 and inp2 < 0: # ZeroDivisionError: 0.0 cannot be raised to a negative power return self.assertRaises((ZeroDivisionError,)) elif ( - # TODO: dear catastrophe waitress, - # this doesn't work - fn in ["float_pow", "pow_by_natural"] + fn == "pow" and inp1 < 0 + and inp2 in (2.5, -2.5) and ( - type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat) + type(inp1) in (SymFloat, SymInt) or type(inp2) in (SymFloat, SymInt) ) - and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float)) ): # Complex result, which we do not support: # TypeError: Cannot convert complex to float - return self.assertRaises((RuntimeError,)) + return self.assertRaises((TypeError,)) elif fn in ("lshift", "rshift") and not ( isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int)) ): @@ -1116,9 +1080,6 @@ def test_method(self, fn, first_type, second_type): ) and fn in sym_node.only_float_magic_methods: self.skipTest(f"{fn} is not an int method") - if second_type == "float" and fn in ["mod"]: - self.skipTest(f"{fn} only handles int") - is_unary_fn = fn in sym_node.unary_methods or fn == "round" # Second argument is ignored for unary function. So only run for one type if is_unary_fn and second_type == "float": @@ -1290,15 +1251,112 @@ def yield_test_cases(values, negate=True): yield (-x, -y) def test_floordiv_float_int(self): - values = ((7, 2),) + values = ( + (2.5, 2.1), + (2.1, 2.5), + (2.0, 2.1), + (7, 2.5), + (2.1, 7), + (7, 2), + ) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) ) + def test_floordiv_bool(self): + values = ( + (False, True), + (True, 2.5), + (2.5, True), + (False, 7), + (7, True), + ) + + for x, y in TestFloorDiv.yield_test_cases(values, negate=False): + # Compares to int since our FloorDiv has no bool support + self.assertEqual( + TestFloorDiv.python_floordiv(x, y), + TestFloorDiv.torch_floordiv(int(x), int(y)), + ) + # Tests that our impl throws + self.assertRaisesRegex( + TypeError, + ( + rf"unsupported operand type\(s\) for //: " + rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" + rf", expected integer or real" + ), + lambda: TestFloorDiv.torch_floordiv(x, y), + ) + + def test_floordiv_complex(self): + values = ( + (1.5 + 2.5j, 1.3 + 3.5j), + (1.5 + 2.5j, 2.5), + (2.5, 1.5 + 2.5j), + (1.5 + 2.5j, 7), + (7, 1.5 + 2.5j), + ) + + for x, y in TestFloorDiv.yield_test_cases(values): + # We don't test error messages to avoid depending on Python + # interpreter version + self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y)) + self.assertRaisesRegex( + TypeError, + ( + rf"unsupported operand type\(s\) for //: " + rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" + rf", expected integer or real" + ), + lambda: TestFloorDiv.torch_floordiv(x, y), + ) + + def test_floordiv_div_by_zero(self): + values = ( + (2.5, 0), + (2.1, 0.0), + (2.3, sympy.Symbol("s", zero=True)), + ) + + for x, y in TestFloorDiv.yield_test_cases(values, negate=False): + # We don't test error messages to avoid depending on Python + # interpreter version + if type(y) is not sympy.Symbol: + self.assertRaises( + ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y) + ) + self.assertRaisesRegex( + ZeroDivisionError, + "division by zero", + lambda: TestFloorDiv.torch_floordiv(x, y), + ) + + def test_floordiv_zero_base(self): + values = ( + (0, 2.5), + (0.0, 2.1), + (sympy.Symbol("s", zero=True), 2.3), + ) + + for x, y in TestFloorDiv.yield_test_cases(values, negate=False): + if type(x) is not sympy.Symbol: + self.assertEqual( + TestFloorDiv.python_floordiv(x, y), + TestFloorDiv.torch_floordiv(x, y), + ) + else: + self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y)) + def test_floordiv_div_by_one(self): - values = ((2, 1),) + values = ( + (2.5, 1), + (2.1, 1.0), + (2, 1.0), + (2, 1), + ) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( @@ -1309,7 +1367,12 @@ def test_floordiv_simplify(self): # Tests how we simplify or evaluate FloorDiv without free variables shape_env = ShapeEnv() result = 21 - exprs = (7 * FloorDiv(6, 2),) + exprs = ( + 7 * FloorDiv(6, 2), + 7 * FloorDiv(6.28, 2), + 7 * FloorDiv(6.28, 2.0), + 7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))), + ) for expr in exprs: self.assertEqual(expr, result) @@ -1319,10 +1382,33 @@ def test_floordiv_simplify(self): self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result) + def test_floordiv_simplify_rational(self): + result = 21 + + a = sympy.Symbol("a", integer=True) + b = sympy.Symbol("b") + + cases = [ + (FloorDiv(a, sympy.Rational(1, 8)), 8 * a), + (FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)), + ] + + for expr, expected in cases: + self.assertEqual(expr, expected) + def test_floordiv_assumptions(self): + # We define two Symbols (with different names) for each type to make + # sure the behavior is consistent regardless of whether both arguments + # are the same object or not. cases = ( sympy.Symbol("i1", integer=True), sympy.Symbol("i2", integer=True), + sympy.Symbol("r1", real=True), + sympy.Symbol("r2", real=True), + sympy.Symbol("c1", complex=True, real=False, integer=False), + sympy.Symbol("c2", complex=True, real=False, integer=False), + sympy.Symbol("s1"), + sympy.Symbol("s2"), ) for base, divisor in itertools.product(cases, repeat=2): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 04483ffba0fc..c7b2e51ced20 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1618,8 +1618,7 @@ def f(a): self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) - sym_float = torch.sym_float(sym_size_int); sym_size_int = None - pow_1 = sym_float ** 0.5; sym_float = None + pow_1 = sym_size_int ** 0.5; sym_size_int = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None return div""") diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 8b16b2c620fd..c5da8f7fc0da 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -36,12 +36,7 @@ "floor", "ceil", ] -BINARY_OPS = [ - "truediv", "floordiv", - # "truncdiv", # TODO - # NB: pow is float_pow - "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" -] +BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"] UNARY_BOOL_OPS = ["not_"] BINARY_BOOL_OPS = ["or_", "and_"] @@ -86,24 +81,16 @@ def valid_unary(fn, v): def valid_binary(fn, a, b): if fn == "pow" and ( - # sympy will expand to x*x*... for integral b; don't do it if it's big - b > 4 - # no imaginary numbers - or a <= 0 - # 0**0 is undefined - or (a == b == 0) - ): - return False - elif fn == "pow_by_natural" and ( - # sympy will expand to x*x*... for integral b; don't do it if it's big b > 4 - or b < 0 - or (a == b == 0) + or ( # sympy will expand to x*x*... for integral b; don't do it if it's big + a <= 0 and b == -1 + ) + or (a == b == 0) # no imaginary numbers # 0**0 is undefined ): return False - elif fn == "mod" and (a < 0 or b <= 0): + elif fn == "mod" and b == 0: return False - elif (fn in ["div", "truediv", "floordiv"]) and b == 0: + elif (fn == "div" or fn == "truediv") and b == 0: return False return True @@ -143,26 +130,27 @@ def test_pow_half(self): ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) @parametrize("fn", BINARY_OPS) - @parametrize("dtype", ("int", "float")) - def test_binary_ref(self, fn, dtype): + @parametrize("dtype_a", ("int", "float")) + @parametrize("dtype_b", ("int", "float")) + def test_binary_ref(self, fn, dtype_a, dtype_b): to_dtype = {"int": sympy.Integer, "float": sympy.Float} - # Don't test float on int only methods - if dtype == "float" and fn in ["pow_by_natural", "mod"]: - return - dtype = to_dtype[dtype] + dtype_a = to_dtype[dtype_a] + dtype_b = to_dtype[dtype_b] for a, b in itertools.product(CONSTANTS, repeat=2): if not valid_binary(fn, a, b): continue - a = dtype(a) - b = dtype(b) + a = dtype_a(a) + b = dtype_b(b) with self.subTest(a=a, b=b): r = getattr(ValueRangeAnalysis, fn)(a, b) if r == ValueRanges.unknown(): continue ref_r = getattr(ReferenceAnalysis, fn)(a, b) - self.assertEqual(r.lower.is_integer, r.upper.is_integer) - self.assertEqual(ref_r.is_integer, r.upper.is_integer) + # sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf + if fn != "floordiv": + self.assertEqual(r.lower.is_integer, r.upper.is_integer) + self.assertEqual(ref_r.is_integer, r.upper.is_integer) self.assertEqual(r.lower, r.upper) self.assertEqual(ref_r, r.lower) @@ -212,8 +200,7 @@ def test_binary_bool_ref_range(self, fn): @parametrize("fn", UNARY_OPS) def test_unary_ref_range(self, fn): - # TODO: bring back sympy.oo testing for float unary fns - vals = CONSTANTS + vals = [-sympy.oo, *CONSTANTS, sympy.oo] for a in generate_range(vals): with self.subTest(a=a): ref_r = getattr(ValueRangeAnalysis, fn)(a) @@ -229,26 +216,40 @@ def test_unary_ref_range(self, fn): # This takes about 4s for all the variants @parametrize("fn", BINARY_OPS + COMPARE_OPS) def test_binary_ref_range(self, fn): - # TODO: bring back sympy.oo testing for float unary fns - vals = LESS_CONSTANTS + vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo] for a, b in itertools.product(generate_range(vals), repeat=2): # don't attempt pow on exponents that are too large (but oo is OK) if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: continue with self.subTest(a=a, b=b): + ref_r = getattr(ValueRangeAnalysis, fn)(a, b) for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): if a0 not in a or b0 not in b: continue if not valid_binary(fn, a0, b0): continue with self.subTest(a0=a0, b0=b0): - ref_r = getattr(ValueRangeAnalysis, fn)(a, b) r = getattr(ReferenceAnalysis, fn)( sympy.Integer(a0), sympy.Integer(b0) ) if r.is_finite: self.assertIn(r, ref_r) + def test_rational_bounds(self): + # Repro from https://github.com/pytorch/pytorch/issues/105097 + from sympy import floor, Eq + shape_0 = sympy.Symbol('shape_0', positive=True, integer=True) + new_expr = ( + Eq(30 * floor(4 * ((shape_0 + 1) // 96) * + ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 + + 2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647), + 2880 * floor(((shape_0 + 1) // 96) * + ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 + + 323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764))) + new_range_env = {shape_0: ValueRanges(lower=1, upper=190)} + self.assertTrue(new_expr.subs({shape_0: 95})) + self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)) + class TestSympyInterp(TestCase): @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) @@ -257,13 +258,7 @@ def test_interp(self, fn): if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): return - is_integer = None - if fn == "pow_by_natural": - is_integer = True - - x = sympy.Dummy('x', integer=is_integer) - y = sympy.Dummy('y', integer=is_integer) - + from sympy.abc import x, y vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] @@ -305,17 +300,29 @@ def test_python_interp_fx(self, fn): if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: arity = 2 - is_integer = None - if fn == "pow_by_natural": - is_integer = True - - x = sympy.Dummy('x', integer=is_integer) - y = sympy.Dummy('y', integer=is_integer) + from sympy.abc import x, y symbols = [x] if arity == 2: symbols = [x, y] + # Workaround mpf from symbol error + if fn == "minimum": + sympy_expr = sympy.Min(x, y) + elif fn == "maximum": + sympy_expr = sympy.Max(x, y) + else: + sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) + + if arity == 1: + def trace_f(px): + return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) + else: + def trace_f(px, py): + return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) + + gm = fx.symbolic_trace(trace_f) + for args in itertools.product(vals, repeat=arity): if arity == 1 and not valid_unary(fn, *args): continue @@ -323,28 +330,11 @@ def test_python_interp_fx(self, fn): continue if fn == "truncdiv" and args[1] == 0: continue - elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0): + elif fn == "pow" and (args[0] == 0 and args[1] <= 0): continue elif fn == "floordiv" and args[1] == 0: continue with self.subTest(args=args): - # Workaround mpf from symbol error - if fn == "minimum": - sympy_expr = sympy.Min(x, y) - elif fn == "maximum": - sympy_expr = sympy.Max(x, y) - else: - sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) - - if arity == 1: - def trace_f(px): - return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) - else: - def trace_f(px, py): - return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) - - gm = fx.symbolic_trace(trace_f) - self.assertEqual( sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), gm(*args) diff --git a/torch/__init__.py b/torch/__init__.py index 896a2c50c36d..16804ff75898 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -316,75 +316,6 @@ def __index__(self): # Magic methods installed by torch.fx.experimental.sym_node - def __round__(self, ndigits=None): - return self - - def __truediv__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return sym_float(self).__float_truediv__(other) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - return self.__int_truediv__(other) - - def __rtruediv__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return sym_float(self).__rfloat_truediv__(other) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - return self.__rint_truediv__(other) - - def __floordiv__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return torch.sym_float(math.floor(sym_float(self) / other)) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - return self.__int_floordiv__(other) - - def __rfloordiv__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return torch.sym_float(math.floor(other / sym_float(self))) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - return self.__rint_floordiv__(other) - - # nb: complex is impossible to handle correctly lol, with - # negative base and integral float need to diverge semantics and - # just always return complex. Neener neener pretend this problem - # doesn't exist - def __pow__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return sym_float(self).__pow__(other) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - # Guards! This guard is necessary because we need to know it to - # determine the output type of this operation - if other >= 0: - return self.__pow_by_natural__(other) - else: - # Mercifully, when the exponent is negative, Python just promotes - # to doubles and does a float pow: - # - # if (Py_SIZE(b) < 0 && c == NULL) { - # /* if exponent is negative and there's no modulus: - # return a float. This works because we know - # that this calls float_pow() which converts its - # arguments to double. */ - # Py_DECREF(a); - # Py_DECREF(b); - # return PyFloat_Type.tp_as_number->nb_power(v, w, x); - # } - return sym_float(self).__pow__(sym_float(other)) - - def __rpow__(self, other): - if isinstance(other, (builtins.float, SymFloat)): - return sym_float(self).__rpow__(other) - if not isinstance(other, (builtins.int, SymInt)): - return NotImplemented - if self >= 0: # self is exponent - return self.__rpow_by_natural__(other) - else: - return sym_float(self).__rpow__(sym_float(other)) - def __eq__(self, other: object) -> builtins.bool: raise AssertionError("type stub not overridden") @@ -406,24 +337,6 @@ def __add__(self, other) -> "SymInt": def __mul__(self, other) -> "SymInt": raise AssertionError("type stub not overridden") - def __pow_by_natural__(self, other) -> "SymInt": - raise AssertionError("type stub not overridden") - - def __rpow_by_natural__(self, other) -> "SymInt": - raise AssertionError("type stub not overridden") - - def __int_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __rint_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __int_floordiv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __rint_floordiv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - def __sym_max__(self, other): raise AssertionError("type stub not overridden") @@ -458,43 +371,9 @@ def __init__(self, node): # class has a field named node that stores SymNode self.node = node - def __truediv__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - return self.__float_truediv__(sym_float(other)) - - def __rtruediv__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - return self.__rfloat_truediv__(sym_float(other)) - - def __floordiv__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - return torch.sym_float(math.floor(self / sym_float(other))) - - def __rfloordiv__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - return torch.sym_float(math.floor(sym_float(other) / self)) - def __bool__(self): return self.node.bool_() - # Symbolic power does NOT work with negative base, this is to avoid - # potential complex outputs - def __pow__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - torch._check(self >= 0) - return self.__float_pow__(other) - - def __rpow__(self, other): - if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): - return NotImplemented - torch._check(other >= 0) - return self.__rfloat_pow__(other) - # Magic methods installed by torch.fx.experimental.sym_node def __eq__(self, other: object) -> builtins.bool: @@ -512,18 +391,6 @@ def __le__(self, other) -> builtins.bool: def __ge__(self, other) -> builtins.bool: raise AssertionError("type stub not overridden") - def __float_pow__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __rfloat_pow__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __float_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - - def __rfloat_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") - def __trunc__(self): raise AssertionError("type stub not overridden") @@ -657,12 +524,7 @@ def sym_int(a): return py_int(a) # type: ignore[operator] def sym_max(a, b): - """ - SymInt-aware utility for max which avoids branching on a < b. - Unlike builtins.max(), this only works for int/float, and it always - promotes to float if any argument is float (unlike builtins.max, which - will faithfully preserve the type of the input argument). - """ + """ SymInt-aware utility for max().""" from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -670,19 +532,14 @@ def sym_max(a, b): if isinstance(a, (SymInt, SymFloat)): return a.__sym_max__(b) elif isinstance(b, (SymInt, SymFloat)): - # Due to promotion semantics, this is operator is commutative: - # max(1, 1.0) === max(1.0, 1) === 1.0 + # NB: If you actually care about preserving output type exactly + # if you do something like max(0, 0.0), it is NOT sound to treat + # min/max as commutative return b.__sym_max__(a) - # TODO: Probably can make bool work too, just lazy - assert isinstance(a, (builtins.int, builtins.float)), type(a) - assert isinstance(b, (builtins.int, builtins.float)), type(b) - if isinstance(a, builtins.float) or isinstance(b, builtins.float): - return builtins.float(builtins.max(a, b)) - else: - return builtins.max(a, b) + return builtins.max(a, b) # type: ignore[operator] def sym_min(a, b): - """ SymInt-aware utility for min().""" + """ SymInt-aware utility for max().""" from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -691,12 +548,7 @@ def sym_min(a, b): return a.__sym_min__(b) elif isinstance(b, (SymInt, SymFloat)): return b.__sym_min__(a) - assert isinstance(a, (builtins.int, builtins.float)), type(a) - assert isinstance(b, (builtins.int, builtins.float)), type(b) - if isinstance(a, builtins.float) or isinstance(b, builtins.float): - return builtins.float(builtins.min(a, b)) - else: - return builtins.min(a, b) + return builtins.min(a, b) # type: ignore[operator] # Drop in replacement for math.sqrt, math.sin, math.cos etc def _get_sym_math_fn(name): diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 9a92c238f950..8d6dc939fb5c 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1474,15 +1474,10 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: # Here we force symbols corresponding to SymInts to be at least integers. # Otherwise some expressions that the shape env would otherwise evaluate to False, # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. - # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024) sym = sym.subs( {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} ) - # We need to check if the symbol has already been allocated, - # self.symbol_name_to_symbol is not enough because the - # integer-ification of symbols can induce simplification; - # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral - if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val: + if isinstance(sym, sympy.Symbol): self.symbol_name_to_symbol[val.expr_str] = sym if hint is not None: self.shape_env.add_var_to_val(sym, hint) @@ -1501,7 +1496,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: free_symbols = sym.free_symbols for s in free_symbols: if s.name not in self.symbol_name_to_symbol: - self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment] + self.symbol_name_to_symbol[s.name] = s if vr := self.symbol_name_to_range.get(s.name): self.shape_env.constrain_symbol_range( s, diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 212b79e35bf9..4640ec4dce6b 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,4 +1,3 @@ -import logging import operator from functools import partial from typing import Any, Callable, Dict @@ -12,9 +11,6 @@ from .virtualized import V -log = logging.getLogger(__name__) - - class BoundVars: """ Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() @@ -59,7 +55,6 @@ def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: with V.set_ops_handler(ValueRangeAnalysis()): interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) - log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) interpreter.run(V.get_ops_handler(), initial_env=self._bounds) return self._bounds diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index dae72186df00..f7b3e7a45d6e 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -340,8 +340,6 @@ def propagate_scheduler_node(cls, node): DataTypePropagation.propagate_loopbody(node._body) -# This printer contains rules that are supposed to be generic for both C/C++ and -# Python class ExprPrinter(Printer): @staticmethod def paren(string): @@ -371,6 +369,12 @@ def all_in_parens(string): return string return f"({string})" + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + def _print_Relational(self, expr): return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) @@ -380,14 +384,11 @@ def _print_Mul(self, expr): def _print_Add(self, expr): return " + ".join(map(self.paren, map(self._print, expr.args))) - # NB: this is OK to put here, because Mod is only defined for positive - # numbers, and so across C/Python its behavior is consistent def _print_Mod(self, expr): return " % ".join(map(self.paren, map(self._print, expr.args))) - def _print_FloatTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) @@ -398,84 +399,10 @@ def _print_GreaterThan(self, expr): # Go figure... return " >= ".join(map(self.paren, map(self._print, expr.args))) - # NB: The C implementation is injected into codegen at - # torch/_inductor/codegen/wrapper.py def _print_align(self, expr): assert len(expr.args) == 1 return f"align({self._print(expr.args[0])})" - # This must be implemented because sympy will collect x * x into Pow(x, 2), without - # any explicit intervention. We print it just like x * x, notably, we - # never generate sympy.Pow with floats. - # - # NB: this pow by natural, you should never have used builtin sympy.pow - # for FloatPow, and a symbolic exponent should be PowByNatural. These - # means exp is guaranteed to be integer. - def _print_Pow(self, expr): - base, exp = expr.args - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - assert exp >= 0 - if exp > 0: - return "*".join([self.paren(base)] * exp) - else: # exp == 0 - return "1" - - # Explicit NotImplemented functions are to prevent default sympy printing - # behavior, which will just barf out ToFloat(...) to your IR. The error - # message is better here because it tells you which printer class it needs - # to go in. - - def _print_ToFloat(self, expr): - raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") - - def _print_Infinity(self, expr): - raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") - - def _print_NegativeInfinity(self, expr): - raise NotImplementedError( - f"_print_NegativeInfinity not implemented for {type(self)}" - ) - - def _print_FloorDiv(self, expr): - raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") - - def _print_PythonMod(self, expr): - raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") - - def _print_IntTrueDiv(self, expr): - raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") - - def _print_PowByNatural(self, expr): - raise NotImplementedError( - f"_print_PowByNatural not implemented for {type(self)}" - ) - - def _print_FloatPow(self, expr): - raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") - - def _print_TruncToInt(self, expr): - raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") - - def _print_RoundToInt(self, expr): - raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") - - def _print_RoundDecimal(self, expr): - raise NotImplementedError( - f"_print_RoundDecimal not implemented for {type(self)}" - ) - - # NB: Some float operations are INTENTIONALLY not implemented for - # printers. You can implement them as a quick unblock, but it is better - # to ask yourself why we haven't done this computation in the Tensor - # universe instead - - def _print_TruncToFloat(self, expr): - raise NotImplementedError( - f"_print_TruncToFloat not implemented for {type(self)}" - ) - def doprint(self, expr, *, simplify: bool = True): # TODO: why are people passing strings to the printer here :think: if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): @@ -484,10 +411,6 @@ def doprint(self, expr, *, simplify: bool = True): class PythonPrinter(ExprPrinter): - def _print_ToFloat(self, expr): - assert len(expr.args) == 1 - return f"float({self._print(expr.args[0])})" - def _print_ModularIndexing(self, expr): x, div, mod = expr.args x = self.paren(self.doprint(x)) @@ -497,72 +420,56 @@ def _print_ModularIndexing(self, expr): x = f"({x} // {div})" return f"{x} % {mod}" - def _print_Infinity(self, expr): - return "math.inf" - - def _print_NegativeInfinity(self, expr): - return "-math.inf" - - # WARNING: this is dangerous for Triton, which has C-style modulus - def _print_PythonMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - # WARNING: this is dangerous for Triton, which has C-style modulus def _print_FloorDiv(self, expr): x, div = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) return f"({x} // {div})" - # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python - # does a special algorithm - def _print_IntTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" - def _helper_sqrt(self, expr): return f"math.sqrt({self._print(expr)})" def _print_OpaqueUnaryFn_sqrt(self, expr): return self._helper_sqrt(expr.args[0]) - def _print_FloatPow(self, expr): - base, exp = expr.args - return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" - - # TODO: Not sure this works with Triton, even when base/exp are integral - def _print_PowByNatural(self, expr): + def _print_Pow(self, expr): + # Pow() confuses triton base, exp = expr.args - return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + # NB: Remember this is sizevar computation! You don't typically + # expect to have to do floating point computation including exponents + # in sizevar compute. Instead of adding support for floating + # point pow, you should make upstream retranslate the Sympy expression + # into Tensor expressions earlier and do that instead. + if exp == 0.5: + return self._helper_sqrt(base) + elif exp == -0.5: + return "1/" + self._helper_sqrt(base) + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + if exp > 0: + return "*".join([self.paren(base)] * exp) + elif exp < 0: + return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + return "1" def _print_floor(self, expr): assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" - def _print_FloorToInt(self, expr): - assert len(expr.args) == 1 - return f"math.floor({self._print(expr.args[0])})" - - def _print_TruncToInt(self, expr): + def _print_Trunc(self, expr): assert len(expr.args) == 1 - # This also could have been int(), they'll do the same thing for float return f"math.trunc({self._print(expr.args[0])})" def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"math.ceil({self._print(expr.args[0])})" - def _print_CeilToInt(self, expr): - assert len(expr.args) == 1 - return f"math.ceil({self._print(expr.args[0])})" - def _print_Abs(self, expr): assert len(expr.args) == 1 return f"abs({self._print(expr.args[0])})" - # NB: It's expected that we've made explicit any promotion in the sympy - # expression, so it doesn't matter that Python max/min doesn't perform - # promotion def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" @@ -607,7 +514,7 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" - def _print_RoundToInt(self, expr): + def _print_Round(self, expr): assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" @@ -746,29 +653,6 @@ def remainder(a, b): ) return ops.where(cond, ops.add(r, b), r) - @staticmethod - def trunc_to_int(a, dtype): - return ops.to_dtype(ops.trunc(a), dtype) - - @staticmethod - def floor_to_int(a, dtype): - return ops.to_dtype(ops.floor(a), dtype) - - @staticmethod - def ceil_to_int(a, dtype): - return ops.to_dtype(ops.ceil(a), dtype) - - @staticmethod - def round_to_int(a, dtype): - return ops.to_dtype(ops.round(a), dtype) - - @staticmethod - def int_truediv(a, b): - # TODO: this is wrong - # TODO: an easy bandaid is to generate runtime asserts that it's - # <= 2**53, which is when this equation is correct - return ops.truediv(a, b) - @staticmethod def load_seed(name, offset): return ops.load(name, sympy.Integer(offset)) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 311781102c3f..eabb5bbef470 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -275,11 +275,11 @@ def visit_modular_indexing(divisor, modulus): original_index = index - div = sympy.Wild("divisor", integer=True) + div = sympy.Wild("divisor") if index.has(FloorDiv): index = index.replace(FloorDiv(var, div), visit_indexing_div) - mod = sympy.Wild("modulus", integer=True) + mod = sympy.Wild("modulus") if index.has(ModularIndexing): index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index aac0c20df0c6..4ab33a5e26dc 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -100,53 +100,10 @@ def _print_floor(self, expr): r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - def _print_FloorToInt(self, expr): - assert len(expr.args) == 1 - r = f"std::floor({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_TruncToInt(self, expr): + def _print_Trunc(self, expr): assert len(expr.args) == 1 r = f"std::trunc({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" - - def _print_TruncToFloat(self, expr): - assert len(expr.args) == 1 - return f"std::trunc({self._print(expr.args[0])})" - - def _print_ToFloat(self, expr): - assert len(expr.args) == 1 - return f"static_cast({self._print(expr.args[0])})" - - # TODO: This is wrong if one of the inputs is negative. This is hard to - # tickle though, as the inputs are typically positive (and if we can prove - # they are positive, we will have used Mod instead, for which this codegen - # is right). - def _print_PythonMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - def _print_CMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - def _print_IntTrueDiv(self, expr): - lhs, rhs = expr.args - # TODO: This is only accurate up to 2**53 - return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" - - # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT - # use std::pow, that operates on floats - def _print_PowByNatural(self, expr): - raise NotImplementedError( - f"_print_PowByNatural not implemented for {type(self)}" - ) - - def _print_FloatTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" - - def _print_FloatPow(self, expr): - base, exp = expr.args - return f"std::pow({self._print(base)}, {self._print(exp)})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r def _print_Pow(self, expr): # Uses float constants to perform FP div @@ -182,11 +139,6 @@ def _print_ceiling(self, expr): r = f"std::ceil({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - def _print_CeilToInt(self, expr): - assert len(expr.args) == 1 - r = f"std::ceil({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - def _print_Min(self, expr): args = [self._print(a) for a in expr.args] if len(args) == 2: @@ -248,9 +200,8 @@ def _print_OpaqueUnaryFn_atan(self, expr): def _print_OpaqueUnaryFn_sqrt(self, expr): return f"std::sqrt({self._print(expr.args[0])})" - def _print_RoundToInt(self, expr): + def _print_Round(self, expr): assert len(expr.args) == 1 - # TODO: dispatch to llrint depending on index type return f"std::lrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 104d24585de2..daf329ee9b80 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -285,68 +285,23 @@ def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]): return f"{value}[{', '.join(expand)}]" -# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a -# number of operators which Triton "implements", but in a way that is -# inconsistent with Python semantics (and consistent with C semantics). We -# must override all of these, or it is potential silent correctness problem class TritonPrinter(PythonPrinter): - def _print_TruncToInt(self, expr): - assert len(expr.args) == 1 - return ( - f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" - ) - - def _print_ToFloat(self, expr): - assert len(expr.args) == 1 - return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" - - # TODO: This is wrong if one of the inputs is negative. This is hard to - # tickle though, as the inputs are typically positive (and if we can prove - # they are positive, we will have used Mod instead, for which this codegen - # is right). If you are trying to hit this, maybe try something like - # torch.arange(n, device="cuda") - 1 and then do a modulus on it - def _print_PythonMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - # TODO: This is wrong, see - # https://github.com/triton-lang/triton/issues/955 - # But for Sympy expressions, things will /mostly/ work out because we - # don't usually deal with negative numbers in the division - def _print_FloorDiv(self, expr): - assert expr.is_integer - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"({x} // {div})" - - # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher - # precision algorithm, which we would need to replicate here - def _print_IntTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" - - # NB: sympy.floor/ceiling produce integers, so we have to do the - # conversion to index dtype def _print_floor(self, expr): assert len(expr.args) == 1 return ( f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) - def _print_FloorToInt(self, expr): + def _print_Trunc(self, expr): assert len(expr.args) == 1 return ( - f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" - def _print_CeilToInt(self, expr): - assert len(expr.args) == 1 - return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" - def _helper_sqrt(self, expr): return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))" @@ -417,9 +372,20 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" - def _print_RoundToInt(self, expr): + def _print_FloorDiv(self, expr): + if expr.is_integer: + return super()._print_FloorDiv(expr) + + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"libdevice.floor({x} / {div}).to({V.kernel.index_dtype})" + + def _print_Round(self, expr): assert len(expr.args) == 1 - return f"libdevice.llrint({self._print(expr.args[0])})" + return ( + f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) def _print_RoundDecimal(self, expr): assert len(expr.args) == 2 diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index d5ec55afd05e..ca739eecb196 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1193,11 +1193,8 @@ def debug(msg): elif is_magic_method(n.target): # TODO: this is sus, it probably should be handled in the # lowerings themselves similarly to sym_size/sym-stride - # https://github.com/pytorch/pytorch/issues/127789 debug("is_magic_method") - if isinstance( - n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) - ): + if isinstance(n.meta["val"], torch.SymInt): result = n.meta["val"].node.expr else: result = super().run_node(n) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index e9adfcd19a2d..c46cad5e41e2 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -44,6 +44,7 @@ is_boolean_dtype, is_float_dtype, make_channels_last_strides_for, + make_contiguous_strides_for, StrideType, ) from torch._subclasses.fake_tensor import get_schema_info @@ -235,7 +236,7 @@ def ir_node_to_tensor(x, guard_shape=True): if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] else: - stride = FlexibleLayout.contiguous_strides(size) # type: ignore[arg-type] + stride = make_contiguous_strides_for(size) # type: ignore[arg-type] dtype = x.get_dtype() device = x.get_device() size = convert_shape_to_symint(size) @@ -2765,7 +2766,6 @@ class FlexibleLayout(Layout): allow_indexing = False - # WARNING! This doesn't handle zero size tensors correctly @staticmethod def contiguous_strides(sizes): if len(sizes) == 0: @@ -5915,7 +5915,7 @@ def _original_deconv_weight_size( # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) if dynamic_shapes and is_contiguous_storage_and_layout(x): - output_stride = FlexibleLayout.contiguous_strides(output_size) + output_stride = make_contiguous_strides_for(output_size) else: output_stride = make_channels_last_strides_for(output_size) @@ -5967,7 +5967,7 @@ def _prepare_linear_fusion_create( assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" inputs = [x, weight] - output_stride = FlexibleLayout.contiguous_strides(output_size) + output_stride = make_contiguous_strides_for(output_size) kernel_layout = FixedLayout( x.get_device(), x.get_dtype(), @@ -6283,7 +6283,7 @@ def create(cls, x, packed_w, orig_w, B, batch_size): *m, _ = x.get_size() oc, _ = orig_w.get_size() output_size = list(m) + [oc] - output_stride = FlexibleLayout.contiguous_strides(output_size) + output_stride = make_contiguous_strides_for(output_size) inputs = [x, packed_w, orig_w] constant_args = [batch_size] if B is not None: @@ -6601,13 +6601,13 @@ def create( def get_strides_of_lstm_output(output_shape, batch_first): assert len(output_shape) == 3, "Expect output_shape to be 3D" - return FlexibleLayout.contiguous_strides(output_shape) + return make_contiguous_strides_for(output_shape) output_sizes = [output_shape, hy_shape, cy_shape] output_strides = [ get_strides_of_lstm_output(output_shape, batch_first), - FlexibleLayout.contiguous_strides(hy_shape), - FlexibleLayout.contiguous_strides(cy_shape), + make_contiguous_strides_for(hy_shape), + make_contiguous_strides_for(cy_shape), ] output_ir = [ MultiOutput( diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index f3492949a84d..42fabf65591d 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -5,6 +5,7 @@ from typing import Any, List, Tuple import torch +from torch._prims_common import make_contiguous_strides_for from .. import config from ..ir import ( ComputedBuffer, @@ -388,7 +389,7 @@ def flex_attention(*args, **kwargs): query.get_device(), query.get_dtype(), query.get_size(), - FlexibleLayout.contiguous_strides(query.get_size()), + make_contiguous_strides_for(query.get_size()), ) # see NOTE:[TritonTemplates with multiple outputs] logsumexp_shape = query.get_size()[:-1] # [B, H, M] @@ -744,7 +745,7 @@ def flex_attention_backward(*args, **kwargs): key.get_device(), key.get_dtype(), key.get_size(), - FlexibleLayout.contiguous_strides(key.get_size()), + make_contiguous_strides_for(key.get_size()), ) # Create delta which will is needed for the bwd's kernel diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index deec9b13e566..0a1909890e69 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -34,7 +34,7 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 @@ -4262,7 +4262,7 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): out_sz = out_sz[dim] in_sz = in_sz[dim] kernel_sz = kernel_sz[dim] - alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) + alpha = (in_sz - kernel_sz) / (out_sz - 1) samples_loader = samples.make_loader() def load(prefix, i): @@ -4372,7 +4372,7 @@ def upsample_nearest2d_backward( w_kernel_max = ceildiv(inp_w, out_w) def start_index(index, out_dim, inp_dim): - return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) + return CeilDiv(index * inp_dim, out_dim) def end_index(index, out_dim, inp_dim): return start_index((index + 1), out_dim, inp_dim) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index f88cd948ca4d..5630061b4426 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -138,38 +138,6 @@ def to_dtype( """ ... - def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: - """ - Convert x to dtype with truncation semantics (similar to how the int - constructor works in Python). In Inductor codegen, this just decays - to trunc and then to_dtype, but this composite operation helps - roundtrips for Sympy evaluation. - - dtype is taken as an explicit parameter because the desired output - dtype is typically the index dtype, which may vary between int32 and - int64 depending on if we've shown that all the indexing operations can - be done in int32. - """ - ... - - def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: - """ - Convert x to dtype with ceiling semantics. See also trunc_to_int. - """ - ... - - def floor_to_int(self, x: T, dtype: torch.dtype) -> T: - """ - Convert x to dtype with ceiling semantics. See also trunc_to_int. - """ - ... - - def round_to_int(self, x: T, dtype: torch.dtype) -> T: - """ - Convert x to dtype with round-to-even semantics. See also trunc_to_int. - """ - ... - def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: """ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) @@ -430,23 +398,21 @@ def isinf(self, x0: T) -> T: def isnan(self, x0: T) -> T: ... - # NB: this returns a float, like the torch operation - # This rounds half to even to break ties def round(self, x0: T) -> T: ... - # NB: this returns a float, like the torch operation def floor(self, x0: T) -> T: ... def sign(self, x0: T) -> T: ... - # NB: this returns a float, like the torch operation + def to_int(self, x0: T) -> T: + ... + def trunc(self, x0: T) -> T: ... - # NB: this returns a float, like the torch operation def ceil(self, x0: T) -> T: ... @@ -483,7 +449,6 @@ def sub(self, x0: T, x1: T) -> T: def mul(self, x0: T, x1: T) -> T: ... - # NB: this returns a float, like the torch operation def pow(self, x0: T, x1: T) -> T: ... @@ -652,21 +617,14 @@ def truncdiv(self, x0: T, x1: T) -> T: def floordiv(self, x0: T, x1: T) -> T: """Python-style floor division between integers only. Computes the - true division of two numbers and floors the result. If you want - floor division for floats, do regular truediv and floor the result. + true division of two numbers and floors the result. """ ... def truediv(self, x0: T, x1: T) -> T: - """True division between floats. Integer inputs are NOT valid. To - do Python-style (int, int) -> float division, use int_truediv""" - ... - - def int_truediv(self, x0: T, x1: T) -> T: - """True division between integers. This is NOT the same as promoting - to float and doing integer division, there is a bespoke algorithm for - doing the division in higher precision than the above. - """ + """True division between floats. Integer inputs are NOT valid: to do + Python style (int, int) -> float division, promote the inputs to float + first.""" ... def div(self, x0: T, x1: T) -> T: @@ -682,10 +640,6 @@ def remainder(self, x0: T, x1: T) -> T: """Python-style modulus, take sign from RHS (x1).""" ... - def round_decimal(self, x0: T, x1: T) -> T: - """Python-style round with decimal argument""" - ... - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # In CUDA, optimized implementations of other mathematical operations are # offered separately via libdevice for double precision computation (in diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index a1b029aa2883..5e5cbf35baf9 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -386,7 +386,7 @@ def store_output( assert isinstance(mask, (str, type(None))) assert self.template_mask is None indices = list(map(TritonPrinter.paren, indices)) - index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + index_symbols = [sympy.Symbol(x) for x in indices] lengths = [ V.graph.sizevars.simplify(s) for s in self.output_node.get_size() ] @@ -410,7 +410,7 @@ def store_output( output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) if output_index == contiguous_index: - output_index = sympy.Symbol("xindex", integer=True) + output_index = sympy.Symbol("xindex") epilogue_args = [val] for input_node in itertools.chain( diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index fba9a66f9237..bc8803a5e715 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -161,9 +161,9 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(ModularIndexing): expr = expr.replace( ModularIndexing( - sympy.Wild("base", integer=True), - sympy.Wild("divisor", integer=True), - sympy.Wild("modulus", integer=True), + sympy.Wild("base"), + sympy.Wild("divisor"), + sympy.Wild("modulus"), ), visit_modular_indexing, ) @@ -171,8 +171,8 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(FloorDiv): expr = expr.replace( FloorDiv( - sympy.Wild("base", integer=True), - sympy.Wild("divisor", integer=True), + sympy.Wild("base"), + sympy.Wild("divisor"), ), visit_indexing_div, ) @@ -604,11 +604,11 @@ def _join_dimensions_cached(expr: Expr) -> Expr: """ assert isinstance(expr, sympy.Add) - scale = sympy.Wild("scale", exclude=[0], integer=True) - base = sympy.Wild("base", integer=True) - divisor = sympy.Wild("divisor", integer=True) - mod1 = sympy.Wild("modulus", integer=True) - mod2 = sympy.Wild("modulus2", integer=True) + scale = sympy.Wild("scale", exclude=[0]) + base = sympy.Wild("base") + divisor = sympy.Wild("divisor") + mod1 = sympy.Wild("modulus") + mod2 = sympy.Wild("modulus2") for term1 in expr.args: m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) if m1: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index a635c2f509c1..0915a8330c34 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -192,7 +192,7 @@ def ceildiv( numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] ) -> Union[int, sympy.Expr]: if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): - return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) + return CeilDiv(numer, denom) # TODO: There is a bug in a call to this function, to repro: # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy # --amp --only YituTechConvBert --dynamic-shapes diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 9343490de3e8..47d4abcf77b9 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1727,7 +1727,7 @@ def go(t, real_t): for run_impl_check, op_impl in op_implementations_checks: if run_impl_check(func): op_impl_out = op_impl(self, func, *args, **kwargs) - if op_impl_out is not NotImplemented: + if op_impl_out != NotImplemented: return maybe_propagate_real_tensors(op_impl_out) def maybe_run_unsafe_fallback(error=None): diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 2a3cb62c56d7..a7ce337f9ac8 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1200,13 +1200,8 @@ void initJITBindings(PyObject* module) { SYMNODE_BINARY(sub) SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) - SYMNODE_BINARY(int_truediv) - SYMNODE_BINARY(float_truediv) SYMNODE_BINARY(pow) - SYMNODE_BINARY(float_pow) - SYMNODE_BINARY(pow_by_natural) SYMNODE_BINARY(floordiv) - SYMNODE_BINARY(int_floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(eq) SYMNODE_BINARY(ne) diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index 15738b1a67e1..f8c710cf6579 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -198,34 +198,14 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__func__, other); } - c10::SymNode float_truediv(const c10::SymNode& other) override { - return dispatch_common_(__func__, other); - } - - c10::SymNode int_truediv(const c10::SymNode& other) override { - return dispatch_common_(__func__, other); - } - c10::SymNode pow(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } - c10::SymNode float_pow(const c10::SymNode& other) override { - return dispatch_common_(__func__, other); - } - - c10::SymNode pow_by_natural(const c10::SymNode& other) override { - return dispatch_common_(__func__, other); - } - c10::SymNode floordiv(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } - c10::SymNode int_floordiv(const c10::SymNode& other) override { - return dispatch_common_(__func__, other); - } - c10::SymNode mod(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index e98e83af340f..43ab56c10501 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,6 +1,7 @@ import builtins import dataclasses import inspect +import math import sys import weakref from collections import defaultdict @@ -265,14 +266,11 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): shared: Optional[_ConstraintTarget] = None debug_name: Optional[str] = None - def _clone_with_range(self, lower=0, upper=None): + def _clone_with_range(self, lower=0, upper=math.inf): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges - if upper is None: - upper = sys.maxsize - 1 - constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), warn_only=False, @@ -500,6 +498,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): ) # Import sympy locally + import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges @@ -509,7 +508,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): id(t), index, StrictMinMaxConstraint( - vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False + vr=ValueRanges(lower=0, upper=sympy.oo), warn_only=False ), debug_name=debug_name, ) diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 28df3fddab0e..4bf9ebab17b3 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -277,13 +277,7 @@ def wrapper(*args, **kwargs): raise except Exception: - log.error( # noqa: G201 - "failed while running %s(*%s, **%s)", - name, - args[1:], - kwargs, - exc_info=log.isEnabledFor(logging.INFO), - ) + log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs) raise return wrapper diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index c7f0aba9fac4..98cba67a73a1 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -267,11 +267,8 @@ def mul(self, other) -> "SymNode": def mod(self, other) -> "SymNode": return self._mod(other) # type: ignore[attr-defined] - def float_pow(self, other) -> "SymNode": - return self._float_pow(other) # type: ignore[attr-defined] - - def pow_by_natural(self, other) -> "SymNode": - return self._pow_by_natural(other) # type: ignore[attr-defined] + def pow(self, other) -> "SymNode": + return self._pow(other) # type: ignore[attr-defined] def and_(self, other) -> "SymNode": return self._and_(other) # type: ignore[attr-defined] @@ -279,14 +276,11 @@ def and_(self, other) -> "SymNode": def or_(self, other) -> "SymNode": return self._or_(other) # type: ignore[attr-defined] - def float_truediv(self, other) -> "SymNode": - return self._float_truediv(other) # type: ignore[attr-defined] - - def int_truediv(self, other) -> "SymNode": - return self._int_truediv(other) # type: ignore[attr-defined] + def truediv(self, other) -> "SymNode": + return self._truediv(other) # type: ignore[attr-defined] - def int_floordiv(self, other) -> "SymNode": - return self._int_floordiv(other) # type: ignore[attr-defined] + def floordiv(self, other) -> "SymNode": + return self._floordiv(other) # type: ignore[attr-defined] def lshift(self, other) -> "SymNode": return self._lshift(other) # type: ignore[attr-defined] @@ -367,17 +361,6 @@ def sym_or(self, other): def sym_and(self, other): return self.and_(other) - # There is no int_truediv available from C++ - def truediv(self, other): - return self.float_truediv(other) - - def floordiv(self, other) -> "SymNode": - return self.int_floordiv(other) - - # We didn't bind integer pow in C++ - def pow(self, other): - return self.float_pow(other) - def is_non_overlapping_and_dense(self, sizes, strides): return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] @@ -494,7 +477,7 @@ def is_constant(self): "eq": operator.eq, "floor": math.floor, "trunc": math.trunc, - "int_floordiv": operator.floordiv, + "floordiv": operator.floordiv, "ge": operator.ge, "gt": operator.gt, "is_integer": lambda x: x.is_integer(), @@ -506,8 +489,7 @@ def is_constant(self): "ne": operator.ne, "neg": operator.neg, "or": operator.or_, - "float_pow": operator.pow, - "pow_by_natural": operator.pow, + "pow": operator.pow, "round": builtins.round, "rshift": operator.rshift, "sub": operator.sub, @@ -516,14 +498,12 @@ def is_constant(self): "sym_max": sym_max, "sym_min": sym_min, "sym_not": sym_not, - "float_truediv": operator.truediv, - "int_truediv": operator.truediv, + "truediv": operator.truediv, } unary_magic_methods = { "abs", "sym_float", - "sym_int", "ceil", "floor", "neg", @@ -579,20 +559,20 @@ def fn(self): bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods # Methods that are only for float -only_float_magic_methods = {"is_integer", "round", "sym_int"} +only_float_magic_methods = {"is_integer"} magic_methods_on_operator_with_trailing_underscore = {"and", "or"} -always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} +always_float_magic_methods = {"truediv", "sym_float", "pow"} for name in math_op_names: sym_name = f"sym_{name}" always_float_magic_methods.add(sym_name) -always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} +always_int_magic_methods = {"ceil", "floor", "trunc"} always_bool_magic_methods = { "eq", "ne", @@ -610,16 +590,10 @@ def fn(self): # Methods that have a `__foo__` as well as `__rfoo__` -def _sympy_float_truediv(a, b): - from torch.utils._sympy.functions import FloatTrueDiv +def _sympy_truediv(a, b): + from torch.utils._sympy.functions import TrueDiv - return FloatTrueDiv(a, b) - - -def _sympy_int_truediv(a, b): - from torch.utils._sympy.functions import IntTrueDiv - - return IntTrueDiv(a, b) + return TrueDiv(a, b) def _sympy_floordiv(a, b): @@ -629,24 +603,15 @@ def _sympy_floordiv(a, b): def _sympy_mod(a, b): - from torch.utils._sympy.functions import Mod, PythonMod - - if a.is_nonnegative and b.is_nonnegative: - return Mod(a, b) - else: - return PythonMod(a, b) - + from torch.utils._sympy.functions import Mod -def _sympy_pow_by_natural(a, b): - from torch.utils._sympy.functions import PowByNatural + return Mod(a, b) - return PowByNatural(a, b) +def _sympy_pow(a, b): + from torch.utils._sympy.functions import Pow -def _sympy_float_pow(a, b): - from torch.utils._sympy.functions import FloatPow - - return FloatPow(a, b) + return Pow(a, b) def _sympy_and(a, b): @@ -678,13 +643,11 @@ def _sympy_rshift(a, b): "sub": operator.sub, "mul": operator.mul, "mod": _sympy_mod, - "pow_by_natural": _sympy_pow_by_natural, - "float_pow": _sympy_float_pow, + "pow": _sympy_pow, "and": _sympy_and, "or": _sympy_or, - "float_truediv": _sympy_float_truediv, - "int_truediv": _sympy_int_truediv, - "int_floordiv": _sympy_floordiv, + "truediv": _sympy_truediv, + "floordiv": _sympy_floordiv, "lshift": _sympy_lshift, "rshift": _sympy_rshift, } @@ -709,23 +672,21 @@ def _floor_ceil_helper(a, fn): def _sympy_floor(a): - from torch.utils._sympy.functions import FloorToInt + import sympy - return FloorToInt(a) + return _floor_ceil_helper(a, sympy.floor) -# NB: this is Python trunc semantics which returns an int. Do NOT use this to -# represent torch.trunc (which is float to float) def _sympy_trunc(a): - from torch.utils._sympy.functions import TruncToInt + from torch.utils._sympy.functions import Trunc - return TruncToInt(a) + return Trunc(a) def _sympy_ceil(a): - from torch.utils._sympy.functions import CeilToInt + import sympy - return CeilToInt(a) + return _floor_ceil_helper(a, sympy.ceiling) def _sympy_eq(a, b): @@ -810,28 +771,26 @@ def _sympy_abs(a): def _sympy_round(number, ndigits=None): - from torch.utils._sympy.functions import RoundDecimal, RoundToInt + from torch.utils._sympy.functions import Round, RoundDecimal if ndigits is None: - return RoundToInt(number) + return Round(number) else: return RoundDecimal(number, ndigits) def _sympy_sym_float(a): - from torch.utils._sympy.functions import ToFloat - - # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly - # reports that it is an integer - return ToFloat(a) + # Cannot use sympy.Float(a) here, coz it expects python literals + # Multiply by 1.0 to cast to float. This is needed when the input + # is a SymInt which has the assumption that it is integer and + # SymPy will otherwise assume that return value cannot be a float. + return a * 1.0 def _sympy_is_integer(a): import sympy - from torch.utils._sympy.functions import ToFloat - - return sympy.Eq(ToFloat(sympy.floor(a)), a) + return sympy.Eq(sympy.floor(a), a) magic_methods = { @@ -1030,26 +989,9 @@ def binary_magic_impl(self, other): self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) ) assert isinstance(other, SymNode) + # TODO: consider constant prop here try: - if method == "mod": - from torch.utils._sympy.functions import Mod, PythonMod - - # Special handling for mod that requires access to the value - # ranges - shape_env = self.shape_env - if ( - self.expr.is_nonnegative - or shape_env.bound_sympy(self.expr).lower >= 0 - ) and ( - other.expr.is_nonnegative - or shape_env.bound_sympy(other.expr).lower >= 0 - ): - out = Mod(self.expr, other.expr) - else: - out = PythonMod(self.expr, other.expr) - else: - # TODO: consider constant prop here - out = func(self.expr, other.expr) + out = func(self.expr, other.expr) except Exception: log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) raise @@ -1180,13 +1122,9 @@ def round_impl(self, ndigits=None): except Exception: log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) raise - out = safe_expand(out) - if ndigits is None: - pytype = int - else: - pytype = self.pytype + pytype = int if ndigits is None else self.pytype out_hint = None if self.hint is not None: @@ -1198,7 +1136,6 @@ def round_impl(self, ndigits=None): # hack down below works, because all round function down the line all take ndigits=None as default in their # signature. # TODO: Remove the args construction below if a different sentinel is used by FX. - # ezyang(May 2024): LOL args = [self.fx_node] if ndigits is not None: args.append(ndigits) @@ -1322,32 +1259,6 @@ def is_constant(x): return x.node.is_constant() return False - # Promotion rules for binary operations. NB: we preserve PYTHON semantics - # - if args are same type, do nothing - # - if one arg is float, promote other arg to float - # - nb: this applies to floordiv, even though output is integral - # (it's still float) - # - pow is funny business - # - if both ints - # - trigger a guard on exponent >= 0 - # - if non-negative, output is int - # - otherwise, output is float - # - otherwise, promote other arg to float - # - nb: complex is impossible to handle correctly lol, with - # negative base and integral float need to diverge semantics and - # just always return complex. Neener neener pretend this problem - # doesn't exist - # - equality is pain: Python does the fancy thing where it unpacks the - # mantissa from the float and then compares that against the int. - # Which means it is able to tell that - # 9007199254740993 != 9007199254740992. (rather than if the LHS was - # promoted to float, in which case it would have truncated to the RHS - # and subsequently been equal). We'll model this exactly by having - # special mixed type equality operations. Unfortunately, we need to - # do this for all comparison operations (maybe I'll only implement - # compare) - # - sym_ite mumble mumble really shouldn't allow mixed but whatever - if method in bool_becomes_int_magic_methods: def promote(x): @@ -1361,41 +1272,6 @@ def promote(x): def promote(x): return x - def promote2(self, other): - # TODO: Remove eq and other relations from this list. - # CPython has fancy implementations for these to get as much precision - # as possible instead of just promoting to float64 and praying, so we - # need to handle them specially too. - # Also, note that int_truediv doesn't go through this path: both - # arguments are "int" so there isn't any promotion - if method not in [ - "add", - "sub", - "mul", - "mod", - "float_pow", - "float_truediv", - "int_floordiv", - "sym_min", - "sym_max", - # TODO: remove these - "eq", - "ne", - "gt", - "lt", - "le", - "ge", - ]: - return self, other - f_self = isinstance(self, (float, torch.SymFloat)) - f_other = isinstance(other, (float, torch.SymFloat)) - if f_self or f_other: - if not f_self: - self = torch.sym_float(self) - if not f_other: - other = torch.sym_float(other) - return self, other - # Before and after performing the operation, check if any operands are constant. # If so, extract out the constant values first. If `self` itself is a # constant, then "redispatch" by calling back into the operator. Sometimes @@ -1410,12 +1286,9 @@ def unary_magic_impl(self): return wrap_node(getattr(self.node, method_attr)()) def binary_magic_impl(self, other): - if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): - return NotImplemented sym_node_log.debug("MAGIC %s %s %s", method, self, other) self = promote(self) other = promote(other) - self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): @@ -1427,11 +1300,8 @@ def binary_magic_impl(self, other): return get_constant(ret) if is_constant(ret) else ret def rbinary_magic_impl(self, other): - if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): - return NotImplemented self = promote(self) other = promote(other) - self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 42ab606e7827..544950298861 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -61,7 +61,7 @@ from torch import SymBool, SymFloat, SymInt from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from torch.utils._sympy.functions import FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv +from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.singleton_int import SingletonInt @@ -869,9 +869,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): for N=1. """ if min is None: - min = -sys.maxsize - 1 + min = -sympy.oo if max is None: - max = sys.maxsize - 1 + max = sympy.oo if max < min: raise ValueError( @@ -979,6 +979,16 @@ def eval_guards(gm, *args, ignore_static=True): def bind_symbols(gm, *args): return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) +def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges): + """ + We assert that the bounds are either Boolean, or not finite, or can be computed + in exact prevision via rational arithmetic. + The only exception to this is the rare case when the user calls `sqrt(s0)` + sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still) + """ + assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr) + assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr) + class DimDynamic(Enum): """ Controls how to perform symbol allocation for a dimension. It is always @@ -1377,19 +1387,14 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 'Min': min, 'Max': max, 'Mod': operator.mod, - 'PythonMod': operator.mod, 'FloorDiv': operator.floordiv, 'TrueDiv': operator.truediv, 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 'floor': math.floor, 'ceiling': math.ceil, - 'FloorToInt': math.floor, - 'CeilToInt': math.ceil, 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, - 'RoundToInt': builtins.round, + 'Round': builtins.round, 'RoundDecimal': builtins.round, - 'TruncToInt': math.trunc, - 'IntTrueDiv': operator.truediv, } @@ -1637,17 +1642,10 @@ def floor_div_handler(*args): congruence = (base - mod_reduced) % divisor if congruence != 0: self._congruences[s].add(congruence) - # NB: Must not be CleanDiv, it needs to be regular sympy division - # so inequality solver works. This is sort of problematic for - # is_integer tests though haha return (base - mod_reduced) / divisor if expr.has(Mod): expr = expr.replace(Mod, mod_handler) - # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative - # arguments should be OK. - if expr.has(PythonMod): - expr = expr.replace(PythonMod, mod_handler) if expr.has(FloorDiv): expr = expr.replace(FloorDiv, floor_div_handler) return expr @@ -3327,7 +3325,6 @@ def create_unbacked_symfloat(self): self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() - assert vr.is_float # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, float) @@ -3346,7 +3343,6 @@ def create_unbacked_symint(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = self._default_unspecified_value_range() - assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, int) @@ -3370,7 +3366,6 @@ def create_unbacked_symbool(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges(0, 1) - assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) @@ -3516,7 +3511,6 @@ def create_symbol( self.var_to_range[sympy_expr] &= constraint_dim.vr vr = self.var_to_range[sympy_expr] - assert vr.is_int if val not in vr: raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") @@ -3525,7 +3519,6 @@ def create_symbol( elif isinstance(val, float): self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) range_str = f"[{vr.lower}, {vr.upper}]" - assert vr.is_float else: # Skip var_range logic for SingletonInt # Only used for jagged layout nested tensors @@ -3575,7 +3568,6 @@ def create_symbol( def add_var_to_val(self, expr: sympy.Symbol, val: int): """ Adds a new symbol to the symbolic environment. """ - log.debug("add_var_to_val %s %s", expr, val, stack_info=True) assert expr not in self.var_to_val, f"{expr} already exists" self.var_to_val[expr] = sympy.Integer(val) @@ -4330,8 +4322,7 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa # Clamp values of size-like variables for x in self.size_like & var_to_range.keys(): if var_to_range[x] is not None: - var_to_range[x] = ValueRanges(2, sys.maxsize - 1) - assert var_to_range[x].is_int + var_to_range[x] = ValueRanges(2, sympy.oo) return bound_sympy(expr, var_to_range) @_lru_cache @@ -4448,11 +4439,6 @@ def _maybe_evaluate_static( vr = self._default_unspecified_value_range() if size_oblivious and k in self.size_like: lower = max(2, vr.lower) - # This is a bit dodgy: what this means is that there was a - # size-like unbacked symbol whose upper bound < 2. This - # causes... problems. - if lower <= vr.upper: - vr = ValueRanges(lower, vr.upper) else: lower = vr.lower # Don't do anything if we don't have a nontrivial lower bound @@ -4460,17 +4446,10 @@ def _maybe_evaluate_static( # SymInt if ( lower < (-sys.maxsize - 1) // 2 or - (unbacked_only and k in self.var_to_val) or - not vr.is_int + (unbacked_only and k in self.var_to_val) ): new_range_env[k] = vr continue - # The goal is to take our symbols which have various lower bounds - # and reallocate them into new symbols which are exactly positive; - # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in - # [1, inf], where s0 = ess0 + 1. This gives the most information - # to sympy for subsequent simplifications. - # # Positive means >= 1 # Positive - 1 means >= 0 # Positive + lower - 1 means >= lower @@ -4502,14 +4481,6 @@ def replace(expr, repl): self.counter["sympy_recursion_error"] += 1 return None - new_expr = safe_expand(new_expr) - if new_expr.is_number: - return new_expr - - # This is bad to do, the replacement with division leaves us with - # rationals when atom.args[0] is addition, e.g., sympy will happily - # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication! - """ floor_div_replace = {} for atom in new_expr.atoms(FloorDiv): floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) @@ -4518,12 +4489,13 @@ def replace(expr, repl): # are still free symbols if new_expr.is_number: return new_expr - """ # Check if the range can solve it statically out = bound_sympy(new_expr, new_range_env) - if out.is_singleton(): - return out.lower + if expect_rational: + _assert_bound_is_rational(new_expr, out) + if out.is_singleton(): + return out.lower return new_expr if unbacked_only else None @@ -4575,7 +4547,7 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": for fd in expr.atoms(FloorDiv): base, divisor = fd.args if self.replace(Mod(base, divisor)) in self.divisible: - div_replacements[fd] = CleanDiv(base, divisor) + div_replacements[fd] = base / divisor new_expr = expr.xreplace(div_replacements) new_expr = safe_expand(new_expr) new_pows = new_expr.atoms(sympy.Pow) @@ -4719,10 +4691,7 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1) def issubset(x, y): - if x.is_int and y.is_int: - return (x & int_range).issubset(y & int_range) - else: - return x.issubset(y) + return (x & int_range).issubset(y & int_range) # First, refine the value range of a based on the computed value range # of tgt. This is always OK to do, even if we decide not to do the @@ -4740,7 +4709,7 @@ def issubset(x, y): b = next(iter(tgt.free_symbols)) # Try to invert the equality r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) - if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])): + if r is not None: b_bound = self.bound_sympy(r[1]) self.var_to_range[b] = b_bound & self.var_to_range[b] tgt_bound = self.bound_sympy(tgt) @@ -4951,12 +4920,12 @@ def trivial_solve(lhs, rhs): ): # We have Mod(i0, q / c) == 0, which means we can # rewrite i0 as (q / gcd(q, c)) * i1 - d = q / sympy.gcd(q, c) # TODO: CleanDiv? + d = q / sympy.gcd(q, c) i1 = self.create_unbacked_symint().node.expr # Propagate the value ranges. It doesn't really # matter if we use truediv or floordiv, because we # have established divisibility. - self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv( + self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv( self.var_to_range[i0], ValueRanges.wrap(d) )) # Propagate size-like-ness @@ -5393,6 +5362,7 @@ def _refine_ranges(self, expr: sympy.Expr) -> None: lower, upper = vr.lower, vr.upper rhs_vr = bound_sympy(rhs, self.var_to_range) + _assert_bound_is_rational(rhs, rhs_vr) # Let's suppose that we have a preexisting range for x [0, 100]. # Now, we issue a guard x > y, where the range for y is [50, 150]. diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index d06b38d60c80..6dcb59db7979 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -216,7 +216,10 @@ def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: def abs(self, number: z3.ArithRef) -> z3.ArithRef: return z3.Abs(number) - def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: + def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: + if ndigits is not None: + raise ValueError("round(..., ndigits=) is currently not supported by shape validations.") + # Pythons builtin 'round' implements the 'round half to even' strategy # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to @@ -281,7 +284,7 @@ def wrapper(*args): operator.truediv: lift(ops.div), operator.mod: lift(ops.mod), operator.abs: lift(ops.abs), - builtins.round: lift(ops.round_to_int), + builtins.round: lift(ops.round), # Math module. math.ceil: lift(ops.ceil), @@ -347,7 +350,6 @@ def __init__( self._ops = _Z3Ops(self._validator) def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: - # TODO: Probably OK to relax this and allow lower precision if dtype is torch.int64: return z3.IntVal(int(value)) if dtype is torch.double: @@ -356,20 +358,6 @@ def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: return z3.BoolVal(bool(value)) raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") - def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: - if dtype == torch.float64: - return z3.ToReal(x) - raise NotImplementedError(f"to_dtype {dtype} NYI") - - def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: - return z3.ToInt(x) - - def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: - return self._ops.round_to_int(x) - - def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: - return self._ops.div(numerator, denominator) - def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: return self._ops.div(numerator, denominator) @@ -382,17 +370,11 @@ def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: return self._ops.pow(base, exp) - def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: - return self._ops.pow(base, exp) - def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: return self._ops.mod(p, q) - def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: - return self._ops.ceil(x) - - def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: - return self._ops.floor(x) + def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: + return self._ops.round(number, ndigits) def __getattr__(self, name: str) -> Any: REPLACEMENT = { diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 128ce537c019..1384261b4512 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,78 +1,43 @@ -import functools import math -import sys import sympy from sympy import S +from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or __all__ = [ "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", - "IntTrueDiv", - "FloatTrueDiv", + "Pow", + "TrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", - "RoundToInt", + "Round", "RoundDecimal", - "ToFloat", - "FloatPow", - "PowByNatural", ] -def _keep_float(f): - @functools.wraps(f) - def inner(*args): - r = f(*args) - if any(isinstance(a, sympy.Float) for a in args) and not isinstance( - r, sympy.Float - ): - r = sympy.Float(float(r)) - return r - - return inner - - def fuzzy_eq(x, y): if None in (x, y): return None return x == y -# It would be nice to have assertions on whether or not inputs is_integer -# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy -# sometimes inconsistently reports floats an integers. -# -# What we can assume from sympy is that if something is an int, it -# definitely is is_integer, but if it is a float it may or may not -# be is_integer. So we are unable to do strong asserts that things -# are NOT integers. - - -# TODO: In Triton, // rounds to zero, but in Python, it is floor division. -# When we can prove both arguments are non-negative, we should just have a -# GenericFloorDiv (name pending) which can codegen efficiently in Python/C, -# and then PythonFloorDiv and CIntDiv which have the appropriate rounding -# semantics. -# -# Right now, FloorDiv de facto changes behavior if arguments are negative or -# not, this can potentially cause correctness issues. class FloorDiv(sympy.Function): """ We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) - - NB: This is Python-style floor division, round to -Inf """ nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 - is_integer = True + # Default return type for SymPy assumptions. + # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers + is_real = True @property def base(self): @@ -87,14 +52,29 @@ def _sympystr(self, printer): divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" + # SymPy assumptions based on argument types. + def _eval_is_real(self): + return fuzzy_or([self.base.is_real, self.divisor.is_real]) + + def _eval_is_integer(self): + return fuzzy_and([self.base.is_integer, self.divisor.is_integer]) + # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod def eval(cls, base, divisor): - # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full - # Assert triggered by inequality solver - # assert base.is_integer, base - # assert divisor.is_integer, divisor + def check_supported_type(x): + if ( + x.is_integer is False and x.is_real is False and x.is_complex + ) or x.is_Boolean: + raise TypeError( + f"unsupported operand type(s) for //: " + f"'{type(base).__name__}' and '{type(divisor).__name__}'" + f", expected integer or real" + ) + + check_supported_type(base) + check_supported_type(divisor) # We don't provide the same error message as in Python because SymPy # makes it difficult to check the types. @@ -105,22 +85,26 @@ def eval(cls, base, divisor): return sympy.S.Zero if base.is_integer and divisor == 1: return base + if base.is_real and divisor == 1: + return sympy.floor(base) if base.is_integer and divisor == -1: return sympy.Mul(base, -1) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): - return sympy.Integer(int(base) // int(divisor)) + return base // divisor + if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance( + divisor, (sympy.Integer, sympy.Float) + ): + return sympy.floor(base / divisor) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) + if isinstance(divisor, sympy.Rational) and divisor.p == 1: + return sympy.floor(base * divisor.q) - # gcd in sympy is over polynomials, so you'll end up with rationals if - # you do this. Don't. - """ if isinstance(base, sympy.Add): for a in base.args: gcd = sympy.gcd(a, divisor) if gcd == divisor: return FloorDiv(base - a, divisor) + a / gcd - """ try: gcd = sympy.gcd(base, divisor) @@ -205,19 +189,6 @@ class Where(sympy.Function): nargs = (3,) - def _eval_is_integer(self): - return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] - - def _eval_is_nonnegative(self): - return ( - True - if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] - else None - ) - - def _eval_is_positive(self): - return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] - @classmethod def eval(cls, c, p, q): if c == sympy.true: @@ -226,27 +197,28 @@ def eval(cls, c, p, q): return q -# Python-style modulus: take sign from RHS -class PythonMod(sympy.Function): - nargs = (2,) +class Mod(sympy.Function): + """ + We maintain this so that we avoid SymPy correctness issues, such as: + https://github.com/sympy/sympy/issues/25146 + """ - is_integer = True + nargs = (2,) @classmethod def eval(cls, p, q): - # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint - # Triggered by sympy.solvers.inequalities.reduce_inequalities - # assert p.is_integer, p - # assert q.is_integer, q + # This was adapted from: sympy/core/mod.py if q.is_zero: raise ZeroDivisionError("Modulo by zero") - + # If either of them is NaN or infinite. + if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False: + return S.NaN # Three cases: # 1. p == 0 # 2. p is either q or -q # 3. p is integer and q == 1 - if p is S.Zero or p in (q, -q) or q == 1: + if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1): return S.Zero # Evaluate if they are both literals. @@ -275,7 +247,10 @@ def eval(cls, p, q): if sympy.Mod(p, q) == 0: return S.Zero - # NB: args[1] for PythonMod + def _eval_is_integer(self): + p, q = self.args + return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined] + def _eval_is_nonnegative(self): return True if self.args[1].is_positive else None # type: ignore[attr-defined] @@ -283,58 +258,6 @@ def _eval_is_nonpositive(self): return True if self.args[1].is_negative else None # type: ignore[attr-defined] -# Generic modulus: only defined on non-negative arguments -class Mod(sympy.Function): - nargs = (2,) - - is_integer = True - is_nonnegative = True - - @classmethod - def eval(cls, p, q): - # This was adapted from: sympy/core/mod.py - - # Triggered by - # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full - # assert p.is_integer, p - # assert q.is_integer, q - - if q.is_zero: - raise ZeroDivisionError("Modulo by zero") - - # Three cases: - # 1. p == 0 - # 2. p is either q or -q - # 3. p is integer and q == 1 - if p is S.Zero or p in (q, -q) or q == 1: - return S.Zero - - # Evaluate if they are both literals. - if q.is_Number and p.is_Number: - assert p >= 0, p - assert q >= 1, q - return p % q - - # If q == 2, it's a matter of whether p is odd or even. - if q.is_Number and q == 2: - if p.is_even: - return S.Zero - if p.is_odd: - return S.One - - # If p is a multiple of q. - r = p / q - if r.is_integer: - return S.Zero - - # If p < q and its ratio is positive, then: - # - floor(p / q) = 0 - # - p % q = p - floor(p / q) * q = p - less = p < q - if less.is_Boolean and bool(less) and r.is_positive: - return p - - class CleanDiv(FloorDiv): """ Div where we can assume no rounding. @@ -344,36 +267,6 @@ class CleanDiv(FloorDiv): pass -# Don't use sympy ceiling/floor as they will attempt simplifications involving -# frac -class CeilToInt(sympy.Function): - is_integer = True - - @classmethod - def eval(cls, number): - # assert number.is_integer is not True, number - if number == sympy.oo: - return sympy.Integer(sys.maxsize - 1) - if number == -sympy.oo: - return sympy.Integer(-sys.maxsize - 1) - if isinstance(number, sympy.Number): - return sympy.Integer(math.ceil(float(number))) - - -class FloorToInt(sympy.Function): - is_integer = True - - @classmethod - def eval(cls, number): - # assert number.is_integer is not True, number - if number == sympy.oo: - return sympy.Integer(sys.maxsize - 1) - if number == -sympy.oo: - return sympy.Integer(-sys.maxsize - 1) - if isinstance(number, sympy.Number): - return sympy.Integer(math.floor(float(number))) - - class CeilDiv(sympy.Function): """ Div used in indexing that rounds up. @@ -382,8 +275,6 @@ class CeilDiv(sympy.Function): is_integer = True def __new__(cls, base, divisor): - base = sympy.sympify(base) - divisor = sympy.sympify(divisor) if sympy.gcd(base, divisor) == divisor: return CleanDiv(base, divisor) else: @@ -391,8 +282,6 @@ def __new__(cls, base, divisor): class LShift(sympy.Function): - is_integer = True - @classmethod def eval(cls, base, shift): if shift < 0: @@ -401,8 +290,6 @@ def eval(cls, base, shift): class RShift(sympy.Function): - is_integer = True - @classmethod def eval(cls, base, shift): if shift < 0: @@ -410,107 +297,28 @@ def eval(cls, base, shift): return base // 2**shift -def safe_pow(base, exp): - sign = 1 - if base < 0: - base = -base - sign = 1 if exp % 2 == 0 else -1 - return sign * _safe_pow(base, exp) - - -def _safe_pow(base, exponent): - if exponent < 0: - raise ValueError("Exponent must be non-negative.") - - if exponent == 0: - return 1 - - half_exp = safe_pow(base, exponent // 2) - if half_exp > sys.maxsize - 1: - return sys.maxsize - 1 - - result = half_exp * half_exp - if result > sys.maxsize - 1: - return sys.maxsize - 1 - - if exponent % 2 == 1: - result *= base - if result > sys.maxsize - 1: - return sys.maxsize - 1 - - return result - - -class PowByNatural(sympy.Function): - is_integer = True - - @classmethod - def eval(cls, base, exp): - if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): - return sympy.Integer(safe_pow(base, exp)) - if isinstance(exp, sympy.Integer): - # Translate power into iterated multiplication - r = sympy.Integer(1) - for _ in range(int(exp)): - r *= base - return r - # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp - # is a natural number if we do - - -# base is assumed to be nonnegative, thereby prevent complex numbers from -# occuring -class FloatPow(sympy.Function): - is_integer = False - is_real = True - +# Overloaded to be compatible with regular Python. +# https://github.com/pytorch/pytorch/issues/90900 +class Pow(sympy.Function): @classmethod def eval(cls, base, exp): - if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): - return sympy.Float(float(base) ** float(exp)) - # NB: do not do any nontrivial reasoning + if exp.is_zero: + return sympy.Integer(1) + elif base.is_zero and exp < 0: + raise ZeroDivisionError(f"{base} cannot be raised to a negative power") + else: + return base**exp # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 -# -# In particular, sympy division is willing to simplify x/x == 1 -# where 1 is an integer, but this must be a float if x was float. -class FloatTrueDiv(sympy.Function): - is_integer = False - is_real = True - - @classmethod - def eval(cls, base, divisor): - # assert base.is_integer is not True, base - # assert divisor.is_integer is not True, divisor - - if divisor.is_zero: - raise ZeroDivisionError("division by zero") - - if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): - return sympy.Float(float(base) / float(divisor)) - - -# Overloaded to be compatible with regular Python. We distinguish this from -# FloatTrueDiv, because the code generation has to be different for this case: -# Python has a fancy algorithm for integer true division that isn't just -# "promote both arguments to float and use float division", so you need to -# codegen it differently. While technically you can work it out from the -# types of the input, this is often inconvenient to do in Inductor codegen, -# so just have a different operator -# NB: Right now, Inductor codegen doesn't implement this correctly lol -class IntTrueDiv(sympy.Function): - is_integer = False - is_real = True - +class TrueDiv(sympy.Function): @classmethod def eval(cls, base, divisor): if divisor.is_zero: raise ZeroDivisionError("division by zero") - - if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): - return sympy.Float(int(base) / int(divisor)) + else: + return base / divisor # TODO: As an indicator, this != 0 implies == 1 (and vice versa). @@ -545,85 +353,45 @@ def eval(cls, *args): return None -# NB: this is inconsistent with math.trunc in Python -class TruncToFloat(sympy.Function): - is_integer = False - is_real = True - - @classmethod - def eval(cls, number): - # assert number.is_integer is not True, number - if isinstance(number, sympy.Number): - # NB: It is safe to use truncation to integer, which is what - # math.trunc does, as Python integers are arbitrary precision and - # so we are guaranteed not to lose precision when we do this - return sympy.Float(math.trunc(float(number))) - - -class TruncToInt(sympy.Function): +class Trunc(sympy.Function): is_integer = True @classmethod def eval(cls, number): - # assert number.is_integer is not True, number - if number == sympy.oo: - return sympy.Integer(sys.maxsize - 1) - if number == -sympy.oo: - return sympy.Integer(-sys.maxsize - 1) - if isinstance(number, sympy.Number): + if number.is_integer: + return number + elif isinstance(number, sympy.Number): return sympy.Integer(math.trunc(float(number))) -# This is float -> int -class RoundToInt(sympy.Function): +class Round(sympy.Function): is_integer = True @classmethod def eval(cls, number): - # assert number.is_integer is not True, number - - if isinstance(number, sympy.Float): - return sympy.Integer(round(float(number), 0)) - + if number.is_integer: + return number + elif isinstance(number, sympy.Number): + return sympy.Integer(round(float(number))) -# To get float -> int, Python style round semantics. -# -# x = PyFloat_AsDouble(self); -# if (o_ndigits == Py_None) { -# /* single-argument round or with None ndigits: -# * round to nearest integer */ -# rounded = round(x); -# if (fabs(x-rounded) == 0.5) -# /* halfway case: round to even */ -# rounded = 2.0*round(x/2.0); -# return PyLong_FromDouble(rounded); -# } + def __int__(self): + # This will only ever be called when computing size hints. At that point, self.args[0] should be a number and + # no longer an expression. If it were, the float call would fail and the caller would handle this further. + return round(float(self.args[0])) # type: ignore[arg-type] -# NB: Like Round, this only ever returns floats. ndigits cannot be None class RoundDecimal(sympy.Function): - is_integer = False - is_real = True - @classmethod def eval(cls, number, ndigits): - # assert number.is_integer is not True, number - - if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): - return sympy.Float(round(float(number), int(ndigits))) - - -class ToFloat(sympy.Function): - is_integer = False - is_real = True - - @classmethod - def eval(cls, number): - if number in [sympy.oo, -sympy.oo]: + if number.is_integer and ndigits >= 0: return number - - if isinstance(number, sympy.Integer): - return sympy.Float(int(number)) + elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): + value_type, output_type = ( + (int, sympy.Integer) + if isinstance(number, sympy.Integer) + else (float, sympy.Float) + ) + return output_type(round(value_type(number), int(ndigits))) def make_opaque_unary_fn(name): diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 09a4b8384749..806e91cfe281 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -15,23 +15,16 @@ import torch from .functions import ( - CeilToInt, CleanDiv, - FloatPow, - FloatTrueDiv, FloorDiv, - FloorToInt, - IntTrueDiv, IsNonOverlappingAndDenseIndicator, Mod, ModularIndexing, - PowByNatural, - PythonMod, + Pow, + Round, RoundDecimal, - RoundToInt, - ToFloat, - TruncToFloat, - TruncToInt, + TrueDiv, + Trunc, Where, ) @@ -56,39 +49,30 @@ def handlers(): sympy.Le: "le", sympy.Ge: "ge", sympy.Not: "not_", - IntTrueDiv: "int_truediv", - FloatTrueDiv: "truediv", + TrueDiv: "truediv", FloorDiv: "floordiv", - CleanDiv: "floordiv", # TODO: hmm? - TruncToFloat: "trunc", + CleanDiv: "div", + Trunc: "trunc", Where: "where", sympy.Add: "add", sympy.Mul: "mul", - FloatPow: "pow", - PowByNatural: "pow_by_natural", - # sympy simplifies x * x into Pow(x, 2), so we need to handle this. - # Do NOT use builtin Pow for floats - # TODO: There is a hazard here, if we have float * float it will - # also get turned into Pow(float, 2) but we don't want this because - # pow_by_natural is assumed to only be integers. Probably the fix is - # to add a FloatMul to impede this optimization - sympy.Pow: "pow_by_natural", + Pow: "pow", + sympy.Pow: "pow", Mod: "mod", - PythonMod: "mod", # TODO: this is wrong - # TODO: Inductor can generate these, but it's ill-specified which - # semantics were intended here. Needs to be cleaned up along with - # FloorDiv in a bigger cleanup sympy.Mod: "mod", sympy.Abs: "abs", sympy.log: "log", sympy.exp: "exp", + sympy.floor: "floor", + sympy.ceiling: "ceil", sympy.Min: "minimum", sympy.Max: "maximum", ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", - RoundDecimal: "round_decimal", + Round: "round", + RoundDecimal: "round", } for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: HANDLERS[getattr(sympy, name)] = name @@ -100,11 +84,7 @@ def handlers(): def sympy_interp( - analysis, - env: Dict[sympy.Symbol, Any], - expr: Union[sympy.Expr, SympyBoolean], - *, - index_dtype=torch.int64, + analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean] ): # Handle base cases dtype = None @@ -125,32 +105,9 @@ def sympy_interp( expr.args[1], sympy.core.numbers.Half ): return analysis.sqrt(sympy_interp(analysis, env, expr.args[0])) - if isinstance(expr, ToFloat): - return analysis.to_dtype( - sympy_interp(analysis, env, expr.args[0]), torch.float64 - ) # Recursive case args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type] - - # These handlers are special because they take an extra dtype argument - # specifying what they should convert to, and we need to appropriately set - # this up when we convert from Sympy. A reasonable default when you - # are translating is to conservatively do int64, and then narrow these - # arguments later when you discover you can narrow the index range. But - # if you already know that 32-bit indexing is OK, you can directly do the - # sympy translation with index_dtype=torch.int32 - INDEX_DTYPE_HANDLERS = { - TruncToInt: "trunc_to_int", - sympy.floor: "floor_to_int", - sympy.ceiling: "ceil_to_int", - FloorToInt: "floor_to_int", - CeilToInt: "ceil_to_int", - RoundToInt: "round_to_int", - } - if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: - return getattr(analysis, handler_name)(*args, index_dtype) - if hasattr(expr.func, "_torch_handler_name"): handler_name = expr.func._torch_handler_name else: diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index b54a0d0503a1..881b9d616eb5 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,25 +1,12 @@ import math -import operator - import sympy import torch from torch.utils._sympy.functions import ( - _keep_float, - FloatPow, - FloatTrueDiv, - FloorDiv, - IntTrueDiv, - Mod, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, OpaqueUnaryFn_sqrt, - PowByNatural, - RoundDecimal, - RoundToInt, - ToFloat, - TruncToInt, ) @@ -75,41 +62,20 @@ def not_(a): @staticmethod def reciprocal(x): - return FloatTrueDiv(1.0, x) + return 1 / x @staticmethod def square(x): - return PowByNatural(x, 2) - - @staticmethod - def trunc_to_int(x, dtype): - return TruncToInt(x) - - @staticmethod - def ceil_to_int(x, dtype): - return sympy.ceiling(x) - - @staticmethod - def floor_to_int(x, dtype): - return sympy.floor(x) - - @staticmethod - def floor(x): - return _keep_float(sympy.floor)(x) - - @staticmethod - def ceil(x): - return _keep_float(sympy.ceiling)(x) - - @staticmethod - def to_dtype(x, dtype): - if dtype == torch.float64: - return ToFloat(x) - raise NotImplementedError(f"to_dtype {dtype} NYI") + return x * x @staticmethod def mod(x, y): - return Mod(x, y) + ret = abs(x) % abs(y) + # without check: + # tracing will fail trying to go through control-flow if x is Proxy() + if isinstance(x, (int, sympy.Number)) and x < 0: + ret *= -1 + return ret @staticmethod def abs(x): @@ -121,31 +87,37 @@ def neg(x): @staticmethod def truediv(a, b): - return FloatTrueDiv(a, b) + return a / b @staticmethod - def int_truediv(a, b): - return IntTrueDiv(a, b) + def div(a, b): + return ReferenceAnalysis.truediv(a, b) @staticmethod def floordiv(a, b): - return FloorDiv(a, b) + if b == 0: + return sympy.nan if a == 0 else sympy.zoo + return a // b @staticmethod def truncdiv(a, b): - raise NotImplementedError("TODO: truncdiv") + result = a / b + if result.is_finite: + result = sympy.Integer(result) + + return result @staticmethod def add(a, b): - return _keep_float(operator.add)(a, b) + return a + b @staticmethod def mul(a, b): - return _keep_float(operator.mul)(a, b) + return a * b @staticmethod def sub(a, b): - return _keep_float(operator.sub)(a, b) + return a - b @staticmethod def exp(x): @@ -161,27 +133,39 @@ def sqrt(x): @staticmethod def pow(a, b): - return _keep_float(FloatPow)(a, b) - - @staticmethod - def pow_by_natural(a, b): - return PowByNatural(a, b) + return a**b @staticmethod def minimum(a, b): - return sympy.Min(a, b) + # Poorman's version of upcasting in Sympy + # This won't do for sympy.Expr as the casting does nothing for those + if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: + result_type = sympy.Float + else: + assert a.is_Integer + assert b.is_Integer + result_type = sympy.Integer + return sympy.Min(result_type(a), result_type(b)) @staticmethod def maximum(a, b): - return sympy.Max(a, b) + # Poorman's version of upcasting in Sympy + # This won't do for sympy.Expr as the casting does nothing for those + if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: + result_type = sympy.Float + else: + assert a.is_Integer + assert b.is_Integer + result_type = sympy.Integer + return sympy.Max(result_type(a), result_type(b)) @staticmethod - def round_to_int(a, dtype): - return RoundToInt(a) + def floor(x): + return sympy.floor(x) @staticmethod - def round_decimal(a, b): - return RoundDecimal(a, b) + def ceil(x): + return sympy.ceiling(x) # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain @@ -207,20 +191,10 @@ def not_(a): def floordiv(a, b): return a // b - @staticmethod - def mod(x, y): - return x % y - @staticmethod def truncdiv(a, b): return a / b - @staticmethod - def to_dtype(x, dtype): - if dtype == torch.float64: - return float(x) - raise NotImplementedError(f"to_dtype {dtype} NYI") - @staticmethod def exp(x): raise AssertionError("exp is not valid shape sympy expr") @@ -241,41 +215,10 @@ def minimum(a, b): def maximum(a, b): return torch.sym_max(a, b) - @staticmethod - def floor_to_int(x, dtype): - return math.floor(x) - - @staticmethod - def ceil_to_int(x, dtype): - return math.ceil(x) - @staticmethod def floor(x): - return float(math.floor(x)) + return math.floor(x) @staticmethod def ceil(x): - return float(math.ceil(x)) - - @staticmethod - def truediv(a, b): - return a / b - - @staticmethod - def pow(a, b): - return a**b - - @staticmethod - def pow_by_natural(a, b): - # Pray that safe_pow is not needed here lol. In particular, this - # never participates in VR low/high ranges, so overflow should be - # unlikely - return a**b - - @staticmethod - def round_to_int(a, dtype): - return round(a) - - @staticmethod - def round_decimal(a, b): - return round(a, ndigits=b) + return math.ceil(x) diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index 02ddf7c34219..6276c696293c 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -88,7 +88,6 @@ def try_solve( # Return if we were able to isolate 'thing' on the left-hand side. if isinstance(e, sympy.Rel) and e.lhs == thing: - log.debug("solved: %s ---> %s", expr, e) return e, e.rhs return None diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 4d364d4981b5..c7cc96beb980 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -5,7 +5,6 @@ import logging import math import operator -import sys from typing import ( Callable, Dict, @@ -26,20 +25,17 @@ from torch._prims_common import dtype_to_type from .functions import ( - _keep_float, - FloatTrueDiv, - FloorDiv, - IntTrueDiv, + OpaqueUnaryFn_acos, + OpaqueUnaryFn_asinh, + OpaqueUnaryFn_atan, + OpaqueUnaryFn_cosh, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, + OpaqueUnaryFn_sinh, OpaqueUnaryFn_sqrt, - PowByNatural, + OpaqueUnaryFn_tanh, + Round, RoundDecimal, - RoundToInt, - safe_pow, - ToFloat, - TruncToFloat, - TruncToInt, ) from .interp import sympy_interp @@ -124,8 +120,6 @@ class ValueRanges(Generic[_T]): lower: _T upper: _T is_bool: bool - is_int: bool - is_float: bool @overload def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: @@ -148,39 +142,8 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: # Because this is a frozen class object.__setattr__(self, "lower", lower) object.__setattr__(self, "upper", upper) - # Unlike bool/int in Python, we don't report bools are ints object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean)) - if self.is_bool: - assert isinstance(upper, SympyBoolean), (lower, upper) - - # Warning: is_int/is_float is best effort. We do pretty well in - # Dynamo, but in Inductor these attributes are often wrong because we - # are not very rigorous in dtype analysis. This is also why we need - # the flexible analysis for is_int: sometimes a sympy.oo pops in for - # an integer bound. I would /like/ for us not to do this, but it's - # too hard to push the invariant through right now. - - object.__setattr__( - self, - "is_int", - not self.is_bool - and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)), - ) - """ - # This assert is just impossible right now, too many sympy bugs - if self.is_int: - # NB: sympy will sometimes randomly lose the float-ness of zero, - # so we also need to account for that in the assertion here. - # See also https://github.com/sympy/sympy/issues/26620 - assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], ( - lower, - upper, - ) - assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper) - """ - # NB: [-oo, oo] always advertises as float! - object.__setattr__(self, "is_float", not self.is_bool and not self.is_int) - assert self.is_bool or self.is_int or self.is_float, (lower, upper) + assert isinstance(upper, SympyBoolean) == self.is_bool def boolify(self) -> ValueRanges[SympyBoolean]: if vr_is_bool(self): @@ -221,8 +184,6 @@ def __and__(self: AllVR, other: AllVR) -> AllVR: if self == ValueRanges.unknown(): return other assert self.is_bool == other.is_bool, (self, other) - assert self.is_int == other.is_int, (self, other) - assert self.is_float == other.is_float, (self, other) if self.is_bool: return ValueRanges( sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper) @@ -392,12 +353,7 @@ def constant(value, dtype): # using nan makes subsequent computation throw, and for the purposes of optimization # returning -math.inf - math.inf is equivalent to giving up if isinstance(value, SupportsFloat) and math.isnan(value): - if dtype == torch.bool: - return ValueRanges.unknown_bool() - elif dtype.is_floating_point: - return ValueRanges.unknown() - else: - return ValueRanges(-sys.maxsize - 1, sys.maxsize) + return ValueRanges.unknown() if is_python: type_ = dtype_to_type(dtype) @@ -413,18 +369,7 @@ def constant(value, dtype): # dtype is intXX assert value.is_integer - r = ValueRanges.wrap(value) - return r - - @staticmethod - def to_dtype(a, dtype, src_dtype=None): - if dtype == torch.float64: - return ValueRanges.increasing_map(a, ToFloat) - return ValueRanges.unknown() - - @staticmethod - def trunc_to_int(a, dtype): - return ValueRanges.increasing_map(a, TruncToInt) + return ValueRanges.wrap(value) @staticmethod def not_(a): @@ -483,9 +428,7 @@ def ge(cls, a, b): @staticmethod def add(a, b): - return ValueRanges.coordinatewise_increasing_map( - a, b, _keep_float(operator.add) - ) + return ValueRanges.coordinatewise_increasing_map(a, b, operator.add) @classmethod def mul(cls, a, b): @@ -505,20 +448,11 @@ def safe_mul(a, b): else: return a * b - return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul)) + return ValueRanges.coordinatewise_monotone_map(a, b, safe_mul) - @staticmethod - def int_truediv(a, b): - a = ValueRanges.wrap(a) - b = ValueRanges.wrap(b) - if 0 in b or ( - (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) - ): - return ValueRanges.unknown() - else: - return ValueRanges.coordinatewise_monotone_map( - a, b, _keep_float(IntTrueDiv) - ) + @classmethod + def div(cls, a, b): + return cls.truediv(a, b) @staticmethod def truediv(a, b): @@ -529,22 +463,18 @@ def truediv(a, b): ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map( - a, b, _keep_float(FloatTrueDiv) - ) + return ValueRanges.coordinatewise_monotone_map(a, b, operator.truediv) @staticmethod def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) if 0 in b or ( - # TODO: make this more precise - (-sympy.oo in a or sympy.oo in a) - or (-sympy.oo in b or sympy.oo in b) + (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv) + return ValueRanges.coordinatewise_monotone_map(a, b, operator.floordiv) @classmethod def mod(cls, x, y): @@ -593,51 +523,17 @@ def modular_indexing(cls, a, b, c): @classmethod def is_non_overlapping_and_dense_indicator(cls, *args): - return ValueRanges.unknown() # TODO: type here is wrong - - @classmethod - def pow_by_natural(cls, a, b): - a = ValueRanges.wrap(a) - b = ValueRanges.wrap(b) - if a.is_singleton() and b.is_singleton(): - return ValueRanges.wrap(safe_pow(a.lower, b.lower)) - # NB: Exclude zero, because zero is special - elif a.lower >= 1: - # We should know that b >= 0 but we may have forgotten this fact due - # to replacements, so don't assert it, but DO clamp it to prevent - # degenerate problems - return ValueRanges.coordinatewise_increasing_map( - a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural - ) - elif b.is_singleton(): - if b.lower % 2 == 0: - # x^n where n is even - return ValueRanges.convex_min_zero_map( - a, lambda x: safe_pow(x, b.lower) - ) - else: - # x^n where n is odd - return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower)) - else: - # a is potentially negative, and we don't know if the exponent is - # even or odd. So just conservatively set the upper and lower - # bound based on what the maximum absolute value could be, in both - # directions - max_base = max(a.upper, -a.lower) - return ValueRanges( - -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper) - ) + return ValueRanges.unknown() @classmethod def pow(cls, a, b): - return ValueRanges.unknown() + def is_integer(val): + return isinstance(val, int) or ( + hasattr(val, "is_integer") and val.is_integer + ) - # We could implement all this, but for floating point pow, is there - # really a point? - """ a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) - # Not implemented yet. It's a bit tricky # If you want to implement it, compute the partial derivatives of a ** b # and check the ranges where the function is increasing / decreasing @@ -657,7 +553,8 @@ def pow(cls, a, b): if b == 0: if not a.lower.is_finite: return ValueRanges.unknown() - return ValueRanges.wrap(1.0) + type_ = sympy.Float if a.lower.is_real else sympy.Integer + return ValueRanges.wrap(type_(1)) if b < 0: a = cls.reciprocal(a) @@ -666,12 +563,21 @@ def pow(cls, a, b): if a == ValueRanges.unknown(): return ValueRanges.unknown() - # If the base is positive, then we're good, otherwise nothing's defined - if a.lower >= 0: - return ValueRanges.increasing_map(a, lambda x: x**b) + # Here b > 0 + if not is_integer(b): + # If the base is positive, then we're good, otherwise nothing's defined + if a.lower >= 0: + return ValueRanges.increasing_map(a, lambda x: x**b) + else: + return ValueRanges.unknown() else: - return ValueRanges.unknown() - """ + # b > 0 integer + if b % 2 == 0: + # x^n where n is even + return ValueRanges.convex_min_zero_map(a, lambda x: x**b) + else: + # x^n where n is odd + return ValueRanges.increasing_map(a, lambda x: x**b) @staticmethod def reciprocal(x): @@ -680,7 +586,7 @@ def reciprocal(x): if 0 in x: return ValueRanges.unknown() else: - return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y)) + return ValueRanges.decreasing_map(x, lambda y: 1 / y) @staticmethod def abs(x): @@ -709,64 +615,45 @@ def maximum(cls, a, b): def min_or_max(a, b, fn): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) - return ValueRanges.coordinatewise_increasing_map(a, b, fn) - - @classmethod - def floor_to_int(cls, x, dtype): - return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) - @classmethod - def ceil_to_int(cls, x, dtype): - return ValueRanges.increasing_map( - x, sympy.functions.elementary.integers.ceiling - ) + # Performs upcasting first + def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr: + # Poorman's version of upcasting in Sympy + # Inf is not a float... + if x.is_Integer and y.is_Integer: + result_type = sympy.Integer + elif x.is_rational and y.is_rational: + result_type = sympy.Rational + else: + assert x.is_real or not x.is_finite or y.is_real or not y.is_finite + result_type = sympy.Float + return fn(result_type(x), result_type(y)) - # I think these implementations are sound. The hazard here is that sympy - # will carry out the floor/ceil at too high precision and then something - # bad will happen when we convert it to float. - # - # For truncation, the implementation is clearly sound, because the desired - # target float is always exactly representable, since you're just chopping - # off bits the mantissa. But what about ceil/floor? - # - # The important constraint here is that we're not defining floor on - # arbitrary real numbers, only representable float numbers. So we can - # take advantage of the fact that before we reach the first - # unrepresentable integer in floating point space, we have the range of - # numbers corresponding to exponent zero: all integers, with no fractional - # amounts. floor/ceil is an identity operation in this case. In the - # range below here, representable floating point numbers are spaced - # exactly 1/2 apart, and notably, both the floor/ceil are defined floating - # point numbers. There is no "gap" as you step up to the next exponent. + return ValueRanges.coordinatewise_increasing_map(a, b, fn_) @classmethod def floor(cls, x): - return ValueRanges.increasing_map( - x, _keep_float(sympy.functions.elementary.integers.floor) - ) + return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) @classmethod def ceil(cls, x): return ValueRanges.increasing_map( - x, _keep_float(sympy.functions.elementary.integers.ceiling) + x, sympy.functions.elementary.integers.ceiling ) @classmethod - def round_decimal(cls, number, ndigits): - if not ndigits.is_singleton(): - return ValueRanges.unknown() - - ndigits = ndigits.lower - # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind - # the second parameter. - fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 + def round(cls, number, ndigits=None): + if ndigits is None: + fn = Round + else: + assert ndigits.is_singleton() + ndigits = ndigits.lower + # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind + # the second parameter. + fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 return ValueRanges.increasing_map(number, fn) - @classmethod - def round_to_int(cls, number, dtype): - return ValueRanges.increasing_map(number, RoundToInt) - # It's used in some models on symints @staticmethod def sqrt(x): @@ -821,15 +708,12 @@ def cos(x): @staticmethod def cosh(x): - return ValueRanges(0.0, sympy.oo) - """ x = ValueRanges.wrap(x) if x.lower > 0: return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh) elif x.upper < 0: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh) return ValueRanges(0.0, sympy.oo) - """ @staticmethod def sin(x): @@ -839,8 +723,7 @@ def sin(x): @staticmethod def sinh(x): - # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) - return ValueRanges(-sympy.oo, sympy.oo) + return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) @staticmethod def tan(x): @@ -848,37 +731,32 @@ def tan(x): @staticmethod def tanh(x): - # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) - return ValueRanges(-sympy.oo, sympy.oo) + return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) @staticmethod def asin(x): - return ValueRanges(-sympy.oo, sympy.oo) - """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh) return ValueRanges.unknown() - """ @staticmethod def acos(x): - return ValueRanges(-sympy.oo, sympy.oo) - """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos) return ValueRanges.unknown() - """ @staticmethod def atan(x): - return ValueRanges(-sympy.oo, sympy.oo) - # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) + return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) @staticmethod def trunc(x): - return ValueRanges.increasing_map(x, TruncToFloat) + def trunc(x): + return sympy.Integer(x) if x.is_finite else x + + return ValueRanges.increasing_map(x, trunc) class ValueRangeAnalysis(SymPyValueRangeAnalysis): @@ -913,10 +791,9 @@ def store(self, name, index, value, mode=None): def reduction(self, name, dtype, src_dtype, reduction_type, index, value): return ValueRanges.unknown() - @classmethod - def index_expr(cls, index, dtype): + def index_expr(self, index, dtype): assert isinstance(index, ValueRanges) - return cls.to_dtype(index, dtype) + return index @staticmethod def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): @@ -953,15 +830,12 @@ def cast(x, dtype): @staticmethod def square(x): - return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) + return ValueRanges.convex_min_zero_map(x, lambda y: y * y) @staticmethod def neg(x): return ValueRanges.decreasing_map(x, operator.neg) - # TODO: this is slightly inaccurate because truncdiv operates at integer - # precision, but we're going through float truediv which means we can - # potentially lose precision on the bounds @classmethod def truncdiv(cls, a, b): x = cls.truediv(a, b) @@ -982,7 +856,6 @@ def __getattr__(self, name): def bound_sympy( expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: - log.debug("bound_sympy(%s, %s)", expr, ranges) if isinstance(expr, sympy.Number): return ValueRanges.wrap(expr) From 852b7b4c995148239bafa21398bd9dae711bae1d Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 5 Jun 2024 17:04:21 -0700 Subject: [PATCH 475/706] [inductor] Enable subprocess-based parallel compile as the default (#126817) Differential Revision: [D58239826](https://our.internmc.facebook.com/intern/diff/D58239826) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126817 Approved by: https://github.com/eellison ghstack dependencies: #128037, #128086 --- torch/_inductor/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index b8ff5ae5a6cd..dbaa528cd3e5 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -387,7 +387,9 @@ def is_fbcode(): # The multiprocessing start method to use for inductor workers in the codecache. # "subprocess", "fork", or "spawn" def decide_worker_start_method(): - start_method = os.environ.get("TORCHINDUCTOR_WORKER_START", "fork") + start_method = os.environ.get( + "TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess" + ) assert start_method in [ "subprocess", "fork", From 8d16a73f0f6470133c7351fd1eead0d04da8ed6f Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Thu, 6 Jun 2024 13:55:05 -0700 Subject: [PATCH 476/706] Manipulate triton_hash_with_backend so that it doesn't contain any keywords (#128159) Summary: See https://github.com/pytorch/pytorch/issues/127637 where "def" appears in the backend_hash and causes a problem. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128159 Approved by: https://github.com/jansel --- torch/utils/_triton.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index 9184f782cc73..ab93a398287c 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -61,7 +61,9 @@ def triton_hash_with_backend(): backend = triton_backend() key = f"{triton_key()}-{backend.hash()}" - return hashlib.sha256(key.encode("utf-8")).hexdigest() + + # Hash is upper case so that it can't contain any Python keywords. + return hashlib.sha256(key.encode("utf-8")).hexdigest().upper() def dtype_to_string(dtype): From c219fa5eb94d271b9fdf21720fc49ea07de9f92d Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 7 Jun 2024 16:13:16 +0000 Subject: [PATCH 477/706] [3/N] Remove unused functions (#128179) Following https://github.com/pytorch/pytorch/pull/128005, this PR continues to remove unused functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128179 Approved by: https://github.com/ezyang --- .../ATen/functorch/PyTorchOperatorHacks.cpp | 40 ------------------- aten/src/ATen/native/MetaTensor.cpp | 12 ------ aten/src/ATen/native/TypeProperties.cpp | 16 -------- functorch/csrc/dim/dim.cpp | 10 ----- functorch/csrc/dim/minpybind.h | 4 -- torch/csrc/cuda/Stream.cpp | 6 --- torch/csrc/jit/ir/ir.cpp | 6 --- 7 files changed, 94 deletions(-) diff --git a/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp b/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp index ce3f20ef97ef..e9e7b2a99553 100644 --- a/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp +++ b/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp @@ -31,46 +31,6 @@ Tensor index_select_backward_hack(const Tensor& grad, IntArrayRef self_sizes, in return at::zeros(self_sizes, grad.options()).index_add(dim, index, grad); } -static optional> unwrap(const Tensor& tensor) { - auto* wrapped = maybeGetTensorWrapper(tensor); - if (wrapped) { - if (wrapped->level().has_value()) { - return std::make_tuple(wrapped->value(), *wrapped->level()); - } - return unwrap(wrapped->value()); - } - auto* batched = maybeGetBatchedImpl(tensor); - if (batched) { - return std::make_tuple(batched->value(), batched->level()); - } - return nullopt; -} - -static bool can_perform_inplace(const Tensor& a, const Tensor& b) { - // TODO: generalize this to more transforms - auto a_ = unwrap(a); - auto b_ = unwrap(b); - if (!a_.has_value() && b_.has_value()) { - return false; - } - if (!a_.has_value() && !b_.has_value()) { - return true; - } - if (a_.has_value() && !b_.has_value()) { - return true; - } - TORCH_INTERNAL_ASSERT(a_.has_value() && b_.has_value()); - - // If b has any wrapper that a does not, then we cannot do a.inplace_(b) - if (std::get<1>(*a_) < std::get<1>(*b_)) { - return false; - } - if (std::get<1>(*a_) > std::get<1>(*b_)) { - return can_perform_inplace(std::get<0>(*a_), b); - } - return can_perform_inplace(std::get<0>(*a_), std::get<0>(*b_)); -} - // TODO: linear is pretty important for performance, but I'm not sure how to work // around the in-place. Tensor linear_hack(const Tensor& input, const Tensor& weight, const std::optional& bias_opt) { diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index 518466df84ce..302a3f45bdf4 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -28,18 +28,6 @@ Tensor empty_meta_symint( size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } -// Kept only for BC with XLA -static Tensor empty_strided_meta( - IntArrayRef size, - IntArrayRef stride, - std::optional dtype_opt, - std::optional layout_opt, - std::optional device_opt, - std::optional pin_memory_opt -) { - return empty_strided_meta_symint(c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype_opt, layout_opt, device_opt, pin_memory_opt); -} - Tensor empty_strided_meta_symint( SymIntArrayRef size, SymIntArrayRef stride, diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index 4afc7619c2eb..6e694109a21f 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -24,10 +24,6 @@ namespace at::native { -static bool is_cuda(const Tensor& self) { - return self.is_cuda(); -} - bool is_distributed(const Tensor& self) { return false; } @@ -60,18 +56,6 @@ bool is_neg(const Tensor& self) { return self.is_neg(); } -static bool is_sparse(const Tensor& self) { - return self.is_sparse(); -} - -static bool is_sparse_csr(const Tensor& self) { - return self.is_sparse_csr(); -} - -static bool is_quantized(const Tensor& self) { - return self.is_quantized(); -} - // True if `self` and `from` have compatible tensor type so that `from`'s // TensorImpl can be copied to `self`. bool _has_compatible_shallow_copy_type(const Tensor& self, const Tensor& from) { diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index 066f9517acef..7f5564c13664 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -1640,16 +1640,6 @@ static PyObject* _dims(PyObject *self, PY_END(nullptr) } -static int64_t dim_index(const std::vector>& dims, mpy::hdl dim) { - for (int64_t i = 0, N = dims.size(); i < N; ++i) { - if (dims[i].ptr() == dim.ptr()) { - return i; - } - } - return -1; -} - - struct DotPart { Slice dims; size_t total_size = 1; diff --git a/functorch/csrc/dim/minpybind.h b/functorch/csrc/dim/minpybind.h index de82b5af95a4..f1eb87265372 100644 --- a/functorch/csrc/dim/minpybind.h +++ b/functorch/csrc/dim/minpybind.h @@ -385,10 +385,6 @@ bool is_int(handle h) { return PyLong_Check(h.ptr()); } -bool is_float(handle h) { - return PyFloat_Check(h.ptr()); -} - bool is_none(handle h) { return h.ptr() == Py_None; } diff --git a/torch/csrc/cuda/Stream.cpp b/torch/csrc/cuda/Stream.cpp index 65ea8a600b57..cbfa64af2523 100644 --- a/torch/csrc/cuda/Stream.cpp +++ b/torch/csrc/cuda/Stream.cpp @@ -84,12 +84,6 @@ static void THCPStream_dealloc(THCPStream* self) { Py_TYPE(self)->tp_free((PyObject*)self); } -static PyObject* THCPStream_get_device(THCPStream* self, void* unused) { - HANDLE_TH_ERRORS - return THPDevice_New(self->cuda_stream.device()); - END_HANDLE_TH_ERRORS -} - static PyObject* THCPStream_get_cuda_stream(THCPStream* self, void* unused) { HANDLE_TH_ERRORS return PyLong_FromVoidPtr(self->cuda_stream.stream()); diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index c39ceb7e91f9..a6b0116d7fb6 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -128,12 +128,6 @@ static std::ostream& operator<<( return printValueRefs(out, nodes); } -static std::ostream& operator<<( - std::ostream& out, - const at::ArrayRef nodes) { - return printValueRefs(out, nodes); -} - struct const_value_list_with_types { const ArrayRef values; std::string delim; From 128952625beb5bcce3601ecab79626d5fac914c3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Jun 2024 16:15:03 +0000 Subject: [PATCH 478/706] Revert "Added memory budget to partitioner (#126320)" This reverts commit 2184cdd29128a924583e4702489177f83fb8270a. Reverted https://github.com/pytorch/pytorch/pull/126320 on behalf of https://github.com/ZainRizvi due to The new test_ac.py fails on ROCm machines ([comment](https://github.com/pytorch/pytorch/pull/126320#issuecomment-2155141886)) --- test/functorch/test_ac.py | 301 ----------------------- torch/_functorch/config.py | 33 --- torch/_functorch/partitioners.py | 393 ++----------------------------- 3 files changed, 25 insertions(+), 702 deletions(-) delete mode 100644 test/functorch/test_ac.py diff --git a/test/functorch/test_ac.py b/test/functorch/test_ac.py deleted file mode 100644 index ee3a2c545183..000000000000 --- a/test/functorch/test_ac.py +++ /dev/null @@ -1,301 +0,0 @@ -# Owner(s): ["oncall: pt2"] -import random - -import torch -import torch._functorch.config as config -from torch.testing._internal.common_utils import run_tests, TestCase -from torch.testing._internal.inductor_utils import HAS_CUDA -from torch.utils.flop_counter import FlopCounterMode - - -def compile_with_ac(f, memory_budget): - return torch.compile(f, backend="aot_eager_decomp_partition") - - -def get_act_mem(f): - out = f() - out.backward() - start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] - out = f() - cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] - act_mem = (cur_mem - start_mem) / (1024 * 1024) - out.backward() - return act_mem - - -def get_bw_flops(f): - # Normalized so that a 512 square matmul returns 1 - f().backward() - out = f() - with FlopCounterMode(display=False) as mode: - out.backward() - return mode.get_total_flops() / (512**3 * 2) - - -def create_pair(B_I, O): - # results in B_I * O memory, requires B_I * B_I * O flops - # arithmetic intensity of B_I - x = torch.randn(B_I * 512, B_I * 512, requires_grad=True) - w = torch.randn(B_I * 512, O * 512, requires_grad=True) - return x, w - - -def get_mem_and_flops(f, memory_budget=None): - # Returns megabytes rounded to 1 decimal point and FLOPs - # Note that each value of size (512, 512, torch.float32) is 1 MiB - torch._dynamo.reset() - with config.patch(activation_memory_budget=memory_budget): - if memory_budget is not None: - f = torch.compile(f, backend="aot_eager_decomp_partition") - - # We round this to nearest 10th of a megabyte. - return round(get_act_mem(f), 1), get_bw_flops(f) - - -class MemoryBudgetTest(TestCase): - def setUp(self): - super().setUp() - torch.set_default_device("cuda") - - def test_rematerializes_cheap(self): - def f(x, w): - x = x.cos() - x = torch.mm(x, w) - return x.sum() - - x = torch.randn(512, 512, requires_grad=True) - w = torch.randn(512, 512, requires_grad=True) - - def call(): - return f(x, w) - - eager_mem, eager_flops = get_mem_and_flops(call) - self.assertEqual(eager_mem, 1.0) - mem_10, flops_10 = get_mem_and_flops(call, memory_budget=1.0) - # Recomputing `.cos()` is not free here. - self.assertEqual(mem_10, 1.0) - self.assertEqual(eager_flops, flops_10) - mem_5, flops_5 = get_mem_and_flops(call, memory_budget=0.5) - # We can just recompute `x.cos()` here to only depend on the inputs - self.assertEqual(mem_5, 0.0) - self.assertEqual(flops_5, eager_flops) - - def test_matmul_even_chain(self): - def f(x, ws): - x = x.cos() - for w in ws: - x = torch.mm(x, w).cos() - return x.sum() - - x = torch.randn(512, 512, requires_grad=True) - ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] - - def call(): - return f(x, ws) - - eager_mem, eager_flops = get_mem_and_flops(call) - for budget in range(0, 11): - mem, flops = get_mem_and_flops(call, memory_budget=budget / 10) - if budget <= 5: - # We start saving the matmuls - self.assertEqual(mem, budget) - self.assertEqual(flops, eager_flops + (5 - budget)) - elif budget < 10: - # We're only recomputing the `cos` operations - self.assertEqual(mem, 5.0) - self.assertEqual(flops, eager_flops) - elif budget == 10: - self.assertEqual(mem, 10.0) - self.assertEqual(flops, eager_flops) - - def test_matmul_uneven_chain(self): - # This function is constructed so that we are saving one input of size - # [512, in_dim] for each w - # In addition, every matmul has a same ratio of compute to "memory - # saved", so this test is essentially testing our knapsack solving - - def f(x, ws): - xs = [torch.mm(x, w).cos() for w in ws] - return sum([x.sum() for x in xs]) - - x = torch.randn(512, 512, requires_grad=True) - - def make_weights(w_shapes): - ws = [] - for idx, dim in enumerate(w_shapes): - ws.append(torch.randn(512, dim * 512, requires_grad=True)) - return ws - - def make_weights_chain(w_shapes): - ws = [] - for idx, _ in enumerate(w_shapes): - old_dim = 512 if idx == 0 else w_shapes[idx - 1] * 512 - new_dim = w_shapes[idx] * 512 - ws.append(torch.randn(old_dim, new_dim, requires_grad=True)) - return ws - - weight_configs = [ - ( - [11, 3, 4, 2], - [ - 18, # 11 + 4 + 3 - 17, # 11 + 4 + 2 - 16, # 11 + 3 + 2 - 15, # 11 + 4 - 14, # 11 + 3 - 13, # 11 + 2 - 11, # 11 + 2 - 7, # 4 + 3 - 6, # 4 + 2 - 5, # 3 + 2 - ], - ), - ( - [3, 5, 11, 17, 14], - [ - 42, # 17 + 14 + 9 - 30, # 11 + 15 + 5 - 19, # 11 + 5 + 3 - 8, # 5 + 3 - 3, # 3 - ], - ), - ] - random.seed(0) - random_arr = [random.randint(0, 50) for _ in range(10)] - exact_sums = [] - for i in range(10): - random.shuffle(random_arr) - exact_sums.append(sum(random_arr[:i])) - weight_configs.append((random_arr, exact_sums)) - - for weight_shapes, exact_solves in weight_configs: - ws = make_weights(weight_shapes) - - def call(): - return f(x, ws) - - eager_mem, eager_flops = get_mem_and_flops(call) - total_mem = sum(weight_shapes) - self.assertEqual(eager_mem, sum(weight_shapes)) - for mem_achieved in exact_solves: - mem, _ = get_mem_and_flops(call, memory_budget=mem_achieved / total_mem) - self.assertEqual(mem, mem_achieved) - - def test_prioritize_cheaper_matmul(self): - def f(xs, ws): - xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] - return sum([x.sum() for x in xs]) - - x1, w1 = create_pair(1, 4) - x2, w2 = create_pair(2, 2) - - def call(): - return f([x1, x2], [w1, w2]) - - eager_mem, eager_flops = get_mem_and_flops(call) - self.assertEqual(eager_mem, 8) - self.assertEqual(eager_flops, 24) - comp_mem, comp_flops = get_mem_and_flops(call, memory_budget=0.5) - self.assertEqual(comp_mem, 4) - # We are recomputing x1 @ w1 here! - self.assertEqual(comp_flops, eager_flops + 4) - - @config.patch(activation_memory_budget_runtime_estimator="profile") - def test_profile(self): - def f(x, ws): - x = x.cos() - for w in ws: - x = torch.mm(x, w).cos() - return x.sum() - - x = torch.randn(512, 512, requires_grad=True) - ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] - - def call(): - return f(x, ws) - - eager_mem, eager_flops = get_mem_and_flops(call) - mem, flops = get_mem_and_flops(call, memory_budget=0.2) - # We start saving the matmuls - self.assertEqual(mem, 2) - self.assertEqual(flops, eager_flops + 3) - - def test_prioritize_cheaper_matmul2(self): - def f(xs, ws): - xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] - return sum([x.sum() for x in xs]) - - data = [(4, 4), (6, 2), (2, 6)] - xs, ws = zip(*[create_pair(a, b) for a, b in data]) - - def call(): - return f(xs, ws) - - eager_mem, eager_flops = get_mem_and_flops(call) - self.assertEqual(eager_mem, 40) - self.assertEqual(eager_flops, 320) - mem, flops = get_mem_and_flops(call, memory_budget=28 / eager_mem) - # Save w1 and w2 - self.assertEqual(mem, 28) - # We're recomputing w3 (the cheap one!) - self.assertEqual(flops - eager_flops, 2 * 2 * 6) - mem, flops = get_mem_and_flops(call, memory_budget=16 / eager_mem) - # Save w2. Note that even though saving w1 gets us closer to our memory - # limit, w2 is actually *more* FLOPs than w1! - self.assertEqual(mem, 12) - self.assertEqual(flops - eager_flops, 2 * 2 * 6 + 4 * 4 * 4) - - def test_attention_vs_linear(self): - def f(x, w): - orig_shape = x.shape - x = x.reshape(1, 1, x.shape[0], x.shape[1]) - # I know this isn't technically right lol - x = torch.nn.functional.scaled_dot_product_attention( - x, x, x, is_causal=False - ).reshape(*orig_shape) - x = torch.mm(x, w) - x = x.cos() - return x.sum() - - def try_seq_length(S, D, expected_recompute): - x = torch.randn(S * 512, D * 512, requires_grad=True) - w = torch.randn(D * 512, D * 512, requires_grad=True) - - def call(): - return f(x, w) - - with FlopCounterMode(display=False) as mode: - call() - mm_flops = mode.get_flop_counts()["Global"][torch.ops.aten.mm] - attn_flops = mode.get_total_flops() - mm_flops - mm_flops /= 512**3 * 2 - attn_flops /= 512**3 * 2 - - eager_mem, eager_flops = get_mem_and_flops(call) - self.assertEqual(eager_mem, S * D * 2) - - mem, flops = get_mem_and_flops( - call, memory_budget=0.6 - ) # Force it to recompute one of mm or attn - self.assertEqual(mem, S * D) - if expected_recompute == "attn": - expected_flops = attn_flops - else: - expected_flops = mm_flops - self.assertEqual(flops - eager_flops, expected_flops) - - # General behind this test is that if sequence length * 2 > D, then - # attention is more expensive than the linear. - try_seq_length(1, 1, "mm") - try_seq_length(1, 3, "attn") - try_seq_length(2, 2, "mm") - try_seq_length(2, 1, "mm") - try_seq_length(2, 5, "attn") - try_seq_length(4, 7, "mm") - try_seq_length(4, 9, "attn") - - -if __name__ == "__main__": - if HAS_CUDA: - run_tests() diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 60bbf1f21c66..c559951f3809 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -88,39 +88,6 @@ # a fusion can be expensive. ban_recompute_reductions = True -# By default, the partitioner is purely trying to optimize for runtime (although -# it should always use less memory than eager) -# This knob controls the partitioner to make that tradeoff for you, choosing the -# fastest option that saves less activations than the memory budget. -# Specifically, 0.0 corresponds to the activation memory from applying -# activation checkpointing to the full compiled region, and 1.0 corresponds to -# the activation memory from the default runtime-optimized strategy. So, 0.4 -# would result in a strategy that saves 40% of the activations compared to the -# default strategy. -# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below -# the activation memory budget. -# NOTE: This *cannot* be treated as -activation_memory_budget = 1.0 - -# This controls how we estimate the runtime when deciding what the cheapest -# operators to recompute are. The 3 options are -# "flops": Bases it off of the flop count provided by torch.utils.flop_counter -# "profile": Benchmarks each operator to come up with a runtime -# "testing": Returns 1 for everything -activation_memory_budget_runtime_estimator = "flops" - -# This controls the solver used for the 0-1 knapsack. By default we use a -# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" -# (which has a scipy dependency). -activation_memory_budget_solver = "dp" - -# This dumps out a png visualization of the expected runtime vs. activation -# memory tradeoffs for all memory budget values from 0 to 1 in increments of -# 0.5. See an example here: -# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 -visualize_memory_budget_pareto = ( - os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1" -) # Sets all of the ban_recompute heuristics to False except ban_recompute_reductions # Generally, this will probably result in some memory improvement, but at the diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index fc1c995e5907..cbfb4ca17168 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -25,7 +25,6 @@ ) from torch.fx.passes import graph_drawer from . import config -from ._aot_autograd.logging_utils import get_aot_graph_name from .compile_utils import fx_graph_cse, get_aten_target if TYPE_CHECKING: @@ -452,16 +451,14 @@ def _size_of(node: fx.Node) -> int: # layering violation) elif isinstance(val, (list, tuple)): return sum( - _tensor_nbytes(hint_int(n.numel(), fallback=4096), n.dtype) + _tensor_nbytes(hint_int(n.numel(), fallback=4098), n.dtype) for n in val if isinstance(n, torch.Tensor) ) elif isinstance(val, torch.Tensor): - return _tensor_nbytes(hint_int(val.numel(), fallback=4096), val.dtype) + return _tensor_nbytes(hint_int(val.numel(), fallback=4098), val.dtype) raise RuntimeError(f"Unknown metadata type {type(val)}") - if node.op == "get_attr": - return 0 raise RuntimeError("We should always have `val` metadata on the nodes") @@ -535,22 +532,25 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: for idx, node in enumerate(gm.graph.nodes): order[node] = idx - def insert_node_in_graph(node): - cur_nodes = [node] - insertable_nodes = set() - while len(cur_nodes) > 0: - node = cur_nodes.pop() - if node in insertable_nodes or node in env: - continue - insertable_nodes.add(node) + # Populate depth for the nodes. Depth is the distance from the inputs. + depths = {} + output_node = next(iter(gm.graph.find_nodes(op="output"))) + for node in gm.graph.nodes: + if node.op == "placeholder": + depths[node] = 0 + else: + depths[node] = max([depths[arg] for arg in node.all_input_nodes], default=0) - # Bias traversal towards the nodes that have higher depth - prioritizes - # critical path first. - cur_nodes += node.all_input_nodes + def insert_node_in_graph(node): + if node in env: + return env[node] - insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n]) - for node in insertable_nodes: - env[node] = new_graph.node_copy(node, lambda x: env[x]) + # Bias traversal towards the nodes that have higher depth - prioritizes + # critical path first. + for arg, _ in sort_depths(node.all_input_nodes, depths): + env[arg] = insert_node_in_graph(arg) + env[node] = new_graph.node_copy(node, lambda x: env[x]) + return env[node] # Find first bwd node in the graph tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) @@ -750,7 +750,7 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: return joint_module -def solve_min_cut( +def get_saved_values( joint_graph: fx.Graph, node_info: NodeInfo, min_cut_options: MinCutOptions, @@ -877,6 +877,7 @@ def ban_recomputation_if_allowed(node): return False if node in dont_ban: return False + # breakpoint() # This bans recomputation of the node unless we've been forced not to by # user annotation # NB: "recompute" > 0 means that user annotation has asked us to @@ -1267,197 +1268,9 @@ def get_name_to_node(graph: fx.Graph): return name_to_node -def greedy_knapsack( - memory: List[float], runtimes: List[float], max_memory: float -) -> Tuple[float, List[int], List[int]]: - n = len(runtimes) - items = list(range(n)) - - # Sort items based on the ratio of runtime to memory in descending order - items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) - - total_memory = 0.0 - total_runtime = 0.0 - items_to_save = [] - items_to_allow_recomputing = [] - - for i in items: - if total_memory + memory[i] <= max_memory: - total_memory += memory[i] - total_runtime += runtimes[i] - items_to_save.append(i) - else: - items_to_allow_recomputing.append(i) - return total_runtime, items_to_save, items_to_allow_recomputing - - -def ilp_knapsack( - memory: List[float], runtimes: List[float], max_memory: float -) -> Tuple[float, List[int], List[int]]: - import numpy as np - - try: - from scipy.optimize import Bounds, LinearConstraint, milp - except ImportError: - raise RuntimeError( - "To use the ILP for memory budget checkpointing you need to install scipy" - ) from None - - np_memory = np.array(memory) - np_runtimes = np.array(runtimes) - c = -np_runtimes # type: ignore[operator] - - memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) - constraints = [memory_constraint] - - integrality = np.ones_like(c) - res = milp( - c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) - ) - if not res.success: - raise RuntimeError("Somehow scipy solving failed") - - items_to_save = [] - items_to_allow_recomputing = [] - for idx, i in enumerate(res.x): - if i == 1: - items_to_save.append(idx) - else: - items_to_allow_recomputing.append(idx) - return -res.fun, items_to_save, items_to_allow_recomputing - - -def dp_knapsack( - memory: List[float], runtimes: List[float], max_memory: float -) -> Tuple[float, List[int], List[int]]: - # Scaling factor to convert floating point weights to integers - S = 10000 - - # Quantize the memory weights - quantized_memory = torch.tensor( - [int(round(m * S)) for m in memory], dtype=torch.long, device="cpu" - ) - runtimes = torch.tensor(runtimes, dtype=torch.float32, device="cpu") - - # Quantized pseudopolynomial DP for 0-1 Knapsack - quantized_max_memory = int(round(max_memory * S)) - - n = len(memory) - - # Initialize the DP table - # TODO(chilli): I think if needed, this memory can be optimized with sliding - # window trick + Hirschberg trick: - # https://codeforces.com/blog/entry/47247?#comment-316200 - dp = torch.zeros( - (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" - ) - - for i in range(1, n + 1): - current_memory = quantized_memory[i - 1] - current_runtime = runtimes[i - 1] - - # Copy the previous row - dp[i, :] = dp[i - 1, :] - - # Update dp[i, j] for all j >= current_memory - if current_memory == 0: - dp[i, :] = dp[i - 1, :] + current_runtime - else: - dp[i, current_memory:] = torch.maximum( - dp[i - 1, current_memory:], - dp[i - 1, :-current_memory] + current_runtime, - ) - - # Backtrack to find the items included in the knapsack - saved_items = [] - recomputable_items = [] - j: int = quantized_max_memory - for i in range(n, 0, -1): - if dp[i][j] != dp[i - 1][j]: - saved_items.append(i - 1) # Include this item (indexing from 0) - j -= int(quantized_memory[i - 1].item()) - else: - recomputable_items.append(i - 1) - - saved_items.reverse() # To get items in the order they were added - - # The maximum runtime that can be achieved within the max_memory constraint - max_runtime = dp[n][quantized_max_memory].item() - - return max_runtime, saved_items, recomputable_items - - -def _optimize_runtime_with_given_memory( - memory: List[float], - runtimes: List[float], - max_memory: float, -) -> Tuple[float, List[int], List[int]]: - SOLVER = config.activation_memory_budget_solver - if SOLVER == "greedy": - return greedy_knapsack(memory, runtimes, max_memory) - elif SOLVER == "ilp": - return ilp_knapsack(memory, runtimes, max_memory) - elif SOLVER == "dp": - return dp_knapsack(memory, runtimes, max_memory) - else: - raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}") - - -from torch.utils._mode_utils import no_dispatch - - -def estimate_runtime(node): - RUNTIME_MODE = config.activation_memory_budget_runtime_estimator - - def materialize_arg(x): - if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor): - shape = list(x.meta["val"].shape) - - def realize_symbol(d): - return hint_int(d, fallback=4096) - - shape = [realize_symbol(s) for s in shape] - return x.meta["val"].new_zeros(shape) - elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): - return hint_int(x.meta["val"], fallback=4096) - elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat): - return 1.0 - elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool): - return True - else: - return x - - if RUNTIME_MODE == "testing": - return 1 - - elif RUNTIME_MODE == "profile": - from triton.testing import do_bench - - with no_dispatch(): - args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) - ms = do_bench(lambda: node.target(*args, **kwargs)) - return ms - - elif RUNTIME_MODE == "flops": - # todo(chilli): Normalize this to also return ms - from torch.utils.flop_counter import FlopCounterMode - - args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) - with FlopCounterMode(display=False) as mode: - node.target(*args, **kwargs) - counted_flops = mode.get_total_flops() - return max(counted_flops, 1) - else: - raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}") - - def choose_saved_values_set( joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 ) -> List[fx.Node]: - if memory_budget > 1 or memory_budget < 0: - raise RuntimeError( - f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}" - ) min_cut_options = MinCutOptions( ban_if_used_far_apart=config.ban_recompute_used_far_apart, ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, @@ -1474,164 +1287,16 @@ def choose_saved_values_set( ban_if_materialized_backward=False, ban_if_not_in_allowlist=False, ) + if memory_budget == 0: return node_info.inputs - runtime_optimized_saved_values, _ = solve_min_cut( + runtime_optimized_saved_values, _ = get_saved_values( joint_graph, node_info, min_cut_options, ) - # return runtime_optimized_saved_values - if memory_budget == 1: - return runtime_optimized_saved_values - - def estimate_activations_size(saved_values: List[fx.Node]) -> float: - return sum([_size_of(i) for i in saved_values]) / 1e9 - - min_act_size = estimate_activations_size(node_info.inputs) - max_act_size = estimate_activations_size(runtime_optimized_saved_values) - # The optimized choice is smaller than the inputs anyways - if max_act_size <= min_act_size: - return runtime_optimized_saved_values - - def get_normalized_size(sz): - return (sz / 1e9) / (max_act_size - min_act_size) - - def get_mem_ratio(activations: List[fx.Node]): - return (estimate_activations_size(activations) - min_act_size) / ( - max_act_size - min_act_size - ) - - more_aggressive_options = replace( - min_cut_options, - ban_if_used_far_apart=False, - ban_if_long_fusible_chains=False, - ban_if_materialized_backward=False, - ) - more_aggressive_saved_values, _ = solve_min_cut( - joint_graph, node_info, more_aggressive_options - ) - if get_mem_ratio(more_aggressive_saved_values) < memory_budget: - return more_aggressive_saved_values - - aggressive_options = replace( - more_aggressive_options, - ban_if_not_in_allowlist=False, - ) - aggressive_recomputation_saved_values, banned_nodes = solve_min_cut( - joint_graph, node_info, aggressive_options - ) - - if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget: - return aggressive_recomputation_saved_values - - from torch._inductor.fx_utils import get_node_storage - - input_storages = {get_node_storage(node) for node in node_info.inputs} - - def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]: - return [ - i - for i in banned_nodes - if ( - # Only allow recomputing nodes that are actually required for BW - i.dist_from_bw < int(1e9) # type: ignore[attr-defined] - and get_node_storage(i) not in input_storages - ) - ] - - recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes) - - # default: runtime_optimized_saved_values - # more aggressive: more_aggressive_saved_values - # full aggressive: aggressive_recomputation_saved_values - - all_recomputable_banned_nodes = sorted( - recomputable_banned_nodes, key=_size_of, reverse=True - ) - if len(all_recomputable_banned_nodes) == 0: - return node_info.inputs - memories_banned_nodes = [ - get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes - ] - runtimes_banned_nodes = [ - estimate_runtime(node) for node in all_recomputable_banned_nodes - ] - from torch.utils._mode_utils import no_dispatch - - def get_saved_values_knapsack(memory_budget): - with no_dispatch(): - ( - expected_runtime, - saved_node_idxs, - recomputable_node_idxs, - ) = _optimize_runtime_with_given_memory( - memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0) - ) - dont_ban = set() - for idx in recomputable_node_idxs: - dont_ban.add(all_recomputable_banned_nodes[idx]) - assert dont_ban.issubset(all_recomputable_banned_nodes) - - saved_values, _ = solve_min_cut( - joint_graph, - node_info, - aggressive_options, - dont_ban, - ) - return saved_values, expected_runtime - - if config.visualize_memory_budget_pareto: - options = [] - for sweep_memory_budget in range(100, -1, -5): - saved_values, expected_runtime = get_saved_values_knapsack( - sweep_memory_budget / 100 - ) - options.append( - ( - sweep_memory_budget, - sum(runtimes_banned_nodes) - expected_runtime, - get_mem_ratio(saved_values), - ) - ) - - import matplotlib.pyplot as plt - - x_values = [item[2] for item in options] - y_values = [item[1] for item in options] - - # Plotting the values with updated axis labels and chart title - plt.figure(figsize=(10, 6)) - plt.plot(x_values, y_values, marker="o") - - # Adding labels for each point - for i, txt in enumerate(x_values): - plt.annotate( - f"{txt:.2f}", - (x_values[i], y_values[i]), - textcoords="offset points", - xytext=(0, 10), - ha="center", - ) - - plt.xlabel("Memory Budget") - plt.ylabel("Runtime of Recomputed Components") - plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime") - plt.grid(True) - fig = plt.gcf() - plt.show() - fig_name = f"memory_budget_pareto_{get_aot_graph_name()}.png" - fig.savefig(fig_name) - log.warning("Generated Pareto frontier curve at %s", fig_name) - - # todo(chilli): Estimated doesn't align exactly with actual - actual is - # usually less memory than estimated. i'm guessing (actually quite - # unsure about this) that's because estimated is just only including - # tensors we actually banned from recompute, but there may be other - # tensors that we choose to save. - - return get_saved_values_knapsack(memory_budget=memory_budget)[0] + return runtime_optimized_saved_values def min_cut_rematerialization_partition( @@ -1747,15 +1412,7 @@ def classify_nodes(joint_module): for user in node.users: node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) - memory_budget = config.activation_memory_budget - for node in joint_graph.nodes: - if isinstance(node.meta.get("memory_budget", None), float): - memory_budget = node.meta["memory_budget"] - break - # print("Memory Budget: ", memory_budget) - saved_values = choose_saved_values_set( - joint_graph, node_info, memory_budget=memory_budget - ) + saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget=1) # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(is_sym_node, saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) From fc6e3ff96d4613dacf4e762de2c3841ed333f5c5 Mon Sep 17 00:00:00 2001 From: Prachi Gupta Date: Fri, 7 Jun 2024 16:23:04 +0000 Subject: [PATCH 479/706] [ROCm] Update triton pin to fix libtanh issue (#125396) There were some internal build issues related to tanh when we moved to upstream triton in ROCm. These issues were fixed by the following triton commit: https://github.com/triton-lang/triton/pull/3810 . This PR moves the triton pin to incorporate that change. Added some skips for unit tests that regressed due to the triton commit bump in this PR. Needs https://github.com/pytorch/pytorch/pull/127968 since this PR introduces a triton dependency on llnl-hatchet, which doesn't have py3.12 wheels available currently. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125396 Approved by: https://github.com/pruthvistony, https://github.com/malfet --- .ci/docker/ci_commit_pins/triton-rocm.txt | 2 +- test/inductor/test_cpu_cpp_wrapper.py | 14 ++++++++++++-- test/inductor/test_triton_kernels.py | 1 + 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton-rocm.txt b/.ci/docker/ci_commit_pins/triton-rocm.txt index 2df035af1fdd..15f681977a12 100644 --- a/.ci/docker/ci_commit_pins/triton-rocm.txt +++ b/.ci/docker/ci_commit_pins/triton-rocm.txt @@ -1 +1 @@ -bbe6246e37d8aa791c67daaf9d9d61b26c9ccfdc +01cbe5045a6898c9a925f01435c8277b2fe6afcc diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index e77c6a5a8208..8bf9b1e6a61f 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -9,7 +9,7 @@ from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) -from torch.testing._internal.common_utils import IS_MACOS, slowTest +from torch.testing._internal.common_utils import IS_MACOS, slowTest, TEST_WITH_ROCM from torch.testing._internal.inductor_utils import HAS_CPU @@ -68,7 +68,17 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): ("cpp_wrapper",), is_skip=True ), } - +if TEST_WITH_ROCM: + test_failures_cpp_wrapper.update( + { + "test_linear_packed": test_torchinductor.TestFailure( + ("cpp_wrapper"), is_skip=True + ), + "test_linear_packed_dynamic_shapes": test_torchinductor.TestFailure( + ("cpp_wrapper"), is_skip=True + ), + } + ) if config.abi_compatible: xfail_list = [ "test_conv2d_binary_inplace_fusion_failed_cpu", diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 58ef3d4e84bc..af788de0ab0c 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -585,6 +585,7 @@ def call_triton( self.assertEqual(int_result, resulti) @requires_cuda + @skipIfRocm def test_triton_kernel_constants(self): @triton.jit def mulC_kernel( From d9696ea62482c15f565de3315db7ec40da3cbdc7 Mon Sep 17 00:00:00 2001 From: Mu-Chu Lee Date: Fri, 7 Jun 2024 16:46:26 +0000 Subject: [PATCH 480/706] [AOTInductor] [Tooling] Update NaN and INF Checker for AOTInductor (#127574) Summary: 1. Integrate NaN and INF checker with existing config, controllable by env var. 2. Move inject point of NaN & INF checker earlier, this could prevent buffer freeing before check. 3. Inject debugging code in Kernel level, which prevents us trying to read buffers that are fused inplace and into a single kernel. Test Plan: Debugging utility. Test and check by existing tests with env var: ``` TORCHINDUCTOR_NAN_ASSERTS=1 TORCHINDUCTOR_MAX_AUTOTUNE=0 python test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCuda.test_seq_non_abi_compatible_cuda ``` Reviewed By: ColinPeppler Differential Revision: D57989176 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127574 Approved by: https://github.com/chenyang78, https://github.com/desertfire --- torch/_inductor/codegen/cpp_wrapper_cpu.py | 2 +- torch/_inductor/codegen/triton.py | 16 ++++++++++++---- torch/_inductor/codegen/wrapper.py | 4 ---- torch/_inductor/config.py | 3 --- torch/_inductor/scheduler.py | 3 --- torch/csrc/inductor/aoti_torch/c/shim.h | 2 +- torch/csrc/inductor/aoti_torch/shim_common.cpp | 12 ++++-------- torch/csrc/inductor/aoti_torch/utils.h | 15 +++++++++++++++ 8 files changed, 33 insertions(+), 24 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index f33c5fb3136e..7d38b6ed1acb 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1504,7 +1504,7 @@ def generate_inf_and_nan_checker(self, nodes): for buf in nodes.get_names(): # TODO: Add buf name directly into check_inf_and_nan. self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_check_inf_and_nan({buf}));" + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan({buf}));" ) def codegen_device(self, device): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index daf329ee9b80..215d4d866980 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2315,10 +2315,18 @@ def codegen_nan_check(self): _, call_args, arg_types, _ = self.args.python_argdefs() for arg, arg_type in zip(call_args, arg_types): if isinstance(arg_type, TensorArg): - line = f"assert not {arg}.isnan().any().item()" - wrapper.writeline(line) - line = f"assert not {arg}.isinf().any().item()" - wrapper.writeline(line) + if V.graph.cpp_wrapper: + if config.abi_compatible: + wrapper.writeline( + f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));' + ) + else: + wrapper.writeline(f'assert_inf_and_nan("{arg}", {arg});') + else: + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) def create_cse_var(self, *args, **kwargs): return TritonCSEVariable(*args, **kwargs) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index fa1bb3463cb6..1daaa534ee4e 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -725,10 +725,6 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed( ): self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(codegen_args)})") - def generate_inf_and_nan_checker(self, node): - # TODO: Add check for python too. - pass - @dynamo_timed def generate(self, is_inference): if config.profile_bandwidth: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index dbaa528cd3e5..148cf1684875 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -12,9 +12,6 @@ def is_fbcode(): # add some debug printouts debug = False -# add inf and NaN checkers -debug_check_inf_and_nan = False - # Whether to disable a progress bar for autotuning disable_progress = True diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 88ff1714a3f6..46d80569125f 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2743,9 +2743,6 @@ def codegen(self) -> None: assert isinstance(node, NopKernelSchedulerNode) node.allocate() - if config.debug_check_inf_and_nan: - V.graph.wrapper_code.generate_inf_and_nan_checker(node) - if config.triton.debug_sync_kernel: self.get_backend(device).codegen_sync() diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index ba716e213a0f..65fbbd9fc23d 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -475,7 +475,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( AtenTensorHandle* out); AOTI_TORCH_EXPORT AOTITorchError -aoti_check_inf_and_nan(AtenTensorHandle tensor); +aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out( AtenTensorHandle out, diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 6f93407aa467..1306c006ba94 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -770,17 +770,13 @@ AOTITorchError aoti_torch_repeat_interleave_Tensor( } // Function to check existence of inf and NaN -AOTITorchError aoti_check_inf_and_nan(AtenTensorHandle tensor) { +AOTITorchError aoti_torch_check_inf_and_nan( + const char* tensor_name, + AtenTensorHandle tensor) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* check_tensor = tensor_handle_to_tensor_pointer(tensor); - auto flattened = check_tensor->view({-1}); - for (int64_t i = 0; i < flattened.numel(); i++) { - auto value = flattened[i].item(); - if (std::isinf(value) || std::isnan(value)) { - assert(false); - } - } + assert_inf_and_nan(tensor_name, *check_tensor); }); } diff --git a/torch/csrc/inductor/aoti_torch/utils.h b/torch/csrc/inductor/aoti_torch/utils.h index 44ca34b1c6e8..6e7bd355c57c 100644 --- a/torch/csrc/inductor/aoti_torch/utils.h +++ b/torch/csrc/inductor/aoti_torch/utils.h @@ -48,6 +48,21 @@ inline AtenTensorHandle new_tensor_handle(at::Tensor&& tensor) { return tensor_pointer_to_tensor_handle(new_tensor); } +inline void assert_inf_and_nan( + const std::string& tensor_name, + at::Tensor& check_tensor) { + auto flattened = check_tensor.view({-1}); + + for (int64_t i = 0; i < flattened.numel(); i++) { + auto value = flattened[i].item(); + if (std::isinf(value)) { + throw std::runtime_error("At least one INF in " + tensor_name); + } else if (std::isnan(value)) { + throw std::runtime_error("At least one NaN in " + tensor_name); + } + } +} + // utility functions to convert a pointer to an optional value template inline std::optional pointer_to_optional(T* ptr) { From b9b89ed638d8cb5eb31e6e153219d10d77198eb7 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Fri, 7 Jun 2024 06:38:34 -0700 Subject: [PATCH 481/706] [pipelining] fix LoopedBFS (#127796) # Issues Currently two issues need to be fixed with LoopedBFS: 1. The wrap around send operation to the looped around stage blocks will cause a hang. For some reason this doesn't surface on single node, but on multihost this surfaces in a hang. image 2. When microbatches are popped off in `backward_one_chunk` will automatically use the `bwd_chunk_id` starting from 0. This works for interleaved 1f1b and 1f1b, but for loopedBFS we want to pop from starting at `num_microbatches - 1`. Same needs to be fixed for gpipe? # Changes - Update LoopedBFS implementation to share `_step_microbatches` with `Interleaved1F1B` - Also share the tests between the two schedules for varying num_microbatches, local_stages, and world_sizes - Update `backward_one_chunk` to optionally take a `bwd_chunk_id` argument. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127796 Approved by: https://github.com/wconstab --- test/distributed/pipelining/test_schedule.py | 8 +- .../pipelining/PipelineSchedule.py | 282 +++++++++--------- 2 files changed, 147 insertions(+), 143 deletions(-) diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index d040efdc7522..7bd5825ce9ed 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -556,13 +556,15 @@ def _validate_pipeline_order( if len(error_msg) != 0: self.fail(f"Error at timestep {timestep}: " + ",".join(error_msg)) - def test_pipeline_order(self): + @parametrize("ScheduleClass", [ScheduleInterleaved1F1B, ScheduleLoopedBFS]) + def test_pipeline_order(self, ScheduleClass): # Define a list of test cases with varying num_local_stages, num_microbatches, and group_size # These should succeed since num_microbatches % group_size == 0 test_cases = [ # small number of stages (2, 2, 2), (2, 4, 4), + (2, 8, 2), (2, 8, 4), (2, 8, 8), (4, 4, 4), @@ -597,13 +599,15 @@ def test_pipeline_order(self): for i in range(num_local_stages) ] - schedule = ScheduleInterleaved1F1B(stages, num_microbatches) + schedule = ScheduleClass(stages, num_microbatches) # print(format_pipeline_order(schedule.pipeline_order)) self._validate_pipeline_order( schedule.pipeline_order, num_microbatches, num_stages ) +instantiate_parametrized_tests(TestSchedulePlan) + if __name__ == "__main__": # Run only the TestSchedulePlan tests (single process) loader = unittest.TestLoader() diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index f3d64189fe0e..964d4e88bece 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -566,7 +566,6 @@ class PipelineScheduleMulti(_PipelineSchedule): """ Base class for multi-stage schedules. Implements the `step` method. - Derived classes should implement `_step_microbatches`. """ def __init__( @@ -596,6 +595,8 @@ def __init__( # Self attributes self._stages = stages self._num_stages = stages[0].num_stages + self.pp_group_size = stages[0].group_size + self.rank = stages[0].group_rank # Set the same has_backward flag for stage object for stage in self._stages: stage.has_backward = self._has_backward @@ -604,6 +605,9 @@ def __init__( lambda stage: stage.is_last and self._loss_fn is not None ) + # This will be set during init of derived schedules + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): """ Run one iteration of the pipeline schedule with *whole-batch* input. @@ -639,6 +643,98 @@ def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): # Does not contain the last stage return None + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: Dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + prev_rank: int = (self.rank - 1) % self.pp_group_size + next_rank: int = (self.rank + 1) % self.pp_group_size + + for time_step, action in enumerate(self.pipeline_order[self.rank]): + prev_rank_ops = self.pipeline_order[prev_rank] + next_rank_ops = self.pipeline_order[next_rank] + ops: List[dist.P2POp] = [] + if action is not None: + computation_type, mb_index, stage_index = action + if computation_type == _ComputationType.FORWARD: + # perform forward computation + stage = stage_index_to_stage[stage_index] + output = stage.forward_one_chunk( + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + ops.extend(stage.get_fwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk(mb_index, loss=loss) + ops.extend(stage.get_bwd_send_ops(mb_index)) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # Look at the neighboring ranks for this current timestep and determine whether + # this current rank needs to do any recv communication + prev_rank_action = None + if time_step < len(prev_rank_ops): + prev_rank_action = prev_rank_ops[time_step] + if prev_rank_action is not None: + computation_type, mb_index, stage_index = prev_rank_action + # Only handle sends for the forward from a previous rank + if computation_type == _ComputationType.FORWARD: + # If not the last stage, then receive fwd activations + if stage_index != self._num_stages - 1: + # TODO: We are assuming that stage will always receive from stage-1 + # however that is not necessarily true of get_fwd_recv_ops + stage = stage_index_to_stage[stage_index + 1] + ops.extend(stage.get_fwd_recv_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD: + # Previous rank doing backward has no influence for the current rank forward recv + pass + else: + raise ValueError(f"Unknown computation type {computation_type}") + + next_rank_action = None + if time_step < len(next_rank_ops): + next_rank_action = next_rank_ops[time_step] + if next_rank_action is not None: + computation_type, mb_index, stage_index = next_rank_action + # Only handle receives for the backwards from a next rank + if computation_type == _ComputationType.FORWARD: + # Next rank doing forward has no influence for the current rank backward recv + pass + elif computation_type == _ComputationType.BACKWARD: + # If not the first stage, then receive bwd gradients + if stage_index != 0: + # TODO: We are assuming that stage will always receive from stage+1 + # however that is not necessarily true of get_bwd_recv_ops + stage = stage_index_to_stage[stage_index - 1] + ops.extend(stage.get_bwd_recv_ops(mb_index)) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # do the communication + if ops: + _batch_p2p(ops).wait() + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + class ScheduleLoopedBFS(PipelineScheduleMulti): """ @@ -650,62 +746,58 @@ class ScheduleLoopedBFS(PipelineScheduleMulti): microbatches at once. """ - def _step_microbatches( + def __init__( self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, # TODO - losses: Optional[List] = None, # TODO + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ): - """ - Run one iteration of the pipeline schedule with list of microbatches. - Will go through all the microbatches according to the Looped BFS schedule. - - Args: - microbatches: list of microbatch args. - """ - # Pre-process inputs - if arg_mbs is not None: - # TODO: fix this so it is preset - self._n_microbatches = len(arg_mbs) - assert len(arg_mbs) == self._n_microbatches - else: - arg_mbs = [()] * self._n_microbatches - - if kwarg_mbs is not None: - assert len(kwarg_mbs) == self._n_microbatches - else: - kwarg_mbs = [{}] * self._n_microbatches - - for stage in self._stages: - for i in range(self._n_microbatches): - with record_function(f"Stage {stage.stage_index} Forward"): - ops = stage.get_fwd_recv_ops(i) - if ops: - _batch_p2p(ops, desc="fwd_recv").wait() + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + output_merge_spec=output_merge_spec, + ) - output = stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) - self._maybe_compute_loss(stage, output, target_mbs, i) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + # ======================================================================== + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops - ops = stage.get_fwd_send_ops(i) - if ops: - _batch_p2p(ops, desc="fwd_send") + def _calculate_single_rank_operations(self, rank): + n_local_stages = len(self._stages) + stage_indices = range( + rank, self.pp_group_size * n_local_stages, self.pp_group_size + ) - for stage in reversed(self._stages): - for i in range(self._n_microbatches): - with record_function(f"Stage {stage.stage_index} Backward"): - ops = stage.get_bwd_recv_ops(i) - if ops: - _batch_p2p(ops, desc="bwd_recv").wait() + # Store the list of operations used for that rank + rank_ops: List[Optional[_Action]] = [] + # Pre-padding, rank starts with no-ops based on the warmup. + for _ in range(rank): + rank_ops.append(None) - loss = self._maybe_get_loss(stage, i) - stage.backward_one_chunk(i, loss=loss) + for stage_index in stage_indices: + for mb_index in range(self._n_microbatches): + rank_ops.append( + _Action(_ComputationType.FORWARD, mb_index, stage_index) + ) - ops = stage.get_bwd_send_ops(i) - if ops: - _batch_p2p(ops, desc="bwd_send") + # wait for the first backward to trickle up + # which is 2 for every hop away + post_warmup_ops = 2 * (self.pp_group_size - 1 - rank) + rank_ops.extend([None] * post_warmup_ops) - self._update_losses(self._stages, losses) + for stage_index in reversed(stage_indices): + for mb_index in reversed(range(self._n_microbatches)): + rank_ops.append( + _Action(_ComputationType.BACKWARD, mb_index, stage_index) + ) + return rank_ops class ScheduleInterleaved1F1B(PipelineScheduleMulti): @@ -870,95 +962,3 @@ def backward_stage_index(step): for _ in range(self.pp_group_size - rank - 1): rank_ops.append(None) return rank_ops - - def _step_microbatches( - self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, - ): - """ - Operate on the microbatches using the interleaved 1f1b schedule. - - TODO: Interleaved 1F1B does not use sorted_batch_isend_irecv(). As a result, this schedule does - not support models with skip connections. - """ - arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - - # Based on the plan in Step 1 created in __init__: - # 2. Perform communication based on the pipeline_order - stage_index_to_stage: Dict[int, _PipelineStageBase] = { - stage.stage_index: stage for stage in self._stages - } - prev_rank: int = (self.rank - 1) % self.pp_group_size - next_rank: int = (self.rank + 1) % self.pp_group_size - - for time_step, action in enumerate(self.pipeline_order[self.rank]): - prev_rank_ops = self.pipeline_order[prev_rank] - next_rank_ops = self.pipeline_order[next_rank] - ops: List[dist.P2POp] = [] - if action is not None: - computation_type, mb_index, stage_index = action - if computation_type == _ComputationType.FORWARD: - # perform forward computation - stage = stage_index_to_stage[stage_index] - output = stage.forward_one_chunk( - mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] - ) - self._maybe_compute_loss(stage, output, target_mbs, mb_index) - ops.extend(stage.get_fwd_send_ops(mb_index)) - elif computation_type == _ComputationType.BACKWARD: - # perform backward computation - stage = stage_index_to_stage[stage_index] - loss = self._maybe_get_loss(stage, mb_index) - stage.backward_one_chunk(mb_index, loss=loss) - ops.extend(stage.get_bwd_send_ops(mb_index)) - else: - raise ValueError(f"Unknown computation type {computation_type}") - - # Look at the neighboring ranks for this current timestep and determine whether - # this current rank needs to do any recv communication - prev_rank_action = None - if time_step < len(prev_rank_ops): - prev_rank_action = prev_rank_ops[time_step] - if prev_rank_action is not None: - computation_type, mb_index, stage_index = prev_rank_action - # Only handle sends for the forward from a previous rank - if computation_type == _ComputationType.FORWARD: - # If not the last stage, then receive fwd activations - if stage_index != self._num_stages - 1: - # TODO: We are assuming that stage will always receive from stage-1 - # however that is not necessarily true of get_fwd_recv_ops - stage = stage_index_to_stage[stage_index + 1] - ops.extend(stage.get_fwd_recv_ops(mb_index)) - elif computation_type == _ComputationType.BACKWARD: - # Previous rank doing backward has no influence for the current rank forward recv - pass - else: - raise ValueError(f"Unknown computation type {computation_type}") - - next_rank_action = None - if time_step < len(next_rank_ops): - next_rank_action = next_rank_ops[time_step] - if next_rank_action is not None: - computation_type, mb_index, stage_index = next_rank_action - # Only handle receives for the backwards from a next rank - if computation_type == _ComputationType.FORWARD: - # Next rank doing forward has no influence for the current rank backward recv - pass - elif computation_type == _ComputationType.BACKWARD: - # If not the first stage, then receive bwd gradients - if stage_index != 0: - # TODO: We are assuming that stage will always receive from stage+1 - # however that is not necessarily true of get_bwd_recv_ops - stage = stage_index_to_stage[stage_index - 1] - ops.extend(stage.get_bwd_recv_ops(mb_index)) - else: - raise ValueError(f"Unknown computation type {computation_type}") - - # do the communication - if ops: - _batch_p2p(ops).wait() - # Return losses if there is a container passed in - self._update_losses(self._stages, losses) From 6c824cd9fbad51ba44aab5936c69a146a3f68d3b Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Tue, 4 Jun 2024 13:45:57 -0700 Subject: [PATCH 482/706] [BE][c10d] fix use of TORCH_ERROR in TCPStore libuv backend (#127956) **Summary** The use of TORCH_ERROR in TCPStore libuv backend code needs update. Differential Revision: [D58259589](https://our.internmc.facebook.com/intern/diff/D58259589) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127956 Approved by: https://github.com/shuqiangzhang, https://github.com/cyyever --- .../distributed/c10d/TCPStoreLibUvBackend.cpp | 62 +++++++++++++++---- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index 845803c5e17e..c70b8e7c6e87 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -186,10 +186,14 @@ class UvTcpServer : public UvTcpSocket { int uv_res = uv_tcp_open((uv_tcp_t*)res->unsafeGetStream(), socket); TORCH_CHECK( uv_res == 0, - "Failed to open existing socket. socket:{} code:{} name:{} message:{}", + "Failed to open existing socket. ", + "socket: ", socket, + ", code: ", uv_res, + ", name: ", uv_err_name(uv_res), + ", message: ", uv_strerror(uv_res)); res->cacheSocketPort(); @@ -221,30 +225,42 @@ class UvTcpServer : public UvTcpSocket { } TORCH_CHECK( uv_res == 0, - "UV Store addr parsing failure. useIpv6:{} code:{} name:{} message:{}", + "UV Store addr parsing failure. ", + "useIpv6: ", useIpv6, + ", code: ", uv_res, + ", name: ", uv_err_name(uv_res), + ", message: ", uv_strerror(uv_res)); uv_res = uv_tcp_bind(res->unsafeGetSocket(), (const struct sockaddr*)&addr, 0); TORCH_CHECK( uv_res == 0, - "UV Store bind failed. useIpv6:{} code:{} name:{} message:{}", + "The server socket has failed to bind. ", + "useIpv6: ", useIpv6, + ", code: ", uv_res, + ", name: ", uv_err_name(uv_res), + ", message: ", uv_strerror(uv_res)); uv_res = uv_listen(res->unsafeGetStream(), DEFAULT_BACKLOG, on_new_connection); TORCH_CHECK( uv_res == 0, - "UV Store listen failed. useIpv6:{} code:{} name:{} message:{}", + "The server socket has failed to listen on any local network address. ", + "useIpv6: ", useIpv6, + ", code: ", uv_res, + ", name: ", uv_err_name(uv_res), + ", message: ", uv_strerror(uv_res)); res->cacheSocketPort(); @@ -265,9 +281,12 @@ class UvTcpServer : public UvTcpSocket { uv_accept(unsafeGetStream(), (uv_stream_t*)socket->unsafeGetHandle()); TORCH_CHECK( res == 0, - "Failed to accept socket. code:{} name:{} desc:{}.", + "Failed to accept socket. ", + "code: ", res, + ", name: ", uv_err_name(res), + ", message: ", uv_strerror(res)); } @@ -458,9 +477,12 @@ class ChunkedStream { if (buff_idx >= buffers.size() && remaining > 0) { TORCH_CHECK( false, - "Trying to read past end of buffer buffer_idx:{} available:{} remaining:{}", + "Trying to read past end of buffer. ", + "buffer_idx: ", buff_idx, + ", available: ", buffers.size(), + ", remaining: ", remaining); } } @@ -498,8 +520,10 @@ class ChunkedStream { return false; TORCH_CHECK( size <= MAX_STRING_LEN, - "Invalid string size. size:{} max:{}", + "Invalid string size. ", + "size: ", size, + ", max: ", MAX_STRING_LEN); if (available() < size) @@ -515,8 +539,10 @@ class ChunkedStream { auto size_in_bytes = size * sizeof(uint8_t); TORCH_CHECK( size_in_bytes <= MAX_PAYLOAD_LEN, - "Invalid payload size. size: {} max:{}", + "Invalid payload size. ", + "size: ", size_in_bytes, + ", max: ", MAX_PAYLOAD_LEN); if (available() < size_in_bytes) @@ -782,8 +808,10 @@ class UvClient : public UvTcpSocket { return false; TORCH_CHECK( key_count <= MAX_KEY_COUNT, - "Too many keys being waited. keys:{} max:{}", + "Too many keys being waited. ", + "keys: ", key_count, + ", max: ", MAX_KEY_COUNT); std::vector keys(key_count); @@ -810,8 +838,10 @@ class UvClient : public UvTcpSocket { } TORCH_CHECK( key_count <= MAX_KEY_COUNT, - "Too many keys being waited. keys:{} max:{}", + "Too many keys being waited. ", + "keys: ", key_count, + ", max: ", MAX_KEY_COUNT); std::vector keys(key_count); @@ -872,8 +902,10 @@ class UvClient : public UvTcpSocket { } TORCH_CHECK( key_count <= MAX_KEY_COUNT, - "Too many keys with multi_get. keys:{} max:{}", + "Too many keys with multi_get. ", + "keys: ", key_count, + ", max: ", MAX_KEY_COUNT); StreamWriter sw(iptr()); @@ -898,8 +930,10 @@ class UvClient : public UvTcpSocket { } TORCH_CHECK( key_count <= MAX_KEY_COUNT, - "Too many keys with multi_get. keys:{} max:{}", + "Too many keys with multi_get. ", + "keys: ", key_count, + ", max: ", MAX_KEY_COUNT); for (const auto _ : c10::irange(key_count)) { @@ -988,9 +1022,11 @@ void LibUVStoreDaemon::init(const TCPStoreOptions& opts) { port_ = tcpServer->port(); TORCH_CHECK( port_ == opts.port || opts.port == 0, // zero means use any port - "listen fd {} is bound to port {}, expected to be bound to port {}", + "listen fd ", *opts.masterListenFd, + " is bound to port ", port_, + ", expected to be bound to port ", opts.port); } From 85758fa5ae4d232902e7da6eaa2bcc33cc96b921 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Thu, 6 Jun 2024 22:59:24 -0700 Subject: [PATCH 483/706] [c10d][TCPStore] make TCPStore server use libuv by default (#127957) **Summary** This PR switches the default TCPStore server backend to a new implementation that utilizes [`libuv`](https://github.com/libuv/libuv) for significantly lower initialization time and better scalability: image We hope this improvement would benefit users from a much shorter startup time in large-scale jobs. Eventually, we hope to fully replace the old TCPStore backend implementation with the libuv one. **What it changes** This PR changes the underlying TCPStore server backend to `libuv` if users don't explicitly specify to use the old TCPStore server. This change is not supposed to cause any user notice except significant faster TCPStore startup for large-scale jobs. One thing to note is, we do not support the initialization approach where user passes in a socket for libuv backend. We plan to support it as a next step but we choose to disable it before fully testing. If you are initializing TCPStore in this approach, you can see the next section to remain using the old TCPStore server. **Fallback/Remain using the old TCPStore server** For users who want to stay with the old TCPStore backend, there're 3 ways: 1. If user is directly instantiating TCPStore object, user can pass in argument `use_libuv=False` to use the old TCPStore server backend e.g. `store = torch.distributed.TCPStore(..., use_libuv=False)`. 2. Or, specify the TCPStore backend option in `init_method` when calling default ProcessGroup init, e.g. `torch.distributed.init_process_group(..., init_method="{YOUR_RENDEZVOUS_METHOD}://{YOUR_HOSTNAME}:{YOUR_PORT}?use_libuv=0")` 3. Or, user can set environment variable `USE_LIBUV` to `"0"` when launching. These 3 approach are in order of precedence. That being said, if user specifies `use_libuv=0` in `init_method` and also sets environment var `USE_LIBUV="1"`, the former will take effect and the TCPStore backend instantiated will be the old one instead of the one using libuv. **Operating Systems Compatibility** From the CI signals, we believe the new implementation has the same behavior as the old TCPStore server on all supported platforms. If you notice any behavior discrepancy, please file an issue with `oncall: distributed` label. **Test Plan** `pytest test/distributed/test_store.py` image note: `TestMultiThreadedWait::test_wait` is a broken test that has been there for some time. `test/distributed/elastic/utils/distributed_test.py` image **TODO** 1. Update the doc at - https://pytorch.org/docs/stable/distributed.html#distributed-key-value-store - https://pytorch.org/docs/stable/distributed.html#tcp-initialization 2. Make torch elastic rendezvous to use libuv TCPStore as well. See `torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py` cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @kurman 3. Test if libuv backend is okay with initialization with socket. Change `LibUvTCPStoreTest::test_take_over_listen_socket`. **Test Plan** `pytest test/distributed/test_store.py` image note: `TestMultiThreadedWait::test_wait` is a broken test that has been there for some time. `test/distributed/elastic/utils/distributed_test.py` image Differential Revision: [D58259591](https://our.internmc.facebook.com/intern/diff/D58259591) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127957 Approved by: https://github.com/kurman ghstack dependencies: #127956 --- test/distributed/test_store.py | 109 +++++++++++++++--- torch/csrc/distributed/c10d/TCPStore.cpp | 11 ++ torch/csrc/distributed/c10d/TCPStore.hpp | 4 +- .../distributed/c10d/TCPStoreLibUvBackend.cpp | 1 + torch/csrc/distributed/c10d/init.cpp | 3 +- torch/distributed/rendezvous.py | 16 ++- torch/testing/_internal/common_distributed.py | 2 +- 7 files changed, 123 insertions(+), 23 deletions(-) diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index 8de265a30cd8..cd126cc0d358 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -265,13 +265,17 @@ def num_keys_total(self): class TCPStoreTest(TestCase, StoreTestBase): + _use_libuv = False + def _create_store(self): - store = create_tcp_store() + store = create_tcp_store(use_libuv=self._use_libuv) store.set_timeout(timedelta(seconds=300)) return store def _create_store_with_ws(self, addr, world_size): - return create_tcp_store(addr, world_size, wait_for_workers=False) + return create_tcp_store( + addr, world_size, wait_for_workers=False, use_libuv=self._use_libuv + ) def test_address_already_in_use(self): err_msg_reg = "^The server socket has failed to listen on any local " @@ -282,8 +286,14 @@ def test_address_already_in_use(self): # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. - store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841 - store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841 + store1 = dist.TCPStore( + addr, port, 1, True, use_libuv=self._use_libuv + ) # noqa: F841 + store2 = dist.TCPStore( + addr, port, 1, True, use_libuv=self._use_libuv + ) # noqa: F841 + self.assertEqual(store1.libuvBackend, self._use_libuv) + self.assertEqual(store2.libuvBackend, self._use_libuv) @retry_on_connect_failures def test_multitenancy(self): @@ -293,8 +303,14 @@ def test_multitenancy(self): # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. - store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841 - store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841 + store1 = dist.TCPStore( + addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv + ) # type: ignore[call-arg] # noqa: F841 + store2 = dist.TCPStore( + addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv + ) # type: ignore[call-arg] # noqa: F841 + self.assertEqual(store1.libuvBackend, self._use_libuv) + self.assertEqual(store2.libuvBackend, self._use_libuv) @skip_if_win32() @retry_on_connect_failures @@ -308,6 +324,7 @@ def test_init_pg_and_rpc_with_same_socket(self): # We internally use a multi-tenant TCP store. Both PG and RPC should successfully # initialize even when using the same socket address. + os.environ["USE_LIBUV"] = "1" if self._use_libuv else "0" dist.init_process_group( backend="gloo", init_method="env://", @@ -325,6 +342,8 @@ def test_init_pg_and_rpc_with_same_socket(self): rpc_backend_options=backend_opts, ) + del os.environ["USE_LIBUV"] + assert "USE_LIBUV" not in os.environ rpc.shutdown() dist.destroy_process_group() @@ -335,8 +354,16 @@ def test_take_over_listen_socket(self): addr, port, *_ = listen_sock.getsockname() listen_fd = listen_sock.detach() - store = dist.TCPStore(addr, port, 1, is_master=True, master_listen_fd=listen_fd) + store = dist.TCPStore( + addr, + port, + 1, + is_master=True, + master_listen_fd=listen_fd, + use_libuv=self._use_libuv, + ) + self.assertEqual(store.libuvBackend, self._use_libuv) store.set("key", "value") self.assertEqual(b"value", store.get("key")) @@ -374,7 +401,11 @@ def test_numkeys_delkeys(self): def _create_client(self, index, addr, port, world_size): client_store = dist.TCPStore( - addr, port, world_size=world_size, timeout=timedelta(seconds=10) + addr, + port, + world_size=world_size, + timeout=timedelta(seconds=10), + use_libuv=self._use_libuv, ) self.assertEqual(b"value", client_store.get("key")) client_store.set(f"new_key{index}", f"new_value{index}") @@ -388,6 +419,7 @@ def _create_client(self, index, addr, port, world_size): def _multi_worker_helper(self, world_size): addr = DEFAULT_HOSTNAME server_store = self._create_store_with_ws(addr, world_size) + self.assertEqual(server_store.libuvBackend, self._use_libuv) server_store.set("key", "value") port = server_store.port @@ -403,6 +435,7 @@ def test_multi_worker_with_nonfixed_world_size(self): def test_append(self): store = self._create_store() + self.assertEqual(store.libuvBackend, self._use_libuv) store.set("foo", "po") store.append("foo", "tato") store.append("bar", "po") @@ -412,12 +445,14 @@ def test_append(self): def test_multi_set(self): store = self._create_store() + self.assertEqual(store.libuvBackend, self._use_libuv) store.multi_set(["foo", "bar"], ["po", "tato"]) self.assertEqual(b"po", store.get("foo")) self.assertEqual(b"tato", store.get("bar")) def test_multi_get(self): store = self._create_store() + self.assertEqual(store.libuvBackend, self._use_libuv) store.set("foo", "po") store.set("bar", "tato") v0, v1 = store.multi_get(["foo", "bar"]) @@ -430,7 +465,14 @@ def test_store_timeout_on_missing_clients(self): r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined.", ): # world_size is 2 so it should timeout - dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2)) + dist.TCPStore( + "localhost", + 0, + 2, + True, + timeout=timedelta(seconds=2), + use_libuv=self._use_libuv, + ) # when wait_for_workers is not set, then there should be no exception raised dist.TCPStore( @@ -440,10 +482,13 @@ def test_store_timeout_on_missing_clients(self): True, timeout=timedelta(seconds=2), wait_for_workers=False, + use_libuv=self._use_libuv, ) class LibUvTCPStoreTest(TCPStoreTest): + _use_libuv = True + def _create_store(self): store = create_tcp_store(use_libuv=True) store.set_timeout(timedelta(seconds=300)) @@ -454,6 +499,33 @@ def _create_store_with_ws(self, addr, world_size): addr, world_size, wait_for_workers=False, use_libuv=True ) + def test_take_over_listen_socket(self): + """ + override the take_over_listen_socket test in TCPStoreTest. + Reason: we have not thoroughly tested libuv TCPStore initialization using + open Socket so we decide to not support this use for now. + TODO (xilunwu): enable this use case + """ + listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_sock.bind(("localhost", 0)) + addr, port, *_ = listen_sock.getsockname() + listen_fd = listen_sock.detach() + + err_msg_reg = ( + "^The libuv TCPStore backend does not support " + "initialization with an listen fd" + ) + + with self.assertRaisesRegex(NotImplementedError, err_msg_reg): + store = dist.TCPStore( + addr, + port, + 1, + is_master=True, + master_listen_fd=listen_fd, + use_libuv=self._use_libuv, + ) + class PrefixTCPStoreTest(TestCase, StoreTestBase): def setUp(self): @@ -769,7 +841,7 @@ def test_extended_methods_fallbacks(self): class TestMultiThreadedWait(MultiThreadedTestCase): - # TODO: Use less hacky means of instantiating stores. + # TODO (xilunwu): Use less hacky means of instantiating stores. # Note, stores accumulate values per test. stores = [ dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1), @@ -777,9 +849,9 @@ class TestMultiThreadedWait(MultiThreadedTestCase): dist.PrefixStore( "pre", dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1) ), - create_tcp_store(), + create_tcp_store(use_libuv=False), create_tcp_store(use_libuv=True), - dist.PrefixStore("pre", create_tcp_store()), + dist.PrefixStore("pre", create_tcp_store(use_libuv=False)), dist.PrefixStore("pre", create_tcp_store(use_libuv=True)), ] @@ -872,7 +944,12 @@ def handler(a, b): self.assertTrue(rank_res[1], "rank1") -class InitPgWithUvStore(TestCase): +class InitPgWithNonUvStore(TestCase): + """ + This test shows how to use the legacy TCPStore (non-libuv) backend since libuv is now + the default backend. + """ + def tearDown(self): super().tearDown() os.environ.pop("USE_LIBUV", None) @@ -885,13 +962,13 @@ def test_with_url_param(self): "gloo", rank=0, world_size=1, - init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=1", + init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=0", ) self._run_test() def test_with_env_var(self): port = common.find_free_port() - os.environ["USE_LIBUV"] = "1" + os.environ["USE_LIBUV"] = "0" os.environ["MASTER_ADDR"] = DEFAULT_HOSTNAME os.environ["MASTER_PORT"] = str(port) dist.init_process_group("gloo", rank=0, world_size=1, init_method="env://") @@ -905,7 +982,7 @@ def _run_test(self): while isinstance(store, dist.PrefixStore): store = store.underlying_store self.assertTrue(isinstance(store, dist.TCPStore)) - self.assertTrue(store.libuvBackend) + self.assertFalse(store.libuvBackend) dist.destroy_process_group() diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index aee1d7677dc4..a716bf666755 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -291,6 +291,17 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) TORCH_CHECK( ::c10d::detail::is_libuv_tcpstore_backend_available(), "use_libuv was requested but PyTorch was build without libuv support"); + + if (opts.masterListenFd.has_value()) { + // TODO(xilunwu): support this init method after testing + constexpr auto* msg = + "The libuv TCPStore backend does not support initialization with an listen fd. " + "Please switch to the legacy TCPStore by setting environment variable USE_LIBUV " + "to \"0\"."; + C10D_ERROR(msg); + C10_THROW_ERROR(NotImplementedError, msg); + return; + } } Socket::initialize(); diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index 7080d50136e9..25783f2d2ace 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -63,7 +63,7 @@ struct TCPStoreOptions { std::optional masterListenFd = c10::nullopt; // A boolean value indicating whether to use the experimental libUV backend. - bool useLibUV = false; + bool useLibUV = true; }; class TORCH_API TCPStore : public Store { @@ -158,7 +158,7 @@ class TORCH_API TCPStore : public Store { const std::string keyPrefix_ = "/"; std::mutex activeOpLock_; std::unordered_map clientCounters_; - bool usingLibUv_ = false; + bool usingLibUv_ = true; }; } // namespace c10d diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index c70b8e7c6e87..d162149ed3a4 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 027e87efee56..50521c6ffa21 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1391,6 +1391,7 @@ the server to establish a connection. wait_for_workers (bool, optional): Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True. multi_tenant (bool, optional): If True, all ``TCPStore`` instances in the current process with the same host/port will use the same underlying ``TCPServer``. Default is False. master_listen_fd (int, optional): If specified, the underlying ``TCPServer`` will listen on this file descriptor, which must be a socket already bound to ``port``. Useful to avoid port assignment races in some scenarios. Default is None (meaning the server creates a new socket and attempts to bind it to ``port``). + use_libuv (bool, optional): If True, use libuv for ``TCPServer`` backend. Default is True. Example:: >>> import torch.distributed as dist >>> from datetime import timedelta @@ -1440,7 +1441,7 @@ Example:: py::arg("wait_for_workers") = true, py::arg("multi_tenant") = false, py::arg("master_listen_fd") = py::none(), - py::arg("use_libuv") = false, + py::arg("use_libuv") = true, py::call_guard()) .def( "collect_client_counters", diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 19936f910b8a..8bef92275edd 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -58,6 +58,12 @@ def _query_to_dict(query: str) -> Dict[str, str]: return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))} +def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool: + # libuv is the default backend for TCPStore. To enable the non-libuv backend, + # user can explicitly specify ``use_libuv=0`` in the URL parameter. + return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1" + + def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs): result = urlparse(url) if world_size_opt is None: @@ -145,13 +151,16 @@ def _torchelastic_use_agent_store() -> bool: return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) -def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=False) -> Store: +def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True) -> Store: """ Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. The TCPStore server is assumed to be hosted on ``hostname:port``. + By default, the TCPStore server uses the asynchronous implementation + ``LibUVStoreDaemon`` which utilizes libuv. + If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that the agent leader (node rank 0) hosts the TCPStore server (for which the endpoint is specified by the given ``hostname:port``). Hence @@ -194,7 +203,8 @@ def _error(msg): rank = int(query_dict["rank"]) world_size = int(query_dict["world_size"]) - use_libuv = query_dict.get("use_libuv", "0") == "1" + use_libuv = _get_use_libuv_from_query_dict(query_dict) + assert result.hostname is not None store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout, use_libuv) @@ -242,7 +252,7 @@ def _get_env_or_raise(env_var: str) -> str: master_addr = _get_env_or_raise("MASTER_ADDR") master_port = int(_get_env_or_raise("MASTER_PORT")) - use_libuv = query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "0")) == "1" + use_libuv = _get_use_libuv_from_query_dict(query_dict) store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 80dc47210471..473e5c35e07a 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -357,7 +357,7 @@ def create_tcp_store( timeout=timedelta(minutes=5), wait_for_workers=True, jit_class=False, - use_libuv=False + use_libuv=True, ): """ Creates a TCP store. Retries if the chosen port is already in use. From 754e6d4ad0d3936641b9bb91aaead8fb5d29d44b Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 7 Jun 2024 17:13:01 +0000 Subject: [PATCH 484/706] Make jobs with LF runners still pass lint (#128175) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128175 Approved by: https://github.com/huydhn --- .github/actionlint.yaml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index bb83775a59b2..f41a70ada6af 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -16,6 +16,17 @@ self-hosted-runner: - linux.8xlarge.nvidia.gpu - linux.16xlarge.nvidia.gpu - linux.g5.4xlarge.nvidia.gpu + # Organization-wide AWS Linux Runners on Linux Foundation account + - lf.linux.large + - lf.linux.2xlarge + - lf.linux.4xlarge + - lf.linux.12xlarge + - lf.linux.24xlarge + - lf.linux.arm64.2xlarge + - lf.linux.4xlarge.nvidia.gpu + - lf.linux.8xlarge.nvidia.gpu + - lf.linux.16xlarge.nvidia.gpu + - lf.linux.g5.4xlarge.nvidia.gpu # Repo-specific IBM hosted S390x runner - linux.s390x # Organization wide AWS Windows runners From 3aa623d407d3c031d8c2c337e43a752cda751467 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 6 Jun 2024 18:52:14 -0700 Subject: [PATCH 485/706] Fix assume_constant_result for UnspecializedNNModuleVariable methods (#127695) Fixes #127509 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127695 Approved by: https://github.com/jansel --- test/dynamo/test_export.py | 24 ++++++++++++++++++++++++ torch/_dynamo/variables/functions.py | 3 +++ 2 files changed, 27 insertions(+) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 9f1417e23247..dbf983faabb7 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -1509,6 +1509,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: graph, guards = torch._dynamo.export(model)(inp) self.assertEqual(model(inp), graph(inp)) + def test_export_with_constant_in_unspecialized_nn_module(self): + class Module(torch.nn.Module): + def __init__(self, y): + super().__init__() + self.y = y + + @torch._dynamo.assume_constant_result + def check(self): + return self.y[0].item() == 1 + + def forward(self, x): + # This line leads to module obj being tracked as UnspecializedNNModuleVariable in dynamo + self.device = x.device + + if self.check(): + return x + 1 + else: + return x + 2 + + model = Module(torch.tensor([1])) + inp = torch.ones(3, 4) + graph, _ = torch._dynamo.export(model)(inp) + self.assertEqual(model(inp), graph(inp)) + def test_export_decomp(self): def f(x): return x.t() + x.t() diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 745e29af4929..88bc94165349 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -339,6 +339,9 @@ def call_function( return self.obj.call_method( tx, self.fn.__name__, args, kwargs, constant=self.is_constant ) + if self.is_constant: + fn = getattr(self.obj.value, self.fn.__name__) + return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) return super().call_function(tx, args, kwargs) def inspect_parameter_names(self): From b741819b0580204e6a6b60c62ce44dacaf7787c8 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 6 Jun 2024 18:52:15 -0700 Subject: [PATCH 486/706] Fix 'get_attr' call in dynamo 'run_node' (#127696) Fixes #124858 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127696 Approved by: https://github.com/jansel ghstack dependencies: #127695 --- test/dynamo/test_decorators.py | 22 ++++++++++++++++++++++ torch/_dynamo/utils.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 890edca40ccc..94b0d5bf3bb6 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -465,6 +465,28 @@ def fn(a, b, c): self.assertEqual(cnt.frame_count, 1) + def test_assume_constant_result_on_user_defined_fn(self): + @torch._dynamo.assume_constant_result + def const_fn(n, s): + return torch.full([n], s) + + def fn(B): + B = const_fn(B.size(0), 13) + X = B * 2 + return X.tolist() + + B_list = [8] * 32 + + B = torch.tensor(B_list, dtype=torch.int32) + torch._dynamo.decorators.mark_static(B, 0) + + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + self.assertEqual( + fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B) + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 54c497be3781..d98c090543b1 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1908,7 +1908,7 @@ def make_error_message(e): assert nnmodule is not None return nnmodule(*args, **kwargs) elif op == "get_attr": - return tracer.get_submodule(node.target) + return tracer.output_graph.get_submodule(node.target) elif op == "placeholder": assert "example_value" in node.meta return node.meta["example_value"] From 19b31d899a78a6806314bcc73b88172dabf0c26e Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 6 Jun 2024 18:52:15 -0700 Subject: [PATCH 487/706] Fix 'get_real_value' on placeholder nodes (#127698) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127698 Approved by: https://github.com/jansel ghstack dependencies: #127695, #127696 --- test/dynamo/test_decorators.py | 16 ++++++++++++++++ torch/_dynamo/utils.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 94b0d5bf3bb6..440872ecc7bc 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -487,6 +487,22 @@ def fn(B): fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B) ) + def test_assume_constant_result_on_computation_with_graph_input(self): + @torch._dynamo.assume_constant_result + def check(y): + return y[0].item() == 1 + + def fn(x, y): + if check(y): + return x + 2 + else: + return x + 1 + + y = torch.tensor([1]) + x = torch.tensor(1) + + self.assertEqual(fn(x, y), torch.compile(fn)(x, y)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index d98c090543b1..1ae882832710 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1943,6 +1943,9 @@ def get_real_value(node, tracer): lambda n: get_real_value(n, tracer), ) + if op == "placeholder" and "grapharg" in node.meta: + return node.meta["grapharg"].example + if op == "call_module": nn_module = tracer.output_graph.nn_modules[node.target] if not is_lazy_module(nn_module): From 662a78f957fb89e53ebeba7deb880561e10ecaf6 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 6 Jun 2024 16:01:59 -0700 Subject: [PATCH 488/706] [dynamo] Inline the getattr of fx graph and proxy graph (#128172) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128172 Approved by: https://github.com/yanboliang ghstack dependencies: #128001, #126578, #128158 --- test/dynamo/test_inline_inbuilt_nn_modules.py | 3 +++ torch/_dynamo/trace_rules.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/test/dynamo/test_inline_inbuilt_nn_modules.py b/test/dynamo/test_inline_inbuilt_nn_modules.py index f7ba32bc15f3..0bd7f573e6e9 100644 --- a/test/dynamo/test_inline_inbuilt_nn_modules.py +++ b/test/dynamo/test_inline_inbuilt_nn_modules.py @@ -6,6 +6,7 @@ try: from . import ( test_aot_autograd, + test_export, test_functions, test_higher_order_ops, test_misc, @@ -14,6 +15,7 @@ ) except ImportError: import test_aot_autograd + import test_export import test_functions import test_higher_order_ops import test_misc @@ -50,6 +52,7 @@ def make_inline_inbuilt_nn_modules_cls(cls): test_higher_order_ops.HigherOrderOpTests, test_higher_order_ops.FuncTorchHigherOrderOpTests, test_aot_autograd.AotAutogradFallbackTests, + test_export.ExportTests, # test_repros.ReproTests, ] for test in tests: diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 73c4beb547ee..5fc8398e3005 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3231,6 +3231,8 @@ def _module_dir(m: types.ModuleType): "torch.cuda.amp.autocast_mode", "torch.distributions", "torch.fx._pytree", + "torch.fx._symbolic_trace", + "torch.fx.experimental.proxy_tensor", "torch.fx.passes.shape_prop", "torch.nn", "torch.random", From 0c7f4353e50e5bfa9ba42c14e6890b30ac91bbba Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Fri, 7 Jun 2024 10:25:10 -0700 Subject: [PATCH 489/706] [inductor] simplify indexing (#127661) This is a short term fix for: https://github.com/pytorch/pytorch/issues/124002 We found the cause of bad perf for the int8_unpack kernel is due to sub-optimal indexing. In this PR we introduce 2 indexing optimizations: 1. expand FloorDiv to the entire expression when feasible. E.g. `x1 * 1024 + x2 // 2` will be transformed to `(x1 * 2048 + x2) // 2`. The motivation is that we have more chance to simplify loops for `x1 * 2048 + x2`. 2. merge ModularIndexing pairs: `ModularIndexing(ModularIndex(x, 1, a), 1, b)`, can be simplified to `ModularIndexing(x, 1, b)` if a is a multiple of b. With both indexing optimizations, we improve int8_unpack perf by 1.54x (183us -> 119us). Pull Request resolved: https://github.com/pytorch/pytorch/pull/127661 Approved by: https://github.com/jansel --- test/inductor/test_indexing.py | 78 ++++++++++++++++++- torch/_inductor/codegen/simd.py | 19 ++++- torch/_inductor/sizevars.py | 131 ++++++++++++++++++++++++++++++++ 3 files changed, 226 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 299a619f9cd6..a3a7bf4b83ab 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -1,18 +1,28 @@ # Owner(s): ["module: inductor"] +import os +import unittest + import sympy +import torch + from torch._inductor.codegen.cpp import cexpr from torch._inductor.codegen.triton import texpr from torch._inductor.codegen.wrapper import pexpr +from torch._inductor.runtime.runtime_utils import do_bench_gpu from torch._inductor.sizevars import SizeVarAllocator from torch._inductor.test_case import TestCase as InductorTestCase +from torch._inductor.utils import run_and_get_triton_code from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, ) +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Round, RoundDecimal +DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" + class TestIndexingSimplification(InductorTestCase): def test_indexing_simplification(self): @@ -159,6 +169,73 @@ def test_indexing_join(self): self.assertEqual(simplified, FloorDiv(i0, 3)) self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485})) + def test_modular_indexing_pairs_merged(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + a = 1024 + b = 32 + expr1 = ModularIndexing(x, 1, a) + expr2 = ModularIndexing(expr1, 1, b) + expected = ModularIndexing(x, 1, b) + + actual = sizevars.combine_modular_indexing_pairs(expr2) + self.assertEqual(expected, actual) + self.assertNotEqual(expr2, actual) + + def test_modular_indexing_pairs_not_merged(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + a = 1024 + b = 3 # pick a 'b' that we can not merge + expr1 = ModularIndexing(x, 1, a) + expr2 = ModularIndexing(expr1, 1, b) + + actual = sizevars.combine_modular_indexing_pairs(expr2) + self.assertEqual(expr2, actual) + self.assertNotEqual(ModularIndexing(x, 1, b), actual) + + def test_expand_floor_div_skipped(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + y = sympy.Symbol("y", integer=True, positive=True) + + expr = FloorDiv(x, 2) + FloorDiv(y, 3) + # The expression can not be simplified since there are multiple + # FloorDiv. We return False in that case + self.assertFalse(sizevars.expand_floor_div(expr)) + + def test_expand_floor_div_applied(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + y = sympy.Symbol("y", integer=True, positive=True) + + expr = x * 5 + FloorDiv(y, 3) + actual, denominator = sizevars.expand_floor_div(expr) + self.assertNotEqual(expr, actual) + expected = FloorDiv(x * 15 + y, 3) + self.assertEqual(expected, FloorDiv(actual, denominator)) + + @unittest.skipUnless(HAS_CUDA, "Need GPU for this test") + def test_int8_unpack(self): + @torch.compile + def f(x): + first_elements = x >> 4 + second_elements = x & 15 + unpacked = torch.stack([first_elements, second_elements], dim=-1).view( + *x.size()[:-1], -1 + ) + return unpacked * 2 + + x = torch.randint(0, 255, (2, 4096, 5504), dtype=torch.uint8, device="cuda") + + triton_code = run_and_get_triton_code(f, x) + # Make sure the 2 load uses simpified indexing rather than something like + # tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), + self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),")) + if DO_PERF_TEST: + ms = do_bench_gpu(lambda: f(x)) + print(f"{ms=:.03f}") + class ExprPrinterTests(InductorTestCase): def test_print_pow(self): @@ -315,7 +392,6 @@ def test_print_Min_Max(self): if __name__ == "__main__": from torch._inductor.test_case import run_tests - from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA if HAS_CPU or HAS_CUDA: run_tests("sympy") diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index ed7261f2a3eb..c5fc2747bee7 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -338,7 +338,8 @@ def simplify_indexing(index: sympy.Expr): index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) for tree in self.range_trees: index = self.combine_contiguous_dims(index, tree) - return index + + return self.combine_modular_indexing_pairs(index) self.simplify_indexing = simplify_indexing self.initialize_range_tree(pid_cache) @@ -422,7 +423,23 @@ def dense_size_str(self): sizes = self.dense_size_list() return f"[{', '.join(sizes)}]" + def combine_modular_indexing_pairs(self, index): + if not isinstance(index, ModularIndexing): + return index + x = index.args[0] + if (tree_node := self.range_tree_nodes.get(x)) is None: + return index + new_index = sympy_subs(index, {x: tree_node.expr}) + return V.graph.sizevars.combine_modular_indexing_pairs(new_index) + def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): + if expand_res := V.graph.sizevars.expand_floor_div(index): + new_index, denominator = expand_res # type: ignore[misc] + return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator) + else: + return self._combine_contiguous_dims(index, tree) + + def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): """ More aggressive simplification to merge contiguous dims """ diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index bc8803a5e715..abb95503e8e2 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -583,6 +583,137 @@ def lookup_precomputed_size(self, expr: Expr) -> Expr: def free_symbols(self) -> Set[sympy.Symbol]: return set(self.var_to_val.keys()) - set(self.replacements.keys()) + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: + """ + A pair of special ModularIndexing can be combined. + + E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b) + We can simplify this to ModuleIndexing(x, 1, b), if + 1. x is non negative integer + 2. a and b are positive integers + 3. a is a multiple of b. + """ + + def _check_args(x, div, mod, is_first): + if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer): + return False + if div != 1: + return False + if mod <= 0: + return False + + if is_first: + # first ModularIndexing should conatins a nested ModularIndex + if not isinstance(x, ModularIndexing): + return False + else: + # second ModularIndexing should constains a non-negative + # symbol + if not isinstance(x, sympy.Symbol) or not self.statically_known_geq( + x, 0 + ): + return False + return True + + if isinstance(index, ModularIndexing): + x, div, mod = index.args + + if not _check_args(x, div, mod, True): + return index + + x2, div2, mod2 = x.args + + if not _check_args(x2, div2, mod2, False): + return index + + if mod2 % mod != 0: + return index + + return ModularIndexing(x2, 1, mod) + + return index + + def expand_floor_div( + self, index: sympy.Expr + ) -> Union[bool, Tuple[sympy.Expr, sympy.Expr]]: + """ + Expand the FloorDiv to the entire expression so that the expression may + be simplfied. + + E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables + x1, x2, index expression 'x1 * 2b + x2' can be easily combined. + But index expression 'x1 * b + x2 // 2' can not. + By expanding the FloorDiv to the entire expression, we get + '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops + for the numerator! + + Return false if this optimization can be applied; + Return the new expression and the denominator otherwise. + The original expression will be equivalent to 'new_expression // denominator' + """ + if not isinstance(index, sympy.Add): + return False + terms = index.args + + if len(terms) < 2: + return False + floor_div_index = -1 + varlist = [] + factorlist = [] + for idx, term in enumerate(terms): + if isinstance(term, sympy.Mul): + # For dynamic shape, term like '2*s1*x1' has 3 child nodes. + # - A integer for 2 + # - A symbol for s1 + # - A symbol for x1 + # Skip for now. + if len(term.args) != 2: + return False + factor, var = term.args + varlist.append(var) + factorlist.append(factor) + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + # It's easier to reason about the correceness of the transformation + # for non-negative integers. + if not self.statically_known_geq(var, 0): + return False + elif isinstance(term, FloorDiv): + var, factor = term.args + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + if not self.statically_known_geq(var, 0): + return False + if floor_div_index >= 0: + # can not handle multi FloorDiv yet + return False + + floor_div_index = idx + varlist.append(var) + # this factor is denominator + factorlist.append(factor) + else: + return False + + if floor_div_index < 0: + return False + + # Construct the new expression and remember the denominator + denominator = factorlist[floor_div_index] + new_index = sympy.Integer(0) + + for var, factor, idx in zip(varlist, factorlist, itertools.count()): + if idx == floor_div_index: + new_index += var + else: + new_index += (factor * denominator) * var + + return new_index, denominator + def join_dimensions(expr: Expr) -> Expr: if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): From 82d7a36a27a26a4904f258fcdb79e37e91b510b6 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 7 Jun 2024 17:52:13 +0000 Subject: [PATCH 490/706] Added torchao nightly workflow (#128152) Summary: Add torchao benchmark workflow, upload the artifacts to GHA. X-link: https://github.com/pytorch/benchmark/pull/2273 Test Plan: ``` python run_benchmark.py torchao --ci ``` Differential Revision: D58140479 Pulled By: xuzhao9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128152 Approved by: https://github.com/jerryzh168 --- benchmarks/dynamo/common.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 2b685b8926b3..39c3a3cda3e3 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -3975,9 +3975,12 @@ def run(runner, args, original_dir=None): assert "cuda" in args.devices, "Quantization requires CUDA device." assert args.bfloat16, "Quantization requires dtype bfloat16." try: - from .torchao_backend import setup_baseline, torchao_optimize_ctx - except ImportError: from torchao_backend import setup_baseline, torchao_optimize_ctx + except ImportError: + from userbenchmark.dynamo.dynamobench.torchao_backend import ( + setup_baseline, + torchao_optimize_ctx, + ) setup_baseline() baseline_ctx = functools.partial( From 0a6df4fca67423fb000f0568c104198b3865ab7c Mon Sep 17 00:00:00 2001 From: _daohang_ Date: Fri, 7 Jun 2024 18:05:46 +0000 Subject: [PATCH 491/706] delete inductor config.trace.compile_profile (#127143) Fixes #ISSUE_NUMBER https://fb.workplace.com/groups/257735836456307/posts/687858786777341/?comment_id=687861123443774&reply_comment_id=687865486776671 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127143 Approved by: https://github.com/Chillee --- torch/_inductor/debug.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index ef1beb7c15a4..dcc2b3ab3e4c 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -1,6 +1,5 @@ import collections import contextlib -import cProfile import dataclasses import functools import itertools @@ -388,9 +387,6 @@ def reset_log_level(level): self._setup_log_capture("debug.log", logging.DEBUG) if config.trace.info_log: self._setup_log_capture("info.log", logging.INFO) - if config.trace.compile_profile: - self._prof = cProfile.Profile() - self._prof.enable() def _setup_log_capture(self, filename: str, level: int): log = logging.getLogger("torch._inductor") From 8ca4cefc7da745d7ec766d5f0336bbcc51f17a15 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 7 Jun 2024 07:57:58 -0700 Subject: [PATCH 492/706] [C10D] Ensure gil is not released when calling toPyBytes (#128212) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128212 Approved by: https://github.com/Skylion007, https://github.com/XilunWu --- torch/csrc/distributed/c10d/init.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 50521c6ffa21..6f1b28886b98 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1016,11 +1016,13 @@ Example:: const std::string& key, const std::string& expected_value, const std::string& desired_value) -> py::bytes { - auto value = store.compareSet( - key, toVec8(expected_value), toVec8(desired_value)); + auto value = [&]() { + py::gil_scoped_release guard; + return store.compareSet( + key, toVec8(expected_value), toVec8(desired_value)); + }(); return toPyBytes(value); }, - py::call_guard(), R"( Inserts the key-value pair into the store based on the supplied ``key`` and performs comparison between ``expected_value`` and ``desired_value`` before inserting. ``desired_value`` From cafbcb63762e13d463fc173be411be4daa0c769d Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Fri, 7 Jun 2024 18:41:32 +0000 Subject: [PATCH 493/706] [BE]: Update ruff to 0.4.8 (#128214) Updates ruff to 0.4.8. Some minor fixes, but noticably is 10% faster on microbenchmark and should further reduce local and CI runtime of the linter. Also includes a few bugfixes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128214 Approved by: https://github.com/ezyang --- .lintrunner.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 874a553ee9bc..7d9cfa39916e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -2079,7 +2079,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'ruff==0.4.6', + 'ruff==0.4.8', ] is_formatter = true From dcb63fcedb062f2346642a047da944b08879b04d Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Fri, 7 Jun 2024 11:47:15 -0700 Subject: [PATCH 494/706] [pipelining] Remove num_microbatches from stage (#128201) This is similar to https://github.com/pytorch/pytorch/pull/127979, but instead of removing `num_microbatches` from schedule, we remove it from `PipelineStage`. This also means that during `PipelineSchedule` init we need to setup the buffers for the stage(s). Pull Request resolved: https://github.com/pytorch/pytorch/pull/128201 Approved by: https://github.com/kwen2501 --- .../pipelining/test_composability.py | 1 - test/distributed/pipelining/test_schedule.py | 8 ++- test/distributed/pipelining/test_stage.py | 1 - .../pipelining/PipelineSchedule.py | 13 ++++ torch/distributed/pipelining/PipelineStage.py | 63 ++++++++++--------- 5 files changed, 54 insertions(+), 32 deletions(-) diff --git a/test/distributed/pipelining/test_composability.py b/test/distributed/pipelining/test_composability.py index 3503001ba49e..0ef42fe90ddd 100644 --- a/test/distributed/pipelining/test_composability.py +++ b/test/distributed/pipelining/test_composability.py @@ -134,7 +134,6 @@ def build_stage(stage_idx, num_stages): self.device, group=pp_group, input_args=input_mb[0], - num_microbatches=num_microbatches, ) return stage, offset diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 7bd5825ce9ed..ef9a63688230 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -53,6 +53,12 @@ def __init__(self, *args, **kwargs): def _create_grad_recv_info(self, *args, **kwargs): return None + def _prepare_forward_infra(self, n_microbatches): + pass + + def _prepare_backward_infra(self, n_microbatches): + pass + class ScheduleTest(MultiProcContinousTest): @classmethod @@ -274,7 +280,6 @@ def test_grad_with_manual(self, ScheduleClass): self.rank, self.world_size, self.device, - chunks, input_args=x.chunk(chunks)[0], ) @@ -358,7 +363,6 @@ def test_grad_with_manual_interleaved(self, ScheduleClass): stage_idx, n_stages, self.device, - chunks, input_args=input_args, ) for stage_module, stage_idx in zip(stage_modules, stage_indices) diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index ec459af7a596..b11a6037f604 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -220,7 +220,6 @@ def test_manual(self): self.rank, self.world_size, self.device, - chunks, input_args=x.chunk(chunks)[0], ) diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index 964d4e88bece..2de3c4eef85d 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -309,6 +309,12 @@ def __init__( # Set the same has_backward flag for stage object self._stage.has_backward = self._has_backward + # TODO: later replace this with lazy shape inference during forward + # Prepare forward send/recv infrastructure for stage + stage._prepare_forward_infra(n_microbatches) + if self._has_backward: + stage._prepare_backward_infra(n_microbatches) + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): """ Run one iteration of the pipeline schedule with *whole-batch* input. @@ -608,6 +614,13 @@ def __init__( # This will be set during init of derived schedules self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + # TODO: later replace this with lazy shape inference during forward + # Prepare forward send/recv infrastructure for stage + for stage in self._stages: + stage._prepare_forward_infra(n_microbatches) + if self._has_backward: + stage._prepare_backward_infra(n_microbatches) + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): """ Run one iteration of the pipeline schedule with *whole-batch* input. diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index 5761e03d689a..fbfca518df4d 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -89,7 +89,6 @@ def __init__( stage_index: int, num_stages: int, device: torch.device, - num_microbatches: int, group: Optional[dist.ProcessGroup] = None, ): """ @@ -113,7 +112,6 @@ def __init__( self.stage_index = stage_index self.num_stages = num_stages self.device = device - self.chunks = num_microbatches self.group = group # `group_rank` is rank in process group `group`. @@ -159,6 +157,9 @@ def __init__( # grad reduction in DDP or FSDP. self._seen_bwd_chunks = 0 + # To be populated later + self.chunks: Optional[int] = None + @property def has_backward(self) -> bool: """ @@ -185,6 +186,10 @@ def is_last(self): return self.stage_index == self.num_stages - 1 def _check_chunk_id(self, chunk_id: int): + if self.chunks is None: + raise RuntimeError( + "Attempted to access chunk_id before chunks have been configured." + ) if chunk_id >= self.chunks: raise RuntimeError( f"Chunk id {chunk_id} is out of range [0, {self.chunks})" @@ -236,6 +241,20 @@ def map_recv_to_send(a): ) return grad_send_info + @abstractmethod + def _prepare_forward_infra(self, num_microbatches: int): + raise NotImplementedError + + def _prepare_backward_infra(self, num_microbatches: int): + # TODO: this is needed for backward_maybe_with_nosync + self.chunks = num_microbatches + + for mb_index in range(num_microbatches): + # `grad_recv_info` is a mirror of `act_send_info` + self.grad_recv_info[mb_index] = self._create_grad_recv_info( + self.act_send_info + ) + @abstractmethod def _create_grad_recv_info( self, @@ -292,13 +311,7 @@ def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: if not self.has_backward or self.is_last: return [] - # Create bwd recv infra lazily - recv_infos = self.grad_recv_info.setdefault( - bwd_chunk_id, - # `grad_recv_info` is a mirror of `act_send_info` - self._create_grad_recv_info(self.act_send_info), - ) - + recv_infos = self.grad_recv_info[bwd_chunk_id] return self._get_recv_ops(recv_infos) def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: @@ -449,7 +462,7 @@ def backward_maybe_with_nosync(self, bwd_kwargs: Dict): there are additional state-variables and performance considerations depending on the data parallelism used. This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. """ - last_backward = self._seen_bwd_chunks == self.chunks - 1 + last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator] # If submod is wrapped by DDP if isinstance(self.submod, DistributedDataParallel): @@ -643,7 +656,6 @@ def __init__( stage_index, pipe_info.num_stages, device, - num_chunks, group, ) self.pipe_info = pipe_info @@ -670,9 +682,6 @@ def __init__( for i, node in enumerate(submod_nodes): self.submod_to_stage_index.setdefault(node.name, i) - # Prepare forward send/recv infrastructure - self._prepare_forward_infra() - # Cast submodule to device self._move_submod_to_device() # Move ops argument to device @@ -700,13 +709,13 @@ def _move_ops_to_device(self): if isinstance(self.submod, torch.fx.GraphModule): modify_graph_op_device(self.submod, self.device) - def _prepare_forward_infra(self): + def _prepare_forward_infra(self, num_microbatches: int): """ Create send/recv infrastructures for activations (during forward) """ # Flag per chunk to keep track of whether we have set `requires_grad` # for receive buffers. Format: {chunk : Boolean} - for chunk in range(self.chunks): + for chunk in range(num_microbatches): self.args_recv_info[chunk] = self._create_act_recv_info() self.set_requires_grad[chunk] = False @@ -1099,14 +1108,11 @@ def __init__( stage_index: int, num_stages: int, device: torch.device, - num_microbatches: int, input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]], output_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None, group: Optional[dist.ProcessGroup] = None, ): - super().__init__( - submodule, stage_index, num_stages, device, num_microbatches, group - ) + super().__init__(submodule, stage_index, num_stages, device, group) self.submod.to(self.device) # When we materialize the model partition on cuda, we call reset_parameters() if it is available self.inputs: List[torch.Tensor] = [] @@ -1138,9 +1144,17 @@ def stage_global_rank(peer_rank): self.prev_stage = stage_global_rank((self.group_rank - 1) % self.group_size) self.next_stage = stage_global_rank((self.group_rank + 1) % self.group_size) + logger.debug( + f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 + f"{self.is_last=}, {self.num_stages=}, " + f"inputs: {[inp.shape for inp in self.inputs]}, " + f"output: {[output.shape for output in self.outputs]}" + ) + + def _prepare_forward_infra(self, num_microbatches: int) -> None: # Receive info during forward # TODO: create args_recv_info lazily? (same needed for PipelineStage) - for chunk_id in range(self.chunks): + for chunk_id in range(num_microbatches): self.set_requires_grad[chunk_id] = False if not self.is_first: # We assume that we always receive from stage - 1 @@ -1171,13 +1185,6 @@ def stage_global_rank(peer_rank): else: self.act_send_info[idx] = [] - logger.debug( - f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 - f"{self.is_last=}, {self.num_stages=}, " - f"inputs: {[inp.shape for inp in self.inputs]}, " - f"output: {[output.shape for output in self.outputs]}" - ) - def _create_grad_recv_info( self, act_send_info: Dict, From e647ea55a3b6be5900bbc5f71a75d256e7a2d43b Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 7 Jun 2024 12:32:28 -0700 Subject: [PATCH 495/706] [pipelining] redirect README to document (#128205) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128205 Approved by: https://github.com/wconstab, https://github.com/H-Huang --- torch/distributed/pipelining/README.md | 175 +------------------------ 1 file changed, 2 insertions(+), 173 deletions(-) diff --git a/torch/distributed/pipelining/README.md b/torch/distributed/pipelining/README.md index 556814c29b37..d4c9aaafa5b3 100644 --- a/torch/distributed/pipelining/README.md +++ b/torch/distributed/pipelining/README.md @@ -1,178 +1,7 @@ # Pipeline Parallelism for PyTorch -> [!NOTE] -> `torch.distributed.pipelining` is a package migrated from the [PiPPy](https://github.com/pytorch/PiPPy) project. It is currently in alpha state and under extensive development. If you need examples that work with our APIs, please refer to PiPPy's [examples](https://github.com/pytorch/PiPPy/tree/main/examples) directory. +`torch.distributed.pipelining` is a package for implementing pipeline parallelism on your model. -[**Why Pipeline Parallel?**](#why-pipeline-parallel) -| [**What is `torch.distributed.pipelining`?**](#what-is-torchdistributedpipelining) -| [**Examples**](#examples) -| [**Techniques Explained**](#techniques-explained) - -# Why Pipeline Parallel? - -One of the most important techniques for advancing the state of the art in deep learning is scaling. Common techniques for scaling neural networks include _data parallelism_, _tensor/operation parallelism_, and _pipeline parallelism_. In many cases, pipeline parallelism in particular can be an effective technique for scaling, however it is often difficult to implement, requiring intrusive code changes to model code and difficult-to-implement runtime orchestration code. `torch.distributed.pipelining` aims to provide a toolkit that does said things automatically to allow high-productivity scaling of models. - -# What is `torch.distributed.pipelining`? - -`torch.distributed.pipelining` consists of a compiler and runtime stack for automated pipelining of PyTorch models. Pipelining, or _pipeline parallelism_, is a technique in which the _code_ of the model is partitioned and multiple _micro-batches_ execute different parts of the model code concurrently. To learn more about pipeline parallelism, see [this article](https://www.deepspeed.ai/tutorials/pipeline/). +Our documentation is available [here](https://pytorch.org/docs/main/distributed.pipelining.html). ![pipeline_diagram_web](https://github.com/pytorch/PiPPy/assets/6676466/c93e2fe7-1cd4-49a2-9fd8-231ec9905e0c) - -Figure: Pipeline parallel. "F", "B" and "U" denote forward, backward and weight update, respectively. Different colors represent different micro-batches. - -`torch.distributed.pipelining` provides the following features that make pipeline parallelism easier: - -* Automatic splitting of model code based on your specification. The goal is for the user to provide model code as-is to the system for parallelization, without having to make heavyweight modifications to make parallelism work. The specification is also simple. -* Support for rich pipeline scheduling paradigms, including GPipe, 1F1B, Interleaved 1F1B and Looped BFS. More schedules will be added and it will be easy to customize your own schedule under `torch.distributed.pipelining`'s framework. -* First-class support for cross-host pipeline parallelism, as this is where PP is typically used (over slower interconnects). -* Composability with other PyTorch parallel schemes such as data parallelism (DDP, FSDP) or tensor parallelism (overall, known as "3d parallelism"). - -# Examples - -In the [PiPPy](https://github.com/pytorch/PiPPy) repo where this package is migrated from, we provide rich examples based on realistic models. In particular, we show how to apply pipelining without any model code change. You can refer to the [HuggingFace examples directory](https://github.com/pytorch/PiPPy/tree/main/examples/huggingface). Popular examples include: [GPT2](https://github.com/pytorch/PiPPy/tree/main/examples/huggingface/pippy_gpt2.py), and [LLaMA](https://github.com/pytorch/PiPPy/tree/main/examples/llama). - -# Techniques Explained - -`torch.distributed.pipelining` consists of two parts: a _compiler_ and a _runtime_. The compiler takes your model code, splits it up, and transforms it into a `Pipe`, which is a wrapper that describes the model at each pipeline stage and their data-flow relationship. The runtime executes the `PipelineStage`s in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc. We will cover the APIs for these concepts in this section. - -## Splitting a Model with `pipeline` - -To see how we can split a model into a pipeline, let's first take an example trivial neural network: - -```python -import torch - -class MyNetworkBlock(torch.nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.lin = torch.nn.Linear(in_dim, out_dim) - - def forward(self, x): - x = self.lin(x) - x = torch.relu(x) - return x - - -class MyNetwork(torch.nn.Module): - def __init__(self, in_dim, layer_dims): - super().__init__() - - prev_dim = in_dim - for i, dim in enumerate(layer_dims): - setattr(self, f'layer{i}', MyNetworkBlock(prev_dim, dim)) - prev_dim = dim - - self.num_layers = len(layer_dims) - # 10 output classes - self.output_proj = torch.nn.Linear(layer_dims[-1], 10) - - def forward(self, x): - for i in range(self.num_layers): - x = getattr(self, f'layer{i}')(x) - - return self.output_proj(x) - - -in_dim = 512 -layer_dims = [512, 1024, 256] -mn = MyNetwork(in_dim, layer_dims).to(device) -``` - -This network is written as free-form Python code; it has not been modified for any specific parallelism technique. - -Let us see our first usage of the `torch.distributed.pipelining` interfaces: - -```python -from torch.distributed.pipelining import annotate_split_points, pipeline, Pipe, SplitPoint - -annotate_split_points(mn, {'layer0': SplitPoint.END, - 'layer1': SplitPoint.END}) - -batch_size = 32 -example_input = torch.randn(batch_size, in_dim, device=device) -chunks = 4 - -pipe = pipeline(mn, chunks, example_args=(example_input,)) -print(pipe) - -""" -************************************* pipe ************************************* -GraphModule( - (submod_0): GraphModule( - (layer0): InterpreterModule( - (lin): InterpreterModule() - ) - ) - (submod_1): GraphModule( - (layer1): InterpreterModule( - (lin): InterpreterModule() - ) - ) - (submod_2): GraphModule( - (layer2): InterpreterModule( - (lin): InterpreterModule() - ) - (output_proj): InterpreterModule() - ) -) - -def forward(self, arg8_1): - submod_0 = self.submod_0(arg8_1); arg8_1 = None - submod_1 = self.submod_1(submod_0); submod_0 = None - submod_2 = self.submod_2(submod_1); submod_1 = None - return (submod_2,) -""" -``` - -So what's going on here? First, `pipeline` turns our model into a directed acyclic graph (DAG) by tracing the model. Then, it groups together the operations and parameters into _pipeline stages_. Stages are represented as `submod_N` submodules, where `N` is a natural number. - -We used `annotate_split_points` to specify that the code should be split and the end of `layer0` and `layer1`. Our code has thus been split into _three_ pipeline stages. Our library also provides `SplitPoint.BEGINNING` if a user wants to split before certain annotation point. - -While the `annotate_split_points` API gives users a way to specify the split points without modifying the model, our library also provides an API for in-model annotation: `pipe_split()`. For details, you can read [this example](https://github.com/pytorch/PiPPy/blob/main/test/test_pipe.py). - -This covers the basic usage of the `Pipe` API. For more information, please see the documentation. - - - -## Using PipelineStage for Pipelined Execution - -Given the above `Pipe` object, we can use one of the `PipelineStage` classes to execute our model in a pipelined fashion. First off, let us instantiate a `PipelineStage` instance: - -```python -# We are using `torchrun` to run this example with multiple processes. -# `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`. -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) - -# Initialize distributed environment -import torch.distributed as dist -dist.init_process_group(rank=rank, world_size=world_size) - -# Pipeline stage is our main pipeline runtime. It takes in the pipe object, -# the rank of this process, and the device. -from torch.distributed.pipelining import PipelineStage -stage = TracerPipelineStage(pipe, rank, device) -``` - -We can now run the pipeline by attaching the `PipelineStage` to a pipeline schedule, GPipe for example: - -```python -from torch.distributed.pipelining import ScheduleGPipe -schedule = ScheduleGPipe(stage, chunks) - -# Input data -x = torch.randn(batch_size, in_dim, device=device) - -# Run the pipeline with input `x`. Divide the batch into 4 micro-batches -# and run them in parallel on the pipeline -if rank == 0: - schedule.step(x) -else: - output = schedule.step() -``` - -Note that since we split our model into three stages, we must run this script with three workers. For this example, we will use `torchrun` to run multiple processes within a single machine for demonstration purposes. We can collect up all of the code blocks above into a file named [example.py](https://github.com/pytorch/PiPPy/tree/main/examples/basic) and then run it with `torchrun` like so: - -``` -torchrun --nproc_per_node=3 example.py -``` From fdf1666b20f63e4acf01798f009e478d997a7f7f Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 7 Jun 2024 20:12:49 +0000 Subject: [PATCH 496/706] Change lerp decomp to use aten.as_strided_copy instead of prims.copy_strided (#128030) aten.lerp decomposition causes prims::copy_strided to appear in the graph, which is not core aten. Internal ref: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1525644288305859/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/128030 Approved by: https://github.com/Skylion007, https://github.com/zou3519 --- torch/_refs/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 68675c751736..4e00a125434f 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -5014,7 +5014,11 @@ def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]): # make sure the decomposition output's stride is same as non-decomposition path. stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs)) if output.stride() != stride: - output = prims.copy_strided(output, stride) + output = torch.ops.aten.as_strided_copy( + output, + output.size(), + stride, + ) return handle_noncontiguous_outputs(inputs, output) From 8892ddaaccf7f07d64e2e819d868c0e95bc53e74 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 7 Jun 2024 20:19:18 +0000 Subject: [PATCH 497/706] [TD] Test removal on sm86 (#127131) Yolo I'm excited to break CI :') Pull Request resolved: https://github.com/pytorch/pytorch/pull/127131 Approved by: https://github.com/huydhn, https://github.com/ZainRizvi --- test/run_test.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/test/run_test.py b/test/run_test.py index 065e24f90801..57e69c0d979c 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1180,22 +1180,15 @@ def parse_args(): or (IS_WINDOWS and not TEST_CUDA) or TEST_CONFIG == "nogpu_AVX512" or TEST_CONFIG == "nogpu_NO_AVX2" - or ( - "sm86" not in BUILD_ENVIRONMENT - and TEST_CONFIG == "default" - and TEST_CUDA - ) - or (not TEST_CUDA and TEST_CONFIG == "default") + or TEST_CONFIG == "default" ) and get_pr_number() is not None and not strtobool(os.environ.get("NO_TD", "False")) - and not IS_SLOW and not TEST_WITH_ROCM and not IS_MACOS and "xpu" not in BUILD_ENVIRONMENT and "onnx" not in BUILD_ENVIRONMENT - and "debug" not in BUILD_ENVIRONMENT - and "parallelnative" not in BUILD_ENVIRONMENT, + and os.environ.get("GITHUB_WORKFLOW", "slow") in ("trunk", "pull"), ) parser.add_argument( "--shard", From 3a620a0f653f26c638a72ce97abe557659b3a8c3 Mon Sep 17 00:00:00 2001 From: dshi7 Date: Fri, 7 Jun 2024 20:47:25 +0000 Subject: [PATCH 498/706] bug fix of dynamo_timed in cprofile (#128203) Fixes #ISSUE_NUMBER fb-only: "Entire Frame" was missing before this change. Before: https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/f565966006-TrainingApplication/20240527/rank_0/5_0_1/compilation_metrics_23.html After: https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/f569854578-TrainingApplication/20240606/rank_0/0_0_0/compilation_metrics_16.html Pull Request resolved: https://github.com/pytorch/pytorch/pull/128203 Approved by: https://github.com/Chillee --- torch/_dynamo/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 1ae882832710..238e2b8227cd 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -214,9 +214,6 @@ def _add_time_spent(key, phase_name, time_spent): def dynamo_timed(original_function=None, phase_name=None, fwd_only=True): def dynamo_timed_inner(func): - if config.cprofile: - return func - @wraps(func) def time_wrapper(*args, **kwargs): key = func.__qualname__ From ba81c3c2909e587a67b838eda3438bb7d262b533 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Fri, 7 Jun 2024 20:49:56 +0000 Subject: [PATCH 499/706] [inductor] add cpp builder code. (take 2) (#125849) Fully manual rebase the code of PR: https://github.com/pytorch/pytorch/pull/124045 The old PR seems crashed due to too many commits, and too many times rebase. Please reference: https://github.com/pytorch/pytorch/pull/124045#issuecomment-2103744588 ------- It is the first step of RFC https://github.com/pytorch/pytorch/issues/124245. Changes: 1. Add cpp builder code, the new cpp_builder support Windows OS. 2. Add CPU ISA checker which is cross OS and exported from backend cpuinfo. 3. Switch compiler ISA checker to new cpp builder. 4. CppCodeCache use the new ISA checker. 5. Add temprary `test_new_cpp_build_logical` UT to help on transfer to new code. Image Pull Request resolved: https://github.com/pytorch/pytorch/pull/125849 Approved by: https://github.com/jgong5, https://github.com/desertfire --- aten/src/ATen/cpu/Utils.cpp | 15 + aten/src/ATen/cpu/Utils.h | 3 + test/inductor/test_torchinductor.py | 5 + torch/_C/_cpu.pyi | 2 + torch/_dynamo/trace_rules.py | 4 + torch/_inductor/codecache.py | 215 ++++- torch/_inductor/cpp_builder.py | 1178 +++++++++++++++++++++++++++ torch/cpu/__init__.py | 10 + torch/csrc/cpu/Module.cpp | 8 +- 9 files changed, 1408 insertions(+), 32 deletions(-) create mode 100644 torch/_inductor/cpp_builder.py diff --git a/aten/src/ATen/cpu/Utils.cpp b/aten/src/ATen/cpu/Utils.cpp index ddb9b34eceb9..21b6f33877ed 100644 --- a/aten/src/ATen/cpu/Utils.cpp +++ b/aten/src/ATen/cpu/Utils.cpp @@ -4,6 +4,21 @@ #endif namespace at::cpu { +bool is_cpu_support_avx2() { +#if !defined(__s390x__) && !defined(__powerpc__) + return cpuinfo_initialize() && cpuinfo_has_x86_avx2(); +#else + return false; +#endif +} + +bool is_cpu_support_avx512() { +#if !defined(__s390x__) && !defined(__powerpc__) + return cpuinfo_initialize() && cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512dq(); +#else + return false; +#endif +} bool is_cpu_support_vnni() { #if !defined(__s390x__) && !defined(__powerpc__) diff --git a/aten/src/ATen/cpu/Utils.h b/aten/src/ATen/cpu/Utils.h index ece13c70bce3..805c7c64a21b 100644 --- a/aten/src/ATen/cpu/Utils.h +++ b/aten/src/ATen/cpu/Utils.h @@ -4,6 +4,9 @@ namespace at::cpu { +TORCH_API bool is_cpu_support_avx2(); +TORCH_API bool is_cpu_support_avx512(); + // Detect if CPU support Vector Neural Network Instruction. TORCH_API bool is_cpu_support_vnni(); diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 42c430866290..4aa97b058271 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6643,6 +6643,11 @@ def fn(x): self.common(fn, [torch.randn(64, 64)]) + def test_new_cpp_build_logical(self): + from torch._inductor.codecache import validate_new_cpp_commands + + validate_new_cpp_commands() + def test_as_strided(self): def fn(x): return ( diff --git a/torch/_C/_cpu.pyi b/torch/_C/_cpu.pyi index 075fecf45d5a..641ba00312e0 100644 --- a/torch/_C/_cpu.pyi +++ b/torch/_C/_cpu.pyi @@ -2,4 +2,6 @@ from torch.types import _bool # Defined in torch/csrc/cpu/Module.cpp +def _is_cpu_support_avx2() -> _bool: ... +def _is_cpu_support_avx512() -> _bool: ... def _is_cpu_support_vnni() -> _bool: ... diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 5fc8398e3005..94487cb8551c 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -405,6 +405,8 @@ "torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata", "torch._C._construct_storage_from_data_pointer", "torch._C._conv_determine_backend_memory_format", + "torch._C._cpu._is_cpu_support_avx2", + "torch._C._cpu._is_cpu_support_avx512", "torch._C._cpu._is_cpu_support_vnni", "torch._C._crash_if_aten_asan", "torch._C._crash_if_csrc_asan", @@ -2416,6 +2418,8 @@ "torch.chain_matmul", "torch.compile", "torch.compiled_with_cxx11_abi", + "torch.cpu._is_cpu_support_avx2", + "torch.cpu._is_cpu_support_avx512", "torch.cpu._is_cpu_support_vnni", "torch.cpu.current_device", "torch.cpu.current_stream", diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6251513f0119..6ef07ed90692 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -78,6 +78,8 @@ _TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) _LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld") +_IS_WINDOWS = sys.platform == "win32" + if config.is_fbcode(): from triton.fb import build_paths from triton.fb.build import _run_build_command @@ -1231,7 +1233,7 @@ def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str: class VecISA: _bit_width: int - _macro: str + _macro: List[str] _arch_flags: str _dtype_nelements: Dict[torch.dtype, int] @@ -1277,7 +1279,7 @@ def bit_width(self) -> int: def nelements(self, dtype: torch.dtype = torch.float) -> int: return self._dtype_nelements[dtype] - def build_macro(self) -> str: + def build_macro(self) -> List[str]: return self._macro def build_arch_flags(self) -> str: @@ -1288,6 +1290,8 @@ def __hash__(self) -> int: @functools.lru_cache(None) # noqa: B019 def __bool__(self) -> bool: + from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions + if config.cpp.vec_isa_ok is not None: return config.cpp.vec_isa_ok @@ -1304,16 +1308,21 @@ def __bool__(self) -> bool: lock_dir = get_lock_dir() lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) with lock: - output_path = input_path[:-3] + "so" - build_cmd = shlex.split( - cpp_compile_command( - input_path, output_path, warning_all=False, vec_isa=self - ) + output_dir = os.path.dirname(input_path) + buid_options = CppTorchOptions(vec_isa=self, warning_all=False) + x86_isa_help_builder = CppBuilder( + key, + [input_path], + buid_options, + output_dir, ) try: # Check if the output file exist, and compile when not. + output_path = x86_isa_help_builder.get_target_file_path() if not os.path.isfile(output_path): - compile_file(input_path, output_path, build_cmd) + status, target_file = x86_isa_help_builder.build() + if status: + return False # Check build result subprocess.check_call( @@ -1334,9 +1343,9 @@ def __bool__(self) -> bool: @dataclasses.dataclass class VecNEON(VecISA): _bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h - _macro = "-DCPU_CAPABILITY_NEON" + _macro = ["CPU_CAPABILITY_NEON"] if sys.platform == "darwin" and platform.processor() == "arm": - _macro += " -DAT_BUILD_ARM_VEC256_WITH_SLEEF" + _macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF") _arch_flags = "" # Unused _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} @@ -1349,8 +1358,12 @@ def __str__(self) -> str: @dataclasses.dataclass class VecAVX512(VecISA): _bit_width = 512 - _macro = "-DCPU_CAPABILITY_AVX512" - _arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + _macro = ["CPU_CAPABILITY_AVX512"] + _arch_flags = ( + "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + if not _IS_WINDOWS + else "/arch:AVX512" + ) # TODO: use cflags _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} def __str__(self) -> str: @@ -1362,8 +1375,10 @@ def __str__(self) -> str: @dataclasses.dataclass class VecAVX2(VecISA): _bit_width = 256 - _macro = "-DCPU_CAPABILITY_AVX2" - _arch_flags = "-mavx2 -mfma" + _macro = ["CPU_CAPABILITY_AVX2"] + _arch_flags = ( + "-mavx2 -mfma" if not _IS_WINDOWS else "/arch:AVX2" + ) # TODO: use cflags _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} def __str__(self) -> str: @@ -1375,7 +1390,11 @@ def __str__(self) -> str: @dataclasses.dataclass class VecZVECTOR(VecISA): _bit_width = 256 - _macro = "-DCPU_CAPABILITY_ZVECTOR -DCPU_CAPABILITY=ZVECTOR -DHAVE_ZVECTOR_CPU_DEFINITION" + _macro = [ + "CPU_CAPABILITY_ZVECTOR", + "CPU_CAPABILITY=ZVECTOR", + "HAVE_ZVECTOR_CPU_DEFINITION", + ] _arch_flags = "-mvx -mzvector" _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} @@ -1387,7 +1406,7 @@ def __str__(self) -> str: class InvalidVecISA(VecISA): _bit_width = 0 - _macro = "" + _macro = [""] _arch_flags = "" _dtype_nelements = {} @@ -1400,6 +1419,31 @@ def __bool__(self) -> bool: # type: ignore[override] __hash__: Callable[[VecISA], Any] = VecISA.__hash__ +def x86_isa_checker() -> List[str]: + supported_isa: List[str] = [] + + def _check_and_append_supported_isa( + dest: List[str], isa_supported: bool, isa_name: str + ): + if isa_supported: + dest.append(isa_name) + + Arch = platform.machine() + """ + Arch value is x86_64 on Linux, and the value is AMD64 on Windows. + """ + if Arch != "x86_64" and Arch != "AMD64": + return supported_isa + + avx2 = torch.cpu._is_cpu_support_avx2() + avx512 = torch.cpu._is_cpu_support_avx512() + + _check_and_append_supported_isa(supported_isa, avx2, "avx2") + _check_and_append_supported_isa(supported_isa, avx512, "avx512") + + return supported_isa + + invalid_vec_isa = InvalidVecISA() supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()] @@ -1412,7 +1456,8 @@ def valid_vec_isa_list() -> List[VecISA]: if sys.platform == "darwin" and platform.processor() == "arm": return [VecNEON()] - if sys.platform != "linux": + cur_os = sys.platform + if cur_os != "linux" and cur_os != "win32": return [] if platform.machine() == "s390x": @@ -1430,12 +1475,11 @@ def valid_vec_isa_list() -> List[VecISA]: return [] isa_list = [] - with open("/proc/cpuinfo") as _cpu_info: - _cpu_info_content = _cpu_info.read() - for isa in supported_vec_isa_list: - if str(isa) in _cpu_info_content and isa: - isa_list.append(isa) - return isa_list + _cpu_supported_isa = x86_isa_checker() + for isa in supported_vec_isa_list: + if str(isa) in _cpu_supported_isa: + isa_list.append(isa) + return isa_list def pick_vec_isa() -> VecISA: @@ -1490,7 +1534,7 @@ def cpp_flags() -> str: def cpp_wrapper_flags() -> str: - return "-DTORCH_INDUCTOR_CPP_WRAPPER" + return "-D TORCH_INDUCTOR_CPP_WRAPPER" def optimization_flags() -> str: @@ -1632,7 +1676,14 @@ def get_include_and_linking_paths( _set_gpu_runtime_env() from torch.utils import cpp_extension - macros = vec_isa.build_macro() if vec_isa != invalid_vec_isa else "" + # Remove below in the further + # macros = "-D {}".format(vec_isa.build_macro()) if vec_isa != invalid_vec_isa else "" + macros = "" + if vec_isa != invalid_vec_isa: + for x in vec_isa.build_macro(): + macros_def = f"-D {x} " + macros += macros_def + build_arch_flags = "" if sys.platform == "linux" and ( include_pytorch @@ -1849,7 +1900,7 @@ def cpp_compile_command( {get_glibcxx_abi_build_flags()} {ipaths_str} {lpaths} {libs} {build_arch_flags} {macros} {linker_paths} {clang_flags} - {optimization_flags()} + {optimization_flags()} {cpp_wrapper_flags()} {use_custom_generated_macros()} {use_fb_internal_macros()} {use_standard_sys_dir_headers()} @@ -2354,8 +2405,21 @@ def load_async(cls, source_code: str, cuda=False, submit_fn=None, extra_flags=() "vec_isa": pick_vec_isa(), "extra_flags": extra_flags, } - cpp_command = repr(cpp_compile_command("i", "o", **compile_command)) - key, input_path = write(source_code, "cpp", extra=cpp_command) + + _set_gpu_runtime_env() # cpp_extension consults the env + + from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions + + dummy_builder = CppBuilder( + name="o", sources="i", BuildOption=CppTorchCudaOptions(**compile_command) + ) + # write function will calc source_code hash, the same source code with different + # ISA level should be generate different hash. + # So we need get a command_line which contains isa related parameter as a part of hash key. + # And then pass the command_line to below write function as extra parameter to + # guarantee the source code hash contains ISA difference. + dummy_cmd = repr(dummy_builder.get_command_line()) + key, input_path = write(source_code, "cpp", extra=dummy_cmd) if key not in cls.cache: from filelock import FileLock @@ -2628,6 +2692,101 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache): ) +# TODO: Will remove the temp code after switch to new cpp_builder +def _temp_validate_new_and_old_command(new_cmd: List[str], old_cmd: List[str]): + new_diff: List[str] = [x for x in new_cmd if x not in old_cmd] + old_diff: List[str] = [y for y in old_cmd if y not in new_cmd] + + if new_diff or old_diff: + print("!!! new_cmd: ", new_cmd) + print("!!! old_cmd: ", old_cmd) + print("!!! new_diff: ", new_diff) + print("!!! old_diff: ", old_diff) + raise RuntimeError("Error in new and old command different.") + + +def _do_validate_cpp_commands( + include_pytorch: bool, + cuda: bool, + compile_only: bool, + mmap_weights: bool, + use_absolute_path: bool, +): + # PreCI will failed if test machine can't run cuda. + temp_dir = tempfile.TemporaryDirectory() + test_dir_path = temp_dir.name + test_cuda = torch.cuda.is_available() and cuda + input_path = os.path.join(test_dir_path, "dummy_input.cpp") + output_path = os.path.join(test_dir_path, "dummy_output.so") + extra_flags = ["-D TEST_EXTRA_FLAGS"] + if compile_only: + output_path = os.path.join(test_dir_path, "dummy_output.o") + picked_isa = pick_vec_isa() + + old_cmd = cpp_compile_command( + input=input_path, + output=output_path, + include_pytorch=include_pytorch, + vec_isa=picked_isa, + cuda=test_cuda, + aot_mode=False, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=mmap_weights, + extra_flags=extra_flags, + ).split(" ") + + from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions + + dummy_build_option = CppTorchCudaOptions( + vec_isa=picked_isa, + include_pytorch=include_pytorch, + cuda=test_cuda, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=mmap_weights, + extra_flags=extra_flags, + ) + + dummy_builder = CppBuilder( + name="dummy_output", + sources=input_path, + BuildOption=dummy_build_option, + output_dir=test_dir_path, + ) + new_cmd = dummy_builder.get_command_line().split(" ") + + _temp_validate_new_and_old_command(new_cmd, old_cmd) + + temp_dir.cleanup() + + +# TODO: Will remove the temp code after switch to new cpp_builder +# It could help on sync new cpp_builder generate same command line as the old one. +def validate_new_cpp_commands(): + cuda = [True, False] + use_mmap_weights = [True, False] + compile_only = [True, False] + include_pytorch = [True, False] + use_absolute_path = [True, False] + + for x in cuda: + for y in use_mmap_weights: + for z in compile_only: + for m in include_pytorch: + for n in use_absolute_path: + print( + f"!!! cuda:{x}, use_mmap_weights:{y}, compile_only:{z}, include_pytorch:{m}, use_absolute_path:{n}" + ) + _do_validate_cpp_commands( + include_pytorch=m, + cuda=x, + mmap_weights=y, + compile_only=z, + use_absolute_path=n, + ) + + @clear_on_fresh_inductor_cache class HalideCodeCache(CppPythonBindingsCodeCache): cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {} diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py new file mode 100644 index 000000000000..8c9d2c18cabd --- /dev/null +++ b/torch/_inductor/cpp_builder.py @@ -0,0 +1,1178 @@ +# This CPP JIT builder is designed to support both Windows and Linux OS. +# The design document please check this RFC: https://github.com/pytorch/pytorch/issues/124245 + +import copy +import errno +import functools +import logging +import os +import platform +import re +import shlex +import shutil +import subprocess +import sys +import sysconfig +import warnings +from pathlib import Path +from typing import List, Sequence, Tuple, Union + +import torch +from torch._inductor import config, exc +from torch._inductor.codecache import ( + _get_python_include_dirs, + _LINKER_SCRIPT, + _transform_cuda_paths, + get_lock_dir, + invalid_vec_isa, + LOCK_TIMEOUT, + VecISA, +) +from torch._inductor.runtime.runtime_utils import cache_dir + +if config.is_fbcode(): + from triton.fb import build_paths # noqa: F401 + + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args, **kwargs): + pass + + def log_global_cache_stats(*args, **kwargs): + pass + + def log_global_cache_vals(*args, **kwargs): + pass + + def use_global_cache() -> bool: + return False + + +# Windows need setup a temp dir to store .obj files. +_BUILD_TEMP_DIR = "CxxBuild" + +# initialize variables for compilation +_IS_LINUX = sys.platform.startswith("linux") +_IS_MACOS = sys.platform.startswith("darwin") +_IS_WINDOWS = sys.platform == "win32" + + +log = logging.getLogger(__name__) + + +@functools.lru_cache(1) +def cpp_compiler_search(search: str) -> str: + for cxx in search: + try: + if cxx is None: + # gxx package is only available for Linux + # according to https://anaconda.org/conda-forge/gxx/ + if sys.platform != "linux": + continue + # Do not install GXX by default + if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"): + continue + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock( + os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT + ) + with lock: + cxx = install_gcc_via_conda() + subprocess.check_output([cxx, "--version"]) + return cxx + except (subprocess.SubprocessError, FileNotFoundError, ImportError): + continue + raise exc.InvalidCxxCompiler() # noqa: RSE102 + + +def install_gcc_via_conda() -> str: + """On older systems, this is a quick way to get a modern compiler""" + prefix = os.path.join(cache_dir(), "gcc") + cxx_path = os.path.join(prefix, "bin", "g++") + if not os.path.exists(cxx_path): + log.info("Downloading GCC via conda") + conda = os.environ.get("CONDA_EXE", "conda") + if conda is None: + conda = shutil.which("conda") + if conda is not None: + subprocess.check_call( + [ + conda, + "create", + f"--prefix={prefix}", + "--channel=conda-forge", + "--quiet", + "-y", + "python=3.8", + "gxx", + ], + stdout=subprocess.PIPE, + ) + return cxx_path + + +def _get_cpp_compiler() -> str: + if _IS_WINDOWS: + compiler = os.environ.get("CXX", "cl") + else: + if config.is_fbcode(): + return build_paths.cc() + if isinstance(config.cpp.cxx, (list, tuple)): + search = tuple(config.cpp.cxx) + else: + search = (config.cpp.cxx,) + compiler = cpp_compiler_search(search) + return compiler + + +def _is_gcc(cpp_compiler) -> bool: + return bool(re.search(r"(gcc|g\+\+)", cpp_compiler)) + + +def is_gcc() -> bool: + return _is_gcc(_get_cpp_compiler()) + + +def _is_clang(cpp_compiler) -> bool: + # Mac OS apple clang maybe named as gcc, need check compiler info. + if sys.platform == "darwin": + return is_apple_clang(cpp_compiler) + return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) + + +def is_clang() -> bool: + compiler = _get_cpp_compiler() + return _is_clang(compiler) + + +@functools.lru_cache(None) +def is_apple_clang(cpp_compiler) -> bool: + version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") + return "Apple" in version_string.splitlines()[0] + + +def _append_list(dest_list: List[str], src_list: List[str]): + for item in src_list: + dest_list.append(copy.deepcopy(item)) + + +def _remove_duplication_in_list(orig_list: List[str]) -> List[str]: + new_list: List[str] = [] + for item in orig_list: + if item not in new_list: + new_list.append(item) + return new_list + + +def _create_if_dir_not_exist(path_dir): + if not os.path.exists(path_dir): + try: + Path(path_dir).mkdir(parents=True, exist_ok=True) + except OSError as exc: # Guard against race condition + if exc.errno != errno.EEXIST: + raise RuntimeError( # noqa: TRY200 (Use `raise from`) + f"Fail to create path {path_dir}" + ) + + +def _remove_dir(path_dir): + if os.path.exists(path_dir): + for root, dirs, files in os.walk(path_dir, topdown=False): + for name in files: + file_path = os.path.join(root, name) + os.remove(file_path) + for name in dirs: + dir_path = os.path.join(root, name) + os.rmdir(dir_path) + os.rmdir(path_dir) + + +def run_command_line(cmd_line, cwd=None): + cmd = shlex.split(cmd_line) + try: + status = subprocess.check_output(args=cmd, cwd=cwd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8") + openmp_problem = "'omp.h' file not found" in output or "libomp" in output + if openmp_problem and sys.platform == "darwin": + instruction = ( + "\n\nOpenMP support not found. Please try one of the following solutions:\n" + "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " + "that has builtin OpenMP support;\n" + "(2) install OpenMP via conda: `conda install llvm-openmp`;\n" + "(3) install libomp via brew: `brew install libomp`;\n" + "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" + " with `include/omp.h` under it." + ) + output += instruction + raise exc.CppCompileError(cmd, output) from e + return status + + +class BuildOptionsBase: + """ + This is the Base class for store cxx build options, as a template. + Acturally, to build a cxx shared library. We just need to select a compiler + and maintains the suitable args. + """ + + def __init__(self) -> None: + self._compiler = "" + self._definations: List[str] = [] + self._include_dirs: List[str] = [] + self._cflags: List[str] = [] + self._ldflags: List[str] = [] + self._libraries_dirs: List[str] = [] + self._libraries: List[str] = [] + # Some args is hard to abstract to OS compatable, passthough it directly. + self._passthough_args: List[str] = [] + + self._aot_mode: bool = False + self._use_absolute_path: bool = False + self._compile_only: bool = False + + def _remove_duplicate_options(self): + self._definations = _remove_duplication_in_list(self._definations) + self._include_dirs = _remove_duplication_in_list(self._include_dirs) + self._cflags = _remove_duplication_in_list(self._cflags) + self._ldflags = _remove_duplication_in_list(self._ldflags) + self._libraries_dirs = _remove_duplication_in_list(self._libraries_dirs) + self._libraries = _remove_duplication_in_list(self._libraries) + self._passthough_args = _remove_duplication_in_list(self._passthough_args) + + def get_compiler(self) -> str: + return self._compiler + + def get_definations(self) -> List[str]: + return self._definations + + def get_include_dirs(self) -> List[str]: + return self._include_dirs + + def get_cflags(self) -> List[str]: + return self._cflags + + def get_ldflags(self) -> List[str]: + return self._ldflags + + def get_libraries_dirs(self) -> List[str]: + return self._libraries_dirs + + def get_libraries(self) -> List[str]: + return self._libraries + + def get_passthough_args(self) -> List[str]: + return self._passthough_args + + def get_aot_mode(self) -> bool: + return self._aot_mode + + def get_use_absolute_path(self) -> bool: + return self._use_absolute_path + + def get_compile_only(self) -> bool: + return self._compile_only + + +def _get_warning_all_cflag(warning_all: bool = True) -> List[str]: + if not _IS_WINDOWS: + return ["Wall"] if warning_all else [] + else: + return [] + + +def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]: + if _IS_WINDOWS: + return [f"std:{std_num}"] + else: + return [f"std={std_num}"] + + +def _get_linux_cpp_cflags(cpp_compiler) -> List[str]: + if not _IS_WINDOWS: + cflags = ["Wno-unused-variable", "Wno-unknown-pragmas"] + if _is_clang(cpp_compiler): + cflags.append("Werror=ignored-optimization-argument") + return cflags + else: + return [] + + +def _get_optimization_cflags() -> List[str]: + if _IS_WINDOWS: + return ["O2"] + else: + cflags = ["O0", "g"] if config.aot_inductor.debug_compile else ["O3", "DNDEBUG"] + cflags.append("ffast-math") + cflags.append("fno-finite-math-only") + + if not config.cpp.enable_unsafe_math_opt_flag: + cflags.append("fno-unsafe-math-optimizations") + if not config.cpp.enable_floating_point_contract_flag: + cflags.append("ffp-contract=off") + + if config.is_fbcode(): + # FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies. + # This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths. + # We will fix it later by exposing the lib path. + return cflags + + if sys.platform == "darwin": + # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang` + # Also, `-march=native` is unrecognized option on M1 + cflags.append("Xclang") + else: + if platform.machine() == "ppc64le": + cflags.append("mcpu=native") + else: + cflags.append("march=native") + + # Internal cannot find libgomp.so + if not config.is_fbcode(): + cflags.append("fopenmp") + + return cflags + + +def _get_shared_cflag(compile_only: bool) -> List[str]: + if _IS_WINDOWS: + SHARED_FLAG = ["DLL"] + else: + if compile_only: + return ["fPIC"] + if platform.system() == "Darwin" and "clang" in _get_cpp_compiler(): + # This causes undefined symbols to behave the same as linux + return ["shared", "fPIC", "undefined dynamic_lookup"] + else: + return ["shared", "fPIC"] + + return SHARED_FLAG + + +def get_cpp_options( + cpp_compiler, + compile_only: bool, + warning_all: bool = True, + extra_flags: Sequence[str] = (), +): + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + cflags = ( + _get_shared_cflag(compile_only) + + _get_optimization_cflags() + + _get_warning_all_cflag(warning_all) + + _get_cpp_std_cflag() + + _get_linux_cpp_cflags(cpp_compiler) + ) + + passthough_args.append(" ".join(extra_flags)) + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppOptions(BuildOptionsBase): + """ + This class is inherited from BuildOptionsBase, and as cxx build options. + This option need contains basic cxx build option, which contains: + 1. OS related args. + 2. Toolchains related args. + 3. Cxx standard related args. + Note: + 1. This Options is good for assist modules build, such as x86_isa_help. + """ + + def __init__( + self, + compile_only: bool, + warning_all: bool = True, + extra_flags: Sequence[str] = (), + use_absolute_path: bool = False, + ) -> None: + super().__init__() + self._compiler = _get_cpp_compiler() + self._use_absolute_path = use_absolute_path + self._compile_only = compile_only + + ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) = get_cpp_options( + cpp_compiler=self._compiler, + compile_only=compile_only, + extra_flags=extra_flags, + warning_all=warning_all, + ) + + _append_list(self._definations, definations) + _append_list(self._include_dirs, include_dirs) + _append_list(self._cflags, cflags) + _append_list(self._ldflags, ldflags) + _append_list(self._libraries_dirs, libraries_dirs) + _append_list(self._libraries, libraries) + _append_list(self._passthough_args, passthough_args) + self._remove_duplicate_options() + + +def _get_glibcxx_abi_build_flags() -> List[str]: + if not _IS_WINDOWS: + return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] + else: + return [] + + +def _get_torch_cpp_wrapper_defination() -> List[str]: + return ["TORCH_INDUCTOR_CPP_WRAPPER"] + + +def _use_custom_generated_macros() -> List[str]: + return [" C10_USING_CUSTOM_GENERATED_MACROS"] + + +def _use_fb_internal_macros() -> List[str]: + if not _IS_WINDOWS: + if config.is_fbcode(): + fb_internal_macros = [ + "C10_USE_GLOG", + "C10_USE_MINIMAL_GLOG", + "C10_DISABLE_TENSORIMPL_EXTENSIBILITY", + ] + # TODO: this is to avoid FC breakage for fbcode. When using newly + # generated model.so on an older verion of PyTorch, need to use + # the v1 version for aoti_torch_create_tensor_from_blob + create_tensor_from_blob_v1 = "AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1" + + fb_internal_macros.append(create_tensor_from_blob_v1) + + # TODO: remove comments later: + # Moved to _get_openmp_args + # openmp_lib = build_paths.openmp_lib() + # return [f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}"] + return fb_internal_macros + else: + return [] + else: + return [] + + +def _setup_standard_sys_libs( + cpp_compiler, + aot_mode: bool, + use_absolute_path: bool, +): + cflags: List[str] = [] + include_dirs: List[str] = [] + passthough_args: List[str] = [] + if _IS_WINDOWS: + return cflags, include_dirs, passthough_args + + if config.is_fbcode(): + cflags.append("nostdinc") + include_dirs.append(build_paths.sleef()) + include_dirs.append(build_paths.cc_include()) + include_dirs.append(build_paths.libgcc()) + include_dirs.append(build_paths.libgcc_arch()) + include_dirs.append(build_paths.libgcc_backward()) + include_dirs.append(build_paths.glibc()) + include_dirs.append(build_paths.linux_kernel()) + include_dirs.append("include") + + if aot_mode and not use_absolute_path: + linker_script = _LINKER_SCRIPT + else: + linker_script = os.path.basename(_LINKER_SCRIPT) + + if _is_clang(cpp_compiler): + passthough_args.append(" --rtlib=compiler-rt") + passthough_args.append(" -fuse-ld=lld") + passthough_args.append(f" -Wl,--script={linker_script}") + passthough_args.append(" -B" + build_paths.glibc_lib()) + passthough_args.append(" -L" + build_paths.glibc_lib()) + + return cflags, include_dirs, passthough_args + + +@functools.lru_cache +def _cpp_prefix_path() -> str: + from torch._inductor.codecache import write # TODO + + path = Path(Path(__file__).parent).parent / "codegen/cpp_prefix.h" + with path.open() as f: + content = f.read() + _, filename = write( + content, + "h", + ) + return filename + + +def _get_build_args_of_chosen_isa(vec_isa: VecISA): + macros = [] + build_flags = [] + if vec_isa != invalid_vec_isa: + # Add Windows support later. + for x in vec_isa.build_macro(): + macros.append(copy.deepcopy(x)) + + build_flags = [vec_isa.build_arch_flags()] + + if config.is_fbcode() and vec_isa != invalid_vec_isa: + cap = str(vec_isa).upper() + macros = [ + f"CPU_CAPABILITY={cap}", + f"CPU_CAPABILITY_{cap}", + f"HAVE_{cap}_CPU_DEFINITION", + ] + + return macros, build_flags + + +def _get_torch_related_args(include_pytorch: bool, aot_mode: bool): + from torch.utils.cpp_extension import _TORCH_PATH, TORCH_LIB_PATH + + include_dirs = [ + os.path.join(_TORCH_PATH, "include"), + os.path.join(_TORCH_PATH, "include", "torch", "csrc", "api", "include"), + # Some internal (old) Torch headers don't properly prefix their includes, + # so we need to pass -Itorch/lib/include/TH as well. + os.path.join(_TORCH_PATH, "include", "TH"), + os.path.join(_TORCH_PATH, "include", "THC"), + ] + libraries_dirs = [TORCH_LIB_PATH] + libraries = [] + if sys.platform == "linux" and not config.is_fbcode(): + libraries = ["torch", "torch_cpu"] + if not aot_mode: + libraries.append("torch_python") + + # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 + if not config.abi_compatible: + libraries.append("c10") + libraries_dirs.append(TORCH_LIB_PATH) + + return include_dirs, libraries_dirs, libraries + + +def _get_python_related_args(): + python_include_dirs = _get_python_include_dirs() + python_include_path = sysconfig.get_path( + "include", scheme="nt" if _IS_WINDOWS else "posix_prefix" + ) + if python_include_path is not None: + python_include_dirs.append(python_include_path) + + if _IS_WINDOWS: + python_path = os.path.dirname(sys.executable) + python_lib_path = [os.path.join(python_path, "libs")] + else: + python_lib_path = [sysconfig.get_config_var("LIBDIR")] + + if config.is_fbcode(): + python_include_dirs.append(build_paths.python()) + + return python_include_dirs, python_lib_path + + +def _get_openmp_args(cpp_compiler): + cflags: List[str] = [] + ldflags: List[str] = [] + include_dir_paths: List[str] = [] + lib_dir_paths: List[str] = [] + libs: List[str] = [] + passthough_args: List[str] = [] + if _IS_MACOS: + from torch._inductor.codecache import ( + homebrew_libomp, + is_conda_llvm_openmp_installed, + ) + + # only Apple builtin compilers (Apple Clang++) require openmp + omp_available = not is_apple_clang(cpp_compiler) + + # check the `OMP_PREFIX` environment first + omp_prefix = os.getenv("OMP_PREFIX") + if omp_prefix is not None: + header_path = os.path.join(omp_prefix, "include", "omp.h") + valid_env = os.path.exists(header_path) + if valid_env: + include_dir_paths.append(os.path.join(omp_prefix, "include")) + lib_dir_paths.append(os.path.join(omp_prefix, "lib")) + else: + warnings.warn("environment variable `OMP_PREFIX` is invalid.") + omp_available = omp_available or valid_env + + if not omp_available: + libs.append("omp") + + # prefer to use openmp from `conda install llvm-openmp` + conda_prefix = os.getenv("CONDA_PREFIX") + if not omp_available and conda_prefix is not None: + omp_available = is_conda_llvm_openmp_installed() + if omp_available: + conda_lib_path = os.path.join(conda_prefix, "lib") + include_dir_paths.append(os.path.join(conda_prefix, "include")) + lib_dir_paths.append(conda_lib_path) + # Prefer Intel OpenMP on x86 machine + if os.uname().machine == "x86_64" and os.path.exists( + os.path.join(conda_lib_path, "libiomp5.dylib") + ): + libs.append("iomp5") + + # next, try to use openmp from `brew install libomp` + if not omp_available: + omp_available, libomp_path = homebrew_libomp() + if omp_available: + include_dir_paths.append(os.path.join(libomp_path, "include")) + lib_dir_paths.append(os.path.join(libomp_path, "lib")) + + # if openmp is still not available, we let the compiler to have a try, + # and raise error together with instructions at compilation error later + elif _IS_WINDOWS: + # /openmp, /openmp:llvm + # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ + # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 + + cflags.append("openmp") + libs = [] + else: + if config.is_fbcode(): + include_dir_paths.append(build_paths.openmp()) + + openmp_lib = build_paths.openmp_lib() + fb_openmp_extra_flags = f"-Wp,-fopenmp {openmp_lib}" + passthough_args.append(fb_openmp_extra_flags) + + libs.append("omp") + else: + if _is_clang(cpp_compiler): + # TODO: fix issue, can't find omp.h + cflags.append("fopenmp") + libs.append("gomp") + else: + cflags.append("fopenmp") + libs.append("gomp") + + return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthough_args + + +def get_mmap_self_macro(use_mmap_weights: bool) -> List[str]: + macros = [] + if use_mmap_weights: + macros.append(" USE_MMAP_SELF") + return macros + + +def get_cpp_torch_options( + cpp_compiler, + vec_isa: VecISA, + include_pytorch: bool, + aot_mode: bool, + compile_only: bool, + use_absolute_path: bool, + use_mmap_weights: bool, +): + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + torch_cpp_wrapper_definations = _get_torch_cpp_wrapper_defination() + use_custom_generated_macros_definations = _use_custom_generated_macros() + + ( + sys_libs_cflags, + sys_libs_include_dirs, + sys_libs_passthough_args, + ) = _setup_standard_sys_libs(cpp_compiler, aot_mode, use_absolute_path) + + isa_macros, isa_ps_args_build_flags = _get_build_args_of_chosen_isa(vec_isa) + + ( + torch_include_dirs, + torch_libraries_dirs, + torch_libraries, + ) = _get_torch_related_args(include_pytorch=include_pytorch, aot_mode=aot_mode) + + python_include_dirs, python_libraries_dirs = _get_python_related_args() + + ( + omp_cflags, + omp_ldflags, + omp_include_dir_paths, + omp_lib_dir_paths, + omp_lib, + omp_passthough_args, + ) = _get_openmp_args(cpp_compiler) + + cxx_abi_passthough_args = _get_glibcxx_abi_build_flags() + fb_macro_passthough_args = _use_fb_internal_macros() + + mmap_self_macros = get_mmap_self_macro(use_mmap_weights) + + definations = ( + torch_cpp_wrapper_definations + + use_custom_generated_macros_definations + + isa_macros + + fb_macro_passthough_args + + mmap_self_macros + ) + include_dirs = ( + sys_libs_include_dirs + + python_include_dirs + + torch_include_dirs + + omp_include_dir_paths + ) + cflags = sys_libs_cflags + omp_cflags + ldflags = omp_ldflags + libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths + libraries = torch_libraries + omp_lib + passthough_args = ( + sys_libs_passthough_args + + isa_ps_args_build_flags + + cxx_abi_passthough_args + + omp_passthough_args + ) + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppTorchOptions(CppOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options. And then it will maintains torch related build + args. + 1. Torch include_directories, libraries, libraries_directories. + 2. Python include_directories, libraries, libraries_directories. + 3. OpenMP related. + 4. Torch MACROs. + 5. MISC + """ + + def __init__( + self, + vec_isa: VecISA, + include_pytorch: bool = False, + warning_all: bool = True, + aot_mode: bool = False, + compile_only: bool = False, + use_absolute_path: bool = False, + use_mmap_weights: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + ) -> None: + super().__init__( + compile_only=compile_only, + warning_all=warning_all, + extra_flags=extra_flags, + use_absolute_path=use_absolute_path, + ) + + self._aot_mode = aot_mode + + ( + torch_definations, + torch_include_dirs, + torch_cflags, + torch_ldflags, + torch_libraries_dirs, + torch_libraries, + torch_passthough_args, + ) = get_cpp_torch_options( + cpp_compiler=self._compiler, + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + + if compile_only: + torch_libraries_dirs = [] + torch_libraries = [] + + _append_list(self._definations, torch_definations) + _append_list(self._include_dirs, torch_include_dirs) + _append_list(self._cflags, torch_cflags) + _append_list(self._ldflags, torch_ldflags) + _append_list(self._libraries_dirs, torch_libraries_dirs) + _append_list(self._libraries, torch_libraries) + _append_list(self._passthough_args, torch_passthough_args) + self._remove_duplicate_options() + + +def get_cpp_torch_cuda_options(cuda: bool, aot_mode: bool = False): + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + if ( + config.is_fbcode() + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = build_paths.cuda() + + from torch.utils import cpp_extension + + include_dirs = cpp_extension.include_paths(cuda) + libraries_dirs = cpp_extension.library_paths(cuda) + + if cuda: + definations.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") + + if torch.version.hip is not None: + if config.is_fbcode(): + libraries += ["amdhip64"] + else: + libraries += ["c10_hip", "torch_hip"] + definations.append(" __HIP_PLATFORM_AMD__") + else: + if config.is_fbcode(): + libraries += ["cuda"] + else: + if config.is_fbcode(): + libraries += ["cuda"] + else: + libraries += ["c10_cuda", "cuda", "torch_cuda"] + + if aot_mode: + cpp_prefix_include_dir = [f"{os.path.dirname(_cpp_prefix_path())}"] + include_dirs += cpp_prefix_include_dir + + if cuda and torch.version.hip is None: + _transform_cuda_paths(libraries_dirs) + + if config.is_fbcode(): + if torch.version.hip is not None: + include_dirs.append(os.path.join(build_paths.rocm(), "include")) + else: + include_dirs.append(os.path.join(build_paths.cuda(), "include")) + + if aot_mode and cuda and config.is_fbcode(): + if torch.version.hip is None: + # TODO: make static link better on Linux. + passthough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"] + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppTorchCudaOptions(CppTorchOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options and torch common build options. And then it will + maintains cuda device related build args. + """ + + def __init__( + self, + vec_isa: VecISA, + include_pytorch: bool = False, + cuda: bool = True, + aot_mode: bool = False, + compile_only: bool = False, + use_absolute_path: bool = False, + use_mmap_weights: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + ) -> None: + super().__init__( + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + extra_flags=extra_flags, + ) + + cuda_definations: List[str] = [] + cuda_include_dirs: List[str] = [] + cuda_cflags: List[str] = [] + cuda_ldflags: List[str] = [] + cuda_libraries_dirs: List[str] = [] + cuda_libraries: List[str] = [] + cuda_passthough_args: List[str] = [] + + ( + cuda_definations, + cuda_include_dirs, + cuda_cflags, + cuda_ldflags, + cuda_libraries_dirs, + cuda_libraries, + cuda_passthough_args, + ) = get_cpp_torch_cuda_options(cuda=cuda, aot_mode=aot_mode) + + if compile_only: + cuda_libraries_dirs = [] + cuda_libraries = [] + + _append_list(self._definations, cuda_definations) + _append_list(self._include_dirs, cuda_include_dirs) + _append_list(self._cflags, cuda_cflags) + _append_list(self._ldflags, cuda_ldflags) + _append_list(self._libraries_dirs, cuda_libraries_dirs) + _append_list(self._libraries, cuda_libraries) + _append_list(self._passthough_args, cuda_passthough_args) + self._remove_duplicate_options() + + +def get_name_and_dir_from_output_file_path( + aot_mode: bool, use_absolute_path: bool, file_path: str +): + name_and_ext = os.path.basename(file_path) + name, ext = os.path.splitext(name_and_ext) + dir = os.path.dirname(file_path) + + if config.is_fbcode(): + if not (aot_mode and not use_absolute_path): + dir = "." + return name, dir + + +class CppBuilder: + """ + CppBuilder is a cpp jit builder, and it supports both Windows, Linux and MacOS. + Args: + name: + 1. Build target name, the final target file will append extension type automatically. + 2. Due to the CppBuilder is supports mutliple OS, it will maintains ext for OS difference. + sources: + Source code file list to be built. + BuildOption: + Build options to the builder. + output_dir: + 1. The output_dir the taget file will output to. + 2. The default value is empty string, and then the use current dir as output dir. + 3. Final target file: output_dir/name.ext + """ + + def get_shared_lib_ext(self) -> str: + SHARED_LIB_EXT = ".dll" if _IS_WINDOWS else ".so" + return SHARED_LIB_EXT + + def get_object_ext(self) -> str: + EXT = ".obj" if _IS_WINDOWS else ".o" + return EXT + + def __init__( + self, + name: str, + sources: Union[str, List[str]], + BuildOption: BuildOptionsBase, + output_dir: str = "", + ) -> None: + self._compiler = "" + self._cflags_args = "" + self._definations_args = "" + self._include_dirs_args = "" + self._ldflags_args = "" + self._libraries_dirs_args = "" + self._libraries_args = "" + self._passthough_parameters_args = "" + + self._output_dir = "" + self._target_file = "" + + self._use_absolute_path: bool = False + + self._name = name + + # Code start here, initial self internal veriables firstly. + self._compiler = BuildOption.get_compiler() + self._use_absolute_path = BuildOption.get_use_absolute_path() + + if len(output_dir) == 0: + self._output_dir = os.path.dirname(os.path.abspath(__file__)) + else: + self._output_dir = output_dir + + self._compile_only = BuildOption.get_compile_only() + file_ext = ( + self.get_object_ext() if self._compile_only else self.get_shared_lib_ext() + ) + self._target_file = os.path.join(self._output_dir, f"{self._name}{file_ext}") + + if isinstance(sources, str): + sources = [sources] + + if config.is_fbcode(): + if BuildOption.get_aot_mode() and not self._use_absolute_path: + inp_name = sources + # output process @ get_name_and_dir_from_output_file_path + else: + # We need to copy any absolute-path torch includes + inp_name = [os.path.basename(i) for i in sources] + self._target_file = os.path.basename(self._target_file) + + self._sources_args = " ".join(inp_name) + else: + self._sources_args = " ".join(sources) + + for cflag in BuildOption.get_cflags(): + if _IS_WINDOWS: + self._cflags_args += f"/{cflag} " + else: + self._cflags_args += f"-{cflag} " + + for defination in BuildOption.get_definations(): + if _IS_WINDOWS: + self._definations_args += f"/D {defination} " + else: + self._definations_args += f"-D {defination} " + + for inc_dir in BuildOption.get_include_dirs(): + if _IS_WINDOWS: + self._include_dirs_args += f"/I {inc_dir} " + else: + self._include_dirs_args += f"-I{inc_dir} " + + for ldflag in BuildOption.get_ldflags(): + if _IS_WINDOWS: + self._ldflags_args += f"/{ldflag} " + else: + self._ldflags_args += f"-{ldflag} " + + for lib_dir in BuildOption.get_libraries_dirs(): + if _IS_WINDOWS: + self._libraries_dirs_args += f'/LIBPATH:"{lib_dir}" ' + else: + self._libraries_dirs_args += f"-L{lib_dir} " + + for lib in BuildOption.get_libraries(): + if _IS_WINDOWS: + self._libraries_args += f'"{lib}.lib" ' + else: + self._libraries_args += f"-l{lib} " + + for passthough_arg in BuildOption.get_passthough_args(): + self._passthough_parameters_args += f"{passthough_arg} " + + def get_command_line(self) -> str: + def format_build_command( + compiler, + sources, + include_dirs_args, + definations_args, + cflags_args, + ldflags_args, + libraries_args, + libraries_dirs_args, + passthougn_args, + target_file, + ): + if _IS_WINDOWS: + # https://learn.microsoft.com/en-us/cpp/build/walkthrough-compile-a-c-program-on-the-command-line?view=msvc-1704 + # https://stackoverflow.com/a/31566153 + cmd = ( + f"{compiler} {include_dirs_args} {definations_args} {cflags_args} {sources} " + f"{passthougn_args} /LD /Fe{target_file} /link {libraries_dirs_args} {libraries_args} {ldflags_args} " + ) + cmd = cmd.replace("\\", "/") + else: + compile_only_arg = "-c" if self._compile_only else "" + cmd = re.sub( + r"[ \n]+", + " ", + f""" + {compiler} {sources} {definations_args} {cflags_args} {include_dirs_args} + {passthougn_args} {ldflags_args} {libraries_args} {libraries_dirs_args} {compile_only_arg} -o {target_file} + """, + ).strip() + return cmd + + command_line = format_build_command( + compiler=self._compiler, + sources=self._sources_args, + include_dirs_args=self._include_dirs_args, + definations_args=self._definations_args, + cflags_args=self._cflags_args, + ldflags_args=self._ldflags_args, + libraries_args=self._libraries_args, + libraries_dirs_args=self._libraries_dirs_args, + passthougn_args=self._passthough_parameters_args, + target_file=self._target_file, + ) + return command_line + + def get_target_file_path(self): + return self._target_file + + def convert_to_cpp_extension_args(self): + include_dirs = self._include_dirs_args + cflags = ( + self._cflags_args + + self._definations_args + + self._passthough_parameters_args + ) + ldflags = self._ldflags_args + self._libraries_args + self._libraries_dirs_args + + return include_dirs, cflags, ldflags + + def build(self) -> Tuple[int, str]: + """ + It is must need a temperary directory to store object files in Windows. + After build completed, delete the temperary directory to save disk space. + """ + _create_if_dir_not_exist(self._output_dir) + _build_tmp_dir = os.path.join( + self._output_dir, f"{self._name}_{_BUILD_TEMP_DIR}" + ) + _create_if_dir_not_exist(_build_tmp_dir) + + build_cmd = self.get_command_line() + + status = run_command_line(build_cmd, cwd=_build_tmp_dir) + + _remove_dir(_build_tmp_dir) + return status, self._target_file diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index 2f2561b69c1c..a36594a3cb15 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -28,6 +28,16 @@ _device_t = Union[_device, str, int, None] +def _is_cpu_support_avx2() -> bool: + r"""Returns a bool indicating if CPU supports AVX2.""" + return torch._C._cpu._is_cpu_support_avx2() + + +def _is_cpu_support_avx512() -> bool: + r"""Returns a bool indicating if CPU supports AVX512.""" + return torch._C._cpu._is_cpu_support_avx512() + + def _is_cpu_support_vnni() -> bool: r"""Returns a bool indicating if CPU supports VNNI.""" return torch._C._cpu._is_cpu_support_vnni() diff --git a/torch/csrc/cpu/Module.cpp b/torch/csrc/cpu/Module.cpp index f577c0c0dae1..b6c931eae0fe 100644 --- a/torch/csrc/cpu/Module.cpp +++ b/torch/csrc/cpu/Module.cpp @@ -2,15 +2,15 @@ #include #include -namespace torch { -namespace cpu { +namespace torch::cpu { void initModule(PyObject* module) { auto m = py::handle(module).cast(); auto cpu = m.def_submodule("_cpu", "cpu related pybind."); + cpu.def("_is_cpu_support_avx2", at::cpu::is_cpu_support_avx2); + cpu.def("_is_cpu_support_avx512", at::cpu::is_cpu_support_avx512); cpu.def("_is_cpu_support_vnni", at::cpu::is_cpu_support_vnni); } -} // namespace cpu -} // namespace torch +} // namespace torch::cpu From 5b3624117ac85a8dbd5486db27d23f2f5652289b Mon Sep 17 00:00:00 2001 From: laithsakka Date: Fri, 7 Jun 2024 08:47:43 -0700 Subject: [PATCH 500/706] update test_issue175 to handle inline_inbuilt_nn_modules (#128026) with inlining the output graph have more function calls reflecting those on the test that count number of function calls. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128026 Approved by: https://github.com/anijain2305 ghstack dependencies: #127553 --- test/dynamo/test_repros.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 771b9e96c88c..8515b6a7f735 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1609,7 +1609,10 @@ def test_issue175(self): opt_model(inp) opt_model(inp) self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 12) + + self.assertEqual( + 18 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count + ) def test_exec_import(self): def fn1(): From 11f2d8e823efa8508d1f2198429b95b1e9007222 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 7 Jun 2024 23:01:52 +0000 Subject: [PATCH 501/706] Move inductor cuda 124 jobs to a separate workflow that is not triggered by ciflow/inductor (#128250) https://github.com/pytorch/pytorch/pull/127825 The majority of the g5 runner usage comes from inductor (its something like 2x everything else) in the past week, inductor ran 1300 ish times on PRs and 300 times on main. Inductor-periodic ran 50 times on main, so the previous move from inductor -> inductor-periodic only results in 250 fewer runs. I was under the impression that cu124 is experimental currently and eventually we'll need to switch to it, so this will stay until we switch or inductor uses much fewer runners Are we expected to be able to handle two versions of cuda in CI? Because currently we cannot, at least not comfortably Pull Request resolved: https://github.com/pytorch/pytorch/pull/128250 Approved by: https://github.com/huydhn --- .github/pytorch-probot.yml | 1 + .github/workflows/inductor-cu124.yml | 108 ++++++++++++++++++++++++ .github/workflows/inductor-periodic.yml | 90 -------------------- 3 files changed, 109 insertions(+), 90 deletions(-) create mode 100644 .github/workflows/inductor-cu124.yml diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index d54346f81650..0d624788fc61 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -8,6 +8,7 @@ ciflow_push_tags: - ciflow/inductor - ciflow/inductor-perf-compare - ciflow/inductor-micro-benchmark +- ciflow/inductor-cu124 - ciflow/linux-aarch64 - ciflow/mps - ciflow/nightly diff --git a/.github/workflows/inductor-cu124.yml b/.github/workflows/inductor-cu124.yml new file mode 100644 index 000000000000..d7ab5665bed6 --- /dev/null +++ b/.github/workflows/inductor-cu124.yml @@ -0,0 +1,108 @@ +name: inductor-cu124 + +on: + push: + tags: + - ciflow/inductor-cu124/* + workflow_dispatch: + schedule: + # Run every 4 hours during the week and every 12 hours on the weekend + - cron: 45 0,4,8,12,16,20 * * 1-5 + - cron: 45 4,12 * * 0,6 + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +permissions: read-all + +jobs: + linux-focal-cuda12_4-py3_10-gcc9-inductor-build: + # Should be synced with the one in inductor.yml, but this doesn't run inductor_timm + name: cuda12.4-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + with: + sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, + { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_10-gcc9-inductor-test: + name: cuda12.4-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build + with: + sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-test + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }} + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp: + name: cuda12.4-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks + cuda-arch-list: '8.0' + test-matrix: | + { include: [ + { config: "inductor_torchbench_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.gcp.a100" }, + ]} + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_10-gcc9-inductor-test-gcp: + name: cuda12.4-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.test-matrix }} + use-gha: anything-non-empty-to-use-gha + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_12-gcc9-inductor-build: + name: cuda12.4-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda12_4-py3_12-gcc9-inductor-test: + name: cuda12.4-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_4-py3_12-gcc9-inductor-build + with: + build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 731291697cef..2fe649cebb5e 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -56,93 +56,3 @@ jobs: test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - linux-focal-cuda12_4-py3_10-gcc9-inductor-build: - # Should be synced with the one in inductor.yml, but this doesn't run inductor_timm - name: cuda12.4-py3.10-gcc9-sm86 - uses: ./.github/workflows/_linux-build.yml - with: - sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks - cuda-arch-list: '8.6' - test-matrix: | - { include: [ - { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, - { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - ]} - secrets: - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - linux-focal-cuda12_4-py3_10-gcc9-inductor-test: - name: cuda12.4-py3.10-gcc9-sm86 - uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build - with: - sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-test - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }} - secrets: - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp: - name: cuda12.4-py3.10-gcc9-sm80 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks - cuda-arch-list: '8.0' - test-matrix: | - { include: [ - { config: "inductor_torchbench_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.gcp.a100" }, - ]} - secrets: - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - linux-focal-cuda12_4-py3_10-gcc9-inductor-test-gcp: - name: cuda12.4-py3.10-gcc9-sm80 - uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.test-matrix }} - use-gha: anything-non-empty-to-use-gha - secrets: - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - linux-focal-cuda12_4-py3_12-gcc9-inductor-build: - name: cuda12.4-py3.12-gcc9-sm86 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks - cuda-arch-list: '8.6' - test-matrix: | - { include: [ - { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda12_4-py3_12-gcc9-inductor-test: - name: cuda12.4-py3.12-gcc9-sm86 - uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-cuda12_4-py3_12-gcc9-inductor-build - with: - build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.test-matrix }} From 09cccbc1c74c9d1157c1caca5526e79ee9b7ea01 Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Fri, 7 Jun 2024 10:27:36 -0700 Subject: [PATCH 502/706] [RFC] add per-collective timeout value in flight recorder (#128190) Summary: Add timeout value field on every collected record. Test Plan: Unit tests Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/128190 Approved by: https://github.com/wconstab --- test/distributed/test_c10d_nccl.py | 5 ++++- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 3 +++ torch/csrc/distributed/c10d/TraceUtils.h | 10 +++++++++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 21a8a632bade..f45600c5d17d 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3548,7 +3548,7 @@ def test_short(self, timing_enabled, include_collectives): ) ) ver = t["version"] - self.assertEqual(ver, "2.1") + self.assertEqual(ver, "2.2") pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] @@ -3577,6 +3577,7 @@ def test_short(self, timing_enabled, include_collectives): self.assertEqual(last["output_sizes"], ((3, 4),)) self.assertEqual(last["output_dtypes"], ["Float"]) self.assertEqual(last["collective_seq_id"], 2) + self.assertEqual(last["timeout_ms"], 600000) now = datetime.now() event_created_time = datetime.fromtimestamp( last["time_created_ns"] / 1000000000 @@ -3661,6 +3662,7 @@ def test_long(self): self.assertEqual(last["input_dtypes"], ["Float"]) self.assertEqual(last["output_sizes"], ((3, 4),)) self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["timeout_ms"], 600000) self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9) @requires_nccl() @@ -3865,6 +3867,7 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled): self.assertTrue(0.001 < duration < 10000, duration) else: self.assertTrue("duration_ms" not in t["entries"][coalesced_op]) + self.assertEqual(t["entries"][coalesced_op]["timeout_ms"], 600000) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 8adf1e02c1a0..07bbcd5a0af4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2356,6 +2356,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( outputs, r->ncclStartEvent_.get(), r->ncclEndEvent_.get(), + options_->timeout, isP2P); } return r; @@ -2966,6 +2967,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( {tensor}, nullptr, nullptr, + options_->timeout, /*isP2P=*/true); // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get // their timings/states updated by proxy when the Work obj representing the @@ -2999,6 +3001,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( {tensor}, work->ncclStartEvent_.get(), work->ncclEndEvent_.get(), + options_->timeout, /*isP2P=*/true); } diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index c3b0464cf992..de623d77fe9e 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -8,6 +8,7 @@ #include #include #include +#include #ifdef USE_C10D_NCCL #include @@ -28,7 +29,7 @@ static c10::IValue nccl_comm_key = "nccl_comm_state"; static c10::IValue version_key = "version"; // Update whenever changing contents or formatting of the dump // (minor when adding fields, major when changing existing fields) -static c10::IValue version_val = "2.1"; +static c10::IValue version_val = "2.2"; static c10::IValue pg_config_key = "pg_config"; static c10::IValue record_id_key = "record_id"; static c10::IValue pg_id_key = "pg_id"; @@ -44,6 +45,7 @@ static c10::IValue output_sizes_key = "output_sizes"; static c10::IValue output_dtypes_key = "output_dtypes"; static c10::IValue time_created_key = "time_created_ns"; static c10::IValue duration_key = "duration_ms"; +static c10::IValue timeout_key = "timeout_ms"; static c10::IValue frames_key = "frames"; static c10::IValue state_key = "state"; @@ -461,6 +463,9 @@ struct NCCLTraceBuffer { // was 'enqueued'- not necessarily started c10::time_t time_created_; + // configured timeout for this entry + c10::time_t timeout_ms_; + // Is this a P2P event? bool isP2P_; @@ -508,6 +513,7 @@ struct NCCLTraceBuffer { const std::vector& outputs, Event* start, Event* end, + std::chrono::milliseconds timeout_ms, bool isP2P) { if (!enabled_) { return c10::nullopt; @@ -528,6 +534,7 @@ struct NCCLTraceBuffer { std::move(start), std::move(end), c10::getTime(), + timeout_ms.count(), isP2P}; for (const auto& input : inputs) { @@ -752,6 +759,7 @@ struct NCCLTraceBuffer { ? int64_t(*e.time_discovered_completed_) : c10::IValue()); dict.insert(retired_key, e.retired_); + dict.insert(timeout_key, e.timeout_ms_); dict.insert(is_p2p_key, e.isP2P_); entries.push_back(dict); From bef586111a4f5707f7d5f04c50c6757fe5ef3072 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Fri, 7 Jun 2024 14:53:46 -0700 Subject: [PATCH 503/706] [pipelining] pipelining.rst updates (#128228) fix some nits and add `PipelineStage` (manual) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128228 Approved by: https://github.com/wconstab ghstack dependencies: #128201 --- docs/source/distributed.pipelining.rst | 31 +++++++++++++------ torch/distributed/pipelining/PipelineStage.py | 1 - 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 4f816bc3b843..05cefb220b90 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -53,7 +53,7 @@ Overall, the ``pipelining`` package provides the following features: * Splitting of model code based on simple specification. The goal is to make parallelism work for your model with **zero model code change**. * Rich support for pipeline schedules, including GPipe, 1F1B, - Interleaved 1F1B and Looped BFS, and provide the infrastruture for writing + Interleaved 1F1B and Looped BFS, and providing the infrastruture for writing customized schedules. * First-class support for cross-host pipeline parallelism, as this is where PP is typically used (over slower interconnects). @@ -179,7 +179,7 @@ You can also create a distributed stage runtime on a device using ``Pipe``: .. code-block:: python - from torch.distributed.pipelining import PipelineStage + from torch.distributed.pipelining import TracerPipelineStage stage = TracerPipelineStage(pipe, stage_idx, device) .. note:: @@ -187,15 +187,28 @@ You can also create a distributed stage runtime on a device using ``Pipe``: model into a single graph. If your model is not full-graph'able, you can use our manual frontend below. -Frontend 2: ``ManualPipelineStage`` -- if you already have module for each stage +Frontend 2: ``PipelineStage`` -- if you already have module for each stage ================================================================================ If you already have the module for each stage, you can skip the pipeline split -step above and directly connect to our runtime offering: ``ManualPipelineStage``. -The ``ManualPipelineStage`` wraps your stage module given a distributed context, +step above and directly connect to our runtime offering: ``PipelineStage``. +The ``PipelineStage`` wraps your stage module given a distributed context, i.e. a ``ProcessGroup`` along the pipeline dimension. -TODO: manual example here +.. code-block:: python + + from torch.distributed.pipelining import PipelineStage + stage = PipelineStage( + stage_mod, + stage_idx, + num_stages, + device, + input_args=x.chunk(num_microbatches)[0], + ) + +The ``PipelineStage`` requires an example argument (similar to ``example_args`` used in ``pipeline``). +This argument is passed through the forward method of the stage module to determine the +input and output shapes required for communication. Step 2: use ``PipelineSchedule`` for execution @@ -265,7 +278,7 @@ captures information during ``Module.__init__()``, and does not capture any information about ``Module.forward()``. Said differently, ``Module.children()`` lacks information about the following aspects key to pipelininig: -* Exectuion order of child modules in ``forward`` +* Execution order of child modules in ``forward`` * Activation flows between child modules * Whether there are any functional operators between child modules (for example, ``relu`` or ``add`` operations will not be captured by ``Module.children()``). @@ -276,8 +289,8 @@ helping the distributed runtime to make correct send/receive calls without human intervention. Another flexibility of the ``pipeline`` API is that split points can be at -arbitrary hierarchy of your model. In the split partitions, the original model -hierarchy related to that partition will be reconstructed at no cost of yours. +arbitrary levels within your model hierarchy. In the split partitions, the original model +hierarchy related to that partition will be reconstructed at no cost to you. At a result, fully-qualified names (FQNs) pointing to a submodule or parameter would be still valid, and services that relies on FQNs (such as FSDP, TP or checkpointing) can still run with your partitioned modules with almost zero code diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index fbfca518df4d..f59e3e9dae65 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -1096,7 +1096,6 @@ class PipelineStage(_PipelineStageBase): stage_index (int): The ID of this stage. num_stages (int): The total number of stages. device (torch.device): The device where this stage is located. - num_microbatches (int): The number of microbatches to use. input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule. output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule. group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. From 39dd4740e6804dfccd82095a79957cd0235d552b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 6 Jun 2024 23:02:06 -0700 Subject: [PATCH 504/706] [inductor][dynamo-inline-nn-modules] Fix test with inlining flag (#128200) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128200 Approved by: https://github.com/Skylion007 ghstack dependencies: #128001, #126578, #128158, #128172 --- test/inductor/test_torchinductor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 4aa97b058271..30b846d5842c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10934,7 +10934,9 @@ def fn3(x): ), ( fn3, - "triton_poi_fused_LayerNorm_ReLU", + "triton_poi_fused_layer_norm_relu" + if torch._dynamo.config.inline_inbuilt_nn_modules + else "triton_poi_fused_LayerNorm_ReLU", (torch.randn(4, 4, device=GPU_TYPE),), ), ] From ef2b5ed500cba0b8b2bf04e6006a0d64c910f440 Mon Sep 17 00:00:00 2001 From: cyy Date: Sat, 8 Jun 2024 00:09:26 +0000 Subject: [PATCH 505/706] [4/N] Remove unused functions (#128193) Follows #128179 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128193 Approved by: https://github.com/ezyang --- aten/src/ATen/native/BinaryOps.cpp | 9 ---- aten/src/ATen/native/DispatchStub.h | 2 +- aten/src/ATen/native/LossNLL2d.cpp | 9 ---- .../native/NaiveConvolutionTranspose2d.cpp | 49 ----------------- .../native/NaiveConvolutionTranspose3d.cpp | 52 ------------------- aten/src/ATen/native/NamedTensor.cpp | 15 ------ aten/src/ATen/native/ReduceOps.cpp | 5 -- aten/src/ATen/native/ReflectionPad.cpp | 8 --- aten/src/ATen/native/Resize.cpp | 8 --- aten/src/ATen/native/TensorCompare.cpp | 6 --- .../ATen/native/quantized/cpu/ReduceOps.cpp | 19 ------- aten/src/ATen/native/sparse/SoftMax.cpp | 9 ---- 12 files changed, 1 insertion(+), 190 deletions(-) diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 19c70672fb93..3fe3ac2b4a25 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -1480,23 +1480,14 @@ Tensor& not_equal_(Tensor& self, const Scalar& other) { return self.ne_(other); Tensor& logical_and_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_and_stub); } Tensor logical_and(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast(at::logical_and_out)); } Tensor& logical_and_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast(at::logical_and_out)); } -static Tensor& logical_and_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast(at::logical_and_out)); } -static Tensor logical_and(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast(at::logical_and_out)); } -static Tensor& logical_and_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast(at::logical_and_out)); } Tensor& logical_or_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_or_stub); } Tensor logical_or(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast(at::logical_or_out)); } Tensor& logical_or_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast(at::logical_or_out)); } -static Tensor& logical_or_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast(at::logical_or_out)); } -static Tensor logical_or(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast(at::logical_or_out)); } -static Tensor& logical_or_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast(at::logical_or_out)); } Tensor& logical_xor_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_xor_stub); } Tensor logical_xor(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast(at::logical_xor_out)); } Tensor& logical_xor_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast(at::logical_xor_out)); } -static Tensor& logical_xor_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast(at::logical_xor_out)); } -static Tensor logical_xor(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast(at::logical_xor_out)); } -static Tensor& logical_xor_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast(at::logical_xor_out)); } // binary max, alias for maximum Tensor& max_out(const Tensor& self, const Tensor& other, Tensor& result) { diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index e1952795843c..b35ad072d0cf 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -393,7 +393,7 @@ struct RegisterPRIVATEUSE1Dispatch { // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches. // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others. #ifdef CPU_CAPABILITY_AVX512 -#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, nullptr) +#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr)) #else #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #endif diff --git a/aten/src/ATen/native/LossNLL2d.cpp b/aten/src/ATen/native/LossNLL2d.cpp index 6f27884b8f24..13c575a1a7bb 100644 --- a/aten/src/ATen/native/LossNLL2d.cpp +++ b/aten/src/ATen/native/LossNLL2d.cpp @@ -499,13 +499,4 @@ Tensor nll_loss2d_symint(const Tensor & self, const Tensor & target, const std:: return std::get<0>(at::nll_loss2d_forward_symint(self, target, weight, reduction, std::move(ignore_index))); } -// Duplicate of above code for non-symbolic ints. Kept for BC purposes and to minimize breakages. -static Tensor nll_loss2d(const Tensor & self, const Tensor & target, const std::optional& weight_opt, int64_t reduction, int64_t ignore_index) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - return std::get<0>(at::nll_loss2d_forward_symint(self, target, weight, reduction, ignore_index)); -} - } // namespace at::native diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp index fbac5d4cc72c..7da1ec9b1998 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp @@ -802,55 +802,6 @@ TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cpu) dilation); } -static std::tuple slow_conv_transpose2d_backward_out_cpu(const Tensor& grad_output, - const Tensor& input, - const Tensor& weight, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef output_padding, - IntArrayRef dilation, - Tensor& grad_input, - Tensor& grad_weight, - Tensor& grad_bias) { - if (grad_input.defined()) { - slow_conv_transpose2d_backward_out_cpu_template( - input, - grad_output, - grad_input, - weight, - kernel_size, - stride, - padding, - output_padding, - dilation); - } - - if (grad_bias.defined()) { - at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3}); - } - - if (grad_weight.defined()) { - grad_weight.resize_(weight.sizes(), weight.suggest_memory_format()); - grad_weight.zero_(); - slow_conv_transpose2d_acc_grad_parameters_cpu( - input, - weight, - grad_output, - grad_weight, - grad_bias, - kernel_size, - stride, - padding, - output_padding, - dilation, - 1); - } - - return std::tuple( - grad_input, grad_weight, grad_bias); -} - static std::tuple slow_conv_transpose2d_backward_cpu( const Tensor& grad_output, const Tensor& input, diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp index f82354ace3b8..9ef236d4dab9 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp @@ -871,58 +871,6 @@ Tensor slow_conv_transpose3d_cpu( return output; } -static std::tuple slow_conv_transpose3d_backward_out_cpu(const Tensor& grad_output, - const Tensor& input, - const Tensor& weight, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef output_padding, - IntArrayRef dilation, - Tensor& grad_input, - Tensor& grad_weight, - Tensor& grad_bias) { - if (grad_input.defined()) { - slow_conv_transpose3d_backward_out_cpu_template( - input, - grad_output, - grad_input, - weight, - kernel_size, - stride, - padding, - output_padding, - dilation); - } - - if (grad_weight.defined()) { - grad_weight.resize_(weight.sizes()); - grad_weight.zero_(); - } - - if (grad_bias.defined()) { - grad_bias.resize_({weight.size(1)}); - grad_bias.zero_(); - } - - if (grad_weight.defined() || grad_bias.defined()) { - slow_conv_transpose3d_acc_grad_parameters_cpu( - input, - grad_output, - grad_weight, - grad_bias, - kernel_size, - stride, - padding, - output_padding, - dilation, - 1); - } - - return std::tuple( - grad_input, grad_weight, grad_bias); -} - static std::tuple slow_conv_transpose3d_backward_cpu( const Tensor& grad_output, const Tensor& input, diff --git a/aten/src/ATen/native/NamedTensor.cpp b/aten/src/ATen/native/NamedTensor.cpp index 709d63bae636..70fb94cc6f45 100644 --- a/aten/src/ATen/native/NamedTensor.cpp +++ b/aten/src/ATen/native/NamedTensor.cpp @@ -339,12 +339,6 @@ Tensor& gather_out(const Tensor& self, Dimname dim, const Tensor& index, bool sp Tensor index_add(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar &alpha) { reportNYIDimnameOverload("index_add"); } -static Tensor& index_add_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar &alpha) { - reportNYIDimnameOverload("index_add"); -} -static Tensor& index_add_out(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar& alpha, Tensor& result) { - reportNYIDimnameOverload("index_add"); -} Tensor index_fill(const Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) { return at::index_fill(self, dimname_to_position(self, dim), index, source); } @@ -372,21 +366,12 @@ Tensor index_select(const Tensor& self, Dimname dim, const Tensor& index) { Tensor scatter(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) { reportNYIDimnameOverload("scatter"); } -static Tensor& scatter_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) { - reportNYIDimnameOverload("scatter"); -} Tensor scatter(const Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) { reportNYIDimnameOverload("scatter"); } -static Tensor& scatter_(Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) { - reportNYIDimnameOverload("scatter"); -} Tensor scatter_add(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) { reportNYIDimnameOverload("scatter_add"); } -static Tensor& scatter_add_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) { - reportNYIDimnameOverload("scatter_add"); -} std::tuple sort_out(const Tensor& self, std::optional stable, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) { reportNYIDimnameOverload("sort"); } diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 2e870dc83ee1..4718e824fad8 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -2276,11 +2276,6 @@ bool cpu_equal(const Tensor& self, const Tensor& other) { return result.load(); } -static Tensor value_selecting_reduction_backward(const Tensor& grad, int64_t dim, const Tensor& indices, at::IntArrayRef sizes, bool keepdim) { - return at::native::value_selecting_reduction_backward_symint(grad, dim, indices, c10::fromIntArrayRefSlow(sizes), keepdim); -} - - // max(dim), min(dim), topk(dim), mode(dim), are examples of reduction // functions that select values. value_selecting_reduction_backward is the // backward function for those operators; it propagates the grad to the diff --git a/aten/src/ATen/native/ReflectionPad.cpp b/aten/src/ATen/native/ReflectionPad.cpp index ac5702935442..61c60c17428c 100644 --- a/aten/src/ATen/native/ReflectionPad.cpp +++ b/aten/src/ATen/native/ReflectionPad.cpp @@ -301,14 +301,6 @@ void reflection_pad2d_backward_out_template( } // namespace -// TODO: I tihnk this function should be removed since we implement it with -// TORCH_IMPL_FUNC below -static Tensor& reflection_pad1d_out_cpu(const Tensor& input, IntArrayRef padding, - Tensor& output) { - reflection_pad1d_kernel(kCPU, output, input, padding); - return output; -} - Tensor& reflection_pad1d_out_quantized_cpu(const Tensor& input, IntArrayRef padding, Tensor& output) { TORCH_CHECK(input.qscheme() == kPerTensorAffine, "Only per tensor quantization is supported"); diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index fd06627b7027..95676ad11772 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -231,14 +231,6 @@ TensorImpl* resize_impl_cpu_( return _resize_impl_(self, size, stride, resize_storage); } -static TensorImpl* resize_impl_meta_( - TensorImpl* self, - c10::SymIntArrayRef size, - at::OptionalSymIntArrayRef stride, - bool resize_storage = true) { - return _resize_impl_(self, size, stride, resize_storage); -} - template const Tensor& _resize_( const Tensor& self, diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 6f132a6ea814..6d6db1477f1f 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -792,12 +792,6 @@ std::tuple max(const Tensor& self, Dimname dim, bool keepdim) { std::tuple max_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& max, Tensor& max_indices) { return at::max_out(max, max_indices, self, dimname_to_position(self, dim), keepdim); } -static Tensor argmax(const Tensor& /*self*/, Dimname /*dim*/, bool /*keepdim*/) { - reportNYIDimnameOverload("argmax"); -} -static Tensor argmin(const Tensor& /*self*/, Dimname /*dim*/, bool /*keepdim*/) { - reportNYIDimnameOverload("argmin"); -} Tensor argsort(const Tensor& /*self*/, Dimname /*dim*/, bool /*keepdim*/) { reportNYIDimnameOverload("argsort"); } diff --git a/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp b/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp index 113c57f2cc35..573b2ffff4b4 100644 --- a/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp +++ b/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp @@ -237,24 +237,5 @@ Tensor std_quantized_cpu( return result; } -static Tensor std_quantized_cpu( - const Tensor& self, - DimnameList dim, - const std::optional& correction, - bool keepdim) { - return std_quantized_cpu( - self, dimnames_to_positions(self, dim), correction, keepdim); -} - -static Tensor& std_out_quantized_cpu( - Tensor& result, - const Tensor& self, - DimnameList dim, - const std::optional& correction, - bool keepdim) { - return std_out_quantized_cpu( - self, dimnames_to_positions(self, dim), correction, keepdim, result); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/sparse/SoftMax.cpp b/aten/src/ATen/native/sparse/SoftMax.cpp index 179db48beacc..668032cb588e 100644 --- a/aten/src/ATen/native/sparse/SoftMax.cpp +++ b/aten/src/ATen/native/sparse/SoftMax.cpp @@ -606,15 +606,6 @@ Tensor log_softmax_backward_sparse_cpu( return grad_input; } -static Tensor _sparse_softmax(const Tensor& input_, const int64_t dim_) { - auto result = [&]() { - NoNamesGuard guard; - return at::_sparse_softmax(input_, dim_, false); - }(); - namedinference::propagate_names(result, input_); - return result; -} - Tensor _sparse_softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; From 647815049ec28a72dc1bb6a977791927bba058d5 Mon Sep 17 00:00:00 2001 From: Alnis Murtovi Date: Sat, 8 Jun 2024 00:46:16 +0000 Subject: [PATCH 506/706] Inductor: Allow small sizes of m for mixed mm autotuning (#127663) For mixed mm with small sizes of m, such as in the example provided in #127056, being able to set BLOCK_M to 16 leads to better performance. This PR introduces kernel configs that are specific to mixed mm by extending the mm configs with two configs that work well for the example provided in #127056. I am excluding configs with (BLOCK_M=16, BLOCK_K=16, BLOCK_N=64) because triton crashes when this config is used. For the example in #127056: - Without my changes, skip_triton is evaluated to true which disables autotuning. On my machine I achieve 146GB/s. - If autotuning is enabled, but BLOCK_M>=32, I achieve 614 GB/s. - With the changes in this PR (i.e. autotuning enabled and BLOCK_M=16), I achieve 772 GB/s. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127663 Approved by: https://github.com/Chillee --- test/inductor/test_torchinductor.py | 18 ++++++++++++++ torch/_inductor/kernel/mm.py | 6 +++-- torch/_inductor/kernel/mm_common.py | 38 ++++++++++++++++++++++++----- 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 30b846d5842c..53167a83ecd8 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2900,6 +2900,24 @@ def fn(a, b, scale, bias): check_lowp=True, ) + @skipIfPy312 # segfaults + @config.patch(force_mixed_mm=True) + def test_mixed_mm3(self): + def fn(a, b): + return torch.mm(a, b.to(a.dtype)) + + # (256, 256) @ (256, 256) so different block sizes are tried out during autotuning + self.common( + fn, + ( + torch.randn(256, 256), + torch.randint(-128, 127, (256, 256), dtype=torch.int8), + ), + check_lowp=True, + rtol=0.01, + atol=0.1, + ) + @with_tf32_off @config.patch(use_mixed_mm=True) def test_uint4x2_mixed_mm(self): diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index a90fdbfa33d9..2f30aa941837 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -26,6 +26,7 @@ from .mm_common import ( addmm_epilogue, int8_mm_configs, + mixed_mm_configs, mm_args, mm_configs, mm_grid, @@ -407,7 +408,8 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype): # can't use triton kernel unless one of these is true or if running on v100 (numerical issues) skip_triton = ( - mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous() + mat1.layout.dtype != torch.float32 + and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed()) ) or _is_sm7x_or_older_gpu(layout.device.index) if inductor_config.force_mixed_mm: @@ -415,7 +417,7 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype): if not skip_triton: b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "") has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2) - for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor): + for config in mixed_mm_configs(m, n, k, has_int8_tensor=has_int8_tensor): mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 97741cc0f8eb..1ca0558d19c0 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -31,10 +31,10 @@ def filtered_configs( ): """Heuristic to shrink configs when they are bigger than the input size""" - # According to https://github.com/openai/triton/issues/2156#issuecomment-1695897424 - # it's safer to use at least [32, 32] block size for int8/uint8 - # tensors - min_block_size = 32 if has_int8_tensor else 16 + min_block_size = 16 + # block_k=16 seems to be causing issues + # see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424 + min_block_size_k = 32 if has_int8_tensor else 16 m = max( next_power_of_2( V.graph.sizevars.size_hint( @@ -57,14 +57,14 @@ def filtered_configs( k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] ) ), - min_block_size, + min_block_size_k, ) used = set() for block_m, block_n, block_k, num_stages, num_warps in configs: # shrink configs for small sizes block_m = max(min(block_m, m), min_block_size) block_n = max(min(block_n, n), min_block_size) - block_k = max(min(block_k, k), min_block_size) + block_k = max(min(block_k, k), min_block_size_k) # each warp computes 16x16 tile = 256 num_warps = min(num_warps, block_m * block_n // 256) if torch.version.hip: @@ -166,6 +166,18 @@ def filtered_configs( {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None}, ] +# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192). +mixed_mm_kernel_configs_small_m = [ + {"config": (16, 128, 256, 3, 4), "cond": True}, + {"config": (16, 128, 256, 5, 8), "cond": True}, +] + +mixed_mm_kernel_configs = ( + mm_kernel_configs + mixed_mm_kernel_configs_small_m + if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" + else mm_kernel_configs +) + # Create filtered list of configs based on cond evaluation @@ -179,6 +191,11 @@ def filtered_configs( for config in int8_mm_kernel_configs if config["cond"] ) +mixed_mm_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in mixed_mm_kernel_configs + if config["cond"] +) # On ROCm convert num_stages to 0 to enable software pipelining if torch.version.hip: @@ -190,6 +207,10 @@ def filtered_configs( (config[0], config[1], config[2], 0, config[4]) for config in mm_platform_configs ) + mixed_mm_platform_configs = tuple( + (config[0], config[1], config[2], 0, config[4]) + for config in mixed_mm_platform_configs + ) mm_configs = functools.partial( filtered_configs, @@ -201,6 +222,11 @@ def filtered_configs( configs=int8_platform_configs, ) +mixed_mm_configs = functools.partial( + filtered_configs, + configs=mixed_mm_platform_configs, +) + def mm_grid(m, n, meta): """ From 5ef081031e1dfa5902c43214a3533a40af397459 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Fri, 7 Jun 2024 14:23:59 -0700 Subject: [PATCH 507/706] [MPS] Include MPSGraphVenturaOps.h for complex types on macOS 12 (#127859) Fixes this on macOS 12: ``` /Users/qqaatw/Forks/pytorch/aten/src/ATen/native/mps/operations/FastFourierTransform.mm:108:60: error: use of undeclared identifier 'MPSDataTypeComplexFloat16'; did you mean 'MPSDataTypeFloat16'? (inputTensor.dataType == MPSDataTypeFloat16) ? MPSDataTypeComplexFloat16 : MPSDataTypeComplexFloat32; ^~~~~~~~~~~~~~~~~~~~~~~~~ MPSDataTypeFloat16 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127859 Approved by: https://github.com/kulinseth --- aten/src/ATen/native/mps/operations/FastFourierTransform.mm | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm index 1b6e650f51d4..a9ac70110617 100644 --- a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm +++ b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm @@ -1,5 +1,6 @@ #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS From ad96f991a5d6a26181a98c95adab8bc4b2dca669 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 7 Jun 2024 18:23:03 -0700 Subject: [PATCH 508/706] [pipelining] Add pipe.build_stage() (#128240) Given `PipelineStage` name to manual side. Thus adding a method under `Pipe` to create PipelineStage. Moved `PipeInfo` to utils.py to avoid circular dependency between `_IR` and `PipelineStage`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128240 Approved by: https://github.com/wconstab, https://github.com/H-Huang --- docs/source/distributed.pipelining.rst | 5 +- test/distributed/pipelining/test_schedule.py | 13 ++--- test/distributed/pipelining/test_stage.py | 15 ++---- .../pipelining/PipelineSchedule.py | 23 +-------- torch/distributed/pipelining/PipelineStage.py | 38 ++++---------- torch/distributed/pipelining/_IR.py | 50 +++++++++++-------- torch/distributed/pipelining/__init__.py | 3 +- torch/distributed/pipelining/_utils.py | 12 +++++ 8 files changed, 59 insertions(+), 100 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 05cefb220b90..32efef67cde8 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -179,8 +179,7 @@ You can also create a distributed stage runtime on a device using ``Pipe``: .. code-block:: python - from torch.distributed.pipelining import TracerPipelineStage - stage = TracerPipelineStage(pipe, stage_idx, device) + stage = pipe.build_stage(stage_idx, device, group) .. note:: The ``pipeline`` frontend uses a tracer (``torch.export``) to capture your @@ -354,8 +353,6 @@ Pipeline Stages .. autoclass:: PipelineStage -.. autoclass:: TracerPipelineStage - Pipeline Schedules ================== diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index ef9a63688230..22d1167908aa 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -19,7 +19,6 @@ ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleLoopedBFS, - TracerPipelineStage, ) from torch.distributed.pipelining.PipelineSchedule import _Action, _ComputationType from torch.distributed.pipelining.PipelineStage import _PipelineStageBase @@ -98,11 +97,9 @@ def test_multi_iter(self, ScheduleClass): split_spec=split_spec, ) - stage = TracerPipelineStage( - pipe, + stage = pipe.build_stage( self.rank, self.device, - chunks, # to be cleaned ) # Attach to a schedule @@ -140,11 +137,9 @@ def test_kwargs_with_tracer(self, ScheduleClass): mb_kwargs={"y": y_mb}, ) - stage = TracerPipelineStage( - pipe, + stage = pipe.build_stage( self.rank, self.device, - chunks, # to be cleaned ) # Attach to a schedule @@ -203,11 +198,9 @@ def test_grad_with_tracer(self, ScheduleClass, ModelClass): split_spec=split_spec, ) - stage = TracerPipelineStage( - pipe, + stage = pipe.build_stage( self.rank, self.device, - chunks, # to be cleaned ) # Attach to a schedule diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index b11a6037f604..959dd2e526fa 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -8,12 +8,7 @@ import torch import torch.distributed as dist -from torch.distributed.pipelining import ( - pipeline, - PipelineStage, - ScheduleGPipe, - TracerPipelineStage, -) +from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe from torch.distributed.pipelining._utils import PipeliningShapeError from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( @@ -91,11 +86,9 @@ def test_tracer(self, ModelClass): split_spec=split_spec, ) - stage = TracerPipelineStage( - pipe, + stage = pipe.build_stage( self.rank, self.device, - chunks, # to be cleaned ) # Attach to a schedule @@ -160,11 +153,9 @@ def test_tracer_kwargs(self, ModelClass): mb_kwargs={"y": y_mb}, ) - stage = TracerPipelineStage( - pipe, + stage = pipe.build_stage( self.rank, self.device, - chunks, ) # Attach to a schedule diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index 2de3c4eef85d..31632a8aaee5 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -4,17 +4,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from enum import Enum -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Tuple, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union import torch import torch.distributed as dist @@ -23,9 +13,6 @@ from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec from .PipelineStage import _PipelineStageBase -if TYPE_CHECKING: - from ._IR import Pipe - __all__ = [ "PipelineScheduleSingle", @@ -84,8 +71,6 @@ def __init__( # Derived self._has_backward = self._loss_fn is not None - # To be filled by subclasses - self._pipe_info: Optional[Pipe.PipeInfo] = None # Holds the losses for each microbatch. self._internal_losses: List[torch.Tensor] = [] @@ -300,9 +285,6 @@ def __init__( kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, ) - self._pipe_info = ( - stage.pipe_info if hasattr(stage, "pipe_info") else None # type: ignore[attr-defined] - ) # Self attributes self._stage = stage self._num_stages = stage.num_stages @@ -595,9 +577,6 @@ def __init__( kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, ) - self._pipe_info = ( - stages[0].pipe_info if hasattr(stages[0], "pipe_info") else None # type: ignore[attr-defined] - ) # Self attributes self._stages = stages self._num_stages = stages[0].num_stages diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index f59e3e9dae65..c18f91e1d9d8 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -15,13 +15,16 @@ from ._backward import stage_backward from ._debug import map_debug_info -from ._IR import Pipe -from ._utils import flatten_args, modify_graph_op_device, validate_tensors_metadata +from ._utils import ( + flatten_args, + modify_graph_op_device, + PipeInfo, + validate_tensors_metadata, +) __all__ = [ "PipelineStage", - "TracerPipelineStage", ] logger = logging.getLogger(__name__) @@ -80,7 +83,8 @@ def _make_tensor_from_meta( class _PipelineStageBase(ABC): """ Base class for pipeline stages. - Implements common methods used by both the `TracerPipelineStage` used by the tracing frontend and `PipelineStage`. + Defines or implements common methods used by the `_PipelineStage` used by + the tracing frontend and `PipelineStage` used by manual frontend. """ def __init__( @@ -97,7 +101,6 @@ def __init__( stage_index (int): The index of this stage. num_stages (int): The total number of stages in this pipeline. device (torch.device): The device to run this stage on. - num_microbatches (int): The number of microbatches to be run with this stage. group (Optional[dist.ProcessGroup]): The process group to use for communication. If `None`, the default process group will be used. Default: `None`. @@ -641,9 +644,8 @@ def __init__( self, stage_module: torch.nn.Module, stage_index: int, - pipe_info: Pipe.PipeInfo, + pipe_info: PipeInfo, device: torch.device, - num_chunks: int, group: Optional[dist.ProcessGroup] = None, ): """ @@ -904,28 +906,6 @@ def _create_grad_recv_info( return grad_recv_info_tuple -# TODO: Update this to be returned by helper method under Pipe (kwen) -class TracerPipelineStage(_PipelineStage): - def __init__( - self, - pipe: Pipe, - stage_index: int, - device: torch.device, - num_chunks: int, # To be cleaned - group: Optional[dist.ProcessGroup] = None, - ): - """ - Create a pipeline stage given a `Pipe` (representing the whole pipeline) and a stage index. - """ - # Find my stage module - stage_module = pipe.get_stage_module(stage_index) - # Get my pipe info - pipe_info = pipe.info() - super().__init__( - stage_module, stage_index, pipe_info, device, num_chunks, group - ) - - # Manual PipelineStage functions and definition METADATA_TENSOR_LEN = 100 diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 9c3e21ba70ea..ba6b042e94bf 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -3,7 +3,6 @@ import logging import operator from collections import defaultdict -from dataclasses import dataclass from enum import Enum from inspect import Parameter, signature, Signature from types import MethodType @@ -11,6 +10,7 @@ import torch import torch.fx as fx +from torch.distributed import ProcessGroup from torch.export import ExportedProgram from torch.export.unflatten import ( _assign_attr, @@ -20,9 +20,11 @@ ) from torch.fx.node import map_aggregate from torch.fx.passes.split_module import split_module - from ._backward import _null_coalesce_accumulate, stage_backward from ._unflatten import _outline_submodules +from ._utils import PipeInfo + +from .PipelineStage import _PipelineStage logger = logging.getLogger(__name__) @@ -485,12 +487,6 @@ def _direct_serialization_reduce(self): class Pipe(torch.nn.Module): - @dataclass - class PipeInfo: - graph: fx.Graph - num_stages: int - has_loss_and_backward: bool - def __init__( self, split_gm: fx.GraphModule, @@ -505,7 +501,6 @@ def __init__( self.num_stages: int = num_stages self.has_loss_and_backward = has_loss_and_backward self.loss_spec = loss_spec - self.pipe_info: Optional[Pipe.PipeInfo] = None for node in split_gm.graph.nodes: assert ( @@ -1044,12 +1039,6 @@ def from_tracing( ) submod0.recompile() - # Create pipe info - pipe.pipe_info = Pipe.PipeInfo( - graph=pipe.split_gm.graph, - num_stages=pipe.num_stages, - has_loss_and_backward=pipe.has_loss_and_backward, - ) return pipe def __str__(self): @@ -1058,12 +1047,31 @@ def __str__(self): def __repr__(self): return self.split_gm.__repr__() - def info(self) -> PipeInfo: - if self.pipe_info is None: - raise RuntimeError( - "Pipe info is not available. Please use the `pipeline` method to create the `Pipe` object." - ) - return self.pipe_info + def _info(self) -> PipeInfo: + return PipeInfo( + graph=self.split_gm.graph, + num_stages=self.num_stages, + has_loss_and_backward=self.has_loss_and_backward, + ) + + def build_stage( + self, + stage_index: int, + device: torch.device, + group: Optional[ProcessGroup] = None, + ) -> _PipelineStage: + """ + Create a pipeline stage given a stage index and distributed context. + """ + # Find stage module + stage_module = self.get_stage_module(stage_index) + # Detach pipe info + # Note: be careful what's included in `pipe_info`. We don't want to keep + # a reference to `Pipe` or `Pipe.split_gm` which stops python from + # recycling them. When python recycles them, other stage modules (which + # are irrelevant to current rank) can be automatically freed. + pipe_info = self._info() + return _PipelineStage(stage_module, stage_index, pipe_info, device, group) class SplitPoint(Enum): diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index d9fd8feaf6e5..69e455e41992 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -6,14 +6,13 @@ ScheduleInterleaved1F1B, ScheduleLoopedBFS, ) -from .PipelineStage import PipelineStage, TracerPipelineStage +from .PipelineStage import PipelineStage __all__ = [ "Pipe", "pipe_split", "SplitPoint", "pipeline", - "TracerPipelineStage", "PipelineStage", "Schedule1F1B", "ScheduleGPipe", diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py index f4680530d29f..72e96a34e3ac 100644 --- a/torch/distributed/pipelining/_utils.py +++ b/torch/distributed/pipelining/_utils.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging +from dataclasses import dataclass from typing import List, Tuple, Union import torch @@ -120,3 +121,14 @@ def validate_tensors_metadata( validate_tensor_metadata( f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] ) + + +@dataclass +class PipeInfo: + """ + Captures information for a pipeline (`Pipe` object). + """ + + graph: fx.Graph + num_stages: int + has_loss_and_backward: bool From 921aa194c77f5279b15415eaa213813ddcdb3b29 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 7 Jun 2024 18:23:03 -0700 Subject: [PATCH 509/706] [pipelining] Move modify_graph_op_device to _IR.py (#128241) This part is more IR related. Thus moving from `PipelineStage` constructor to `pipe.build_stage(..., device, ...)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128241 Approved by: https://github.com/wconstab ghstack dependencies: #128240 --- torch/distributed/pipelining/PipelineStage.py | 18 +------ torch/distributed/pipelining/_IR.py | 51 ++++++++++++++++++- torch/distributed/pipelining/_utils.py | 36 ------------- 3 files changed, 51 insertions(+), 54 deletions(-) diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index c18f91e1d9d8..4f8902b4320f 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -15,12 +15,7 @@ from ._backward import stage_backward from ._debug import map_debug_info -from ._utils import ( - flatten_args, - modify_graph_op_device, - PipeInfo, - validate_tensors_metadata, -) +from ._utils import flatten_args, PipeInfo, validate_tensors_metadata __all__ = [ @@ -686,8 +681,6 @@ def __init__( # Cast submodule to device self._move_submod_to_device() - # Move ops argument to device - self._move_ops_to_device() def _move_submod_to_device(self): # Move submodule to indicated device if possible @@ -702,15 +695,6 @@ def _move_submod_to_device(self): else: self.submod.to(self.device) - def _move_ops_to_device(self): - # Today PT2 tracer does not treat `x.device` as a symbolic device; - # instead, the device of tracing time got burned into the generated - # code. Here we provide a workaround for users to manually modify the - # "device" kwarg of operations. Such operation may include: - # `torch.ones`, `torch.zeros`, `torch.rand`, etc. - if isinstance(self.submod, torch.fx.GraphModule): - modify_graph_op_device(self.submod, self.device) - def _prepare_forward_infra(self, num_microbatches: int): """ Create send/recv infrastructures for activations (during forward) diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index ba6b042e94bf..e9aeeb06c870 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -486,6 +486,41 @@ def _direct_serialization_reduce(self): ) +def _modify_graph_op_device( + gm: torch.fx.GraphModule, + new_device: torch.device, +): + """ + Modify the device argument of all "call_function" nodes in the graph. This + is useful for moving the graph to a different device. In particular for + generator ops, like torch.ones. + """ + modified = False + for node in gm.graph.nodes: + if node.op == "call_function": + if "device" in node.kwargs and node.kwargs["device"] != new_device: + logger.debug( + f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004 + ) + node.update_kwarg("device", new_device) + modified = True + elif node.op == "call_module": + # Recursively modify "device" in submodules + submod = gm.get_submodule(node.target) + if isinstance(submod, torch.fx.GraphModule): + _modify_graph_op_device(submod, new_device) + elif isinstance(submod, InterpreterModule): + # If unflattening has been performed, we need to access its graph module by `.graph_module` + _modify_graph_op_device(submod.graph_module, new_device) + else: + logger.warning( + f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004 + ) + + if modified: + gm.recompile() + + class Pipe(torch.nn.Module): def __init__( self, @@ -1061,10 +1096,24 @@ def build_stage( group: Optional[ProcessGroup] = None, ) -> _PipelineStage: """ - Create a pipeline stage given a stage index and distributed context. + Create a `PipelineStage` given a stage index and distributed context. """ # Find stage module stage_module = self.get_stage_module(stage_index) + + # Move ops argument to device + # Today PT2 tracer does not treat `x.device` as a symbolic device; + # instead, the device of tracing time got burned into the generated + # code. Here we provide a workaround for users to manually modify the + # "device" kwarg of operations. Such operation may include: + # `torch.ones`, `torch.zeros`, `torch.rand`, etc. + if isinstance(stage_module, torch.fx.GraphModule): + _modify_graph_op_device(stage_module, device) + else: + logger.warning( + f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004 + ) + # Detach pipe info # Note: be careful what's included in `pipe_info`. We don't want to keep # a reference to `Pipe` or `Pipe.split_gm` which stops python from diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py index 72e96a34e3ac..31caf3427424 100644 --- a/torch/distributed/pipelining/_utils.py +++ b/torch/distributed/pipelining/_utils.py @@ -5,7 +5,6 @@ import torch from torch import fx -from torch.export.unflatten import InterpreterModule logger = logging.getLogger(__name__) @@ -54,41 +53,6 @@ def extract_tensor_args(a): return flat_args -def modify_graph_op_device( - gm: torch.fx.GraphModule, - new_device: torch.device, -): - """ - Modify the device argument of all "call_function" nodes in the graph. This - is useful for moving the graph to a different device. In particular for - generator ops, like torch.ones. - """ - modified = False - for node in gm.graph.nodes: - if node.op == "call_function": - if "device" in node.kwargs and node.kwargs["device"] != new_device: - logger.debug( - f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004 - ) - node.update_kwarg("device", new_device) - modified = True - elif node.op == "call_module": - # Recursively modify "device" in submodules - submod = gm.get_submodule(node.target) - if isinstance(submod, torch.fx.GraphModule): - modify_graph_op_device(submod, new_device) - elif isinstance(submod, InterpreterModule): - # If unflattening has been performed, we need to access its graph module by `.graph_module` - modify_graph_op_device(submod.graph_module, new_device) - else: - logger.warning( - f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004 - ) - - if modified: - gm.recompile() - - class PipeliningShapeError(RuntimeError): """Shape mismatch between configured and runtime values.""" From fe74bbd6f0fac14085c0f2dd6dff1eca79eb0dca Mon Sep 17 00:00:00 2001 From: Andrew Hoblitzell Date: Sat, 8 Jun 2024 01:47:57 +0000 Subject: [PATCH 510/706] init sigmoid comments (#127983) Fixes #127913 ### Description Add docstring to `torch/onnx/symbolic_opset9.py`:`sigmoid` function ### Checklist - [x] The issue that is being fixed is referred in the description - [x] Only one issue is addressed in this pull request - [x] Labels from the issue that this PR is fixing are added to this pull request - [x] No unnecessary issues are included into this pull request Pull Request resolved: https://github.com/pytorch/pytorch/pull/127983 Approved by: https://github.com/xadupre --- torch/onnx/symbolic_opset9.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index de9b616103e7..adfa538c8f08 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -757,6 +757,16 @@ def atan2(g: jit_utils.GraphContext, self, other): @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) @_beartype.beartype def sigmoid(g: jit_utils.GraphContext, self): + """Converts the corresponding PyTorch function into ONNX operators. + + It is not meant to be called directly by a user. + + Args: + g (jit_utils.GraphContext): Graph context. + self (Tensor): the input tensor. + Returns: + ONNX operator + """ return g.op("Sigmoid", self) From f9508b4c1f15e9239379449b3db743038dec90d3 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 7 Jun 2024 14:53:47 -0700 Subject: [PATCH 511/706] [pipelining] Update Pipelining Docs (#128236) ---- - Bring PipelineStage/Schedule more front-and-center - provide details on how to manually construct PipelineStage - move tracer example and manual example below so the high-level flow (e2e) is closer to the top Pull Request resolved: https://github.com/pytorch/pytorch/pull/128236 Approved by: https://github.com/H-Huang ghstack dependencies: #128201, #128228 --- docs/source/distributed.pipelining.rst | 179 ++++++++++++++++--------- 1 file changed, 114 insertions(+), 65 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 32efef67cde8..3d860fba976d 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -43,15 +43,14 @@ on **general** models. It consists of two parts: a **splitting frontend** and a **distributed runtime**. The splitting frontend takes your model code as-is, splits it up into "model -partitions", and capture the data-flow relationship. The distributed runtime +partitions", and captures the data-flow relationship. The distributed runtime executes the pipeline stages on different devices in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc. Overall, the ``pipelining`` package provides the following features: -* Splitting of model code based on simple specification. The goal is to make - parallelism work for your model with **zero model code change**. +* Splitting of model code based on simple specification. * Rich support for pipeline schedules, including GPipe, 1F1B, Interleaved 1F1B and Looped BFS, and providing the infrastruture for writing customized schedules. @@ -63,18 +62,122 @@ Overall, the ``pipelining`` package provides the following features: application on the Llama model. -Step 1: choose the frontend that fits your need -*********************************************** +Step 1: build ``PipelineStage`` objects for Execution +******************************************************** -The ``pipelining`` package provides two frontends for two different use cases. -You can make your choice based on whether you have: +Before we can use a PipelineSchedule, we need to create PipelineStage objects that wrap the part of the model running in that stage. The `PipelineStage` is responsible for allocating communication buffers and creating send/recv ops to communicate with its peers. It manages intermediate buffers e.g. for the outputs of forward that have not been consumed yet, and it provides a utility for running the backwards for the stage model. -* a full model, or -* module constructor for each stage. +A `PipelineStage` needs to know the input and output shapes for the stage model, so that it can correctly allocate communication buffers. The shapes must be static, e.g. at runtime the shapes can not change from step to step. A class `PipeliningShapeError` will be raised if runtime shapes do not match the expected shapes. When composing with other paralleisms or applying mixed precision, these techniques must be taken into account so the `PipelineStage` knows the correct shape (and dtype) for the output of the stage module at runtime. +Users may construct a `PipelineStage` instance directly, by passing in an `nn.Module` representing the portion of the model that should run on the stage. This may require changes to the original model code. See the example below "Preparing a model for pipeline splitting". -Frontend 1: the ``pipeline`` API -- if you have a full model -============================================================ +Alternatively, the tracing frontend can use graph-partitioning to construct a `GraphModule` that represents the desired subset of the model automatically. This technique requires the model is traceable with torch.Export in non-strict mode. Composability of the resulting `GraphModule` with other parallelism techniques and torch.compile is experimental, and may require some workarounds. Usage of this frontend may be more appealing if the user cannot easily change the model code. See "Splitting a Model with the ``pipeline`` tracing frontend" for more information. + + +Step 2: use ``PipelineSchedule`` for execution +********************************************** + +We can now attach the ``PipelineStage`` to a pipeline schedule, and run the +schedule with input data. Here is a GPipe example: + +.. code-block:: python + + from torch.distributed.pipelining import ScheduleGPipe + + # Create a schedule + schedule = ScheduleGPipe(stage, n_microbatches) + + # Input data (whole batch) + x = torch.randn(batch_size, in_dim, device=device) + + # Run the pipeline with input `x` + # `x` will be divided into microbatches automatically + if rank == 0: + schedule.step(x) + else: + output = schedule.step() + +Note that the above code needs to be launched for each worker, thus we use a +launcher service to launch multiple processes: + +.. code-block:: bash + + torchrun --nproc_per_node=2 example.py + + +Preparing a model for pipeline splitting +======================================== + +To directly construct a `PipelineStage`, the user is responsible for providing a single nn.Module instance that owns the relevant nn.Parameters and nn.Buffers, and defines a .forward() method that executes the operations relevant for that stage. For example, a condensed version of the Transformer class defined in Torchtitan shows a pattern of building an easily partitionable model. + +.. code-block:: python + + class Transformer(nn.Module): + def __init__(self, model_args: ModelArgs): + super().__init__() + + self.tok_embeddings = nn.Embedding(...) + + # Using a ModuleDict lets us delete layers witout affecting names, + # ensuring checkpoints will correctly save and load. + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(...) + + self.output = nn.Linear(...) + + def forward(self, tokens: torch.Tensor): + # Handling layers being 'None' at runtime enables easy pipeline splitting + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h).float() if self.output else h + return output + +A model defined in this manner can be easily configured per stage by first initializing the whole model (using meta-device to avoid OOM errors), deleting undesired layers for that stage, and then creating a PipelineStage that wraps the model. For example: + +.. code-block:: python + + with torch.device("meta"): + assert num_stages == 2, "This is a simple 2-stage example" + + # we construct the entire model, then delete the parts we do not need for this stage + # in practice, this can be done using a helper function that automatically divides up layers across stages. + model = Transformer() + + if stage_index == 0: + # prepare the first stage model + del model.layers["1"] + model.norm = None + model.output = None + + elif stage_index == 1: + # prepare the second stage model + model.tok_embeddings = None + del model.layers["0"] + + from torch.distributed.pipelining import PipelineStage + stage = PipelineStage( + model, + stage_index, + num_stages, + device, + input_args=example_input_microbatch, + ) + + +The ``PipelineStage`` requires an example argument `input_args` representing the runtime input to the stage, which would be one microbatch worth of input data. +This argument is passed through the forward method of the stage module to determine the +input and output shapes required for communication. + +When composing with other Data or Model parallelism techniques, `output_args` may also be required, if the output shape/dtype of the model chunk will be affected. + + +Splitting a Model with the ``pipeline`` tracing frontend +======================================================== If you have a full model and do not want to spend time on modifying it into a sequence of "model partitions", the ``pipeline`` API is here to help. @@ -186,60 +289,6 @@ You can also create a distributed stage runtime on a device using ``Pipe``: model into a single graph. If your model is not full-graph'able, you can use our manual frontend below. -Frontend 2: ``PipelineStage`` -- if you already have module for each stage -================================================================================ - -If you already have the module for each stage, you can skip the pipeline split -step above and directly connect to our runtime offering: ``PipelineStage``. -The ``PipelineStage`` wraps your stage module given a distributed context, -i.e. a ``ProcessGroup`` along the pipeline dimension. - -.. code-block:: python - - from torch.distributed.pipelining import PipelineStage - stage = PipelineStage( - stage_mod, - stage_idx, - num_stages, - device, - input_args=x.chunk(num_microbatches)[0], - ) - -The ``PipelineStage`` requires an example argument (similar to ``example_args`` used in ``pipeline``). -This argument is passed through the forward method of the stage module to determine the -input and output shapes required for communication. - - -Step 2: use ``PipelineSchedule`` for execution -********************************************** - -We can now attach the ``PipelineStage`` to a pipeline schedule, and run the -schedule with input data. Here is a GPipe example: - -.. code-block:: python - - from torch.distributed.pipelining import ScheduleGPipe - - # Create a schedule - schedule = ScheduleGPipe(stage, n_microbatches) - - # Input data (whole batch) - x = torch.randn(batch_size, in_dim, device=device) - - # Run the pipeline with input `x` - # `x` will be divided into microbatches automatically - if rank == 0: - schedule.step(x) - else: - output = schedule.step() - -Note that the above code needs to be launched for each worker, thus we use a -launcher service to launch multiple processes: - -.. code-block:: bash - - torchrun --nproc_per_node=2 example.py - Hugging Face Examples ********************* From 0ef522956943e3f0398b6d2bf9ee1ac0a5a3130d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 8 Jun 2024 02:34:06 +0000 Subject: [PATCH 512/706] Revert "Change lerp decomp to use aten.as_strided_copy instead of prims.copy_strided (#128030)" This reverts commit fdf1666b20f63e4acf01798f009e478d997a7f7f. Reverted https://github.com/pytorch/pytorch/pull/128030 on behalf of https://github.com/nWEIdia due to breaking cuda12.1 test_cuda, see HUD https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=inductor ([comment](https://github.com/pytorch/pytorch/pull/128030#issuecomment-2155764546)) --- torch/_refs/__init__.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 4e00a125434f..68675c751736 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -5014,11 +5014,7 @@ def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]): # make sure the decomposition output's stride is same as non-decomposition path. stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs)) if output.stride() != stride: - output = torch.ops.aten.as_strided_copy( - output, - output.size(), - stride, - ) + output = prims.copy_strided(output, stride) return handle_noncontiguous_outputs(inputs, output) From 622060294386372d6a4e6330904403017686bcbb Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Fri, 7 Jun 2024 16:57:05 -0700 Subject: [PATCH 513/706] [torchbind] support query schema of methods (#128267) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128267 Approved by: https://github.com/angelayi --- test/export/test_torchbind.py | 13 ++++++++++ torch/_library/fake_class_registry.py | 35 ++++++++++++++++++++------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 2ff09598ace8..42c87bf4c10e 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -188,6 +188,19 @@ def forward(self, obj_attr, x, n): return (add,)""", ) + def test_method_schema(self): + tq = _empty_tensor_queue() + fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() + fake_obj = torch._library.fake_class_registry.to_fake_obj(fake_mode, tq) + self.assertExpectedInline( + str(fake_obj.push.schema), + """push(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, Tensor _1) -> NoneType _0""", + ) + self.assertExpectedInline( + str(fake_obj.pop.schema), + """pop(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0) -> Tensor _0""", + ) + @parametrize("pre_dispatch", [True, False]) def test_attribute(self, pre_dispatch): class MyModule(torch.nn.Module): diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index d77989cd829b..aaa57d79e283 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -16,6 +16,23 @@ def __init__(self, wrapped_obj: Any, script_class_name: str): self.script_class_name = script_class_name +class FakeScriptMethod: + def __init__( + self, + self_fake_obj: FakeScriptObject, + method_name: str, + schema: Optional[torch.FunctionSchema], + ): + self.self_fake_obj = self_fake_obj + self.method_name = method_name + self.schema = schema + + def __call__(self, *args, **kwargs): + from torch._higher_order_ops.torchbind import call_torchbind + + return call_torchbind(self.self_fake_obj, self.method_name, *args, **kwargs) + + class HasStaticMethodFromReal(Protocol): @classmethod def from_real(cls, real_obj: torch.ScriptObject): @@ -95,25 +112,25 @@ def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject: fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened) - def _call_torchbind(method_name): - from torch._higher_order_ops.torchbind import call_torchbind - - def wrapped(self_, *args, **kwargs): - return call_torchbind(self_, method_name, *args, **kwargs) - - return wrapped - fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name()) # type: ignore[attr-defined] + for name in x._method_names(): # type: ignore[attr-defined] attr = getattr(fake_x, name, None) if attr: if not callable(attr): raise RuntimeError(f"Expect {name} to be a callable but got {attr}.") + real_attr = getattr(x, name) # type: ignore[attr-defined] + + # real attr sometimes is not torch.ScriptMethod thus doesn't have schema e.g. __init___ or __eq__ + method_schema: Optional[torch.FunctionSchema] = None + if isinstance(real_attr, torch.ScriptMethod): + method_schema = real_attr.schema # type: ignore[attr-defined] + setattr( fake_x_wrapped, name, - _call_torchbind(name).__get__(fake_x_wrapped), + FakeScriptMethod(fake_x_wrapped, name, method_schema), ) else: log.warning("fake object of %s doesn't implement method %s.", x, name) From 6e5c2a1a3bc9507ec459f3e01f5e492d8bef122a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Jun 2024 15:03:00 -0700 Subject: [PATCH 514/706] [inductor] Add missing files to torch_key (#128230) Previosly all subdirs (like torch.inductor.codegen) were not hashed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128230 Approved by: https://github.com/oulgen --- torch/_inductor/codecache.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6ef07ed90692..6077747d95da 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -559,20 +559,28 @@ def get_str(obj) -> str: return "\n".join(lines) -def get_code_hash(roots): - contents: Dict[str, bytes] = {torch.__version__: b""} - for lib in pkgutil.iter_modules(roots): +def build_code_hash(roots, prefix, hasher): + for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): spec = lib.module_finder.find_spec(lib.name, None) assert spec is not None module = spec.origin assert module is not None with open(module, "rb") as f: - contents[spec.name] = f.read() + hasher.update(spec.name.encode("utf-8")) + hasher.update(f.read()) + if lib.ispkg: + # need to also hash submodules + build_code_hash(spec.submodule_search_locations, f"{spec.name}.", hasher) + + +def get_code_hash(roots, extra_files=()): hasher = hashlib.sha256() - # Iterate over dict in sorted order since iter_modules may not be deterministic - for name, value in sorted(contents.items()): - hasher.update(name.encode("utf-8")) - hasher.update(value) + hasher.update(torch.__version__.encode("utf-8")) + build_code_hash(roots, "", hasher) + for path in extra_files: + if os.path.exists(path): + with open(path, "rb") as f: + hasher.update(f.read()) return hasher.digest() @@ -583,7 +591,15 @@ def torch_key(): """ if not config.is_fbcode(): inductor_root = os.path.dirname(__file__) - return get_code_hash([inductor_root]) + extra_files = ( + "codegen/aoti_runtime/interface.cpp", + "codegen/aoti_runtime/implementation.cpp", + "codegen/cpp_prefix.h", + "script.ld", + ) + return get_code_hash( + [inductor_root], [os.path.join(inductor_root, x) for x in extra_files] + ) from libfb.py import parutil From 1d84c7e1002b5bf2c1e2970ac35b924dcd14116b Mon Sep 17 00:00:00 2001 From: Iris Z <31293777+wz337@users.noreply.github.com> Date: Sat, 8 Jun 2024 04:28:56 +0000 Subject: [PATCH 515/706] [DeviceMesh] Update get_group and add get_all_groups (#128097) Fixes #121984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128097 Approved by: https://github.com/wconstab, https://github.com/wanchaol --- test/distributed/_spmd/test_tracing.py | 2 +- test/distributed/test_device_mesh.py | 21 +++---- .../_composable/fsdp/_fsdp_common.py | 8 +-- torch/distributed/device_mesh.py | 58 ++++++++++--------- torch/distributed/fsdp/_init_utils.py | 3 +- torch/distributed/tensor/parallel/fsdp.py | 4 +- 6 files changed, 46 insertions(+), 50 deletions(-) diff --git a/test/distributed/_spmd/test_tracing.py b/test/distributed/_spmd/test_tracing.py index 20ad2a6e06f9..77445aac7419 100644 --- a/test/distributed/_spmd/test_tracing.py +++ b/test/distributed/_spmd/test_tracing.py @@ -46,7 +46,7 @@ def _test_tracing_all_reduce_nd(self, mesh_tensor): local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank # check all dim groups - dim_to_subgroups = mesh.get_group() + dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): dim_group_size = get_world_size(dim_group) global_ranks = [ diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 26a9bd1e0b3d..22d8b0fbbdce 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -73,7 +73,7 @@ def test_assert_invalid_mesh_tensor(self): device_mesh = DeviceMesh(self.device_type, mesh) @with_comms - def test_get_group(self): + def test_get_group_and_get_all_groups(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") @@ -82,16 +82,17 @@ def test_get_group(self): tp_mesh = mesh_2d["tp"] dp_mesh = mesh_2d["dp"] - self.assertEqual(len(mesh_2d.get_group()), 2) - self.assertEqual(mesh_2d.get_group()[0], mesh_2d.get_group("dp")) - self.assertEqual(mesh_2d.get_group()[1], mesh_2d.get_group("tp")) - self.assertEqual(mesh_2d.get_group(0), mesh_2d.get_group("dp")) self.assertEqual(mesh_2d.get_group(1), mesh_2d.get_group("tp")) self.assertEqual(mesh_2d.get_group("dp"), dp_mesh.get_group()) self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group()) + groups = mesh_2d.get_all_groups() + self.assertEqual(len(groups), 2) + self.assertTrue(tp_mesh.get_group() in groups) + self.assertTrue(dp_mesh.get_group() in groups) + @with_comms def test_get_local_rank_raises_exception(self): mesh_shape = (2, self.world_size // 2) @@ -126,7 +127,7 @@ def test_device_mesh_2d(self): mesh = DeviceMesh(self.device_type, mesh_tensor) # check all dim groups - dim_to_subgroups = mesh.get_group() + dim_to_subgroups = mesh.get_all_groups() expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] for dim, dim_group in enumerate(dim_to_subgroups): @@ -191,7 +192,7 @@ def test_from_group_with_invalid_mesh(self): DeviceMesh.from_group(global_pg, "cuda", invalid_mesh) device_mesh = init_device_mesh(self.device_type, (2, 2)) - groups = device_mesh.get_group() + groups = device_mesh.get_all_groups() invalid_mesh = (0, 1, 2, 3) # 1D mesh when we need 2D regex = r"Expects mesh with ndim equal to number of ProcessGroups but got mesh \[0, 1, 2, 3\] and 2 ProcessGroups" with self.assertRaisesRegex(ValueError, regex): @@ -230,7 +231,7 @@ def test_device_mesh_nd(self): mesh = DeviceMesh(self.device_type, mesh_tensor) # check all dim groups - dim_to_subgroups = mesh.get_group() + dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): self.assertTrue(dim < mesh_tensor.ndim) @@ -803,7 +804,7 @@ def test_broadcast_nd(self): local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank # check all dim groups - dim_to_subgroups = mesh.get_group() + dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): dim_group_size = get_world_size(dim_group) global_ranks = [ @@ -820,7 +821,7 @@ def test_scatter_nd(self): mesh = DeviceMesh(self.device_type, mesh_tensor) # check all dim groups - dim_to_subgroups = mesh.get_group() + dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): dim_group_size = get_world_size(dim_group) global_ranks = [ diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index f372fcd2e073..3cb06174703a 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -34,9 +34,7 @@ def __post_init__(self): if self.shard_mesh_dim is None: raise AssertionError("Expects non-None shard_mesh_dim") self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim) - self.shard_process_group = cast( - dist.ProcessGroup, self.mesh.get_group(self.shard_mesh_dim) - ) + self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim) self.shard_mesh_rank: int = self.shard_process_group.rank() @@ -47,9 +45,7 @@ def __post_init__(self): if self.replicate_mesh_dim is None: raise AssertionError("Expects non-None replicate_mesh_dim") self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim) - self.replicate_process_group = cast( - dist.ProcessGroup, self.mesh.get_group(self.replicate_mesh_dim) - ) + self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim) self.replicate_mesh_rank: int = self.replicate_process_group.rank() diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 5bbe4e113464..f25a5b91da4b 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -470,48 +470,50 @@ def __getitem__( submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) return submesh - def get_group( - self, mesh_dim: Optional[Union[int, str]] = None - ) -> Union[ProcessGroup, List[ProcessGroup]]: + def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: """ - Returns a list of ProcessGroups corresponding to the mesh dimensions, or - returns a single ProcessGroup if mesh_dim is specified or the given mesh has - only one mesh dimension. + Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the + DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. Args: mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index of the mesh dimension. Default is None. Returns: - A list of :class:`ProcessGroup` object when `mesh_dim` is not specified for - a DeviceMesh with more than 1 dimension; otherwise, returns a single - :class:`ProcessGroup` object. + A :class:`ProcessGroup` object. """ if not hasattr(self, "_dim_group_infos"): raise RuntimeError("DeviceMesh process groups not initialized!") - if self.mesh.ndim == 1: - return not_none( - _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) + if self.mesh.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.mesh.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + "If you want to get the list of all the ProcessGroups in the DeviceMesh," + "please use `get_all_groups()` instead.", ) - if mesh_dim is not None: - if isinstance(mesh_dim, str): - mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim) - return not_none( - _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) - ) + if self.mesh.ndim == 1 and mesh_dim is None: + mesh_dim = 0 else: - dim_groups = [] - for ith_dim in range(self.mesh.ndim): - dim_groups.append( - not_none( - _find_pg_by_ranks_and_tag( - *self._dim_group_infos[ith_dim][:2] - ) - ) - ) - return dim_groups + mesh_dim = ( + _mesh_resources.get_mesh_dim_by_name(self, mesh_dim) + if isinstance(mesh_dim, str) + else mesh_dim + ) + + return not_none( + _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index] + ) + + def get_all_groups(self) -> List[ProcessGroup]: + """ + Returns a list of ProcessGroups for all mesh dimensions. + + Returns: + A list of :class:`ProcessGroup` object. + """ + return [self.get_group(i) for i in range(self.mesh.ndim)] @staticmethod def from_group( diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 2364b1871206..64685013d9a4 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -166,8 +166,7 @@ def _init_process_group_state_for_hybrid_shard( state.process_group = device_mesh.get_group(mesh_dim=1) else: raise ValueError( - "Expected device_mesh to have ndim=2 " - f"but got {len(device_mesh.get_group())}" + f"Expected device_mesh to have ndim=2 but got {device_mesh.ndim}" ) elif process_group is None: default_group = _get_default_group() diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index d7eae93a7258..888631e67777 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -112,9 +112,7 @@ def _create_sharded_tensor_md_from_dt( def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup: mesh = dt.device_mesh assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" - dim_groups = mesh.get_group() - assert isinstance(dim_groups, list) - return dim_groups[0] + return mesh.get_group() def _rewrite_spec_if_needed( From 8a45cf4c64c13859d36ea5f8d16f4e1145a1f231 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Fri, 7 Jun 2024 16:42:19 +0000 Subject: [PATCH 516/706] [AOTI] align data_size of the constants (#127610) https://github.com/pytorch/pytorch/pull/124272 set the alignment to the `consts_o` but if there're `data_size` of tensor in the `consts_o` non divisible by the alignment, the following tensors are not aligned anymore, resulting in poor performance on CPU. We align the `data_size` as well in this PR and pad the serialized bytes. Since `size` of the tensor instead of the `data_size` is used when creating tensor from the serialized bytes ([link](https://github.com/pytorch/pytorch/blob/f4d7cdc5e63c786b1f6588eafa53bbc6d33c3826/torch/csrc/inductor/aoti_runtime/model.h#L236-L259)), there won't be correctness issue. `data_size` is only used to record the [bytes_read](https://github.com/pytorch/pytorch/blob/f4d7cdc5e63c786b1f6588eafa53bbc6d33c3826/torch/csrc/inductor/aoti_runtime/model.h#L217). This PR will improve the performance on CPU for 4 models in HF, 7 models in TIMM and 1 model in Torchbench. For the unit test, I add a bias value the original `data_size` of which is not divisible by the alignment to test the correctness: ``` constants_info_[0].dtype = static_cast(at::kFloat); constants_info_[0].data_size = 64; # was 40 before this PR constants_info_[0].shape = {10}; constants_info_[1].dtype = static_cast(at::kFloat); ...... ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127610 Approved by: https://github.com/jgong5, https://github.com/desertfire --- test/inductor/test_aot_inductor.py | 3 +- torch/_inductor/codecache.py | 44 +++++++++++++++------- torch/_inductor/codegen/cpp_wrapper_cpu.py | 30 +++++++++------ torch/_inductor/codegen/memory_planning.py | 32 +--------------- torch/_inductor/utils.py | 35 ++++++++++++++++- 5 files changed, 85 insertions(+), 59 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 15e10140d926..fb15fa01d318 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -422,9 +422,10 @@ class LinearModel(torch.nn.Module): def __init__(self, device): super().__init__() self.weight = torch.randn(10, 10, device=device).to(dtype) + self.bias = torch.randn(10, device=device).to(dtype) def forward(self, y): - return torch.nn.functional.linear(y, self.weight) + return torch.nn.functional.linear(y, self.weight, self.bias) example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6077747d95da..ca6cacaa213b 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -56,7 +56,7 @@ _reload_python_module_in_subproc, ) from torch._inductor.runtime.runtime_utils import cache_dir -from torch._inductor.utils import clear_on_fresh_inductor_cache, is_linux +from torch._inductor.utils import ALIGN_BYTES, clear_on_fresh_inductor_cache, is_linux from torch._logging import trace_structured from torch._subclasses.fake_tensor import ( @@ -2059,10 +2059,14 @@ def _compile_consts_linux(consts: bytes) -> str: # as read-only (i.e. .lrodata) which could accomodate larger size of data # to be linked. rename_data = " .data=.lrodata,alloc,load,readonly,data,contents" + + assert ( + ALIGN_BYTES & (ALIGN_BYTES - 1) + ) == 0 and ALIGN_BYTES >= 64, "must be power of 2 and >= 64" cmd = ( f"{objcopy_command} --rename-section" f"{rename_data}" - " --set-section-alignment .data=64" # following the gAlignment of CPU in c10/core/alignment.h + f" --set-section-alignment .data={ALIGN_BYTES}" # following the gAlignment of CPU in c10/core/alignment.h f" {consts_o} {consts_o}" ) log.debug("aot constant rename section command: %s", cmd) @@ -2186,7 +2190,14 @@ def _compile_consts_darwin(consts: bytes) -> str: else: run_command_and_check(compile_cmd) - def _to_bytes(t: torch.Tensor) -> bytes: + def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: + def _pad_to_alignment(raw_bytes): + padded_bytes = raw_bytes.ljust( + (len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES, + b"\x00", + ) + return padded_bytes + # This serializes the tensor's untyped_storage to bytes by accessing # the raw data of the underlying structure. import ctypes @@ -2195,22 +2206,27 @@ def _to_bytes(t: torch.Tensor) -> bytes: return b"" if t.is_mkldnn: - raw_array = ctypes.cast( - torch.ops.mkldnn.data_ptr(t), - ctypes.POINTER(ctypes.c_ubyte * torch.ops.mkldnn._nbytes(t)), - ) - return bytes(raw_array.contents) + data_ptr = torch.ops.mkldnn.data_ptr(t) + nbytes = torch.ops.mkldnn._nbytes(t) + else: + t_cpu = t.untyped_storage().cpu() + data_ptr = t_cpu.data_ptr() + nbytes = t_cpu.nbytes() - t_cpu = t.untyped_storage().cpu() raw_array = ctypes.cast( - t_cpu.data_ptr(), - ctypes.POINTER(ctypes.c_ubyte * t_cpu.nbytes()), + data_ptr, + ctypes.POINTER(ctypes.c_ubyte * nbytes), ) + raw_bytes = bytes(raw_array.contents) + return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) - return bytes(raw_array.contents) - + all_cuda = all( + graph.get_original_value_of_constant(name).is_cuda + for name in graph.constants.keys() + if name not in graph.folded_constants + ) serialized_weights = b"".join( - _to_bytes(graph.get_original_value_of_constant(name)) + _to_bytes(graph.get_original_value_of_constant(name), all_cuda) for name in graph.constants.keys() if name not in graph.folded_constants ) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 7d38b6ed1acb..3dc397a84b8d 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -15,7 +15,7 @@ from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey from .. import config, ir from ..codecache import CudaKernelParamCache -from ..utils import cache_on_self, sympy_product +from ..utils import _align, ALIGN_BYTES, cache_on_self, sympy_product from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import IndentedBuffer @@ -239,8 +239,6 @@ class RAIIPyObject { """ ) - from .memory_planning import ALIGN_BYTES - # Round up to the nearest multiple of ALIGN_BYTES # ALIGN_BYTES must be a power of 2 self.header.splice( @@ -721,6 +719,11 @@ def codegen_model_constructor(self): ), f"input {name=} cannot be symbolic" self.write_input_output_info("inputs_info_", idx, name) + all_cuda = all( + V.graph.get_original_value_of_constant(name).is_cuda + for name in V.graph.constants.keys() + if name not in V.graph.folded_constants + ) for idx, name in enumerate(V.graph.constants.keys()): tensor = V.graph.get_original_value_of_constant(name) assert isinstance(tensor, torch.Tensor) @@ -731,14 +734,19 @@ def codegen_model_constructor(self): self.prefix.writeline( f"constants_info_[{idx}].offset = {tensor.storage_offset()};" ) - if tensor.is_mkldnn: - self.prefix.writeline( - f"constants_info_[{idx}].data_size = {torch.ops.mkldnn._nbytes(tensor)};" - ) - else: - self.prefix.writeline( - f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};" - ) + + # If constants to serialize contain cpu tensors, we always align data_size it to 64. + # When loading the constants, the valid data will depends on the size + # not the data_size so there won't be correctness issue. + data_size = ( + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() + ) + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};" + ) + from_folded = "true" if name in V.graph.folded_constants else "false" self.prefix.writeline( f"constants_info_[{idx}].from_folded = {from_folded};" diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 2aade2a297df..3489a61f2d86 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -10,7 +10,7 @@ import torch from .. import config, ir -from ..utils import cache_on_self, CachedMethod, IndentedBuffer +from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer from ..virtualized import V from .wrapper import ( @@ -22,36 +22,6 @@ ) -ALIGN_BYTES = 64 -assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" - - -def _align(nbytes): - """Round up to the nearest multiple of ALIGN_BYTES""" - return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES - - -def _is_aligned(v: sympy.Expr): - """v can be statically proven to be a multiple of ALIGN_BYTES""" - if isinstance(v, (sympy.Add, sympy.Max)): - return all(map(_is_aligned, v.args)) - return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES - - -class align(sympy.Function): - """Symbolically round up to the nearest multiple of ALIGN_BYTES""" - - nargs = (1,) - is_integer = True - - @classmethod - def eval(cls, value): - if isinstance(value, (int, sympy.Integer)): - return _align(int(value)) - if _is_aligned(value): - return value - - @dataclasses.dataclass class LiveRange: """ diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 0915a8330c34..d19ef0cd3004 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -63,7 +63,36 @@ _T = TypeVar("_T") VarRanges = Dict[sympy.Expr, sympy.Expr] -ALIGNMENT = 16 +GPU_ALIGN_BYTES = 16 + +ALIGN_BYTES = 64 +assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" + + +def _align(nbytes): + """Round up to the nearest multiple of ALIGN_BYTES""" + return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES + + +def _is_aligned(v: sympy.Expr): + """v can be statically proven to be a multiple of ALIGN_BYTES""" + if isinstance(v, (sympy.Add, sympy.Max)): + return all(map(_is_aligned, v.args)) + return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES + + +class align(sympy.Function): + """Symbolically round up to the nearest multiple of ALIGN_BYTES""" + + nargs = (1,) + is_integer = True + + @classmethod + def eval(cls, value): + if isinstance(value, (int, sympy.Integer)): + return _align(int(value)) + if _is_aligned(value): + return value def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float: @@ -1548,7 +1577,9 @@ def tensor_is_aligned(tensor: torch.Tensor): # but symbolic storage_offsets are. For consistency, we suppress guard creation # upon performing this check: that ensures that we don't add recompiles when we # add this logic. - return (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % ALIGNMENT == 0 + return ( + tensor.storage_offset() * get_dtype_size(tensor.dtype) + ) % GPU_ALIGN_BYTES == 0 def should_assume_input_aligned(example_input: torch.Tensor): From 0e3fe694d160caf4f6ba1e8eb5402edb8aec23e8 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 7 Jun 2024 19:43:16 -0700 Subject: [PATCH 517/706] [pipelining] Restore a stage constructor for tracer path (#128273) In case user modified stage module out of place, such as mod = DDP(mod) mod = torch.compile(mod) They need a stage builder else than `pipe.build_stage()`. This PR provides an API to do so: ``` def build_stage( stage_module, stage_index, pipe.info(), ... ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128273 Approved by: https://github.com/wconstab --- docs/source/distributed.pipelining.rst | 2 + test/distributed/pipelining/test_stage.py | 14 ++++++- torch/distributed/pipelining/PipelineStage.py | 39 +++++++++++++++++++ torch/distributed/pipelining/_IR.py | 15 +++++-- torch/distributed/pipelining/__init__.py | 3 +- 5 files changed, 67 insertions(+), 6 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 3d860fba976d..3f2c0f9cb98f 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -402,6 +402,8 @@ Pipeline Stages .. autoclass:: PipelineStage +.. autofunction:: build_stage + Pipeline Schedules ================== diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index 959dd2e526fa..fac2be495ce0 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -8,7 +8,12 @@ import torch import torch.distributed as dist -from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe +from torch.distributed.pipelining import ( + build_stage, + pipeline, + PipelineStage, + ScheduleGPipe, +) from torch.distributed.pipelining._utils import PipeliningShapeError from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( @@ -153,8 +158,13 @@ def test_tracer_kwargs(self, ModelClass): mb_kwargs={"y": y_mb}, ) - stage = pipe.build_stage( + stage_mod = pipe.get_stage_module(self.rank) + + # Test build_stage + stage = build_stage( + stage_mod, self.rank, + pipe.info(), self.device, ) diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py index 4f8902b4320f..a114b629344d 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/PipelineStage.py @@ -20,6 +20,7 @@ __all__ = [ "PipelineStage", + "build_stage", ] logger = logging.getLogger(__name__) @@ -646,6 +647,13 @@ def __init__( """ Create a pipeline stage given a stage_module to be wrapped by this stage and a `pipe_info` describing the stage relationship of the pipeline. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage """ _PipelineStageBase.__init__( self, @@ -890,6 +898,37 @@ def _create_grad_recv_info( return grad_recv_info_tuple +# A helper function to create a pipeline stage based on traced pipeline information +def build_stage( + stage_module: torch.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: torch.device, + group: Optional[dist.ProcessGroup] = None, +) -> _PipelineStage: + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and pipeline information. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + + Returns: + _PipelineStage: a pipeline stage that can run with `PipelineSchedules`. + """ + return _PipelineStage( + stage_module, + stage_index, + pipe_info, + device, + group, + ) + + # Manual PipelineStage functions and definition METADATA_TENSOR_LEN = 100 diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index e9aeeb06c870..4b82ab8bab66 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1082,7 +1082,15 @@ def __str__(self): def __repr__(self): return self.split_gm.__repr__() - def _info(self) -> PipeInfo: + def info(self) -> PipeInfo: + """ + Get information about the pipe. + + Returns + ------- + PipeInfo + A dataclass containing information about the pipe. + """ return PipeInfo( graph=self.split_gm.graph, num_stages=self.num_stages, @@ -1096,7 +1104,8 @@ def build_stage( group: Optional[ProcessGroup] = None, ) -> _PipelineStage: """ - Create a `PipelineStage` given a stage index and distributed context. + Create a `PipelineStage` given a stage index and distributed group. + The `PipelineStage` can run with `PipelineSchedule`s. """ # Find stage module stage_module = self.get_stage_module(stage_index) @@ -1119,7 +1128,7 @@ def build_stage( # a reference to `Pipe` or `Pipe.split_gm` which stops python from # recycling them. When python recycles them, other stage modules (which # are irrelevant to current rank) can be automatically freed. - pipe_info = self._info() + pipe_info = self.info() return _PipelineStage(stage_module, stage_index, pipe_info, device, group) diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index 69e455e41992..fe487ad8505b 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -6,7 +6,7 @@ ScheduleInterleaved1F1B, ScheduleLoopedBFS, ) -from .PipelineStage import PipelineStage +from .PipelineStage import build_stage, PipelineStage __all__ = [ "Pipe", @@ -14,6 +14,7 @@ "SplitPoint", "pipeline", "PipelineStage", + "build_stage", "Schedule1F1B", "ScheduleGPipe", "ScheduleInterleaved1F1B", From 2e42671619604490fc8b3e7ef90ed9eab5fb5ee6 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 7 Jun 2024 19:43:20 -0700 Subject: [PATCH 518/706] [pipelining] Rename to stage.py and schedules.py (#128278) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128278 Approved by: https://github.com/H-Huang ghstack dependencies: #128273 --- docs/source/distributed.pipelining.rst | 10 +++++----- test/allowlist_for_publicAPI.json | 17 ----------------- .../pipelining/test_composability.py | 2 +- test/distributed/pipelining/test_schedule.py | 4 ++-- torch/distributed/pipelining/_IR.py | 3 +-- torch/distributed/pipelining/__init__.py | 4 ++-- .../{PipelineSchedule.py => schedules.py} | 2 +- .../pipelining/{PipelineStage.py => stage.py} | 0 8 files changed, 12 insertions(+), 30 deletions(-) rename torch/distributed/pipelining/{PipelineSchedule.py => schedules.py} (99%) rename torch/distributed/pipelining/{PipelineStage.py => stage.py} (100%) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 3f2c0f9cb98f..f8273a4aa372 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -63,7 +63,7 @@ Overall, the ``pipelining`` package provides the following features: Step 1: build ``PipelineStage`` objects for Execution -******************************************************** +***************************************************** Before we can use a PipelineSchedule, we need to create PipelineStage objects that wrap the part of the model running in that stage. The `PipelineStage` is responsible for allocating communication buffers and creating send/recv ops to communicate with its peers. It manages intermediate buffers e.g. for the outputs of forward that have not been consumed yet, and it provides a utility for running the backwards for the stage model. @@ -396,9 +396,9 @@ Microbatch Utilities Pipeline Stages =============== -.. automodule:: torch.distributed.pipelining.PipelineStage +.. automodule:: torch.distributed.pipelining.stage -.. currentmodule:: torch.distributed.pipelining.PipelineStage +.. currentmodule:: torch.distributed.pipelining.stage .. autoclass:: PipelineStage @@ -407,9 +407,9 @@ Pipeline Stages Pipeline Schedules ================== -.. automodule:: torch.distributed.pipelining.PipelineSchedule +.. automodule:: torch.distributed.pipelining.schedules -.. currentmodule:: torch.distributed.pipelining.PipelineSchedule +.. currentmodule:: torch.distributed.pipelining.schedules .. autoclass:: ScheduleGPipe diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 947b8d79077a..44de9e809615 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2610,23 +2610,6 @@ "pipe_split", "pipeline" ], - "torch.distributed.pipelining.PipelineSchedule": [ - "ABC", - "Any", - "Callable", - "Dict", - "List", - "Optional", - "Pipe", - "PipelineStageBase", - "Tuple", - "Union", - "abstractmethod", - "defaultdict", - "merge_chunks", - "record_function", - "split_args_kwargs_into_chunks" - ], "torch.distributed.pipelining.microbatch": [ "Any", "Dict", diff --git a/test/distributed/pipelining/test_composability.py b/test/distributed/pipelining/test_composability.py index 0ef42fe90ddd..a2a37a6e0740 100644 --- a/test/distributed/pipelining/test_composability.py +++ b/test/distributed/pipelining/test_composability.py @@ -17,7 +17,7 @@ from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import init_device_mesh from torch.distributed.pipelining import PipelineStage -from torch.distributed.pipelining.PipelineSchedule import ( +from torch.distributed.pipelining.schedules import ( PipelineScheduleSingle, Schedule1F1B, ScheduleGPipe, diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 22d1167908aa..e67459d5b44b 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -20,8 +20,8 @@ ScheduleInterleaved1F1B, ScheduleLoopedBFS, ) -from torch.distributed.pipelining.PipelineSchedule import _Action, _ComputationType -from torch.distributed.pipelining.PipelineStage import _PipelineStageBase +from torch.distributed.pipelining.schedules import _Action, _ComputationType +from torch.distributed.pipelining.stage import _PipelineStageBase from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 4b82ab8bab66..33243108f468 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -23,8 +23,7 @@ from ._backward import _null_coalesce_accumulate, stage_backward from ._unflatten import _outline_submodules from ._utils import PipeInfo - -from .PipelineStage import _PipelineStage +from .stage import _PipelineStage logger = logging.getLogger(__name__) diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index fe487ad8505b..18b3191add5b 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -1,12 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from ._IR import Pipe, pipe_split, pipeline, SplitPoint -from .PipelineSchedule import ( +from .schedules import ( Schedule1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleLoopedBFS, ) -from .PipelineStage import build_stage, PipelineStage +from .stage import build_stage, PipelineStage __all__ = [ "Pipe", diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/schedules.py similarity index 99% rename from torch/distributed/pipelining/PipelineSchedule.py rename to torch/distributed/pipelining/schedules.py index 31632a8aaee5..dfd5752c2e45 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/schedules.py @@ -11,7 +11,7 @@ from torch.profiler import record_function from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec -from .PipelineStage import _PipelineStageBase +from .stage import _PipelineStageBase __all__ = [ diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/stage.py similarity index 100% rename from torch/distributed/pipelining/PipelineStage.py rename to torch/distributed/pipelining/stage.py From 613c7d270d809ff01590dfcac0a192f35bfdb553 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 7 Jun 2024 21:56:10 -0700 Subject: [PATCH 519/706] [pipelining] Format doc (#128279) - Should use two dots around `var` - Wrap lines - Add section cross ref Pull Request resolved: https://github.com/pytorch/pytorch/pull/128279 Approved by: https://github.com/H-Huang ghstack dependencies: #128273, #128278 --- docs/source/distributed.pipelining.rst | 79 +++++++++++++++++++------- 1 file changed, 59 insertions(+), 20 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index f8273a4aa372..2e5e9f74c662 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -62,16 +62,37 @@ Overall, the ``pipelining`` package provides the following features: application on the Llama model. -Step 1: build ``PipelineStage`` objects for Execution -***************************************************** - -Before we can use a PipelineSchedule, we need to create PipelineStage objects that wrap the part of the model running in that stage. The `PipelineStage` is responsible for allocating communication buffers and creating send/recv ops to communicate with its peers. It manages intermediate buffers e.g. for the outputs of forward that have not been consumed yet, and it provides a utility for running the backwards for the stage model. - -A `PipelineStage` needs to know the input and output shapes for the stage model, so that it can correctly allocate communication buffers. The shapes must be static, e.g. at runtime the shapes can not change from step to step. A class `PipeliningShapeError` will be raised if runtime shapes do not match the expected shapes. When composing with other paralleisms or applying mixed precision, these techniques must be taken into account so the `PipelineStage` knows the correct shape (and dtype) for the output of the stage module at runtime. - -Users may construct a `PipelineStage` instance directly, by passing in an `nn.Module` representing the portion of the model that should run on the stage. This may require changes to the original model code. See the example below "Preparing a model for pipeline splitting". - -Alternatively, the tracing frontend can use graph-partitioning to construct a `GraphModule` that represents the desired subset of the model automatically. This technique requires the model is traceable with torch.Export in non-strict mode. Composability of the resulting `GraphModule` with other parallelism techniques and torch.compile is experimental, and may require some workarounds. Usage of this frontend may be more appealing if the user cannot easily change the model code. See "Splitting a Model with the ``pipeline`` tracing frontend" for more information. +Step 1: build ``PipelineStage`` for execution +********************************************* + +Before we can use a ``PipelineSchedule``, we need to create ``PipelineStage`` +objects that wrap the part of the model running in that stage. The +``PipelineStage`` is responsible for allocating communication buffers and +creating send/recv ops to communicate with its peers. It manages intermediate +buffers e.g. for the outputs of forward that have not been consumed yet, and it +provides a utility for running the backwards for the stage model. + +A ``PipelineStage`` needs to know the input and output shapes for the stage +model, so that it can correctly allocate communication buffers. The shapes must +be static, e.g. at runtime the shapes can not change from step to step. A class +``PipeliningShapeError`` will be raised if runtime shapes do not match the +expected shapes. When composing with other paralleisms or applying mixed +precision, these techniques must be taken into account so the ``PipelineStage`` +knows the correct shape (and dtype) for the output of the stage module at +runtime. + +Users may construct a ``PipelineStage`` instance directly, by passing in an +``nn.Module`` representing the portion of the model that should run on the +stage. This may require changes to the original model code. See the example +in :ref:`option_1_manual`. + +Alternatively, the splitting frontend can use graph partitioning to split your +model into a series of ``nn.Module`` automatically. This technique requires the +model is traceable with ``torch.Export``. Composability of the resulting +``nn.Module`` with other parallelism techniques is experimental, and may require +some workarounds. Usage of this frontend may be more appealing if the user +cannot easily change the model code. See :ref:`option_2_tracer` for more +information. Step 2: use ``PipelineSchedule`` for execution @@ -105,10 +126,20 @@ launcher service to launch multiple processes: torchrun --nproc_per_node=2 example.py -Preparing a model for pipeline splitting -======================================== +Options for Splitting a Model +***************************** -To directly construct a `PipelineStage`, the user is responsible for providing a single nn.Module instance that owns the relevant nn.Parameters and nn.Buffers, and defines a .forward() method that executes the operations relevant for that stage. For example, a condensed version of the Transformer class defined in Torchtitan shows a pattern of building an easily partitionable model. +.. _option_1_manual: + +Option 1: splitting a model manually +==================================== + +To directly construct a ``PipelineStage``, the user is responsible for providing +a single ``nn.Module`` instance that owns the relevant ``nn.Parameters`` and +``nn.Buffers``, and defines a ``forward()`` method that executes the operations +relevant for that stage. For example, a condensed version of the Transformer +class defined in Torchtitan shows a pattern of building an easily partitionable +model. .. code-block:: python @@ -137,7 +168,10 @@ To directly construct a `PipelineStage`, the user is responsible for providing a output = self.output(h).float() if self.output else h return output -A model defined in this manner can be easily configured per stage by first initializing the whole model (using meta-device to avoid OOM errors), deleting undesired layers for that stage, and then creating a PipelineStage that wraps the model. For example: +A model defined in this manner can be easily configured per stage by first +initializing the whole model (using meta-device to avoid OOM errors), deleting +undesired layers for that stage, and then creating a PipelineStage that wraps +the model. For example: .. code-block:: python @@ -169,15 +203,20 @@ A model defined in this manner can be easily configured per stage by first initi ) -The ``PipelineStage`` requires an example argument `input_args` representing the runtime input to the stage, which would be one microbatch worth of input data. -This argument is passed through the forward method of the stage module to determine the -input and output shapes required for communication. +The ``PipelineStage`` requires an example argument ``input_args`` representing +the runtime input to the stage, which would be one microbatch worth of input +data. This argument is passed through the forward method of the stage module to +determine the input and output shapes required for communication. + +When composing with other Data or Model parallelism techniques, ``output_args`` +may also be required, if the output shape/dtype of the model chunk will be +affected. -When composing with other Data or Model parallelism techniques, `output_args` may also be required, if the output shape/dtype of the model chunk will be affected. +.. _option_2_tracer: -Splitting a Model with the ``pipeline`` tracing frontend -======================================================== +Option 2: splitting a model automatically +========================================= If you have a full model and do not want to spend time on modifying it into a sequence of "model partitions", the ``pipeline`` API is here to help. From c4468518293f524d7ff1b2514d59753870dc232c Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Fri, 7 Jun 2024 14:30:35 -0700 Subject: [PATCH 520/706] [fsdp2] update foreach_reduce accumulate_grad (#128117) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128117 Approved by: https://github.com/awgu --- torch/distributed/_composable/fsdp/_fsdp_collectives.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index b7264cb34d6d..99b69cd82e4b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist +from torch.distributed._tensor import DTensor from torch.distributed.distributed_c10d import ReduceOp from ._fsdp_common import ( _get_dim0_padded_size, @@ -222,10 +223,13 @@ def foreach_reduce( # Record an event on which to block the CPU thread to # ensure that the D2H copy finishes before the optimizer fsdp_param.grad_offload_event = reduce_scatter_stream.record_event() - new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor(new_sharded_grad) if to_accumulate_grad: - fsdp_param.sharded_param.grad += new_sharded_dtensor_grad + assert isinstance(fsdp_param.sharded_param.grad, DTensor) + fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad else: + new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor( + new_sharded_grad + ) fsdp_param.sharded_param.grad = new_sharded_dtensor_grad padded_sharded_numel = padded_unsharded_size.numel() // world_size flat_grad_offset += padded_sharded_numel From ffc202a1b91def8c81a6eb9a39777bc7e149e1ee Mon Sep 17 00:00:00 2001 From: chilli Date: Fri, 7 Jun 2024 16:08:14 -0700 Subject: [PATCH 521/706] Added remove_noop_ops to joint_graph_passes (#124451) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124451 Approved by: https://github.com/ezyang, https://github.com/fmassa --- test/inductor/test_flex_attention.py | 4 +- torch/_inductor/fx_passes/fuse_attention.py | 3 + torch/_inductor/fx_passes/joint_graph.py | 4 + .../serialized_patterns/_sfdp_pattern_1.py | 16 +--- .../serialized_patterns/_sfdp_pattern_10.py | 20 ++-- .../serialized_patterns/_sfdp_pattern_11.py | 16 +--- .../serialized_patterns/_sfdp_pattern_12.py | 34 +++---- .../serialized_patterns/_sfdp_pattern_13.py | 26 ++---- .../serialized_patterns/_sfdp_pattern_14.py | 16 +--- .../serialized_patterns/_sfdp_pattern_15.py | 16 +--- .../serialized_patterns/_sfdp_pattern_16.py | 92 ++++++------------- .../serialized_patterns/_sfdp_pattern_17.py | 34 +++---- .../serialized_patterns/_sfdp_pattern_18.py | 60 ++++-------- .../serialized_patterns/_sfdp_pattern_19.py | 30 ++---- .../serialized_patterns/_sfdp_pattern_2.py | 16 +--- .../serialized_patterns/_sfdp_pattern_3.py | 26 ++---- .../serialized_patterns/_sfdp_pattern_4.py | 26 ++---- .../serialized_patterns/_sfdp_pattern_5.py | 16 +--- .../serialized_patterns/_sfdp_pattern_6.py | 26 ++---- .../serialized_patterns/_sfdp_pattern_7.py | 36 +++----- .../serialized_patterns/_sfdp_pattern_8.py | 20 ++-- .../serialized_patterns/_sfdp_pattern_9.py | 36 +++----- torch/_inductor/pattern_matcher.py | 13 +++ 23 files changed, 187 insertions(+), 399 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index d7afbe1123e7..4e8eecef0f41 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1048,10 +1048,10 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): joint_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", alias_3: "f64[2, 2, 8, 4]", alias_5: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): + def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): fw_graph = self.fw_graph joint_graph = self.joint_graph - flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, alias_3, alias_5, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_3 = alias_5 = tangents_1 = fw_graph = joint_graph = None + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = None getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0] getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1] getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 3fbb67cb2776..a6c1d11bd78a 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -883,6 +883,9 @@ def _get_sfdp_patterns(): "pass_dicts": patterns, "extra_check": extra_check, "scalar_workaround": workaround, + # with dropout turned into clone, we end up with a number of + # semantically identical graphs + "skip_duplicates": True, } diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index bf282ee72ba8..477bfe670cec 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -313,6 +313,10 @@ def joint_graph_passes(graph: torch.fx.GraphModule): config.joint_custom_pre_pass(graph.graph) count += 1 + from .post_grad import remove_noop_ops + + remove_noop_ops(graph.graph) + if config.joint_graph_constant_folding: constant_fold_uniform_value(graph) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py index ce678d28833b..55d2216b4e1f 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py @@ -42,23 +42,19 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) @@ -123,11 +119,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py index a9c38dd92fd0..860ef1c8551f 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py @@ -46,7 +46,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) @@ -56,18 +56,14 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) @@ -137,7 +133,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) @@ -147,17 +143,13 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py index e324c7943e21..d8119c33ed93 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py @@ -46,7 +46,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) @@ -55,16 +55,12 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) @@ -144,11 +140,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py index 09220864f13e..40834960904a 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py @@ -48,7 +48,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -59,11 +59,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -71,8 +67,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) @@ -116,13 +111,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -158,11 +152,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -171,8 +161,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -220,12 +209,11 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py index ad05c6ed4014..bef5eab2bee9 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py @@ -38,22 +38,17 @@ sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) -alias_default = CallFunction(aten.alias.default, div_Tensor) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor) permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2) permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) @@ -78,8 +73,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor) -_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'), _users=0) +_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0) rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) @@ -96,19 +90,14 @@ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -137,5 +126,4 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'), _users=0) +_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py index a25976ad6672..a1e87c009fcc 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py @@ -47,7 +47,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) @@ -56,16 +56,12 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) @@ -148,11 +144,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py index e5cc2e1cfb61..289585111a54 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py @@ -50,7 +50,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) @@ -60,16 +60,12 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) @@ -161,11 +157,7 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py index 8895782436b4..e3c1b5c60235 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py @@ -49,7 +49,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -60,11 +60,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -72,8 +68,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) @@ -119,13 +114,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -147,7 +141,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -157,11 +151,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -169,8 +159,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) @@ -214,8 +203,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) @@ -256,11 +244,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -269,8 +253,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -320,13 +303,12 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -360,11 +342,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -373,8 +351,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -422,8 +399,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) @@ -451,7 +427,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -463,11 +439,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -476,8 +448,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) @@ -524,14 +495,13 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -553,7 +523,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -564,11 +534,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -577,8 +543,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) @@ -623,8 +588,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py index 225dce51a19a..f741b23c0dd3 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py @@ -52,7 +52,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -64,11 +64,7 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -76,8 +72,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) @@ -128,13 +123,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) -view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) _sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -175,11 +169,7 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -188,8 +178,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -244,12 +233,11 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) -view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) _sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py index cf3fe7cff4a2..25c482876a99 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py @@ -51,7 +51,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -62,11 +62,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -74,8 +70,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) @@ -126,13 +121,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) _sfdp_pattern_18_inference = MultiOutputPattern([view_default_5, @@ -160,7 +154,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -170,11 +164,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -182,8 +172,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) @@ -232,8 +221,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) @@ -280,11 +268,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -293,8 +277,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -349,13 +332,12 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) _sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5, @@ -395,11 +377,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -408,8 +386,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -462,8 +439,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py index c2b71b521b2b..3cba2215bc76 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py @@ -48,7 +48,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -57,11 +57,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) @@ -69,8 +65,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) @@ -114,8 +109,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) @@ -141,7 +135,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) @@ -151,11 +145,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) @@ -163,9 +153,8 @@ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) @@ -211,8 +200,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py index cdaa975bcfc0..f573cb373491 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py @@ -42,23 +42,19 @@ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) @@ -123,11 +119,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py index 481c704f709e..d7eb251ba52d 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py @@ -44,7 +44,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -53,11 +53,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) @@ -65,8 +61,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) @@ -103,8 +98,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) @@ -137,11 +131,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -150,8 +140,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -192,8 +181,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py index d9f8bf2ebc99..773b2be31bde 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py @@ -44,7 +44,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) @@ -53,11 +53,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) @@ -65,8 +61,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) -clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format) -mul_Tensor_5 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) @@ -103,8 +98,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) @@ -137,11 +131,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -150,8 +140,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) -clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored()) mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) @@ -192,8 +181,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py index 64f99e2ac21e..fe481c8293be 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -43,23 +43,19 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) @@ -126,11 +122,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py index 9836142aade5..7de8b8229ea8 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -45,7 +45,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -54,11 +54,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) @@ -66,8 +62,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) @@ -105,8 +100,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) @@ -140,11 +134,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -153,8 +143,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -196,8 +185,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py index 87c233a2ae18..ff198232b5e6 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py @@ -48,7 +48,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -60,11 +60,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -74,8 +70,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) @@ -118,14 +113,13 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -149,7 +143,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -161,11 +155,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -174,8 +164,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) @@ -220,13 +209,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py index eb6ffee4614c..8c4b27c8a6fb 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -46,7 +46,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) @@ -56,18 +56,14 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) @@ -137,7 +133,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) @@ -147,17 +143,13 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py index f2456fbef495..78380c1bb341 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py @@ -48,7 +48,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -60,11 +60,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -74,8 +70,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) @@ -118,14 +113,13 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -149,7 +143,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -161,11 +155,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -174,8 +164,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) @@ -220,13 +209,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index e3e0ddcfd547..b9f4e1e18c93 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -906,6 +906,7 @@ def __init__(self) -> None: self.memoized_objs_pp: Dict[PatternExpr, str] = {} @staticmethod + @functools.lru_cache(None) def run(obj: PatternExpr, output_name: str = "output") -> str: """ Serializes obj to python code with obj written out to `output_name` @@ -1463,6 +1464,7 @@ def gen_register_replacement( extra_check: Callable[[Match], bool] = _return_true, scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, exclusive_arg_names: Sequence[str] = (), + skip_duplicates: bool = False, ) -> None: # Make sure the example_inputs is materialized. example_inputs = tuple(example_inputs) @@ -1491,6 +1493,8 @@ def gen_register_replacement( # Since this is just an optimization we can clear it out. arg.constant = None + if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates: + return _known_precompiled_patterns.append( (search_fn, example_inputs, trace_fn, scalar_workaround, pat) ) @@ -1790,6 +1794,11 @@ def fwd_only( # TODO - look into using aot autograd, asserting no mutating ops here with enable_python_dispatcher(): gm = make_fx(fn, select_decomp_table(), tracing_mode="real")(*args) + + from .fx_passes.post_grad import remove_noop_ops + + remove_noop_ops(gm.graph) + if run_dce: gm.graph.eliminate_dead_code() gm.recompile() @@ -1820,6 +1829,10 @@ def record_joint_graph( )(*args) assert gm + from .fx_passes.post_grad import remove_noop_ops + + remove_noop_ops(gm.graph) + from .fx_passes.joint_graph import pointless_view matcher_pass = PatternMatcherPass() From 310f80995b750ec3c0650c72a83c6d146c9f6b76 Mon Sep 17 00:00:00 2001 From: chilli Date: Fri, 7 Jun 2024 16:10:45 -0700 Subject: [PATCH 522/706] Added memory budget to partitioner (#126320) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126320 Approved by: https://github.com/shunting314 --- test/functorch/test_ac.py | 302 ++++++++++++++++++++++++ torch/_functorch/config.py | 33 +++ torch/_functorch/partitioners.py | 393 +++++++++++++++++++++++++++++-- 3 files changed, 703 insertions(+), 25 deletions(-) create mode 100644 test/functorch/test_ac.py diff --git a/test/functorch/test_ac.py b/test/functorch/test_ac.py new file mode 100644 index 000000000000..a9b1d00b9929 --- /dev/null +++ b/test/functorch/test_ac.py @@ -0,0 +1,302 @@ +# Owner(s): ["oncall: pt2"] +import random + +import torch +import torch._functorch.config as config +from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase +from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.utils.flop_counter import FlopCounterMode + + +def compile_with_ac(f, memory_budget): + return torch.compile(f, backend="aot_eager_decomp_partition") + + +def get_act_mem(f): + out = f() + out.backward() + start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + out = f() + cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + act_mem = (cur_mem - start_mem) / (1024 * 1024) + out.backward() + return act_mem + + +def get_bw_flops(f): + # Normalized so that a 512 square matmul returns 1 + f().backward() + out = f() + with FlopCounterMode(display=False) as mode: + out.backward() + return mode.get_total_flops() / (512**3 * 2) + + +def create_pair(B_I, O): + # results in B_I * O memory, requires B_I * B_I * O flops + # arithmetic intensity of B_I + x = torch.randn(B_I * 512, B_I * 512, requires_grad=True) + w = torch.randn(B_I * 512, O * 512, requires_grad=True) + return x, w + + +def get_mem_and_flops(f, memory_budget=None): + # Returns megabytes rounded to 1 decimal point and FLOPs + # Note that each value of size (512, 512, torch.float32) is 1 MiB + torch._dynamo.reset() + with config.patch(activation_memory_budget=memory_budget): + if memory_budget is not None: + f = torch.compile(f, backend="aot_eager_decomp_partition") + + # We round this to nearest 10th of a megabyte. + return round(get_act_mem(f), 1), get_bw_flops(f) + + +class MemoryBudgetTest(TestCase): + def setUp(self): + super().setUp() + torch.set_default_device("cuda") + + def test_rematerializes_cheap(self): + def f(x, w): + x = x.cos() + x = torch.mm(x, w) + return x.sum() + + x = torch.randn(512, 512, requires_grad=True) + w = torch.randn(512, 512, requires_grad=True) + + def call(): + return f(x, w) + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, 1.0) + mem_10, flops_10 = get_mem_and_flops(call, memory_budget=1.0) + # Recomputing `.cos()` is not free here. + self.assertEqual(mem_10, 1.0) + self.assertEqual(eager_flops, flops_10) + mem_5, flops_5 = get_mem_and_flops(call, memory_budget=0.5) + # We can just recompute `x.cos()` here to only depend on the inputs + self.assertEqual(mem_5, 0.0) + self.assertEqual(flops_5, eager_flops) + + def test_matmul_even_chain(self): + def f(x, ws): + x = x.cos() + for w in ws: + x = torch.mm(x, w).cos() + return x.sum() + + x = torch.randn(512, 512, requires_grad=True) + ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] + + def call(): + return f(x, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + for budget in range(0, 11): + mem, flops = get_mem_and_flops(call, memory_budget=budget / 10) + if budget <= 5: + # We start saving the matmuls + self.assertEqual(mem, budget) + self.assertEqual(flops, eager_flops + (5 - budget)) + elif budget < 10: + # We're only recomputing the `cos` operations + self.assertEqual(mem, 5.0) + self.assertEqual(flops, eager_flops) + elif budget == 10: + self.assertEqual(mem, 10.0) + self.assertEqual(flops, eager_flops) + + def test_matmul_uneven_chain(self): + # This function is constructed so that we are saving one input of size + # [512, in_dim] for each w + # In addition, every matmul has a same ratio of compute to "memory + # saved", so this test is essentially testing our knapsack solving + + def f(x, ws): + xs = [torch.mm(x, w).cos() for w in ws] + return sum([x.sum() for x in xs]) + + x = torch.randn(512, 512, requires_grad=True) + + def make_weights(w_shapes): + ws = [] + for idx, dim in enumerate(w_shapes): + ws.append(torch.randn(512, dim * 512, requires_grad=True)) + return ws + + def make_weights_chain(w_shapes): + ws = [] + for idx, _ in enumerate(w_shapes): + old_dim = 512 if idx == 0 else w_shapes[idx - 1] * 512 + new_dim = w_shapes[idx] * 512 + ws.append(torch.randn(old_dim, new_dim, requires_grad=True)) + return ws + + weight_configs = [ + ( + [11, 3, 4, 2], + [ + 18, # 11 + 4 + 3 + 17, # 11 + 4 + 2 + 16, # 11 + 3 + 2 + 15, # 11 + 4 + 14, # 11 + 3 + 13, # 11 + 2 + 11, # 11 + 2 + 7, # 4 + 3 + 6, # 4 + 2 + 5, # 3 + 2 + ], + ), + ( + [3, 5, 11, 17, 14], + [ + 42, # 17 + 14 + 9 + 30, # 11 + 15 + 5 + 19, # 11 + 5 + 3 + 8, # 5 + 3 + 3, # 3 + ], + ), + ] + random.seed(0) + random_arr = [random.randint(0, 50) for _ in range(10)] + exact_sums = [] + for i in range(10): + random.shuffle(random_arr) + exact_sums.append(sum(random_arr[:i])) + weight_configs.append((random_arr, exact_sums)) + + for weight_shapes, exact_solves in weight_configs: + ws = make_weights(weight_shapes) + + def call(): + return f(x, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + total_mem = sum(weight_shapes) + self.assertEqual(eager_mem, sum(weight_shapes)) + for mem_achieved in exact_solves: + mem, _ = get_mem_and_flops(call, memory_budget=mem_achieved / total_mem) + self.assertEqual(mem, mem_achieved) + + def test_prioritize_cheaper_matmul(self): + def f(xs, ws): + xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] + return sum([x.sum() for x in xs]) + + x1, w1 = create_pair(1, 4) + x2, w2 = create_pair(2, 2) + + def call(): + return f([x1, x2], [w1, w2]) + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, 8) + self.assertEqual(eager_flops, 24) + comp_mem, comp_flops = get_mem_and_flops(call, memory_budget=0.5) + self.assertEqual(comp_mem, 4) + # We are recomputing x1 @ w1 here! + self.assertEqual(comp_flops, eager_flops + 4) + + @config.patch(activation_memory_budget_runtime_estimator="profile") + def test_profile(self): + def f(x, ws): + x = x.cos() + for w in ws: + x = torch.mm(x, w).cos() + return x.sum() + + x = torch.randn(512, 512, requires_grad=True) + ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] + + def call(): + return f(x, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + mem, flops = get_mem_and_flops(call, memory_budget=0.2) + # We start saving the matmuls + self.assertEqual(mem, 2) + self.assertEqual(flops, eager_flops + 3) + + def test_prioritize_cheaper_matmul2(self): + def f(xs, ws): + xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] + return sum([x.sum() for x in xs]) + + data = [(4, 4), (6, 2), (2, 6)] + xs, ws = zip(*[create_pair(a, b) for a, b in data]) + + def call(): + return f(xs, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, 40) + self.assertEqual(eager_flops, 320) + mem, flops = get_mem_and_flops(call, memory_budget=28 / eager_mem) + # Save w1 and w2 + self.assertEqual(mem, 28) + # We're recomputing w3 (the cheap one!) + self.assertEqual(flops - eager_flops, 2 * 2 * 6) + mem, flops = get_mem_and_flops(call, memory_budget=16 / eager_mem) + # Save w2. Note that even though saving w1 gets us closer to our memory + # limit, w2 is actually *more* FLOPs than w1! + self.assertEqual(mem, 12) + self.assertEqual(flops - eager_flops, 2 * 2 * 6 + 4 * 4 * 4) + + def test_attention_vs_linear(self): + def f(x, w): + orig_shape = x.shape + x = x.reshape(1, 1, x.shape[0], x.shape[1]) + # I know this isn't technically right lol + x = torch.nn.functional.scaled_dot_product_attention( + x, x, x, is_causal=False + ).reshape(*orig_shape) + x = torch.mm(x, w) + x = x.cos() + return x.sum() + + def try_seq_length(S, D, expected_recompute): + x = torch.randn(S * 512, D * 512, requires_grad=True) + w = torch.randn(D * 512, D * 512, requires_grad=True) + + def call(): + return f(x, w) + + with FlopCounterMode(display=False) as mode: + call() + mm_flops = mode.get_flop_counts()["Global"][torch.ops.aten.mm] + attn_flops = mode.get_total_flops() - mm_flops + mm_flops /= 512**3 * 2 + attn_flops /= 512**3 * 2 + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, S * D * 2) + + mem, flops = get_mem_and_flops( + call, memory_budget=0.6 + ) # Force it to recompute one of mm or attn + self.assertEqual(mem, S * D) + if expected_recompute == "attn": + expected_flops = attn_flops + else: + expected_flops = mm_flops + self.assertEqual(flops - eager_flops, expected_flops) + + # General behind this test is that if sequence length * 2 > D, then + # attention is more expensive than the linear. + try_seq_length(1, 1, "mm") + try_seq_length(1, 3, "attn") + try_seq_length(2, 2, "mm") + try_seq_length(2, 1, "mm") + try_seq_length(2, 5, "attn") + try_seq_length(4, 7, "mm") + try_seq_length(4, 9, "attn") + + +if __name__ == "__main__": + # I'm using the cuda memory allocator to verify memory allocations + if HAS_CUDA and not TEST_WITH_ROCM: + run_tests() diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index c559951f3809..60bbf1f21c66 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -88,6 +88,39 @@ # a fusion can be expensive. ban_recompute_reductions = True +# By default, the partitioner is purely trying to optimize for runtime (although +# it should always use less memory than eager) +# This knob controls the partitioner to make that tradeoff for you, choosing the +# fastest option that saves less activations than the memory budget. +# Specifically, 0.0 corresponds to the activation memory from applying +# activation checkpointing to the full compiled region, and 1.0 corresponds to +# the activation memory from the default runtime-optimized strategy. So, 0.4 +# would result in a strategy that saves 40% of the activations compared to the +# default strategy. +# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below +# the activation memory budget. +# NOTE: This *cannot* be treated as +activation_memory_budget = 1.0 + +# This controls how we estimate the runtime when deciding what the cheapest +# operators to recompute are. The 3 options are +# "flops": Bases it off of the flop count provided by torch.utils.flop_counter +# "profile": Benchmarks each operator to come up with a runtime +# "testing": Returns 1 for everything +activation_memory_budget_runtime_estimator = "flops" + +# This controls the solver used for the 0-1 knapsack. By default we use a +# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" +# (which has a scipy dependency). +activation_memory_budget_solver = "dp" + +# This dumps out a png visualization of the expected runtime vs. activation +# memory tradeoffs for all memory budget values from 0 to 1 in increments of +# 0.5. See an example here: +# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 +visualize_memory_budget_pareto = ( + os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1" +) # Sets all of the ban_recompute heuristics to False except ban_recompute_reductions # Generally, this will probably result in some memory improvement, but at the diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index cbfb4ca17168..fc1c995e5907 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -25,6 +25,7 @@ ) from torch.fx.passes import graph_drawer from . import config +from ._aot_autograd.logging_utils import get_aot_graph_name from .compile_utils import fx_graph_cse, get_aten_target if TYPE_CHECKING: @@ -451,14 +452,16 @@ def _size_of(node: fx.Node) -> int: # layering violation) elif isinstance(val, (list, tuple)): return sum( - _tensor_nbytes(hint_int(n.numel(), fallback=4098), n.dtype) + _tensor_nbytes(hint_int(n.numel(), fallback=4096), n.dtype) for n in val if isinstance(n, torch.Tensor) ) elif isinstance(val, torch.Tensor): - return _tensor_nbytes(hint_int(val.numel(), fallback=4098), val.dtype) + return _tensor_nbytes(hint_int(val.numel(), fallback=4096), val.dtype) raise RuntimeError(f"Unknown metadata type {type(val)}") + if node.op == "get_attr": + return 0 raise RuntimeError("We should always have `val` metadata on the nodes") @@ -532,25 +535,22 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: for idx, node in enumerate(gm.graph.nodes): order[node] = idx - # Populate depth for the nodes. Depth is the distance from the inputs. - depths = {} - output_node = next(iter(gm.graph.find_nodes(op="output"))) - for node in gm.graph.nodes: - if node.op == "placeholder": - depths[node] = 0 - else: - depths[node] = max([depths[arg] for arg in node.all_input_nodes], default=0) - def insert_node_in_graph(node): - if node in env: - return env[node] + cur_nodes = [node] + insertable_nodes = set() + while len(cur_nodes) > 0: + node = cur_nodes.pop() + if node in insertable_nodes or node in env: + continue + insertable_nodes.add(node) - # Bias traversal towards the nodes that have higher depth - prioritizes - # critical path first. - for arg, _ in sort_depths(node.all_input_nodes, depths): - env[arg] = insert_node_in_graph(arg) - env[node] = new_graph.node_copy(node, lambda x: env[x]) - return env[node] + # Bias traversal towards the nodes that have higher depth - prioritizes + # critical path first. + cur_nodes += node.all_input_nodes + + insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n]) + for node in insertable_nodes: + env[node] = new_graph.node_copy(node, lambda x: env[x]) # Find first bwd node in the graph tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) @@ -750,7 +750,7 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: return joint_module -def get_saved_values( +def solve_min_cut( joint_graph: fx.Graph, node_info: NodeInfo, min_cut_options: MinCutOptions, @@ -877,7 +877,6 @@ def ban_recomputation_if_allowed(node): return False if node in dont_ban: return False - # breakpoint() # This bans recomputation of the node unless we've been forced not to by # user annotation # NB: "recompute" > 0 means that user annotation has asked us to @@ -1268,9 +1267,197 @@ def get_name_to_node(graph: fx.Graph): return name_to_node +def greedy_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + n = len(runtimes) + items = list(range(n)) + + # Sort items based on the ratio of runtime to memory in descending order + items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) + + total_memory = 0.0 + total_runtime = 0.0 + items_to_save = [] + items_to_allow_recomputing = [] + + for i in items: + if total_memory + memory[i] <= max_memory: + total_memory += memory[i] + total_runtime += runtimes[i] + items_to_save.append(i) + else: + items_to_allow_recomputing.append(i) + return total_runtime, items_to_save, items_to_allow_recomputing + + +def ilp_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + import numpy as np + + try: + from scipy.optimize import Bounds, LinearConstraint, milp + except ImportError: + raise RuntimeError( + "To use the ILP for memory budget checkpointing you need to install scipy" + ) from None + + np_memory = np.array(memory) + np_runtimes = np.array(runtimes) + c = -np_runtimes # type: ignore[operator] + + memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) + constraints = [memory_constraint] + + integrality = np.ones_like(c) + res = milp( + c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) + ) + if not res.success: + raise RuntimeError("Somehow scipy solving failed") + + items_to_save = [] + items_to_allow_recomputing = [] + for idx, i in enumerate(res.x): + if i == 1: + items_to_save.append(idx) + else: + items_to_allow_recomputing.append(idx) + return -res.fun, items_to_save, items_to_allow_recomputing + + +def dp_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # Quantize the memory weights + quantized_memory = torch.tensor( + [int(round(m * S)) for m in memory], dtype=torch.long, device="cpu" + ) + runtimes = torch.tensor(runtimes, dtype=torch.float32, device="cpu") + + # Quantized pseudopolynomial DP for 0-1 Knapsack + quantized_max_memory = int(round(max_memory * S)) + + n = len(memory) + + # Initialize the DP table + # TODO(chilli): I think if needed, this memory can be optimized with sliding + # window trick + Hirschberg trick: + # https://codeforces.com/blog/entry/47247?#comment-316200 + dp = torch.zeros( + (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" + ) + + for i in range(1, n + 1): + current_memory = quantized_memory[i - 1] + current_runtime = runtimes[i - 1] + + # Copy the previous row + dp[i, :] = dp[i - 1, :] + + # Update dp[i, j] for all j >= current_memory + if current_memory == 0: + dp[i, :] = dp[i - 1, :] + current_runtime + else: + dp[i, current_memory:] = torch.maximum( + dp[i - 1, current_memory:], + dp[i - 1, :-current_memory] + current_runtime, + ) + + # Backtrack to find the items included in the knapsack + saved_items = [] + recomputable_items = [] + j: int = quantized_max_memory + for i in range(n, 0, -1): + if dp[i][j] != dp[i - 1][j]: + saved_items.append(i - 1) # Include this item (indexing from 0) + j -= int(quantized_memory[i - 1].item()) + else: + recomputable_items.append(i - 1) + + saved_items.reverse() # To get items in the order they were added + + # The maximum runtime that can be achieved within the max_memory constraint + max_runtime = dp[n][quantized_max_memory].item() + + return max_runtime, saved_items, recomputable_items + + +def _optimize_runtime_with_given_memory( + memory: List[float], + runtimes: List[float], + max_memory: float, +) -> Tuple[float, List[int], List[int]]: + SOLVER = config.activation_memory_budget_solver + if SOLVER == "greedy": + return greedy_knapsack(memory, runtimes, max_memory) + elif SOLVER == "ilp": + return ilp_knapsack(memory, runtimes, max_memory) + elif SOLVER == "dp": + return dp_knapsack(memory, runtimes, max_memory) + else: + raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}") + + +from torch.utils._mode_utils import no_dispatch + + +def estimate_runtime(node): + RUNTIME_MODE = config.activation_memory_budget_runtime_estimator + + def materialize_arg(x): + if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor): + shape = list(x.meta["val"].shape) + + def realize_symbol(d): + return hint_int(d, fallback=4096) + + shape = [realize_symbol(s) for s in shape] + return x.meta["val"].new_zeros(shape) + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): + return hint_int(x.meta["val"], fallback=4096) + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat): + return 1.0 + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool): + return True + else: + return x + + if RUNTIME_MODE == "testing": + return 1 + + elif RUNTIME_MODE == "profile": + from triton.testing import do_bench + + with no_dispatch(): + args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) + ms = do_bench(lambda: node.target(*args, **kwargs)) + return ms + + elif RUNTIME_MODE == "flops": + # todo(chilli): Normalize this to also return ms + from torch.utils.flop_counter import FlopCounterMode + + args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) + with FlopCounterMode(display=False) as mode: + node.target(*args, **kwargs) + counted_flops = mode.get_total_flops() + return max(counted_flops, 1) + else: + raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}") + + def choose_saved_values_set( joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 ) -> List[fx.Node]: + if memory_budget > 1 or memory_budget < 0: + raise RuntimeError( + f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}" + ) min_cut_options = MinCutOptions( ban_if_used_far_apart=config.ban_recompute_used_far_apart, ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, @@ -1287,16 +1474,164 @@ def choose_saved_values_set( ban_if_materialized_backward=False, ban_if_not_in_allowlist=False, ) - if memory_budget == 0: return node_info.inputs - runtime_optimized_saved_values, _ = get_saved_values( + runtime_optimized_saved_values, _ = solve_min_cut( joint_graph, node_info, min_cut_options, ) - return runtime_optimized_saved_values + # return runtime_optimized_saved_values + if memory_budget == 1: + return runtime_optimized_saved_values + + def estimate_activations_size(saved_values: List[fx.Node]) -> float: + return sum([_size_of(i) for i in saved_values]) / 1e9 + + min_act_size = estimate_activations_size(node_info.inputs) + max_act_size = estimate_activations_size(runtime_optimized_saved_values) + # The optimized choice is smaller than the inputs anyways + if max_act_size <= min_act_size: + return runtime_optimized_saved_values + + def get_normalized_size(sz): + return (sz / 1e9) / (max_act_size - min_act_size) + + def get_mem_ratio(activations: List[fx.Node]): + return (estimate_activations_size(activations) - min_act_size) / ( + max_act_size - min_act_size + ) + + more_aggressive_options = replace( + min_cut_options, + ban_if_used_far_apart=False, + ban_if_long_fusible_chains=False, + ban_if_materialized_backward=False, + ) + more_aggressive_saved_values, _ = solve_min_cut( + joint_graph, node_info, more_aggressive_options + ) + if get_mem_ratio(more_aggressive_saved_values) < memory_budget: + return more_aggressive_saved_values + + aggressive_options = replace( + more_aggressive_options, + ban_if_not_in_allowlist=False, + ) + aggressive_recomputation_saved_values, banned_nodes = solve_min_cut( + joint_graph, node_info, aggressive_options + ) + + if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget: + return aggressive_recomputation_saved_values + + from torch._inductor.fx_utils import get_node_storage + + input_storages = {get_node_storage(node) for node in node_info.inputs} + + def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]: + return [ + i + for i in banned_nodes + if ( + # Only allow recomputing nodes that are actually required for BW + i.dist_from_bw < int(1e9) # type: ignore[attr-defined] + and get_node_storage(i) not in input_storages + ) + ] + + recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes) + + # default: runtime_optimized_saved_values + # more aggressive: more_aggressive_saved_values + # full aggressive: aggressive_recomputation_saved_values + + all_recomputable_banned_nodes = sorted( + recomputable_banned_nodes, key=_size_of, reverse=True + ) + if len(all_recomputable_banned_nodes) == 0: + return node_info.inputs + memories_banned_nodes = [ + get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes + ] + runtimes_banned_nodes = [ + estimate_runtime(node) for node in all_recomputable_banned_nodes + ] + from torch.utils._mode_utils import no_dispatch + + def get_saved_values_knapsack(memory_budget): + with no_dispatch(): + ( + expected_runtime, + saved_node_idxs, + recomputable_node_idxs, + ) = _optimize_runtime_with_given_memory( + memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0) + ) + dont_ban = set() + for idx in recomputable_node_idxs: + dont_ban.add(all_recomputable_banned_nodes[idx]) + assert dont_ban.issubset(all_recomputable_banned_nodes) + + saved_values, _ = solve_min_cut( + joint_graph, + node_info, + aggressive_options, + dont_ban, + ) + return saved_values, expected_runtime + + if config.visualize_memory_budget_pareto: + options = [] + for sweep_memory_budget in range(100, -1, -5): + saved_values, expected_runtime = get_saved_values_knapsack( + sweep_memory_budget / 100 + ) + options.append( + ( + sweep_memory_budget, + sum(runtimes_banned_nodes) - expected_runtime, + get_mem_ratio(saved_values), + ) + ) + + import matplotlib.pyplot as plt + + x_values = [item[2] for item in options] + y_values = [item[1] for item in options] + + # Plotting the values with updated axis labels and chart title + plt.figure(figsize=(10, 6)) + plt.plot(x_values, y_values, marker="o") + + # Adding labels for each point + for i, txt in enumerate(x_values): + plt.annotate( + f"{txt:.2f}", + (x_values[i], y_values[i]), + textcoords="offset points", + xytext=(0, 10), + ha="center", + ) + + plt.xlabel("Memory Budget") + plt.ylabel("Runtime of Recomputed Components") + plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime") + plt.grid(True) + fig = plt.gcf() + plt.show() + fig_name = f"memory_budget_pareto_{get_aot_graph_name()}.png" + fig.savefig(fig_name) + log.warning("Generated Pareto frontier curve at %s", fig_name) + + # todo(chilli): Estimated doesn't align exactly with actual - actual is + # usually less memory than estimated. i'm guessing (actually quite + # unsure about this) that's because estimated is just only including + # tensors we actually banned from recompute, but there may be other + # tensors that we choose to save. + + return get_saved_values_knapsack(memory_budget=memory_budget)[0] def min_cut_rematerialization_partition( @@ -1412,7 +1747,15 @@ def classify_nodes(joint_module): for user in node.users: node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) - saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget=1) + memory_budget = config.activation_memory_budget + for node in joint_graph.nodes: + if isinstance(node.meta.get("memory_budget", None), float): + memory_budget = node.meta["memory_budget"] + break + # print("Memory Budget: ", memory_budget) + saved_values = choose_saved_values_set( + joint_graph, node_info, memory_budget=memory_budget + ) # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(is_sym_node, saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) From cbb7e3053fdf69d270cfe0ab7bddcac926959b57 Mon Sep 17 00:00:00 2001 From: Shaz Qadeer Date: Sat, 8 Jun 2024 05:52:50 +0000 Subject: [PATCH 523/706] View specialization (#127641) This PR adds specialization shortcuts for converting n-d to 1-d and 1-d to 2-d views. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127641 Approved by: https://github.com/ezyang --- test/dynamo/test_view.py | 41 ++++++++++++++++++++++++++++++++++++++ test/export/test_export.py | 5 +++++ test/export/test_serdes.py | 4 +--- torch/_refs/__init__.py | 10 ++++++++++ 4 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 test/dynamo/test_view.py diff --git a/test/dynamo/test_view.py b/test/dynamo/test_view.py new file mode 100644 index 000000000000..2d63e86af162 --- /dev/null +++ b/test/dynamo/test_view.py @@ -0,0 +1,41 @@ +# Owner(s): ["module: dynamo"] +import torch + +import torch._dynamo +import torch._dynamo.test_case + + +@torch._dynamo.config.patch("capture_scalar_outputs", True) +class ViewTests(torch._dynamo.test_case.TestCase): + def test_view_to_2d(self): + @torch.compile(fullgraph=True, backend="eager") + def f(t, _u0): + u0 = t[0].item() + u1 = t[1].item() + torch._check_is_size(u0) + torch._check_is_size(u1) + n = u0 * u1 + a = torch.randn(n) + return a.view(-1, _u0) + + t = torch.tensor([2, 4], dtype=torch.int32) + f(t, 2) + + def test_view_to_1d(self): + @torch.compile(fullgraph=True, backend="eager") + def f(t, _n): + u0 = t[0].item() + u1 = t[1].item() + torch._check_is_size(u0) + torch._check_is_size(u1) + a = torch.randn(u0, u1) + return a.view(_n) + + t = torch.tensor([2, 4], dtype=torch.int32) + f(t, 8) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/export/test_export.py b/test/export/test_export.py index 5b0c93135ba7..19acbbca39f1 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -149,6 +149,7 @@ class Inp: NON_STRICT_SUFFIX = "_non_strict" RETRACEABILITY_SUFFIX = "_retraceability" +SERDES_SUFFIX = "_serdes" PREDISPATCH_SUFFIX = "_pre_dispatch" @@ -160,6 +161,10 @@ def is_retracebility_test(test_name): return test_name.endswith(RETRACEABILITY_SUFFIX) +def is_serdes_test(test_name): + return test_name.endswith(SERDES_SUFFIX) + + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestDynamismExpression(TestCase): def test_export_inline_constraints(self): diff --git a/test/export/test_serdes.py b/test/export/test_serdes.py index bd11cd7f8366..52848134721f 100644 --- a/test/export/test_serdes.py +++ b/test/export/test_serdes.py @@ -23,14 +23,12 @@ def mocked_serder_export(*args, **kwargs): def make_dynamic_cls(cls): - suffix = "_serdes" - cls_prefix = "SerDesExport" test_class = testing.make_test_cls_with_mocked_export( cls, cls_prefix, - suffix, + test_export.SERDES_SUFFIX, mocked_serder_export, xfail_prop="_expected_failure_serdes", ) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 68675c751736..ca941f41f07f 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -3649,6 +3649,16 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL else: return _a + if a.is_contiguous(): + # Special-cases for nd_to_1d + if len(shape) == 1 and a.ndim > 1: + return torch.as_strided(a, [a.numel()], [1]) + # Special-cases for 1d_to_2d + if len(shape) == 2 and a.ndim == 1: + dim0 = shape[0] + dim1 = shape[1] + return torch.as_strided(a, [dim0, dim1], [dim1, 1]) + # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape # NOTE [Reshape Algorithm] From 8a0bc8c9ee5cdcba16d7caa50f7e663037a239af Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Fri, 7 Jun 2024 18:49:24 -0700 Subject: [PATCH 524/706] [fsdp2] simplify fsdp_param logic with DTensorSpec (#128242) as titled, we can use a single DTensorSpec to save the SPMD sharding spec, plus the global shape/stride to simplify the FSDPParam logic Pull Request resolved: https://github.com/pytorch/pytorch/pull/128242 Approved by: https://github.com/awgu --- .../_composable/fsdp/_fsdp_common.py | 30 ++++------- .../_composable/fsdp/_fsdp_param.py | 50 +++++++++---------- torch/distributed/_tensor/api.py | 2 +- 3 files changed, 34 insertions(+), 48 deletions(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 3cb06174703a..85b0192b0f50 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -3,15 +3,15 @@ from dataclasses import dataclass from enum import auto, Enum -from typing import Any, cast, List, Optional, Tuple +from typing import Any, cast, List, Optional import torch import torch._dynamo.compiled_autograd as ca import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.contract import _get_registry -from torch.distributed._tensor import DeviceMesh, DTensor, Placement -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed._tensor import DeviceMesh, DTensor +from torch.distributed._tensor.placement_types import DTensorSpec @dataclass @@ -109,10 +109,7 @@ def _get_dim0_chunked_size( def _from_local_no_grad( local_tensor: torch.Tensor, - device_mesh: DeviceMesh, - placements: Tuple[Placement, ...], - global_size: torch.Size, - global_stride: Tuple[int, ...], + sharding_spec: DTensorSpec, ) -> DTensor: """ This method is similar to ``DTensor.from_local()`` except that in eager mode @@ -120,29 +117,20 @@ def _from_local_no_grad( """ if not ca.compiled_autograd_enabled: - spec = DTensorSpec( - device_mesh, - placements, - tensor_meta=TensorMeta( - global_size, - global_stride, - local_tensor.dtype, - ), - ) return DTensor( # Use the local tensor directly instead of constructing a new tensor # variable, e.g. with `view_as()`, since this is not differentiable local_tensor, - spec, + sharding_spec, requires_grad=local_tensor.requires_grad, ) else: return DTensor.from_local( local_tensor, - device_mesh, - placements, - shape=global_size, - stride=global_stride, + sharding_spec.mesh, + sharding_spec.placements, + shape=sharding_spec.shape, + stride=sharding_spec.stride, ) diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index cf28a8e4fe13..ca12ea74b230 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -9,9 +9,9 @@ from torch._prims_common import make_contiguous_strides_for from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor import DTensor, Placement, Replicate, Shard +from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed._tensor.device_mesh import _mesh_resources -from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_common import ( _chunk_with_empty, @@ -128,12 +128,10 @@ class FSDPParam: _sharded_post_forward_param: Optional[nn.Parameter] # ND _unsharded_param: nn.Parameter # ND unsharded_accumulated_grad: Optional[torch.Tensor] # ND - _spmd_placements: Tuple[Placement, ...] - _global_size: torch.Size - _global_stride: Tuple[int, ...] - all_gather_outputs: List[torch.Tensor] # 1D + _sharding_spec: DTensorSpec # DTensor attributes (only defined for DTensor `param`): _tp_spec: DTensorSpec + all_gather_outputs: List[torch.Tensor] # 1D # All-gather extension attributes _extensions_data: ExtensionsData _unsharded_inner_tensors: List[torch.Tensor] @@ -213,13 +211,16 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): ) # TODO: Hard code FSDP + TP; need to support HSDP + TP - self._spmd_placements = ( + self._spmd_placements: Tuple[Placement, ...] = ( Shard(0), self._tp_spec.placements[0], ) - self._global_size = param.size() - self._global_stride = param.stride() + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=self._tp_spec.tensor_meta, + ) param_data = cast(DTensor, param)._local_tensor else: self._spmd_mesh = self.mesh_info.mesh @@ -227,8 +228,15 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): self._spmd_placements = (Replicate(), Shard(0)) else: self._spmd_placements = (Shard(0),) - self._global_size = param.size() - self._global_stride = param.stride() + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=TensorMeta( + param.size(), + param.stride(), + param.dtype, + ), + ) param_data = param self._orig_size = param_data.size() shard_rank = self.mesh_info.shard_mesh_rank @@ -352,13 +360,7 @@ def init_unsharded_param(self): storage_offset=0, ) if self.is_dtensor: - unsharded_param = _from_local_no_grad( - unsharded_param, - self._tp_spec.mesh, - self._tp_spec.placements, - self._global_size, - self._global_stride, - ) + unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) self._unsharded_param = nn.Parameter(unsharded_param) self._unsharded_param.requires_grad_(self.sharded_param.requires_grad) @@ -442,10 +444,7 @@ def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor: ) return _from_local_no_grad( tensor, - self._spmd_mesh, - self._spmd_placements, - self._global_size, - self._global_stride, + self._sharding_spec, ) def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor: @@ -456,13 +455,12 @@ def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor: assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo) # TODO: Prefer this DTensor to be read-only and generalize the # placement once we support TP. - return _from_local_no_grad( - tensor, + post_forward_sharding_spec = DTensorSpec( self.post_forward_mesh_info.mesh, (Replicate(), Shard(0)), - self._global_size, - self._global_stride, + tensor_meta=self._sharding_spec.tensor_meta, ) + return _from_local_no_grad(tensor, post_forward_sharding_spec) def to_accumulated_grad_if_needed(self) -> None: # Access `_unsharded_param` to bypass the sharded state check since we diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index be887f3ce6ca..5bcd6b033c8f 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -240,7 +240,7 @@ def __new__( cls, spec.tensor_meta.shape, strides=spec.tensor_meta.stride, - dtype=spec.tensor_meta.dtype, + dtype=local_tensor.dtype, device=local_tensor.device, layout=local_tensor.layout, requires_grad=requires_grad, From 94165dba7b96f5ac0de5c95560915b7bce3af21e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 8 Jun 2024 06:29:36 +0000 Subject: [PATCH 525/706] Revert "[dynamo] Inline the getattr of fx graph and proxy graph (#128172)" This reverts commit 662a78f957fb89e53ebeba7deb880561e10ecaf6. Reverted https://github.com/pytorch/pytorch/pull/128172 on behalf of https://github.com/anijain2305 due to pippy tests fail ([comment](https://github.com/pytorch/pytorch/pull/128172#issuecomment-2155835201)) --- test/dynamo/test_inline_inbuilt_nn_modules.py | 3 --- torch/_dynamo/trace_rules.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/test/dynamo/test_inline_inbuilt_nn_modules.py b/test/dynamo/test_inline_inbuilt_nn_modules.py index 0bd7f573e6e9..f7ba32bc15f3 100644 --- a/test/dynamo/test_inline_inbuilt_nn_modules.py +++ b/test/dynamo/test_inline_inbuilt_nn_modules.py @@ -6,7 +6,6 @@ try: from . import ( test_aot_autograd, - test_export, test_functions, test_higher_order_ops, test_misc, @@ -15,7 +14,6 @@ ) except ImportError: import test_aot_autograd - import test_export import test_functions import test_higher_order_ops import test_misc @@ -52,7 +50,6 @@ def make_inline_inbuilt_nn_modules_cls(cls): test_higher_order_ops.HigherOrderOpTests, test_higher_order_ops.FuncTorchHigherOrderOpTests, test_aot_autograd.AotAutogradFallbackTests, - test_export.ExportTests, # test_repros.ReproTests, ] for test in tests: diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 94487cb8551c..4f2bee755ae7 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3235,8 +3235,6 @@ def _module_dir(m: types.ModuleType): "torch.cuda.amp.autocast_mode", "torch.distributions", "torch.fx._pytree", - "torch.fx._symbolic_trace", - "torch.fx.experimental.proxy_tensor", "torch.fx.passes.shape_prop", "torch.nn", "torch.random", From 6e13c7e8745d80e14d77dc2c5cb1fd666959fbba Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 8 Jun 2024 06:32:28 +0000 Subject: [PATCH 526/706] Revert "[dynamo] Support if cond on UnspecializedNNModuleVariable and add inline tests (#128158)" This reverts commit 747fc35ff54154ddec2a5ab5661f57c28d65c591. Reverted https://github.com/pytorch/pytorch/pull/128158 on behalf of https://github.com/anijain2305 due to pippy tests fail ([comment](https://github.com/pytorch/pytorch/pull/128158#issuecomment-2155835787)) --- test/dynamo/test_inline_inbuilt_nn_modules.py | 62 ------------------- torch/_dynamo/symbolic_convert.py | 8 +-- 2 files changed, 1 insertion(+), 69 deletions(-) delete mode 100644 test/dynamo/test_inline_inbuilt_nn_modules.py diff --git a/test/dynamo/test_inline_inbuilt_nn_modules.py b/test/dynamo/test_inline_inbuilt_nn_modules.py deleted file mode 100644 index f7ba32bc15f3..000000000000 --- a/test/dynamo/test_inline_inbuilt_nn_modules.py +++ /dev/null @@ -1,62 +0,0 @@ -# Owner(s): ["module: dynamo"] - -from torch._dynamo import config -from torch._dynamo.testing import make_test_cls_with_patches - -try: - from . import ( - test_aot_autograd, - test_functions, - test_higher_order_ops, - test_misc, - test_modules, - # test_repros, - ) -except ImportError: - import test_aot_autograd - import test_functions - import test_higher_order_ops - import test_misc - import test_modules - - -test_classes = {} - - -def make_inline_inbuilt_nn_modules_cls(cls): - suffix = "_inline_inbuilt_nn_modules" - - cls_prefix = "InlineInbuiltNNModules" - - test_class = make_test_cls_with_patches( - cls, - cls_prefix, - suffix, - (config, "inline_inbuilt_nn_modules", True), - xfail_prop="_expected_failure_inline_inbuilt_nn_modules", - ) - - test_classes[test_class.__name__] = test_class - # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING - globals()[test_class.__name__] = test_class - test_class.__module__ = __name__ - return test_class - - -tests = [ - test_misc.MiscTests, - test_functions.FunctionTests, - test_modules.NNModuleTests, - test_higher_order_ops.HigherOrderOpTests, - test_higher_order_ops.FuncTorchHigherOrderOpTests, - test_aot_autograd.AotAutogradFallbackTests, - # test_repros.ReproTests, -] -for test in tests: - make_inline_inbuilt_nn_modules_cls(test) -del test - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index da04fdfa8584..30f28e2ab265 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -101,7 +101,7 @@ PythonModuleVariable, UnknownVariable, ) -from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable +from .variables.nn_module import NNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable from .variables.user_defined import ( RemovableHandleVariable, @@ -414,12 +414,6 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if push: self.push(value) self.jump(inst) - elif isinstance(value, UnspecializedNNModuleVariable): - mod = value.value - if truth_fn(mod): - if push: - self.push(value) - self.jump(inst) elif isinstance(value, UserDefinedObjectVariable): try: x = value.var_getattr(self, "__bool__") From 44371bd43276b27aa5da0e223fb7daaf52558767 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 8 Jun 2024 06:35:34 +0000 Subject: [PATCH 527/706] Revert "[dynamo][nn-modules] Trace through nn.Module dunder methods for UnspecializedNNModule (#126578)" This reverts commit 7ede78f9f5d7e6c993faa1a70a5f0b0eaec5640d. Reverted https://github.com/pytorch/pytorch/pull/126578 on behalf of https://github.com/anijain2305 due to pippy tests fail ([comment](https://github.com/pytorch/pytorch/pull/126578#issuecomment-2155836555)) --- test/distributed/test_dynamo_distributed.py | 10 ++- test/dynamo/test_higher_order_ops.py | 16 ++--- ...=> FakeTensorTest.test_embedding_bag_meta} | 0 ...ansformsCPU.test_compile_vmap_hessian_cpu} | 0 ...> TestEmbeddingNN.test_embedding_max_norm} | 0 ...stEmbeddingNN.test_embedding_sparse_basic} | 0 ...ddingNN.test_embedding_sparse_empty_tensor | 0 ...ngNN.test_embeddingbag_include_last_offset | 0 ....test_profiler_pattern_matcher_json_report | 0 .../TestJitGeneratedModule.test_nn_Bilinear | 0 .../TestJitGeneratedModule.test_nn_Embedding | 0 ...dModule.test_nn_EmbeddingBag_discontiguous | 0 ...itGeneratedModule.test_nn_EmbeddingBag_max | 0 ...odule.test_nn_EmbeddingBag_max_padding_idx | 0 ...tGeneratedModule.test_nn_EmbeddingBag_mean | 0 ...dule.test_nn_EmbeddingBag_mean_padding_idx | 0 ...eneratedModule.test_nn_EmbeddingBag_sparse | 0 ...itGeneratedModule.test_nn_EmbeddingBag_sum | 0 ...odule.test_nn_EmbeddingBag_sum_padding_idx | 0 ...atedModule.test_nn_Embedding_discontiguous | 0 ...itGeneratedModule.test_nn_Embedding_sparse | 0 .../TestJitGeneratedModule.test_nn_Linear | 0 ...eneratedModule.test_nn_Linear_no_batch_dim | 0 ...GeneratedModule.test_nn_PReLU_no_batch_dim | 0 .../TestNN.test_ParameterDict | 0 .../TestNN.test_Sequential_iadd | 0 .../TestNN.test_bilinear_broadcasting | 0 ...st_layer_norm_grads_with_create_graph_flag | 0 ..._linear_autograd_device_cpu_bias_weightCOO | 0 ..._linear_autograd_device_cpu_bias_weightCSC | 0 ..._linear_autograd_device_cpu_bias_weightCSR | 0 .../TestNN.test_linear_broadcasting | 0 .../TestNN.test_module_apply_inplace_op | 0 ...metrized_tensor_parametrization_swap_False | 0 ...weight_norm_parametrization_swap_False_cpu | 0 ..._weight_norm_parametrization_swap_True_cpu | 0 ...sorDeviceTypeCPU.test_embedding_jagged_cpu | 0 .../TestPruningNN.test_identity_pruning | 0 .../TestPruningNN.test_random_pruning_0perc | 0 test/profiler/test_profiler.py | 1 - torch/_dynamo/create_parameter_op.py | 20 ------ torch/_dynamo/mutation_guard.py | 3 - torch/_dynamo/side_effects.py | 32 ++++------ torch/_dynamo/symbolic_convert.py | 11 +--- torch/_dynamo/utils.py | 4 +- torch/_dynamo/variables/dicts.py | 6 +- torch/_dynamo/variables/misc.py | 26 +++----- torch/_dynamo/variables/nn_module.py | 40 ++++-------- torch/_dynamo/variables/torch.py | 9 +-- torch/_dynamo/variables/user_defined.py | 63 +++++++------------ 50 files changed, 72 insertions(+), 169 deletions(-) rename test/dynamo_expected_failures/{TestNN.test_overwrite_module_params_on_conversion => FakeTensorTest.test_embedding_bag_meta} (100%) rename test/dynamo_expected_failures/{TestNNParametrization.test_new_spectral_norm_forward_swap_True => TestCompileTransformsCPU.test_compile_vmap_hessian_cpu} (100%) rename test/dynamo_expected_failures/{TestNNParametrization.test_new_spectral_norm_swap_True => TestEmbeddingNN.test_embedding_max_norm} (100%) rename test/dynamo_expected_failures/{TestPruningNN.test_pruning_id_consistency => TestEmbeddingNN.test_embedding_sparse_basic} (100%) create mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor create mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset create mode 100644 test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim create mode 100644 test/dynamo_expected_failures/TestNN.test_ParameterDict create mode 100644 test/dynamo_expected_failures/TestNN.test_Sequential_iadd create mode 100644 test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting create mode 100644 test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_broadcasting create mode 100644 test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False create mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu create mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu create mode 100644 test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu create mode 100644 test/dynamo_expected_failures/TestPruningNN.test_identity_pruning create mode 100644 test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index db44f1ce915d..b31a2f717537 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1084,14 +1084,12 @@ def _(ctx): # far from an exhaustive check of all the expected guards, just check a couple of them. FileCheck().check("""local "L['self']" TYPE_MATCH""").check( """local "L['self']" ID_MATCH""" + ).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check( + f"""{expected_guard_source} "L['self'].net" ID_MATCH""" ).check( - f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" + f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH""" ).check( - f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH""" - ).check( - f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" - ).check( - f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH""" + f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH""" ).run( GUARDS_FILE.getvalue() ) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 43bc69ea403b..9b86a90b02f3 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -5118,10 +5118,10 @@ def wrapper_fn(x): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"): - l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_ + def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): + l_self_tensor_constant0 = L_self_tensor_constant0 - alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None + alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default) @@ -5140,16 +5140,16 @@ def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): - l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_ - l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_ + def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): + getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_ + getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_ l_flat_tangents_1_ = L_flat_tangents_1_ - _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None + _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None - mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None + mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None return (mul_tensor,) """, ) diff --git a/test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion b/test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta similarity index 100% rename from test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion rename to test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True b/test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu similarity index 100% rename from test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True rename to test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm similarity index 100% rename from test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True rename to test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm diff --git a/test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic similarity index 100% rename from test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency rename to test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset b/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report b/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_ParameterDict b/test/dynamo_expected_failures/TestNN.test_ParameterDict new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_Sequential_iadd b/test/dynamo_expected_failures/TestNN.test_Sequential_iadd new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting b/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag b/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_broadcasting b/test/dynamo_expected_failures/TestNN.test_linear_broadcasting new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op b/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False b/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu b/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning b/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc b/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 2844663ca6e0..38e83d448fdd 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -2408,7 +2408,6 @@ def test_profiler_matmul_dim_fp16_pattern(self): num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) - @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_profiler_pattern_matcher_json_report(self): x = torch.ones((100, 100)) model = nn.Sequential( diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py index 601d3c94bdc1..42981fcf1015 100644 --- a/torch/_dynamo/create_parameter_op.py +++ b/torch/_dynamo/create_parameter_op.py @@ -1,6 +1,3 @@ -import threading -from contextlib import contextmanager - import torch doc = """ @@ -39,20 +36,3 @@ def new_parameter_placeholder(size, dtype, device, requires_grad): # Allocating a zero tensor would causes assert failures in autograd. result.untyped_storage().resize_(0) return result - - -_TLS = threading.local() - - -@contextmanager -def do_not_convert_to_tracable_parameter(): - old_flag = getattr(_TLS, "convert_tracable_parameter", True) - _TLS.convert_tracable_parameter = False - try: - yield False - finally: - _TLS.convert_tracable_parameter = old_flag - - -def can_convert_to_tracable_parameter(): - return getattr(_TLS, "convert_tracable_parameter", True) diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py index 00347a012676..1fa24cfa25bb 100644 --- a/torch/_dynamo/mutation_guard.py +++ b/torch/_dynamo/mutation_guard.py @@ -10,9 +10,6 @@ from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks -unpatched_nn_module_init = torch.nn.Module.__init__ - - class MutationTracker: db = ExactWeakKeyDictionary() diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 1fa1c004e01a..647fae379c54 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -346,7 +346,13 @@ def codegen_save_tempvars(self, cg: PyCodegen): elif isinstance(var.mutable_local, AttributeMutationNew): if isinstance(var, variables.AutogradFunctionContextVariable): unimplemented("AutogradFunctionContextVariable escaped") - cg.load_import_from(utils.__name__, "object_new") + if "__call_nn_module_init" in self.store_attr_mutations.get( + var.mutable_local, {} + ): + assert isinstance(var, variables.UnspecializedNNModuleVariable) + cg.load_import_from(utils.__name__, "nn_module_new") + else: + cg.load_import_from(utils.__name__, "object_new") cg(var.mutable_local.cls_source) cg.extend_output(create_call_function(1, True)) cg.add_cache(var) @@ -473,25 +479,9 @@ def codegen_update_mutated(self, cg: PyCodegen): ] ) elif self.is_attribute_mutation(var): - # Applying mutations involves two steps: 1) Push all - # reconstructed objects onto the stack. 2) Call STORE_ATTR to - # apply the mutations. - # - # Dynamo must ensure that mutations are applied in the same - # order as in the original program. Therefore, two reverse - # operations occur below. - # - # The first reverse operation concerns `suffixes`. We apply - # suffixes in reverse order due to the way Python handles the - # stack. In Step 1, we push all reconstructed objects onto the - # stack, but the item at the top of the stack refers to the last - # attribute in the mutation order. If not fixed, this will apply - # the mutations of attributes in the reverse order. To account - # for this reversal, we iterate through the mutable attributes - # in reverse order. - for name, value in reversed( - self.store_attr_mutations.get(var.mutable_local, {}).items() - ): + for name, value in self.store_attr_mutations.get( + var.mutable_local, {} + ).items(): if isinstance(var, variables.NewGlobalVariable): cg.tx.output.update_co_names(name) cg(value) @@ -499,6 +489,8 @@ def codegen_update_mutated(self, cg: PyCodegen): suffixes.append( [create_instruction("STORE_GLOBAL", argval=name)] ) + elif name == "__call_nn_module_init": + pass # handled in codegen_save_tempvars elif isinstance(value, variables.DeletedVariable): if isinstance( var.mutable_local, AttributeMutationExisting diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 30f28e2ab265..71ed48fbb292 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -415,15 +415,10 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): self.push(value) self.jump(inst) elif isinstance(value, UserDefinedObjectVariable): - try: - x = value.var_getattr(self, "__bool__") - except exc.ObservedException: - # if __bool__ is missing, trying __len__ to infer a truth value. + x = value.var_getattr(self, "__bool__") + # if __bool__ is missing, trying __len__ to infer a truth value. + if isinstance(x, GetAttrVariable): x = value.var_getattr(self, "__len__") - else: - if isinstance(x, GetAttrVariable): - # if __bool__ is missing, trying __len__ to infer a truth value. - x = value.var_getattr(self, "__len__") # __bool__ or __len__ is function if isinstance(x, UserMethodVariable): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 238e2b8227cd..60be7898d929 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2018,12 +2018,12 @@ def object_has_getattribute(value: Any): return False -def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): +def get_custom_getattr(value: Any): try: getattr_fn = inspect.getattr_static(type(value), "__getattr__") except AttributeError: getattr_fn = None - if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__: + if getattr_fn is torch.nn.Module.__getattr__: # ignore this case of getattr getattr_fn = None return getattr_fn diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 8391563c8e76..0724a80621f7 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -174,11 +174,7 @@ def python_type(self): def __contains__(self, vt): assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker - return ( - is_hashable(vt) - and Hashable(vt) in self.items - and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) - ) + return is_hashable(vt) and Hashable(vt) in self.items def reconstruct(self, codegen): # instructions to load collections.OrderedDict if necessary diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 9ef36eb7f29f..cc0fb7096701 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -14,10 +14,8 @@ import torch.utils._pytree as pytree from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction -from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import unimplemented from ..guards import GuardBuilder, install_guard -from ..mutation_guard import unpatched_nn_module_init from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource from ..utils import ( check_unspec_or_constant_args, @@ -123,6 +121,7 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) + if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: @@ -134,10 +133,12 @@ def call_method( and isinstance(objvar.mutable_local, AttributeMutationNew) and not (args or kwargs) ): - with do_not_convert_to_tracable_parameter(): - return variables.UserFunctionVariable( - unpatched_nn_module_init, source=source - ).call_function(tx, [self.objvar] + args, kwargs) + tx.output.side_effects.store_attr( + objvar, + "__call_nn_module_init", + variables.ConstantVariable.create(True), + ) + return variables.ConstantVariable.create(None) else: unimplemented("super() nn.Module.__init__") elif isinstance(inner_fn, types.FunctionType): @@ -174,19 +175,6 @@ def call_method( self.objvar, UserDefinedObjectVariable ): return self.objvar.method_setattr_standard(tx, *args, **kwargs) - elif inner_fn is object.__delattr__: - attr = args[0] - try: - attr = attr.as_python_constant() - except NotImplementedError: - unimplemented(f"non-const delattr attr: {attr}") - if not tx.output.side_effects.is_attribute_mutation(self.objvar): - unimplemented(f"delattr({self.objvar}, {attr}, ...)") - - tx.output.side_effects.store_attr( - self.objvar, attr, variables.DeletedVariable() - ) - return variables.ConstantVariable(None) unimplemented(f"non-function or method super: {inner_fn}") diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 5699d7341429..0a6bad4730dd 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -236,7 +236,7 @@ def _custom_getattr_fallback(self, base, tx, name, options): if object_has_getattribute(base): unimplemented("torch.nn.Module with a custom __getattribute__ defined") - getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True) + getattr_fn = get_custom_getattr(base) if getattr_fn is None: return None @@ -672,6 +672,7 @@ def gen_source(source, name): if isinstance(args[0], SliceVariable): # Build a TupleVariable of NNModules result = [] + submods = [] # Turn the slice into the list of integers keys = list(range(len(module)))[args[0].as_python_constant()] @@ -685,8 +686,9 @@ def gen_source(source, name): source=src, ) ) + submods.append(submod) - new_module = module[args[0].as_python_constant()] + new_module = torch.nn.Sequential(*submods) new_module_variable = tx.output.register_attr_or_module( new_module, f"{self}.__getitem__(slice)", @@ -700,10 +702,8 @@ def gen_source(source, name): if isinstance(args[0], SymNodeVariable): key = args[0].evaluate_expr(tx.output) - elif args[0].is_python_constant(): - key = args[0].as_python_constant() else: - unimplemented(f"getitem on NNModuleVariable with key {args[0]}") + key = args[0].as_python_constant() submod = module[key] return tx.output.register_attr_or_module( @@ -783,7 +783,7 @@ def __init__(self, value, **kwargs): @functools.lru_cache(None) def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler - supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} + supported = {torch.nn.Module.__setattr__} return { id(x.__code__) for x in torch.nn.Module.__dict__.values() @@ -791,6 +791,8 @@ def _nn_module_method_ids(): } def unpack_var_sequence(self, tx): + from .builder import VariableBuilder + try: fn = inspect.getattr_static(self.value_type, "__iter__") except AttributeError as e: @@ -801,16 +803,11 @@ def unpack_var_sequence(self, tx): torch.nn.ParameterList.__iter__, torch.nn.Sequential.__iter__, ): - # The program can mutate the nn module object but the saved `value` - # will not reflect the mutations. So, trace through the `__iter__` - # function to reflect any tracked mutations. - return tx.inline_user_function_return( - variables.UserFunctionVariable(fn), - [ - self, - ], - {}, - ).unpack_var_sequence(tx) + assert self.source + return [ + VariableBuilder(tx, source=GetItemSource(self.source, idx))(item) + for idx, item in enumerate(self.value) + ] return super().unpack_var_sequence(tx) @@ -937,17 +934,6 @@ def call_method( # Handle submodules self.is_state_mutated = True - if method is torch.nn.Module.__setattr__ and isinstance( - args[1], variables.DeletedVariable - ): - # Trace through __delattr__ to track mutations on the module - # members like `_modules``. - return tx.inline_user_function_return( - variables.UserFunctionVariable(torch.nn.Module.__delattr__), - [self, args[0]], - kwargs, - ) - return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 36fa0a697032..0b3e28860aaf 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -17,11 +17,7 @@ from ..._guards import TracingContext from .. import config, polyfill, variables from ..codegen import PyCodegen -from ..create_parameter_op import ( - can_convert_to_tracable_parameter, - new_parameter_placeholder, - tracable_create_parameter, -) +from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter from ..device_interface import get_registered_device_interfaces from ..exc import unimplemented from ..guards import GuardBuilder, install_guard @@ -874,9 +870,6 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) - if not can_convert_to_tracable_parameter(): - unimplemented("Workaround for issues with nn_parameter construction") - try: shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) dtype = data.var_getattr(tx, "dtype").as_python_constant() diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index d5faafcffbed..5b785293911f 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -34,8 +34,7 @@ from torch._guards import TracingContext from .. import variables -from ..create_parameter_op import do_not_convert_to_tracable_parameter -from ..exc import ObservedException, unimplemented +from ..exc import unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource from ..utils import ( @@ -58,7 +57,10 @@ def is_standard_setattr(val): - return val in (object.__setattr__,) + return val in ( + object.__setattr__, + torch.nn.Module.__setattr__, + ) class UserDefinedVariable(VariableTracker): @@ -376,7 +378,17 @@ def call_function( else UserDefinedObjectVariable, {}, ) - with do_not_convert_to_tracable_parameter(): + if ( + inspect.getattr_static(self.value, "__init__", None) + is torch.nn.Module.__init__ + ): + tx.output.side_effects.store_attr( + var, + "__call_nn_module_init", + variables.ConstantVariable.create(True), + ) + return var + else: var.call_method(tx, "__init__", args, kwargs) return var elif variables.CustomizedDictVariable.is_matching_cls(self.value): @@ -626,10 +638,6 @@ def call_method( else AttrSource(AttrSource(self.source, "__class__"), name) ) # TODO(jansel): add a guard to check for monkey patching? - from ..mutation_guard import unpatched_nn_module_init - - if method is torch.nn.Module.__init__: - method = unpatched_nn_module_init return UserMethodVariable(method, self, source=source).call_function( tx, args, kwargs ) @@ -791,7 +799,7 @@ def _check_for_getattr(self): def _getattr_static(self, name): if ( - isinstance(self.value, PyTreeSpec) + isinstance(self.value, (torch.nn.Module, PyTreeSpec)) or "__slots__" in self.value.__class__.__dict__ or type(self.value) == threading.local ): @@ -804,6 +812,7 @@ def _getattr_static(self, name): return cls_var except AttributeError: pass # __slots__ + # this might call torch.nn.Module.__getattr__ subobj = getattr(self.value, name) else: subobj = inspect.getattr_static(self.value, name) @@ -992,35 +1001,14 @@ def call_hasattr(self, tx, name: str) -> "VariableTracker": install_guard( AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) ) - if self._check_for_getattribute(): - unimplemented("hasattr with custom __getattribute__") + if self._check_for_getattribute() or self._check_for_getattr(): + unimplemented("hasattr with custom __getattr__") try: self._getattr_static(name) return variables.ConstantVariable.create(True) except AttributeError: - # Now check in __getattr__ function - getattr_fn = self._check_for_getattr() - if isinstance(getattr_fn, types.FunctionType): - # Dynamo is going to trace the __getattr__ function with - # args=name. Set the source accordingly. - new_source = None - if self.source: - new_source = AttrSource(self.source, "__getattr__") - try: - result = variables.UserMethodVariable( - getattr_fn, self, source=new_source - ).call_function(tx, [variables.ConstantVariable.create(name)], {}) - - return variables.ConstantVariable.create( - not isinstance(result, variables.DeletedVariable) - ) - except ObservedException: - return variables.ConstantVariable.create(False) - elif getattr_fn is None: - return variables.ConstantVariable.create(False) - else: - unimplemented("UserDefined with non-function __getattr__") + return variables.ConstantVariable.create(False) def odict_getitem(self, tx, key): from .builder import VariableBuilder @@ -1087,12 +1075,6 @@ def var_getattr(self, tx, name): return super().var_getattr(tx, name) -class RemovableHandleClass: - # Dummy class to pass to python_type of RemovableHandleVariable - # Useful for isinstance check on hooks - pass - - class RemovableHandleVariable(VariableTracker): REMOVED = -1 @@ -1123,6 +1105,3 @@ def reconstruct(self, codegen): return # unreachable due to codegen.add_cache() when the hook is installed super().reconstruct(codegen) - - def python_type(self): - return RemovableHandleClass From 0e6c204642a571d5a7cd60be0caeb9b50faca030 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 7 Jun 2024 19:14:44 -0700 Subject: [PATCH 528/706] [pipelining] Friendly error message when not traceable (#128276) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128276 Approved by: https://github.com/H-Huang --- torch/distributed/pipelining/_IR.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 33243108f468..e58749c581cb 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1001,11 +1001,21 @@ def _trace_with_export( example_kwargs: Optional[Dict[str, Any]] = None, ) -> ExportedProgram: logger.info("Tracing model ...") - ep = torch.export.export( - mod, - example_args, - example_kwargs, - ) + try: + ep = torch.export.export( + mod, + example_args, + example_kwargs, + ) + except Exception as e: + raise RuntimeError( + "It seems that we cannot capture your model as a full graph. " + "Typical reasons include graph breaks, data/shape-dependent " + "control flow, or missing meta kernels for custom operators. " + "You can use our manual pipeline interfaces, or try to fix the " + "graph breaks, see https://pytorch.org/docs/stable/export.html" + ) from e + return ep @staticmethod From 73d6ec2db6f591d4d45b76e2754a02a15aa67f81 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 7 Jun 2024 20:56:58 -0700 Subject: [PATCH 529/706] Increase verbosity of FX graph dumps (#128042) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/128042 Approved by: https://github.com/aorenste --- test/dynamo/test_misc.py | 10 ++--- torch/_dynamo/output_graph.py | 5 ++- .../dispatch_and_compile_graph.py | 13 ++++++- .../jit_compile_runtime_wrappers.py | 37 ++++++++++++++++--- torch/_inductor/compile_fx.py | 11 +++++- 5 files changed, 60 insertions(+), 16 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index bcb0fd18818e..ce423eab7d8a 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -698,7 +698,7 @@ def f(x, y, z, n): self.assertExpectedInline( post_grad_graphs, """\ -def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]"): +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): # No stacktrace found for following nodes foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None return ()""", @@ -757,11 +757,11 @@ def f(x, y, z, n): self.assertExpectedInline( post_grad_graphs, """\ -def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]"): +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): # No stacktrace found for following nodes foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None - getitem_4: "f32[3]" = foo_default[0] - getitem_5: "f32[3]" = foo_default[1]; foo_default = None + getitem_4: "f32[3][1]cpu" = foo_default[0] + getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None return (getitem_4, getitem_5)""", ) @@ -849,7 +849,7 @@ def f(x, y, z, n): self.assertExpectedInline( post_grad_graphs, """\ -def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"): +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): # No stacktrace found for following nodes foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg1_1 = arg0_1 = None return ()""", diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 21cd2e889e90..03e7f844cbd4 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1287,7 +1287,10 @@ def compile_and_call_fx_graph(self, tx, rv, root): "dynamo_flat_name_to_original_fqn" ] = self.dynamo_flat_name_to_original_fqn.copy() - graph_code_log.debug("%s", lazy_format_graph_code(name, gm)) + graph_code_log.debug( + "%s", + lazy_format_graph_code(name, gm, include_stride=True, include_device=True), + ) torch._logging.trace_structured( "dynamo_output_graph", lambda: {"sizes": self.get_graph_sizes_structured()}, diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py index 1a6f1c7dce1e..c956d58b645e 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -186,11 +186,20 @@ def _map_assigned_buffer_to_proxy(_mod, name, buffer): if aot_config.enable_log: aot_graphs_log.info( - "%s", lazy_format_graph_code("Forward graph", fw_module, aot_config.aot_id) + "%s", + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + ), ) trace_structured( "aot_forward_graph", - payload_fn=lambda: fw_module.print_readable(print_output=False), + payload_fn=lambda: fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) # TODO: should factor this into a separate function for export that always only returns just the graph. diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 9eff7b20c04b..b8093b3cdc98 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -247,11 +247,20 @@ def aot_dispatch_autograd( if aot_config.enable_log: aot_joint_log.info( - "%s", lazy_format_graph_code("Joint graph", fx_g, aot_config.aot_id) + "%s", + lazy_format_graph_code( + "Joint graph", + fx_g, + aot_config.aot_id, + include_stride=True, + include_device=True, + ), ) trace_structured( "aot_joint_graph", - payload_fn=lambda: fx_g.print_readable(print_output=False), + payload_fn=lambda: fx_g.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) with torch.no_grad(): @@ -389,19 +398,35 @@ def aot_dispatch_autograd( if aot_config.enable_log: aot_graphs_log.info( "%s", - lazy_format_graph_code("Forward graph", fw_module, aot_config.aot_id), + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + ), ) aot_graphs_log.info( "%s", - lazy_format_graph_code("Backward graph", bw_module, aot_config.aot_id), + lazy_format_graph_code( + "Backward graph", + bw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + ), ) trace_structured( "aot_forward_graph", - payload_fn=lambda: fw_module.print_readable(print_output=False), + payload_fn=lambda: fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) trace_structured( "aot_backward_graph", - payload_fn=lambda: bw_module.print_readable(print_output=False), + payload_fn=lambda: bw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) with track_graph_compiling(aot_config, "forward"): diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index ce7d8f6e9b14..618df9fd8ff6 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -744,10 +744,17 @@ def fx_codegen_and_compile( # has some issues with memory in training _recursive_post_grad_passes(gm, is_inference=is_inference) V.debug.fx_graph_transformed(gm, example_inputs) - post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm)) + post_grad_graphs_log.debug( + "%s", + lazy_format_graph_code( + "AFTER POST GRAD", gm, include_stride=True, include_device=True + ), + ) trace_structured( "inductor_post_grad_graph", - payload_fn=lambda: gm.print_readable(print_output=False), + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) if config.is_fbcode(): log_optimus_to_scuba( From 695502ca653d25be23be61a909587f110067e859 Mon Sep 17 00:00:00 2001 From: cyy Date: Sat, 8 Jun 2024 08:06:31 +0000 Subject: [PATCH 530/706] [3/N] Change static functions in headers to inline (#128194) Follows #127764 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128194 Approved by: https://github.com/ezyang, https://github.com/Skylion007 --- aten/src/ATen/TensorIndexing.h | 38 +++++++++---------- aten/src/ATen/core/Generator.h | 6 +-- aten/src/ATen/core/TensorBase.h | 2 +- .../ATen/core/dispatch/DispatchKeyExtractor.h | 2 +- aten/src/ATen/core/stack.h | 38 +++++++++---------- aten/src/ATen/native/ConvUtils.h | 32 ++++++++-------- aten/src/ATen/native/cpu/Loops.h | 14 +++---- aten/src/ATen/native/cpu/Reduce.h | 18 ++++----- 8 files changed, 74 insertions(+), 76 deletions(-) diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index fc951207f009..eb36c0e02fa4 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -197,7 +197,7 @@ TORCH_API std::ostream& operator<<( const std::vector& tensor_indices); namespace impl { -static inline Tensor applySlice( +inline Tensor applySlice( const Tensor& self, int64_t dim, c10::SymInt start, @@ -227,7 +227,7 @@ static inline Tensor applySlice( dim, std::move(start), std::move(stop), std::move(step)); } -static inline Tensor applySelect( +inline Tensor applySelect( const Tensor& self, int64_t dim, SymInt index, @@ -266,9 +266,7 @@ static inline Tensor applySelect( return self.select_symint(dim, std::move(index)); } -static inline Tensor boolToIndexingTensorCPUOrCUDA( - const Tensor& self, - bool value) { +inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) { // booleans add a dimension of size 1. true indexes this dimension as if 0:, // false as empty. if (value) { @@ -278,7 +276,7 @@ static inline Tensor boolToIndexingTensorCPUOrCUDA( } } -static inline Tensor boolToIndexingTensorNonNativeDeviceType( +inline Tensor boolToIndexingTensorNonNativeDeviceType( const Tensor& self, bool value) { // booleans add a dimension of size 1. true indexes this dimension as if 0:, @@ -290,7 +288,7 @@ static inline Tensor boolToIndexingTensorNonNativeDeviceType( } } -static inline Tensor boolToIndexingTensor( +inline Tensor boolToIndexingTensor( const Tensor& self, bool value, const at::Device& self_device) { @@ -301,13 +299,13 @@ static inline Tensor boolToIndexingTensor( } } -static inline Tensor scalarToTensorNonNativeDeviceType( +inline Tensor scalarToTensorNonNativeDeviceType( const Scalar& v, const TensorOptions& options) { return at::scalar_tensor(v, options); } -static inline void recordTensorIndex( +inline void recordTensorIndex( const Tensor& tensor, std::vector& outIndices, int64_t* dim_ptr) { @@ -317,7 +315,7 @@ static inline void recordTensorIndex( (*dim_ptr)++; }; -static inline c10::List<::std::optional> typeConvertIndices( +inline c10::List<::std::optional> typeConvertIndices( const Tensor& /*self*/, std::vector&& indices) { c10::List<::std::optional> converted_inds; @@ -338,7 +336,7 @@ static inline c10::List<::std::optional> typeConvertIndices( // construct a `std::vector` container to be consumed by the C++ // `count_specified_dimensions` function, which adds 100s of nanoseconds // overhead and is undesirable. -static inline int64_t count_specified_dimensions( +inline int64_t count_specified_dimensions( const ArrayRef& indices) { // Count the number of indexed dimensions (everything but ellipsis and None) int64_t count = 0; @@ -372,7 +370,7 @@ static inline int64_t count_specified_dimensions( // // The rest of the functions are in `at::indexing::impl` namespace, signifying // that they shouldn't be used from Python indexing implementation. -static inline Tensor scalarToTensor( +inline Tensor scalarToTensor( const Scalar& v, const TensorOptions& options, const at::Device& self_device) { @@ -387,7 +385,7 @@ static inline Tensor scalarToTensor( // To match numpy semantics: // As a special case for backwards compatibility, // strip away unit dimensions from the left of 'src' -static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) { +inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) { size_t first_non1_src = sizes.size(); for (const auto i : c10::irange(sizes.size())) { // Unbacked SymInt has different behavior, but this is sound because @@ -402,7 +400,7 @@ static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) { return sizes.slice(first_non1_src); } -static inline void copy_to(const Tensor& dst, const Tensor& src) { +inline void copy_to(const Tensor& dst, const Tensor& src) { if (dst.sym_sizes().equals(src.sym_sizes())) { // A shortcut to avoid generating hard-coded constant sizes during tracing. // This is not a perfect solution: when src & dst have different shapes, @@ -421,7 +419,7 @@ static inline void copy_to(const Tensor& dst, const Tensor& src) { // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor // indexing functions from Python ] -static inline Tensor handleDimInMultiDimIndexing( +inline Tensor handleDimInMultiDimIndexing( const Tensor& prev_dim_result, const Tensor& original_tensor, const TensorIndex& index, @@ -509,7 +507,7 @@ static inline Tensor handleDimInMultiDimIndexing( namespace impl { // This mirrors `applySlicing` in // torch/csrc/autograd/python_variable_indexing.cpp -static inline Tensor applySlicing( +inline Tensor applySlicing( const Tensor& self, const ArrayRef& indices, std::vector& outIndices, @@ -550,13 +548,13 @@ static inline Tensor applySlicing( } } // namespace impl -static inline Tensor dispatch_index( +inline Tensor dispatch_index( const Tensor& self, std::vector&& indices) { return self.index(impl::typeConvertIndices(self, std::move(indices))); } -static inline Tensor dispatch_index_put_( +inline Tensor dispatch_index_put_( Tensor& self, std::vector&& indices, const Tensor& value) { @@ -598,7 +596,7 @@ static inline Tensor dispatch_index_put_( // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting // `disable_slice_optimization` when calling C++ tensor indexing functions from // Python ] -static inline Tensor get_item( +inline Tensor get_item( const Tensor& self, const ArrayRef& indices, bool disable_slice_optimization = false) { @@ -664,7 +662,7 @@ static inline Tensor get_item( // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++ // tensor indexing functions from Python ] -static inline void set_item( +inline void set_item( const Tensor& self, const ArrayRef& indices, const Tensor& value, diff --git a/aten/src/ATen/core/Generator.h b/aten/src/ATen/core/Generator.h index 6b76db5d0686..297b805f407b 100644 --- a/aten/src/ATen/core/Generator.h +++ b/aten/src/ATen/core/Generator.h @@ -150,7 +150,7 @@ Generator make_generator(Args&&... args) { * the backend generator type (CPU/CUDAGeneratorImpl etc.) */ template -static inline T * check_generator(std::optional gen) { +inline T * check_generator(std::optional gen) { TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt"); TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed"); TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'"); @@ -164,7 +164,7 @@ static inline T * check_generator(std::optional gen) { * the backend generator type (CPU/CUDAGeneratorImpl etc.) */ template -static inline T* get_generator_or_default(const std::optional& gen, const Generator& default_gen) { +inline T* get_generator_or_default(const std::optional& gen, const Generator& default_gen) { return gen.has_value() && gen->defined() ? check_generator(gen) : check_generator(default_gen); } @@ -177,7 +177,7 @@ namespace detail { * - The new state tensor must be a torch.ByteTensor * - Data of the new state tensor must be contiguous */ -static inline void check_rng_state(const c10::TensorImpl& new_state) { +inline void check_rng_state(const c10::TensorImpl& new_state) { TORCH_CHECK_TYPE( new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte, "RNG state must be a torch.ByteTensor" diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 0188e546179b..7218ee56689c 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -953,7 +953,7 @@ TensorBase make_tensor_base(Args&&... args) { } // namespace detail -static inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) { +inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) { return legacyExtractDispatchKey(t.key_set()); } diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index 46c291bada30..4a345facaa94 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -21,7 +21,7 @@ namespace impl { // on TLS. // // NB: If there is no valid dispatch key, this will return Undefined -static inline DispatchKeySet computeDispatchKeySet( +inline DispatchKeySet computeDispatchKeySet( DispatchKeySet ks, // The key mask lets us eliminate (by zero entries) keys which should not // be considered for dispatch. There are two cases when we use this: diff --git a/aten/src/ATen/core/stack.h b/aten/src/ATen/core/stack.h index 5dc89da6c562..6372a3ccb556 100644 --- a/aten/src/ATen/core/stack.h +++ b/aten/src/ATen/core/stack.h @@ -66,51 +66,51 @@ class Operation { // treat the last N elements of the stack as a list, looking up // element i -static inline IValue& peek(Stack& stack, size_t i, size_t N) { +inline IValue& peek(Stack& stack, size_t i, size_t N) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions) return *(stack.end() - N + i); } -static inline IValue& peek(Stack* stack, size_t i, size_t N) { +inline IValue& peek(Stack* stack, size_t i, size_t N) { return peek(*stack, i, N); } -static inline const IValue& peek(const Stack& stack, size_t i, size_t N) { +inline const IValue& peek(const Stack& stack, size_t i, size_t N) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions) return *(stack.end() - N + i); } -static inline const IValue& peek(const Stack* stack, size_t i, size_t N) { +inline const IValue& peek(const Stack* stack, size_t i, size_t N) { return peek(*stack, i, N); } // treat the last N elements of the stack as a list, looking up the // slice starting at index i and having length len -static inline at::ArrayRef peekSlice( +inline at::ArrayRef peekSlice( const Stack& stack, size_t i, size_t len, size_t N) { return at::ArrayRef(stack).slice(stack.size() - N + i, len); } -static inline at::ArrayRef last(const Stack& stack, size_t N) { +inline at::ArrayRef last(const Stack& stack, size_t N) { return peekSlice(stack, 0, N, N); } -static inline at::ArrayRef last(const Stack* stack, size_t N) { +inline at::ArrayRef last(const Stack* stack, size_t N) { return last(*stack, N); } -static inline void drop(Stack& stack, size_t n) { +inline void drop(Stack& stack, size_t n) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions) stack.erase(stack.end() - n, stack.end()); } -static inline void drop(Stack* stack, size_t n) { +inline void drop(Stack* stack, size_t n) { drop(*stack, n); } -static inline IValue pop(Stack& stack) { +inline IValue pop(Stack& stack) { auto r = std::move(stack.back()); stack.pop_back(); return r; } -static inline IValue pop(Stack* stack) { +inline IValue pop(Stack* stack) { return pop(*stack); } -static inline std::vector pop(Stack& stack, size_t n) { +inline std::vector pop(Stack& stack, size_t n) { std::vector result; result.reserve(n); for (const auto i : c10::irange(n)) { @@ -127,7 +127,7 @@ static inline std::vector pop(Stack& stack, size_t n) { // b = pop(stack).toTensor(); // a = pop(stack).toInt(); template -static inline void pop(Stack& stack, Types&... args) { +inline void pop(Stack& stack, Types&... args) { size_t i = 0; constexpr size_t N = sizeof...(args); (void)std::initializer_list{ @@ -135,15 +135,15 @@ static inline void pop(Stack& stack, Types&... args) { drop(stack, N); } template -static inline void pop(Stack* stack, Types&... args) { +inline void pop(Stack* stack, Types&... args) { pop(*stack, args...); } template -static inline void push_one(Stack& stack, Type&& arg) { +inline void push_one(Stack& stack, Type&& arg) { stack.emplace_back(std::forward(arg)); } -static inline void push_one(Stack& stack, c10::TensorOptions options) { +inline void push_one(Stack& stack, c10::TensorOptions options) { stack.emplace_back(c10::typeMetaToScalarType(options.dtype())); stack.emplace_back(options.layout()); stack.emplace_back(options.device()); @@ -151,15 +151,15 @@ static inline void push_one(Stack& stack, c10::TensorOptions options) { } template -static inline void push(Stack& stack, Types&&... args) { +inline void push(Stack& stack, Types&&... args) { (void)std::initializer_list{(push_one(stack, std::forward(args)), 0)...}; } template -static inline void push(Stack* stack, Types&&... args) { +inline void push(Stack* stack, Types&&... args) { return push(*stack, std::forward(args)...); } template -static inline void push_list_elements(Stack& stack, const c10::List& elements) { +inline void push_list_elements(Stack& stack, const c10::List& elements) { for (T elem : elements) { stack.push_back(std::move(elem)); } diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index 446bbeccc223..4c77c983c295 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -75,7 +75,7 @@ namespace { } } -static inline bool cudnnv8_enabled_check_debug() { +inline bool cudnnv8_enabled_check_debug() { static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true; static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true; static uint8_t cudnnv8_debugcount = 0; @@ -86,7 +86,7 @@ static inline bool cudnnv8_enabled_check_debug() { return cudnnv8_flag == 1; } -static inline bool cudnnv8_use_heur_mode_b() { +inline bool cudnnv8_use_heur_mode_b() { return is_cudnnv8_heuristic_mode_b(); } @@ -186,7 +186,7 @@ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, co // (which the user can change) and computed inputs (which the user can // only indirectly affect). It would be an interesting exercise to // come up with a general framework to handle such situations.) -static void convolution_shape_check( +inline void convolution_shape_check( CheckedFrom c, const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) @@ -212,7 +212,7 @@ static void convolution_shape_check( // takes an extra output_padding argument to resolve the ambiguity. template -static inline std::vector _conv_output_size( +inline std::vector _conv_output_size( ArrayRef input_size, ArrayRef weight_size, ArrayRef padding, ArrayRef stride, ArrayRef dilation = ArrayRef() ) { @@ -231,14 +231,14 @@ static inline std::vector _conv_output_size( return output_size; } -static inline std::vector conv_output_size( +inline std::vector conv_output_size( IntArrayRef input_size, IntArrayRef weight_size, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() ) { return _conv_output_size(input_size, weight_size, padding, stride, dilation); } -static inline std::vector conv_output_size( +inline std::vector conv_output_size( SymIntArrayRef input_size, SymIntArrayRef weight_size, SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef() ) { @@ -264,14 +264,14 @@ std::vector _conv_input_size( return input_size; } -static inline std::vector conv_input_size( +inline std::vector conv_input_size( SymIntArrayRef output_size, SymIntArrayRef weight_size, SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups ) { return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); } -static inline std::vector conv_input_size( +inline std::vector conv_input_size( IntArrayRef output_size, IntArrayRef weight_size, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { @@ -295,27 +295,27 @@ std::vector _conv_weight_size( return weight_size; } -static inline std::vector conv_weight_size( +inline std::vector conv_weight_size( SymIntArrayRef input_size, SymIntArrayRef output_size, SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); } -static inline std::vector conv_weight_size( +inline std::vector conv_weight_size( IntArrayRef input_size, IntArrayRef output_size, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); } -static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) { +inline Tensor reshape_bias(int64_t dim, const Tensor& bias) { std::vector shape(dim, 1); shape[1] = -1; return bias.reshape(shape); } -static inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) { +inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) { // disable NHWC for float64 input. if (!at::detail::getCUDAHooks().compiledWithCuDNN() || input.scalar_type() == at::kDouble || @@ -351,7 +351,7 @@ TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable); TORCH_API bool _cudnn_get_conv_benchmark_empty_cache(); -static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { +inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { // disable NHWC for float64 input. if (!at::detail::getCUDAHooks().compiledWithMIOpen() || @@ -378,7 +378,7 @@ static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d; } -static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { +inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { // disable NHWC for float64 input. if (input.scalar_type() == at::kDouble || @@ -405,7 +405,7 @@ static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d; } -static inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { +inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { auto input_memory_format = input.suggest_memory_format(); auto weight_memory_format = weight.suggest_memory_format(); @@ -417,7 +417,7 @@ static inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at return can_use_thnn_channels_last_2d; } -static inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { +inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { // check layout only for xpu tensor. if (!input.is_xpu() || !weight.is_xpu()) { diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index 08c3bbe43500..7d87df45c1c5 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -82,7 +82,7 @@ dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& o template ::result_type>::value>::type* = nullptr> -static inline void +inline void execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { using traits = function_traits; using result_type = typename traits::result_type; @@ -97,7 +97,7 @@ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t template ::result_type>::value>::type* = nullptr> -static inline void +inline void execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { using traits = function_traits; for (; i < n; i++) { @@ -111,7 +111,7 @@ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t // Basic loop operation (one output, N inputs). May be auto-vectorized // by the compiler. Supports inputs and outputs of different types. template -static inline void +inline void basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) { using traits = function_traits; constexpr int ntensors = traits::arity + 1; @@ -166,7 +166,7 @@ void handle_tuple_outputs(char* C10_RESTRICT data[], // 2. Iterate over the members of the returned tuple, set the corresponding // output tensor by the tuple member in `handle_tuple_outputs` function. template -static inline void +inline void multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) { using traits = function_traits; @@ -195,7 +195,7 @@ multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_ // a scalar (stride 0). It's position is indicated by the argument `S`. If `S` // is 0, then there are no scalar inputs. template -static inline void +inline void vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) { using traits = function_traits; using scalar_t = typename function_traits::result_type; @@ -228,7 +228,7 @@ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, ve template -static inline void unroll_contiguous_scalar_checks( +inline void unroll_contiguous_scalar_checks( const int64_t* /*strides*/, std::index_sequence<>, cb_t&& cb) { @@ -236,7 +236,7 @@ static inline void unroll_contiguous_scalar_checks( } template -static inline void unroll_contiguous_scalar_checks( +inline void unroll_contiguous_scalar_checks( const int64_t* strides, std::index_sequence, cb_t&& cb) { diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index 26155373be58..37bd32d1c4c1 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -21,21 +21,21 @@ using namespace vec; // reduction that is contiguous over the input in dim 0 template -static inline bool is_contiguous_reduction(const int64_t* strides) { +inline bool is_contiguous_reduction(const int64_t* strides) { return strides[0] == 0 && strides[1] == sizeof(typename traits::arg2_t); } // reduction that is contiguous over the input in dim 1 template -static inline bool is_outer_reduction(const int64_t* strides) { +inline bool is_outer_reduction(const int64_t* strides) { return strides[0] == 0 && strides[2] == sizeof(typename traits::result_type) && strides[3] == sizeof(typename traits::arg2_t); } template -static inline void vectorized_reduction(char** data, int64_t n, int64_t stride, +inline void vectorized_reduction(char** data, int64_t n, int64_t stride, func_t op, vec_func_t vop, bool reduce) { VEC_LOOP_HEADER(func_t, data) const char* in1_ptr = data[1]; @@ -69,7 +69,7 @@ static inline void vectorized_reduction(char** data, int64_t n, int64_t stride, } template -static inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) { +inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) { for (const auto j C10_UNUSED : c10::irange(n)) { f(); data[0] += strides[0]; @@ -79,7 +79,7 @@ static inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int // computes the reduction out = op(out, in) template -static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) { +inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) { VEC_LOOP_HEADER(func_t, data) int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); int64_t count = n / (4 * Vec::size()); @@ -93,7 +93,7 @@ static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, // computes the reduction out = op(out, in) template -static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { +inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { VEC_LOOP_HEADER(func_t, data) // reduce down each column of 4 * Vec::size() elements (128 or 256 bytes) @@ -132,13 +132,13 @@ static void set_results(const res_t result, const TensorIteratorBase &iter, cons } template -static inline typename std::enable_if::type +inline typename std::enable_if::type for_each_in_tuple(const std::tuple& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) { return i; } template -static inline typename std::enable_if::type +inline typename std::enable_if::type for_each_in_tuple(const std::tuple& t, const TensorIteratorBase &iter, const int num_outputs) { if (i < (size_t)num_outputs) { set_result(i, std::get(t), iter, num_outputs); @@ -286,7 +286,7 @@ void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vo // when reduction is on most inner dimension (dim 0 in TensorIterator) // and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim` // can be used. -static inline bool is_reduce_lastdim(TensorIteratorBase& iter) { +inline bool is_reduce_lastdim(TensorIteratorBase& iter) { return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0) && iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1); } From 917387f66d59a198b0f8313ddded4913d94d22d6 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Sat, 8 Jun 2024 13:23:49 +0000 Subject: [PATCH 531/706] [AOTI] fix a constant tensor device move issue (#128265) Summary: When copying a constant tensor to another device, `.to` returns a fake tensor and causes a problem when a real tensor is expected. Test Plan: CI Differential Revision: D58313034 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128265 Approved by: https://github.com/chenyang78 --- torch/_inductor/graph.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index ca739eecb196..b7b032236907 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -854,10 +854,13 @@ def constant_name(self, name: str, device_override: Optional[torch.device]): """ if self.constants[name].device == device_override or device_override is None: return name - return self.allocate_non_dup_const_name( - f"{name}_{device_override.type}{device_override.index or 0}", - self.constants[name].to(device_override), - ) + with torch.utils._python_dispatch._disable_current_modes(): + # caller might have set fake tensor mode which will create a fake tensor + # when calling .to, so unset modes here + return self.allocate_non_dup_const_name( + f"{name}_{device_override.type}{device_override.index or 0}", + self.constants[name].to(device_override), + ) def placeholder(self, target: str, args, kwargs): example = super().placeholder(target, args, kwargs) From 348b181a97abc2e636a6c18e5880a78e5d1dab94 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 8 Jun 2024 11:57:11 +0000 Subject: [PATCH 532/706] Deprecate `torch._utils.is_compiling()` and `torch._dynamo.external_utils.is_compiling()` (#127690) This PR is split from PR #126898. - #126898 ------ Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690 Approved by: https://github.com/Skylion007 --- test/test_optim.py | 2 +- torch/_dynamo/decorators.py | 3 +-- torch/_dynamo/external_utils.py | 5 +++++ torch/_functorch/apis.py | 6 +++--- torch/_functorch/eager_transforms.py | 4 ++-- torch/_higher_order_ops/associative_scan.py | 2 +- torch/_utils.py | 6 +++++- .../algorithms/ddp_comm_hooks/default_hooks.py | 4 ++-- torch/distributed/tensor/parallel/_utils.py | 2 +- torch/nn/parallel/distributed.py | 4 ++-- torch/optim/adadelta.py | 6 +++--- torch/optim/adam.py | 6 +++--- torch/optim/adamax.py | 6 +++--- torch/optim/adamw.py | 6 +++--- torch/optim/asgd.py | 4 ++-- torch/optim/nadam.py | 4 ++-- torch/optim/optimizer.py | 11 +++++------ torch/optim/radam.py | 4 ++-- torch/optim/rmsprop.py | 6 +++--- torch/optim/rprop.py | 6 +++--- torch/optim/sgd.py | 2 +- torch/testing/_internal/optests/generate_tests.py | 2 +- 22 files changed, 54 insertions(+), 47 deletions(-) diff --git a/test/test_optim.py b/test/test_optim.py index d61c33e2adce..3ab57fecd833 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -287,7 +287,7 @@ def test_param_group_with_lrscheduler_goes_right_direction( inpt = torch.randn(5, device=device, dtype=dtype) # avoid endless recompiles by wrapping LR in a tensor if we're compiling - lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01 + lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01 optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}]) schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c] diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 87fdc6502436..01c629709bd8 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -8,7 +8,6 @@ from .comptime import comptime from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage -from .external_utils import is_compiling if TYPE_CHECKING: from torch._C._dynamo.eval_frame import ( # noqa: F401 @@ -264,7 +263,7 @@ def mark_static(t, index=None): Unlike mark_dynamic, this can be done inside a graph, in which case it induces specialization on the tensor. """ - if is_compiling(): + if torch.compiler.is_compiling(): if index is None: for s in t.size(): comptime.force_static(s) diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 3ba10d34b771..1aea186bb679 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -2,6 +2,7 @@ import functools from typing import List +from typing_extensions import deprecated import torch import torch.utils._pytree as pytree @@ -12,6 +13,10 @@ np = None # type: ignore[assignment] +@deprecated( + "`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.", + category=FutureWarning, +) def is_compiling() -> bool: """ Indicates whether we are tracing/compiling with torch.compile() or torch.export(). diff --git a/torch/_functorch/apis.py b/torch/_functorch/apis.py index ee0c0a1984e4..477a01583b3d 100644 --- a/torch/_functorch/apis.py +++ b/torch/_functorch/apis.py @@ -188,7 +188,7 @@ def vmap( vmap does not provide general autobatching or handle variable-length sequences out of the box. """ - from torch._dynamo import is_compiling + from torch.compiler import is_compiling _check_randomness_arg(randomness) if not (chunk_size is None or chunk_size > 0): @@ -390,7 +390,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla """ # To avoid cyclical dependency. import torch._functorch.eager_transforms as eager_transforms - from torch._dynamo import is_compiling + from torch.compiler import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) @@ -432,8 +432,8 @@ def grad_and_value( See :func:`grad` for examples """ - from torch._dynamo import is_compiling from torch._functorch import eager_transforms + from torch.compiler import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_and_value_impl( diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index fff6bd67838f..80751c9694fd 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -765,7 +765,7 @@ def compute_jacobian_preallocate_and_copy(): # Dynamo does not support HOP composition if their inner function is # annotated with @functools.wraps(...). We circumvent this issue by applying # wraps only if we're not tracing with dynamo. - if not torch._dynamo.is_compiling(): + if not torch.compiler.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) return wrapper_fn @@ -1346,7 +1346,7 @@ def push_jvp(basis): # Dynamo does not support HOP composition if their inner function is # annotated with @functools.wraps(...). We circumvent this issue by applying # wraps only if we're not tracing with dynamo. - if not torch._dynamo.is_compiling(): + if not torch.compiler.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) return wrapper_fn diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index 8b406f39a64d..e0e22eb4202f 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -76,7 +76,7 @@ def add(x: torch.Tensor, y: torch.Tensor): assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}" assert isinstance(dim, int), "dim must be an int, but got {type(dim)}" - if not torch._dynamo.is_compiling(): + if not torch.compiler.is_compiling(): with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): return torch.compile(associative_scan, fullgraph=True)( combine_fn, input, dim diff --git a/torch/_utils.py b/torch/_utils.py index eec2d8231d1a..e6ddb96bfa40 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -7,7 +7,7 @@ import warnings from collections import defaultdict from typing import Any, Callable, DefaultDict, Generic, List, Optional -from typing_extensions import ParamSpec +from typing_extensions import deprecated, ParamSpec import torch @@ -852,6 +852,10 @@ def classproperty(func): return _ClassPropertyDescriptor(func) +@deprecated( + "`torch._utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.", + category=FutureWarning, +) def is_compiling() -> bool: """ Indicates whether we are tracing/compiling with torch.compile() or torch.export(). diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index bff55327e847..791061e34f90 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -85,7 +85,7 @@ def decompress(fut): decompressed_tensor.copy_(value) return decompressed_tensor - if torch._utils.is_compiling(): + if torch.compiler.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) @@ -134,7 +134,7 @@ def decompress(fut): decompressed_tensor.copy_(value) return decompressed_tensor - if torch._utils.is_compiling(): + if torch.compiler.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 3c7e269fffea..876e97f70c5b 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -5,7 +5,7 @@ from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import _mesh_resources try: - from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling + from torch.compiler import is_compiling as is_torchdynamo_compiling except Exception: def is_torchdynamo_compiling(): # type: ignore[misc] return False diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 069be22991cd..37b6501d2c9c 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1481,7 +1481,7 @@ def _lazy_init(self): def _should_disable_cpp_reducer(self) -> bool: return self._use_python_reducer and ( - torch._utils.is_compiling() or self._force_to_disable_cpp_reducer + torch.compiler.is_compiling() or self._force_to_disable_cpp_reducer ) def _pre_forward(self, *inputs, **kwargs): @@ -1494,7 +1494,7 @@ def _pre_forward(self, *inputs, **kwargs): h.remove() self._accum_grad_hooks.clear() - if not self._lazy_init_ran and not torch._utils.is_compiling(): + if not self._lazy_init_ran and not torch.compiler.is_compiling(): self._lazy_init() if self._delay_all_reduce_all_params: diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 097c8040b63e..4d1a4e25319c 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -254,7 +254,7 @@ def _single_tensor_adadelta( has_complex: bool, ): # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -310,7 +310,7 @@ def _multi_tensor_adadelta( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -413,7 +413,7 @@ def adadelta( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch._utils.is_compiling() and not all( + if not torch.compiler.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adam.py b/torch/optim/adam.py index fba4b2027b05..1c625682fc34 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -353,7 +353,7 @@ def _single_tensor_adam( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -466,7 +466,7 @@ def _multi_tensor_adam( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -743,7 +743,7 @@ def adam( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch._utils.is_compiling() and not all( + if not torch.compiler.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 8af468ba8386..005327d8bb88 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -243,7 +243,7 @@ def _single_tensor_adamax( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -315,7 +315,7 @@ def _multi_tensor_adamax( return # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -424,7 +424,7 @@ def adamax( See :class:`~torch.optim.Adamax` for details. """ - if not torch._utils.is_compiling() and not all( + if not torch.compiler.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index e58b28244083..707ac17c361c 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -354,7 +354,7 @@ def _single_tensor_adamw( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -467,7 +467,7 @@ def _multi_tensor_adamw( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -728,7 +728,7 @@ def adamw( See :class:`~torch.optim.AdamW` for details. """ - if not torch._utils.is_compiling() and not all( + if not torch.compiler.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index f53f8b427e9f..633a14832282 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -214,7 +214,7 @@ def _single_tensor_asgd( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type @@ -287,7 +287,7 @@ def _multi_tensor_asgd( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index b860ed3ddda3..fd1f8ab0e718 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -304,7 +304,7 @@ def _single_tensor_nadam( exp_avg_sq = torch.view_as_real(exp_avg_sq) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == mu_product.device.type == step_t.device.type @@ -390,7 +390,7 @@ def _multi_tensor_nadam( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 0352b7976579..fc091e273c36 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -24,7 +24,6 @@ import torch import torch.utils.hooks as hooks -from torch._utils import is_compiling from torch.utils._foreach_utils import ( _get_foreach_kernels_supported_devices, _get_fused_kernels_supported_devices, @@ -97,14 +96,14 @@ def _use_grad(self, *args, **kwargs): def _get_value(x): # item is significantly faster than a cpu tensor in eager mode - if not torch.jit.is_scripting() and is_compiling(): + if not torch.jit.is_scripting() and torch.compiler.is_compiling(): return x else: return x.item() if isinstance(x, torch.Tensor) else x def _stack_if_compiling(x): - if not torch.jit.is_scripting() and is_compiling(): + if not torch.jit.is_scripting() and torch.compiler.is_compiling(): return torch.stack(x) else: return x @@ -145,7 +144,7 @@ def wrapper(func): # the capturable flag. If capturable=True, this is not a problem. @functools.wraps(func) def maybe_fallback(*args, **kwargs): - if is_compiling() and ( + if torch.compiler.is_compiling() and ( not kwargs.get("capturable", False) and has_state_steps and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda) @@ -418,7 +417,7 @@ def _cuda_graph_capture_health_check(self) -> None: # Thus, when compiling, inductor will determine if cudagraphs # can be enabled based on whether there is input mutation or CPU tensors. if ( - not is_compiling() + not torch.compiler.is_compiling() and torch.backends.cuda.is_built() and torch.cuda.is_available() ): @@ -505,7 +504,7 @@ def _group_tensors_by_device_and_dtype( """Groups a list of lists of tensors by device and dtype. Skips this step if we are compiling since this will occur during inductor lowering. """ - if is_compiling(): + if torch.compiler.is_compiling(): return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} else: return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] diff --git a/torch/optim/radam.py b/torch/optim/radam.py index ea592185c887..619f10493587 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -271,7 +271,7 @@ def _single_tensor_radam( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -369,7 +369,7 @@ def _multi_tensor_radam( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index b3375c338b40..bdc3ec0b8b3f 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -276,7 +276,7 @@ def _single_tensor_rmsprop( step = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step.device.type @@ -349,7 +349,7 @@ def _multi_tensor_rmsprop( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert all( p.device.type == step.device.type @@ -467,7 +467,7 @@ def rmsprop( """ # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch._utils.is_compiling() and not all( + if not torch.compiler.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index ec40aae5c90a..af1854cc518a 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -236,7 +236,7 @@ def _single_tensor_rprop( step = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step.device.type @@ -302,7 +302,7 @@ def _multi_tensor_rprop( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert all( p.device.type == step.device.type @@ -414,7 +414,7 @@ def rprop( """ # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch._utils.is_compiling() and not all( + if not torch.compiler.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index a95574a65aba..291b4068dd4c 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -429,7 +429,7 @@ def _multi_tensor_sgd( if not device_has_sparse_grad: # handle internal item() call if lr is a tensor - if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): + if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling(): grads_x_lr = torch._foreach_mul(device_grads, -lr) torch._foreach_add_(device_params, grads_x_lr) else: diff --git a/torch/testing/_internal/optests/generate_tests.py b/torch/testing/_internal/optests/generate_tests.py index 70ee48274800..d01f91563c92 100644 --- a/torch/testing/_internal/optests/generate_tests.py +++ b/torch/testing/_internal/optests/generate_tests.py @@ -569,7 +569,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if ( torch.jit.is_tracing() or torch.jit.is_scripting() - or torch._dynamo.is_compiling() + or torch.compiler.is_compiling() ): return func(*args, **kwargs) # Pre-existing code may not use the .default overload. If we see an From 57a24c4fdb58b82724b4f3d55d3af105a660bf39 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 8 Jun 2024 15:25:25 +0000 Subject: [PATCH 533/706] Revert "[RFC] add per-collective timeout value in flight recorder (#128190)" This reverts commit 09cccbc1c74c9d1157c1caca5526e79ee9b7ea01. Reverted https://github.com/pytorch/pytorch/pull/128190 on behalf of https://github.com/atalman due to Sorry need to revert this, in conflict with https://github.com/pytorch/pytorch/pull/127651 that needs reverting ([comment](https://github.com/pytorch/pytorch/pull/128190#issuecomment-2156075318)) --- test/distributed/test_c10d_nccl.py | 5 +---- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 3 --- torch/csrc/distributed/c10d/TraceUtils.h | 10 +--------- 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index f45600c5d17d..21a8a632bade 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3548,7 +3548,7 @@ def test_short(self, timing_enabled, include_collectives): ) ) ver = t["version"] - self.assertEqual(ver, "2.2") + self.assertEqual(ver, "2.1") pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] @@ -3577,7 +3577,6 @@ def test_short(self, timing_enabled, include_collectives): self.assertEqual(last["output_sizes"], ((3, 4),)) self.assertEqual(last["output_dtypes"], ["Float"]) self.assertEqual(last["collective_seq_id"], 2) - self.assertEqual(last["timeout_ms"], 600000) now = datetime.now() event_created_time = datetime.fromtimestamp( last["time_created_ns"] / 1000000000 @@ -3662,7 +3661,6 @@ def test_long(self): self.assertEqual(last["input_dtypes"], ["Float"]) self.assertEqual(last["output_sizes"], ((3, 4),)) self.assertEqual(last["output_dtypes"], ["Float"]) - self.assertEqual(last["timeout_ms"], 600000) self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9) @requires_nccl() @@ -3867,7 +3865,6 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled): self.assertTrue(0.001 < duration < 10000, duration) else: self.assertTrue("duration_ms" not in t["entries"][coalesced_op]) - self.assertEqual(t["entries"][coalesced_op]["timeout_ms"], 600000) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 07bbcd5a0af4..8adf1e02c1a0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2356,7 +2356,6 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( outputs, r->ncclStartEvent_.get(), r->ncclEndEvent_.get(), - options_->timeout, isP2P); } return r; @@ -2967,7 +2966,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( {tensor}, nullptr, nullptr, - options_->timeout, /*isP2P=*/true); // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get // their timings/states updated by proxy when the Work obj representing the @@ -3001,7 +2999,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( {tensor}, work->ncclStartEvent_.get(), work->ncclEndEvent_.get(), - options_->timeout, /*isP2P=*/true); } diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index de623d77fe9e..c3b0464cf992 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -8,7 +8,6 @@ #include #include #include -#include #ifdef USE_C10D_NCCL #include @@ -29,7 +28,7 @@ static c10::IValue nccl_comm_key = "nccl_comm_state"; static c10::IValue version_key = "version"; // Update whenever changing contents or formatting of the dump // (minor when adding fields, major when changing existing fields) -static c10::IValue version_val = "2.2"; +static c10::IValue version_val = "2.1"; static c10::IValue pg_config_key = "pg_config"; static c10::IValue record_id_key = "record_id"; static c10::IValue pg_id_key = "pg_id"; @@ -45,7 +44,6 @@ static c10::IValue output_sizes_key = "output_sizes"; static c10::IValue output_dtypes_key = "output_dtypes"; static c10::IValue time_created_key = "time_created_ns"; static c10::IValue duration_key = "duration_ms"; -static c10::IValue timeout_key = "timeout_ms"; static c10::IValue frames_key = "frames"; static c10::IValue state_key = "state"; @@ -463,9 +461,6 @@ struct NCCLTraceBuffer { // was 'enqueued'- not necessarily started c10::time_t time_created_; - // configured timeout for this entry - c10::time_t timeout_ms_; - // Is this a P2P event? bool isP2P_; @@ -513,7 +508,6 @@ struct NCCLTraceBuffer { const std::vector& outputs, Event* start, Event* end, - std::chrono::milliseconds timeout_ms, bool isP2P) { if (!enabled_) { return c10::nullopt; @@ -534,7 +528,6 @@ struct NCCLTraceBuffer { std::move(start), std::move(end), c10::getTime(), - timeout_ms.count(), isP2P}; for (const auto& input : inputs) { @@ -759,7 +752,6 @@ struct NCCLTraceBuffer { ? int64_t(*e.time_discovered_completed_) : c10::IValue()); dict.insert(retired_key, e.retired_); - dict.insert(timeout_key, e.timeout_ms_); dict.insert(is_p2p_key, e.isP2P_); entries.push_back(dict); From 02a901f1e9136f45d8993c5b4ec23031b3bf0bcf Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 8 Jun 2024 15:30:04 +0000 Subject: [PATCH 534/706] Revert "[RFC] Provide optional switches to _dump_nccl_trace (#127651)" This reverts commit 0a761f0627130e739f0e2748e3f71a0c347552c4. Reverted https://github.com/pytorch/pytorch/pull/127651 on behalf of https://github.com/atalman due to Breaks internal CI ([comment](https://github.com/pytorch/pytorch/pull/127651#issuecomment-2156076838)) --- test/distributed/test_c10d_nccl.py | 73 +++++++-------- .../distributed/c10d/ProcessGroupNCCL.cpp | 25 ++--- .../distributed/c10d/ProcessGroupNCCL.hpp | 15 +-- torch/csrc/distributed/c10d/TraceUtils.h | 92 +++++++------------ torch/csrc/distributed/c10d/init.cpp | 25 +---- 5 files changed, 78 insertions(+), 152 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 21a8a632bade..baf2adb1fb2d 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3523,8 +3523,7 @@ class NCCLTraceTest(NCCLTraceTestBase): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("timing_enabled", [True, False]) - @parametrize("include_collectives", [True, False]) - def test_short(self, timing_enabled, include_collectives): + def test_short(self, timing_enabled): if self.rank == self.MAIN_PROCESS_RANK: return pg = self._create_process_group_nccl() @@ -3539,14 +3538,8 @@ def test_short(self, timing_enabled, include_collectives): # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api time.sleep(1) - if include_collectives: - t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) - else: - t = pickle.loads( - torch._C._distributed_c10d._dump_nccl_trace( - includeCollectives=False, includeStackTraces=None, onlyActive=None - ) - ) + + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) ver = t["version"] self.assertEqual(ver, "2.1") pg_config = t["pg_config"] @@ -3557,39 +3550,35 @@ def test_short(self, timing_enabled, include_collectives): self.assertIn("ranks", default_pg_info) global_ranks = pg_config["0"]["ranks"] self.assertEqual(len(json.loads(global_ranks)), self.world_size) - if include_collectives: - self.assertEqual(len(t["entries"]), 2) - t = t["entries"] - self.assertEqual(len(t), 2) - last = t[-1] - self.assertEqual(last["process_group"], ("0", "default_pg")) - self.assertEqual(last["state"], "completed") - s = last["time_discovered_started_ns"] - f = last["time_discovered_completed_ns"] - self.assertEqual(last["record_id"], 1) - self.assertIsNotNone(f) - if timing_enabled: - self.assertIsNotNone(s) - self.assertTrue(s <= f) - self.assertIn("test_c10d_nccl.py", str(last["frames"])) - self.assertEqual(last["input_sizes"], ((3, 4),)) - self.assertEqual(last["input_dtypes"], ["Float"]) - self.assertEqual(last["output_sizes"], ((3, 4),)) - self.assertEqual(last["output_dtypes"], ["Float"]) - self.assertEqual(last["collective_seq_id"], 2) - now = datetime.now() - event_created_time = datetime.fromtimestamp( - last["time_created_ns"] / 1000000000 - ) - before_test = now - timedelta(minutes=1) - self.assertTrue(before_test < event_created_time < now) - if timing_enabled: - # very loose bounds, measured 0.036 ms on devgpu - self.assertTrue(0 < last["duration_ms"] < 100) - else: - self.assertTrue("duration_ms" not in last) + t = t["entries"] + self.assertEqual(len(t), 2) + last = t[-1] + self.assertEqual(last["process_group"], ("0", "default_pg")) + self.assertEqual(last["state"], "completed") + s = last["time_discovered_started_ns"] + f = last["time_discovered_completed_ns"] + self.assertEqual(last["record_id"], 1) + self.assertIsNotNone(f) + if timing_enabled: + self.assertIsNotNone(s) + self.assertTrue(s <= f) + self.assertIn("test_c10d_nccl.py", str(last["frames"])) + self.assertEqual(last["input_sizes"], ((3, 4),)) + self.assertEqual(last["input_dtypes"], ["Float"]) + self.assertEqual(last["output_sizes"], ((3, 4),)) + self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["collective_seq_id"], 2) + now = datetime.now() + event_created_time = datetime.fromtimestamp( + last["time_created_ns"] / 1000000000 + ) + before_test = now - timedelta(minutes=1) + self.assertTrue(before_test < event_created_time < now) + if timing_enabled: + # very loose bounds, measured 0.036 ms on devgpu + self.assertTrue(0 < last["duration_ms"] < 100) else: - self.assertTrue("entries" not in t) + self.assertTrue("duration_ms" not in last) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 8adf1e02c1a0..26381207ca7d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -342,10 +342,7 @@ void cacheAllocatorDeregisterHook( } #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) -std::string dump_nccl_trace( - bool includeCollectives, - bool includeStackTraces, - bool onlyActive) { +std::string dump_nccl_trace() { std::unordered_map< std::string /* ncclUniqueID */, std::unordered_map /* dump from this comm */> @@ -365,27 +362,19 @@ std::string dump_nccl_trace( std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); } - return NCCLTraceBuffer::get()->dump( - ncclDumpMap, includeCollectives, includeStackTraces, onlyActive); + return NCCLTraceBuffer::get()->dump(ncclDumpMap); } - #else -std::string dump_nccl_trace( - bool includeCollectives, - bool includeStackTraces, - bool onlyActive) { - return NCCLTraceBuffer::get()->dump( - c10::nullopt, includeCollectives, includeStackTraces, onlyActive); +std::string dump_nccl_trace() { + return NCCLTraceBuffer::get()->dump(c10::nullopt); } #endif // TODO(c-p-i-o): add a JSON endpoint. control_plane::RegisterHandler dumpHandler{ "dump_nccl_trace_pickle", - [](const control_plane::Request& req, control_plane::Response& res) { - // TODO: c-p-i-o: params from the request need to go to dump_nccl_trace. - res.setContent( - dump_nccl_trace(true, true, false), "application/octet-stream"); + [](const control_plane::Request&, control_plane::Response& res) { + res.setContent(dump_nccl_trace(), "application/octet-stream"); }}; std::optional)>>& @@ -1208,7 +1197,7 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() { // We dump nccl trace into local disk by default and users can register // their customized writer by inheriting `DebugInfoWriter` via // `registerDebugInfoWriter`. - auto ncclTrace = dump_nccl_trace(true, true, false); + auto ncclTrace = dump_nccl_trace(); DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank()); LOG(INFO) << logPrefix() << "ProcessGroupNCCL dumping nccl trace to " << writer.getWriterTarget(); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index faaabe411bfc..f36ebdeb16e9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -1114,16 +1114,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { ProcessGroupStatus pgStatus_; }; -// Dumps the NCCL comm traces and additional information about the Process -// Group. -TORCH_API std::string dump_nccl_trace( - bool includeCollectives, - bool includeStackTraces, - bool onlyActive); - -// Gets a mutable reference to a global optional function.Heartbeat Monitor -// will use this function to dump traces, if available. Inside fbcode, we -// store a function here that uses an internal tool for process tracing +TORCH_API std::string dump_nccl_trace(); + +// Gets a mutable reference to a global optional function. Heartbeat Monitor +// will use this function to dump traces, if available. Inside fbcode, we store +// a function here that uses an internal tool for process tracing TORCH_API std::optional< std::function)>>& get_cpp_trace_dumper(); diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index c3b0464cf992..e8dadb6537e0 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -655,44 +655,31 @@ struct NCCLTraceBuffer { entry->start_ = entry->end_ = nullptr; } - const c10::List getCollectiveTrace( - bool includeStacktraces, - bool onlyActive) { - auto entries = new_list(); + std::string dump( + const std::optional>>& ncclDumpMap) { auto result = dump_entries(); + auto entries = new_list(); + std::vector tracebacks; - torch::SymbolizedTracebacks stracebacks; + for (auto& e : result) { + tracebacks.push_back(e.traceback_.get()); + } + torch::SymbolizedTracebacks stracebacks = torch::symbolize(tracebacks); std::vector all_frames; - if (includeStacktraces) { - for (auto& e : result) { - tracebacks.push_back(e.traceback_.get()); - } - stracebacks = torch::symbolize(tracebacks); - for (const auto& f : stracebacks.all_frames) { - auto d = new_dict(); - d.insert(name_key, f.funcname); - d.insert(filename_key, f.filename); - d.insert(line_key, int64_t(f.lineno)); - all_frames.emplace_back(std::move(d)); - } + for (const auto& f : stracebacks.all_frames) { + auto d = new_dict(); + d.insert(name_key, f.funcname); + d.insert(filename_key, f.filename); + d.insert(line_key, int64_t(f.lineno)); + all_frames.emplace_back(std::move(d)); } + for (auto i : c10::irange(result.size())) { - auto dict = new_dict(); auto& e = result.at(i); - // Skip completed events - if (onlyActive && e.time_discovered_completed_.has_value()) { - continue; - } - - if (includeStacktraces) { - auto& tb = stracebacks.tracebacks.at(i); - auto frames = new_list(); - for (int64_t frame : tb) { - frames.push_back(all_frames.at(frame)); - } - dict.insert(frames_key, frames); - } - + auto& tb = stracebacks.tracebacks.at(i); + auto dict = new_dict(); dict.insert(record_id_key, int64_t(e.id_)); dict.insert(pg_id_key, int64_t(e.pg_id_)); dict.insert(pg_name_key, e.pg_name_); @@ -754,13 +741,13 @@ struct NCCLTraceBuffer { dict.insert(retired_key, e.retired_); dict.insert(is_p2p_key, e.isP2P_); + auto frames = new_list(); + for (int64_t frame : tb) { + frames.push_back(all_frames.at(frame)); + } + dict.insert(frames_key, frames); entries.push_back(dict); } - return entries; - } - - // dump pg_entries - const c10::Dict getPgConfig() { auto pg_config = new_dict(); for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { auto pg_info = new_dict(); @@ -769,27 +756,6 @@ struct NCCLTraceBuffer { pg_info.insert("ranks", ranks_str(ranks)); pg_config.insert(std::get<0>(pg_name), pg_info); } - return pg_config; - } - - // dump all collectives + ncclDumpMap - std::string dump( - const std::optional>>& ncclDumpMap, - bool includeCollectives, - bool includeStackTraces, - bool onlyActive) { - auto result = new_dict(); - // common values - result.insert(version_key, version_val); - result.insert(pg_config_key, getPgConfig()); - - // collective trace - if (includeCollectives) { - result.insert( - entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); - } // convert ncclDumpMap into a dictionary auto per_comm_dict = new_dict(); @@ -802,10 +768,16 @@ struct NCCLTraceBuffer { per_comm_dict.insert(ncclId, inner_dict); } } + + auto dict = new_dict(); + dict.insert(entries_key, entries); + dict.insert(version_key, version_val); if (per_comm_dict.size() > 0) { - result.insert(nccl_comm_key, per_comm_dict); + dict.insert(nccl_comm_key, per_comm_dict); } - return pickle_str(result); + dict.insert(pg_config_key, pg_config); + + return pickle_str(dict); } }; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f1b28886b98..f0284c0a3bb7 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -3164,28 +3164,9 @@ such as `dist.all_reduce(tensor, async_op=True)`. Arguments: tensors(List[torch.Tensor]): List of tensors we want to hash. )"); - module.def( - "_dump_nccl_trace", - [](std::optional includeCollectives, - std::optional includeStackTraces, - std::optional onlyActive) { - return py::bytes(::c10d::dump_nccl_trace( - includeCollectives.value_or(true), - includeStackTraces.value_or(true), - onlyActive.value_or(false))); - }, - py::arg("includeCollectives") = std::optional(), - py::arg("includeStackTraces") = std::optional(), - py::arg("onlyActive") = std::optional(), - R"( - Arguments: - includeCollectives(bool, optional): Whether to include collective work traces. Default is True. - includeStackTraces(bool, optional): Whether to include stacktraces in the collective work traces. Default is True. - onlyActive (bool, optional): Whether to only include active collective work traces. Default is False. - Returns: - Stringified pickle work traces. - Default settings return everything - i.e. contains NCCL comm dumps and collective traces. - )"); + module.def("_dump_nccl_trace", []() { + return py::bytes(::c10d::dump_nccl_trace()); + }); #endif intrusive_ptr_class_<::c10d::control_plane::WorkerServer>( From 2369c719d485af0787d95668947125a5605bed88 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 7 Jun 2024 13:52:34 -0700 Subject: [PATCH 535/706] [DSD][BE] Cleanup unused variables and rename variables to avoid exposure to the users (#128249) These APIs and variables should not be exposed to users as they are designed to be used internally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128249 Approved by: https://github.com/wz337 --- torch/distributed/checkpoint/state_dict.py | 143 ++++++++---------- .../distributed/common_state_dict.py | 28 ++-- 2 files changed, 80 insertions(+), 91 deletions(-) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 0d1a3a625a25..95fc57faf8e8 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -53,19 +53,12 @@ from torch.utils._pytree import tree_map_only __all__ = [ - "FLAT_PARAM", - "PG", - "PG_PREFIX", - "STATE", - "STATE_PREFIX", - "PARAMS", "FQNS_T", "PrimitiveType", "ValueType", "DictValueType", "ListDictValueType", "OptimizerStateType", - "gc_context", "StateDictOptions", "get_model_state_dict", "get_optimizer_state_dict", @@ -75,17 +68,13 @@ "set_state_dict", ] -FLAT_PARAM = "_flat_param" -PG = "param_groups" -PG_PREFIX = f"{PG}." -STATE = "state" -STATE_PREFIX = f"{STATE}." -PARAMS = "params" -FQNS_T = Set[str] - -_patched_state_dict: Set[Callable] = set() +_FLAT_PARAM = "_flat_param" +_PG = "param_groups" +_PARAMS = "params" +_STATE = "state" +FQNS_T = Set[str] PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] ValueType = Union[ PrimitiveType, List[PrimitiveType], Tuple[PrimitiveType], Dict[str, "ValueType"] @@ -95,14 +84,16 @@ OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]] +_patched_state_dict: Set[Callable] = set() + + @contextlib.contextmanager -def gc_context(): +def _gc_context(): is_enabled = gc.isenabled() gc.disable() try: yield finally: - # TODO: add logging for the gc details/time if is_enabled: gc.enable() @@ -159,7 +150,6 @@ class _StateDictInfo(StateDictOptions): fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] ] = field(default_factory=dict) - all_fqns: Set[str] = field(default_factory=set) submodule_prefixes: Set[str] = field(default_factory=set) handle_model: bool = True handle_optim: bool = True @@ -204,9 +194,9 @@ def _get_fqns( if not skip_ddp_prefix: fqn_obj_names.append(curr_obj_name) elif isinstance(curr_obj, FSDP): - if i < len(obj_names) - 1 and obj_names[i + 1] == FLAT_PARAM: + if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM: prefix = ".".join(fqn_obj_names) - flat_param = getattr(curr_obj, FLAT_PARAM) + flat_param = getattr(curr_obj, _FLAT_PARAM) if prefix: prefix = f"{prefix}." return {f"{prefix}{fqn}" for fqn in flat_param._fqns} @@ -293,7 +283,6 @@ def _verify_options( fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[Set[str], torch.Tensor] ] = {} - all_fqns = set() for name, param in _iterate_valid_model_state(model): fqns = _get_fqns(model, name) if not isinstance(param, _EXTRA_STATE): @@ -301,7 +290,6 @@ def _verify_options( for fqn in fqns: if not isinstance(param, _EXTRA_STATE): fqn_param_mapping[fqn] = param - all_fqns.add(fqn) submodule_prefixes: Set[str] = set() if submodules: @@ -370,7 +358,6 @@ def fsdp_state_dict_type_without_warning( return _StateDictInfo( **asdict(options), fqn_param_mapping=fqn_param_mapping, - all_fqns=all_fqns, submodule_prefixes=submodule_prefixes, fsdp_context=fsdp_context, fsdp_modules=cast(List[nn.Module], fsdp_modules), @@ -417,9 +404,9 @@ def _verify_state_dict( ) for key in model_state_dict.keys(): - if FLAT_PARAM in key: + if _FLAT_PARAM in key: raise RuntimeError( - f"{key} contains {FLAT_PARAM}. This can happen if the model " + f"{key} contains {_FLAT_PARAM}. This can happen if the model " "is not the root module." ) @@ -571,7 +558,7 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None: return for param_group in optim.param_groups: - for param in param_group[PARAMS]: + for param in param_group[_PARAMS]: if param.grad is not None: raise RuntimeError( "state_dict can only be used if the optimizer " @@ -651,16 +638,16 @@ def _raise_if_type_not_supported(v): ) ret: Dict[str, ValueType] = {} - for fqn, state in cast(DictValueType, state_dict[STATE]).items(): + for fqn, state in cast(DictValueType, state_dict[_STATE]).items(): for k, v in cast(DictValueType, state).items(): _raise_if_type_not_supported(v) - ret[f"{STATE}.{fqn}.{k}"] = v + ret[f"{_STATE}.{fqn}.{k}"] = v - for param_group in cast(ListDictValueType, state_dict[PG]): - fqns = param_group.pop(PARAMS) + for param_group in cast(ListDictValueType, state_dict[_PG]): + fqns = param_group.pop(_PARAMS) for fqn in cast(List[str], fqns): for k, v in param_group.items(): - ret[f"{PG}.{fqn}.{k}"] = v + ret[f"{_PG}.{fqn}.{k}"] = v return ret @@ -675,13 +662,13 @@ def _unflatten_optim_state_dict( """ state: DictValueType = {} pg_state: ListDictValueType = [] - return_osd: OptimizerStateType = {STATE: state, PG: pg_state} + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} for param_group in optim.param_groups: - pg_state.append({PARAMS: []}) - for param in param_group[PARAMS]: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: for fqn in info.fqn_param_mapping[param]: - params = pg_state[-1][PARAMS] + params = pg_state[-1][_PARAMS] assert isinstance(params, list) # typing params.append(fqn) if not param.requires_grad: @@ -689,14 +676,14 @@ def _unflatten_optim_state_dict( state[fqn] = {} for state_name in optim.state[param].keys(): cast(DictValueType, state[fqn])[state_name] = state_dict[ - f"{STATE}.{fqn}.{state_name}" + f"{_STATE}.{fqn}.{state_name}" ] - first_param_fqn = cast(List[str], pg_state[-1][PARAMS])[0] + first_param_fqn = cast(List[str], pg_state[-1][_PARAMS])[0] for k in param_group.keys(): - if k == PARAMS: + if k == _PARAMS: continue - value = state_dict[f"{PG}.{first_param_fqn}.{k}"] + value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] if k not in pg_state[-1]: pg_state[-1][k] = value elif pg_state[-1][k] != value: @@ -717,7 +704,7 @@ def _get_optim_state_dict( if not info.handle_optim: return {} - optim_state_dict: OptimizerStateType = {STATE: {}, PG: []} + optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []} for optim in optimizers: _init_optim_state(optim) osd = _state_dict_fn(optim, "state_dict")() @@ -731,14 +718,14 @@ def _get_optim_state_dict( # We can only use a string replacment without correctness check. if not osd: continue - for k in list(osd[STATE].keys()): + for k in list(osd[_STATE].keys()): if "_orig_mod" in k: - osd[STATE][k.replace("_orig_mod.", "")] = osd[STATE].pop(k) - for g in osd[PG]: - params = [k.replace("_orig_mod.", "") for k in g[PARAMS]] - g[PARAMS] = params + osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k) + for g in osd[_PG]: + params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]] + g[_PARAMS] = params else: - params = list(chain.from_iterable(g[PARAMS] for g in optim.param_groups)) + params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups)) param_pid_mapping = dict(zip(params, range(len(params)))) fqn_pid_mapping = {} for key, param in model.named_parameters(): @@ -751,18 +738,18 @@ def _get_optim_state_dict( fqn_pid_mapping[fqn] = pid fqn_pid_mapping[pid] = fqn - for key in list(osd[STATE].keys()): + for key in list(osd[_STATE].keys()): fqn = fqn_pid_mapping[key] - osd[STATE][fqn] = osd[STATE].pop(key) + osd[_STATE][fqn] = osd[_STATE].pop(key) - for group in osd[PG]: - group[PARAMS] = [fqn_pid_mapping[pid] for pid in group[PARAMS]] + for group in osd[_PG]: + group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]] if not osd: continue - cast(DictValueType, optim_state_dict[STATE]).update(osd[STATE]) - cast(ListDictValueType, optim_state_dict[PG]).extend(osd[PG]) + cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE]) + cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG]) if info.flatten_optimizer_state_dict: optim_state_dict = cast( @@ -795,35 +782,37 @@ def _split_optim_state_dict( state: DictValueType = {} pg_state: ListDictValueType = [] - return_osd: OptimizerStateType = {STATE: state, PG: pg_state} + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} pg_mapping: Dict[int, int] = {} if all( - isinstance(k, int) for k in cast(DictValueType, optim_state_dict[STATE]).keys() + isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys() ): return optim_state_dict for param_group in optim.param_groups: - pg_state.append({PARAMS: []}) - for param in param_group[PARAMS]: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: for fqn in info.fqn_param_mapping[param]: - params = pg_state[-1][PARAMS] + params = pg_state[-1][_PARAMS] assert isinstance(params, list) params.append(fqn) if param.requires_grad: - state[fqn] = cast(DictValueType, optim_state_dict[STATE])[fqn] - for loaded_param_group in cast(ListDictValueType, optim_state_dict[PG]): - params = loaded_param_group[PARAMS] + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + params = loaded_param_group[_PARAMS] assert isinstance(params, list) if fqn in params: - pg_mapping[id(loaded_param_group)] = len(return_osd[PG]) - 1 + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 - for param_group in cast(ListDictValueType, optim_state_dict[PG]): + for param_group in cast(ListDictValueType, optim_state_dict[_PG]): idx = pg_mapping.get(id(param_group), -1) if idx == -1: continue for key, value in param_group.items(): - if key == PARAMS: + if key == _PARAMS: continue # TODO: check if value is the same if exists. pg_state[idx][key] = value @@ -843,7 +832,7 @@ def _load_optim_state_dict( for optim in optimizers: _init_optim_state(optim) if state_dict: - if STATE in state_dict: + if _STATE in state_dict: optim_state_dict = _split_optim_state_dict( model, optim, state_dict, info ) @@ -867,13 +856,13 @@ def _load_optim_state_dict( assert len(fqns) == 1 fqn = fqns.pop() fqn_with_compiler = fqns_with_compiler.pop() - for g in optim_state_dict[PG]: + for g in optim_state_dict[_PG]: val = cast(Dict[str, Any], g) params = [ - key.replace(fqn, fqn_with_compiler) for key in val[PARAMS] + key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS] ] - val[PARAMS] = params - osd_state = cast(DictValueType, optim_state_dict[STATE]) + val[_PARAMS] = params + osd_state = cast(DictValueType, optim_state_dict[_STATE]) for k in list(osd_state.keys()): if fqn in k: osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k) @@ -916,8 +905,8 @@ def _device(t): ) # Note that we do not have to convert the FQN back to param id here if - # order in optim.param_groups[idx][PARAMS] is the same as the one in - # optim_state_dict[PG][idx][PARAMS]. + # order in optim.param_groups[idx][_PARAMS] is the same as the one in + # optim_state_dict[_PG][idx][_PARAMS]. _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict) @@ -945,7 +934,7 @@ def get_model_state_dict( :rtype: typing.Dict[str, ValueType] """ - with gc_context(): + with _gc_context(): info = _verify_options( model, tuple(), @@ -985,7 +974,7 @@ def get_optimizer_state_dict( :rtype: OptimizerStateType """ - with gc_context(): + with _gc_context(): optimizers = ( (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) @@ -1073,7 +1062,7 @@ def get_state_dict( :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] """ - with gc_context(): + with _gc_context(): optimizers = ( (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) @@ -1157,7 +1146,7 @@ def set_model_state_dict( model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( model, model_state_dict ) - with gc_context(): + with _gc_context(): info = _verify_options(model, tuple(), optim_only=False, options=options) _verify_state_dict(model_state_dict, {}, info) @@ -1191,7 +1180,7 @@ def set_optimizer_state_dict( :type optim_state_dict: typing.OptimizerStateType """ - with gc_context(): + with _gc_context(): optimizers = ( (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) @@ -1248,7 +1237,7 @@ def set_state_dict( model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( model, model_state_dict ) - with gc_context(): + with _gc_context(): optimizers = ( (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) diff --git a/torch/testing/_internal/distributed/common_state_dict.py b/torch/testing/_internal/distributed/common_state_dict.py index b1cfda3e1cc3..68783b9856c1 100644 --- a/torch/testing/_internal/distributed/common_state_dict.py +++ b/torch/testing/_internal/distributed/common_state_dict.py @@ -13,9 +13,9 @@ from torch.distributed._state_dict_utils import _gather_state_dict from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.state_dict import ( - PG, + _PG, + _STATE, set_state_dict, - STATE, StateDictOptions, ) @@ -64,10 +64,10 @@ def _verify_osd( fqn_pid_mapping[pid] = fqn # Check optimizer_state_dict state - self.assertEqual(len(osd[STATE]), len(dist_osd[STATE])) - for pid, states in osd[STATE].items(): + self.assertEqual(len(osd[_STATE]), len(dist_osd[_STATE])) + for pid, states in osd[_STATE].items(): fqn = fqn_pid_mapping[pid] - dist_states = dist_osd[STATE].get(fqn, None) + dist_states = dist_osd[_STATE].get(fqn, None) self.assertIsNotNone(dist_states, fqn) self.assertEqual(len(states), len(dist_states)) for key, state in states.items(): @@ -76,17 +76,17 @@ def _verify_osd( self._compare_tensor(state, dist_state) # Check optimizer_state_dict param_group - old_dist_osd_pg = dist_osd[PG] - if len(osd[PG]) != len(dist_osd[PG]): - self.assertTrue(len(dist_osd[PG]) > len(osd[PG])) - new_pg = copy.deepcopy(dist_osd[PG][0]) + old_dist_osd_pg = dist_osd[_PG] + if len(osd[_PG]) != len(dist_osd[_PG]): + self.assertTrue(len(dist_osd[_PG]) > len(osd[_PG])) + new_pg = copy.deepcopy(dist_osd[_PG][0]) new_pg["params"] = [] - for dist_group in dist_osd[PG]: + for dist_group in dist_osd[_PG]: new_pg["params"].extend(dist_group["params"]) - dist_osd[PG] = [new_pg] + dist_osd[_PG] = [new_pg] - self.assertEqual(len(osd[PG]), len(dist_osd[PG])) - for group, dist_group in zip(osd[PG], dist_osd[PG]): + self.assertEqual(len(osd[_PG]), len(dist_osd[_PG])) + for group, dist_group in zip(osd[_PG], dist_osd[_PG]): self.assertEqual(len(group), len(dist_group)) for key, value in group.items(): # Below doesn't work because param_groups can have None @@ -99,7 +99,7 @@ def _verify_osd( self.assertEqual(sorted(fqns), sorted(dist_value)) else: self.assertEqual(value, dist_value) - dist_osd[PG] = old_dist_osd_pg + dist_osd[_PG] = old_dist_osd_pg def _verify_osd_by_load( self, From dcfa7702c3ecd8754e8a66bc49142de00c8474ee Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 11:08:21 -0700 Subject: [PATCH 536/706] Flip default value for mypy disallow_untyped_defs [1/11] (#127838) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127838 Approved by: https://github.com/oulgen --- caffe2/perfkernels/hp_emblookup_codegen.py | 1 + test/test_bundled_images.py | 1 + test/test_bundled_inputs.py | 1 + test/test_complex.py | 1 + test/test_futures.py | 1 + test/test_torch.py | 1 + test/test_type_hints.py | 1 + test/test_type_info.py | 1 + test/test_utils.py | 1 + torch/_C/_VariableFunctions.pyi.in | 1 + torch/_C/__init__.pyi.in | 1 + torch/_C/_autograd.pyi | 1 + torch/_C/_distributed_autograd.pyi | 1 + torch/_C/_distributed_c10d.pyi | 1 + torch/_C/_distributed_rpc.pyi | 1 + torch/_C/_dynamo/eval_frame.pyi | 1 + torch/_C/_dynamo/guards.pyi | 1 + torch/_C/_functorch.pyi | 1 + torch/_C/_lazy.pyi | 1 + torch/_C/_lazy_ts_backend.pyi | 1 + torch/_C/_nvtx.pyi | 1 + torch/_C_flatbuffer/__init__.pyi | 1 + torch/__config__.py | 1 + torch/__init__.py | 1 + torch/_classes.py | 1 + torch/_compile.py | 1 + torch/_custom_op/autograd.py | 1 + torch/_custom_op/functional.py | 1 + torch/_custom_op/impl.py | 1 + torch/_custom_ops.py | 1 + torch/_decomp/__init__.py | 1 + torch/_decomp/decompositions.py | 1 + torch/_decomp/decompositions_for_jvp.py | 1 + torch/_decomp/decompositions_for_rng.py | 1 + torch/_deploy.py | 1 + torch/_dispatch/python.py | 1 + torch/_dynamo/_trace_wrapped_higher_order_op.py | 1 + torch/_dynamo/bytecode_analysis.py | 1 + torch/_dynamo/bytecode_transformation.py | 1 + torch/_dynamo/cache_size.py | 1 + torch/_dynamo/callback.py | 1 + torch/_dynamo/code_context.py | 1 + torch/_dynamo/codegen.py | 1 + torch/_dynamo/compiled_autograd.py | 1 + torch/_dynamo/comptime.py | 1 + torch/_dynamo/config.py | 1 + torch/_dynamo/convert_frame.py | 1 + torch/_dynamo/create_parameter_op.py | 1 + torch/_dynamo/current_scope_id.py | 1 + torch/_dynamo/debug_utils.py | 1 + torch/_dynamo/decorators.py | 1 + torch/_dynamo/device_interface.py | 1 + torch/_dynamo/eval_frame.py | 1 + torch/_dynamo/exc.py | 1 + torch/_dynamo/external_utils.py | 1 + torch/_dynamo/guards.py | 1 + torch/_dynamo/logging.py | 1 + torch/_dynamo/mutation_guard.py | 1 + torch/_dynamo/output_graph.py | 1 + torch/_dynamo/profiler.py | 1 + torch/_dynamo/replay_record.py | 1 + torch/_dynamo/repro/after_aot.py | 1 + torch/_dynamo/repro/after_dynamo.py | 1 + torch/_dynamo/resume_execution.py | 1 + torch/_dynamo/side_effects.py | 1 + torch/_dynamo/source.py | 1 + torch/_dynamo/symbolic_convert.py | 1 + torch/_dynamo/tensor_version_op.py | 1 + torch/_dynamo/test_case.py | 1 + torch/_dynamo/test_minifier_common.py | 1 + torch/_dynamo/testing.py | 1 + torch/_dynamo/trace_rules.py | 1 + torch/_dynamo/utils.py | 1 + torch/_dynamo/variables/script_object.py | 1 + torch/_dynamo/variables/torch.py | 1 + torch/_export/__init__.py | 1 + torch/_export/converter.py | 1 + torch/_export/db/case.py | 1 + torch/_export/db/examples/__init__.py | 1 + torch/_export/db/examples/assume_constant_result.py | 1 + torch/_export/db/examples/autograd_function.py | 1 + torch/_export/db/examples/class_method.py | 1 + torch/_export/db/examples/cond_branch_class_method.py | 1 + torch/_export/db/examples/cond_branch_nested_function.py | 1 + torch/_export/db/examples/cond_branch_nonlocal_variables.py | 1 + torch/_export/db/examples/cond_closed_over_variable.py | 1 + torch/_export/db/examples/cond_operands.py | 1 + torch/_export/db/examples/cond_predicate.py | 1 + torch/_export/db/examples/constrain_as_size_example.py | 1 + torch/_export/db/examples/constrain_as_value_example.py | 1 + torch/_export/db/examples/decorator.py | 1 + torch/_export/db/examples/dictionary.py | 1 + torch/_export/db/examples/dynamic_shape_assert.py | 1 + torch/_export/db/examples/dynamic_shape_constructor.py | 1 + torch/_export/db/examples/dynamic_shape_if_guard.py | 1 + torch/_export/db/examples/dynamic_shape_map.py | 1 + torch/_export/db/examples/dynamic_shape_round.py | 1 + torch/_export/db/examples/dynamic_shape_slicing.py | 1 + torch/_export/db/examples/dynamic_shape_view.py | 1 + torch/_export/db/examples/fn_with_kwargs.py | 1 + torch/_export/db/examples/list_contains.py | 1 + torch/_export/db/examples/list_unpack.py | 1 + torch/_export/db/examples/model_attr_mutation.py | 1 + torch/_export/db/examples/nested_function.py | 1 + torch/_export/db/examples/null_context_manager.py | 1 + torch/_export/db/examples/optional_input.py | 1 + torch/_export/db/examples/pytree_flatten.py | 1 + torch/_export/db/examples/scalar_output.py | 1 + torch/_export/db/examples/specialized_attribute.py | 1 + torch/_export/db/examples/static_for_loop.py | 1 + torch/_export/db/examples/static_if.py | 1 + torch/_export/db/examples/tensor_setattr.py | 1 + torch/_export/db/examples/torch_sym_min.py | 1 + torch/_export/db/examples/type_reflection_method.py | 1 + torch/_export/db/examples/user_input_mutation.py | 1 + torch/_export/db/logging.py | 1 + torch/_export/exported_program.py | 1 + torch/_export/non_strict_utils.py | 1 + torch/_export/pass_base.py | 1 + torch/_export/pass_infra/proxy_value.py | 1 + torch/_export/passes/_node_metadata_hook.py | 1 + .../passes/add_runtime_assertions_for_constraints_pass.py | 1 + torch/_export/passes/collect_tracepoints_pass.py | 1 + torch/_export/passes/lift_constants_pass.py | 1 + torch/_export/passes/remove_runtime_assertions.py | 1 + torch/_export/passes/replace_set_grad_with_hop_pass.py | 1 + torch/_export/passes/replace_sym_size_ops_pass.py | 1 + torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py | 1 + 128 files changed, 128 insertions(+) diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index 7e4208caf655..26018c2c002c 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse diff --git a/test/test_bundled_images.py b/test/test_bundled_images.py index 73f51d008bb1..c6ed9efe9f64 100644 --- a/test/test_bundled_images.py +++ b/test/test_bundled_images.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Owner(s): ["oncall: mobile"] +# mypy: allow-untyped-defs import io diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index 2ba1ee847e8b..007fbd32dde4 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Owner(s): ["oncall: mobile"] +# mypy: allow-untyped-defs import io import textwrap diff --git a/test/test_complex.py b/test/test_complex.py index 04fa566bf94f..67e8732dcbe1 100644 --- a/test/test_complex.py +++ b/test/test_complex.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: complex"] import torch diff --git a/test/test_futures.py b/test/test_futures.py index 33814eda41ea..dd1e79ff83b3 100644 --- a/test/test_futures.py +++ b/test/test_futures.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: unknown"] import threading diff --git a/test/test_torch.py b/test/test_torch.py index ff573706913f..f252ddf4a574 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: tests"] import torch diff --git a/test/test_type_hints.py b/test/test_type_hints.py index a4ae1768cd2a..2fba1ba2f9e4 100644 --- a/test/test_type_hints.py +++ b/test/test_type_hints.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: typing"] import doctest diff --git a/test/test_type_info.py b/test/test_type_info.py index 97bb23e89c99..9160c31b4fb8 100644 --- a/test/test_type_info.py +++ b/test/test_type_info.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: typing"] from torch.testing._internal.common_utils import ( diff --git a/test/test_utils.py b/test/test_utils.py index 66d66b8874f1..b0435e548311 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: unknown"] import os diff --git a/torch/_C/_VariableFunctions.pyi.in b/torch/_C/_VariableFunctions.pyi.in index 24f9f0f9e9fb..9476acb75791 100644 --- a/torch/_C/_VariableFunctions.pyi.in +++ b/torch/_C/_VariableFunctions.pyi.in @@ -1,5 +1,6 @@ # ${generated_comment} # mypy: disable-error-code="type-arg" +# mypy: allow-untyped-defs import builtins from typing import ( diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index bcc26350a896..4326cd3c71da 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1,5 +1,6 @@ # ${generated_comment} # mypy: disable-error-code="type-arg" +# mypy: allow-untyped-defs import builtins from enum import Enum, IntEnum diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 118d913f6815..05a791725608 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum from typing import Any, Callable, List, Optional, Set diff --git a/torch/_C/_distributed_autograd.pyi b/torch/_C/_distributed_autograd.pyi index f4c91304a1b1..dc2a9e9488a9 100644 --- a/torch/_C/_distributed_autograd.pyi +++ b/torch/_C/_distributed_autograd.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Set import torch diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index d6f7ae259a88..cffbf22219c8 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" from datetime import timedelta from enum import Enum diff --git a/torch/_C/_distributed_rpc.pyi b/torch/_C/_distributed_rpc.pyi index 7909e0b8e33c..ded7061bbd49 100644 --- a/torch/_C/_distributed_rpc.pyi +++ b/torch/_C/_distributed_rpc.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" from datetime import timedelta from typing import Any, Dict, Generic, List, Optional, overload, Tuple, Type, TypeVar diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index f3ad6f722827..14321b2f946f 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import types from typing import List, NewType, Optional diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 2de2f10cd328..6b1cf00bce41 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional, Union import torch diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index 111113221a0c..0180586d0bc3 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum from typing import Optional, Tuple diff --git a/torch/_C/_lazy.pyi b/torch/_C/_lazy.pyi index ceaaedee2102..f4f57ee56b34 100644 --- a/torch/_C/_lazy.pyi +++ b/torch/_C/_lazy.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List from torch import Tensor diff --git a/torch/_C/_lazy_ts_backend.pyi b/torch/_C/_lazy_ts_backend.pyi index ce833c5ec2e4..b5e69583dbb9 100644 --- a/torch/_C/_lazy_ts_backend.pyi +++ b/torch/_C/_lazy_ts_backend.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # defined in torch/csrc/lazy/python/init.cpp from typing import Any, List, Tuple diff --git a/torch/_C/_nvtx.pyi b/torch/_C/_nvtx.pyi index f7ff779d8ad7..79c9cc2c4b9b 100644 --- a/torch/_C/_nvtx.pyi +++ b/torch/_C/_nvtx.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Defined in torch/csrc/cuda/shared/nvtx.cpp def rangePushA(message: str) -> int: ... def rangePop() -> int: ... diff --git a/torch/_C_flatbuffer/__init__.pyi b/torch/_C_flatbuffer/__init__.pyi index 3a2ff059b0ed..38750ed26aa2 100644 --- a/torch/_C_flatbuffer/__init__.pyi +++ b/torch/_C_flatbuffer/__init__.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch._C import LiteScriptModule, ScriptModule def _load_mobile_module_from_file(filename: str): ... diff --git a/torch/__config__.py b/torch/__config__.py index f7e3e209654a..fdb091032759 100644 --- a/torch/__config__.py +++ b/torch/__config__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/__init__.py b/torch/__init__.py index 16804ff75898..b07d4ea1c180 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" The torch package contains data structures for multi-dimensional diff --git a/torch/_classes.py b/torch/_classes.py index 870073fea6ea..58b347453524 100644 --- a/torch/_classes.py +++ b/torch/_classes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import types import torch._C diff --git a/torch/_compile.py b/torch/_compile.py index 2b00415e0eba..0f0f51a3509a 100644 --- a/torch/_compile.py +++ b/torch/_compile.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ APIs related to torch.compile which lazily import torch._dynamo to avoid circular dependencies. diff --git a/torch/_custom_op/autograd.py b/torch/_custom_op/autograd.py index 116a4612a45e..35727197d03c 100644 --- a/torch/_custom_op/autograd.py +++ b/torch/_custom_op/autograd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.utils._pytree as pytree from collections import namedtuple diff --git a/torch/_custom_op/functional.py b/torch/_custom_op/functional.py index 26ef5b307bd5..57ff351e2e2d 100644 --- a/torch/_custom_op/functional.py +++ b/torch/_custom_op/functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import weakref import torch diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index d9200160057c..2f3efce60a81 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import functools import inspect diff --git a/torch/_custom_ops.py b/torch/_custom_ops.py index c09a8ae68543..b8231a186c0a 100644 --- a/torch/_custom_ops.py +++ b/torch/_custom_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from torch._custom_op.impl import ( diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index b277bb7eceb0..e0c7e5b6f49d 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from collections import defaultdict from functools import wraps diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 76599d299b29..7c9d342ea0f0 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import numbers import operator diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py index d430386ff360..ce47ac43d372 100644 --- a/torch/_decomp/decompositions_for_jvp.py +++ b/torch/_decomp/decompositions_for_jvp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from typing import Callable, Dict, List, Optional, Tuple diff --git a/torch/_decomp/decompositions_for_rng.py b/torch/_decomp/decompositions_for_rng.py index 1aa762351171..74eb9b9240ae 100644 --- a/torch/_decomp/decompositions_for_rng.py +++ b/torch/_decomp/decompositions_for_rng.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from collections import defaultdict from typing import Callable, Dict diff --git a/torch/_deploy.py b/torch/_deploy.py index 35e8d4976940..3f8adc420672 100644 --- a/torch/_deploy.py +++ b/torch/_deploy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import io import torch diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index d80839dc7e47..1d36623ba861 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import unittest.mock from contextlib import contextmanager diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 6e22cafcc6dd..30e44b000fd5 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._C import DispatchKey from torch._higher_order_ops.utils import autograd_not_implemented diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py index 340378e7266b..541c3e0cc882 100644 --- a/torch/_dynamo/bytecode_analysis.py +++ b/torch/_dynamo/bytecode_analysis.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import bisect import dataclasses import dis diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index f07fe1c7a0e0..63dbdf048f6c 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import dataclasses import dis diff --git a/torch/_dynamo/cache_size.py b/torch/_dynamo/cache_size.py index 340f227a9956..ea5e2ae0ce10 100644 --- a/torch/_dynamo/cache_size.py +++ b/torch/_dynamo/cache_size.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import types import weakref diff --git a/torch/_dynamo/callback.py b/torch/_dynamo/callback.py index a65e2844f215..35f447a80349 100644 --- a/torch/_dynamo/callback.py +++ b/torch/_dynamo/callback.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs class CompilationCallbackHandler: def __init__(self): self.start_callbacks = [] diff --git a/torch/_dynamo/code_context.py b/torch/_dynamo/code_context.py index 0fe19016ca13..59c912bd30f7 100644 --- a/torch/_dynamo/code_context.py +++ b/torch/_dynamo/code_context.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import types from .utils import ExactWeakKeyDictionary diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 6dbd7f36b0b5..ac0d06d9f428 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import dataclasses import re diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index bbc8d722b7e2..f13b53e7ed5f 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import functools from typing import Dict, List, Optional, TYPE_CHECKING diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index 80880588b54e..ffb9fbc47cca 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This file establishes the public comptime interface to Dynamo. # This allows Dynamo users to execute arbitrary Python code while # Dynamo is symbolically evaluating their original programs. diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 6487a2726381..bf3d35c334aa 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import getpass import inspect import os diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 88fb2a85bca2..a1d7e7e6e130 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import cProfile import dis diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py index 42981fcf1015..f6cd12de2021 100644 --- a/torch/_dynamo/create_parameter_op.py +++ b/torch/_dynamo/create_parameter_op.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch doc = """ diff --git a/torch/_dynamo/current_scope_id.py b/torch/_dynamo/current_scope_id.py index 1289bdcdffe4..ad079875b58a 100644 --- a/torch/_dynamo/current_scope_id.py +++ b/torch/_dynamo/current_scope_id.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import threading diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 4b4b37a34da9..e262f8cbdb71 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="method-assign" import copy diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 01c629709bd8..557b9a72dde1 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # ruff: noqa: TCH004 from dataclasses import dataclass from typing import TYPE_CHECKING diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index d93a26546683..aa8848014b34 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 94cad71f7ef5..2fc451cf3d17 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="method-assign" """ diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index d9f4c847d030..f3cc073b8a30 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import textwrap from enum import auto, Enum diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 1aea186bb679..7982e77f6feb 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This module contains functions that *will be allowed* by dynamo import functools diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index ac46b4df0f38..fc3f12847a75 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import ast diff --git a/torch/_dynamo/logging.py b/torch/_dynamo/logging.py index 1e9a820785be..316b3ec817cb 100644 --- a/torch/_dynamo/logging.py +++ b/torch/_dynamo/logging.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import logging diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py index 1fa24cfa25bb..22e2b9999e03 100644 --- a/torch/_dynamo/mutation_guard.py +++ b/torch/_dynamo/mutation_guard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="method-assign" import functools diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 03e7f844cbd4..ee43db8524a7 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import contextlib import copy diff --git a/torch/_dynamo/profiler.py b/torch/_dynamo/profiler.py index b52551c67137..b7e9553ce219 100644 --- a/torch/_dynamo/profiler.py +++ b/torch/_dynamo/profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import os from typing import Any, List diff --git a/torch/_dynamo/replay_record.py b/torch/_dynamo/replay_record.py index 7a312e5d58a9..0049dfe7d3ef 100644 --- a/torch/_dynamo/replay_record.py +++ b/torch/_dynamo/replay_record.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses from dataclasses import field from types import CodeType, ModuleType diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 0dbf3cd5c0e4..98149c72c02c 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import copy import functools diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index 43f761f84d3d..254f293951ee 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import copy import functools diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 387adc06272a..3dae1b3b9b10 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import dataclasses import sys diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 647fae379c54..229282f709cb 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import warnings from typing import Any, Dict, List, Optional, Union diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index ded62ba97d8a..69423712c53c 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import dataclasses import enum diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 71ed48fbb292..41ceaa615916 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import collections.abc import contextlib diff --git a/torch/_dynamo/tensor_version_op.py b/torch/_dynamo/tensor_version_op.py index 4c4246474c1d..290f03ad0c6e 100644 --- a/torch/_dynamo/tensor_version_op.py +++ b/torch/_dynamo/tensor_version_op.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._prims import _make_prim, RETURN_TYPE from torch._subclasses import FakeTensorMode diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index 297ea6e2bc2a..0489b6acc963 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import importlib import logging diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index d12e5a92315a..4736c75785cc 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import io import logging diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 99b6607afead..527e0138fc25 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import dis import functools diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 4f2bee755ae7..585ddb04eda1 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import _collections_abc import _weakrefset import abc diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 60be7898d929..6da8b514f16b 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import atexit import collections import contextlib diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 70354e28bb3d..923437193640 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Dict diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 0b3e28860aaf..4d7b96b6a320 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import inspect import logging diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index d41ff4b53af0..d9a514232569 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import dataclasses import functools diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 777249c24a2a..20b6101948de 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from typing import Any, Dict, List, Optional, Set, Tuple, Union diff --git a/torch/_export/db/case.py b/torch/_export/db/case.py index 6c4c03572e3a..21b456fbe029 100644 --- a/torch/_export/db/case.py +++ b/torch/_export/db/case.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import re import string diff --git a/torch/_export/db/examples/__init__.py b/torch/_export/db/examples/__init__.py index d737548c3d48..2e93d4b80824 100644 --- a/torch/_export/db/examples/__init__.py +++ b/torch/_export/db/examples/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import glob import importlib from os.path import basename, dirname, isfile, join diff --git a/torch/_export/db/examples/assume_constant_result.py b/torch/_export/db/examples/assume_constant_result.py index 0078200bc0f0..1503e0c91134 100644 --- a/torch/_export/db/examples/assume_constant_result.py +++ b/torch/_export/db/examples/assume_constant_result.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch._dynamo as torchdynamo diff --git a/torch/_export/db/examples/autograd_function.py b/torch/_export/db/examples/autograd_function.py index 9c8aeadc45ae..3c9099b0cdb8 100644 --- a/torch/_export/db/examples/autograd_function.py +++ b/torch/_export/db/examples/autograd_function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/class_method.py b/torch/_export/db/examples/class_method.py index 838a0a1cdb67..831339372274 100644 --- a/torch/_export/db/examples/class_method.py +++ b/torch/_export/db/examples/class_method.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_branch_class_method.py b/torch/_export/db/examples/cond_branch_class_method.py index 40430d23c0f2..21fe1d25516a 100644 --- a/torch/_export/db/examples/cond_branch_class_method.py +++ b/torch/_export/db/examples/cond_branch_class_method.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_branch_nested_function.py b/torch/_export/db/examples/cond_branch_nested_function.py index 00bce0b580a1..03639c0a207d 100644 --- a/torch/_export/db/examples/cond_branch_nested_function.py +++ b/torch/_export/db/examples/cond_branch_nested_function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_branch_nonlocal_variables.py b/torch/_export/db/examples/cond_branch_nonlocal_variables.py index 2db6192117df..676e7d21ffd2 100644 --- a/torch/_export/db/examples/cond_branch_nonlocal_variables.py +++ b/torch/_export/db/examples/cond_branch_nonlocal_variables.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_closed_over_variable.py b/torch/_export/db/examples/cond_closed_over_variable.py index 226576cc83f7..cf4787f481c4 100644 --- a/torch/_export/db/examples/cond_closed_over_variable.py +++ b/torch/_export/db/examples/cond_closed_over_variable.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_operands.py b/torch/_export/db/examples/cond_operands.py index 1a0db6a110d3..03fd467959a2 100644 --- a/torch/_export/db/examples/cond_operands.py +++ b/torch/_export/db/examples/cond_operands.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_predicate.py b/torch/_export/db/examples/cond_predicate.py index c72c11e32f57..fa3cdeaf3b05 100644 --- a/torch/_export/db/examples/cond_predicate.py +++ b/torch/_export/db/examples/cond_predicate.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/constrain_as_size_example.py b/torch/_export/db/examples/constrain_as_size_example.py index 16d646252414..a3664b7e80f1 100644 --- a/torch/_export/db/examples/constrain_as_size_example.py +++ b/torch/_export/db/examples/constrain_as_size_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/constrain_as_value_example.py b/torch/_export/db/examples/constrain_as_value_example.py index 1de266c689c4..b1b412d41391 100644 --- a/torch/_export/db/examples/constrain_as_value_example.py +++ b/torch/_export/db/examples/constrain_as_value_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/decorator.py b/torch/_export/db/examples/decorator.py index fbc95182e60e..da963ce7da01 100644 --- a/torch/_export/db/examples/decorator.py +++ b/torch/_export/db/examples/decorator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch diff --git a/torch/_export/db/examples/dictionary.py b/torch/_export/db/examples/dictionary.py index 5a210906e680..19f138e6f4d1 100644 --- a/torch/_export/db/examples/dictionary.py +++ b/torch/_export/db/examples/dictionary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_assert.py b/torch/_export/db/examples/dynamic_shape_assert.py index 52cc43a21049..57ba98552e0c 100644 --- a/torch/_export/db/examples/dynamic_shape_assert.py +++ b/torch/_export/db/examples/dynamic_shape_assert.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_constructor.py b/torch/_export/db/examples/dynamic_shape_constructor.py index 599747f7968a..5ce7fdda2877 100644 --- a/torch/_export/db/examples/dynamic_shape_constructor.py +++ b/torch/_export/db/examples/dynamic_shape_constructor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_if_guard.py b/torch/_export/db/examples/dynamic_shape_if_guard.py index 2120ec0145fe..9350c6d992f5 100644 --- a/torch/_export/db/examples/dynamic_shape_if_guard.py +++ b/torch/_export/db/examples/dynamic_shape_if_guard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_map.py b/torch/_export/db/examples/dynamic_shape_map.py index 5607c2796d68..421d4b355efb 100644 --- a/torch/_export/db/examples/dynamic_shape_map.py +++ b/torch/_export/db/examples/dynamic_shape_map.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_round.py b/torch/_export/db/examples/dynamic_shape_round.py index d581d6d839bc..57a1e07dab97 100644 --- a/torch/_export/db/examples/dynamic_shape_round.py +++ b/torch/_export/db/examples/dynamic_shape_round.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/dynamic_shape_slicing.py b/torch/_export/db/examples/dynamic_shape_slicing.py index eb237876f4e6..ddc2f86f774c 100644 --- a/torch/_export/db/examples/dynamic_shape_slicing.py +++ b/torch/_export/db/examples/dynamic_shape_slicing.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_view.py b/torch/_export/db/examples/dynamic_shape_view.py index bcedd04cf36f..666da36ad2a8 100644 --- a/torch/_export/db/examples/dynamic_shape_view.py +++ b/torch/_export/db/examples/dynamic_shape_view.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/fn_with_kwargs.py b/torch/_export/db/examples/fn_with_kwargs.py index 6182a7479555..d5a9a23415d9 100644 --- a/torch/_export/db/examples/fn_with_kwargs.py +++ b/torch/_export/db/examples/fn_with_kwargs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, ExportArgs, SupportLevel diff --git a/torch/_export/db/examples/list_contains.py b/torch/_export/db/examples/list_contains.py index d25d815cde1a..6105220c09b9 100644 --- a/torch/_export/db/examples/list_contains.py +++ b/torch/_export/db/examples/list_contains.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/list_unpack.py b/torch/_export/db/examples/list_unpack.py index 2251c6eb360d..66b4fe456a0d 100644 --- a/torch/_export/db/examples/list_unpack.py +++ b/torch/_export/db/examples/list_unpack.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List import torch diff --git a/torch/_export/db/examples/model_attr_mutation.py b/torch/_export/db/examples/model_attr_mutation.py index 409a0c0f6c03..4c2a03d4e77b 100644 --- a/torch/_export/db/examples/model_attr_mutation.py +++ b/torch/_export/db/examples/model_attr_mutation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/nested_function.py b/torch/_export/db/examples/nested_function.py index 608ef39d5187..cc668ee561a6 100644 --- a/torch/_export/db/examples/nested_function.py +++ b/torch/_export/db/examples/nested_function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/null_context_manager.py b/torch/_export/db/examples/null_context_manager.py index da759b0980fa..ff4b94e6bf44 100644 --- a/torch/_export/db/examples/null_context_manager.py +++ b/torch/_export/db/examples/null_context_manager.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/_export/db/examples/optional_input.py b/torch/_export/db/examples/optional_input.py index 47bb5e1bab8d..dfc256d6a5ce 100644 --- a/torch/_export/db/examples/optional_input.py +++ b/torch/_export/db/examples/optional_input.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/pytree_flatten.py b/torch/_export/db/examples/pytree_flatten.py index 0d799b2a609a..9c91cc21df3c 100644 --- a/torch/_export/db/examples/pytree_flatten.py +++ b/torch/_export/db/examples/pytree_flatten.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/scalar_output.py b/torch/_export/db/examples/scalar_output.py index 86217847bff8..46e03c1f7e94 100644 --- a/torch/_export/db/examples/scalar_output.py +++ b/torch/_export/db/examples/scalar_output.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/specialized_attribute.py b/torch/_export/db/examples/specialized_attribute.py index 3f8f09c4128d..a53ad213c63f 100644 --- a/torch/_export/db/examples/specialized_attribute.py +++ b/torch/_export/db/examples/specialized_attribute.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum import torch diff --git a/torch/_export/db/examples/static_for_loop.py b/torch/_export/db/examples/static_for_loop.py index af14f6fe8ae1..4ad60737ff5d 100644 --- a/torch/_export/db/examples/static_for_loop.py +++ b/torch/_export/db/examples/static_for_loop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/static_if.py b/torch/_export/db/examples/static_if.py index 048bf20ce8bf..bc5dce9f0667 100644 --- a/torch/_export/db/examples/static_if.py +++ b/torch/_export/db/examples/static_if.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/tensor_setattr.py b/torch/_export/db/examples/tensor_setattr.py index fae18fb1cf93..201dca37c81a 100644 --- a/torch/_export/db/examples/tensor_setattr.py +++ b/torch/_export/db/examples/tensor_setattr.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/torch_sym_min.py b/torch/_export/db/examples/torch_sym_min.py index f7edc7003f14..a8fe560773a4 100644 --- a/torch/_export/db/examples/torch_sym_min.py +++ b/torch/_export/db/examples/torch_sym_min.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/type_reflection_method.py b/torch/_export/db/examples/type_reflection_method.py index 869fb4cadd65..5d6570ca0cb9 100644 --- a/torch/_export/db/examples/type_reflection_method.py +++ b/torch/_export/db/examples/type_reflection_method.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel, export_rewrite_case diff --git a/torch/_export/db/examples/user_input_mutation.py b/torch/_export/db/examples/user_input_mutation.py index 01c5d775a264..b60036257617 100644 --- a/torch/_export/db/examples/user_input_mutation.py +++ b/torch/_export/db/examples/user_input_mutation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/logging.py b/torch/_export/db/logging.py index fc412b8c5082..8cd0827d3893 100644 --- a/torch/_export/db/logging.py +++ b/torch/_export/db/logging.py @@ -1,2 +1,3 @@ +# mypy: allow-untyped-defs def exportdb_error_message(case_name: str): return "" diff --git a/torch/_export/exported_program.py b/torch/_export/exported_program.py index 5d28ea315490..49dfd0cf996e 100644 --- a/torch/_export/exported_program.py +++ b/torch/_export/exported_program.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index d15cb29f28df..9db3653de1e2 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import inspect from collections import defaultdict diff --git a/torch/_export/pass_base.py b/torch/_export/pass_base.py index 1cf7e75ad5f9..2200193e78a5 100644 --- a/torch/_export/pass_base.py +++ b/torch/_export/pass_base.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator import traceback import typing diff --git a/torch/_export/pass_infra/proxy_value.py b/torch/_export/pass_infra/proxy_value.py index 66592d48a45e..07d888b30656 100644 --- a/torch/_export/pass_infra/proxy_value.py +++ b/torch/_export/pass_infra/proxy_value.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # pyre-strict from typing import Union diff --git a/torch/_export/passes/_node_metadata_hook.py b/torch/_export/passes/_node_metadata_hook.py index e04059a9114a..3dd87b546da8 100644 --- a/torch/_export/passes/_node_metadata_hook.py +++ b/torch/_export/passes/_node_metadata_hook.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py index 5a2a8b5874bf..44f0ea270212 100644 --- a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +++ b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import operator import traceback diff --git a/torch/_export/passes/collect_tracepoints_pass.py b/torch/_export/passes/collect_tracepoints_pass.py index ca8eaf30be59..8d65a720b9d7 100644 --- a/torch/_export/passes/collect_tracepoints_pass.py +++ b/torch/_export/passes/collect_tracepoints_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator import torch diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 83914fb828c5..d9cd62ffc928 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections from typing import Any, Dict, List, Union diff --git a/torch/_export/passes/remove_runtime_assertions.py b/torch/_export/passes/remove_runtime_assertions.py index adcc708e5548..a80b62d2765a 100644 --- a/torch/_export/passes/remove_runtime_assertions.py +++ b/torch/_export/passes/remove_runtime_assertions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx.passes.infra.pass_base import PassBase, PassResult diff --git a/torch/_export/passes/replace_set_grad_with_hop_pass.py b/torch/_export/passes/replace_set_grad_with_hop_pass.py index 91104c17c38d..0b0bef582e45 100644 --- a/torch/_export/passes/replace_set_grad_with_hop_pass.py +++ b/torch/_export/passes/replace_set_grad_with_hop_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import copy diff --git a/torch/_export/passes/replace_sym_size_ops_pass.py b/torch/_export/passes/replace_sym_size_ops_pass.py index 109a96d7b4bd..29d594d41f06 100644 --- a/torch/_export/passes/replace_sym_size_ops_pass.py +++ b/torch/_export/passes/replace_sym_size_ops_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict import torch diff --git a/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py b/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py index f32b442733eb..edc249b572b5 100644 --- a/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py +++ b/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, Optional, Set import torch From ea614fb2b1c43f85c5a10ee1a227f90251b889d9 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 11:19:36 -0700 Subject: [PATCH 537/706] Flip default value for mypy disallow_untyped_defs [2/11] (#127839) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127839 Approved by: https://github.com/oulgen --- torch/_export/serde/schema_check.py | 1 + torch/_export/serde/serialize.py | 1 + torch/_export/serde/union.py | 1 + torch/_export/serde/upgrade.py | 1 + torch/_export/tools.py | 1 + torch/_export/utils.py | 1 + torch/_export/verifier.py | 1 + torch/_export/wrappers.py | 1 + torch/_functorch/_aot_autograd/autograd_cache.py | 1 + torch/_functorch/_aot_autograd/collect_metadata_analysis.py | 1 + torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py | 1 + torch/_functorch/_aot_autograd/functional_utils.py | 1 + torch/_functorch/_aot_autograd/input_output_analysis.py | 1 + .../_functorch/_aot_autograd/jit_compile_runtime_wrappers.py | 1 + torch/_functorch/_aot_autograd/logging_utils.py | 1 + torch/_functorch/_aot_autograd/runtime_wrappers.py | 1 + torch/_functorch/_aot_autograd/schemas.py | 1 + torch/_functorch/_aot_autograd/subclass_utils.py | 1 + torch/_functorch/_aot_autograd/traced_function_transforms.py | 1 + torch/_functorch/_aot_autograd/utils.py | 1 + torch/_functorch/apis.py | 1 + torch/_functorch/autograd_function.py | 1 + torch/_functorch/batch_norm_replacement.py | 1 + torch/_functorch/deprecated.py | 1 + torch/_functorch/functional_call.py | 1 + torch/_functorch/make_functional.py | 1 + torch/_functorch/partitioners.py | 1 + torch/_functorch/pyfunctorch.py | 1 + torch/_functorch/utils.py | 1 + torch/_guards.py | 1 + torch/_higher_order_ops/associative_scan.py | 1 + torch/_higher_order_ops/auto_functionalize.py | 1 + torch/_higher_order_ops/cond.py | 1 + torch/_higher_order_ops/effects.py | 1 + torch/_higher_order_ops/flex_attention.py | 1 + torch/_higher_order_ops/map.py | 1 + torch/_higher_order_ops/out_dtype.py | 1 + torch/_higher_order_ops/strict_mode.py | 1 + torch/_higher_order_ops/torchbind.py | 1 + torch/_higher_order_ops/triton_kernel_wrap.py | 1 + torch/_higher_order_ops/utils.py | 1 + torch/_higher_order_ops/while_loop.py | 1 + torch/_higher_order_ops/wrap.py | 1 + torch/_inductor/__init__.py | 1 + torch/_inductor/async_compile.py | 1 + torch/_inductor/autotune_process.py | 1 + torch/_inductor/bounds.py | 1 + torch/_inductor/codecache.py | 1 + torch/_inductor/codegen/aoti_hipify_utils.py | 1 + torch/_inductor/codegen/common.py | 1 + torch/_inductor/codegen/cpp.py | 1 + torch/_inductor/codegen/cpp_gemm_template.py | 1 + torch/_inductor/codegen/cpp_micro_gemm.py | 1 + torch/_inductor/codegen/cpp_template.py | 1 + torch/_inductor/codegen/cpp_template_kernel.py | 1 + torch/_inductor/codegen/cpp_utils.py | 1 + torch/_inductor/codegen/cpp_wrapper_cpu.py | 1 + torch/_inductor/codegen/cpp_wrapper_cuda.py | 1 + torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py | 1 + torch/_inductor/codegen/cuda/cuda_kernel.py | 1 + torch/_inductor/codegen/cuda/cuda_template.py | 1 + torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py | 1 + .../cuda/cutlass_lib_extensions/gemm_operation_extensions.py | 1 + torch/_inductor/codegen/cuda/cutlass_utils.py | 1 + torch/_inductor/codegen/cuda/device_op_overrides.py | 1 + torch/_inductor/codegen/cuda/gemm_template.py | 1 + torch/_inductor/codegen/cuda_combined_scheduling.py | 1 + torch/_inductor/codegen/memory_planning.py | 1 + torch/_inductor/codegen/multi_kernel.py | 1 + torch/_inductor/codegen/simd.py | 1 + torch/_inductor/codegen/triton.py | 1 + torch/_inductor/codegen/triton_foreach.py | 1 + torch/_inductor/codegen/triton_split_scan.py | 1 + torch/_inductor/codegen/triton_utils.py | 1 + torch/_inductor/codegen/wrapper.py | 1 + torch/_inductor/codegen/xpu/device_op_overrides.py | 1 + torch/_inductor/comms.py | 1 + torch/_inductor/compile_fx.py | 1 + torch/_inductor/compile_worker/__main__.py | 1 + torch/_inductor/compile_worker/subproc_pool.py | 1 + torch/_inductor/compile_worker/watchdog.py | 1 + torch/_inductor/config.py | 5 +++-- torch/_inductor/constant_folding.py | 1 + torch/_inductor/cudagraph_trees.py | 1 + torch/_inductor/cudagraph_utils.py | 1 + torch/_inductor/debug.py | 1 + torch/_inductor/decomposition.py | 1 + torch/_inductor/dependencies.py | 1 + torch/_inductor/exc.py | 1 + torch/_inductor/freezing.py | 1 + torch/_inductor/fx_passes/binary_folding.py | 1 + torch/_inductor/fx_passes/decompose_mem_bound_mm.py | 1 + torch/_inductor/fx_passes/dedupe_symint_uses.py | 1 + torch/_inductor/fx_passes/efficient_conv_bn_eval.py | 1 + torch/_inductor/fx_passes/freezing_patterns.py | 1 + torch/_inductor/fx_passes/fuse_attention.py | 1 + torch/_inductor/fx_passes/group_batch_fusion.py | 1 + torch/_inductor/fx_passes/joint_graph.py | 1 + torch/_inductor/fx_passes/misc_patterns.py | 1 + torch/_inductor/fx_passes/mkldnn_fusion.py | 1 + torch/_inductor/fx_passes/numeric_utils.py | 1 + torch/_inductor/fx_passes/pad_mm.py | 1 + torch/_inductor/fx_passes/post_grad.py | 1 + torch/_inductor/fx_passes/pre_grad.py | 1 + torch/_inductor/fx_passes/quantization.py | 1 + torch/_inductor/fx_passes/reinplace.py | 1 + torch/_inductor/fx_passes/replace_random.py | 1 + torch/_inductor/fx_passes/split_cat.py | 1 + torch/_inductor/fx_utils.py | 1 + torch/_inductor/graph.py | 1 + torch/_inductor/hooks.py | 1 + torch/_inductor/index_propagation.py | 1 + torch/_inductor/inductor_prims.py | 1 + torch/_inductor/ir.py | 1 + torch/_inductor/kernel/bmm.py | 1 + torch/_inductor/kernel/conv.py | 1 + torch/_inductor/kernel/flex_attention.py | 1 + torch/_inductor/kernel/mm.py | 1 + torch/_inductor/kernel/mm_common.py | 1 + torch/_inductor/kernel/mm_plus_mm.py | 1 + torch/_inductor/kernel/unpack_mixed_mm.py | 1 + torch/_inductor/lowering.py | 1 + torch/_inductor/metrics.py | 1 + torch/_inductor/mkldnn_lowerings.py | 1 + torch/_inductor/ops_handler.py | 1 + torch/_inductor/optimize_indexing.py | 1 + torch/_inductor/quantized_lowerings.py | 1 + 127 files changed, 129 insertions(+), 2 deletions(-) diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index cde4cf1ada27..b22b9778819e 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import hashlib import re diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 8d6dc939fb5c..51dbc435deaf 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import base64 import copy import copyreg diff --git a/torch/_export/serde/union.py b/torch/_export/serde/union.py index 8dfce61f0ab2..b129e8dd9a89 100644 --- a/torch/_export/serde/union.py +++ b/torch/_export/serde/union.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from dataclasses import fields from typing import Hashable, Set diff --git a/torch/_export/serde/upgrade.py b/torch/_export/serde/upgrade.py index d35fe7e1586c..c427a4030c9c 100644 --- a/torch/_export/serde/upgrade.py +++ b/torch/_export/serde/upgrade.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs class GraphModuleOpUpgrader: diff --git a/torch/_export/tools.py b/torch/_export/tools.py index d76392993bd2..23fae4a9196c 100644 --- a/torch/_export/tools.py +++ b/torch/_export/tools.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import warnings from typing import Any, Dict, Iterable, Optional, Tuple diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 772bd3e124b7..1cec59aaaa0c 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import dataclasses import inspect diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 3f89324642eb..07b5ca097400 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import math import operator diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py index 5ca2375ec124..c18ed34a395c 100644 --- a/torch/_export/wrappers.py +++ b/torch/_export/wrappers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager import torch diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index dd3ec09408aa..8144a47f057a 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Utils for caching the outputs of AOTAutograd """ diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 991e12a59d4b..44301291a91f 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module is one of the analysis modules - it takes as input a function or graph and some preexisting properties, and returns some data that is useful for deciding diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py index c956d58b645e..c38a98366cb3 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module dispatches the graphs to either the forward-only or joint compilation pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata. diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index 2e0a7d322f6f..a8af6f0366cc 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file contains utilities related to functionalization in AOTAutograd: 1. converting to/from functional tensors diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index 9a02dffb3d1b..29a32ee03078 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module is one of the analysis modules - it takes as input a function or graph and some preexisting properties, and returns some data that is useful for deciding diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index b8093b3cdc98..5eb681889d8a 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Functions in this module do most of the "work" of AOTAutograd. An aot_dispatch_* function: diff --git a/torch/_functorch/_aot_autograd/logging_utils.py b/torch/_functorch/_aot_autograd/logging_utils.py index 414166cbdd2f..c961f74dc6c1 100644 --- a/torch/_functorch/_aot_autograd/logging_utils.py +++ b/torch/_functorch/_aot_autograd/logging_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Contains utils for logging in AOTAutograd, including managing the names of the graphs under compilation, capturing user-friendly tracebacks, and debug messages. diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index fd188eb6a700..3293db8f8a93 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module defines runtime wrappers, which, based on previous analysis attempts to: 1. process the inputs and outputs diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 3246f142ca43..338bff655b66 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The various dataclasses, Enums, namedtuples etc used in AOTAutograd. This includes input/output types, metadata, config, function signatures etc. diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index cee3cf6e4eda..98f08bb786c4 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file contains utilities for tracing through __torch_dispatch__ based tensor subclasses and modes. AOTAutograd's responsibility is to trace through all pytorch capabilities that live in the pytorch dispatcher, diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py index 27d3f2c9ad99..fa33d9fd79c4 100644 --- a/torch/_functorch/_aot_autograd/traced_function_transforms.py +++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module is responsible for transforming functions to be traced into a form that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis) diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index a479dd2712a4..3d577d2b37b5 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Contains various utils for AOTAutograd, including those for handling collections. """ diff --git a/torch/_functorch/apis.py b/torch/_functorch/apis.py index 477a01583b3d..8d4a77457867 100644 --- a/torch/_functorch/apis.py +++ b/torch/_functorch/apis.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can # trace through functorch transforms. # Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index 03bfd710ae34..b827fb20424c 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, NamedTuple, Tuple import torch diff --git a/torch/_functorch/batch_norm_replacement.py b/torch/_functorch/batch_norm_replacement.py index a2df284138e7..672a8ce76955 100644 --- a/torch/_functorch/batch_norm_replacement.py +++ b/torch/_functorch/batch_norm_replacement.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.nn as nn from torch._functorch.utils import exposed_in diff --git a/torch/_functorch/deprecated.py b/torch/_functorch/deprecated.py index 058e206599c5..ebb930e8ecb7 100644 --- a/torch/_functorch/deprecated.py +++ b/torch/_functorch/deprecated.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The APIs in this file are exposed as `functorch.*`. They are thin wrappers around the torch.func.* APIs that have deprecation warnings -- we're trying diff --git a/torch/_functorch/functional_call.py b/torch/_functorch/functional_call.py index 7533811ed235..5552036e8ddf 100644 --- a/torch/_functorch/functional_call.py +++ b/torch/_functorch/functional_call.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import Counter from typing import Any, Dict, List, Optional, Sequence, Tuple, Union diff --git a/torch/_functorch/make_functional.py b/torch/_functorch/make_functional.py index 711be174d827..8932f750551c 100644 --- a/torch/_functorch/make_functional.py +++ b/torch/_functorch/make_functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index fc1c995e5907..8e954f910ba4 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import heapq diff --git a/torch/_functorch/pyfunctorch.py b/torch/_functorch/pyfunctorch.py index 5a78facf08c0..fb2aae84c0b9 100644 --- a/torch/_functorch/pyfunctorch.py +++ b/torch/_functorch/pyfunctorch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from abc import ABC, abstractmethod from typing import Any, List, Tuple diff --git a/torch/_functorch/utils.py b/torch/_functorch/utils.py index 303ebbc45d63..5e88b8462c5f 100644 --- a/torch/_functorch/utils.py +++ b/torch/_functorch/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Tuple, Union diff --git a/torch/_guards.py b/torch/_guards.py index 4dccd4aa84e6..92041700f0b0 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import contextlib diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index e0e22eb4202f..540d5c1a77f9 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools from typing import Callable, List diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index 89263bd65e7a..189f746b77a0 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional, Tuple, Union import torch diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 359feb192ae5..f4fe64d67f0b 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index f76596a3c6f3..a8da01fe06ec 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum from typing import Any, Dict, Optional, Tuple diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index f4586a0a57b0..c2efa3b48b7f 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Callable, Tuple, Union import torch diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 2bf88ea19565..f5bf1d43c19f 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.utils._pytree as pytree from torch._C import DispatchKey diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py index f675519ee182..a3f5e2115aee 100644 --- a/torch/_higher_order_ops/out_dtype.py +++ b/torch/_higher_order_ops/out_dtype.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.utils._pytree as pytree diff --git a/torch/_higher_order_ops/strict_mode.py b/torch/_higher_order_ops/strict_mode.py index 81c20bc3462b..d781248a19c9 100644 --- a/torch/_higher_order_ops/strict_mode.py +++ b/torch/_higher_order_ops/strict_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch._subclasses.functional_tensor diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py index 235dfe6ec416..744e559e65d0 100644 --- a/torch/_higher_order_ops/torchbind.py +++ b/torch/_higher_order_ops/torchbind.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from contextlib import contextmanager diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index a99afaaa9547..5552ef1ff8b2 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import inspect import logging diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 0fcf22bcc338..84c029084ae3 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from contextlib import contextmanager from dataclasses import dataclass diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index b0ab00bdfac4..4577036b731f 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Callable, Tuple, Union import torch diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index f288c350f0ee..6d83a44e752a 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import logging diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 0d7cd8cece49..9d9445c5de3f 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional, Tuple import torch.fx diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index b8e3d338dd9b..496a7a5ad841 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index e2503e1d8ca2..71171b3a4c32 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import contextlib diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 4640ec4dce6b..a1412adb505d 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from functools import partial from typing import Any, Callable, Dict diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index ca6cacaa213b..71815a31718e 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import base64 diff --git a/torch/_inductor/codegen/aoti_hipify_utils.py b/torch/_inductor/codegen/aoti_hipify_utils.py index a86ef2d29761..9edfe839946d 100644 --- a/torch/_inductor/codegen/aoti_hipify_utils.py +++ b/torch/_inductor/codegen/aoti_hipify_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.utils.hipify.hipify_python import PYTORCH_MAP, RE_PYTORCH_PREPROCESSOR diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index f7b3e7a45d6e..29d6db791672 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import dataclasses import functools diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index eabb5bbef470..35e604a35e48 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import dataclasses import functools diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index e0a4c0993549..ce45ada78eba 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast, List, Optional import torch diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 649782ff158d..65b270285f47 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import namedtuple from typing import Dict, List, Optional, Type diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index aeebd2698aa5..e46465178840 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 5a6c6969b20c..34065e412f84 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from typing import Any, Callable, Dict, List, Optional, Tuple, Union diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 4ab33a5e26dc..438ac908486a 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from collections import namedtuple diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 3dc397a84b8d..65ff4ebf4e69 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import math import os diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index 2519f80b6626..ad8c8eafbbd1 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import os from itertools import chain, count diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 5c91736e9abd..0b91219d8f03 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import cast, Sequence diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index 8cad41082d64..12b7b21de61e 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 871c8b388494..24a02efe3805 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging diff --git a/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py b/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py index d8bf408dc28a..11258382ad21 100644 --- a/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py +++ b/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List from unittest.mock import patch diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py index 2a386a114e86..4ee8af3949ae 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from ..cutlass_utils import try_import_cutlass if try_import_cutlass(): diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index 789a2e44152c..04866fe4deb1 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging import os diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 93a8c08b6a0f..7ff99b871c82 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from ..common import DeviceOpOverrides, register_device_op_overrides diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 89c326cef546..3a7dccf7442b 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import enum import logging diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index f7be73c247fd..0b5b9d795202 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Sequence, Union from ..scheduler import ( diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 3489a61f2d86..435bd2d895ce 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import collections diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 8b4dbb179016..84279191ceac 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import os from typing import Any, List diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index c5fc2747bee7..2063a183385b 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import collections diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 215d4d866980..f366533e3b94 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import dataclasses diff --git a/torch/_inductor/codegen/triton_foreach.py b/torch/_inductor/codegen/triton_foreach.py index 8ed909ec823a..4a909a6025d5 100644 --- a/torch/_inductor/codegen/triton_foreach.py +++ b/torch/_inductor/codegen/triton_foreach.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from collections import defaultdict from dataclasses import dataclass diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 6df3f39a9724..1e0475ffd0f9 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Optional, Set diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index ea6f25ae2c0a..2e4107f85916 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional import sympy diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 1daaa534ee4e..41b9fdc180bc 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import contextlib import dataclasses diff --git a/torch/_inductor/codegen/xpu/device_op_overrides.py b/torch/_inductor/codegen/xpu/device_op_overrides.py index 1f1258898290..6eec71344ae8 100644 --- a/torch/_inductor/codegen/xpu/device_op_overrides.py +++ b/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from ..common import DeviceOpOverrides, register_device_op_overrides diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index a1fe0e1cdceb..9f95f7354437 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # pyre-strict from typing import List diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 618df9fd8ff6..d0069c5cc219 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import functools import itertools diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py index fc8148f20c5f..7f0965415bbf 100644 --- a/torch/_inductor/compile_worker/__main__.py +++ b/torch/_inductor/compile_worker/__main__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import os import sys diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index fbed608b851b..03bfe6c3f203 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging diff --git a/torch/_inductor/compile_worker/watchdog.py b/torch/_inductor/compile_worker/watchdog.py index c91c9efb492c..f3956e1272e9 100644 --- a/torch/_inductor/compile_worker/watchdog.py +++ b/torch/_inductor/compile_worker/watchdog.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import signal from threading import Thread diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 148cf1684875..2ea60000d265 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os # noqa: C101 import sys from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union @@ -805,8 +806,8 @@ class cuda: # Path to CUDA NVCC. # NVCC search order: # 1) cuda_cxx set in this config - # 2)CUDACXX environment variable - # 3)CUDA_HOME environment variable + # 2) CUDACXX environment variable + # 3) CUDA_HOME environment variable # 4) default system search PATH. cuda_cxx: Optional[str] = None diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 5f5cc12be872..523aac95d354 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections from typing import Any, Callable, Dict, Optional diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index d49404ddafde..2b6a9dab45da 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables, which share the same memory pool. Sharing a memory pool is an extremely diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 8556a0f751ed..188c91ba65f0 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses from typing import Any, Callable, Dict, List, Optional, Tuple diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index dcc2b3ab3e4c..b0ad369c4316 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import contextlib import dataclasses diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 960c3a42e1f1..c9c3eb579e6c 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging import math diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index d7cd3ce64f4f..d5abfaa49696 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import collections import dataclasses diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index 9e6aa6effae2..27dcc6d8ef2d 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import os diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py index 7d7cbed25193..9a5f12820a2b 100644 --- a/torch/_inductor/freezing.py +++ b/torch/_inductor/freezing.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import itertools diff --git a/torch/_inductor/fx_passes/binary_folding.py b/torch/_inductor/fx_passes/binary_folding.py index 5cfabf9b7707..7453cde1ce9d 100644 --- a/torch/_inductor/fx_passes/binary_folding.py +++ b/torch/_inductor/fx_passes/binary_folding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools diff --git a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py index 793d29383f56..66f0afed9e7d 100644 --- a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +++ b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import List diff --git a/torch/_inductor/fx_passes/dedupe_symint_uses.py b/torch/_inductor/fx_passes/dedupe_symint_uses.py index 7145508a3ae2..646e8d16f4d2 100644 --- a/torch/_inductor/fx_passes/dedupe_symint_uses.py +++ b/torch/_inductor/fx_passes/dedupe_symint_uses.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from typing import Union diff --git a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py index 7ab01e0abbb2..7aecc3f15f33 100644 --- a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +++ b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py index fe39b13033a7..039fea2dcca2 100644 --- a/torch/_inductor/fx_passes/freezing_patterns.py +++ b/torch/_inductor/fx_passes/freezing_patterns.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index a6c1d11bd78a..2a646bc4c4ce 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import inspect import logging diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 289fe0dbead8..7c095841140d 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import logging import operator diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 477bfe670cec..8358fdd0bd31 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import logging import typing diff --git a/torch/_inductor/fx_passes/misc_patterns.py b/torch/_inductor/fx_passes/misc_patterns.py index 76c641e3e8eb..f2d943cab241 100644 --- a/torch/_inductor/fx_passes/misc_patterns.py +++ b/torch/_inductor/fx_passes/misc_patterns.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Dict, Set, Tuple diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index be73a09ca648..97d45ae4f5f2 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import operator from functools import reduce diff --git a/torch/_inductor/fx_passes/numeric_utils.py b/torch/_inductor/fx_passes/numeric_utils.py index 44d0564fe3ea..5bad4ed9489c 100644 --- a/torch/_inductor/fx_passes/numeric_utils.py +++ b/torch/_inductor/fx_passes/numeric_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import gc import logging import os diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index b2a64df57d36..f7b7977bffc1 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import operator diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 3677f27e1d20..3f36c2e7918f 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 1cfa104ea995..717a46811802 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import itertools import logging diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 4476a9ccd512..5d2a087face4 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import itertools diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 27730ea17905..bae75aae249d 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import operator from collections import defaultdict diff --git a/torch/_inductor/fx_passes/replace_random.py b/torch/_inductor/fx_passes/replace_random.py index 59d4c3891226..4265bf7f26bd 100644 --- a/torch/_inductor/fx_passes/replace_random.py +++ b/torch/_inductor/fx_passes/replace_random.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import logging diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 34757b3b5b1e..b5014a8780f1 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import logging import operator diff --git a/torch/_inductor/fx_utils.py b/torch/_inductor/fx_utils.py index 5ccff50c1d45..8f3ed2e9177c 100644 --- a/torch/_inductor/fx_utils.py +++ b/torch/_inductor/fx_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from collections import defaultdict from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index b7b032236907..4c5ea746f3f5 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import logging import operator diff --git a/torch/_inductor/hooks.py b/torch/_inductor/hooks.py index 2b558f4350a7..bf4a8bb090aa 100644 --- a/torch/_inductor/hooks.py +++ b/torch/_inductor/hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Callable, List, TYPE_CHECKING diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 77b73ffd6842..2ec43bce36f0 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file implements the IndexPropagation ops handler, which wraps an underlying handler to add a limited form of constant propagation, as well as propagation of sympy expressions downstream of ops.index_expr calls. diff --git a/torch/_inductor/inductor_prims.py b/torch/_inductor/inductor_prims.py index 0a00650b1c38..c50686d9ee61 100644 --- a/torch/_inductor/inductor_prims.py +++ b/torch/_inductor/inductor_prims.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import logging diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index c46cad5e41e2..8bf3bb22f93f 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import contextlib import dataclasses diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index a8650cd32c3f..7d1fbc0b35e8 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import torch diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 205919b48723..f3b4cd8ac430 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 42fabf65591d..b2c5bb271501 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Triton Implementation of the flex_attention Kernel""" import logging diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 2f30aa941837..de811fd41c0f 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging from typing import Any, Dict, List, Optional diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 1ca0558d19c0..9ffaba040e7f 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index 931aa592556b..f2f810d1fe02 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch diff --git a/torch/_inductor/kernel/unpack_mixed_mm.py b/torch/_inductor/kernel/unpack_mixed_mm.py index c0053b15c16a..c483dbff2b85 100644 --- a/torch/_inductor/kernel/unpack_mixed_mm.py +++ b/torch/_inductor/kernel/unpack_mixed_mm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import List, TYPE_CHECKING diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 0a1909890e69..0519211c01aa 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 76f15243c5ba..3d8de535542e 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import csv diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 1f64574d589b..f1d82dcf7d60 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 5630061b4426..20d652019372 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from typing import ( Any, diff --git a/torch/_inductor/optimize_indexing.py b/torch/_inductor/optimize_indexing.py index 0d5f2d0b2db7..63887b347364 100644 --- a/torch/_inductor/optimize_indexing.py +++ b/torch/_inductor/optimize_indexing.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import sympy diff --git a/torch/_inductor/quantized_lowerings.py b/torch/_inductor/quantized_lowerings.py index 7b4edf0627dd..954a85abe52e 100644 --- a/torch/_inductor/quantized_lowerings.py +++ b/torch/_inductor/quantized_lowerings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from . import lowering From afe15d2d2fe86b812d7c71777ab0f78d85ea9903 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 11:24:41 -0700 Subject: [PATCH 538/706] Flip default value for mypy disallow_untyped_defs [3/11] (#127840) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127840 Approved by: https://github.com/oulgen --- torch/_inductor/remote_cache.py | 1 + torch/_inductor/runtime/compile_tasks.py | 1 + torch/_inductor/runtime/coordinate_descent_tuner.py | 1 + torch/_inductor/runtime/hints.py | 1 + torch/_inductor/runtime/runtime_utils.py | 1 + torch/_inductor/runtime/triton_helpers.py | 1 + torch/_inductor/runtime/triton_heuristics.py | 1 + torch/_inductor/select_algorithm.py | 1 + torch/_inductor/sizevars.py | 1 + torch/_inductor/subgraph_lowering.py | 1 + torch/_inductor/test_case.py | 1 + torch/_inductor/test_operators.py | 1 + torch/_inductor/utils.py | 1 + torch/_inductor/virtualized.py | 1 + torch/_inductor/wrapper_benchmark.py | 1 + torch/_jit_internal.py | 1 + torch/_lazy/__init__.py | 1 + torch/_lazy/closure.py | 1 + torch/_lazy/computation.py | 1 + torch/_lazy/config.py | 1 + torch/_lazy/debug.py | 1 + torch/_lazy/device_context.py | 1 + torch/_lazy/extract_compiled_graph.py | 1 + torch/_lazy/ir_cache.py | 1 + torch/_lazy/metrics.py | 1 + torch/_lazy/ts_backend.py | 1 + torch/_library/abstract_impl.py | 1 + torch/_library/autograd.py | 1 + torch/_library/custom_ops.py | 1 + torch/_library/fake_class_registry.py | 1 + torch/_library/infer_schema.py | 1 + torch/_library/simple_registry.py | 1 + torch/_library/utils.py | 1 + torch/_linalg_utils.py | 1 + torch/_lobpcg.py | 1 + torch/_logging/_internal.py | 1 + torch/_meta_registrations.py | 1 + torch/_namedtensor_internals.py | 1 + torch/_ops.py | 1 + torch/_prims/__init__.py | 1 + torch/_prims/context.py | 1 + torch/_prims/debug_prims.py | 1 + torch/_prims/executor.py | 1 + torch/_prims/rng_prims.py | 1 + torch/_prims_common/__init__.py | 1 + torch/_prims_common/wrappers.py | 1 + torch/_python_dispatcher.py | 1 + torch/_refs/__init__.py | 1 + torch/_refs/_conversions.py | 1 + torch/_refs/linalg/__init__.py | 1 + torch/_refs/nn/functional/__init__.py | 1 + torch/_refs/special/__init__.py | 1 + torch/_size_docs.py | 1 + torch/_sources.py | 1 + torch/_storage_docs.py | 1 + torch/_streambase.py | 1 + torch/_strobelight/examples/cli_function_profiler_example.py | 1 + torch/_strobelight/examples/compile_time_profile_example.py | 1 + torch/_subclasses/fake_tensor.py | 1 + torch/_subclasses/functional_tensor.py | 1 + torch/_subclasses/meta_utils.py | 1 + torch/_tensor.py | 1 + torch/_tensor_docs.py | 1 + torch/_tensor_str.py | 1 + torch/_torch_docs.py | 1 + torch/_utils.py | 1 + torch/_utils_internal.py | 1 + torch/_vmap_internals.py | 1 + torch/_weights_only_unpickler.py | 1 + torch/amp/autocast_mode.py | 1 + torch/amp/grad_scaler.py | 1 + torch/ao/__init__.py | 1 + torch/ao/nn/__init__.py | 1 + torch/ao/nn/intrinsic/__init__.py | 1 + torch/ao/nn/intrinsic/modules/fused.py | 1 + torch/ao/nn/intrinsic/qat/modules/conv_fused.py | 1 + torch/ao/nn/intrinsic/qat/modules/linear_fused.py | 1 + torch/ao/nn/intrinsic/qat/modules/linear_relu.py | 1 + torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py | 1 + torch/ao/nn/intrinsic/quantized/modules/bn_relu.py | 1 + torch/ao/nn/intrinsic/quantized/modules/conv_add.py | 1 + torch/ao/nn/intrinsic/quantized/modules/conv_relu.py | 1 + torch/ao/nn/intrinsic/quantized/modules/linear_relu.py | 1 + torch/ao/nn/qat/dynamic/modules/linear.py | 1 + torch/ao/nn/qat/modules/conv.py | 1 + torch/ao/nn/qat/modules/embedding_ops.py | 1 + torch/ao/nn/qat/modules/linear.py | 1 + torch/ao/nn/quantizable/modules/activation.py | 1 + torch/ao/nn/quantizable/modules/rnn.py | 1 + torch/ao/nn/quantized/dynamic/modules/conv.py | 1 + torch/ao/nn/quantized/dynamic/modules/linear.py | 1 + torch/ao/nn/quantized/dynamic/modules/rnn.py | 1 + torch/ao/nn/quantized/functional.py | 1 + torch/ao/nn/quantized/modules/__init__.py | 1 + torch/ao/nn/quantized/modules/activation.py | 1 + torch/ao/nn/quantized/modules/batchnorm.py | 1 + torch/ao/nn/quantized/modules/conv.py | 1 + torch/ao/nn/quantized/modules/dropout.py | 1 + torch/ao/nn/quantized/modules/embedding_ops.py | 1 + torch/ao/nn/quantized/modules/functional_modules.py | 1 + torch/ao/nn/quantized/modules/linear.py | 1 + torch/ao/nn/quantized/modules/normalization.py | 1 + torch/ao/nn/quantized/modules/rnn.py | 1 + torch/ao/nn/quantized/modules/utils.py | 1 + torch/ao/nn/quantized/reference/modules/conv.py | 1 + torch/ao/nn/quantized/reference/modules/linear.py | 1 + torch/ao/nn/quantized/reference/modules/rnn.py | 1 + torch/ao/nn/quantized/reference/modules/sparse.py | 1 + torch/ao/nn/quantized/reference/modules/utils.py | 1 + torch/ao/nn/sparse/quantized/dynamic/linear.py | 1 + torch/ao/nn/sparse/quantized/linear.py | 1 + torch/ao/nn/sparse/quantized/utils.py | 1 + torch/ao/ns/_numeric_suite.py | 1 + torch/ao/ns/_numeric_suite_fx.py | 1 + torch/ao/ns/fx/graph_matcher.py | 1 + torch/ao/ns/fx/graph_passes.py | 1 + torch/ao/ns/fx/n_shadows_utils.py | 1 + torch/ao/ns/fx/qconfig_multi_mapping.py | 1 + torch/ao/ns/fx/utils.py | 1 + .../_experimental/activation_sparsifier/activation_sparsifier.py | 1 + .../pruning/_experimental/data_scheduler/base_data_scheduler.py | 1 + .../_experimental/data_sparsifier/base_data_sparsifier.py | 1 + .../_experimental/data_sparsifier/benchmarks/dlrm_utils.py | 1 + .../data_sparsifier/benchmarks/evaluate_disk_savings.py | 1 + .../data_sparsifier/benchmarks/evaluate_forward_time.py | 1 + .../data_sparsifier/benchmarks/evaluate_model_metrics.py | 1 + .../_experimental/data_sparsifier/data_norm_sparsifier.py | 1 + .../data_sparsifier/lightning/callbacks/_data_sparstity_utils.py | 1 + 128 files changed, 128 insertions(+) diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 5bf3f50154e8..91b69b3bf6f5 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os from abc import abstractmethod diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 878125c9fabc..da30bd46b112 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index b5d10478a03c..31ff94774613 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import itertools import logging diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 46acd83c7377..ba36f40a2263 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import typing from dataclasses import fields diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index bc3a3d008f3c..51a6c22644b8 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index 95708ada9020..845bec583f6d 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs try: import triton import triton.language as tl diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 6629e0fe5e77..5e05368e0a11 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import copy import functools diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 5e5cbf35baf9..4f9d12b13c64 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import contextlib import functools diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index abb95503e8e2..1a863bc5485d 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index 9413ac1b2659..4f7eec8ff50c 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utilities for lowering subgraphs used by higher order operators """ diff --git a/torch/_inductor/test_case.py b/torch/_inductor/test_case.py index 3933c9dbc004..3acc68ff22a5 100644 --- a/torch/_inductor/test_case.py +++ b/torch/_inductor/test_case.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import os diff --git a/torch/_inductor/test_operators.py b/torch/_inductor/test_operators.py index 8e85f8bebbdb..3c105ba7db2d 100644 --- a/torch/_inductor/test_operators.py +++ b/torch/_inductor/test_operators.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.library from torch import Tensor from torch.autograd import Function diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index d19ef0cd3004..9a83b3d10d40 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import collections diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 07c6ea8190a6..ac8d3c640141 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file provides a number of "global" variables/handlers that are actually thread local and dynamically scoped, with Inductor patching them to various diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 3e952765695f..976d0c7458e7 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import tempfile from collections import defaultdict diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 57458a0801ab..4ed425f0435a 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The weak_script annotation needs to be here instead of inside torch/jit/ so it can be used in other places in torch/ (namely torch.nn) without running into diff --git a/torch/_lazy/__init__.py b/torch/_lazy/__init__.py index 249ce9b11578..c074abd14372 100644 --- a/torch/_lazy/__init__.py +++ b/torch/_lazy/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import threading import torch._C._lazy diff --git a/torch/_lazy/closure.py b/torch/_lazy/closure.py index 07f1055ee827..32b2c58ba2b8 100644 --- a/torch/_lazy/closure.py +++ b/torch/_lazy/closure.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import threading from queue import Empty as EmptyQueue, Queue diff --git a/torch/_lazy/computation.py b/torch/_lazy/computation.py index 27b73c42e5c0..17a61e36cb9f 100644 --- a/torch/_lazy/computation.py +++ b/torch/_lazy/computation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy import torch._C._lazy_ts_backend diff --git a/torch/_lazy/config.py b/torch/_lazy/config.py index e7a4d1dd24f8..f7ebca12de7f 100644 --- a/torch/_lazy/config.py +++ b/torch/_lazy/config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy diff --git a/torch/_lazy/debug.py b/torch/_lazy/debug.py index 286aa049280c..84534fb23250 100644 --- a/torch/_lazy/debug.py +++ b/torch/_lazy/debug.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy diff --git a/torch/_lazy/device_context.py b/torch/_lazy/device_context.py index 840c7f8e50d0..bc47835fd912 100644 --- a/torch/_lazy/device_context.py +++ b/torch/_lazy/device_context.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import threading from typing import Any, Dict diff --git a/torch/_lazy/extract_compiled_graph.py b/torch/_lazy/extract_compiled_graph.py index 033d000c69d8..7c1cb95855b9 100644 --- a/torch/_lazy/extract_compiled_graph.py +++ b/torch/_lazy/extract_compiled_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import dataclasses import itertools diff --git a/torch/_lazy/ir_cache.py b/torch/_lazy/ir_cache.py index 4270684d2943..a6e654566f29 100644 --- a/torch/_lazy/ir_cache.py +++ b/torch/_lazy/ir_cache.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy diff --git a/torch/_lazy/metrics.py b/torch/_lazy/metrics.py index 2d7db7305567..a77981feb90d 100644 --- a/torch/_lazy/metrics.py +++ b/torch/_lazy/metrics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy diff --git a/torch/_lazy/ts_backend.py b/torch/_lazy/ts_backend.py index 184223771932..5c6ce13746e9 100644 --- a/torch/_lazy/ts_backend.py +++ b/torch/_lazy/ts_backend.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy_ts_backend diff --git a/torch/_library/abstract_impl.py b/torch/_library/abstract_impl.py index 2946b743ee53..1f0f4c87bab7 100644 --- a/torch/_library/abstract_impl.py +++ b/torch/_library/abstract_impl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import functools from typing import Callable, Optional diff --git a/torch/_library/autograd.py b/torch/_library/autograd.py index ebd35361a940..1ff5696417f3 100644 --- a/torch/_library/autograd.py +++ b/torch/_library/autograd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses from typing import Any, Callable, Optional, Protocol diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 20758d24e37a..ce692f16a097 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import weakref from typing import ( diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index aaa57d79e283..f206b68fc3be 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import Any, Dict, Optional, Protocol, Tuple diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index fd03f9182434..6305375e4433 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import typing diff --git a/torch/_library/simple_registry.py b/torch/_library/simple_registry.py index 64a543e99b0b..65ecf8ef0d75 100644 --- a/torch/_library/simple_registry.py +++ b/torch/_library/simple_registry.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .abstract_impl import AbstractImplHolder __all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"] diff --git a/torch/_library/utils.py b/torch/_library/utils.py index d3577dbbf9d1..27d1ef92b5b3 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import inspect import sys diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index 198decab4826..fd5f574ad7eb 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Various linear algebra utility methods for internal use. """ diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index 864b5dc6245f..3f7bdf456c39 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Locally Optimal Block Preconditioned Conjugate Gradient methods. """ # Author: Pearu Peterson diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 798eeabc5d6b..bfc071b0d53a 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import hashlib import itertools diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 759870b4427d..3afe3a98d102 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from enum import Enum from typing import List, Optional, Sequence, Tuple, Union diff --git a/torch/_namedtensor_internals.py b/torch/_namedtensor_internals.py index cbc9de2de091..3791d17c2e42 100644 --- a/torch/_namedtensor_internals.py +++ b/torch/_namedtensor_internals.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict """ diff --git a/torch/_ops.py b/torch/_ops.py index 83a7b6b849df..ed8c788b8af6 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import ctypes import importlib diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 498b6fa9a2cb..603658ea6151 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import itertools import operator diff --git a/torch/_prims/context.py b/torch/_prims/context.py index 2c7a030b3509..81cc47dc86e5 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from contextlib import nullcontext from typing import Any, Callable, Dict, Optional, Sequence diff --git a/torch/_prims/debug_prims.py b/torch/_prims/debug_prims.py index ea3854d04bbd..9683c163827d 100644 --- a/torch/_prims/debug_prims.py +++ b/torch/_prims/debug_prims.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Optional diff --git a/torch/_prims/executor.py b/torch/_prims/executor.py index bb2fafce8726..8d80af720e79 100644 --- a/torch/_prims/executor.py +++ b/torch/_prims/executor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Callable, Optional from torch._prims.context import TorchRefsMode diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index 616940d57036..1345ff0334f5 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional, Tuple import torch diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 10290535f930..11b97403f308 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import operator diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 9057edc87594..89088aaaf049 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import warnings from functools import wraps diff --git a/torch/_python_dispatcher.py b/torch/_python_dispatcher.py index bfd208eddb9e..644cf92fda2b 100644 --- a/torch/_python_dispatcher.py +++ b/torch/_python_dispatcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import re import torch._C as C diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index ca941f41f07f..db1f2a99d3d4 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import collections import inspect diff --git a/torch/_refs/_conversions.py b/torch/_refs/_conversions.py index fa1ca2428255..b312f8f6eada 100644 --- a/torch/_refs/_conversions.py +++ b/torch/_refs/_conversions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch._prims_common as utils diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index bffc9a3df2c8..411087b773ea 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import partial from typing import List, Optional, Tuple, Union diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index dd06febbcd6c..8383d888bbe8 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from functools import wraps from typing import Callable, Optional, Union diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py index 14ec33cf208f..1e98deaeb16d 100644 --- a/torch/_refs/special/__init__.py +++ b/torch/_refs/special/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from typing import Optional, Union diff --git a/torch/_size_docs.py b/torch/_size_docs.py index 58587be32f1d..b678e3dfd12a 100644 --- a/torch/_size_docs.py +++ b/torch/_size_docs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Adds docstrings to torch.Size functions""" import torch._C diff --git a/torch/_sources.py b/torch/_sources.py index 3f56bd8ef247..dd2a863bfc7e 100644 --- a/torch/_sources.py +++ b/torch/_sources.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import functools import inspect diff --git a/torch/_storage_docs.py b/torch/_storage_docs.py index 5d6df58d2b6b..edf5d696ad89 100644 --- a/torch/_storage_docs.py +++ b/torch/_storage_docs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Adds docstrings to Storage functions""" import torch._C diff --git a/torch/_streambase.py b/torch/_streambase.py index b06946523fa3..85e203a3d993 100644 --- a/torch/_streambase.py +++ b/torch/_streambase.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC, abstractmethod diff --git a/torch/_strobelight/examples/cli_function_profiler_example.py b/torch/_strobelight/examples/cli_function_profiler_example.py index 8142ef1bdc77..2ddf62f065f5 100644 --- a/torch/_strobelight/examples/cli_function_profiler_example.py +++ b/torch/_strobelight/examples/cli_function_profiler_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._strobelight.cli_function_profiler import ( diff --git a/torch/_strobelight/examples/compile_time_profile_example.py b/torch/_strobelight/examples/compile_time_profile_example.py index 338727206076..93fffa4ad01a 100644 --- a/torch/_strobelight/examples/compile_time_profile_example.py +++ b/torch/_strobelight/examples/compile_time_profile_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 47d4abcf77b9..f9075c603f11 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import functools import logging diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index dfef5951ab26..4040774fe225 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import warnings from abc import ABC, abstractmethod diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 5aeccce2e1ee..4ea0db56aae2 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import contextlib diff --git a/torch/_tensor.py b/torch/_tensor.py index 712cbc3863d8..5ea2985c2d3f 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copyreg import enum import functools diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 88cae5b27aa3..07d94b57f791 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Adds docstrings to Tensor functions""" import torch._C diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index eddbe4d8b729..461f3a26b58a 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import dataclasses import math diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index ab244dab2635..ad44998d92dc 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Adds docstrings to functions defined in the torch._C module.""" import re diff --git a/torch/_utils.py b/torch/_utils.py index e6ddb96bfa40..b0dcb448092a 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copyreg import functools import logging diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 91b7a3722f55..0001888f18ed 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging import os diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index 465e5dbdca1b..cc23d7851eb5 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Any, Callable, List, Optional, Tuple, Union from typing_extensions import deprecated diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 9cc74c05e45f..2ca07d15136c 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Unpickler restricted to loading only state dicts # Restrict constructing types to a list defined in _get_allowed_globals() # Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index e33533d2c833..ad8892a3099d 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import functools import warnings diff --git a/torch/amp/grad_scaler.py b/torch/amp/grad_scaler.py index a72c6246c99e..bb5cf8204c08 100644 --- a/torch/amp/grad_scaler.py +++ b/torch/amp/grad_scaler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import inspect diff --git a/torch/ao/__init__.py b/torch/ao/__init__.py index fe6f3a460316..32b1048ad35d 100644 --- a/torch/ao/__init__.py +++ b/torch/ao/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # torch.ao is a package with a lot of interdependencies. # We will use lazy import to avoid cyclic dependencies here. diff --git a/torch/ao/nn/__init__.py b/torch/ao/nn/__init__.py index 88a5a03af1cc..4041508e0b9b 100644 --- a/torch/ao/nn/__init__.py +++ b/torch/ao/nn/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # We are exposing all subpackages to the end-user. # Because of possible inter-dependency, we want to avoid # the cyclic imports, thus implementing lazy version diff --git a/torch/ao/nn/intrinsic/__init__.py b/torch/ao/nn/intrinsic/__init__.py index a18bae3eaa38..ca446141106f 100644 --- a/torch/ao/nn/intrinsic/__init__.py +++ b/torch/ao/nn/intrinsic/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .modules import * # noqa: F403 from .modules.fused import _FusedModule # noqa: F403 diff --git a/torch/ao/nn/intrinsic/modules/fused.py b/torch/ao/nn/intrinsic/modules/fused.py index 4fff70cd76b2..a02365318104 100644 --- a/torch/ao/nn/intrinsic/modules/fused.py +++ b/torch/ao/nn/intrinsic/modules/fused.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d from torch.nn.utils.parametrize import type_before_parametrizations diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 3aa068e382d7..91a25a11d50b 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch import torch.nn as nn diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py index fb7ac4545bb3..89b3a55ff7d2 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.ao.nn.intrinsic as nni diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py index 7319c882b0aa..49cea103982f 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.qat as nnqat import torch.ao.nn.intrinsic as nni diff --git a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py index 9d0467c4cd57..b8bff1f5e3a9 100644 --- a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.quantized.dynamic as nnqd import torch.ao.nn.intrinsic as nni diff --git a/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py b/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py index 32c1d0eeb351..eb5104d8c409 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.intrinsic diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py index a369d2b7cec7..e7df10597331 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.intrinsic import torch.ao.nn.intrinsic.qat diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py index 10011e52b3ef..1ff34f9f5f20 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.intrinsic diff --git a/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py index ed64cba253b2..38cb543f4001 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.quantized as nnq import torch.ao.nn.intrinsic as nni diff --git a/torch/ao/nn/qat/dynamic/modules/linear.py b/torch/ao/nn/qat/dynamic/modules/linear.py index c93dfab1f15b..dd3c06953597 100644 --- a/torch/ao/nn/qat/dynamic/modules/linear.py +++ b/torch/ao/nn/qat/dynamic/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch __all__ = ["Linear"] diff --git a/torch/ao/nn/qat/modules/conv.py b/torch/ao/nn/qat/modules/conv.py index 0f56708fb84a..896bb2d243bd 100644 --- a/torch/ao/nn/qat/modules/conv.py +++ b/torch/ao/nn/qat/modules/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn from torch.nn.modules.utils import _single, _pair, _triple diff --git a/torch/ao/nn/qat/modules/embedding_ops.py b/torch/ao/nn/qat/modules/embedding_ops.py index 499d872ba049..4269db4abed5 100644 --- a/torch/ao/nn/qat/modules/embedding_ops.py +++ b/torch/ao/nn/qat/modules/embedding_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor import torch.nn as nn diff --git a/torch/ao/nn/qat/modules/linear.py b/torch/ao/nn/qat/modules/linear.py index a7083401cb21..67573a427bae 100644 --- a/torch/ao/nn/qat/modules/linear.py +++ b/torch/ao/nn/qat/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.nn.functional as F diff --git a/torch/ao/nn/quantizable/modules/activation.py b/torch/ao/nn/quantizable/modules/activation.py index 2c1aad574158..8a45499fd80f 100644 --- a/torch/ao/nn/quantizable/modules/activation.py +++ b/torch/ao/nn/quantizable/modules/activation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.jit # this is needed to avoid a circular import from torch import nn diff --git a/torch/ao/nn/quantizable/modules/rnn.py b/torch/ao/nn/quantizable/modules/rnn.py index 7c4eebafefbb..a311587bd984 100644 --- a/torch/ao/nn/quantizable/modules/rnn.py +++ b/torch/ao/nn/quantizable/modules/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numbers from typing import Optional, Tuple import warnings diff --git a/torch/ao/nn/quantized/dynamic/modules/conv.py b/torch/ao/nn/quantized/dynamic/modules/conv.py index 54d2b7e83fed..d47c898efa6a 100644 --- a/torch/ao/nn/quantized/dynamic/modules/conv.py +++ b/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Dynamically quantized convolution modules.""" import torch diff --git a/torch/ao/nn/quantized/dynamic/modules/linear.py b/torch/ao/nn/quantized/dynamic/modules/linear.py index 85b89b75fe58..0b8bf245af43 100644 --- a/torch/ao/nn/quantized/dynamic/modules/linear.py +++ b/torch/ao/nn/quantized/dynamic/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.quantized as nnq from torch.ao.nn.quantized.modules.utils import _quantize_weight diff --git a/torch/ao/nn/quantized/dynamic/modules/rnn.py b/torch/ao/nn/quantized/dynamic/modules/rnn.py index c81771a71889..9afab93d1a55 100644 --- a/torch/ao/nn/quantized/dynamic/modules/rnn.py +++ b/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numbers import warnings from typing_extensions import deprecated diff --git a/torch/ao/nn/quantized/functional.py b/torch/ao/nn/quantized/functional.py index 72218184fcfa..ccb450bdd834 100644 --- a/torch/ao/nn/quantized/functional.py +++ b/torch/ao/nn/quantized/functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" Functional interface (quantized).""" from typing import List, Optional import warnings diff --git a/torch/ao/nn/quantized/modules/__init__.py b/torch/ao/nn/quantized/modules/__init__.py index f539db753a47..2b87be71fd73 100644 --- a/torch/ao/nn/quantized/modules/__init__.py +++ b/torch/ao/nn/quantized/modules/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch # The quantized modules use `torch.nn` and `torch.ao.nn.quantizable` diff --git a/torch/ao/nn/quantized/modules/activation.py b/torch/ao/nn/quantized/modules/activation.py index 094ac63fb0af..3288c84555c4 100644 --- a/torch/ao/nn/quantized/modules/activation.py +++ b/torch/ao/nn/quantized/modules/activation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from warnings import warn __all__ = [ diff --git a/torch/ao/nn/quantized/modules/batchnorm.py b/torch/ao/nn/quantized/modules/batchnorm.py index 3644a314e9e8..975697936d1e 100644 --- a/torch/ao/nn/quantized/modules/batchnorm.py +++ b/torch/ao/nn/quantized/modules/batchnorm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.intrinsic as nni diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index 5e41aa5bfdaf..ee0bceb336b7 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Quantized convolution modules.""" from typing import Optional, List, TypeVar diff --git a/torch/ao/nn/quantized/modules/dropout.py b/torch/ao/nn/quantized/modules/dropout.py index 759113bdbf25..ac934111c7f6 100644 --- a/torch/ao/nn/quantized/modules/dropout.py +++ b/torch/ao/nn/quantized/modules/dropout.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch __all__ = ['Dropout'] diff --git a/torch/ao/nn/quantized/modules/embedding_ops.py b/torch/ao/nn/quantized/modules/embedding_ops.py index dc6f66a0d4eb..43b8d65063a4 100644 --- a/torch/ao/nn/quantized/modules/embedding_ops.py +++ b/torch/ao/nn/quantized/modules/embedding_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn from torch import Tensor # noqa: F401 diff --git a/torch/ao/nn/quantized/modules/functional_modules.py b/torch/ao/nn/quantized/modules/functional_modules.py index 4cb135dee0ec..77b366c1f6d0 100644 --- a/torch/ao/nn/quantized/modules/functional_modules.py +++ b/torch/ao/nn/quantized/modules/functional_modules.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List import torch diff --git a/torch/ao/nn/quantized/modules/linear.py b/torch/ao/nn/quantized/modules/linear.py index cbc01b092f3a..52b0a80a1c90 100644 --- a/torch/ao/nn/quantized/modules/linear.py +++ b/torch/ao/nn/quantized/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections.abc import Iterable import torch diff --git a/torch/ao/nn/quantized/modules/normalization.py b/torch/ao/nn/quantized/modules/normalization.py index e7c5c85a4527..46a18c4e2853 100644 --- a/torch/ao/nn/quantized/modules/normalization.py +++ b/torch/ao/nn/quantized/modules/normalization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch __all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d'] diff --git a/torch/ao/nn/quantized/modules/rnn.py b/torch/ao/nn/quantized/modules/rnn.py index deb14856a9ef..b75ad0e6b34d 100644 --- a/torch/ao/nn/quantized/modules/rnn.py +++ b/torch/ao/nn/quantized/modules/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch __all__ = [ diff --git a/torch/ao/nn/quantized/modules/utils.py b/torch/ao/nn/quantized/modules/utils.py index 7c24c0ca31dc..83f478b57ff3 100644 --- a/torch/ao/nn/quantized/modules/utils.py +++ b/torch/ao/nn/quantized/modules/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import torch import itertools diff --git a/torch/ao/nn/quantized/reference/modules/conv.py b/torch/ao/nn/quantized/reference/modules/conv.py index 910223056fba..a7c285bc7f67 100644 --- a/torch/ao/nn/quantized/reference/modules/conv.py +++ b/torch/ao/nn/quantized/reference/modules/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.nn.functional as F diff --git a/torch/ao/nn/quantized/reference/modules/linear.py b/torch/ao/nn/quantized/reference/modules/linear.py index 378fe0eb6eee..9dcba1f4bacd 100644 --- a/torch/ao/nn/quantized/reference/modules/linear.py +++ b/torch/ao/nn/quantized/reference/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.nn.functional as F diff --git a/torch/ao/nn/quantized/reference/modules/rnn.py b/torch/ao/nn/quantized/reference/modules/rnn.py index 978c1d69f30a..f5a53d0ceb3e 100644 --- a/torch/ao/nn/quantized/reference/modules/rnn.py +++ b/torch/ao/nn/quantized/reference/modules/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn from torch import Tensor diff --git a/torch/ao/nn/quantized/reference/modules/sparse.py b/torch/ao/nn/quantized/reference/modules/sparse.py index 973eb05bd3b3..8db3f14b08ce 100644 --- a/torch/ao/nn/quantized/reference/modules/sparse.py +++ b/torch/ao/nn/quantized/reference/modules/sparse.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.nn as nn import torch.nn.functional as F from torch import Tensor diff --git a/torch/ao/nn/quantized/reference/modules/utils.py b/torch/ao/nn/quantized/reference/modules/utils.py index c4f4d0b46efd..87acd1901f0c 100644 --- a/torch/ao/nn/quantized/reference/modules/utils.py +++ b/torch/ao/nn/quantized/reference/modules/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import typing diff --git a/torch/ao/nn/sparse/quantized/dynamic/linear.py b/torch/ao/nn/sparse/quantized/dynamic/linear.py index bc5cb99fced2..7a28142e4b0d 100644 --- a/torch/ao/nn/sparse/quantized/dynamic/linear.py +++ b/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch diff --git a/torch/ao/nn/sparse/quantized/linear.py b/torch/ao/nn/sparse/quantized/linear.py index 9d1c8f332172..26388e2e2c7b 100644 --- a/torch/ao/nn/sparse/quantized/linear.py +++ b/torch/ao/nn/sparse/quantized/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch diff --git a/torch/ao/nn/sparse/quantized/utils.py b/torch/ao/nn/sparse/quantized/utils.py index 3d934f578574..46b1cb1e5b71 100644 --- a/torch/ao/nn/sparse/quantized/utils.py +++ b/torch/ao/nn/sparse/quantized/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import threading __all__ = [ diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index 3f0df31dfd2a..d6df04bbb5e6 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.ao.nn.quantized as nnq diff --git a/torch/ao/ns/_numeric_suite_fx.py b/torch/ao/ns/_numeric_suite_fx.py index ec5fdaede073..bd827ea16368 100644 --- a/torch/ao/ns/_numeric_suite_fx.py +++ b/torch/ao/ns/_numeric_suite_fx.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module contains tooling to compare weights and activations across models. Example usage:: diff --git a/torch/ao/ns/fx/graph_matcher.py b/torch/ao/ns/fx/graph_matcher.py index 8db946ec707a..8b542a3a0b81 100644 --- a/torch/ao/ns/fx/graph_matcher.py +++ b/torch/ao/ns/fx/graph_matcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import enum diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py index fbd03426790d..ba977eed9962 100644 --- a/torch/ao/ns/fx/graph_passes.py +++ b/torch/ao/ns/fx/graph_passes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx import GraphModule, map_arg from torch.fx.graph import Graph, Node diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index 1fd6f069ac83..fc96a0da5a2b 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.fx from torch.fx import ( diff --git a/torch/ao/ns/fx/qconfig_multi_mapping.py b/torch/ao/ns/fx/qconfig_multi_mapping.py index 33efe21e3fe0..915fdb3e7830 100644 --- a/torch/ao/ns/fx/qconfig_multi_mapping.py +++ b/torch/ao/ns/fx/qconfig_multi_mapping.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import copy diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py index bf35a7e531e1..16ac0c9c1504 100644 --- a/torch/ao/ns/fx/utils.py +++ b/torch/ao/ns/fx/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import enum import operator diff --git a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py index 7c03a9f6e36a..0f4ace3de206 100644 --- a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py +++ b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional import torch from collections import defaultdict diff --git a/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py b/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py index ad4df426c8e1..76514b19f93c 100644 --- a/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py +++ b/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import wraps import weakref import abc diff --git a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py index 7f4fcb461e22..f56fa511f991 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import torch from typing import Optional, Tuple, List, Any, Dict diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py index 20919c140a4d..a90ed9bae523 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from dlrm_s_pytorch import DLRM_Net # type: ignore[import] import numpy as np # type: ignore[import] diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py index 3813f01c0975..1780b68540aa 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List import torch import time diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py index 4f205312e181..69ddce634237 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List import torch from dlrm_s_pytorch import unpack_batch # type: ignore[import] diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py index 31600118f662..79d5093d5098 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List import torch from dlrm_s_pytorch import unpack_batch # type: ignore[import] diff --git a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py index 448c9377cc55..f1281729a74b 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.nn import functional as F from functools import reduce diff --git a/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py index 922c81322cfe..704391268985 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import SUPPORTED_TYPES From 62bcdc0ac9942cb3b0006a70efaab1e873f50538 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 11:33:07 -0700 Subject: [PATCH 539/706] Flip default value for mypy disallow_untyped_defs [4/11] (#127841) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127841 Approved by: https://github.com/oulgen --- .../lightning/callbacks/data_sparsity.py | 1 + .../lightning/tests/test_callbacks.py | 1 + .../data_sparsifier/quantization_utils.py | 1 + .../_experimental/pruner/FPGM_pruner.py | 1 + .../pruner/base_structured_sparsifier.py | 1 + .../pruner/lstm_saliency_pruner.py | 1 + .../_experimental/pruner/parametrization.py | 1 + .../_experimental/pruner/prune_functions.py | 1 + .../_experimental/pruner/saliency_pruner.py | 1 + torch/ao/pruning/_mappings.py | 1 + torch/ao/pruning/scheduler/base_scheduler.py | 1 + torch/ao/pruning/scheduler/cubic_scheduler.py | 1 + .../ao/pruning/scheduler/lambda_scheduler.py | 1 + .../ao/pruning/sparsifier/base_sparsifier.py | 1 + .../sparsifier/nearly_diagonal_sparsifier.py | 1 + torch/ao/pruning/sparsifier/utils.py | 1 + .../sparsifier/weight_norm_sparsifier.py | 1 + torch/ao/quantization/__init__.py | 1 + torch/ao/quantization/_correct_bias.py | 1 + torch/ao/quantization/_equalize.py | 1 + .../quantization/_learnable_fake_quantize.py | 1 + .../_common_operator_config_utils.py | 1 + .../backend_config/_qnnpack_pt2e.py | 1 + .../backend_config/backend_config.py | 1 + .../ao/quantization/backend_config/native.py | 1 + .../ao/quantization/backend_config/onednn.py | 1 + .../quantization/backend_config/tensorrt.py | 1 + torch/ao/quantization/backend_config/utils.py | 1 + .../quantization/experimental/APoT_tensor.py | 1 + .../experimental/adaround_fake_quantize.py | 1 + .../experimental/adaround_optimization.py | 1 + .../quantization/experimental/apot_utils.py | 1 + .../experimental/fake_quantize.py | 1 + .../experimental/fake_quantize_function.py | 1 + torch/ao/quantization/experimental/linear.py | 1 + .../ao/quantization/experimental/observer.py | 1 + .../ao/quantization/experimental/quantizer.py | 1 + torch/ao/quantization/fake_quantize.py | 1 + torch/ao/quantization/fuse_modules.py | 1 + .../ao/quantization/fuser_method_mappings.py | 1 + torch/ao/quantization/fx/_decomposed.py | 1 + torch/ao/quantization/fx/_equalize.py | 1 + .../fx/_lower_to_native_backend.py | 1 + .../quantization/fx/_model_report/detector.py | 1 + .../fx/_model_report/model_report.py | 1 + .../fx/_model_report/model_report_observer.py | 1 + .../_model_report/model_report_visualizer.py | 1 + torch/ao/quantization/fx/custom_config.py | 1 + torch/ao/quantization/fx/fuse.py | 1 + torch/ao/quantization/fx/fuse_handler.py | 1 + torch/ao/quantization/fx/graph_module.py | 1 + torch/ao/quantization/fx/match_utils.py | 1 + torch/ao/quantization/fx/pattern_utils.py | 1 + torch/ao/quantization/fx/prepare.py | 1 + .../quantization/fx/qconfig_mapping_utils.py | 1 + torch/ao/quantization/fx/quantize_handler.py | 1 + torch/ao/quantization/fx/utils.py | 1 + torch/ao/quantization/observer.py | 1 + .../ao/quantization/pt2e/duplicate_dq_pass.py | 1 + torch/ao/quantization/pt2e/export_utils.py | 1 + torch/ao/quantization/pt2e/graph_utils.py | 1 + .../quantization/pt2e/port_metadata_pass.py | 1 + torch/ao/quantization/pt2e/prepare.py | 1 + torch/ao/quantization/pt2e/qat_utils.py | 1 + .../pt2e/representation/rewrite.py | 1 + torch/ao/quantization/pt2e/utils.py | 1 + torch/ao/quantization/qconfig.py | 1 + torch/ao/quantization/qconfig_mapping.py | 1 + torch/ao/quantization/quantize.py | 1 + torch/ao/quantization/quantize_jit.py | 1 + .../quantizer/embedding_quantizer.py | 1 + torch/ao/quantization/quantizer/quantizer.py | 1 + torch/ao/quantization/quantizer/utils.py | 1 + .../quantizer/x86_inductor_quantizer.py | 1 + .../quantizer/xnnpack_quantizer.py | 1 + .../quantizer/xnnpack_quantizer_utils.py | 1 + torch/ao/quantization/stubs.py | 1 + torch/ao/quantization/utils.py | 1 + torch/autograd/__init__.py | 1 + torch/autograd/_functions/tensor.py | 1 + torch/autograd/_functions/utils.py | 1 + torch/autograd/anomaly_mode.py | 1 + torch/autograd/forward_ad.py | 1 + torch/autograd/function.py | 1 + torch/autograd/functional.py | 1 + torch/autograd/grad_mode.py | 1 + torch/autograd/gradcheck.py | 1 + torch/autograd/graph.py | 1 + torch/autograd/profiler.py | 1 + torch/autograd/profiler_legacy.py | 1 + torch/autograd/profiler_util.py | 1 + torch/autograd/variable.py | 1 + torch/backends/__init__.py | 1 + torch/backends/_coreml/preprocess.py | 1 + torch/backends/_nnapi/prepare.py | 1 + torch/backends/_nnapi/serializer.py | 1 + torch/backends/cuda/__init__.py | 1 + torch/backends/cudnn/__init__.py | 1 + torch/backends/cudnn/rnn.py | 1 + torch/backends/mkl/__init__.py | 1 + torch/backends/mkldnn/__init__.py | 1 + torch/backends/mps/__init__.py | 1 + torch/backends/nnpack/__init__.py | 1 + torch/backends/openmp/__init__.py | 1 + torch/backends/opt_einsum/__init__.py | 1 + torch/backends/quantized/__init__.py | 1 + torch/backends/xeon/run_cpu.py | 1 + torch/backends/xnnpack/__init__.py | 1 + torch/compiler/__init__.py | 1 + torch/contrib/_tensorboard_vis.py | 1 + torch/cpu/__init__.py | 1 + torch/cpu/amp/autocast_mode.py | 1 + torch/cuda/__init__.py | 1 + torch/cuda/_memory_viz.py | 1 + torch/cuda/_sanitizer.py | 1 + torch/cuda/amp/autocast_mode.py | 1 + torch/cuda/amp/common.py | 1 + torch/cuda/graphs.py | 1 + torch/cuda/jiterator.py | 1 + torch/cuda/memory.py | 1 + torch/cuda/nccl.py | 1 + torch/cuda/nvtx.py | 1 + torch/cuda/profiler.py | 1 + torch/cuda/random.py | 1 + torch/cuda/streams.py | 1 + torch/cuda/tunable.py | 30 +++++++++---------- torch/distributed/__init__.py | 1 + .../_composable/checkpoint_activation.py | 1 + 128 files changed, 142 insertions(+), 15 deletions(-) diff --git a/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py b/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py index 77ca61d599cb..554ad27dd357 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict from copy import deepcopy from typing import Any, Optional, Dict, TYPE_CHECKING diff --git a/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py b/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py index 252405de4968..957254284215 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.ao.pruning._experimental.data_sparsifier.data_norm_sparsifier import DataNormSparsifier from torch.ao.pruning._experimental.data_scheduler.base_data_scheduler import BaseDataScheduler import torch diff --git a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py index 1e76cfc345ac..0e907f42d3bf 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module diff --git a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py index d8c3d20052ba..fe874c6effc7 100644 --- a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Callable, Optional, Union import torch diff --git a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py index 357421fb5529..b380ae00adce 100644 --- a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py +++ b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from itertools import chain from operator import getitem import torch diff --git a/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py b/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py index 9e569c14a6c8..3b65ce59fecc 100644 --- a/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast import torch diff --git a/torch/ao/pruning/_experimental/pruner/parametrization.py b/torch/ao/pruning/_experimental/pruner/parametrization.py index df94f7093b53..c5aa74e3bc52 100644 --- a/torch/ao/pruning/_experimental/pruner/parametrization.py +++ b/torch/ao/pruning/_experimental/pruner/parametrization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import nn from torch.nn.utils.parametrize import is_parametrized diff --git a/torch/ao/pruning/_experimental/pruner/prune_functions.py b/torch/ao/pruning/_experimental/pruner/prune_functions.py index 2b16d4b327a0..f7dcf120f9c3 100644 --- a/torch/ao/pruning/_experimental/pruner/prune_functions.py +++ b/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Collection of conversion functions for linear / conv2d structured pruning Also contains utilities for bias propagation diff --git a/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/torch/ao/pruning/_experimental/pruner/saliency_pruner.py index 7f96f0865d30..cf932c272005 100644 --- a/torch/ao/pruning/_experimental/pruner/saliency_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/saliency_pruner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .base_structured_sparsifier import BaseStructuredSparsifier diff --git a/torch/ao/pruning/_mappings.py b/torch/ao/pruning/_mappings.py index 726cbc6b0fc8..70a0c785190f 100644 --- a/torch/ao/pruning/_mappings.py +++ b/torch/ao/pruning/_mappings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs __all__ = [ "get_static_sparse_quantized_mapping", "get_dynamic_sparse_quantized_mapping", diff --git a/torch/ao/pruning/scheduler/base_scheduler.py b/torch/ao/pruning/scheduler/base_scheduler.py index 3391d3e73cd6..82f02399b7ec 100644 --- a/torch/ao/pruning/scheduler/base_scheduler.py +++ b/torch/ao/pruning/scheduler/base_scheduler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.ao.pruning import BaseSparsifier diff --git a/torch/ao/pruning/scheduler/cubic_scheduler.py b/torch/ao/pruning/scheduler/cubic_scheduler.py index 76fc61daa288..1a883059f569 100644 --- a/torch/ao/pruning/scheduler/cubic_scheduler.py +++ b/torch/ao/pruning/scheduler/cubic_scheduler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from .base_scheduler import BaseScheduler diff --git a/torch/ao/pruning/scheduler/lambda_scheduler.py b/torch/ao/pruning/scheduler/lambda_scheduler.py index a88d99a1f83b..5236ebc33a26 100644 --- a/torch/ao/pruning/scheduler/lambda_scheduler.py +++ b/torch/ao/pruning/scheduler/lambda_scheduler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from .base_scheduler import BaseScheduler diff --git a/torch/ao/pruning/sparsifier/base_sparsifier.py b/torch/ao/pruning/sparsifier/base_sparsifier.py index 1c210ace344d..8afed4d68945 100644 --- a/torch/ao/pruning/sparsifier/base_sparsifier.py +++ b/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import copy from collections import defaultdict diff --git a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py index 4f44e81485df..419323e68f93 100644 --- a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py +++ b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from . import base_sparsifier diff --git a/torch/ao/pruning/sparsifier/utils.py b/torch/ao/pruning/sparsifier/utils.py index 98f489904cc4..7fd93e4d9da7 100644 --- a/torch/ao/pruning/sparsifier/utils.py +++ b/torch/ao/pruning/sparsifier/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, Optional, Type from torch.nn.utils.parametrize import type_before_parametrizations, is_parametrized from itertools import chain diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index 2b24ca3d82e3..2f50d51f2a38 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import reduce from typing import Callable, Optional, Tuple, Union diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index e2b8ee5c810a..f77969b32149 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # flake8: noqa: F403 from .fake_quantize import * # noqa: F403 diff --git a/torch/ao/quantization/_correct_bias.py b/torch/ao/quantization/_correct_bias.py index 83cc81bb6b00..bf6b42a4a0dc 100644 --- a/torch/ao/quantization/_correct_bias.py +++ b/torch/ao/quantization/_correct_bias.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.ao.nn.quantized as nnq diff --git a/torch/ao/quantization/_equalize.py b/torch/ao/quantization/_equalize.py index 7d39dbcf1ca8..4fed532c56f0 100644 --- a/torch/ao/quantization/_equalize.py +++ b/torch/ao/quantization/_equalize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import copy from typing import Dict, Any diff --git a/torch/ao/quantization/_learnable_fake_quantize.py b/torch/ao/quantization/_learnable_fake_quantize.py index cdf44c5ea7b2..ce23e80de150 100644 --- a/torch/ao/quantization/_learnable_fake_quantize.py +++ b/torch/ao/quantization/_learnable_fake_quantize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.nn.parameter import Parameter from typing import List diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index 4e946a25ffbb..d76bdfddddaf 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import operator import torch diff --git a/torch/ao/quantization/backend_config/_qnnpack_pt2e.py b/torch/ao/quantization/backend_config/_qnnpack_pt2e.py index 01e112b688c0..871d26dd9ff7 100644 --- a/torch/ao/quantization/backend_config/_qnnpack_pt2e.py +++ b/torch/ao/quantization/backend_config/_qnnpack_pt2e.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator import torch from torch.ao.quantization.backend_config import ( diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py index 2288aced0995..96fb66662d6f 100644 --- a/torch/ao/quantization/backend_config/backend_config.py +++ b/torch/ao/quantization/backend_config/backend_config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Type, Union, TYPE_CHECKING diff --git a/torch/ao/quantization/backend_config/native.py b/torch/ao/quantization/backend_config/native.py index 81cfc928adb5..84e0fbc45c62 100644 --- a/torch/ao/quantization/backend_config/native.py +++ b/torch/ao/quantization/backend_config/native.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from ._common_operator_config_utils import ( _get_binary_op_configs, diff --git a/torch/ao/quantization/backend_config/onednn.py b/torch/ao/quantization/backend_config/onednn.py index 6eab945f7d74..88dffedfd81b 100644 --- a/torch/ao/quantization/backend_config/onednn.py +++ b/torch/ao/quantization/backend_config/onednn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.ao.nn.intrinsic as nni diff --git a/torch/ao/quantization/backend_config/tensorrt.py b/torch/ao/quantization/backend_config/tensorrt.py index 1c5f761508bb..7a80d1883cfd 100644 --- a/torch/ao/quantization/backend_config/tensorrt.py +++ b/torch/ao/quantization/backend_config/tensorrt.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from .backend_config import ( BackendConfig, diff --git a/torch/ao/quantization/backend_config/utils.py b/torch/ao/quantization/backend_config/utils.py index 2e7382274079..13bf632e251a 100644 --- a/torch/ao/quantization/backend_config/utils.py +++ b/torch/ao/quantization/backend_config/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, Any, List, Callable, Union, Tuple, Type import torch diff --git a/torch/ao/quantization/experimental/APoT_tensor.py b/torch/ao/quantization/experimental/APoT_tensor.py index debda7aea8c0..6caa2334be07 100644 --- a/torch/ao/quantization/experimental/APoT_tensor.py +++ b/torch/ao/quantization/experimental/APoT_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.ao.quantization.experimental.quantizer import APoTQuantizer diff --git a/torch/ao/quantization/experimental/adaround_fake_quantize.py b/torch/ao/quantization/experimental/adaround_fake_quantize.py index 4d988bbb25bb..d035a02b047a 100644 --- a/torch/ao/quantization/experimental/adaround_fake_quantize.py +++ b/torch/ao/quantization/experimental/adaround_fake_quantize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Tuple import torch diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py index 808b7abe2c78..f7eedd9fef12 100644 --- a/torch/ao/quantization/experimental/adaround_optimization.py +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from typing import Any, Callable, List, Optional, Tuple, Type, Union diff --git a/torch/ao/quantization/experimental/apot_utils.py b/torch/ao/quantization/experimental/apot_utils.py index ad7a7bed1fbe..c2f2f0746ca5 100644 --- a/torch/ao/quantization/experimental/apot_utils.py +++ b/torch/ao/quantization/experimental/apot_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This file contains utility functions to convert values using APoT nonuniform quantization methods. diff --git a/torch/ao/quantization/experimental/fake_quantize.py b/torch/ao/quantization/experimental/fake_quantize.py index 7541106a61c8..6b4da74541f2 100644 --- a/torch/ao/quantization/experimental/fake_quantize.py +++ b/torch/ao/quantization/experimental/fake_quantize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor from torch.ao.quantization.experimental.observer import APoTObserver diff --git a/torch/ao/quantization/experimental/fake_quantize_function.py b/torch/ao/quantization/experimental/fake_quantize_function.py index cac01fd8c002..924c81fc08df 100644 --- a/torch/ao/quantization/experimental/fake_quantize_function.py +++ b/torch/ao/quantization/experimental/fake_quantize_function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT diff --git a/torch/ao/quantization/experimental/linear.py b/torch/ao/quantization/experimental/linear.py index 154023b16183..cb46c99b01af 100644 --- a/torch/ao/quantization/experimental/linear.py +++ b/torch/ao/quantization/experimental/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import numpy as np diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py index 76a63815bdc6..8474f69c26a2 100644 --- a/torch/ao/quantization/experimental/observer.py +++ b/torch/ao/quantization/experimental/observer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module implements nonuniform observers used to collect statistics about the values observed during calibration (PTQ) or training (QAT). diff --git a/torch/ao/quantization/experimental/quantizer.py b/torch/ao/quantization/experimental/quantizer.py index df9c0f27847e..b386ce20bbd3 100644 --- a/torch/ao/quantization/experimental/quantizer.py +++ b/torch/ao/quantization/experimental/quantizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor import numpy as np diff --git a/torch/ao/quantization/fake_quantize.py b/torch/ao/quantization/fake_quantize.py index 9f0503cf06a5..b921df39217a 100644 --- a/torch/ao/quantization/fake_quantize.py +++ b/torch/ao/quantization/fake_quantize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Implements modules used to perform fake quantization.""" import torch diff --git a/torch/ao/quantization/fuse_modules.py b/torch/ao/quantization/fuse_modules.py index 2caa0a2b7f2d..b9447ff37e39 100644 --- a/torch/ao/quantization/fuse_modules.py +++ b/torch/ao/quantization/fuse_modules.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import torch.nn as nn diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index 16c0c3a85b8f..a989ae298825 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.nn as nn import torch.ao.nn.intrinsic as nni diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index f2e774590be3..72ce4b2471f5 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from typing import Optional, Tuple diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index b0965b9a7051..40a7e7bbff3b 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from collections import namedtuple diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index 049f4e3135d9..92620a169383 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx import map_arg, Node from torch.fx.graph import Graph diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index b5c7f9fd2976..8e59df51c6ff 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, Set, Tuple, Callable, List import torch diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py index 724e76ad576f..3370d8c9baf6 100644 --- a/torch/ao/quantization/fx/_model_report/model_report.py +++ b/torch/ao/quantization/fx/_model_report/model_report.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, Set, Tuple, Callable from collections import OrderedDict import torch diff --git a/torch/ao/quantization/fx/_model_report/model_report_observer.py b/torch/ao/quantization/fx/_model_report/model_report_observer.py index eaa45264be7e..f04d6da8a054 100644 --- a/torch/ao/quantization/fx/_model_report/model_report_observer.py +++ b/torch/ao/quantization/fx/_model_report/model_report_observer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.ao.quantization.observer import ObserverBase diff --git a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py index 5463862aa1cd..e6288c6f71d9 100644 --- a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py +++ b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from typing import Any, Set, Dict, List, Tuple, OrderedDict from collections import OrderedDict as OrdDict diff --git a/torch/ao/quantization/fx/custom_config.py b/torch/ao/quantization/fx/custom_config.py index 4fb2c3a28cb0..72f28ddbc777 100644 --- a/torch/ao/quantization/fx/custom_config.py +++ b/torch/ao/quantization/fx/custom_config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py index 6b2b614728f8..b555789f673a 100644 --- a/torch/ao/quantization/fx/fuse.py +++ b/torch/ao/quantization/fx/fuse.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx import ( GraphModule, Node, diff --git a/torch/ao/quantization/fx/fuse_handler.py b/torch/ao/quantization/fx/fuse_handler.py index 718cc561bfa0..2766211e8e1b 100644 --- a/torch/ao/quantization/fx/fuse_handler.py +++ b/torch/ao/quantization/fx/fuse_handler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.ao.quantization.backend_config import BackendConfig from torch.fx.graph import Node, Graph diff --git a/torch/ao/quantization/fx/graph_module.py b/torch/ao/quantization/fx/graph_module.py index cc9187285ae6..224f71745157 100644 --- a/torch/ao/quantization/fx/graph_module.py +++ b/torch/ao/quantization/fx/graph_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import copy from torch.fx import GraphModule diff --git a/torch/ao/quantization/fx/match_utils.py b/torch/ao/quantization/fx/match_utils.py index cf287db8c524..b5a6657103fc 100644 --- a/torch/ao/quantization/fx/match_utils.py +++ b/torch/ao/quantization/fx/match_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import torch from torch.fx.graph import ( diff --git a/torch/ao/quantization/fx/pattern_utils.py b/torch/ao/quantization/fx/pattern_utils.py index d8648a0aed5e..3665f75f7567 100644 --- a/torch/ao/quantization/fx/pattern_utils.py +++ b/torch/ao/quantization/fx/pattern_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict from typing import Dict, Any from torch.ao.quantization.utils import Pattern diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index d8e25f1260f5..80f50581cc72 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import torch import warnings diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 0b906a1777de..378c51b6805d 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import re from collections import defaultdict, OrderedDict diff --git a/torch/ao/quantization/fx/quantize_handler.py b/torch/ao/quantization/fx/quantize_handler.py index e70040f7e649..83fee8efcd99 100644 --- a/torch/ao/quantization/fx/quantize_handler.py +++ b/torch/ao/quantization/fx/quantize_handler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC from typing import Callable, Dict, List, Optional, Type diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 5cfedde4bc24..5029db47961f 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import torch import torch.nn as nn diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 5f075df1cd83..656372d37555 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module implements observers which are used to collect statistics about the values observed during calibration (PTQ) or training (QAT). diff --git a/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/torch/ao/quantization/pt2e/duplicate_dq_pass.py index 48c7d7247b99..a6cfbce611fa 100644 --- a/torch/ao/quantization/pt2e/duplicate_dq_pass.py +++ b/torch/ao/quantization/pt2e/duplicate_dq_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import operator diff --git a/torch/ao/quantization/pt2e/export_utils.py b/torch/ao/quantization/pt2e/export_utils.py index 139042c326b8..78c69b718d7d 100644 --- a/torch/ao/quantization/pt2e/export_utils.py +++ b/torch/ao/quantization/pt2e/export_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import types import torch diff --git a/torch/ao/quantization/pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py index bacb4d8a28f1..6ae93ba1d260 100644 --- a/torch/ao/quantization/pt2e/graph_utils.py +++ b/torch/ao/quantization/pt2e/graph_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from typing import Any, List, OrderedDict, Set, Optional, Callable import operator diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index 5ea1f939a3b6..313b420e7a22 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import Optional diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 169a982f62ce..162ee45623ee 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._subclasses import FakeTensor from torch.ao.quantization.fx.prepare import ( diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index 45f5c265d2cb..c4c1f804d41c 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import itertools import operator diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index 7f5cb2eeb13b..40801344740b 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx import GraphModule from ..export_utils import _WrapperModule diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 25f82f04e4e3..cde22426ae5b 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator import types diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 88e7b47aff2b..dc93d7938f0c 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import namedtuple from typing import Optional, Any, Union, Type from typing_extensions import deprecated diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index 6bf4b41c724a..37f71465afea 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from collections import OrderedDict from typing import Any, Callable, Dict, Tuple, Union, List diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 534def354573..be00be0e295b 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import itertools import warnings diff --git a/torch/ao/quantization/quantize_jit.py b/torch/ao/quantization/quantize_jit.py index 632fc1db2327..3001deb6ab9c 100644 --- a/torch/ao/quantization/quantize_jit.py +++ b/torch/ao/quantization/quantize_jit.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.ao.quantization.qconfig import QConfig diff --git a/torch/ao/quantization/quantizer/embedding_quantizer.py b/torch/ao/quantization/quantizer/embedding_quantizer.py index 81306943264b..bd3d2773e628 100644 --- a/torch/ao/quantization/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import copy diff --git a/torch/ao/quantization/quantizer/quantizer.py b/torch/ao/quantization/quantizer/quantizer.py index a521ff56c34c..4cecfee28f2b 100644 --- a/torch/ao/quantization/quantizer/quantizer.py +++ b/torch/ao/quantization/quantizer/quantizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Callable, Dict, List, Optional, Tuple, Union diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index f25d0916018b..f948dbb112dc 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List from torch.ao.quantization.pt2e.utils import _is_sym_size_node diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 4cc05e46c6a7..89e4966bf4eb 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import itertools diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index f3d1b6ca8b39..ae9ae60b8a3b 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import copy diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index 9f1732e57370..928ee0d3ac45 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import operator from dataclasses import dataclass diff --git a/torch/ao/quantization/stubs.py b/torch/ao/quantization/stubs.py index 10a63fb8f0ee..f62a227f1d77 100644 --- a/torch/ao/quantization/stubs.py +++ b/torch/ao/quantization/stubs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch import nn diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index 5ce1d1109e72..fadbf33a70b6 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Utils shared by different modes of quantization (eager/graph) """ diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index adf47ad1727d..aca9abb24070 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ ``torch.autograd`` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. It requires minimal diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py index d2b3149bfc81..9c982b074b65 100644 --- a/torch/autograd/_functions/tensor.py +++ b/torch/autograd/_functions/tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from functools import reduce from typing_extensions import deprecated diff --git a/torch/autograd/_functions/utils.py b/torch/autograd/_functions/utils.py index 7111d893400f..56baae4aae3b 100644 --- a/torch/autograd/_functions/utils.py +++ b/torch/autograd/_functions/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from functools import reduce diff --git a/torch/autograd/anomaly_mode.py b/torch/autograd/anomaly_mode.py index 80a2526a81de..7e73ad4ef2c3 100644 --- a/torch/autograd/anomaly_mode.py +++ b/torch/autograd/anomaly_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings import torch diff --git a/torch/autograd/forward_ad.py b/torch/autograd/forward_ad.py index 747b18f0f369..4187e220ceab 100644 --- a/torch/autograd/forward_ad.py +++ b/torch/autograd/forward_ad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os from collections import namedtuple diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 9aca2b2a1b32..62ec1183a365 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import inspect import itertools diff --git a/torch/autograd/functional.py b/torch/autograd/functional.py index 6701efbedac1..8cf3955a6927 100644 --- a/torch/autograd/functional.py +++ b/torch/autograd/functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Tuple import torch diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index be173c9b9de0..1c97ab58298b 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any import torch diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index a0d874038761..5bf74afacb66 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import functools import warnings diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 19938c183557..cde56a6f26c7 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import collections import contextlib diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 162dfe1eeaef..0392a8769846 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict from dataclasses import dataclass from time import perf_counter_ns diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py index cb573faf4410..40baafd441ae 100644 --- a/torch/autograd/profiler_legacy.py +++ b/torch/autograd/profiler_legacy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import warnings from typing_extensions import deprecated diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 23243733aaa8..a5cff1ea12a8 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import bisect import itertools import math diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py index ed841d4da7d4..84b504a9c82c 100644 --- a/torch/autograd/variable.py +++ b/torch/autograd/variable.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._C import _ImperativeEngine as ImperativeEngine diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index 2236230e8c6d..086147b87a81 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import types from contextlib import contextmanager diff --git a/torch/backends/_coreml/preprocess.py b/torch/backends/_coreml/preprocess.py index f393929bb7c2..18cb8229db9a 100644 --- a/torch/backends/_coreml/preprocess.py +++ b/torch/backends/_coreml/preprocess.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import hashlib import json from typing import Dict, Tuple diff --git a/torch/backends/_nnapi/prepare.py b/torch/backends/_nnapi/prepare.py index 8b07c3d6e0c6..6ba389902c9f 100644 --- a/torch/backends/_nnapi/prepare.py +++ b/torch/backends/_nnapi/prepare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py index 551fa821df68..34bcc42f8927 100644 --- a/torch/backends/_nnapi/serializer.py +++ b/torch/backends/_nnapi/serializer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import array import enum import functools diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index c35a962ba693..cb5f511bc5db 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Union diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index e00d92f44b28..e528ac68552d 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import sys import warnings diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py index aaf0bd02e8af..f2e9d4321a02 100644 --- a/torch/backends/cudnn/rnn.py +++ b/torch/backends/cudnn/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.cuda try: diff --git a/torch/backends/mkl/__init__.py b/torch/backends/mkl/__init__.py index 261ee764485b..9f96d692ae02 100644 --- a/torch/backends/mkl/__init__.py +++ b/torch/backends/mkl/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index 9cdee1cbd565..669ed59a1132 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys from contextlib import contextmanager diff --git a/torch/backends/mps/__init__.py b/torch/backends/mps/__init__.py index 8d5e70f06a0a..06eda58e82f9 100644 --- a/torch/backends/mps/__init__.py +++ b/torch/backends/mps/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import lru_cache as _lru_cache from typing import Optional diff --git a/torch/backends/nnpack/__init__.py b/torch/backends/nnpack/__init__.py index 892dfa022cfc..1a30e977cab3 100644 --- a/torch/backends/nnpack/__init__.py +++ b/torch/backends/nnpack/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager import torch diff --git a/torch/backends/openmp/__init__.py b/torch/backends/openmp/__init__.py index 4a7fcca12d0c..aff8d46cd4ac 100644 --- a/torch/backends/openmp/__init__.py +++ b/torch/backends/openmp/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/backends/opt_einsum/__init__.py b/torch/backends/opt_einsum/__init__.py index 2e66cd37542d..993a219fa9aa 100644 --- a/torch/backends/opt_einsum/__init__.py +++ b/torch/backends/opt_einsum/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import warnings from contextlib import contextmanager diff --git a/torch/backends/quantized/__init__.py b/torch/backends/quantized/__init__.py index 85009753e0ae..3cb795dd39fc 100644 --- a/torch/backends/quantized/__init__.py +++ b/torch/backends/quantized/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import types from typing import List diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py index 0344631ee6b4..bdf07e286174 100644 --- a/torch/backends/xeon/run_cpu.py +++ b/torch/backends/xeon/run_cpu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable Processors with optimal configurations. diff --git a/torch/backends/xnnpack/__init__.py b/torch/backends/xnnpack/__init__.py index c26dc11deb47..31e69876927d 100644 --- a/torch/backends/xnnpack/__init__.py +++ b/torch/backends/xnnpack/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import types diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index a27238c3d833..812bbaa4c660 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from typing import List diff --git a/torch/contrib/_tensorboard_vis.py b/torch/contrib/_tensorboard_vis.py index 87c325948a8b..ed1445dd7bce 100644 --- a/torch/contrib/_tensorboard_vis.py +++ b/torch/contrib/_tensorboard_vis.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import time from collections import defaultdict from functools import partial diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index a36594a3cb15..d2b8069048cc 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This package implements abstractions found in ``torch.cuda`` to facilitate writing device-agnostic code. diff --git a/torch/cpu/amp/autocast_mode.py b/torch/cpu/amp/autocast_mode.py index b545e91dd6f4..b61e9b542dba 100644 --- a/torch/cpu/amp/autocast_mode.py +++ b/torch/cpu/amp/autocast_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any from typing_extensions import deprecated diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 2b2fe32154b2..6722114e295b 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This package adds support for CUDA tensor types. diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index 7d211fd3b8cb..2047ec4efb28 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import pickle import sys import os diff --git a/torch/cuda/_sanitizer.py b/torch/cuda/_sanitizer.py index 89766ba8c1a4..bf72f277dd8a 100644 --- a/torch/cuda/_sanitizer.py +++ b/torch/cuda/_sanitizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams. diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index eb17d7a75e69..049ff41c590f 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Any from typing_extensions import deprecated diff --git a/torch/cuda/amp/common.py b/torch/cuda/amp/common.py index c4e8c1cc99b0..30ccaeede8d9 100644 --- a/torch/cuda/amp/common.py +++ b/torch/cuda/amp/common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from importlib.util import find_spec import torch diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index 9d9df283ced6..78c572a1822d 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import gc import typing diff --git a/torch/cuda/jiterator.py b/torch/cuda/jiterator.py index 1be552555945..294670f8819e 100644 --- a/torch/cuda/jiterator.py +++ b/torch/cuda/jiterator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import re from typing import Callable, List diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 0f12395ac778..9634d1c0d80b 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""This package adds support for device memory management implemented in CUDA.""" import collections diff --git a/torch/cuda/nccl.py b/torch/cuda/nccl.py index 67d528771215..4c28443c9e29 100644 --- a/torch/cuda/nccl.py +++ b/torch/cuda/nccl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import warnings from typing import Optional, Sequence, Union diff --git a/torch/cuda/nvtx.py b/torch/cuda/nvtx.py index 4b902c0c6d4d..195509687905 100644 --- a/torch/cuda/nvtx.py +++ b/torch/cuda/nvtx.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling.""" from contextlib import contextmanager diff --git a/torch/cuda/profiler.py b/torch/cuda/profiler.py index 51c8aa46f714..7e5dc9bab8de 100644 --- a/torch/cuda/profiler.py +++ b/torch/cuda/profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import tempfile diff --git a/torch/cuda/random.py b/torch/cuda/random.py index 1cf33114d17b..b736c9d959d8 100644 --- a/torch/cuda/random.py +++ b/torch/cuda/random.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Iterable, List, Union import torch diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index d36121381586..89271b588711 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ctypes import torch diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py index 0f7e0a1f3725..8b387102b43d 100644 --- a/torch/cuda/tunable.py +++ b/torch/cuda/tunable.py @@ -138,12 +138,12 @@ def enable(val: bool = True) -> None: r"""This is the big on/off switch for all TunableOp implementations.""" - torch._C._cuda_tunableop_enable(val) + torch._C._cuda_tunableop_enable(val) # type: ignore[attr-defined] def is_enabled() -> bool: r"""Returns whether the TunableOp feature is enabled.""" - return torch._C._cuda_tunableop_is_enabled() + return torch._C._cuda_tunableop_is_enabled() # type: ignore[attr-defined] def tuning_enable(val: bool = True) -> None: @@ -152,12 +152,12 @@ def tuning_enable(val: bool = True) -> None: When enabled, if a tuned entry isn't found, run the tuning step and record the entry. """ - torch._C._cuda_tunableop_tuning_enable(val) + torch._C._cuda_tunableop_tuning_enable(val) # type: ignore[attr-defined] def tuning_is_enabled() -> bool: r"""Returns whether TunableOp implementations can be tuned.""" - return torch._C._cuda_tunableop_tuning_is_enabled() + return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined] def set_max_tuning_duration(duration: int) -> None: @@ -166,12 +166,12 @@ def set_max_tuning_duration(duration: int) -> None: If both max tuning duration and iterations are set, the smaller of the two will be honored. At minimum 1 tuning iteration will always be run. """ - torch._C._cuda_tunableop_set_max_tuning_duration(duration) + torch._C._cuda_tunableop_set_max_tuning_duration(duration) # type: ignore[attr-defined] def get_max_tuning_duration() -> int: r"""Get max time to spend tuning a given solution.""" - return torch._C._cuda_tunableop_get_max_tuning_duration() + return torch._C._cuda_tunableop_get_max_tuning_duration() # type: ignore[attr-defined] def set_max_tuning_iterations(iterations: int) -> None: @@ -180,12 +180,12 @@ def set_max_tuning_iterations(iterations: int) -> None: If both max tuning duration and iterations are set, the smaller of the two will be honored. At minimum 1 tuning iteration will always be run. """ - torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) + torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) # type: ignore[attr-defined] def get_max_tuning_iterations() -> int: r"""Get max iterations to spend tuning a given solution.""" - return torch._C._cuda_tunableop_get_max_tuning_iterations() + return torch._C._cuda_tunableop_get_max_tuning_iterations() # type: ignore[attr-defined] def set_filename(filename: str, insert_device_ordinal: bool = False) -> None: @@ -195,22 +195,22 @@ def set_filename(filename: str, insert_device_ordinal: bool = False) -> None: will be added to the given filename automatically. This can be used in a 1-process-per-gpu cenario to ensure all processes write to a separate file. """ - torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) + torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined] def get_filename() -> str: r"""Get the results filename.""" - return torch._C._cuda_tunableop_get_filename() + return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined] def get_results() -> Tuple[str, str, str, float]: r"""Return all TunableOp results.""" - return torch._C._cuda_tunableop_get_results() + return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined] def get_validators() -> Tuple[str, str]: r"""Return the TunableOp validators.""" - return torch._C._cuda_tunableop_get_validators() + return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined] def write_file_on_exit(val: bool) -> None: @@ -219,7 +219,7 @@ def write_file_on_exit(val: bool) -> None: This is useful as a final flush of your results to disk if your application terminates as result of normal operation or an error. Manual flushing of your results can be achieved by manually calling ``write_file()``.""" - torch._C._cuda_tunableop_write_file_on_exit(val) + torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined] def write_file(filename: Optional[str] = None) -> bool: @@ -229,7 +229,7 @@ def write_file(filename: Optional[str] = None) -> bool: """ if filename is None: filename = get_filename() - return torch._C._cuda_tunableop_write_file(filename) + return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined] def read_file(filename: Optional[str] = None) -> bool: @@ -239,4 +239,4 @@ def read_file(filename: Optional[str] = None) -> bool: """ if filename is None: filename = get_filename() - return torch._C._cuda_tunableop_read_file(filename) + return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined] diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 3e7dce97b54c..b8e911c8738c 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import sys from enum import Enum diff --git a/torch/distributed/_composable/checkpoint_activation.py b/torch/distributed/_composable/checkpoint_activation.py index 8accef6afc34..6716f43a74a0 100644 --- a/torch/distributed/_composable/checkpoint_activation.py +++ b/torch/distributed/_composable/checkpoint_activation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager, nullcontext from typing import Any, Tuple From 3a0d0885171376ed610c8175a19ba40411fc6f3f Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 11:41:11 -0700 Subject: [PATCH 540/706] Flip default value for mypy disallow_untyped_defs [5/11] (#127842) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127842 Approved by: https://github.com/oulgen --- torch/distributed/_composable/contract.py | 1 + torch/distributed/_composable/fsdp/_fsdp_api.py | 1 + torch/distributed/_composable/fsdp/_fsdp_common.py | 1 + torch/distributed/_composable/fsdp/_fsdp_param.py | 1 + torch/distributed/_composable/fsdp/_fsdp_param_group.py | 1 + torch/distributed/_composable/fsdp/_fsdp_state.py | 1 + torch/distributed/_composable/fsdp/fully_shard.py | 1 + torch/distributed/_composable/replicate.py | 1 + torch/distributed/_cuda_p2p/__init__.py | 1 + torch/distributed/_functional_collectives.py | 1 + torch/distributed/_functional_collectives_impl.py | 1 + torch/distributed/_shard/api.py | 1 + torch/distributed/_shard/common_op_utils.py | 1 + torch/distributed/_shard/metadata.py | 1 + torch/distributed/_shard/op_registry_utils.py | 1 + torch/distributed/_shard/sharded_optim/api.py | 1 + torch/distributed/_shard/sharded_tensor/__init__.py | 1 + torch/distributed/_shard/sharded_tensor/_ops/_common.py | 1 + torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py | 1 + torch/distributed/_shard/sharded_tensor/_ops/init.py | 1 + torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py | 1 + torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py | 1 + torch/distributed/_shard/sharded_tensor/api.py | 1 + torch/distributed/_shard/sharded_tensor/metadata.py | 1 + torch/distributed/_shard/sharded_tensor/reshard.py | 1 + torch/distributed/_shard/sharded_tensor/shard.py | 1 + torch/distributed/_shard/sharded_tensor/utils.py | 1 + torch/distributed/_shard/sharding_spec/_internals.py | 1 + torch/distributed/_shard/sharding_spec/api.py | 1 + torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py | 1 + .../_shard/sharding_spec/chunk_sharding_spec_ops/_common.py | 1 + .../_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py | 1 + .../sharding_spec/chunk_sharding_spec_ops/embedding_bag.py | 1 + torch/distributed/_spmd/api.py | 1 + torch/distributed/_spmd/batch_dim_utils.py | 1 + torch/distributed/_spmd/comm_tensor.py | 1 + torch/distributed/_spmd/config.py | 1 + torch/distributed/_spmd/data_parallel.py | 1 + torch/distributed/_spmd/distribute.py | 1 + torch/distributed/_spmd/experimental_ops.py | 1 + torch/distributed/_spmd/graph_optimization.py | 1 + torch/distributed/_spmd/iter_graph_module.py | 1 + torch/distributed/_state_dict_utils.py | 1 + torch/distributed/_tensor/__init__.py | 1 + torch/distributed/_tensor/_collective_utils.py | 1 + torch/distributed/_tensor/_op_schema.py | 1 + torch/distributed/_tensor/_redistribute.py | 1 + torch/distributed/_tensor/_sharding_prop.py | 1 + torch/distributed/_tensor/_tp_conv.py | 1 + torch/distributed/_tensor/api.py | 1 + torch/distributed/_tensor/debug/__init__.py | 1 + torch/distributed/_tensor/debug/_op_coverage.py | 1 + torch/distributed/_tensor/debug/comm_mode.py | 1 + torch/distributed/_tensor/debug/visualize_sharding.py | 1 + torch/distributed/_tensor/examples/checkpoint_example.py | 1 + torch/distributed/_tensor/examples/convnext_example.py | 1 + torch/distributed/_tensor/examples/torchrec_sharding_example.py | 1 + torch/distributed/_tensor/experimental/__init__.py | 1 + torch/distributed/_tensor/experimental/local_map.py | 1 + torch/distributed/_tensor/experimental/tp_transform.py | 1 + torch/distributed/_tensor/ops/embedding_ops.py | 1 + torch/distributed/_tensor/ops/math_ops.py | 1 + torch/distributed/_tensor/ops/tensor_ops.py | 1 + torch/distributed/_tensor/ops/utils.py | 1 + torch/distributed/_tensor/ops/view_ops.py | 1 + torch/distributed/_tensor/placement_types.py | 1 + torch/distributed/_tensor/random.py | 1 + torch/distributed/_tools/memory_tracker.py | 1 + torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py | 1 + torch/distributed/algorithms/_comm_hooks/default_hooks.py | 1 + .../algorithms/_optimizer_overlap/optimizer_overlap.py | 1 + torch/distributed/algorithms/_quantization/quantization.py | 1 + torch/distributed/algorithms/ddp_comm_hooks/__init__.py | 1 + torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py | 1 + torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py | 1 + .../algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py | 1 + .../distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py | 1 + torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py | 1 + .../distributed/algorithms/ddp_comm_hooks/quantization_hooks.py | 1 + torch/distributed/algorithms/join.py | 1 + torch/distributed/algorithms/model_averaging/averagers.py | 1 + .../algorithms/model_averaging/hierarchical_model_averager.py | 1 + torch/distributed/algorithms/model_averaging/utils.py | 1 + torch/distributed/argparse_util.py | 1 + torch/distributed/autograd/__init__.py | 1 + torch/distributed/benchmarks/benchmark_ddp_rpc.py | 1 + torch/distributed/c10d_logger.py | 1 + torch/distributed/checkpoint/api.py | 1 + torch/distributed/checkpoint/default_planner.py | 1 + .../checkpoint/examples/async_checkpointing_example.py | 1 + torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py | 1 + torch/distributed/checkpoint/examples/stateful_example.py | 1 + torch/distributed/checkpoint/filesystem.py | 1 + torch/distributed/checkpoint/format_utils.py | 1 + torch/distributed/checkpoint/logger.py | 1 + torch/distributed/checkpoint/metadata.py | 1 + torch/distributed/checkpoint/planner_helpers.py | 1 + torch/distributed/checkpoint/resharding.py | 1 + torch/distributed/checkpoint/state_dict.py | 1 + torch/distributed/checkpoint/state_dict_loader.py | 1 + torch/distributed/checkpoint/state_dict_saver.py | 1 + torch/distributed/checkpoint/utils.py | 1 + torch/distributed/device_mesh.py | 1 + torch/distributed/distributed_c10d.py | 1 + torch/distributed/elastic/agent/server/local_elastic_agent.py | 1 + torch/distributed/elastic/events/api.py | 1 + torch/distributed/elastic/metrics/__init__.py | 1 + torch/distributed/elastic/metrics/api.py | 1 + torch/distributed/elastic/multiprocessing/api.py | 1 + torch/distributed/elastic/multiprocessing/errors/__init__.py | 1 + .../distributed/elastic/multiprocessing/errors/error_handler.py | 1 + torch/distributed/elastic/multiprocessing/errors/handlers.py | 1 + torch/distributed/elastic/multiprocessing/redirects.py | 1 + .../elastic/multiprocessing/subprocess_handler/handlers.py | 1 + torch/distributed/elastic/multiprocessing/tail_log.py | 1 + torch/distributed/elastic/rendezvous/api.py | 1 + torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py | 1 + torch/distributed/elastic/rendezvous/dynamic_rendezvous.py | 1 + torch/distributed/elastic/rendezvous/etcd_rendezvous.py | 1 + torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py | 1 + torch/distributed/elastic/rendezvous/etcd_server.py | 1 + torch/distributed/elastic/rendezvous/etcd_store.py | 1 + torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py | 1 + torch/distributed/elastic/rendezvous/utils.py | 1 + torch/distributed/elastic/timer/api.py | 1 + torch/distributed/elastic/timer/debug_info_logging.py | 1 + torch/distributed/elastic/timer/file_based_local_timer.py | 1 + torch/distributed/elastic/timer/local_timer.py | 1 + 128 files changed, 128 insertions(+) diff --git a/torch/distributed/_composable/contract.py b/torch/distributed/_composable/contract.py index 2a6983023f76..6693fa9608df 100644 --- a/torch/distributed/_composable/contract.py +++ b/torch/distributed/_composable/contract.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import uuid from collections import OrderedDict from functools import wraps diff --git a/torch/distributed/_composable/fsdp/_fsdp_api.py b/torch/distributed/_composable/fsdp/_fsdp_api.py index 2bf0278ed488..aa6b5e803b80 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_api.py +++ b/torch/distributed/_composable/fsdp/_fsdp_api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from typing import Optional diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 85b0192b0f50..594ec483bd3b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import traceback diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index ca12ea74b230..81596fe05f6b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from dataclasses import dataclass, field from enum import auto, Enum diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index bb66977848a3..2361b7ba7c7e 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index 15a00e83f086..f080e7550338 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index ca050790cdd6..018333a65886 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Any, cast, NoReturn, Optional, Union diff --git a/torch/distributed/_composable/replicate.py b/torch/distributed/_composable/replicate.py index 45e1b9d8ab7f..0cb4ea79bc7d 100644 --- a/torch/distributed/_composable/replicate.py +++ b/torch/distributed/_composable/replicate.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import weakref from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple diff --git a/torch/distributed/_cuda_p2p/__init__.py b/torch/distributed/_cuda_p2p/__init__.py index 84fda06265d9..c77902c0d3a7 100644 --- a/torch/distributed/_cuda_p2p/__init__.py +++ b/torch/distributed/_cuda_p2p/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict from contextlib import contextmanager diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index d170410061b1..9ac89166b25f 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import warnings from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union diff --git a/torch/distributed/_functional_collectives_impl.py b/torch/distributed/_functional_collectives_impl.py index 7abd33e42afa..c39cb4a9d50d 100644 --- a/torch/distributed/_functional_collectives_impl.py +++ b/torch/distributed/_functional_collectives_impl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py index 9afa7d9e793a..441bb421b195 100644 --- a/torch/distributed/_shard/api.py +++ b/torch/distributed/_shard/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager from typing import Optional import torch diff --git a/torch/distributed/_shard/common_op_utils.py b/torch/distributed/_shard/common_op_utils.py index c426503161c7..7506f17b046d 100644 --- a/torch/distributed/_shard/common_op_utils.py +++ b/torch/distributed/_shard/common_op_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.utils import _pytree as pytree from typing import Optional diff --git a/torch/distributed/_shard/metadata.py b/torch/distributed/_shard/metadata.py index b7bae9e6664a..850b065e4dab 100644 --- a/torch/distributed/_shard/metadata.py +++ b/torch/distributed/_shard/metadata.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from typing import List, Union, Optional from functools import reduce diff --git a/torch/distributed/_shard/op_registry_utils.py b/torch/distributed/_shard/op_registry_utils.py index 4febe841186a..033dc7c58e0a 100644 --- a/torch/distributed/_shard/op_registry_utils.py +++ b/torch/distributed/_shard/op_registry_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from inspect import signature from .common_op_utils import _basic_validation diff --git a/torch/distributed/_shard/sharded_optim/api.py b/torch/distributed/_shard/sharded_optim/api.py index 54d8a94ad3fe..e1acf7dc17a8 100644 --- a/torch/distributed/_shard/sharded_optim/api.py +++ b/torch/distributed/_shard/sharded_optim/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Union, Mapping, Dict, Any import torch.optim as optim diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index 602f75163782..1b846a8dabb4 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import List, TYPE_CHECKING diff --git a/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/torch/distributed/_shard/sharded_tensor/_ops/_common.py index e672c54927db..4d35d24ecafc 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/_common.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from torch.distributed._shard.sharded_tensor import ( _sharded_op_impl, diff --git a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py index 0a7999a4c263..034f91498161 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist import torch.distributed.distributed_c10d as distributed_c10d diff --git a/torch/distributed/_shard/sharded_tensor/_ops/init.py b/torch/distributed/_shard/sharded_tensor/_ops/init.py index dfb661653e71..736190d491e1 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/init.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/init.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed._shard.sharded_tensor as sharded_tensor from torch.distributed._shard.sharded_tensor import ( diff --git a/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py index 0e0911bb1d18..82737f82de53 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributed._shard.sharded_tensor import ( _sharded_op_impl, diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index f96eded95f31..7de78bf61f3f 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import torch from torch.distributed._shard.sharded_tensor import ( diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 79944953fd40..bf5db21b9a16 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations # type: ignore[attr-defined] from dataclasses import dataclass from typing import ( diff --git a/torch/distributed/_shard/sharded_tensor/metadata.py b/torch/distributed/_shard/sharded_tensor/metadata.py index cb112da5686b..8b3257240e38 100644 --- a/torch/distributed/_shard/sharded_tensor/metadata.py +++ b/torch/distributed/_shard/sharded_tensor/metadata.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass, field from enum import Enum from typing import List diff --git a/torch/distributed/_shard/sharded_tensor/reshard.py b/torch/distributed/_shard/sharded_tensor/reshard.py index de7a44bb8200..549dde38cdf8 100644 --- a/torch/distributed/_shard/sharded_tensor/reshard.py +++ b/torch/distributed/_shard/sharded_tensor/reshard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from typing import List, Tuple diff --git a/torch/distributed/_shard/sharded_tensor/shard.py b/torch/distributed/_shard/sharded_tensor/shard.py index d448cc6321b1..ac1e881370e8 100644 --- a/torch/distributed/_shard/sharded_tensor/shard.py +++ b/torch/distributed/_shard/sharded_tensor/shard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from typing import List diff --git a/torch/distributed/_shard/sharded_tensor/utils.py b/torch/distributed/_shard/sharded_tensor/utils.py index d904137ba6f0..782def0e4d4c 100644 --- a/torch/distributed/_shard/sharded_tensor/utils.py +++ b/torch/distributed/_shard/sharded_tensor/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections.abc import copy from typing import Optional, List, Sequence, TYPE_CHECKING diff --git a/torch/distributed/_shard/sharding_spec/_internals.py b/torch/distributed/_shard/sharding_spec/_internals.py index e8275063e038..07d3c2e19bc0 100644 --- a/torch/distributed/_shard/sharding_spec/_internals.py +++ b/torch/distributed/_shard/sharding_spec/_internals.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple from torch.distributed._shard.metadata import ShardMetadata diff --git a/torch/distributed/_shard/sharding_spec/api.py b/torch/distributed/_shard/sharding_spec/api.py index 1824b66a8194..7493eccdf015 100644 --- a/torch/distributed/_shard/sharding_spec/api.py +++ b/torch/distributed/_shard/sharding_spec/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC, abstractmethod from dataclasses import dataclass import functools diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py index 2775dbd9dd8d..bd2c960f7f60 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass import torch import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py index c869b71d69e7..83d3371c7f90 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py index c9cfcba1fe1a..117aed79520d 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py index 2f954398f988..01a148b5a9a9 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast, List diff --git a/torch/distributed/_spmd/api.py b/torch/distributed/_spmd/api.py index 2848060bf28d..ce9984efac6e 100644 --- a/torch/distributed/_spmd/api.py +++ b/torch/distributed/_spmd/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC, abstractmethod from contextlib import contextmanager, nullcontext from copy import copy diff --git a/torch/distributed/_spmd/batch_dim_utils.py b/torch/distributed/_spmd/batch_dim_utils.py index 6d36b2e38118..d3c39295c0e6 100644 --- a/torch/distributed/_spmd/batch_dim_utils.py +++ b/torch/distributed/_spmd/batch_dim_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Callable, Dict, List, Set import torch diff --git a/torch/distributed/_spmd/comm_tensor.py b/torch/distributed/_spmd/comm_tensor.py index 292f5b250861..a54ed2f46d21 100644 --- a/torch/distributed/_spmd/comm_tensor.py +++ b/torch/distributed/_spmd/comm_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from functools import partial from typing import Any, List, Optional, Tuple diff --git a/torch/distributed/_spmd/config.py b/torch/distributed/_spmd/config.py index 54f0cc4dc5c8..73ee19e803dc 100644 --- a/torch/distributed/_spmd/config.py +++ b/torch/distributed/_spmd/config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import sys from types import ModuleType diff --git a/torch/distributed/_spmd/data_parallel.py b/torch/distributed/_spmd/data_parallel.py index 5e376d9f0c4a..8b18c6c86763 100644 --- a/torch/distributed/_spmd/data_parallel.py +++ b/torch/distributed/_spmd/data_parallel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from contextlib import contextmanager from enum import Enum diff --git a/torch/distributed/_spmd/distribute.py b/torch/distributed/_spmd/distribute.py index 0ed2bcabb907..5fb5ff766799 100644 --- a/torch/distributed/_spmd/distribute.py +++ b/torch/distributed/_spmd/distribute.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import operator from dataclasses import dataclass diff --git a/torch/distributed/_spmd/experimental_ops.py b/torch/distributed/_spmd/experimental_ops.py index e108061e5d74..94a0da822449 100644 --- a/torch/distributed/_spmd/experimental_ops.py +++ b/torch/distributed/_spmd/experimental_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import cast, List, Optional, Sequence, Tuple diff --git a/torch/distributed/_spmd/graph_optimization.py b/torch/distributed/_spmd/graph_optimization.py index 10423fb55cd4..4a5cad7917d8 100644 --- a/torch/distributed/_spmd/graph_optimization.py +++ b/torch/distributed/_spmd/graph_optimization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["oncall: distributed"] import collections import itertools diff --git a/torch/distributed/_spmd/iter_graph_module.py b/torch/distributed/_spmd/iter_graph_module.py index f1e8e960f361..cd5f934c5c7f 100644 --- a/torch/distributed/_spmd/iter_graph_module.py +++ b/torch/distributed/_spmd/iter_graph_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import inspect import logging diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 2ec7be89c9e0..4d7a7b086509 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import io import math diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index de01187f2512..85de716d8439 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Optional, Sequence diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/_tensor/_collective_utils.py index 93052d6ddd62..4c1d18403666 100644 --- a/torch/distributed/_tensor/_collective_utils.py +++ b/torch/distributed/_tensor/_collective_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import math from dataclasses import dataclass diff --git a/torch/distributed/_tensor/_op_schema.py b/torch/distributed/_tensor/_op_schema.py index 43aa065a59e0..071c2ac4748f 100644 --- a/torch/distributed/_tensor/_op_schema.py +++ b/torch/distributed/_tensor/_op_schema.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from functools import cached_property from typing import Any, Dict, List, Optional, Sequence, Tuple, Union diff --git a/torch/distributed/_tensor/_redistribute.py b/torch/distributed/_tensor/_redistribute.py index c8e54a98b927..2653423a257f 100644 --- a/torch/distributed/_tensor/_redistribute.py +++ b/torch/distributed/_tensor/_redistribute.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from functools import lru_cache from typing import cast, Dict, List, NamedTuple, Tuple diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/_tensor/_sharding_prop.py index 3510f80cbeba..449cf6c23775 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/_tensor/_sharding_prop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import lru_cache from itertools import chain from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union diff --git a/torch/distributed/_tensor/_tp_conv.py b/torch/distributed/_tensor/_tp_conv.py index ebcc981d2c93..d480e9d7f79e 100644 --- a/torch/distributed/_tensor/_tp_conv.py +++ b/torch/distributed/_tensor/_tp_conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates # implement matrix related ops for distributed tensor from typing import cast, Dict, List, Tuple diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 5bcd6b033c8f..7da5f4e3dfcb 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import inspect import warnings diff --git a/torch/distributed/_tensor/debug/__init__.py b/torch/distributed/_tensor/debug/__init__.py index 2cd388cf93e4..b7bde685fd1e 100644 --- a/torch/distributed/_tensor/debug/__init__.py +++ b/torch/distributed/_tensor/debug/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.debug.comm_mode import CommDebugMode diff --git a/torch/distributed/_tensor/debug/_op_coverage.py b/torch/distributed/_tensor/debug/_op_coverage.py index a722136e2baf..4f5424633235 100644 --- a/torch/distributed/_tensor/debug/_op_coverage.py +++ b/torch/distributed/_tensor/debug/_op_coverage.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from operator import itemgetter from typing import List diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index 1ff97e4e78e1..150ef9250c2d 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict from typing import Any, Dict diff --git a/torch/distributed/_tensor/debug/visualize_sharding.py b/torch/distributed/_tensor/debug/visualize_sharding.py index 91bc9c2a382c..76cd8f3e9208 100644 --- a/torch/distributed/_tensor/debug/visualize_sharding.py +++ b/torch/distributed/_tensor/debug/visualize_sharding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Sequence, Tuple import numpy as np diff --git a/torch/distributed/_tensor/examples/checkpoint_example.py b/torch/distributed/_tensor/examples/checkpoint_example.py index 9bccc07d9625..1cb292f12c41 100644 --- a/torch/distributed/_tensor/examples/checkpoint_example.py +++ b/torch/distributed/_tensor/examples/checkpoint_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The following example contains a simple MLP model that uses different DTensor layouts, and use the checkpointing API to diff --git a/torch/distributed/_tensor/examples/convnext_example.py b/torch/distributed/_tensor/examples/convnext_example.py index df6b7d3d71fd..61f8d0234938 100644 --- a/torch/distributed/_tensor/examples/convnext_example.py +++ b/torch/distributed/_tensor/examples/convnext_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The following example demonstrates how to train a ConvNeXt model with intermediate activations sharded across mutliple GPUs via DTensor diff --git a/torch/distributed/_tensor/examples/torchrec_sharding_example.py b/torch/distributed/_tensor/examples/torchrec_sharding_example.py index 8edbad13301f..3e6c63dd18eb 100644 --- a/torch/distributed/_tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/_tensor/examples/torchrec_sharding_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The following example demonstrates how to represent torchrec's embedding sharding with the DTensor API. diff --git a/torch/distributed/_tensor/experimental/__init__.py b/torch/distributed/_tensor/experimental/__init__.py index 587eef3011ba..2dd21605ffcc 100644 --- a/torch/distributed/_tensor/experimental/__init__.py +++ b/torch/distributed/_tensor/experimental/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from contextlib import contextmanager diff --git a/torch/distributed/_tensor/experimental/local_map.py b/torch/distributed/_tensor/experimental/local_map.py index 2bf12871cc36..0fc6ce96e6e0 100644 --- a/torch/distributed/_tensor/experimental/local_map.py +++ b/torch/distributed/_tensor/experimental/local_map.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Callable, Optional, Sequence, Tuple, Union diff --git a/torch/distributed/_tensor/experimental/tp_transform.py b/torch/distributed/_tensor/experimental/tp_transform.py index b36f3d87e3d8..4a18d36bbc64 100644 --- a/torch/distributed/_tensor/experimental/tp_transform.py +++ b/torch/distributed/_tensor/experimental/tp_transform.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import operator from typing import Any, cast, Dict, List, Optional, Sequence, Tuple diff --git a/torch/distributed/_tensor/ops/embedding_ops.py b/torch/distributed/_tensor/ops/embedding_ops.py index 7cc8dd262638..6f8cc8c67851 100644 --- a/torch/distributed/_tensor/ops/embedding_ops.py +++ b/torch/distributed/_tensor/ops/embedding_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates # implement matrix related ops for distributed tensor from dataclasses import dataclass, field diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py index 029d1f803cb1..377c50dffa13 100644 --- a/torch/distributed/_tensor/ops/math_ops.py +++ b/torch/distributed/_tensor/ops/math_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import math from dataclasses import dataclass diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index 40f75c151579..d2feb19ba2f9 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import cast, List, Optional, Sequence, Tuple diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py index 245298607c5e..ecc3c5d06bee 100644 --- a/torch/distributed/_tensor/ops/utils.py +++ b/torch/distributed/_tensor/ops/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import functools import itertools diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index 303d802bc7bc..7161988adf25 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from dataclasses import dataclass from typing import ( diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 5cb5aaf55fa9..31e280c2f5b8 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from dataclasses import dataclass diff --git a/torch/distributed/_tensor/random.py b/torch/distributed/_tensor/random.py index f2eff6bb5ec3..ed331736c5ce 100644 --- a/torch/distributed/_tensor/random.py +++ b/torch/distributed/_tensor/random.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import contextlib import warnings diff --git a/torch/distributed/_tools/memory_tracker.py b/torch/distributed/_tools/memory_tracker.py index d8b6765230a1..10f70c9ce18e 100644 --- a/torch/distributed/_tools/memory_tracker.py +++ b/torch/distributed/_tools/memory_tracker.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict from itertools import chain diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index 24a079849df7..86ab1de003db 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from enum import auto, Enum from functools import partial diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py index 53c8eb7e163f..d370fabafc37 100644 --- a/torch/distributed/algorithms/_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch import torch.distributed as dist diff --git a/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py b/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py index 8044557e71dc..1afbb8d7967f 100644 --- a/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py +++ b/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC, abstractmethod import inspect from typing import Dict, Type diff --git a/torch/distributed/algorithms/_quantization/quantization.py b/torch/distributed/algorithms/_quantization/quantization.py index 911cc8255ee5..c421076bde3e 100644 --- a/torch/distributed/algorithms/_quantization/quantization.py +++ b/torch/distributed/algorithms/_quantization/quantization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch import torch.distributed as dist diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index 570aa34cf02e..2366a9d28c13 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum from functools import partial diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index 52f9b419ab14..8ab58cb58442 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import weakref from typing import Any, Callable, List, Optional diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index 791061e34f90..6ad4280e95ae 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Callable, cast, Tuple import torch diff --git a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py index dc7e5ee2fdc5..76d4cd6de2bd 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Callable, List, no_type_check import torch diff --git a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py index 218ee08dbd46..3528f3987479 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import torch diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 9d2d5649f745..fbc3b9e8739e 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict import logging import math diff --git a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py index 9d5cd573eed6..cbc1290e76e4 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist from torch import nn diff --git a/torch/distributed/algorithms/join.py b/torch/distributed/algorithms/join.py index 7c1aa3cac5ac..2936747a1c6e 100644 --- a/torch/distributed/algorithms/join.py +++ b/torch/distributed/algorithms/join.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod from types import TracebackType diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py index e1f8c0800c50..178efd1dbad9 100644 --- a/torch/distributed/algorithms/model_averaging/averagers.py +++ b/torch/distributed/algorithms/model_averaging/averagers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod from typing import Union, Iterable, Dict diff --git a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py index 637ae144b379..02802466ab62 100644 --- a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py +++ b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright 2022 Cruise LLC import logging import warnings diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index eaa1cd2e968d..de1977959d21 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # flake8: noqa C101 import itertools from typing import Union, Iterable, Dict, Iterator diff --git a/torch/distributed/argparse_util.py b/torch/distributed/argparse_util.py index a214dadd312a..c475eebf2127 100644 --- a/torch/distributed/argparse_util.py +++ b/torch/distributed/argparse_util.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/autograd/__init__.py b/torch/distributed/autograd/__init__.py index e94ab1bb9d63..6546c38a37b9 100644 --- a/torch/distributed/autograd/__init__.py +++ b/torch/distributed/autograd/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import torch diff --git a/torch/distributed/benchmarks/benchmark_ddp_rpc.py b/torch/distributed/benchmarks/benchmark_ddp_rpc.py index 7294fce61ff3..60f71e12213b 100644 --- a/torch/distributed/benchmarks/benchmark_ddp_rpc.py +++ b/torch/distributed/benchmarks/benchmark_ddp_rpc.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import io import os diff --git a/torch/distributed/c10d_logger.py b/torch/distributed/c10d_logger.py index 5d2aa9b62991..c1cc67b40681 100644 --- a/torch/distributed/c10d_logger.py +++ b/torch/distributed/c10d_logger.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/checkpoint/api.py b/torch/distributed/checkpoint/api.py index 828685103261..660196bc28de 100644 --- a/torch/distributed/checkpoint/api.py +++ b/torch/distributed/checkpoint/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import traceback as tb from typing import Any, Dict, Tuple diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index 57ca0f2a764f..83b76718a6b7 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import dataclasses diff --git a/torch/distributed/checkpoint/examples/async_checkpointing_example.py b/torch/distributed/checkpoint/examples/async_checkpointing_example.py index d4e2b5268de7..5eaba9a67227 100644 --- a/torch/distributed/checkpoint/examples/async_checkpointing_example.py +++ b/torch/distributed/checkpoint/examples/async_checkpointing_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["oncall: distributed"] import os diff --git a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py index 9e2438c47bb8..38c637d3a4fd 100644 --- a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py +++ b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates """ diff --git a/torch/distributed/checkpoint/examples/stateful_example.py b/torch/distributed/checkpoint/examples/stateful_example.py index 6c23dc3e298f..6c76ec436364 100644 --- a/torch/distributed/checkpoint/examples/stateful_example.py +++ b/torch/distributed/checkpoint/examples/stateful_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["oncall: distributed"] diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index aa25d1fb5369..4d512891f122 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import dataclasses import io diff --git a/torch/distributed/checkpoint/format_utils.py b/torch/distributed/checkpoint/format_utils.py index 41ebaf8be61b..e82284704565 100644 --- a/torch/distributed/checkpoint/format_utils.py +++ b/torch/distributed/checkpoint/format_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import os from enum import Enum diff --git a/torch/distributed/checkpoint/logger.py b/torch/distributed/checkpoint/logger.py index 08e2bee2a78b..270240490c99 100644 --- a/torch/distributed/checkpoint/logger.py +++ b/torch/distributed/checkpoint/logger.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import time from typing import Any, Callable, Dict, List, TypeVar diff --git a/torch/distributed/checkpoint/metadata.py b/torch/distributed/checkpoint/metadata.py index bbcfcbc01e17..b3bc7a580dad 100644 --- a/torch/distributed/checkpoint/metadata.py +++ b/torch/distributed/checkpoint/metadata.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os from dataclasses import dataclass, field from enum import Enum diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py index c4e5be89a45d..4bbe26876c88 100644 --- a/torch/distributed/checkpoint/planner_helpers.py +++ b/torch/distributed/checkpoint/planner_helpers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, cast, List import torch diff --git a/torch/distributed/checkpoint/resharding.py b/torch/distributed/checkpoint/resharding.py index 1ebb0ba57d73..a1bf112f1795 100644 --- a/torch/distributed/checkpoint/resharding.py +++ b/torch/distributed/checkpoint/resharding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Tuple from torch.distributed.checkpoint.metadata import ChunkStorageMetadata diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 95fc57faf8e8..cc55b1a5b42c 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import functools import gc diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index b8ad6f61da14..f443f73f02d6 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import warnings from typing import Any, cast, Dict, Optional, Set, Union diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index 451603288d12..6d04044391ab 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import os import warnings diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index a93c0bfc400a..0efba34a551b 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import cProfile import inspect import io diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index f25a5b91da4b..e46356a36894 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging import math diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 17152f0a87ed..bd81fd61b02f 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Distributed Collective Communication (c10d).""" import itertools diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index 95369ecb61e1..232f28234e65 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/events/api.py b/torch/distributed/elastic/events/api.py index 62f5d7500922..082499b3af63 100644 --- a/torch/distributed/elastic/events/api.py +++ b/torch/distributed/elastic/events/api.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/metrics/__init__.py b/torch/distributed/elastic/metrics/__init__.py index 767abcc1d60b..d8bea0b3c079 100644 --- a/torch/distributed/elastic/metrics/__init__.py +++ b/torch/distributed/elastic/metrics/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env/python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 11a3930acf70..7b6d8295ef05 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index eb0b110f25ee..5d294a7d0802 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py index 95d6a6192245..d63c283b4c35 100644 --- a/torch/distributed/elastic/multiprocessing/errors/__init__.py +++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/torch/distributed/elastic/multiprocessing/errors/error_handler.py index 903731a6a2ab..34d6229dda3b 100644 --- a/torch/distributed/elastic/multiprocessing/errors/error_handler.py +++ b/torch/distributed/elastic/multiprocessing/errors/error_handler.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/errors/handlers.py b/torch/distributed/elastic/multiprocessing/errors/handlers.py index 3071aef17117..09b2aca55f16 100644 --- a/torch/distributed/elastic/multiprocessing/errors/handlers.py +++ b/torch/distributed/elastic/multiprocessing/errors/handlers.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/redirects.py b/torch/distributed/elastic/multiprocessing/redirects.py index e63255819383..8ad3e2edf1c1 100644 --- a/torch/distributed/elastic/multiprocessing/redirects.py +++ b/torch/distributed/elastic/multiprocessing/redirects.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # !/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py index 8d4477452a20..e122f89a94f7 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 17b0d216e954..804e2e5a6323 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py index 09b19be479dc..7ddcd7c70b9a 100644 --- a/torch/distributed/elastic/rendezvous/api.py +++ b/torch/distributed/elastic/rendezvous/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py index 62413df02aae..7fb894bd2247 100644 --- a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index a80fa9e97894..0bc92d845d19 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index b642d6201200..1a371b74275a 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py index cacb888590f8..c9d60abdc236 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/rendezvous/etcd_server.py b/torch/distributed/elastic/rendezvous/etcd_server.py index a28f7cc31839..891858534c56 100644 --- a/torch/distributed/elastic/rendezvous/etcd_server.py +++ b/torch/distributed/elastic/rendezvous/etcd_server.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/rendezvous/etcd_store.py b/torch/distributed/elastic/rendezvous/etcd_store.py index 7690439237ad..605596475686 100644 --- a/torch/distributed/elastic/rendezvous/etcd_store.py +++ b/torch/distributed/elastic/rendezvous/etcd_store.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py index 2e53034a9d6e..ace82d0a2226 100644 --- a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/rendezvous/utils.py b/torch/distributed/elastic/rendezvous/utils.py index 326bc604a914..8419051d29f8 100644 --- a/torch/distributed/elastic/rendezvous/utils.py +++ b/torch/distributed/elastic/rendezvous/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/timer/api.py b/torch/distributed/elastic/timer/api.py index 0121c98d56d1..77fcaaceed4f 100644 --- a/torch/distributed/elastic/timer/api.py +++ b/torch/distributed/elastic/timer/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/timer/debug_info_logging.py b/torch/distributed/elastic/timer/debug_info_logging.py index 2ac2dc5318be..55a1a9e9bcdf 100644 --- a/torch/distributed/elastic/timer/debug_info_logging.py +++ b/torch/distributed/elastic/timer/debug_info_logging.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index f2ded8ba84dd..fce46f053a7e 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/timer/local_timer.py b/torch/distributed/elastic/timer/local_timer.py index 7c87413aef19..b6a54896fc5e 100644 --- a/torch/distributed/elastic/timer/local_timer.py +++ b/torch/distributed/elastic/timer/local_timer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # From 7c12cc7ce4b0070aac22484e13106b98e8602170 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 11:41:12 -0700 Subject: [PATCH 541/706] Flip default value for mypy disallow_untyped_defs [6/11] (#127843) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843 Approved by: https://github.com/oulgen ghstack dependencies: #127842 --- torch/distributed/elastic/utils/data/cycling_iterator.py | 1 + .../elastic/utils/data/elastic_distributed_sampler.py | 1 + torch/distributed/elastic/utils/distributed.py | 1 + torch/distributed/elastic/utils/logging.py | 1 + torch/distributed/elastic/utils/store.py | 1 + torch/distributed/examples/memory_tracker_example.py | 1 + torch/distributed/fsdp/_common_utils.py | 1 + torch/distributed/fsdp/_debug_utils.py | 1 + torch/distributed/fsdp/_dynamo_utils.py | 1 + torch/distributed/fsdp/_exec_order_utils.py | 1 + torch/distributed/fsdp/_flat_param.py | 1 + torch/distributed/fsdp/_init_utils.py | 1 + torch/distributed/fsdp/_optim_utils.py | 1 + torch/distributed/fsdp/_runtime_utils.py | 1 + torch/distributed/fsdp/_shard_utils.py | 1 + torch/distributed/fsdp/_state_dict_utils.py | 1 + torch/distributed/fsdp/_trace_utils.py | 1 + torch/distributed/fsdp/_unshard_param_utils.py | 1 + torch/distributed/fsdp/_wrap_utils.py | 1 + torch/distributed/fsdp/sharded_grad_scaler.py | 1 + torch/distributed/fsdp/wrap.py | 1 + torch/distributed/launch.py | 1 + torch/distributed/launcher/api.py | 1 + torch/distributed/nn/api/remote_module.py | 1 + torch/distributed/nn/functional.py | 1 + torch/distributed/nn/jit/instantiator.py | 1 + torch/distributed/nn/jit/templates/remote_module_template.py | 1 + torch/distributed/optim/functional_adadelta.py | 1 + torch/distributed/optim/functional_adagrad.py | 1 + torch/distributed/optim/functional_adam.py | 1 + torch/distributed/optim/functional_adamax.py | 1 + torch/distributed/optim/functional_adamw.py | 1 + torch/distributed/optim/functional_rmsprop.py | 1 + torch/distributed/optim/functional_rprop.py | 1 + torch/distributed/optim/functional_sgd.py | 1 + torch/distributed/optim/named_optimizer.py | 1 + torch/distributed/optim/optimizer.py | 1 + torch/distributed/optim/post_localSGD_optimizer.py | 1 + torch/distributed/optim/utils.py | 1 + torch/distributed/optim/zero_redundancy_optimizer.pyi | 1 + torch/distributed/pipelining/_IR.py | 1 + torch/distributed/pipelining/_backward.py | 1 + torch/distributed/pipelining/_debug.py | 1 + torch/distributed/pipelining/_unflatten.py | 1 + torch/distributed/pipelining/_utils.py | 1 + torch/distributed/pipelining/microbatch.py | 1 + torch/distributed/pipelining/schedules.py | 1 + torch/distributed/pipelining/stage.py | 1 + torch/distributed/remote_device.py | 1 + torch/distributed/rendezvous.py | 1 + torch/distributed/rpc/__init__.py | 1 + torch/distributed/rpc/_testing/__init__.py | 1 + torch/distributed/rpc/_testing/faulty_agent_backend_registry.py | 1 + torch/distributed/rpc/_utils.py | 1 + torch/distributed/rpc/api.py | 1 + torch/distributed/rpc/backend_registry.py | 1 + torch/distributed/rpc/functions.py | 1 + torch/distributed/rpc/internal.py | 1 + torch/distributed/rpc/options.py | 1 + torch/distributed/rpc/rref_proxy.py | 1 + torch/distributed/rpc/server_process_global_profiler.py | 1 + torch/distributed/run.py | 1 + torch/distributed/tensor/parallel/_utils.py | 1 + torch/distributed/tensor/parallel/ddp.py | 1 + torch/distributed/tensor/parallel/fsdp.py | 1 + torch/distributed/tensor/parallel/loss.py | 1 + torch/distributed/tensor/parallel/style.py | 1 + torch/distributed/utils.py | 1 + torch/distributions/bernoulli.py | 1 + torch/distributions/beta.py | 1 + torch/distributions/binomial.py | 1 + torch/distributions/categorical.py | 1 + torch/distributions/cauchy.py | 1 + torch/distributions/chi2.py | 1 + torch/distributions/constraint_registry.py | 1 + torch/distributions/constraints.py | 1 + torch/distributions/continuous_bernoulli.py | 1 + torch/distributions/dirichlet.py | 1 + torch/distributions/distribution.py | 1 + torch/distributions/exp_family.py | 1 + torch/distributions/exponential.py | 1 + torch/distributions/fishersnedecor.py | 1 + torch/distributions/gamma.py | 1 + torch/distributions/geometric.py | 1 + torch/distributions/gumbel.py | 1 + torch/distributions/half_cauchy.py | 1 + torch/distributions/half_normal.py | 1 + torch/distributions/independent.py | 1 + torch/distributions/inverse_gamma.py | 1 + torch/distributions/kl.py | 1 + torch/distributions/kumaraswamy.py | 1 + torch/distributions/laplace.py | 1 + torch/distributions/lkj_cholesky.py | 1 + torch/distributions/log_normal.py | 1 + torch/distributions/logistic_normal.py | 1 + torch/distributions/lowrank_multivariate_normal.py | 1 + torch/distributions/mixture_same_family.py | 1 + torch/distributions/multinomial.py | 1 + torch/distributions/multivariate_normal.py | 1 + torch/distributions/negative_binomial.py | 1 + torch/distributions/normal.py | 1 + torch/distributions/one_hot_categorical.py | 1 + torch/distributions/pareto.py | 1 + torch/distributions/poisson.py | 1 + torch/distributions/relaxed_bernoulli.py | 1 + torch/distributions/relaxed_categorical.py | 1 + torch/distributions/studentT.py | 1 + torch/distributions/transformed_distribution.py | 1 + torch/distributions/transforms.py | 1 + torch/distributions/uniform.py | 1 + torch/distributions/utils.py | 1 + torch/distributions/von_mises.py | 1 + torch/distributions/weibull.py | 1 + torch/distributions/wishart.py | 1 + torch/export/_remove_auto_functionalized_pass.py | 1 + torch/export/_remove_effect_tokens_pass.py | 1 + torch/export/_safeguard.py | 1 + torch/export/_trace.py | 1 + torch/export/_unlift.py | 1 + torch/export/dynamic_shapes.py | 1 + 120 files changed, 120 insertions(+) diff --git a/torch/distributed/elastic/utils/data/cycling_iterator.py b/torch/distributed/elastic/utils/data/cycling_iterator.py index 60a5861f7bef..b5dadb96bda4 100644 --- a/torch/distributed/elastic/utils/data/cycling_iterator.py +++ b/torch/distributed/elastic/utils/data/cycling_iterator.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py index a66803fa8c09..8e378c6a1be1 100644 --- a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py +++ b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/utils/distributed.py b/torch/distributed/elastic/utils/distributed.py index 1dc4680abc16..a1ad1acca796 100644 --- a/torch/distributed/elastic/utils/distributed.py +++ b/torch/distributed/elastic/utils/distributed.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/utils/logging.py b/torch/distributed/elastic/utils/logging.py index e305d16400cb..d87504d255d6 100644 --- a/torch/distributed/elastic/utils/logging.py +++ b/torch/distributed/elastic/utils/logging.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/utils/store.py b/torch/distributed/elastic/utils/store.py index 080e92eae91e..6d2e1f046502 100644 --- a/torch/distributed/elastic/utils/store.py +++ b/torch/distributed/elastic/utils/store.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/examples/memory_tracker_example.py b/torch/distributed/examples/memory_tracker_example.py index d4946513098c..cb2ba03777d8 100644 --- a/torch/distributed/examples/memory_tracker_example.py +++ b/torch/distributed/examples/memory_tracker_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torchvision diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index c1d77bf410b5..aae2405d0bb5 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file includes private common utilities for FSDP. """ diff --git a/torch/distributed/fsdp/_debug_utils.py b/torch/distributed/fsdp/_debug_utils.py index a41a817724e5..523330e5580d 100644 --- a/torch/distributed/fsdp/_debug_utils.py +++ b/torch/distributed/fsdp/_debug_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import time from collections import defaultdict diff --git a/torch/distributed/fsdp/_dynamo_utils.py b/torch/distributed/fsdp/_dynamo_utils.py index 3a6c63dc5af8..e58c91a5807b 100644 --- a/torch/distributed/fsdp/_dynamo_utils.py +++ b/torch/distributed/fsdp/_dynamo_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Set import torch.nn as nn diff --git a/torch/distributed/fsdp/_exec_order_utils.py b/torch/distributed/fsdp/_exec_order_utils.py index 3ba2a43c0596..ad5fdc1fde5f 100644 --- a/torch/distributed/fsdp/_exec_order_utils.py +++ b/torch/distributed/fsdp/_exec_order_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import warnings from enum import auto, Enum diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index ed141465155c..f3e918349af7 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import functools import logging diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 64685013d9a4..c8b58091bf89 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import itertools import os diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index b066f930ebaf..d4aa344c1114 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import logging diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index f1e579adae00..833c1d45697a 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging from enum import auto, Enum diff --git a/torch/distributed/fsdp/_shard_utils.py b/torch/distributed/fsdp/_shard_utils.py index 8af94b78209b..da243e6aa130 100644 --- a/torch/distributed/fsdp/_shard_utils.py +++ b/torch/distributed/fsdp/_shard_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import itertools import math diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 9489994a3bb4..797a0116587b 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import logging import math diff --git a/torch/distributed/fsdp/_trace_utils.py b/torch/distributed/fsdp/_trace_utils.py index c768b73b8f95..49039e337ea2 100644 --- a/torch/distributed/fsdp/_trace_utils.py +++ b/torch/distributed/fsdp/_trace_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from contextlib import contextmanager from dataclasses import dataclass, field diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py index 7700d631d73e..435193a88703 100644 --- a/torch/distributed/fsdp/_unshard_param_utils.py +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import warnings from typing import cast, Generator diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 16f521f65b8d..84cdf250d8ae 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import functools import inspect diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index 47bfe041cdc2..3487e01263c7 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from collections import abc, defaultdict from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index 90796269de46..acb5a6f1f642 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the BSD license found in the diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index 3efb0c3cf31d..a9e35c36db7f 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" Module ``torch.distributed.launch``. diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 20de0a032713..937647f77828 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 16e38b32712d..de8a15dd65da 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +# mypy: allow-untyped-defs import collections import io import sys diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index 857d090dedbe..e90a78a69324 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist from torch.autograd import Function diff --git a/torch/distributed/nn/jit/instantiator.py b/torch/distributed/nn/jit/instantiator.py index 24f53c4f1a60..d529fc740945 100644 --- a/torch/distributed/nn/jit/instantiator.py +++ b/torch/distributed/nn/jit/instantiator.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +# mypy: allow-untyped-defs import importlib import logging import os diff --git a/torch/distributed/nn/jit/templates/remote_module_template.py b/torch/distributed/nn/jit/templates/remote_module_template.py index ac731b434243..07b055774b36 100644 --- a/torch/distributed/nn/jit/templates/remote_module_template.py +++ b/torch/distributed/nn/jit/templates/remote_module_template.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +# mypy: allow-untyped-defs def get_remote_module_template(enable_moving_cpu_tensors_to_cuda: bool): diff --git a/torch/distributed/optim/functional_adadelta.py b/torch/distributed/optim/functional_adadelta.py index e3e44d4667ae..bc5f7c63dd17 100644 --- a/torch/distributed/optim/functional_adadelta.py +++ b/torch/distributed/optim/functional_adadelta.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional import torch diff --git a/torch/distributed/optim/functional_adagrad.py b/torch/distributed/optim/functional_adagrad.py index dfd50db17591..93a1fe2b2240 100644 --- a/torch/distributed/optim/functional_adagrad.py +++ b/torch/distributed/optim/functional_adagrad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional import torch diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py index 5335df17e089..34868d23d8a5 100644 --- a/torch/distributed/optim/functional_adam.py +++ b/torch/distributed/optim/functional_adam.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional, Tuple import torch diff --git a/torch/distributed/optim/functional_adamax.py b/torch/distributed/optim/functional_adamax.py index f3acd4d271ef..32bce65dfe1f 100644 --- a/torch/distributed/optim/functional_adamax.py +++ b/torch/distributed/optim/functional_adamax.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional, Tuple import torch diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py index 40aabafb0ca7..43addd050822 100644 --- a/torch/distributed/optim/functional_adamw.py +++ b/torch/distributed/optim/functional_adamw.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional, Tuple import torch diff --git a/torch/distributed/optim/functional_rmsprop.py b/torch/distributed/optim/functional_rmsprop.py index fc4d7750973c..851119c8600c 100644 --- a/torch/distributed/optim/functional_rmsprop.py +++ b/torch/distributed/optim/functional_rmsprop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional import torch diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py index 6018ce943b40..60742bc68896 100644 --- a/torch/distributed/optim/functional_rprop.py +++ b/torch/distributed/optim/functional_rprop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional, Tuple import torch diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py index 4a807a605571..3a8176e87705 100644 --- a/torch/distributed/optim/functional_sgd.py +++ b/torch/distributed/optim/functional_sgd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional import torch diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index 28edbe39d80e..9e1e5377873d 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import warnings diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index 8246c667509d..f2eca606c026 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from collections import defaultdict diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index f1717685966a..db65856e32ad 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings import torch diff --git a/torch/distributed/optim/utils.py b/torch/distributed/optim/utils.py index 5fb197e2d1dd..af2220ca5574 100644 --- a/torch/distributed/optim/utils.py +++ b/torch/distributed/optim/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Type from torch import optim diff --git a/torch/distributed/optim/zero_redundancy_optimizer.pyi b/torch/distributed/optim/zero_redundancy_optimizer.pyi index c341e00e3ee3..21f3cc5e3fc2 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.pyi +++ b/torch/distributed/optim/zero_redundancy_optimizer.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import enum from typing import Any, Callable, Dict, List, Optional, overload, Set, Type diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index e58749c581cb..7d0aede8943e 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import copy import logging diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index c3aa9060502b..6ba12899e838 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import List, Optional diff --git a/torch/distributed/pipelining/_debug.py b/torch/distributed/pipelining/_debug.py index 7067a39b39d1..6b153ec78d89 100644 --- a/torch/distributed/pipelining/_debug.py +++ b/torch/distributed/pipelining/_debug.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import torch diff --git a/torch/distributed/pipelining/_unflatten.py b/torch/distributed/pipelining/_unflatten.py index 27241d17874c..659c9804a966 100644 --- a/torch/distributed/pipelining/_unflatten.py +++ b/torch/distributed/pipelining/_unflatten.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py index 31caf3427424..cf7097795868 100644 --- a/torch/distributed/pipelining/_utils.py +++ b/torch/distributed/pipelining/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging from dataclasses import dataclass diff --git a/torch/distributed/pipelining/microbatch.py b/torch/distributed/pipelining/microbatch.py index 6358a1293edb..8360951b43eb 100644 --- a/torch/distributed/pipelining/microbatch.py +++ b/torch/distributed/pipelining/microbatch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging from typing import Any, Dict, List, Optional, Tuple diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index dfd5752c2e45..6990ea983edb 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index a114b629344d..c2c5582d6854 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging import operator diff --git a/torch/distributed/remote_device.py b/torch/distributed/remote_device.py index e26d398bf786..da664f7408bb 100644 --- a/torch/distributed/remote_device.py +++ b/torch/distributed/remote_device.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional, Union import torch diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 8bef92275edd..e3266cb238ac 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs try: from urllib.parse import urlparse, urlunparse except ImportError as e: diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index de8153e19c01..581433d220c6 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from datetime import timedelta import logging import os diff --git a/torch/distributed/rpc/_testing/__init__.py b/torch/distributed/rpc/_testing/__init__.py index 5755b99c7571..640c4d09f062 100644 --- a/torch/distributed/rpc/_testing/__init__.py +++ b/torch/distributed/rpc/_testing/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py index b02a6a2ff8ac..9e8660989e5a 100644 --- a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py +++ b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs import torch.distributed as dist import torch.distributed.rpc as rpc diff --git a/torch/distributed/rpc/_utils.py b/torch/distributed/rpc/_utils.py index a532897969d4..6499a80e0e17 100644 --- a/torch/distributed/rpc/_utils.py +++ b/torch/distributed/rpc/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager from typing import cast import logging diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 0f317829b207..a33358eb0dc6 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs __all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync", "rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"] diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index d09ec399e390..6290f9e8e205 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs __all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"] import collections diff --git a/torch/distributed/rpc/functions.py b/torch/distributed/rpc/functions.py index b1c85c47853d..c9e92980cf56 100644 --- a/torch/distributed/rpc/functions.py +++ b/torch/distributed/rpc/functions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index 6e00a4d18521..2fc647c414d9 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import copyreg import io diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 67892d14e075..70328f345969 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional, Union import torch diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index 89986be8b928..cdb0a5d22b74 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import partial from . import functions diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index dc3f4c19ef1e..0543ab56a877 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +# mypy: allow-untyped-defs import itertools diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 399c9c39ec61..9e418c708f03 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 876e97f70c5b..013a2a9d1723 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from typing import Tuple, Union diff --git a/torch/distributed/tensor/parallel/ddp.py b/torch/distributed/tensor/parallel/ddp.py index 474e542551ae..baa9d638037d 100644 --- a/torch/distributed/tensor/parallel/ddp.py +++ b/torch/distributed/tensor/parallel/ddp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, List, Tuple import torch.nn as nn diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index 888631e67777..c38771ae86e2 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from typing import Any, cast, List, Optional, Tuple diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index 8e7b7de84e1e..f2776c5123b4 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import contextlib from typing import cast, Dict, Optional, Tuple diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 2720f9dca7d0..f532b97e97d0 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod from typing import Optional, Union, Tuple, Dict diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index af44fee9d720..7c135cbbacf8 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import traceback from typing import ( diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index 75c2882dbc15..701d24ecd68c 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index a802301a47ed..79b2f5e79ae0 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number, Real import torch diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 9243da7b6bf4..95e7baeb906e 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 08d2fb3ac8e8..cc35689bee99 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import nan from torch.distributions import constraints diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 1a95dfe0d762..ed42d183a7fd 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from numbers import Number diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py index 16d0d6d60fbe..11f8127169a3 100644 --- a/torch/distributions/chi2.py +++ b/torch/distributions/chi2.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.distributions import constraints from torch.distributions.gamma import Gamma diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index 83192f69547f..ae30348dd2d7 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" PyTorch provides two global :class:`ConstraintRegistry` objects that link :class:`~torch.distributions.constraints.Constraint` objects to diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index df94bbd7b14f..5dc9b46519a3 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" The following constraints are implemented: diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py index 3e7f1a53a47f..34eb75b9b6f8 100644 --- a/torch/distributions/continuous_bernoulli.py +++ b/torch/distributions/continuous_bernoulli.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from numbers import Number diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index b7175aa61628..c8a5ec485b1a 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.autograd import Function from torch.autograd.function import once_differentiable diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 2fb05828a8b3..b329a277174d 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from typing import Any, Dict, Optional, Tuple from typing_extensions import deprecated diff --git a/torch/distributions/exp_family.py b/torch/distributions/exp_family.py index e60f6489d5bf..6d422aeacf08 100644 --- a/torch/distributions/exp_family.py +++ b/torch/distributions/exp_family.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions.distribution import Distribution diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index 020b5215bbdb..e557f6a6bccc 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py index 788f74b58556..3e70aa7f5c70 100644 --- a/torch/distributions/fishersnedecor.py +++ b/torch/distributions/fishersnedecor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py index c189fb24e070..c115a8d71bf9 100644 --- a/torch/distributions/gamma.py +++ b/torch/distributions/gamma.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index 0bf2f3dbacc6..918d97885738 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index e0ed5d8f8690..af886f65e833 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from numbers import Number diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py index ef0edc6f0fe8..0afedbc9d5d7 100644 --- a/torch/distributions/half_cauchy.py +++ b/torch/distributions/half_cauchy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py index 6526170b24ee..4cf977376ea3 100644 --- a/torch/distributions/half_normal.py +++ b/torch/distributions/half_normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index 35b705fd0f29..36946e798f6b 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict import torch diff --git a/torch/distributions/inverse_gamma.py b/torch/distributions/inverse_gamma.py index 5a66138b6f04..cff64d0a9e49 100644 --- a/torch/distributions/inverse_gamma.py +++ b/torch/distributions/inverse_gamma.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions import constraints from torch.distributions.gamma import Gamma diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 923f1edcdf41..20adf1cdad2a 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import warnings from functools import total_ordering diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py index 9de3c422dc4c..25393f7177c5 100644 --- a/torch/distributions/kumaraswamy.py +++ b/torch/distributions/kumaraswamy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import nan from torch.distributions import constraints diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index 7b830cc76f9b..8069a41ab6fb 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py index c1cb46f02fc2..38f5235ed278 100644 --- a/torch/distributions/lkj_cholesky.py +++ b/torch/distributions/lkj_cholesky.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro). diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py index f6694cf9507f..bde09b88ecb4 100644 --- a/torch/distributions/log_normal.py +++ b/torch/distributions/log_normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.distributions import constraints from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py index a9ef4dd26564..6cdd4f8db515 100644 --- a/torch/distributions/logistic_normal.py +++ b/torch/distributions/logistic_normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.distributions import constraints from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index a3acaa990966..6f09de1f5177 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index 8db242e33253..ab507f9f60a2 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict import torch diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 3f316e823a79..50699a592a31 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import inf from torch.distributions import Categorical, constraints diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 2784eeb214d5..4edff9c69b57 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index 59edee589f9a..230b404c3fb0 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F from torch.distributions import constraints diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 3364474ba68f..0f73c8facf29 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from numbers import Number, Real diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 2fdf5ff6c0ae..957a7d6bdf7f 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions import constraints from torch.distributions.categorical import Categorical diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 07cfb417a814..76dbe29b67b6 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.distributions import constraints from torch.distributions.exponential import Exponential from torch.distributions.transformed_distribution import TransformedDistribution diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index 81c0898a577b..4ecf85dc825b 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index a41e1be1f029..ca5b6fd46b5b 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index 707a80d05415..719c0c15d38e 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions import constraints from torch.distributions.categorical import Categorical diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py index 553144e2643b..b49e56c2e313 100644 --- a/torch/distributions/studentT.py +++ b/torch/distributions/studentT.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index b2201278ea8d..8c7cba61fb14 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict import torch diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index f2907caa6018..b81b19441335 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import math import numbers diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index e939bb4aae39..8b3497b4e313 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 91e4345e983c..c6a10088fdd8 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import update_wrapper from numbers import Number from typing import Any, Dict diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index 17f52fad25b3..8be9ffb7778c 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/weibull.py b/torch/distributions/weibull.py index 39e07d580bc5..607190df1e1e 100644 --- a/torch/distributions/weibull.py +++ b/torch/distributions/weibull.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions import constraints from torch.distributions.exponential import Exponential diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index 733efbbeb95f..3ec13c25017f 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import warnings from numbers import Number diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py index c1cea8ec005f..930915f96f9b 100644 --- a/torch/export/_remove_auto_functionalized_pass.py +++ b/torch/export/_remove_auto_functionalized_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index 235b43b969aa..20411dc87cce 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from typing import List diff --git a/torch/export/_safeguard.py b/torch/export/_safeguard.py index 92fb9b434041..76f22f369c56 100644 --- a/torch/export/_safeguard.py +++ b/torch/export/_safeguard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode from torch.overrides import TorchFunctionMode diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 4fcc85f3236b..ee25dbc2e1ea 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import functools import inspect diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 2fdb7916eeeb..97df0562caa7 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from itertools import chain from typing import Any, Dict, List, Optional, Tuple diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 43ab56c10501..e351df8d622c 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import dataclasses import inspect From 038b927590669f9368e78dd6d91911d24635c27e Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 11:41:14 -0700 Subject: [PATCH 542/706] Flip default value for mypy disallow_untyped_defs [7/11] (#127844) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844 Approved by: https://github.com/oulgen ghstack dependencies: #127842, #127843 --- torch/export/exported_program.py | 1 + torch/export/graph_signature.py | 1 + torch/export/unflatten.py | 1 + torch/functional.py | 1 + torch/futures/__init__.py | 1 + torch/fx/_compatibility.py | 1 + torch/fx/_lazy_graph_module.py | 1 + torch/fx/_pytree.py | 1 + torch/fx/_symbolic_trace.py | 1 + torch/fx/_utils.py | 1 + torch/fx/annotate.py | 1 + torch/fx/experimental/_sym_dispatch_mode.py | 1 + torch/fx/experimental/accelerator_partitioner.py | 1 + torch/fx/experimental/const_fold.py | 1 + torch/fx/experimental/debug.py | 1 + torch/fx/experimental/graph_gradual_typechecker.py | 1 + torch/fx/experimental/merge_matmul.py | 1 + torch/fx/experimental/meta_tracer.py | 1 + torch/fx/experimental/migrate_gradual_types/constraint.py | 1 + .../experimental/migrate_gradual_types/constraint_generator.py | 1 + torch/fx/experimental/migrate_gradual_types/transform_to_z3.py | 1 + torch/fx/experimental/migrate_gradual_types/util.py | 1 + torch/fx/experimental/normalize.py | 1 + torch/fx/experimental/optimization.py | 1 + torch/fx/experimental/partitioner_utils.py | 1 + torch/fx/experimental/recording.py | 1 + torch/fx/experimental/refinement_types.py | 1 + torch/fx/experimental/rewriter.py | 1 + torch/fx/experimental/schema_type_annotation.py | 1 + torch/fx/experimental/shape_inference/infer_shape.py | 1 + torch/fx/experimental/sym_node.py | 1 + torch/fx/experimental/unification/core.py | 1 + torch/fx/experimental/unification/match.py | 1 + torch/fx/experimental/unification/more.py | 1 + torch/fx/experimental/unification/multipledispatch/conflict.py | 1 + torch/fx/experimental/unification/multipledispatch/core.py | 1 + torch/fx/experimental/unification/multipledispatch/dispatcher.py | 1 + torch/fx/experimental/unification/multipledispatch/utils.py | 1 + torch/fx/experimental/unification/multipledispatch/variadic.py | 1 + torch/fx/experimental/unification/unification_tools.py | 1 + torch/fx/experimental/unification/utils.py | 1 + torch/fx/experimental/unification/variable.py | 1 + torch/fx/experimental/unify_refinements.py | 1 + torch/fx/experimental/validator.py | 1 + torch/fx/graph.py | 1 + torch/fx/graph_module.py | 1 + torch/fx/immutable_collections.py | 1 + torch/fx/interpreter.py | 1 + torch/fx/operator_schemas.py | 1 + torch/fx/passes/backends/cudagraphs.py | 1 + torch/fx/passes/dialect/common/cse_pass.py | 1 + torch/fx/passes/fake_tensor_prop.py | 1 + torch/fx/passes/graph_drawer.py | 1 + torch/fx/passes/graph_manipulation.py | 1 + torch/fx/passes/graph_transform_observer.py | 1 + torch/fx/passes/infra/partitioner.py | 1 + torch/fx/passes/infra/pass_base.py | 1 + torch/fx/passes/infra/pass_manager.py | 1 + torch/fx/passes/net_min_base.py | 1 + torch/fx/passes/operator_support.py | 1 + torch/fx/passes/pass_manager.py | 1 + torch/fx/passes/reinplace.py | 1 + torch/fx/passes/runtime_assert.py | 1 + torch/fx/passes/split_module.py | 1 + torch/fx/passes/split_utils.py | 1 + torch/fx/passes/splitter_base.py | 1 + torch/fx/passes/tools_common.py | 1 + torch/fx/passes/utils/common.py | 1 + torch/fx/passes/utils/fuser_utils.py | 1 + torch/fx/passes/utils/matcher_utils.py | 1 + torch/fx/passes/utils/source_matcher_utils.py | 1 + torch/fx/tensor_type.py | 1 + torch/fx/traceback.py | 1 + torch/hub.py | 1 + torch/jit/__init__.py | 1 + torch/jit/_async.py | 1 + torch/jit/_await.py | 1 + torch/jit/_builtins.py | 1 + torch/jit/_check.py | 1 + torch/jit/_dataclass_impls.py | 1 + torch/jit/_decomposition_utils.py | 1 + torch/jit/_decompositions.py | 1 + torch/jit/_freeze.py | 1 + torch/jit/_fuser.py | 1 + torch/jit/_ir_utils.py | 1 + torch/jit/_monkeytype_config.py | 1 + torch/jit/_passes/_property_propagation.py | 1 + torch/jit/_pickle.py | 1 + torch/jit/_recursive.py | 1 + torch/jit/_script.pyi | 1 + torch/jit/_serialization.py | 1 + torch/jit/_shape_functions.py | 1 + torch/jit/_state.py | 1 + torch/jit/_trace.py | 1 + torch/jit/annotations.py | 1 + torch/jit/frontend.py | 1 + torch/jit/generate_bytecode.py | 1 + torch/jit/mobile/__init__.py | 1 + torch/jit/quantized.py | 1 + torch/jit/supported_ops.py | 1 + torch/jit/unsupported_tensor_ops.py | 1 + torch/library.py | 1 + torch/masked/_ops.py | 1 + torch/masked/maskedtensor/_ops_refs.py | 1 + torch/masked/maskedtensor/binary.py | 1 + torch/masked/maskedtensor/core.py | 1 + torch/masked/maskedtensor/creation.py | 1 + torch/masked/maskedtensor/passthrough.py | 1 + torch/masked/maskedtensor/reductions.py | 1 + torch/masked/maskedtensor/unary.py | 1 + torch/mps/__init__.py | 1 + torch/mps/event.py | 1 + torch/mps/profiler.py | 1 + torch/mtia/__init__.py | 1 + torch/multiprocessing/__init__.py | 1 + torch/multiprocessing/_atfork.py | 1 + torch/multiprocessing/queue.py | 1 + torch/multiprocessing/reductions.py | 1 + torch/multiprocessing/spawn.py | 1 + torch/nested/__init__.py | 1 + torch/nested/_internal/nested_tensor.py | 1 + torch/nested/_internal/ops.py | 1 + torch/nested/_internal/sdpa.py | 1 + torch/nn/__init__.py | 1 + torch/nn/attention/__init__.py | 1 + torch/nn/attention/_flex_attention.py | 1 + torch/nn/attention/_utils.py | 1 + torch/nn/attention/bias.py | 1 + 128 files changed, 128 insertions(+) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 048ffe2e85c9..7b29251ca4ae 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import dataclasses import functools diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index ecfd7853400d..ce62e8793941 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses from enum import auto, Enum from typing import Collection, Dict, List, Mapping, Optional, Set, Tuple, Union diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 61685ed0f180..4de95dad2c8d 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import copy import operator diff --git a/torch/functional.py b/torch/functional.py index 7c07ae348631..a836c06f028d 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import ( List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING ) diff --git a/torch/futures/__init__.py b/torch/futures/__init__.py index 6a398bebb599..e1623c44f193 100644 --- a/torch/futures/__init__.py +++ b/torch/futures/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union diff --git a/torch/fx/_compatibility.py b/torch/fx/_compatibility.py index 14588fad9a09..4258979eb3e7 100644 --- a/torch/fx/_compatibility.py +++ b/torch/fx/_compatibility.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict import textwrap diff --git a/torch/fx/_lazy_graph_module.py b/torch/fx/_lazy_graph_module.py index a4b4bc0d69d7..79a18de12f31 100644 --- a/torch/fx/_lazy_graph_module.py +++ b/torch/fx/_lazy_graph_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager from torch.fx import GraphModule diff --git a/torch/fx/_pytree.py b/torch/fx/_pytree.py index 29ab0c867911..da02e21528de 100644 --- a/torch/fx/_pytree.py +++ b/torch/fx/_pytree.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import namedtuple from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 5725c4c6a05c..25a342f064c8 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import copy import functools diff --git a/torch/fx/_utils.py b/torch/fx/_utils.py index 598aeafee2d9..36c831dfdee0 100644 --- a/torch/fx/_utils.py +++ b/torch/fx/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, Optional import torch diff --git a/torch/fx/annotate.py b/torch/fx/annotate.py index 032ce14b6ec7..ab5c6d0acd61 100644 --- a/torch/fx/annotate.py +++ b/torch/fx/annotate.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.proxy import Proxy from ._compatibility import compatibility diff --git a/torch/fx/experimental/_sym_dispatch_mode.py b/torch/fx/experimental/_sym_dispatch_mode.py index c3385de61683..6e48a8ca18f4 100644 --- a/torch/fx/experimental/_sym_dispatch_mode.py +++ b/torch/fx/experimental/_sym_dispatch_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Type __all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"] diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index fc28f112323f..9b347762dedb 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from collections import deque from typing import Dict, List, Set, NamedTuple, Tuple, Deque diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 8176ccb562fa..cb94ed3930ed 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import re from typing import Callable, Dict, Optional, Set, Union diff --git a/torch/fx/experimental/debug.py b/torch/fx/experimental/debug.py index bd6fed690914..d3c482319f2e 100644 --- a/torch/fx/experimental/debug.py +++ b/torch/fx/experimental/debug.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.fx as fx def set_trace(gm: fx.GraphModule) -> fx.GraphModule: diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index e44a75ddad08..a6ac80fd72fb 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import reduce import torch import operator diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py index bd56694773e9..c1a634b2602a 100644 --- a/torch/fx/experimental/merge_matmul.py +++ b/torch/fx/experimental/merge_matmul.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx.node import Node diff --git a/torch/fx/experimental/meta_tracer.py b/torch/fx/experimental/meta_tracer.py index be19e7b93ac8..b09e221f6b36 100644 --- a/torch/fx/experimental/meta_tracer.py +++ b/torch/fx/experimental/meta_tracer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.fx import warnings diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py index 3c1f724d26a5..45038837cae6 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \ op_mod, op_gt, op_lt, op_neq, op_eq from torch.fx.tensor_type import TensorType, Dyn diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index 031562393edc..e04fc26b408e 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import operator import warnings diff --git a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py index 15af0241ec5b..c8cf70006cd8 100644 --- a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py +++ b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim diff --git a/torch/fx/experimental/migrate_gradual_types/util.py b/torch/fx/experimental/migrate_gradual_types/util.py index a43d8f3ebbe0..99f94609f265 100644 --- a/torch/fx/experimental/migrate_gradual_types/util.py +++ b/torch/fx/experimental/migrate_gradual_types/util.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \ BVar from torch.fx.experimental.migrate_gradual_types.operation import op_leq diff --git a/torch/fx/experimental/normalize.py b/torch/fx/experimental/normalize.py index 06bc2309975c..30b076a72bee 100644 --- a/torch/fx/experimental/normalize.py +++ b/torch/fx/experimental/normalize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from typing import Any, Callable, Dict, Tuple, Optional diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index be411d9b6eff..8362c0cb88ac 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.fx as fx from torch.fx.node import Argument, Target from torch.nn.utils.fusion import fuse_conv_bn_eval diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py index d96c6b40667f..796c65a43022 100644 --- a/torch/fx/experimental/partitioner_utils.py +++ b/torch/fx/experimental/partitioner_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum from typing import NamedTuple, Dict, List, Set diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 4bf9ebab17b3..3eeb7ad02602 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import inspect import itertools diff --git a/torch/fx/experimental/refinement_types.py b/torch/fx/experimental/refinement_types.py index 762e4340f12b..a33ddf3710a4 100644 --- a/torch/fx/experimental/refinement_types.py +++ b/torch/fx/experimental/refinement_types.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs class Equality: def __init__(self, lhs, rhs): self.lhs = lhs diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py index 85a95895f7c9..8cfb030b9f77 100644 --- a/torch/fx/experimental/rewriter.py +++ b/torch/fx/experimental/rewriter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import inspect import textwrap diff --git a/torch/fx/experimental/schema_type_annotation.py b/torch/fx/experimental/schema_type_annotation.py index a2a840408618..5c7ab78706cb 100644 --- a/torch/fx/experimental/schema_type_annotation.py +++ b/torch/fx/experimental/schema_type_annotation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.fx import inspect diff --git a/torch/fx/experimental/shape_inference/infer_shape.py b/torch/fx/experimental/shape_inference/infer_shape.py index 3c2e0c22bd89..10f5d53712ae 100644 --- a/torch/fx/experimental/shape_inference/infer_shape.py +++ b/torch/fx/experimental/shape_inference/infer_shape.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from collections import defaultdict diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 98cba67a73a1..559c3f8ed4cd 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file does three things: - Contains the definition of SymNode diff --git a/torch/fx/experimental/unification/core.py b/torch/fx/experimental/unification/core.py index 560ceb588924..0893c385bbc9 100644 --- a/torch/fx/experimental/unification/core.py +++ b/torch/fx/experimental/unification/core.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections.abc import Iterator # type: ignore[import] from functools import partial diff --git a/torch/fx/experimental/unification/match.py b/torch/fx/experimental/unification/match.py index dd459726917f..96583ef324de 100644 --- a/torch/fx/experimental/unification/match.py +++ b/torch/fx/experimental/unification/match.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .core import unify, reify # type: ignore[attr-defined] from .variable import isvar from .utils import _toposort, freeze diff --git a/torch/fx/experimental/unification/more.py b/torch/fx/experimental/unification/more.py index 2b074235f14a..2228448a71a1 100644 --- a/torch/fx/experimental/unification/more.py +++ b/torch/fx/experimental/unification/more.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .core import unify, reify # type: ignore[attr-defined] from .dispatch import dispatch diff --git a/torch/fx/experimental/unification/multipledispatch/conflict.py b/torch/fx/experimental/unification/multipledispatch/conflict.py index 6c247bd98111..7187330ead25 100644 --- a/torch/fx/experimental/unification/multipledispatch/conflict.py +++ b/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .utils import _toposort, groupby from .variadic import isvariadic import operator diff --git a/torch/fx/experimental/unification/multipledispatch/core.py b/torch/fx/experimental/unification/multipledispatch/core.py index 2a8ed78e52e3..5b5bdbc96301 100644 --- a/torch/fx/experimental/unification/multipledispatch/core.py +++ b/torch/fx/experimental/unification/multipledispatch/core.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import sys diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index c46e47e5d35b..a1d28201d041 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from warnings import warn import inspect from typing_extensions import deprecated diff --git a/torch/fx/experimental/unification/multipledispatch/utils.py b/torch/fx/experimental/unification/multipledispatch/utils.py index 4b5ec2ed6315..0e90241cf69c 100644 --- a/torch/fx/experimental/unification/multipledispatch/utils.py +++ b/torch/fx/experimental/unification/multipledispatch/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict __all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] diff --git a/torch/fx/experimental/unification/multipledispatch/variadic.py b/torch/fx/experimental/unification/multipledispatch/variadic.py index 0f046ba55bd3..49e546e1ea26 100644 --- a/torch/fx/experimental/unification/multipledispatch/variadic.py +++ b/torch/fx/experimental/unification/multipledispatch/variadic.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .utils import typename __all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"] diff --git a/torch/fx/experimental/unification/unification_tools.py b/torch/fx/experimental/unification/unification_tools.py index ae159b937ec0..472cd487f62f 100644 --- a/torch/fx/experimental/unification/unification_tools.py +++ b/torch/fx/experimental/unification/unification_tools.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import operator from functools import reduce diff --git a/torch/fx/experimental/unification/utils.py b/torch/fx/experimental/unification/utils.py index 56cde39319e3..2147d6175136 100644 --- a/torch/fx/experimental/unification/utils.py +++ b/torch/fx/experimental/unification/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs __all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] def hashable(x): try: diff --git a/torch/fx/experimental/unification/variable.py b/torch/fx/experimental/unification/variable.py index 8f7efda3328b..66e97a3a7663 100644 --- a/torch/fx/experimental/unification/variable.py +++ b/torch/fx/experimental/unification/variable.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager from .utils import hashable from .dispatch import dispatch diff --git a/torch/fx/experimental/unify_refinements.py b/torch/fx/experimental/unify_refinements.py index 532d2784fb49..cad0a33425bf 100644 --- a/torch/fx/experimental/unify_refinements.py +++ b/torch/fx/experimental/unify_refinements.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.experimental.graph_gradual_typechecker import Refine from torch.fx.tensor_type import TensorType from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 6dcb59db7979..f9219fa4d551 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging import math diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 7c73c89473d5..dea8265f134d 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name import torch.utils._pytree as pytree diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index c5d0df29b903..5fb6691dda7c 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import copy import itertools diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py index 7ad3807f23bb..2ff29cba474d 100644 --- a/torch/fx/immutable_collections.py +++ b/torch/fx/immutable_collections.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, Iterable, List, Tuple from torch.utils._pytree import ( diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index 23c006fbbd5f..61f3a6919015 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .graph_module import GraphModule from ._lazy_graph_module import _make_graph_module from .graph import Graph diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 142740a322bc..becd1ffcd6f4 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import inspect import numbers diff --git a/torch/fx/passes/backends/cudagraphs.py b/torch/fx/passes/backends/cudagraphs.py index d423de930dc7..0f48165b7dab 100644 --- a/torch/fx/passes/backends/cudagraphs.py +++ b/torch/fx/passes/backends/cudagraphs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport diff --git a/torch/fx/passes/dialect/common/cse_pass.py b/torch/fx/passes/dialect/common/cse_pass.py index dc95a70a22a7..577f445e7b31 100644 --- a/torch/fx/passes/dialect/common/cse_pass.py +++ b/torch/fx/passes/dialect/common/cse_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, Tuple, Any import torch diff --git a/torch/fx/passes/fake_tensor_prop.py b/torch/fx/passes/fake_tensor_prop.py index 58ee61f10089..04aadbbdc9b9 100644 --- a/torch/fx/passes/fake_tensor_prop.py +++ b/torch/fx/passes/fake_tensor_prop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch.fx diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index 7256c41dcdec..ec2336dbdeab 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import hashlib import torch diff --git a/torch/fx/passes/graph_manipulation.py b/torch/fx/passes/graph_manipulation.py index f6e53f0e969a..36c59cb31af0 100644 --- a/torch/fx/passes/graph_manipulation.py +++ b/torch/fx/passes/graph_manipulation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, NamedTuple, Optional import torch diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index a2ec324f512c..83975a930115 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os from typing import Optional diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 3952bb652517..095be545eb54 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.passes.utils.fuser_utils import fuse_by_partitions import collections import itertools diff --git a/torch/fx/passes/infra/pass_base.py b/torch/fx/passes/infra/pass_base.py index dd699ea86cde..488450ab24ec 100644 --- a/torch/fx/passes/infra/pass_base.py +++ b/torch/fx/passes/infra/pass_base.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc from collections import namedtuple from typing import Optional diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index 44de7fcc0b1b..fcf0499b9dd1 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import logging from queue import Queue diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 6d050c78f754..e250dd09a121 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple diff --git a/torch/fx/passes/operator_support.py b/torch/fx/passes/operator_support.py index ce050f046eea..8edd3c746dbb 100644 --- a/torch/fx/passes/operator_support.py +++ b/torch/fx/passes/operator_support.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import typing as t diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index 55d5ea0af54d..b90f338f303d 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import wraps from inspect import unwrap from typing import Callable, List, Optional diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index 6f6014b1c2af..535c63aa1bad 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx import Node from torch.fx._compatibility import compatibility diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 05e7f31ffb4e..66b8fbe29d9f 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import operator from typing import Any, Dict, Optional, Set, TYPE_CHECKING diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 977741cfe62d..093d7e4071d0 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from typing import Any, Callable, Dict, List, Optional, Set from collections import OrderedDict diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py index 1282081af67b..38aa56064db6 100644 --- a/torch/fx/passes/split_utils.py +++ b/torch/fx/passes/split_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Type, Union diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index b37f8ecf1d0c..f4aa439b409d 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import copy from collections import defaultdict diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py index 7dc757a9c0e5..aac071ace8c2 100644 --- a/torch/fx/passes/tools_common.py +++ b/torch/fx/passes/tools_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional import collections from dataclasses import dataclass diff --git a/torch/fx/passes/utils/common.py b/torch/fx/passes/utils/common.py index 3bd030337df4..ba2ae45aabf5 100644 --- a/torch/fx/passes/utils/common.py +++ b/torch/fx/passes/utils/common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, Tuple from torch.fx._compatibility import compatibility diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 3423ea3dad5a..cc26dea3cc44 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from queue import SimpleQueue from typing import List, Dict, Tuple diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py index 00415d10fee7..a69806829875 100644 --- a/torch/fx/passes/utils/matcher_utils.py +++ b/torch/fx/passes/utils/matcher_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass, field from collections import defaultdict import copy diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index 2830f60d5eab..0f2650ea8d49 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass, field from torch.fx.graph import Graph from torch.fx.node import Node diff --git a/torch/fx/tensor_type.py b/torch/fx/tensor_type.py index c822a38ec78e..f59ed2d45baa 100644 --- a/torch/fx/tensor_type.py +++ b/torch/fx/tensor_type.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.experimental.unification import Var # type: ignore[attr-defined] from ._compatibility import compatibility diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index a582e03979c4..4e72a8011f63 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import traceback from contextlib import contextmanager from typing import List, Any, Dict diff --git a/torch/hub.py b/torch/hub.py index 0ba9e25a2830..213a1290bebd 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import errno import hashlib diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index a5b9f5627ea7..6d1760fb9f4f 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from contextlib import contextmanager diff --git a/torch/jit/_async.py b/torch/jit/_async.py index 2134975bb953..bdde55adf14f 100644 --- a/torch/jit/_async.py +++ b/torch/jit/_async.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Async API. This module contains the API for parallelism in TorchScript, notably: diff --git a/torch/jit/_await.py b/torch/jit/_await.py index a79952bf3e2d..e86493512e59 100644 --- a/torch/jit/_await.py +++ b/torch/jit/_await.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._jit_internal import _Await from torch.jit._builtins import _register_builtin diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index f50e1bbfedb5..ecf0223cebe6 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import cmath import math import warnings diff --git a/torch/jit/_check.py b/torch/jit/_check.py index 0dc2cb6d37ba..8db5bb82ce3d 100644 --- a/torch/jit/_check.py +++ b/torch/jit/_check.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import inspect import textwrap diff --git a/torch/jit/_dataclass_impls.py b/torch/jit/_dataclass_impls.py index 52056ce46bea..2dc1dfba076f 100644 --- a/torch/jit/_dataclass_impls.py +++ b/torch/jit/_dataclass_impls.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Functions for synthesizing magic methods for JIT-compiled dataclasses import ast import dataclasses diff --git a/torch/jit/_decomposition_utils.py b/torch/jit/_decomposition_utils.py index fb4448e2b900..795f9da8e073 100644 --- a/torch/jit/_decomposition_utils.py +++ b/torch/jit/_decomposition_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._ops import OpOverload, OpOverloadPacket diff --git a/torch/jit/_decompositions.py b/torch/jit/_decompositions.py index babb70eaf7cb..8ac456be482b 100644 --- a/torch/jit/_decompositions.py +++ b/torch/jit/_decompositions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py index 731f28305628..8f35fc471e68 100644 --- a/torch/jit/_freeze.py +++ b/torch/jit/_freeze.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Freezing. This is not intended to be imported directly; please use the exposed diff --git a/torch/jit/_fuser.py b/torch/jit/_fuser.py index 253682736034..7466800402d2 100644 --- a/torch/jit/_fuser.py +++ b/torch/jit/_fuser.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import List, Tuple diff --git a/torch/jit/_ir_utils.py b/torch/jit/_ir_utils.py index 028247f54011..52b953624a3a 100644 --- a/torch/jit/_ir_utils.py +++ b/torch/jit/_ir_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Union import torch diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index 3b19e8438d4e..4662869e3683 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import pathlib import sys diff --git a/torch/jit/_passes/_property_propagation.py b/torch/jit/_passes/_property_propagation.py index 8ebd21e4bc10..1537f7bc4147 100644 --- a/torch/jit/_passes/_property_propagation.py +++ b/torch/jit/_passes/_property_propagation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Tools to help with tensor property propagation. diff --git a/torch/jit/_pickle.py b/torch/jit/_pickle.py index 1cb4a0a93efd..5517499e9260 100644 --- a/torch/jit/_pickle.py +++ b/torch/jit/_pickle.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # These functions are referenced from the pickle archives produced by # ScriptModule.save() diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index a76a0c4a2cb0..fc37237edd30 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import functools import inspect diff --git a/torch/jit/_script.pyi b/torch/jit/_script.pyi index b43a8bc7089e..b1f39b2bc706 100644 --- a/torch/jit/_script.pyi +++ b/torch/jit/_script.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" from typing import ( Any, diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index 514f23cb76d3..b9b9691401d3 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Serialization. This module contains functionality for serializing TorchScript modules, notably: diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index bef34e28239b..18b69acddc09 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union diff --git a/torch/jit/_state.py b/torch/jit/_state.py index 1d75415ef80e..63df2acfdf09 100644 --- a/torch/jit/_state.py +++ b/torch/jit/_state.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """JIT-related state. This module stores various pieces of Python-global state relating to the JIT. diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 8be700ee7711..5bdd71f94381 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Tracing. This module contains functionality to support the JIT's tracing frontend, notably: diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index a24fad838353..76d5ce5805b6 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import builtins import dis diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index ea834f664f4f..775120a67ccb 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import dataclasses import inspect diff --git a/torch/jit/generate_bytecode.py b/torch/jit/generate_bytecode.py index 8e56c7665d1c..f66bf7bfc4c1 100644 --- a/torch/jit/generate_bytecode.py +++ b/torch/jit/generate_bytecode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List from torch._C import _compile_graph_to_code_table, _generate_upgraders_graph diff --git a/torch/jit/mobile/__init__.py b/torch/jit/mobile/__init__.py index 63632de23d3f..ba29b31bccc5 100644 --- a/torch/jit/mobile/__init__.py +++ b/torch/jit/mobile/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import torch diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py index c7c679c79456..a2500c1f1b9f 100644 --- a/torch/jit/quantized.py +++ b/torch/jit/quantized.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index c06664a6cff2..3bfec99feb17 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import textwrap diff --git a/torch/jit/unsupported_tensor_ops.py b/torch/jit/unsupported_tensor_ops.py index 4e553757eab4..f8c9be4f5b06 100644 --- a/torch/jit/unsupported_tensor_ops.py +++ b/torch/jit/unsupported_tensor_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from textwrap import dedent from typing import Any, Dict diff --git a/torch/library.py b/torch/library.py index da8c5a1264a2..d0a4cf24f088 100644 --- a/torch/library.py +++ b/torch/library.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from ._ops import OpOverload from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence from typing_extensions import deprecated diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 0c082f7cd01f..26094459c171 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 7544fc84ff9f..802c52aecafd 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from functools import partial diff --git a/torch/masked/maskedtensor/binary.py b/torch/masked/maskedtensor/binary.py index b035678f73a6..7b64cfa0fbd9 100644 --- a/torch/masked/maskedtensor/binary.py +++ b/torch/masked/maskedtensor/binary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import torch diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 4574fed9c0d6..0933a804fcc7 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import warnings diff --git a/torch/masked/maskedtensor/creation.py b/torch/masked/maskedtensor/creation.py index 6b490edfc058..a013ef1beb66 100644 --- a/torch/masked/maskedtensor/creation.py +++ b/torch/masked/maskedtensor/creation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from .core import MaskedTensor diff --git a/torch/masked/maskedtensor/passthrough.py b/torch/masked/maskedtensor/passthrough.py index d8c87a9c2110..4a2e79456c86 100644 --- a/torch/masked/maskedtensor/passthrough.py +++ b/torch/masked/maskedtensor/passthrough.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates """ These are functions that should simply be applied to both mask and data. diff --git a/torch/masked/maskedtensor/reductions.py b/torch/masked/maskedtensor/reductions.py index d36df2715c0b..fedab1c12a63 100644 --- a/torch/masked/maskedtensor/reductions.py +++ b/torch/masked/maskedtensor/reductions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import warnings diff --git a/torch/masked/maskedtensor/unary.py b/torch/masked/maskedtensor/unary.py index 4bfe987ef004..790d86ef92e4 100644 --- a/torch/masked/maskedtensor/unary.py +++ b/torch/masked/maskedtensor/unary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import torch diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py index 6118c2b05686..0538ae50d1ad 100644 --- a/torch/mps/__init__.py +++ b/torch/mps/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python. Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased diff --git a/torch/mps/event.py b/torch/mps/event.py index a206b640ef4a..d619c027480c 100644 --- a/torch/mps/event.py +++ b/torch/mps/event.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/mps/profiler.py b/torch/mps/profiler.py index 9094a275136c..d9ca3f55c5e6 100644 --- a/torch/mps/profiler.py +++ b/torch/mps/profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index 4007f0e584f2..b68a25bdb61b 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This package enables an interface for accessing MTIA backend in python """ diff --git a/torch/multiprocessing/__init__.py b/torch/multiprocessing/__init__.py index 8cbb1fb07ff8..5d69bc7daa1a 100644 --- a/torch/multiprocessing/__init__.py +++ b/torch/multiprocessing/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module. It registers custom reducers, that use shared memory to provide shared diff --git a/torch/multiprocessing/_atfork.py b/torch/multiprocessing/_atfork.py index 92a3280fee78..37ebe377838d 100644 --- a/torch/multiprocessing/_atfork.py +++ b/torch/multiprocessing/_atfork.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys __all__ = ["register_after_fork"] diff --git a/torch/multiprocessing/queue.py b/torch/multiprocessing/queue.py index 99da145e75f1..876bf8d0e745 100644 --- a/torch/multiprocessing/queue.py +++ b/torch/multiprocessing/queue.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import io import multiprocessing.queues import pickle diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index f5eb0a6abd86..9de36c39d7b5 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import multiprocessing import os import threading diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 88bdc5155342..408a3908cf45 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import multiprocessing import multiprocessing.connection diff --git a/torch/nested/__init__.py b/torch/nested/__init__.py index ea1cce595011..0a12e14e1aff 100644 --- a/torch/nested/__init__.py +++ b/torch/nested/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple, Union import torch diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 5ef8983a8393..66d25eacc7ad 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Tuple import torch diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 85f62170595c..f900a9a9ab01 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import math import operator diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index c393fb1bf357..b7c69c905e9a 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import Optional, Tuple diff --git a/torch/nn/__init__.py b/torch/nn/__init__.py index 3d317b7c09f2..23447d484409 100644 --- a/torch/nn/__init__.py +++ b/torch/nn/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .modules import * # noqa: F403 from .parameter import ( Parameter as Parameter, diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 039d76a32f4b..6bf1ffb68e69 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """ import contextlib from typing import List, Union diff --git a/torch/nn/attention/_flex_attention.py b/torch/nn/attention/_flex_attention.py index 430d3280442a..06ddd7c3dc2f 100644 --- a/torch/nn/attention/_flex_attention.py +++ b/torch/nn/attention/_flex_attention.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This module implements the user facing API for flex_attention in PyTorch.""" import functools from typing import Callable diff --git a/torch/nn/attention/_utils.py b/torch/nn/attention/_utils.py index 6662eb58f361..9785f74c6683 100644 --- a/torch/nn/attention/_utils.py +++ b/torch/nn/attention/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Defines utilities for interacting with scaled_dot_product_attention""" import math from typing import List, Optional diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index c7f6b41d660c..773ed38f82e8 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Defines bias subclasses that work with scaled_dot_product_attention""" from enum import auto, IntEnum from typing import Optional From 27f9d3b0a17289500f1a2d24e4901a1e0fa9ea95 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 11:41:15 -0700 Subject: [PATCH 543/706] Flip default value for mypy disallow_untyped_defs [8/11] (#127845) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127845 Approved by: https://github.com/oulgen ghstack dependencies: #127842, #127843, #127844 --- torch/nn/backends/thnn.py | 1 + torch/nn/cpp.py | 1 + torch/nn/grad.py | 1 + torch/nn/init.py | 1 + torch/nn/modules/_functions.py | 1 + torch/nn/modules/activation.py | 1 + torch/nn/modules/adaptive.py | 1 + torch/nn/modules/batchnorm.py | 1 + torch/nn/modules/container.py | 1 + torch/nn/modules/conv.py | 1 + torch/nn/modules/flatten.py | 1 + torch/nn/modules/instancenorm.py | 1 + torch/nn/modules/lazy.py | 1 + torch/nn/modules/linear.py | 1 + torch/nn/modules/loss.py | 1 + torch/nn/modules/module.py | 1 + torch/nn/modules/normalization.py | 1 + torch/nn/modules/padding.py | 1 + torch/nn/modules/rnn.py | 1 + torch/nn/modules/sparse.py | 1 + torch/nn/modules/transformer.py | 1 + torch/nn/modules/upsampling.py | 1 + torch/nn/modules/utils.py | 1 + torch/nn/parallel/__init__.py | 1 + torch/nn/parallel/comm.py | 1 + torch/nn/parallel/data_parallel.py | 1 + torch/nn/parallel/distributed.py | 1 + torch/nn/parallel/scatter_gather.py | 1 + torch/nn/parameter.pyi | 1 + torch/nn/utils/_deprecation_utils.py | 1 + torch/nn/utils/_expanded_weights/conv_expanded_weights.py | 1 + torch/nn/utils/_expanded_weights/conv_utils.py | 1 + torch/nn/utils/_expanded_weights/embedding_expanded_weights.py | 1 + torch/nn/utils/_expanded_weights/expanded_weights_impl.py | 1 + torch/nn/utils/_expanded_weights/expanded_weights_utils.py | 1 + torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py | 1 + .../nn/utils/_expanded_weights/instance_norm_expanded_weights.py | 1 + torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py | 1 + torch/nn/utils/_expanded_weights/linear_expanded_weights.py | 1 + torch/nn/utils/_per_sample_grad.py | 1 + torch/nn/utils/clip_grad.py | 1 + torch/nn/utils/init.py | 1 + torch/nn/utils/memory_format.py | 1 + torch/nn/utils/parametrizations.py | 1 + torch/nn/utils/parametrize.py | 1 + torch/nn/utils/prune.py | 1 + torch/nn/utils/rnn.pyi | 1 + torch/nn/utils/spectral_norm.py | 1 + torch/nn/utils/stateless.py | 1 + torch/nn/utils/weight_norm.py | 1 + torch/onnx/__init__.py | 1 + torch/onnx/_deprecation.py | 1 + torch/onnx/_globals.py | 1 + torch/onnx/_internal/_beartype.py | 1 + torch/onnx/_internal/diagnostics/_diagnostic.py | 1 + torch/onnx/_internal/diagnostics/_rules.py | 1 + torch/onnx/_internal/diagnostics/infra/_infra.py | 1 + torch/onnx/_internal/diagnostics/infra/context.py | 1 + torch/onnx/_internal/diagnostics/infra/decorator.py | 1 + torch/onnx/_internal/exporter.py | 1 + torch/onnx/_internal/fx/_pass.py | 1 + torch/onnx/_internal/fx/analysis/unsupported_nodes.py | 1 + torch/onnx/_internal/fx/decomposition_skip.py | 1 + torch/onnx/_internal/fx/decomposition_table.py | 1 + torch/onnx/_internal/fx/diagnostics.py | 1 + torch/onnx/_internal/fx/dynamo_graph_extractor.py | 1 + torch/onnx/_internal/fx/fx_onnx_interpreter.py | 1 + torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py | 1 + torch/onnx/_internal/fx/onnxfunction_dispatcher.py | 1 + torch/onnx/_internal/fx/op_validation.py | 1 + torch/onnx/_internal/fx/passes/_utils.py | 1 + torch/onnx/_internal/fx/passes/decomp.py | 1 + torch/onnx/_internal/fx/passes/functionalization.py | 1 + torch/onnx/_internal/fx/passes/modularization.py | 1 + torch/onnx/_internal/fx/passes/readability.py | 1 + torch/onnx/_internal/fx/passes/type_promotion.py | 1 + torch/onnx/_internal/fx/passes/virtualization.py | 1 + torch/onnx/_internal/fx/patcher.py | 1 + torch/onnx/_internal/fx/serialization.py | 1 + torch/onnx/_internal/fx/torch_export_graph_extractor.py | 1 + torch/onnx/_internal/fx/type_utils.py | 1 + torch/onnx/_internal/io_adapter.py | 1 + torch/onnx/_internal/jit_utils.py | 1 + torch/onnx/_internal/onnx_proto_utils.py | 1 + torch/onnx/_internal/onnxruntime.py | 1 + torch/onnx/_internal/registration.py | 1 + torch/onnx/_onnx_supported_ops.py | 1 + torch/onnx/_type_utils.py | 1 + torch/onnx/operators.py | 1 + torch/onnx/symbolic_caffe2.py | 1 + torch/onnx/symbolic_helper.py | 1 + torch/onnx/symbolic_opset10.py | 1 + torch/onnx/symbolic_opset11.py | 1 + torch/onnx/symbolic_opset12.py | 1 + torch/onnx/symbolic_opset13.py | 1 + torch/onnx/symbolic_opset14.py | 1 + torch/onnx/symbolic_opset15.py | 1 + torch/onnx/symbolic_opset16.py | 1 + torch/onnx/symbolic_opset17.py | 1 + torch/onnx/symbolic_opset18.py | 1 + torch/onnx/symbolic_opset20.py | 1 + torch/onnx/symbolic_opset7.py | 1 + torch/onnx/symbolic_opset8.py | 1 + torch/onnx/symbolic_opset9.py | 1 + torch/onnx/utils.py | 1 + torch/onnx/verification.py | 1 + torch/optim/_functional.py | 1 + torch/optim/adadelta.py | 1 + torch/optim/adagrad.py | 1 + torch/optim/adam.py | 1 + torch/optim/adamax.py | 1 + torch/optim/adamw.py | 1 + torch/optim/asgd.py | 1 + torch/optim/lbfgs.py | 1 + torch/optim/lr_scheduler.py | 1 + torch/optim/nadam.py | 1 + torch/optim/optimizer.py | 1 + torch/optim/radam.py | 1 + torch/optim/rmsprop.py | 1 + torch/optim/rprop.py | 1 + torch/optim/sgd.py | 1 + torch/optim/sparse_adam.py | 1 + torch/optim/swa_utils.py | 1 + torch/package/_digraph.py | 1 + torch/package/_directory_reader.py | 1 + torch/package/_importlib.py | 1 + torch/package/_mangling.py | 1 + torch/package/_mock.py | 1 + 128 files changed, 128 insertions(+) diff --git a/torch/nn/backends/thnn.py b/torch/nn/backends/thnn.py index 5250b4bff167..3cb0f3ff57e2 100644 --- a/torch/nn/backends/thnn.py +++ b/torch/nn/backends/thnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # this is for historical pickle deserialization, it is not used otherwise def _get_thnn_function_backend(): diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py index a08c7b314100..98a61bfb7c42 100644 --- a/torch/nn/cpp.py +++ b/torch/nn/cpp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Functionality for Python <-> C++ frontend inter-op.""" from torch import nn diff --git a/torch/nn/grad.py b/torch/nn/grad.py index 660c87fb4133..dbd38fcdd38c 100644 --- a/torch/nn/grad.py +++ b/torch/nn/grad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Gradient interface.""" import torch diff --git a/torch/nn/init.py b/torch/nn/init.py index f5be081e7dd0..b3179abb4937 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file contains utilities for initializing neural network parameters.""" import math import warnings diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 669448ce4fda..0e19faa99e5c 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 5dec6f9578b1..3d8b65175956 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from typing import Optional, Tuple diff --git a/torch/nn/modules/adaptive.py b/torch/nn/modules/adaptive.py index 83b37696c8a7..a6c2da5f596f 100644 --- a/torch/nn/modules/adaptive.py +++ b/torch/nn/modules/adaptive.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import namedtuple diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 3c48e56d5e6e..75c8b5504d46 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional, Any import torch diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 775a826d69cc..c82d8d7d3037 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict, abc as container_abcs from itertools import chain, islice import operator diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 4ab4c8bff9fc..fb6a1557aa71 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/nn/modules/flatten.py b/torch/nn/modules/flatten.py index eaf62d5bbeea..f1c44fd350d1 100644 --- a/torch/nn/modules/flatten.py +++ b/torch/nn/modules/flatten.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .module import Module from typing import Tuple, Union diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index ae187e98b7e6..e6a3e1c0a3a1 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from torch import Tensor diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index c4b7459c4acd..f4be1b7db706 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from typing import Protocol, Optional, Type, Any diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 54981596f7ee..be2739462399 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from typing import Any diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 4324c1df144d..497da8218506 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .distance import PairwiseDistance from .module import Module from .. import functional as F diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index ffd429cc06f2..f803d3f02a17 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict, namedtuple import itertools import warnings diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index 97c9c307c5d9..d503409d53a1 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import numbers from torch.nn.parameter import Parameter diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 0aecca58c305..4b29fbf1c8f4 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .module import Module from .utils import _pair, _quadruple, _ntuple from .. import functional as F diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index b4bdd7824474..8ba4f9f08319 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import warnings import numbers diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index f053a0c8f3c2..512b17d03222 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 3c9a8547df32..f5980cd6b1e8 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from typing import Optional, Any, Union, Callable diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py index da9b23add18d..7d674da0d5c3 100644 --- a/torch/nn/modules/upsampling.py +++ b/torch/nn/modules/upsampling.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .module import Module from .. import functional as F diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 019dabe3e533..4a051ed1eba5 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections from itertools import repeat from typing import List, Dict, Any diff --git a/torch/nn/parallel/__init__.py b/torch/nn/parallel/__init__.py index adcd6bd838eb..8f08e5099d8b 100644 --- a/torch/nn/parallel/__init__.py +++ b/torch/nn/parallel/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing_extensions import deprecated from .parallel_apply import parallel_apply diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index 22cf80bd64e2..b907de4004b1 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings import torch from torch.cuda import nccl diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index 4471cee6f379..3980706a932a 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator import torch import warnings diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 37b6501d2c9c..c71c838cfb85 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import inspect diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index f6fb9d47ecbf..73e753760e72 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, overload from typing_extensions import deprecated diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index 219bb6d4efa2..221ffacc3520 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins from typing import Optional, Tuple diff --git a/torch/nn/utils/_deprecation_utils.py b/torch/nn/utils/_deprecation_utils.py index 1b2a9b6e29f2..9910db96e66c 100644 --- a/torch/nn/utils/_deprecation_utils.py +++ b/torch/nn/utils/_deprecation_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Callable import importlib import warnings diff --git a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py index c10ccb90ae92..147346796d1f 100644 --- a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F diff --git a/torch/nn/utils/_expanded_weights/conv_utils.py b/torch/nn/utils/_expanded_weights/conv_utils.py index b675e3b892bd..2836809d40be 100644 --- a/torch/nn/utils/_expanded_weights/conv_utils.py +++ b/torch/nn/utils/_expanded_weights/conv_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F diff --git a/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py b/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py index c7956a3a1b1f..593fa9e5eed7 100644 --- a/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F from .expanded_weights_impl import implements_per_sample_grads diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index 94e6041c6de5..664e65cc7d90 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager import torch diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index 249dbe591204..840be6a163f5 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch diff --git a/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py index fe29b1eafbe2..6e2919803e4f 100644 --- a/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import reduce import operator import torch diff --git a/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py index f3e68b940660..1d0f40c54081 100644 --- a/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import partial import torch import torch.nn.functional as F diff --git a/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py index f2ead2d4c08f..b18c284cd7cf 100644 --- a/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F diff --git a/torch/nn/utils/_expanded_weights/linear_expanded_weights.py b/torch/nn/utils/_expanded_weights/linear_expanded_weights.py index c2cbae63f336..6a80c1dc9219 100644 --- a/torch/nn/utils/_expanded_weights/linear_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/linear_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F from .expanded_weights_impl import implements_per_sample_grads diff --git a/torch/nn/utils/_per_sample_grad.py b/torch/nn/utils/_per_sample_grad.py index 0644ab5d2535..a64942083f0c 100644 --- a/torch/nn/utils/_per_sample_grad.py +++ b/torch/nn/utils/_per_sample_grad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 4ac8a4e7445b..cc83353909f9 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Union, Iterable, List, Dict, Tuple, Optional, cast from typing_extensions import deprecated diff --git a/torch/nn/utils/init.py b/torch/nn/utils/init.py index 416ad0db8ef7..4768d3009005 100644 --- a/torch/nn/utils/init.py +++ b/torch/nn/utils/init.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import torch diff --git a/torch/nn/utils/memory_format.py b/torch/nn/utils/memory_format.py index c8fc22bea51c..aaa2b6bfb198 100644 --- a/torch/nn/utils/memory_format.py +++ b/torch/nn/utils/memory_format.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index f9b25bcac0cb..cf686504072f 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum, auto import torch diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index f512b7c3b22a..b828c1d230f1 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.__future__ import get_swap_module_params_on_conversion from torch.nn.modules.container import ModuleList, ModuleDict, Module diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 0375106d69e0..b0e1f99a6c1f 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Pruning methods.""" import numbers from abc import ABC, abstractmethod diff --git a/torch/nn/utils/rnn.pyi b/torch/nn/utils/rnn.pyi index fd033d8888be..9ffc650714ff 100644 --- a/torch/nn/utils/rnn.pyi +++ b/torch/nn/utils/rnn.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Iterable, NamedTuple, Optional, overload, Sequence, Tuple, Union from typing_extensions import Self diff --git a/torch/nn/utils/spectral_norm.py b/torch/nn/utils/spectral_norm.py index bda54b9a1222..fcc4bbf5fe29 100644 --- a/torch/nn/utils/spectral_norm.py +++ b/torch/nn/utils/spectral_norm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Spectral Normalization from https://arxiv.org/abs/1802.05957.""" import torch from torch.nn.functional import normalize diff --git a/torch/nn/utils/stateless.py b/torch/nn/utils/stateless.py index 660a1a484ebb..07b03c04a120 100644 --- a/torch/nn/utils/stateless.py +++ b/torch/nn/utils/stateless.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from collections import defaultdict from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union diff --git a/torch/nn/utils/weight_norm.py b/torch/nn/utils/weight_norm.py index 6cfe4b3e526d..abb21a7b4672 100644 --- a/torch/nn/utils/weight_norm.py +++ b/torch/nn/utils/weight_norm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Weight Normalization from https://arxiv.org/abs/1602.07868.""" from torch.nn.parameter import Parameter, UninitializedParameter from torch import _weight_norm, norm_except_dim diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 4d16ef09c8b3..2b2f2bdae0de 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch import _C from torch._C import _onnx as _C_onnx from torch._C._onnx import ( diff --git a/torch/onnx/_deprecation.py b/torch/onnx/_deprecation.py index 0fd2cd764fc9..1f78dd55bd5d 100644 --- a/torch/onnx/_deprecation.py +++ b/torch/onnx/_deprecation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utility for deprecating functions.""" import functools diff --git a/torch/onnx/_globals.py b/torch/onnx/_globals.py index f827d12be7fb..22c05075dba8 100644 --- a/torch/onnx/_globals.py +++ b/torch/onnx/_globals.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Globals used internally by the ONNX exporter. Do not use this module outside of `torch.onnx` and its tests. diff --git a/torch/onnx/_internal/_beartype.py b/torch/onnx/_internal/_beartype.py index 25e1c1cb7299..1e5006fb56c1 100644 --- a/torch/onnx/_internal/_beartype.py +++ b/torch/onnx/_internal/_beartype.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """An internal wrapper for the beartype library. The module returns a no-op decorator when the beartype library is not installed. diff --git a/torch/onnx/_internal/diagnostics/_diagnostic.py b/torch/onnx/_internal/diagnostics/_diagnostic.py index 09079d5e9c4a..e5b22b07539c 100644 --- a/torch/onnx/_internal/diagnostics/_diagnostic.py +++ b/torch/onnx/_internal/diagnostics/_diagnostic.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`.""" from __future__ import annotations diff --git a/torch/onnx/_internal/diagnostics/_rules.py b/torch/onnx/_internal/diagnostics/_rules.py index 0bfda96c5bce..3b2ca727d0d1 100644 --- a/torch/onnx/_internal/diagnostics/_rules.py +++ b/torch/onnx/_internal/diagnostics/_rules.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ GENERATED CODE - DO NOT EDIT DIRECTLY This file is generated by gen_diagnostics.py. diff --git a/torch/onnx/_internal/diagnostics/infra/_infra.py b/torch/onnx/_internal/diagnostics/infra/_infra.py index c118f3e5ae14..e51c99a3151b 100644 --- a/torch/onnx/_internal/diagnostics/infra/_infra.py +++ b/torch/onnx/_internal/diagnostics/infra/_infra.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file defines an additional layer of abstraction on top of the SARIF OM.""" from __future__ import annotations diff --git a/torch/onnx/_internal/diagnostics/infra/context.py b/torch/onnx/_internal/diagnostics/infra/context.py index 6106a42467c1..f670adc2cae2 100644 --- a/torch/onnx/_internal/diagnostics/infra/context.py +++ b/torch/onnx/_internal/diagnostics/infra/context.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """A diagnostic context based on SARIF.""" from __future__ import annotations diff --git a/torch/onnx/_internal/diagnostics/infra/decorator.py b/torch/onnx/_internal/diagnostics/infra/decorator.py index 0ac803815703..67066f5da500 100644 --- a/torch/onnx/_internal/diagnostics/infra/decorator.py +++ b/torch/onnx/_internal/diagnostics/infra/decorator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index cf9f1cd747e5..ac62de0214dd 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (ONNXRuntimeOptions) annotations, ) diff --git a/torch/onnx/_internal/fx/_pass.py b/torch/onnx/_internal/fx/_pass.py index 69fa023b9add..cef8e045f7fb 100644 --- a/torch/onnx/_internal/fx/_pass.py +++ b/torch/onnx/_internal/fx/_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import abc diff --git a/torch/onnx/_internal/fx/analysis/unsupported_nodes.py b/torch/onnx/_internal/fx/analysis/unsupported_nodes.py index 5da0dbed3d91..deec2a85e1da 100644 --- a/torch/onnx/_internal/fx/analysis/unsupported_nodes.py +++ b/torch/onnx/_internal/fx/analysis/unsupported_nodes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import dataclasses diff --git a/torch/onnx/_internal/fx/decomposition_skip.py b/torch/onnx/_internal/fx/decomposition_skip.py index 7fb971a3307a..646e0765f190 100644 --- a/torch/onnx/_internal/fx/decomposition_skip.py +++ b/torch/onnx/_internal/fx/decomposition_skip.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """A context manager that disables the decomposition of certain ops during dynamo tracing. The approach is to temporarily hijack the operator callable with PT2 custom operator. diff --git a/torch/onnx/_internal/fx/decomposition_table.py b/torch/onnx/_internal/fx/decomposition_table.py index 5cb9be6da79d..027d580717af 100644 --- a/torch/onnx/_internal/fx/decomposition_table.py +++ b/torch/onnx/_internal/fx/decomposition_table.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Dispatcher for AtenLib functions from onnx-script.""" from __future__ import annotations diff --git a/torch/onnx/_internal/fx/diagnostics.py b/torch/onnx/_internal/fx/diagnostics.py index 11e4c79f2e1a..0be358751c11 100644 --- a/torch/onnx/_internal/fx/diagnostics.py +++ b/torch/onnx/_internal/fx/diagnostics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import dataclasses diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index fbc7d92e043f..1379a0613895 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # NOTE: This file is referenced by name at # /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. # introduced by https://github.com/pytorch/pytorch/pull/98894. diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py index 50ead7556f37..a0be86e11d6b 100644 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import inspect diff --git a/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py b/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py index 18dc84e19585..1d7d191cbd25 100644 --- a/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py +++ b/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py index 2986ac279ec3..3886733093a3 100644 --- a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py +++ b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Dispatcher for AtenLib functions from onnx-script.""" from __future__ import annotations diff --git a/torch/onnx/_internal/fx/op_validation.py b/torch/onnx/_internal/fx/op_validation.py index b306bc2141de..01161aee25ea 100644 --- a/torch/onnx/_internal/fx/op_validation.py +++ b/torch/onnx/_internal/fx/op_validation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Module for handling op-level validation during exporting.""" from __future__ import annotations diff --git a/torch/onnx/_internal/fx/passes/_utils.py b/torch/onnx/_internal/fx/passes/_utils.py index 92a883469a52..6e49bccfcfaf 100644 --- a/torch/onnx/_internal/fx/passes/_utils.py +++ b/torch/onnx/_internal/fx/passes/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Common utility functions for FX passes. These functions should NOT be directly invoked outside of `passes` package. diff --git a/torch/onnx/_internal/fx/passes/decomp.py b/torch/onnx/_internal/fx/passes/decomp.py index b9a131b97466..5185b1152485 100644 --- a/torch/onnx/_internal/fx/passes/decomp.py +++ b/torch/onnx/_internal/fx/passes/decomp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import contextlib diff --git a/torch/onnx/_internal/fx/passes/functionalization.py b/torch/onnx/_internal/fx/passes/functionalization.py index 21f2691cbb8e..dfdee6e88c85 100644 --- a/torch/onnx/_internal/fx/passes/functionalization.py +++ b/torch/onnx/_internal/fx/passes/functionalization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import contextlib diff --git a/torch/onnx/_internal/fx/passes/modularization.py b/torch/onnx/_internal/fx/passes/modularization.py index b7c3b90cab66..6e1352f73046 100644 --- a/torch/onnx/_internal/fx/passes/modularization.py +++ b/torch/onnx/_internal/fx/passes/modularization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import abc diff --git a/torch/onnx/_internal/fx/passes/readability.py b/torch/onnx/_internal/fx/passes/readability.py index 64887ad2ee6e..2b3518b79ea6 100644 --- a/torch/onnx/_internal/fx/passes/readability.py +++ b/torch/onnx/_internal/fx/passes/readability.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from typing import Dict, List, Sequence, Tuple, Union diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index 944cad4acf1c..bc584ff32925 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: onnx"] from __future__ import annotations diff --git a/torch/onnx/_internal/fx/passes/virtualization.py b/torch/onnx/_internal/fx/passes/virtualization.py index 66ca69d7a70f..cd77b6eec18b 100644 --- a/torch/onnx/_internal/fx/passes/virtualization.py +++ b/torch/onnx/_internal/fx/passes/virtualization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from typing import List, Optional, Tuple diff --git a/torch/onnx/_internal/fx/patcher.py b/torch/onnx/_internal/fx/patcher.py index ee919eae00d1..dbd8fb591126 100644 --- a/torch/onnx/_internal/fx/patcher.py +++ b/torch/onnx/_internal/fx/patcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import io diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py index 726bf4219330..5739442163a3 100644 --- a/torch/onnx/_internal/fx/serialization.py +++ b/torch/onnx/_internal/fx/serialization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import io diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py index fb3f0e99a6d6..a825e466f1aa 100644 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ b/torch/onnx/_internal/fx/torch_export_graph_extractor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # NOTE: This file is referenced by name at # /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. # introduced by https://github.com/pytorch/pytorch/pull/98894. diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index 90abdc244d99..3aac02a51214 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utilities for converting and operating on ONNX, JIT and torch types.""" from __future__ import annotations diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 2f8c9202d7bb..12100d0f489c 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import inspect diff --git a/torch/onnx/_internal/jit_utils.py b/torch/onnx/_internal/jit_utils.py index 719f4b0c16e8..13ae4209da5d 100644 --- a/torch/onnx/_internal/jit_utils.py +++ b/torch/onnx/_internal/jit_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utilities for manipulating the torch.Graph object and the torchscript.""" from __future__ import annotations diff --git a/torch/onnx/_internal/onnx_proto_utils.py b/torch/onnx/_internal/onnx_proto_utils.py index 278af3feacc6..40eb1bd8d64e 100644 --- a/torch/onnx/_internal/onnx_proto_utils.py +++ b/torch/onnx/_internal/onnx_proto_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto.""" from __future__ import annotations diff --git a/torch/onnx/_internal/onnxruntime.py b/torch/onnx/_internal/onnxruntime.py index aa3495ee5ac5..d8a7e55e8f9e 100644 --- a/torch/onnx/_internal/onnxruntime.py +++ b/torch/onnx/_internal/onnxruntime.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import importlib import logging diff --git a/torch/onnx/_internal/registration.py b/torch/onnx/_internal/registration.py index 3b2e68e1e40a..f051708f864d 100644 --- a/torch/onnx/_internal/registration.py +++ b/torch/onnx/_internal/registration.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Module for handling symbolic function registration.""" import warnings diff --git a/torch/onnx/_onnx_supported_ops.py b/torch/onnx/_onnx_supported_ops.py index 2611b0d81e9b..e2707298d6d9 100644 --- a/torch/onnx/_onnx_supported_ops.py +++ b/torch/onnx/_onnx_supported_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from typing import Dict, List, Union diff --git a/torch/onnx/_type_utils.py b/torch/onnx/_type_utils.py index d13232507317..d9b647c807f3 100644 --- a/torch/onnx/_type_utils.py +++ b/torch/onnx/_type_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utilities for converting and operating on ONNX, JIT and torch types.""" from __future__ import annotations diff --git a/torch/onnx/operators.py b/torch/onnx/operators.py index 489010519980..1e7532e8451d 100644 --- a/torch/onnx/operators.py +++ b/torch/onnx/operators.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""This file provides a location for operators that help exporting models via onnx. E.g. `shape_as_tensor` and `reshape_from_tensor_shape` diff --git a/torch/onnx/symbolic_caffe2.py b/torch/onnx/symbolic_caffe2.py index 3398fcd2fe10..ed2dc6cd9fdb 100644 --- a/torch/onnx/symbolic_caffe2.py +++ b/torch/onnx/symbolic_caffe2.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import importlib import inspect diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 4430babaef00..676c3d68048b 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 6fd576822e2c..e9ba8b4015f2 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 99d5064ad7a0..e562d5a47567 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 11.""" from __future__ import annotations diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 130b02a889b0..5a6bf720df36 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index 5bba817bbce0..bb7045c0f58b 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 1b4b8ee7917c..62e05910dd72 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 14. Note [ONNX operators that are added/updated in opset 14] diff --git a/torch/onnx/symbolic_opset15.py b/torch/onnx/symbolic_opset15.py index 4f316a77f62e..793c1cad8fb9 100644 --- a/torch/onnx/symbolic_opset15.py +++ b/torch/onnx/symbolic_opset15.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 15. Note [ONNX operators that are added/updated in opset 15] diff --git a/torch/onnx/symbolic_opset16.py b/torch/onnx/symbolic_opset16.py index 24306b475366..cd5829ada850 100644 --- a/torch/onnx/symbolic_opset16.py +++ b/torch/onnx/symbolic_opset16.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 16. Note [ONNX Operators that are added/updated in opset 16] diff --git a/torch/onnx/symbolic_opset17.py b/torch/onnx/symbolic_opset17.py index c7720b9e5c9f..44c789017d75 100644 --- a/torch/onnx/symbolic_opset17.py +++ b/torch/onnx/symbolic_opset17.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 17. Note [ONNX Operators that are added/updated in opset 17] diff --git a/torch/onnx/symbolic_opset18.py b/torch/onnx/symbolic_opset18.py index d80361dd417f..68e14c987731 100644 --- a/torch/onnx/symbolic_opset18.py +++ b/torch/onnx/symbolic_opset18.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 18. Note [ONNX Operators that are added/updated in opset 18] diff --git a/torch/onnx/symbolic_opset20.py b/torch/onnx/symbolic_opset20.py index 9c81bc3e3c49..9557b5f2828e 100644 --- a/torch/onnx/symbolic_opset20.py +++ b/torch/onnx/symbolic_opset20.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 20. Note [ONNX Operators that are added/updated in opset 20] diff --git a/torch/onnx/symbolic_opset7.py b/torch/onnx/symbolic_opset7.py index 0537e8a92888..c647ead4e297 100644 --- a/torch/onnx/symbolic_opset7.py +++ b/torch/onnx/symbolic_opset7.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Note [ONNX operators that are added/updated from opset 7 to opset 8] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index b2fbee3b9784..87b4be230e78 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Note [ONNX operators that are added/updated from opset 8 to opset 9] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index adfa538c8f08..f71ef713636a 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 9. Opset 9 is supported by ONNX release 1.4.1 diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 191df45ac9ef..94a57786a4bd 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Functions to export models into the ONNX IR format. These models can be loaded with the ONNX library and then diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index 6b49e7fc72b9..95ed873bf633 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Functions to verify exported ONNX model is functionally equivalent to original PyTorch model. ONNX Runtime is required, and is used as the ONNX backend for export verification. diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py index 4a6198956fb8..a307cc76846d 100644 --- a/torch/optim/_functional.py +++ b/torch/optim/_functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Functional interface.""" import math from typing import List diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 4d1a4e25319c..eff24f159213 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional import torch diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index a95e985b49eb..0b6dfe852d08 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 1c625682fc34..bff29613175a 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple, Union import torch diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 005327d8bb88..c2e39b788014 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple, Union import torch diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 707ac17c361c..2292c17f9e0f 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast, List, Optional, Tuple, Union import torch diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 633a14832282..454772670904 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple, Union import torch diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index e8818cca538c..480b45c84d72 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index cb7d9738df5a..4a5f162a0b20 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import types import warnings diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index fd1f8ab0e718..a4a6d07b2ca8 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast, List, Optional, Tuple, Union import torch diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index fc091e273c36..498669e65fbd 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import math import warnings diff --git a/torch/optim/radam.py b/torch/optim/radam.py index 619f10493587..d21973caeceb 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast, List, Optional, Tuple, Union import torch diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index bdc3ec0b8b3f..30b56779fc75 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index af1854cc518a..69043b48673e 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple import torch diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 291b4068dd4c..e682d83701d5 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index 88643d1a5646..adb7c17629c2 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Tuple import torch diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 4cfca073af77..440897e6041e 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import math import warnings diff --git a/torch/package/_digraph.py b/torch/package/_digraph.py index f84a51398f00..8b753f7ebdc4 100644 --- a/torch/package/_digraph.py +++ b/torch/package/_digraph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import deque from typing import List, Set diff --git a/torch/package/_directory_reader.py b/torch/package/_directory_reader.py index cec5333c3e3f..77d629cccce2 100644 --- a/torch/package/_directory_reader.py +++ b/torch/package/_directory_reader.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os.path from glob import glob from typing import cast diff --git a/torch/package/_importlib.py b/torch/package/_importlib.py index fd303b6141e7..9741925315e5 100644 --- a/torch/package/_importlib.py +++ b/torch/package/_importlib.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import _warnings import os.path diff --git a/torch/package/_mangling.py b/torch/package/_mangling.py index 0876d64664a2..7dcf3538631f 100644 --- a/torch/package/_mangling.py +++ b/torch/package/_mangling.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Import mangling. See mangling.md for details. """ diff --git a/torch/package/_mock.py b/torch/package/_mock.py index b0bdb95cc48c..44876b1a1d3f 100644 --- a/torch/package/_mock.py +++ b/torch/package/_mock.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs _magic_methods = [ "__subclasscheck__", "__hex__", From 8db9dfa2d79b903ac5937f1be0643d279f9bd48d Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 11:41:16 -0700 Subject: [PATCH 544/706] Flip default value for mypy disallow_untyped_defs [9/11] (#127846) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127846 Approved by: https://github.com/ezyang ghstack dependencies: #127842, #127843, #127844, #127845 --- torch/package/_package_pickler.py | 1 + torch/package/_package_unpickler.py | 1 + torch/package/_stdlib.py | 1 + torch/package/analyze/trace_dependencies.py | 1 + torch/package/file_structure_representation.py | 1 + torch/package/find_file_dependencies.py | 1 + torch/package/glob_group.py | 1 + torch/package/importer.py | 1 + torch/package/package_exporter.py | 1 + torch/package/package_importer.py | 1 + torch/profiler/__init__.py | 1 + torch/profiler/_memory_profiler.py | 1 + torch/profiler/_pattern_matcher.py | 1 + torch/profiler/_utils.py | 1 + torch/profiler/itt.py | 1 + torch/profiler/profiler.py | 1 + torch/quantization/__init__.py | 1 + torch/quantization/_quantized_conversions.py | 1 + torch/quasirandom.py | 1 + torch/random.py | 1 + torch/serialization.py | 1 + torch/signal/windows/windows.py | 1 + torch/sparse/__init__.py | 1 + torch/sparse/_semi_structured_conversions.py | 1 + torch/sparse/_semi_structured_ops.py | 1 + torch/sparse/_triton_ops.py | 1 + torch/sparse/_triton_ops_meta.py | 1 + torch/sparse/semi_structured.py | 1 + torch/storage.py | 1 + torch/testing/_comparison.py | 1 + torch/testing/_internal/common_fsdp.py | 1 + torch/testing/_internal/custom_op_db.py | 1 + torch/testing/_internal/dynamo_test_failures.py | 1 + torch/testing/_internal/static_module.py | 1 + torch/testing/_internal/torchbind_impls.py | 1 + torch/testing/_utils.py | 1 + torch/types.py | 1 + torch/utils/__init__.py | 1 + torch/utils/_config_module.py | 1 + torch/utils/_config_typing.pyi | 1 + torch/utils/_content_store.py | 1 + torch/utils/_contextlib.py | 1 + torch/utils/_cpp_extension_versioner.py | 1 + torch/utils/_device.py | 1 + torch/utils/_exposed_in.py | 1 + torch/utils/_freeze.py | 1 + torch/utils/_get_clean_triton.py | 1 + torch/utils/_import_utils.py | 1 + torch/utils/_mode_utils.py | 1 + torch/utils/_python_dispatch.py | 1 + torch/utils/_stats.py | 1 + .../_strobelight/examples/cli_function_profiler_example.py | 1 + torch/utils/_sympy/functions.py | 1 + torch/utils/_sympy/interp.py | 1 + torch/utils/_sympy/reference.py | 1 + torch/utils/_sympy/singleton_int.py | 1 + torch/utils/_sympy/symbol.py | 1 + torch/utils/_sympy/value_ranges.py | 3 ++- torch/utils/_traceback.py | 1 + torch/utils/_triton.py | 1 + torch/utils/_zip.py | 1 + torch/utils/backcompat/__init__.py | 1 + torch/utils/backend_registration.py | 1 + torch/utils/benchmark/examples/blas_compare_setup.py | 1 + torch/utils/benchmark/examples/compare.py | 1 + torch/utils/benchmark/examples/fuzzer.py | 1 + torch/utils/benchmark/examples/op_benchmark.py | 1 + torch/utils/benchmark/examples/simple_timeit.py | 1 + torch/utils/benchmark/examples/sparse/compare.py | 1 + torch/utils/benchmark/examples/sparse/fuzzer.py | 1 + torch/utils/benchmark/examples/sparse/op_benchmark.py | 1 + torch/utils/benchmark/examples/spectral_ops_fuzz_test.py | 1 + torch/utils/benchmark/op_fuzzers/binary.py | 1 + torch/utils/benchmark/op_fuzzers/sparse_binary.py | 1 + torch/utils/benchmark/op_fuzzers/sparse_unary.py | 1 + torch/utils/benchmark/op_fuzzers/spectral.py | 1 + torch/utils/benchmark/op_fuzzers/unary.py | 1 + torch/utils/benchmark/utils/compare.py | 1 + torch/utils/benchmark/utils/compile.py | 1 + torch/utils/benchmark/utils/fuzzer.py | 1 + torch/utils/benchmark/utils/sparse_fuzzer.py | 1 + torch/utils/bottleneck/__main__.py | 1 + torch/utils/bundled_inputs.py | 1 + torch/utils/checkpoint.py | 1 + torch/utils/collect_env.py | 1 + torch/utils/cpp_backtrace.py | 1 + torch/utils/cpp_extension.py | 1 + torch/utils/data/_utils/__init__.py | 1 + torch/utils/data/_utils/collate.py | 1 + torch/utils/data/_utils/fetch.py | 1 + torch/utils/data/_utils/pin_memory.py | 1 + torch/utils/data/_utils/signal_handling.py | 1 + torch/utils/data/_utils/worker.py | 1 + torch/utils/data/backward_compatibility.py | 1 + torch/utils/data/dataloader.py | 1 + torch/utils/data/datapipes/_decorator.py | 1 + torch/utils/data/datapipes/_hook_iterator.py | 1 + torch/utils/data/datapipes/_typing.py | 1 + torch/utils/data/datapipes/dataframe/dataframe_wrapper.py | 1 + torch/utils/data/datapipes/dataframe/dataframes.py | 1 + torch/utils/data/datapipes/dataframe/datapipes.py | 1 + torch/utils/data/datapipes/dataframe/structures.py | 1 + torch/utils/data/datapipes/gen_pyi.py | 1 + torch/utils/data/datapipes/iter/callable.py | 1 + torch/utils/data/datapipes/iter/combinatorics.py | 1 + torch/utils/data/datapipes/iter/combining.py | 1 + torch/utils/data/datapipes/iter/filelister.py | 1 + torch/utils/data/datapipes/iter/fileopener.py | 1 + torch/utils/data/datapipes/iter/grouping.py | 1 + torch/utils/data/datapipes/iter/selecting.py | 1 + torch/utils/data/datapipes/iter/sharding.py | 1 + torch/utils/data/datapipes/iter/streamreader.py | 1 + torch/utils/data/datapipes/iter/utils.py | 1 + torch/utils/data/datapipes/map/callable.py | 1 + torch/utils/data/datapipes/map/combinatorics.py | 1 + torch/utils/data/datapipes/map/combining.py | 1 + torch/utils/data/datapipes/map/grouping.py | 1 + torch/utils/data/datapipes/map/utils.py | 1 + torch/utils/data/datapipes/utils/common.py | 1 + torch/utils/data/datapipes/utils/decoder.py | 1 + torch/utils/data/datapipes/utils/snapshot.py | 1 + torch/utils/data/dataset.py | 1 + torch/utils/data/graph.py | 1 + torch/utils/data/graph_settings.py | 1 + torch/utils/data/sampler.py | 1 + torch/utils/deterministic.py | 1 + torch/utils/file_baton.py | 1 + torch/utils/flop_counter.py | 1 + 128 files changed, 129 insertions(+), 1 deletion(-) diff --git a/torch/package/_package_pickler.py b/torch/package/_package_pickler.py index cabc6a82164f..2ac59395b73b 100644 --- a/torch/package/_package_pickler.py +++ b/torch/package/_package_pickler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """isort:skip_file""" from pickle import ( # type: ignore[attr-defined] _compat_pickle, diff --git a/torch/package/_package_unpickler.py b/torch/package/_package_unpickler.py index b00210e3c191..890e6b4e03ba 100644 --- a/torch/package/_package_unpickler.py +++ b/torch/package/_package_unpickler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import _compat_pickle import pickle diff --git a/torch/package/_stdlib.py b/torch/package/_stdlib.py index a810d50661cb..2d5145b40aa7 100644 --- a/torch/package/_stdlib.py +++ b/torch/package/_stdlib.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """List of Python standard library modules. Sadly, there is no reliable way to tell whether a module is part of the diff --git a/torch/package/analyze/trace_dependencies.py b/torch/package/analyze/trace_dependencies.py index 9f882fb33481..405fcf2f9bc2 100644 --- a/torch/package/analyze/trace_dependencies.py +++ b/torch/package/analyze/trace_dependencies.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys from typing import Any, Callable, Iterable, List, Tuple diff --git a/torch/package/file_structure_representation.py b/torch/package/file_structure_representation.py index 1453ad3a5ded..44e07978640f 100644 --- a/torch/package/file_structure_representation.py +++ b/torch/package/file_structure_representation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List from .glob_group import GlobGroup, GlobPattern diff --git a/torch/package/find_file_dependencies.py b/torch/package/find_file_dependencies.py index af8cd9fec84d..80cfccbec50a 100644 --- a/torch/package/find_file_dependencies.py +++ b/torch/package/find_file_dependencies.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast from typing import List, Optional, Tuple diff --git a/torch/package/glob_group.py b/torch/package/glob_group.py index a8434788d016..974364400502 100644 --- a/torch/package/glob_group.py +++ b/torch/package/glob_group.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import re from typing import Iterable, Union diff --git a/torch/package/importer.py b/torch/package/importer.py index dd01d09209a8..513847513910 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import importlib from abc import ABC, abstractmethod from pickle import ( # type: ignore[attr-defined] # type: ignore[attr-defined] diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 493c017ccf99..bfa00278fa4b 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import importlib.machinery import io diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 9e2f74354db5..1a103ab6c5c9 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import importlib import importlib.machinery diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index e3c4145fd91f..4a681daf788e 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference. Profiler's context manager API can be used to better understand what model operators are the most expensive, diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index b719df2a56ee..1834f0494e02 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import dataclasses import enum diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py index 02e9b014d308..a7ec5d05dd68 100644 --- a/torch/profiler/_pattern_matcher.py +++ b/torch/profiler/_pattern_matcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import json import math import os diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 35f6e71de558..d69fa4630595 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import operator import re diff --git a/torch/profiler/itt.py b/torch/profiler/itt.py index 4d072957d6fe..4666bba515a3 100644 --- a/torch/profiler/itt.py +++ b/torch/profiler/itt.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager try: diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 5d1b50bc3020..f43dcc06de20 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import gzip import json import os diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index fd83d88a3e3e..a82518db6084 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .quantize import * # noqa: F403 from .observer import * # noqa: F403 from .qconfig import * # noqa: F403 diff --git a/torch/quantization/_quantized_conversions.py b/torch/quantization/_quantized_conversions.py index 2b7670ea4802..8d930c366c0d 100644 --- a/torch/quantization/_quantized_conversions.py +++ b/torch/quantization/_quantized_conversions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/quasirandom.py b/torch/quasirandom.py index 884d1d17e77c..a1218012ceb6 100644 --- a/torch/quasirandom.py +++ b/torch/quasirandom.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from typing import Optional diff --git a/torch/random.py b/torch/random.py index 74d448488042..0916fe115a92 100644 --- a/torch/random.py +++ b/torch/random.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Generator import warnings diff --git a/torch/serialization.py b/torch/serialization.py index a13363d037ac..1cab9b92c550 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import difflib import functools import os diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index d86a1245dc27..f9f73b2dca07 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional, Iterable import torch diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 8ca4aed7d71a..5b86e068096f 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # The Tensor classes are added to this module by python_tensor.cpp from typing import Optional, Tuple, List, Union, Any diff --git a/torch/sparse/_semi_structured_conversions.py b/torch/sparse/_semi_structured_conversions.py index 5203ad245b28..141464f7dc76 100644 --- a/torch/sparse/_semi_structured_conversions.py +++ b/torch/sparse/_semi_structured_conversions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/sparse/_semi_structured_ops.py b/torch/sparse/_semi_structured_ops.py index 551111b429a5..bcaa889ba1ee 100644 --- a/torch/sparse/_semi_structured_ops.py +++ b/torch/sparse/_semi_structured_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index a22b5c8077e3..e11bdf59c882 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import os import torch diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index e6fc1329e812..eedfa03b756a 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Provides optimal triton kernel parameters. Aim diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index d592e5ef6a62..6105038e4df7 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from collections import namedtuple from typing import Any, Optional, Tuple, List, Callable, Dict diff --git a/torch/storage.py b/torch/storage.py index dd268cab0d2e..c094ba5ac3e9 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import io import torch diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 85d5adb0cd3a..9815cc2a8807 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import cmath import collections.abc diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index da9dc2ef4e3c..2b5fdc613c2e 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["oncall: distributed"] import contextlib diff --git a/torch/testing/_internal/custom_op_db.py b/torch/testing/_internal/custom_op_db.py index ee170cc36058..71a2a8f10651 100644 --- a/torch/testing/_internal/custom_op_db.py +++ b/torch/testing/_internal/custom_op_db.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import functools from torch.testing import make_tensor diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py index eb626b552ce6..3b5c291bc41f 100644 --- a/torch/testing/_internal/dynamo_test_failures.py +++ b/torch/testing/_internal/dynamo_test_failures.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import os import sys diff --git a/torch/testing/_internal/static_module.py b/torch/testing/_internal/static_module.py index b39daa380d9d..0a031b0d8f6e 100644 --- a/torch/testing/_internal/static_module.py +++ b/torch/testing/_internal/static_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: unknown"] import torch diff --git a/torch/testing/_internal/torchbind_impls.py b/torch/testing/_internal/torchbind_impls.py index 4ae765c206f7..5d127a9a50c4 100644 --- a/torch/testing/_internal/torchbind_impls.py +++ b/torch/testing/_internal/torchbind_impls.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Optional diff --git a/torch/testing/_utils.py b/torch/testing/_utils.py index b85860eeff03..50d077cb1649 100644 --- a/torch/testing/_utils.py +++ b/torch/testing/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/types.py b/torch/types.py index 10f091a4b24e..a522d622bcc7 100644 --- a/torch/types.py +++ b/torch/types.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins from typing import Any, List, Optional, Sequence, Tuple, Union diff --git a/torch/utils/__init__.py b/torch/utils/__init__.py index a5ca0329a794..24e426a46187 100644 --- a/torch/utils/__init__.py +++ b/torch/utils/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os.path as _osp import torch diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 6b38645e486b..0e548aa7f741 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import copy diff --git a/torch/utils/_config_typing.pyi b/torch/utils/_config_typing.pyi index b2d99e67fabb..2ebb4c09e33e 100644 --- a/torch/utils/_config_typing.pyi +++ b/torch/utils/_config_typing.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, Optional, TYPE_CHECKING, Union """ diff --git a/torch/utils/_content_store.py b/torch/utils/_content_store.py index f36837ed674e..dec70d90b7d3 100644 --- a/torch/utils/_content_store.py +++ b/torch/utils/_content_store.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This module provides a FAST (on GPU) content addressable store for storages # (and tensors on top of them) with VERY WEAK portability guarantees (e.g., # don't expect CPU/CUDA to address to the same hash, don't expect it to be diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index 59b7d368af26..4f1b991438c0 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Extra utilities for working with context managers that should have been # in the standard library but are not diff --git a/torch/utils/_cpp_extension_versioner.py b/torch/utils/_cpp_extension_versioner.py index 0c09a82413fe..0686e826007d 100644 --- a/torch/utils/_cpp_extension_versioner.py +++ b/torch/utils/_cpp_extension_versioner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections diff --git a/torch/utils/_device.py b/torch/utils/_device.py index d4909e54c267..c852cd30c775 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch from torch.overrides import TorchFunctionMode diff --git a/torch/utils/_exposed_in.py b/torch/utils/_exposed_in.py index ddd845349916..54faf279ecfc 100644 --- a/torch/utils/_exposed_in.py +++ b/torch/utils/_exposed_in.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Allows one to expose an API in a private submodule publicly as per the definition # in PyTorch's public api policy. # diff --git a/torch/utils/_freeze.py b/torch/utils/_freeze.py index c7be90a4baee..f813ca28b81c 100644 --- a/torch/utils/_freeze.py +++ b/torch/utils/_freeze.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Freeze Python packages. diff --git a/torch/utils/_get_clean_triton.py b/torch/utils/_get_clean_triton.py index ea0e27cf7d5c..70faa6a8e79d 100644 --- a/torch/utils/_get_clean_triton.py +++ b/torch/utils/_get_clean_triton.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import os import re diff --git a/torch/utils/_import_utils.py b/torch/utils/_import_utils.py index b7756a6fa62f..1102fa8a019d 100644 --- a/torch/utils/_import_utils.py +++ b/torch/utils/_import_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import importlib.util diff --git a/torch/utils/_mode_utils.py b/torch/utils/_mode_utils.py index c6e3cbb5e940..91c0e07b3d93 100644 --- a/torch/utils/_mode_utils.py +++ b/torch/utils/_mode_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from typing import TypeVar diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index c417f1d9d72a..36a4ff65af6f 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import warnings diff --git a/torch/utils/_stats.py b/torch/utils/_stats.py index 5b33f7b8cb02..c11cbd5df270 100644 --- a/torch/utils/_stats.py +++ b/torch/utils/_stats.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE. # IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils # AND SCRUB AWAY TORCH NOTIONS THERE. diff --git a/torch/utils/_strobelight/examples/cli_function_profiler_example.py b/torch/utils/_strobelight/examples/cli_function_profiler_example.py index d97f339ba081..222a70c9fe2d 100644 --- a/torch/utils/_strobelight/examples/cli_function_profiler_example.py +++ b/torch/utils/_strobelight/examples/cli_function_profiler_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.utils._strobelight.cli_function_profiler import ( diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 1384261b4512..5109bc38ffcf 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import sympy diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 806e91cfe281..e5c3c1aa43a7 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This is a simple interpreter for Sympy expressions that dispatches to classes following the torch._inductor.virtualized calling convention. diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 881b9d616eb5..eea543a30943 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import sympy diff --git a/torch/utils/_sympy/singleton_int.py b/torch/utils/_sympy/singleton_int.py index 870bda554e74..1b5e8a96104f 100644 --- a/torch/utils/_sympy/singleton_int.py +++ b/torch/utils/_sympy/singleton_int.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sympy from sympy.multipledispatch import dispatch diff --git a/torch/utils/_sympy/symbol.py b/torch/utils/_sympy/symbol.py index 89908a09e197..bd853faee6d2 100644 --- a/torch/utils/_sympy/symbol.py +++ b/torch/utils/_sympy/symbol.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file contains canonical definitions for our symbol naming conventions, across torch.fx.experimental.symbolic_shapes and torch._inductor. The diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index c7cc96beb980..619a9046796d 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import dataclasses @@ -586,7 +587,7 @@ def reciprocal(x): if 0 in x: return ValueRanges.unknown() else: - return ValueRanges.decreasing_map(x, lambda y: 1 / y) + return ValueRanges.decreasing_map(x, lambda y: 1 / y) # type: ignore[operator] @staticmethod def abs(x): diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py index 9f4d04c55105..aa3944d41708 100644 --- a/torch/utils/_traceback.py +++ b/torch/utils/_traceback.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from types import TracebackType from typing import List, Optional import tempfile diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index ab93a398287c..ff8a5fc73b64 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import hashlib diff --git a/torch/utils/_zip.py b/torch/utils/_zip.py index f37ddb449878..c7dd6445fabe 100644 --- a/torch/utils/_zip.py +++ b/torch/utils/_zip.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import glob import os diff --git a/torch/utils/backcompat/__init__.py b/torch/utils/backcompat/__init__.py index fdd16eec5aca..6a53076c90a6 100644 --- a/torch/utils/backcompat/__init__.py +++ b/torch/utils/backcompat/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch._C import _set_backcompat_broadcast_warn from torch._C import _get_backcompat_broadcast_warn from torch._C import _set_backcompat_keepdim_warn diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index 6a4cbcb8436b..6f3444116f3a 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.overrides import ( handle_torch_function, diff --git a/torch/utils/benchmark/examples/blas_compare_setup.py b/torch/utils/benchmark/examples/blas_compare_setup.py index 44038539cae0..323138d19ddd 100644 --- a/torch/utils/benchmark/examples/blas_compare_setup.py +++ b/torch/utils/benchmark/examples/blas_compare_setup.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import os import shutil diff --git a/torch/utils/benchmark/examples/compare.py b/torch/utils/benchmark/examples/compare.py index 6f99d9d06ad5..5d797a5b0a2b 100644 --- a/torch/utils/benchmark/examples/compare.py +++ b/torch/utils/benchmark/examples/compare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example of Timer and Compare APIs: $ python -m examples.compare diff --git a/torch/utils/benchmark/examples/fuzzer.py b/torch/utils/benchmark/examples/fuzzer.py index 9728bf3d26c9..ee2c9f9c04ed 100644 --- a/torch/utils/benchmark/examples/fuzzer.py +++ b/torch/utils/benchmark/examples/fuzzer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example of the Timer and Fuzzer APIs: $ python -m examples.fuzzer diff --git a/torch/utils/benchmark/examples/op_benchmark.py b/torch/utils/benchmark/examples/op_benchmark.py index e2f0861d20ac..cdf3a7853d73 100644 --- a/torch/utils/benchmark/examples/op_benchmark.py +++ b/torch/utils/benchmark/examples/op_benchmark.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example use of Timer and op fuzzers to measure kernel performance. $ python -m examples.op_benchmark diff --git a/torch/utils/benchmark/examples/simple_timeit.py b/torch/utils/benchmark/examples/simple_timeit.py index 81aaa6dee981..390b88f59e70 100644 --- a/torch/utils/benchmark/examples/simple_timeit.py +++ b/torch/utils/benchmark/examples/simple_timeit.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Trivial use of Timer API: $ python -m examples.simple_timeit diff --git a/torch/utils/benchmark/examples/sparse/compare.py b/torch/utils/benchmark/examples/sparse/compare.py index 4adbd6d2b35e..640912e0167e 100644 --- a/torch/utils/benchmark/examples/sparse/compare.py +++ b/torch/utils/benchmark/examples/sparse/compare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example of Timer and Compare APIs: $ python -m examples.sparse.compare diff --git a/torch/utils/benchmark/examples/sparse/fuzzer.py b/torch/utils/benchmark/examples/sparse/fuzzer.py index 38421474ccf8..8f3885839d3f 100644 --- a/torch/utils/benchmark/examples/sparse/fuzzer.py +++ b/torch/utils/benchmark/examples/sparse/fuzzer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example of the Timer and Sparse Fuzzer APIs: $ python -m examples.sparse.fuzzer diff --git a/torch/utils/benchmark/examples/sparse/op_benchmark.py b/torch/utils/benchmark/examples/sparse/op_benchmark.py index f998f6d5db47..3efb75e8ea13 100644 --- a/torch/utils/benchmark/examples/sparse/op_benchmark.py +++ b/torch/utils/benchmark/examples/sparse/op_benchmark.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example use of Timer and sparse op fuzzers to measure kernel performance. $ python -m examples.sparse.op_benchmark diff --git a/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py b/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py index 3ac54059416c..a3c8cbe5b12c 100644 --- a/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py +++ b/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Microbenchmarks for the torch.fft module""" from argparse import ArgumentParser from collections import namedtuple diff --git a/torch/utils/benchmark/op_fuzzers/binary.py b/torch/utils/benchmark/op_fuzzers/binary.py index 91289d88db8a..75f394179b3e 100644 --- a/torch/utils/benchmark/op_fuzzers/binary.py +++ b/torch/utils/benchmark/op_fuzzers/binary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numpy as np import torch diff --git a/torch/utils/benchmark/op_fuzzers/sparse_binary.py b/torch/utils/benchmark/op_fuzzers/sparse_binary.py index 984493fe4a71..014361877dea 100644 --- a/torch/utils/benchmark/op_fuzzers/sparse_binary.py +++ b/torch/utils/benchmark/op_fuzzers/sparse_binary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numpy as np import torch diff --git a/torch/utils/benchmark/op_fuzzers/sparse_unary.py b/torch/utils/benchmark/op_fuzzers/sparse_unary.py index 70b5ae3cd3a5..f6fe622183f6 100644 --- a/torch/utils/benchmark/op_fuzzers/sparse_unary.py +++ b/torch/utils/benchmark/op_fuzzers/sparse_unary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numpy as np import torch diff --git a/torch/utils/benchmark/op_fuzzers/spectral.py b/torch/utils/benchmark/op_fuzzers/spectral.py index 29359ba3edb6..2b9e92d7a2c7 100644 --- a/torch/utils/benchmark/op_fuzzers/spectral.py +++ b/torch/utils/benchmark/op_fuzzers/spectral.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/utils/benchmark/op_fuzzers/unary.py b/torch/utils/benchmark/op_fuzzers/unary.py index a0f810d0b9fa..e780b421f24c 100644 --- a/torch/utils/benchmark/op_fuzzers/unary.py +++ b/torch/utils/benchmark/op_fuzzers/unary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numpy as np import torch diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index 20122df66718..36c5a77cd1eb 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Display class to aggregate and print the results of many measurements.""" import collections import enum diff --git a/torch/utils/benchmark/utils/compile.py b/torch/utils/benchmark/utils/compile.py index dcee32ace403..fa8f6b63b437 100644 --- a/torch/utils/benchmark/utils/compile.py +++ b/torch/utils/benchmark/utils/compile.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch __all__ = ["bench_all", "benchmark_compile"] diff --git a/torch/utils/benchmark/utils/fuzzer.py b/torch/utils/benchmark/utils/fuzzer.py index 7d1ee8ebb8f8..08206efce377 100644 --- a/torch/utils/benchmark/utils/fuzzer.py +++ b/torch/utils/benchmark/utils/fuzzer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools as it from typing import Any, Callable, Dict, List, Optional, Tuple, Union diff --git a/torch/utils/benchmark/utils/sparse_fuzzer.py b/torch/utils/benchmark/utils/sparse_fuzzer.py index eac6a6baf910..5d3cd051e1de 100644 --- a/torch/utils/benchmark/utils/sparse_fuzzer.py +++ b/torch/utils/benchmark/utils/sparse_fuzzer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional, Tuple, Union from numbers import Number import torch diff --git a/torch/utils/bottleneck/__main__.py b/torch/utils/bottleneck/__main__.py index 4444211a0f87..9b23b1483fe0 100644 --- a/torch/utils/bottleneck/__main__.py +++ b/torch/utils/bottleneck/__main__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import cProfile import pstats diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index 201a000b3006..21fa4e50396d 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union, Sequence, Dict, Callable import textwrap import torch diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index a98c9b2059b8..38b747d8cd69 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import platform import uuid diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 039bc012226c..ed0e02c4c1b9 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Unlike the rest of the PyTorch this file must be python2 compliant. # This script outputs relevant system environment info diff --git a/torch/utils/cpp_backtrace.py b/torch/utils/cpp_backtrace.py index 40dbbb5b913a..af4a7fcb63e2 100644 --- a/torch/utils/cpp_backtrace.py +++ b/torch/utils/cpp_backtrace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch._C import _get_cpp_backtrace def get_cpp_backtrace(frames_to_skip=0, maximum_number_of_frames=64) -> str: diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 913947ea84c7..1904f8c3ecae 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import glob import importlib diff --git a/torch/utils/data/_utils/__init__.py b/torch/utils/data/_utils/__init__.py index 62cfdf91f1ea..7c2b452c15cb 100644 --- a/torch/utils/data/_utils/__init__.py +++ b/torch/utils/data/_utils/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py. A lot of multiprocessing is used in data loading, which only supports running diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index 4c17597bd6f1..1f705c09f0f4 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. These methods are used to collate samples fetched from dataset into Tensor(s). diff --git a/torch/utils/data/_utils/fetch.py b/torch/utils/data/_utils/fetch.py index 553c516ff3ce..3fa6c49404f6 100644 --- a/torch/utils/data/_utils/fetch.py +++ b/torch/utils/data/_utils/fetch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset. This logic is shared in both single- and multi-processing data loading. diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index 9de645cd7ee7..ecb7f8875f23 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory. These **needs** to be in global scope since Py2 doesn't support serializing diff --git a/torch/utils/data/_utils/signal_handling.py b/torch/utils/data/_utils/signal_handling.py index da8f3780bed2..6f0219e91c27 100644 --- a/torch/utils/data/_utils/signal_handling.py +++ b/torch/utils/data/_utils/signal_handling.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Signal handling for multiprocessing data loading. NOTE [ Signal handling in multiprocessing data loading ] diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index 137791c4c436..849f4b9300fe 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. These **needs** to be in global scope since Py2 doesn't support serializing diff --git a/torch/utils/data/backward_compatibility.py b/torch/utils/data/backward_compatibility.py index f51418265f41..e8f1c4e30ef7 100644 --- a/torch/utils/data/backward_compatibility.py +++ b/torch/utils/data/backward_compatibility.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing_extensions import deprecated as _deprecated diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 80784f2ec362..9ad0db898a04 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter. To support these two classes, in `./_utils` we define many utility methods and diff --git a/torch/utils/data/datapipes/_decorator.py b/torch/utils/data/datapipes/_decorator.py index 93ef42076c21..9c5b25d7f22d 100644 --- a/torch/utils/data/datapipes/_decorator.py +++ b/torch/utils/data/datapipes/_decorator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from functools import wraps from typing import Any, Callable, Optional, Type, Union, get_type_hints diff --git a/torch/utils/data/datapipes/_hook_iterator.py b/torch/utils/data/datapipes/_hook_iterator.py index 49e17438d60e..00b44cbede61 100644 --- a/torch/utils/data/datapipes/_hook_iterator.py +++ b/torch/utils/data/datapipes/_hook_iterator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import functools from enum import Enum diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index 08d54bfb31ad..f3fe402690b6 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Taking reference from official Python typing # https://github.com/python/cpython/blob/master/Lib/typing.py diff --git a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py index 9a03a8f00efc..67c5b5408b50 100644 --- a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py +++ b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Optional _pandas: Any = None diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index a93ea6ba2d82..677104538b23 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional from torch.utils.data.datapipes._decorator import functional_datapipe diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py index a75cc5c7a7c2..de0bb8246fb5 100644 --- a/torch/utils/data/datapipes/dataframe/datapipes.py +++ b/torch/utils/data/datapipes/dataframe/datapipes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import random from torch.utils.data.datapipes._decorator import functional_datapipe diff --git a/torch/utils/data/datapipes/dataframe/structures.py b/torch/utils/data/datapipes/dataframe/structures.py index 507a04e491d3..ad5f6f6d588e 100644 --- a/torch/utils/data/datapipes/dataframe/structures.py +++ b/torch/utils/data/datapipes/dataframe/structures.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.utils.data.datapipes.datapipe import DataChunk from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper diff --git a/torch/utils/data/datapipes/gen_pyi.py b/torch/utils/data/datapipes/gen_pyi.py index 2729c6296c08..e2b3ad966a21 100644 --- a/torch/utils/data/datapipes/gen_pyi.py +++ b/torch/utils/data/datapipes/gen_pyi.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import pathlib from collections import defaultdict diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 9a67cc0592ff..f29c96e886e6 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from collections import namedtuple diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index 16d2f5444dcd..b86b28f9d7e1 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import random import torch diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 9a4365516a33..878d885c2042 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod diff --git a/torch/utils/data/datapipes/iter/filelister.py b/torch/utils/data/datapipes/iter/filelister.py index bb10fe4c4965..7384a3a26cb8 100644 --- a/torch/utils/data/datapipes/iter/filelister.py +++ b/torch/utils/data/datapipes/iter/filelister.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Iterator, List, Sequence, Union diff --git a/torch/utils/data/datapipes/iter/fileopener.py b/torch/utils/data/datapipes/iter/fileopener.py index 67e9797fe335..b58ee14a4378 100644 --- a/torch/utils/data/datapipes/iter/fileopener.py +++ b/torch/utils/data/datapipes/iter/fileopener.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from io import IOBase from typing import Iterable, Tuple, Optional diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index c11804ea2cc0..31aa90af5451 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from collections import defaultdict from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py index fee74582e61b..5910ab0da2ec 100644 --- a/torch/utils/data/datapipes/iter/selecting.py +++ b/torch/utils/data/datapipes/iter/selecting.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Callable, Iterator, Tuple, TypeVar from torch.utils.data.datapipes._decorator import functional_datapipe diff --git a/torch/utils/data/datapipes/iter/sharding.py b/torch/utils/data/datapipes/iter/sharding.py index f5bd3261fc1b..f493af685fb4 100644 --- a/torch/utils/data/datapipes/iter/sharding.py +++ b/torch/utils/data/datapipes/iter/sharding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import ( Dict, Sized, diff --git a/torch/utils/data/datapipes/iter/streamreader.py b/torch/utils/data/datapipes/iter/streamreader.py index 9fd80e94e509..4e379db92bc5 100644 --- a/torch/utils/data/datapipes/iter/streamreader.py +++ b/torch/utils/data/datapipes/iter/streamreader.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Tuple from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe diff --git a/torch/utils/data/datapipes/iter/utils.py b/torch/utils/data/datapipes/iter/utils.py index 3794f7f0e778..096188b1369e 100644 --- a/torch/utils/data/datapipes/iter/utils.py +++ b/torch/utils/data/datapipes/iter/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import warnings from torch.utils.data.datapipes.datapipe import IterDataPipe diff --git a/torch/utils/data/datapipes/map/callable.py b/torch/utils/data/datapipes/map/callable.py index c9202bb1eefb..9ddd51ba9bb1 100644 --- a/torch/utils/data/datapipes/map/callable.py +++ b/torch/utils/data/datapipes/map/callable.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from typing import Callable, TypeVar from torch.utils.data.datapipes._decorator import functional_datapipe diff --git a/torch/utils/data/datapipes/map/combinatorics.py b/torch/utils/data/datapipes/map/combinatorics.py index c21d532d4925..7b435ce7c130 100644 --- a/torch/utils/data/datapipes/map/combinatorics.py +++ b/torch/utils/data/datapipes/map/combinatorics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import random import torch diff --git a/torch/utils/data/datapipes/map/combining.py b/torch/utils/data/datapipes/map/combining.py index 809b44dc96cd..731418239ba0 100644 --- a/torch/utils/data/datapipes/map/combining.py +++ b/torch/utils/data/datapipes/map/combining.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import MapDataPipe from typing import Sized, Tuple, TypeVar diff --git a/torch/utils/data/datapipes/map/grouping.py b/torch/utils/data/datapipes/map/grouping.py index a94cc7b5679e..d5d216158acd 100644 --- a/torch/utils/data/datapipes/map/grouping.py +++ b/torch/utils/data/datapipes/map/grouping.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import MapDataPipe, DataChunk from typing import List, Sized, TypeVar diff --git a/torch/utils/data/datapipes/map/utils.py b/torch/utils/data/datapipes/map/utils.py index 18d4fd18a193..d22e708c1538 100644 --- a/torch/utils/data/datapipes/map/utils.py +++ b/torch/utils/data/datapipes/map/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import warnings from torch.utils.data.datapipes.datapipe import MapDataPipe diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 3c466d3392ad..3e8e99c4b154 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import fnmatch import functools import inspect diff --git a/torch/utils/data/datapipes/utils/decoder.py b/torch/utils/data/datapipes/utils/decoder.py index 0211a8fe4ba4..7c055c567295 100644 --- a/torch/utils/data/datapipes/utils/decoder.py +++ b/torch/utils/data/datapipes/utils/decoder.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This file takes partial of the implementation from NVIDIA's webdataset at here: # https://github.com/tmbdev/webdataset/blob/master/webdataset/autodecode.py diff --git a/torch/utils/data/datapipes/utils/snapshot.py b/torch/utils/data/datapipes/utils/snapshot.py index 02487d0da573..8b2266d15d62 100644 --- a/torch/utils/data/datapipes/utils/snapshot.py +++ b/torch/utils/data/datapipes/utils/snapshot.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.utils.data.datapipes._hook_iterator import _SnapshotState from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.graph_settings import apply_random_seed diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index b3cf9d92943d..6ce4b67bfb06 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import bisect import itertools import math diff --git a/torch/utils/data/graph.py b/torch/utils/data/graph.py index cd78db474d5e..d3a882e58595 100644 --- a/torch/utils/data/graph.py +++ b/torch/utils/data/graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import io import pickle import warnings diff --git a/torch/utils/data/graph_settings.py b/torch/utils/data/graph_settings.py index 573069279201..f9de29df288e 100644 --- a/torch/utils/data/graph_settings.py +++ b/torch/utils/data/graph_settings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import warnings diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 4c4c967ef9a9..476d8dfadd41 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor diff --git a/torch/utils/deterministic.py b/torch/utils/deterministic.py index 98a6d30b067b..a055c43be531 100644 --- a/torch/utils/deterministic.py +++ b/torch/utils/deterministic.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import types diff --git a/torch/utils/file_baton.py b/torch/utils/file_baton.py index b55db82b8532..77ee5091b3f7 100644 --- a/torch/utils/file_baton.py +++ b/torch/utils/file_baton.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import time diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index d7080c9e4e38..93c1cf78e710 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten from .module_tracker import ModuleTracker From 57536286e2ecaa281d5510ced4fe70db64acaf0e Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 11:41:18 -0700 Subject: [PATCH 545/706] Flip default value for mypy disallow_untyped_defs [10/11] (#127847) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127847 Approved by: https://github.com/oulgen ghstack dependencies: #127842, #127843, #127844, #127845, #127846 --- torch/utils/hipify/hipify_python.py | 1 + torch/utils/hooks.py | 1 + torch/utils/jit/log_extract.py | 1 + torch/utils/mkldnn.py | 1 + torch/utils/mobile_optimizer.py | 1 + torch/utils/model_dump/__init__.py | 1 + torch/utils/module_tracker.py | 1 + torch/utils/show_pickle.py | 1 + torch/utils/tensorboard/_convert_np.py | 1 + torch/utils/tensorboard/_embedding.py | 1 + torch/utils/tensorboard/_onnx_graph.py | 1 + torch/utils/tensorboard/_proto_graph.py | 1 + torch/utils/tensorboard/_pytorch_graph.py | 1 + torch/utils/tensorboard/_utils.py | 1 + torch/utils/tensorboard/summary.py | 1 + torch/utils/tensorboard/writer.py | 1 + torch/utils/throughput_benchmark.py | 1 + torch/utils/viz/_cycles.py | 1 + torch/utils/weak.py | 1 + torch/xpu/__init__.py | 1 + torch/xpu/random.py | 1 + torch/xpu/streams.py | 1 + 22 files changed, 22 insertions(+) diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 59ee1b2f4743..755a50404055 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs """ The Python Hipify script. ## # Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved. diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index f70a43ad6857..ee828034bdf6 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from collections import OrderedDict import weakref diff --git a/torch/utils/jit/log_extract.py b/torch/utils/jit/log_extract.py index 2e89a769eff0..51894f495e8e 100644 --- a/torch/utils/jit/log_extract.py +++ b/torch/utils/jit/log_extract.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager from typing import Any, List, Tuple, cast import random diff --git a/torch/utils/mkldnn.py b/torch/utils/mkldnn.py index 2d1d8cd89ff5..06ca96d2de9a 100644 --- a/torch/utils/mkldnn.py +++ b/torch/utils/mkldnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/utils/mobile_optimizer.py b/torch/utils/mobile_optimizer.py index 038572806f41..6d2230da8ae1 100644 --- a/torch/utils/mobile_optimizer.py +++ b/torch/utils/mobile_optimizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This module contains utility method for mobile model optimization and lint.""" import torch diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index a8d491ed6b3a..7e2bc36d2e71 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs """ model_dump: a one-stop shop for TorchScript model inspection. diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index f2d83fb36f92..9feef40ca4da 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import weakref from typing import Set diff --git a/torch/utils/show_pickle.py b/torch/utils/show_pickle.py index 24ea1eb4e1e9..66549fac2673 100644 --- a/torch/utils/show_pickle.py +++ b/torch/utils/show_pickle.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs import sys import pickle import struct diff --git a/torch/utils/tensorboard/_convert_np.py b/torch/utils/tensorboard/_convert_np.py index 9368464c2491..80a3c684579d 100644 --- a/torch/utils/tensorboard/_convert_np.py +++ b/torch/utils/tensorboard/_convert_np.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This module converts objects into numpy array.""" import numpy as np diff --git a/torch/utils/tensorboard/_embedding.py b/torch/utils/tensorboard/_embedding.py index afbe68191aa9..44cb6c41b017 100644 --- a/torch/utils/tensorboard/_embedding.py +++ b/torch/utils/tensorboard/_embedding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import numpy as np from ._convert_np import make_np diff --git a/torch/utils/tensorboard/_onnx_graph.py b/torch/utils/tensorboard/_onnx_graph.py index 5c923fcb0ee5..c744ca8719f3 100644 --- a/torch/utils/tensorboard/_onnx_graph.py +++ b/torch/utils/tensorboard/_onnx_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.versions_pb2 import VersionDef diff --git a/torch/utils/tensorboard/_proto_graph.py b/torch/utils/tensorboard/_proto_graph.py index 3c0d15723d24..30140a22cff6 100644 --- a/torch/utils/tensorboard/_proto_graph.py +++ b/torch/utils/tensorboard/_proto_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.attr_value_pb2 import AttrValue diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index f4274199ffd3..d3d2f37cad74 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict import contextlib from typing import Dict, Any diff --git a/torch/utils/tensorboard/_utils.py b/torch/utils/tensorboard/_utils.py index f79f59749f53..30984cfadf17 100644 --- a/torch/utils/tensorboard/_utils.py +++ b/torch/utils/tensorboard/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numpy as np diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index 4d94c3e6158b..55a74f3f8771 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import json import logging import os diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index c646ce0c0c11..cdc4c565734a 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Provide an API for writing protocol buffers to event files to be consumed by TensorBoard for visualization.""" import os diff --git a/torch/utils/throughput_benchmark.py b/torch/utils/throughput_benchmark.py index 5607fadee9e9..2778b37b5a78 100644 --- a/torch/utils/throughput_benchmark.py +++ b/torch/utils/throughput_benchmark.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index f17348e401c3..8c1b9da7a6ad 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import gc import sys from typing import Any, Dict, List, NamedTuple, Optional, Tuple diff --git a/torch/utils/weak.py b/torch/utils/weak.py index a5e33a34d7aa..cc272a7f2637 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import weakref diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 3e7f43b87d4a..6049a11861d2 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This package introduces support for the XPU backend, specifically tailored for Intel GPU optimization. diff --git a/torch/xpu/random.py b/torch/xpu/random.py index 733c55b658cd..1ebdd476ed8c 100644 --- a/torch/xpu/random.py +++ b/torch/xpu/random.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Iterable, List, Union import torch diff --git a/torch/xpu/streams.py b/torch/xpu/streams.py index f4e35a376e7c..19a7cda162f4 100644 --- a/torch/xpu/streams.py +++ b/torch/xpu/streams.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ctypes import torch From 33972dfd581a59b8e430a2546d98dd43b51538c2 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 8 Jun 2024 09:26:15 -0700 Subject: [PATCH 546/706] [easy][inline-inbuilt-nn-modules] Fix expected graph for control flow test (#128246) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128246 Approved by: https://github.com/ydwu4 --- test/functorch/test_control_flow.py | 58 +++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index e8664cb1e98d..f538c5af78ce 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -681,9 +681,43 @@ def test_while_loop_simple_with_linear_compile_check_graph(self): torch.compile(fn, backend=backend)(*inp) self.assertEqual(len(backend.graphs), 1) gm = backend.graphs[0] - self.assertExpectedInline( - gm.code.strip(), - """\ + if torch._dynamo.config.inline_inbuilt_nn_modules: + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_dec_ : torch.Tensor, L_self_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter, L_self_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter): + l_iter_ = L_iter_ + l_x_ = L_x_ + l_self_buffers_dec_ = L_self_buffers_dec_ + l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ + l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ + cond_fn_0 = self.cond_fn_0 + body_fn_0 = self.body_fn_0 + while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None + getitem = while_loop[0] + getitem_1 = while_loop[1]; while_loop = None + return (getitem, getitem_1)""", # noqa: B950 + ) + self.assertExpectedInline( + gm.cond_fn_0.code.strip(), + """\ +def forward(self, l_iter_, l_x_, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): + sub = l_iter_ - l_self_buffers_dec__cond_fn; l_iter_ = l_self_buffers_dec__cond_fn = None + gt = sub > 0; sub = None + return gt""", # noqa: B950 + ) + self.assertExpectedInline( + gm.body_fn_0.code.strip(), + """\ +def forward(self, l_iter_, l_x_, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): + sub = l_iter_ - 1; l_iter_ = None + linear = torch._C._nn.linear(l_x_, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); l_x_ = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None + return (sub, linear)""", # noqa: B950 + ) + else: + self.assertExpectedInline( + gm.code.strip(), + """\ def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor): l_iter_ = L_iter_ l_x_ = L_x_ @@ -696,23 +730,23 @@ def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor): getitem = while_loop[0] getitem_1 = while_loop[1]; while_loop = None return (getitem, getitem_1)""", # noqa: B950 - ) - self.assertExpectedInline( - gm.cond_fn_0.code.strip(), - """\ + ) + self.assertExpectedInline( + gm.cond_fn_0.code.strip(), + """\ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn): sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None gt = sub > 0; sub = None return gt""", # noqa: B950 - ) - self.assertExpectedInline( - gm.body_fn_0.code.strip(), - """\ + ) + self.assertExpectedInline( + gm.body_fn_0.code.strip(), + """\ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn): sub = l_iter_ - 1; l_iter_ = None linear = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None return (sub, linear)""", # noqa: B950 - ) + ) def test_while_loop_nested2_traced(self): fn, inp = WHILE_LOOP_TESTS["nested2"] From 3494f3f9917e665ea1ccf37d33fa52415e8f5c67 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 8 Jun 2024 09:26:15 -0700 Subject: [PATCH 547/706] [dynamo] Skip inlining builtin nn modules for torch.compile inside cond (#128247) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128247 Approved by: https://github.com/ydwu4 ghstack dependencies: #128246 --- torch/_higher_order_ops/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 84c029084ae3..f4b393e7c234 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -96,13 +96,21 @@ def wrapped(*args): @contextmanager def _set_compilation_env(): _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag + _old_is_inlining = torch._dynamo.config.inline_inbuilt_nn_modules try: # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo # once we are confident fx tracing works with dynamo. torch.fx._symbolic_trace._is_fx_tracing_flag = False + + # TODO(anijain2305, export-team) For non-strict export with module + # stack info, the codepatch forces the nn module __getattr__ to + # ProxyAttr __getattr__ downstream. To circumvent the issue for now, + # skip inlining inbuilt nn modules for cond. + torch._dynamo.config.inline_inbuilt_nn_modules = False yield finally: torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing + torch._dynamo.config.inline_inbuilt_nn_modules = _old_is_inlining def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False): From 0dd55ee159ae5aa847d10ca13935e0873fe8d241 Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Sat, 8 Jun 2024 19:52:21 +0000 Subject: [PATCH 548/706] Fix bug in _update_process_group API (#128262) `local_used_map_` was undefined in case of `find_unused_parameters=False`, this resulted in an error when we ran `local_used_map_.fill_(0);` Added a unit test as well Pull Request resolved: https://github.com/pytorch/pytorch/pull/128262 Approved by: https://github.com/awgu --- torch/csrc/distributed/c10d/reducer.cpp | 6 ++++-- .../_internal/distributed/distributed_test.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index ae4db6bd7a17..6a2812ab24b9 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -2374,8 +2374,10 @@ void Reducer::reset_state() { // Reset unused parameter accounting. // See Note [local_used_map_ -> local_used_map_dev copying] - local_used_map_.fill_(0); - local_used_map_reduced_ = false; + if (find_unused_parameters_) { + local_used_map_.zero_(); + local_used_map_reduced_ = false; + } } } // namespace c10d diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 77e9f1f9486f..0ec5dd222444 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -9938,6 +9938,20 @@ def forward(self, inp, error): # Run ddp again. ddp(input, False).sum().backward() + @skip_if_lt_x_gpu(4) + @require_world_size(4) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_update_process_group_no_find_unused(self): + ddp = torch.nn.parallel.DistributedDataParallel( + torch.nn.Linear(10, 10).cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=False, + ) + ddp._update_process_group(_get_default_group()) + @skip_if_lt_x_gpu(2) @skip_but_pass_in_sandcastle_if( From aee154edbe2031c0537aae80e6f4766d20818911 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 8 Jun 2024 01:08:57 -0700 Subject: [PATCH 549/706] [Traceable FSDP2] Make FSDPParam._unsharded_param creation traceable (#127245) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127245 Approved by: https://github.com/awgu --- .../_composable/fsdp/_fsdp_collectives.py | 20 ++++++++-- .../_composable/fsdp/_fsdp_param.py | 39 ++++++++++++++++--- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index 99b69cd82e4b..ac5084813ee1 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -1,6 +1,7 @@ from typing import List, NamedTuple, Optional, Tuple, Union import torch +import torch._dynamo.compiled_autograd as ca import torch.distributed as dist from torch.distributed._tensor import DTensor from torch.distributed.distributed_c10d import ReduceOp @@ -102,10 +103,21 @@ def foreach_all_gather_copy_out( for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip( param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params ): - fsdp_param.init_all_gather_outputs( - all_gather_input_numels, all_gather_input_dtypes, world_size, device - ) # no-op after 1st call - fsdp_param.alloc_all_gather_outputs() + if ca.compiled_autograd_enabled: + fsdp_param.init_all_gather_outputs( + all_gather_input_numels, + all_gather_input_dtypes, + world_size, + device, + # NOTE: Under compile, make sure we always recreate all_gather_outputs + # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2]. + force_recreate=True, + ) + else: + fsdp_param.init_all_gather_outputs( + all_gather_input_numels, all_gather_input_dtypes, world_size, device + ) # no-op after 1st call + fsdp_param.alloc_all_gather_outputs() all_gather_output = all_gather_output.view(world_size, -1) gen = (t for fsdp_param in fsdp_params for t in fsdp_param.all_gather_outputs) if all_gather_output.dtype == torch.uint8: diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index 81596fe05f6b..c56dc79e266b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -240,6 +240,7 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): ) param_data = param self._orig_size = param_data.size() + self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) shard_rank = self.mesh_info.shard_mesh_rank shard_world_size = self.mesh_info.shard_mesh_size chunks = _chunk_with_empty(param_data, shard_world_size, dim=0) @@ -311,8 +312,9 @@ def init_all_gather_outputs( all_gather_input_dtypes: List[torch.dtype], world_size: int, device: torch.device, + force_recreate: bool = False, ): - if self.all_gather_outputs: + if not force_recreate and len(self.all_gather_outputs) > 0: return # already initialized self.all_gather_outputs = [ torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device) @@ -320,7 +322,24 @@ def init_all_gather_outputs( ] def init_unsharded_param(self): - if hasattr(self, "_unsharded_param"): # after the 1st all-gather + """ + [Note: Invariants for torch.compile Traceable FSDP2] + 1. Under compile, we always re-populate the content of `self._unsharded_param` + per AllGather using the slow path. + 2. Under compile, we always recreate `self.all_gather_outputs` per AllGather. + This is to ensure the buffer creation is internal to the graph and + avoid `self.all_gather_outputs` being captured as a graph input. + 3. Under compile, at the end of `free_unsharded_param()`, we always clean up + `self.all_gather_outputs` and `self._unsharded_inner_tensors`, + to avoid them being captured as graph output. + + With these invariants, only these tensors will be inputs to the graph: + - Sharded parameters + - Placeholders for the `self._unsharded_param` nn.Parameter + """ + if not ca.compiled_autograd_enabled and hasattr( + self, "_unsharded_param" + ): # after the 1st all-gather inner_tensor = self._sharded_local_tensor if not hasattr(inner_tensor, "fsdp_post_all_gather"): return # already initialized @@ -357,13 +376,20 @@ def init_unsharded_param(self): unsharded_param = torch.as_strided( unsharded_tensor, self._orig_size, - make_contiguous_strides_for(self._orig_size), + self._contiguous_orig_stride, storage_offset=0, ) if self.is_dtensor: unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) - self._unsharded_param = nn.Parameter(unsharded_param) - self._unsharded_param.requires_grad_(self.sharded_param.requires_grad) + if hasattr(self, "_unsharded_param"): + assert ca.compiled_autograd_enabled + with torch.no_grad(): + alloc_storage(self._unsharded_param) + self._unsharded_param.copy_(unsharded_param) + else: + self._unsharded_param = nn.Parameter( + unsharded_param, requires_grad=self.sharded_param.requires_grad + ) def _unflatten_all_gather_outputs(self) -> Tuple[torch.Tensor, ...]: return tuple( @@ -493,6 +519,9 @@ def free_unsharded_param(self) -> None: self.all_gather_outputs, self._unsharded_inner_tensors ): free_storage(tensor) + if ca.compiled_autograd_enabled: + self.all_gather_outputs = [] + self._unsharded_inner_tensors = [] @property def all_gather_inputs(self) -> List[torch.Tensor]: # 1D From 6e7a23475d603b6d1a971c4c385cfdd7fd407474 Mon Sep 17 00:00:00 2001 From: James Wu Date: Sat, 8 Jun 2024 10:08:10 -0700 Subject: [PATCH 550/706] [easy] Run autograd if any mutations on inputs that require grad (#128229) If any inputs are mutated that require grad, even if all the outputs don't require grad, we should still run autograd with a backwards graph. This fixes two tests: test_input_mutation_alias_everything and test_view_detach. Fixes #128035 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128229 Approved by: https://github.com/aorenste --- test/functorch/test_aotdispatch.py | 26 +++++++++++++------------- torch/_functorch/aot_autograd.py | 13 +++++++++++-- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index bbb3ad5908f2..5046347c8d0c 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -68,6 +68,7 @@ skipIfRocm, skipIfTorchDynamo, TestCase, + xfail_inherited_tests, xfailIfTorchDynamo, ) from torch.testing._internal.hop_db import hop_db @@ -285,7 +286,7 @@ def is_in_base(t, maybe_tensors): return False -def skipIfDynamoInput(reason, xfail=False): +def skipIfDynamoInput(reason): """ Skip TestAOTAutograd if running with dynamo input """ @@ -293,16 +294,12 @@ def skipIfDynamoInput(reason, xfail=False): def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): - fn = func if isinstance(self, TestAOTAutogradWithDynamo): - if xfail: - fn = unittest.expectedFailure(fn) - else: - self.skipTest( - f"Skipping {self._testMethodName} in TestAOTAutogradWithDynamo because {reason}" - ) + self.skipTest( + f"Skipping {self._testMethodName} in TestAOTAutogradWithDynamo because {reason}" + ) else: - fn(self, *args, **kwargs) + func(self, *args, **kwargs) return wrapper @@ -620,7 +617,6 @@ def forward(self, primals_1, primals_2): # https://github.com/pytorch/pytorch/issues/126236 # https://github.com/pytorch/pytorch/pull/126113 @xfailIfTorchDynamo - @skipIfDynamoInput("Not supported by dynamo", xfail=True) def test_set__and_data_mutation_bad(self): def f(a): a_view = a.view(-1) @@ -1851,7 +1847,6 @@ def forward(self, primals_1): ) @parametrize("req_grad", [False, True]) - @skipIfDynamoInput("Runtime error not raised with dynamo", xfail=True) def test_subclass_metadata_mutation(self, req_grad): def f(a): a.transpose_(1, 0) @@ -1925,7 +1920,6 @@ def forward(self, primals_1, primals_2): return [t, view_1, view_2]""", ) - @skipIfDynamoInput("https://github.com/pytorch/pytorch/issues/128035", xfail=True) def test_view_detach(self): def f(a): tmp = a.detach() @@ -2644,7 +2638,6 @@ def inp_callable(): self.verify_aot_autograd(f, inp_callable, test_mutation=True) - @skipIfDynamoInput("https://github.com/pytorch/pytorch/issues/128035", xfail=True) def test_input_mutation_alias_everything(self): # Mondo test that tests a combination of: # input is mutated, that aliases another input (so we make a synthetic base) @@ -5868,6 +5861,13 @@ def test_aot_autograd_symbolic_module_exhaustive( instantiate_device_type_tests(TestEagerFusionModuleInfo, globals(), only_for=only_for) +@xfail_inherited_tests( + [ + "test_set__and_data_mutation_bad", + "test_subclass_metadata_mutation_req_grad_True", + "test_subclass_metadata_mutation_req_grad_False", + ] +) @skipIfTorchDynamo("This test suite already uses dynamo") class TestAOTAutogradWithDynamo(TestAOTAutograd): """ diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 4dc854781e40..d6b084537567 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -570,10 +570,19 @@ def convert(idx, x): fake_flat_args, fw_metadata ) - if needs_autograd and not any( + output_and_mutation_safe = not any( x.requires_grad for x in fw_metadata.output_info - ): + ) and not any( + x.requires_grad + and x.mutates_data + and not x.mutations_under_no_grad_or_inference_mode + and not x.mutations_hidden_from_autograd + for x in fw_metadata.input_info + ) + + if needs_autograd and output_and_mutation_safe: # We realized that none of the outputs require grad, + # and none of the inputs that require grad are mutated. # so we actually have an inference graph. needs_autograd = False # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta From d34075e0bd3bfb036adf0a6c996d8843aab27f84 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Sat, 8 Jun 2024 22:41:05 +0000 Subject: [PATCH 551/706] Add Efficient Attention support on ROCM (#124885) This patch implements `with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):` by reusing AOTriton's accelerated SDPA implementation Known limitations: - Only supports MI200/MI300X GPUs - Does not support varlen - Does not support `CausalVariant` - Optional arguments `causal_diagonal` and `seqlen_k` in `_efficient_attention_forward/backward` must be null - Does not work well with inductor's SDPA rewriter. The rewriter has been updated to only use math and flash attention on ROCM. This PR also uses a different approach of installing AOTriton binary instead of building it from source in the base docker image. More details on motivation: https://github.com/pytorch/pytorch/pull/124885#issuecomment-2153229129 `PYTORCH_TEST_WITH_ROCM=1 PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" python test/test_transformers.py` yields "55028 passed, 20784 skipped" results with this change. [Previous result](https://hud.pytorch.org/pr/127528) of `test_transformers.py` was 0 error, 0 failure, 55229 skipped out of 75517 tests in total (the XML report does not contain total number of passed tests). Pull Request resolved: https://github.com/pytorch/pytorch/pull/124885 Approved by: https://github.com/malfet --- .ci/docker/aotriton_version.txt | 5 + .ci/docker/centos-rocm/Dockerfile | 14 +- .ci/docker/ci_commit_pins/aotriton.txt | 1 - .ci/docker/common/install_aotriton.sh | 31 ++--- .ci/docker/ubuntu-rocm/Dockerfile | 14 +- CMakeLists.txt | 7 +- .../native/transformers/cuda/attention.cu | 84 ++++++++++- .../transformers/cuda/attention_backward.cu | 69 +++++++++- .../native/transformers/cuda/sdp_utils.cpp | 24 ++++ .../transformers/hip/aotriton_adapter.h | 130 ++++++++++++++++++ .../transformers/hip/flash_attn/flash_api.hip | 94 ++----------- caffe2/CMakeLists.txt | 3 + cmake/External/aotriton.cmake | 4 +- cmake/Summary.cmake | 1 + test/distributed/_tensor/test_attention.py | 27 +++- test/inductor/test_fused_attention.py | 2 + test/test_flop_counter.py | 2 + test/test_transformers.py | 50 +++++-- torch/_inductor/fx_passes/fuse_attention.py | 53 ++++--- torch/testing/_internal/common_cuda.py | 9 +- .../_internal/common_methods_invocations.py | 35 ++--- 21 files changed, 481 insertions(+), 178 deletions(-) create mode 100644 .ci/docker/aotriton_version.txt delete mode 100644 .ci/docker/ci_commit_pins/aotriton.txt mode change 100644 => 100755 .ci/docker/common/install_aotriton.sh create mode 100644 aten/src/ATen/native/transformers/hip/aotriton_adapter.h diff --git a/.ci/docker/aotriton_version.txt b/.ci/docker/aotriton_version.txt new file mode 100644 index 000000000000..d13e9d756c95 --- /dev/null +++ b/.ci/docker/aotriton_version.txt @@ -0,0 +1,5 @@ +0.6b +manylinux_2_17 +rocm6 +04b5df8c8123f90cba3ede7e971e6fbc6040d506 +3db6ecbc915893ff967abd6e1b43bd5f54949868873be60dc802086c3863e648 diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index 38d2ff4ed9ab..bfac9ddd8590 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -113,18 +113,18 @@ COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +# Install AOTriton (Early fail) +COPY ./aotriton_version.txt aotriton_version.txt +COPY ./common/common_utils.sh common_utils.sh +COPY ./common/install_aotriton.sh install_aotriton.sh +RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"] +ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton + # Install ccache/sccache (do this last, so we get priority in PATH) COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH RUN bash ./install_cache.sh && rm install_cache.sh -# Install AOTriton -COPY ci_commit_pins/aotriton.txt aotriton.txt -COPY ./common/common_utils.sh common_utils.sh -COPY ./common/install_aotriton.sh install_aotriton.sh -RUN bash ./install_aotriton.sh /opt/rocm/aotriton && rm -rf install_aotriton.sh aotriton aotriton.txt common_utils.sh -ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton - # Include BUILD_ENVIRONMENT environment variable in image ARG BUILD_ENVIRONMENT ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT} diff --git a/.ci/docker/ci_commit_pins/aotriton.txt b/.ci/docker/ci_commit_pins/aotriton.txt deleted file mode 100644 index adb49c304bf4..000000000000 --- a/.ci/docker/ci_commit_pins/aotriton.txt +++ /dev/null @@ -1 +0,0 @@ -24a3fe9cb57e5cda3c923df29743f9767194cc27 diff --git a/.ci/docker/common/install_aotriton.sh b/.ci/docker/common/install_aotriton.sh old mode 100644 new mode 100755 index 47c7a9df773f..da3fe468d3e8 --- a/.ci/docker/common/install_aotriton.sh +++ b/.ci/docker/common/install_aotriton.sh @@ -4,21 +4,20 @@ set -ex source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh" -AOTRITON_DIR="aotriton" -AOTRITON_PINNED_NAME="aotriton" # No .txt extension -AOTRITON_PINNED_COMMIT=$(get_pinned_commit ${AOTRITON_PINNED_NAME}) +TARBALL='aotriton.tar.bz2' +# This read command alwasy returns with exit code 1 +read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true +ARCH=$(uname -m) AOTRITON_INSTALL_PREFIX="$1" +AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}.tar.bz2" -git clone https://github.com/ROCm/aotriton.git "${AOTRITON_DIR}" -cd "${AOTRITON_DIR}" -git checkout "${AOTRITON_PINNED_COMMIT}" -git submodule sync --recursive -git submodule update --init --recursive --force --depth 1 -mkdir build -cd build -cmake .. -G Ninja -DCMAKE_INSTALL_PREFIX=./install_dir -DCMAKE_BUILD_TYPE=Release -DAOTRITON_COMPRESS_KERNEL=OFF -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NO_SHARED=ON -ninja install -mkdir -p "${AOTRITON_INSTALL_PREFIX}" -cp -r install_dir/* "${AOTRITON_INSTALL_PREFIX}" -find /tmp/ -mindepth 1 -delete -rm -rf ~/.triton +cd "${AOTRITON_INSTALL_PREFIX}" +# Must use -L to follow redirects +curl -L --retry 3 -o "${TARBALL}" "${AOTRITON_URL}" +ACTUAL_SHA256=$(sha256sum "${TARBALL}" | cut -d " " -f 1) +if [ "${SHA256}" != "${ACTUAL_SHA256}" ]; then + echo -n "Error: The SHA256 of downloaded tarball is ${ACTUAL_SHA256}," + echo " which does not match the expected value ${SHA256}." + exit +fi +tar xf "${TARBALL}" && rm -rf "${TARBALL}" diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 111a727fe5b8..ee9ede8ba611 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -105,18 +105,18 @@ COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt -# Install ccache/sccache (do this last, so we get priority in PATH) -COPY ./common/install_cache.sh install_cache.sh -ENV PATH /opt/cache/bin:$PATH -RUN bash ./install_cache.sh && rm install_cache.sh - # Install AOTriton -COPY ci_commit_pins/aotriton.txt aotriton.txt +COPY ./aotriton_version.txt aotriton_version.txt COPY ./common/common_utils.sh common_utils.sh COPY ./common/install_aotriton.sh install_aotriton.sh -RUN bash ./install_aotriton.sh /opt/rocm/aotriton && rm -rf install_aotriton.sh aotriton aotriton.txt common_utils.sh +RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"] ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton +# Install ccache/sccache (do this last, so we get priority in PATH) +COPY ./common/install_cache.sh install_cache.sh +ENV PATH /opt/cache/bin:$PATH +RUN bash ./install_cache.sh && rm install_cache.sh + # Include BUILD_ENVIRONMENT environment variable in image ARG BUILD_ENVIRONMENT ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT} diff --git a/CMakeLists.txt b/CMakeLists.txt index 1264540c6875..c4cd4b2c2a98 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -865,12 +865,13 @@ cmake_dependent_option( # Suspect users building from source will need this add_definitions(-DFLASHATTENTION_DISABLE_ALIBI) -# CAVEAT: Again, do not check USE_ROCM here Flash Attention2 will error while -# building for sm52 while Mem Eff Attention won't +# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem +# Eff Attention won't cmake_dependent_option( USE_MEM_EFF_ATTENTION "Enable memory-efficient attention for scaled dot product attention.\ - Will be disabled if not supported by the platform" ON "USE_CUDA" OFF) + Will be disabled if not supported by the platform" ON + "USE_CUDA OR USE_ROCM" OFF) if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index b3f07206ccbe..1a5dbe3a6911 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -72,10 +72,17 @@ #include #endif #ifdef USE_MEM_EFF_ATTENTION -// MemoryEfficient Attention Specific Imports +#ifndef USE_ROCM +// MemoryEfficient Attention Specific Imports for CUDA #include #include #include +#else +// MemoryEfficient Attention Specific Imports for ROCM +#include +#include +#include +#endif #endif namespace at { @@ -1062,6 +1069,64 @@ std::tuple _efficient_ offset_t = at::empty({}, at::dtype(at::kLong).device(device)); } +#ifdef USE_ROCM + // ROCM Implementation + auto ret = aotriton::v2::flash::check_gpu(stream); + if (hipSuccess != ret) { + TORCH_CHECK(false, + "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)") + } + + // AOTriton may accept aligned on logsumexp tensor in the future for better + // performance, but for now it requires compact logsumexp tensor, even if + // compute_logsumexp is false + constexpr int kAlignLSE = 1; + res = at::empty({B, M, num_heads, Kv}, query.options()); + logsumexp = at::empty( + { B, num_heads, max_seqlen_q }, + query.options().dtype(at::ScalarType::Float)); + at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q}); + at::Tensor q_t = query.transpose(1, 2); + at::Tensor k_t = key.transpose(1, 2); + at::Tensor v_t = value.transpose(1, 2); + at::Tensor output_t = res.transpose(1, 2); + bool is_causal; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + is_causal = true; + } else if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { + is_causal = false; + } else { + TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now"); + } + + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + + using aotriton::v2::flash::attn_fwd; + using sdp::aotriton_adapter::mk_aotensor; + aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16); + at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options()); + hipError_t err; // TODO: Error handling + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + softmax_scale, + mk_aotensor<2>(softmax_lse, "M"), + mk_aotensor(output_t, "Out"), + dropout_p, + use_dropout ? *seed_t.data_ptr() : 0, + use_dropout ? *offset_t.data_ptr() : 0, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + stream); + if (!compute_logsumexp) { + // Set the tensor to empty when compute_logsumexp is false + logsumexp = at::empty( + { B * num_heads, max_seqlen_q, 0 }, + query.options().dtype(at::ScalarType::Float)); + } +#else + // CUDA Implementation cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); const int computeCapability = p->major * 10 + p->minor; @@ -1231,6 +1296,7 @@ std::tuple _efficient_ TORCH_CHECK(kernel_launched, "cutlassF: no kernel found to launch!"); AT_CUDA_CHECK(cudaGetLastError()); +#endif // USE_ROCM return std::make_tuple( std::move(res), std::move(logsumexp), @@ -1251,7 +1317,7 @@ Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tenso REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda); -#ifdef USE_MEM_EFF_ATTENTION +#if defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM) namespace { /** * simple kernel that populates a tensor with rand uniform values. @@ -1301,7 +1367,7 @@ __global__ void rand_uniform_kernel( } } } // namespace -#endif +#endif // defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM) /** * fill tensor with random uniform values. only used for testing, not much * attention is paid to performance @@ -1319,6 +1385,17 @@ at::Tensor& _fill_mem_eff_dropout_mask_( const int64_t n_keys = self.size(3); #if defined(USE_MEM_EFF_ATTENTION) +#ifdef USE_ROCM + using aotriton::v2::flash::debug_fill_dropout_rng; + using sdp::aotriton_adapter::mk_aotensor; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + hipError_t err; // TODO: Error handling + + err = debug_fill_dropout_rng(mk_aotensor(self, "r"), + static_cast(seed), + static_cast(offset), + stream); +#else at::PhiloxCudaState rng_engine_inputs; rng_engine_inputs = at::PhiloxCudaState(seed, offset); at::cuda::CUDAGuard device_guard(self.device()); @@ -1332,6 +1409,7 @@ at::Tensor& _fill_mem_eff_dropout_mask_( rng_engine_inputs, reinterpret_cast(self.data_ptr()), self.numel()); +#endif return self; #endif diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 5d9f0ce98474..af9da7b8835b 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -36,11 +36,18 @@ #include #endif #ifdef USE_MEM_EFF_ATTENTION -// MemoryEfficient Attention Specific Imports +#ifndef USE_ROCM +// MemoryEfficient Attention Specific Imports for CUDA #include #include #include #include +#else +// MemoryEfficient Attention Specific Imports for ROCM +#include +#include +#include +#endif #endif #ifdef __HIP_PLATFORM_AMD__ @@ -348,7 +355,6 @@ _efficient_attention_backward( grad_bias = at::empty(sz, bias->options()) .slice(/*dim=*/-1, /*start=*/0, /*end=*/lastDim); } - at::Tensor workspace; const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; @@ -368,6 +374,62 @@ _efficient_attention_backward( } } +#ifdef USE_ROCM + // ROCM Implementation + TORCH_CHECK(!num_splits_key.has_value(), + "ROCM does not support num_split_keys in _efficient_attention_forward"); + TORCH_CHECK(!window_size.has_value(), + "ROCM does not support window_size in _efficient_attention_forward"); + auto ret = aotriton::v2::flash::check_gpu(stream); + if (hipSuccess != ret) { + TORCH_CHECK(false, + "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)") + } + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + bool is_causal; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + is_causal = true; + } else if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { + is_causal = false; + } else { + TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now"); + } + at::Tensor q_t = query.permute({0,2,1,3}); + at::Tensor k_t = key.permute({0,2,1,3}); + at::Tensor v_t = value.permute({0,2,1,3}); + at::Tensor out_t = out.permute({0,2,1,3}); + at::Tensor dq_t = grad_q.permute({0,2,1,3}); + at::Tensor dk_t = grad_k.permute({0,2,1,3}); + at::Tensor dv_t = grad_v.permute({0,2,1,3}); + at::Tensor dout_t = grad_out.permute({0,2,1,3}); + at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q}); + at::Tensor delta = at::empty_like(softmax_lse).contiguous(); + + hipError_t err; + using aotriton::v2::flash::attn_bwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); + err = attn_bwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + rng_engine_inputs.seed_.val, + rng_engine_inputs.offset_.val, + is_causal, + stream); +#else + at::Tensor workspace; cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); const int computeCapability = p->major * 10 + p->minor; @@ -624,8 +686,9 @@ _efficient_attention_backward( })); TORCH_CHECK(kernel_launched, "cutlassB: no kernel found to launch!"); AT_CUDA_CHECK(cudaGetLastError()); +#endif // USE_ROCM return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v), std::move(grad_bias)); - #endif + #endif // defined(USE_MEM_EFF_ATTENTION) TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.") return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); } diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 372377e1eca6..b474e4ee2a7c 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -215,6 +215,17 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) // Mem Efficient attention supports hardware in the range [sm_50, sm_90] using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; +#if USE_ROCM + auto stream = at::cuda::getCurrentCUDAStream().stream(); + if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (debug) { + TORCH_WARN( + "Mem Efficient attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); + } + return false; + } +#else auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { @@ -227,6 +238,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) } return false; } +#endif return true; } @@ -597,6 +609,10 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { array_of(at::kHalf, at::kFloat, at::kBFloat16); constexpr auto less_than_sm80_mem_efficient_dtypes = array_of(at::kHalf, at::kFloat); +#ifdef USE_ROCM + constexpr auto aotriton_mem_efficient_dtypes = + array_of(at::kHalf, at::kFloat, at::kBFloat16); +#endif // Define gate functions that determine if a mem efficient kernel can be ran constexpr auto general_constraints = array_of( @@ -612,6 +628,10 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { } if (has_for_nested_inputs(params)) { +#ifdef USE_ROCM + TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention."); + return false; +#endif constexpr auto nested_constraints = array_of( check_requires_grad_and_nested, check_batch_size_nested, @@ -634,10 +654,14 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { } } +#ifdef USE_ROCM + return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug); +#else auto dprop = at::cuda::getCurrentDeviceProperties(); if (dprop->major >= 8) { return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug); } +#endif return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug); } diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h new file mode 100644 index 000000000000..1c238c751a05 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -0,0 +1,130 @@ +#pragma once + +#ifdef USE_ROCM + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h +//////////////////////////////////////////////////////////////////////////////// + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace sdp { + +namespace aotriton_adapter { + +inline aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype) +{ +#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname + CAST_TYPE(kByte, kUInt8); + CAST_TYPE(kUInt16, kUInt16); + CAST_TYPE(kUInt32, kUInt32); + CAST_TYPE(kUInt64, kUInt64); + CAST_TYPE(kChar, kInt8); + CAST_TYPE(kShort, kInt16); + CAST_TYPE(kInt, kInt32); + CAST_TYPE(kLong, kInt64); + CAST_TYPE(kHalf, kFloat16); + CAST_TYPE(kFloat, kFloat32); + CAST_TYPE(kBFloat16, kBFloat16); + return aotriton::DType::kUnknown; +#undef CAST_TYPE +} + +template +struct IntArrayRefCaster { + // std::array cast(IntArrayRef); +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ static_cast(ref.at(0)) }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)) + }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)), + static_cast(ref.at(2)) + }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)), + static_cast(ref.at(2)), + static_cast(ref.at(3)) + }}; + } +}; + + +template +aotriton::TensorView mk_aotensor(const at::Tensor& q, c10::string_view tensor_name) +{ + const auto strides = q.strides(); + int real_rank = strides.size(); + if (real_rank != Rank) { // Lazy convertion of tensor_name + TORCH_CHECK(false, + std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) + + " but is " + std::to_string(real_rank)); + } + return aotriton::TensorView(reinterpret_cast(q.data_ptr()), + IntArrayRefCaster::cast(q.sizes()), + IntArrayRefCaster::cast(strides), + cast_dtype(q.dtype())); +} + +} // namespace aotriton_adapter + +} // namespace sdp + +namespace at::native { + +inline int64_t ceil_div(int64_t numerator, int64_t denominator) { + return (numerator + (denominator - 1)) / denominator; +} + +} + +#endif // USE_ROCM diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index e110e4ae1c64..7af480a7ae49 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -54,16 +54,15 @@ #include #endif +#include #include #include #include // AOTriton headers -#include #include #include -#include namespace pytorch_flash { @@ -73,90 +72,10 @@ void check_gpu_arch(hipStream_t stream) { auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)") + "FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)") } } -aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype) -{ -#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname - CAST_TYPE(kByte, kUInt8); - CAST_TYPE(kUInt16, kUInt16); - CAST_TYPE(kUInt32, kUInt32); - CAST_TYPE(kUInt64, kUInt64); - CAST_TYPE(kChar, kInt8); - CAST_TYPE(kShort, kInt16); - CAST_TYPE(kInt, kInt32); - CAST_TYPE(kLong, kInt64); - CAST_TYPE(kHalf, kFloat16); - CAST_TYPE(kFloat, kFloat32); - CAST_TYPE(kBFloat16, kBFloat16); - return aotriton::DType::kUnknown; -#undef CAST_TYPE -} - -template -struct IntArrayRefCaster { - // std::array cast(IntArrayRef); -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ static_cast(ref.at(0)) }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)) - }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)), - static_cast(ref.at(2)) - }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)), - static_cast(ref.at(2)), - static_cast(ref.at(3)) - }}; - } -}; - - -template -aotriton::TensorView mk_aotensor(const at::Tensor& q, c10::string_view tensor_name) -{ - const auto strides = q.strides(); - int real_rank = strides.size(); - if (real_rank != Rank) { // Lazy convertion of tensor_name - TORCH_CHECK(false, - std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) - + " but is " + std::to_string(real_rank)); - } - return aotriton::TensorView(reinterpret_cast(q.data_ptr()), - IntArrayRefCaster::cast(q.sizes()), - IntArrayRefCaster::cast(strides), - cast_dtype(q.dtype())); -} - } #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") @@ -300,9 +219,13 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_fwd(mk_aotensor(q_t, "q"), mk_aotensor(k_t, "k"), mk_aotensor(v_t, "v"), + empty_bias, softmax_scale, mk_aotensor<2>(M, "M"), mk_aotensor(output_t, "Out"), @@ -495,15 +418,20 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si hipError_t err; // TODO: Error handling { using aotriton::v2::flash::attn_bwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_bwd(mk_aotensor(q_t, "q"), mk_aotensor(k_t, "k"), mk_aotensor(v_t, "v"), + empty_bias, softmax_scale, mk_aotensor(out_t, "out"), mk_aotensor(dout_t, "dout"), mk_aotensor(dq_t, "dq"), mk_aotensor(dk_t, "dk"), mk_aotensor(dv_t, "dv"), + empty_bias, mk_aotensor<2>(softmax_lse_cont, "L"), mk_aotensor<2>(delta, "delta"), p_dropout, diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 0d64fe75be41..89c31fab1134 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1321,6 +1321,9 @@ if(USE_ROCM) if(USE_FLASH_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION) endif() + if(USE_MEM_EFF_ATTENTION) + target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION) + endif() endif() if(BUILD_LITE_INTERPRETER) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index c95c66626837..ec6f09b60533 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -10,9 +10,11 @@ if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}") message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}") else() + file(STRINGS "${CMAKE_CURRENT_SOURCE_DIR}/.ci/docker/aotriton_version.txt" __AOTRITON_CI_INFO) + list(GET __AOTRITON_CI_INFO 3 __AOTRITON_CI_COMMIT) ExternalProject_Add(aotriton_external GIT_REPOSITORY https://github.com/ROCm/aotriton.git - GIT_TAG 24a3fe9cb57e5cda3c923df29743f9767194cc27 + GIT_TAG ${__AOTRITON_CI_COMMIT} SOURCE_DIR ${__AOTRITON_SOURCE_DIR} BINARY_DIR ${__AOTRITON_BUILD_DIR} PREFIX ${__AOTRITON_INSTALL_DIR} diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 289419c38603..aeb367690d3f 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -121,6 +121,7 @@ function(caffe2_print_configuration_summary) if(${USE_ROCM}) message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") + message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") endif() message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") diff --git a/test/distributed/_tensor/test_attention.py b/test/distributed/_tensor/test_attention.py index db5a26d43850..3979dd4ad546 100644 --- a/test/distributed/_tensor/test_attention.py +++ b/test/distributed/_tensor/test_attention.py @@ -17,12 +17,19 @@ ) from torch.distributed.tensor.parallel import parallelize_module from torch.nn.attention import sdpa_kernel, SDPBackend -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FLASH_ATTENTION, + PLATFORM_SUPPORTS_FUSED_ATTENTION, + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + TEST_CUDA, +) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + skipIfRocm, + TEST_WITH_ROCM, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -41,6 +48,7 @@ def world_size(self) -> int: return 2 @skip_if_lt_x_gpu(2) + @skipIfRocm # Missing _c10d_functional_autograd::all_to_all_single @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" ) @@ -299,18 +307,29 @@ def test_ring_attention_custom_transformer(self) -> None: @skip_if_lt_x_gpu(2) @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" + not PLATFORM_SUPPORTS_FUSED_ATTENTION, + "Does not support flash nor efficient attention", ) + @unittest.skipIf( + TEST_CUDA and not TEST_WITH_ROCM and not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Does not support flash attention", + ) # On CUDA (not ROCM) platform, the UT is skipped if no FA support (even if ME may get supported) @with_comms @parametrize( "attention_fn", [ - _scaled_dot_product_ring_flash_attention, - _scaled_dot_product_ring_efficient_attention, + _scaled_dot_product_ring_flash_attention + if PLATFORM_SUPPORTS_FLASH_ATTENTION + else None, + _scaled_dot_product_ring_efficient_attention + if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION + else None, # _scaled_dot_product_ring_cudnn_attention, # TODO: not built by default ], ) def test_ring_attention_compile(self, attention_fn: object) -> None: + if attention_fn is None: + self.skipTest("Unsupported on current platform") device_mesh = DeviceMesh( self.device_type, torch.arange(0, self.world_size), diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index e53ab76036d6..c17d78f628a3 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -280,6 +280,7 @@ def dot_prod_attention( self._check_common(dot_prod_attention) self._check_common(checkpoint_wrapper(dot_prod_attention)) + @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0 def _test_sdpa_rewriter_3(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool @@ -296,6 +297,7 @@ def dot_prod_attention( checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True ) + @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0 def _test_sdpa_rewriter_4(self): def dot_prod_attention( query: torch.Tensor, diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index 4f9c7020c0e6..3f09a85a6d97 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -14,6 +14,7 @@ run_tests, TEST_WITH_TORCHDYNAMO, TestCase, + skipIfRocm, ) try: @@ -434,6 +435,7 @@ def get_flops( self.assertExpectedInline(str(flops_fw_bw_math), """805306368""") self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""") + @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION diff --git a/test/test_transformers.py b/test/test_transformers.py index 73f838143dd5..774cb60ee94d 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -132,6 +132,10 @@ def get_platform_specific_sdpa(): return ret PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa() +# Indicate the Efficient attention backend can support: +# 1. sequence longher than 512 +# 2. head dimsion larger than 64 +MEM_EFF_CAPABILITY_MATCHES_SM80 = SM80OrLater or TEST_WITH_ROCM def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str, requires_grad: bool = False, packed: bool = False) -> torch.Tensor: @@ -2255,6 +2259,8 @@ def test_singelton_head_dim_stride_ne_1(self, device): @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool): + if TEST_WITH_ROCM and type == 'nested': + self.skipTest("ROCM does not support efficient attention on nested tensors, for now") make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True) batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 @@ -2349,7 +2355,7 @@ def rand_tensor(shape): self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3) - @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Flash Attention was not built for this system") + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Efficient Attention was not built for this system") @parametrize("contiguous_inputs", [True, False]) @parametrize("is_causal", [True, False]) def test_sdp_mem_efficient_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool): @@ -2482,6 +2488,7 @@ def test_fused_sdp_choice(self, device, type: str): assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value + @skipIfRocm # Missing triton.float32 ("triton" prefix is to locate skipped UTs), and deterministic algo @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA") @parametrize("warn_only", [True, False]) def test_sdp_choice_with_determinism(self, device, warn_only): @@ -2494,6 +2501,7 @@ def test_sdp_choice_with_determinism(self, device, warn_only): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value + @skipIfRocm # Missing deterministic algo @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) @parametrize("warn_only", [True, False]) @@ -2572,13 +2580,16 @@ def test_mem_eff_backwards_determinism(self, device): @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") @parametrize("batch_size", [1, 8]) - @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512]) - @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512]) - @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if SM80OrLater else [8, 16, 32, 64]) + @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [4, 8, 64, 128, 256, 512]) + @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [4, 8, 64, 128, 256, 512]) + @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [8, 16, 32, 64]) @parametrize("is_causal", [False, True]) @parametrize("dropout_p", [0.0, 0.22]) - @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if - SM80OrLater else [torch.float16, torch.float32]) + @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [torch.float16, torch.float32]) @parametrize("scale", [None, "l1"]) def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, @@ -2591,6 +2602,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: unittest.skip("Reference implementation OOM") return + if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: + torch.cuda.empty_cache() # Prevent memory fragmentation seed = 42 scale = scale if scale is None else (1 / head_dim) n_heads = 4 @@ -2660,6 +2673,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 + if TEST_WITH_ROCM: + value_fudge_factor = max(2.0, value_fudge_factor) grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) @@ -2674,13 +2689,16 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") @parametrize("batch_size", [1, 8]) - @parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 152, 256, 512]) - @parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if SM80OrLater else [4, 8, 37, 64, 128, 256, 512]) - @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if SM80OrLater else [8, 16, 32, 64]) + @parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [4, 8, 64, 128, 152, 256, 512]) + @parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [4, 8, 37, 64, 128, 256, 512]) + @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [8, 16, 32, 64]) @parametrize("is_causal", [False]) @parametrize("dropout_p", [0.0, 0.22]) - @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if - SM80OrLater else [torch.float16, torch.float32]) + @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [torch.float16, torch.float32]) @parametrize("scale", [None, "l1"]) def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, @@ -2694,6 +2712,11 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: unittest.skip("Reference implementation OOM") return + if TEST_WITH_ROCM and dtype == torch.float32: + unittest.skip("Skip fp32 attn_mask gradients on ROCM, for now.") + return + if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: + torch.cuda.empty_cache() # Prevent memory fragmentation seed = 42 scale = scale if scale is None else (1 / head_dim) n_heads = 4 @@ -2772,6 +2795,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 + if TEST_WITH_ROCM: + value_fudge_factor = max(2.0, value_fudge_factor) grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) mask_fudge_factor = 12 if attn_mask.numel() > 512 else 22 @@ -2806,6 +2831,8 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") if is_causal and seq_len_q != seq_len_k: self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") + if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1: + torch.cuda.empty_cache() # Prevent memory fragmentation scale = scale if scale is None else (1 / head_dim) n_heads = 4 @@ -3191,6 +3218,7 @@ def _broadcast(t, batch_broadcasted, num_heads_broadcasted): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) + @skipIfRocm # Nested tensor @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") def test_fused_kernels_nested_broadcasting_query_dense(self, device): rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32) diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 2a646bc4c4ce..fad49d404827 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -5,6 +5,7 @@ import math import torch +from torch.nn.attention import sdpa_kernel, SDPBackend from ..._dynamo.utils import counters from ..pattern_matcher import ( filter_nodes, @@ -17,6 +18,16 @@ aten = torch.ops.aten +if torch.version.hip: + + def _scaled_dot_product_attention(*args, **kwargs): + with sdpa_kernel(backends=[SDPBackend.MATH, SDPBackend.FLASH_ATTENTION]): + return aten.scaled_dot_product_attention(*args, **kwargs) + +else: + _scaled_dot_product_attention = aten.scaled_dot_product_attention + + def _sfdp_pattern_1(query, key, value, inv_scale): return ( torch.matmul(query, key.transpose(-2, -1)) @@ -28,7 +39,7 @@ def _sfdp_pattern_1(query, key, value, inv_scale): def _sfdp_replacement_1(query, key, value, inv_scale): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -50,7 +61,7 @@ def _sfdp_pattern_2(query, key, value, scale_factor): def _sfdp_replacement_2(query, key, value, scale_factor): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -72,7 +83,7 @@ def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p): def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -92,7 +103,7 @@ def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p): def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -113,7 +124,7 @@ def _sfdp_pattern_5(query, key, value, attn_mask): def _sfdp_replacement_5(query, key, value, attn_mask): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -133,7 +144,7 @@ def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -168,7 +179,7 @@ def _sfdp_replacement_7(query, key, value, dropout_p): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( q, k, v, @@ -195,7 +206,7 @@ def _sfdp_replacement_8(query, key, value): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( q, k, v, @@ -223,7 +234,7 @@ def _sfdp_replacement_9(query, key, value, dropout_p): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( q, k, v, @@ -251,7 +262,7 @@ def _sfdp_replacement_10(query, key, value): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( q, k, v, @@ -271,7 +282,7 @@ def _sfdp_pattern_11(query, key, value, inv_scale): def _sfdp_replacement_11(query, key, value, inv_scale): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -294,7 +305,7 @@ def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p): def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -313,7 +324,7 @@ def _sfdp_pattern_13(query, key, value, dropout_p): def _sfdp_replacement_13(query, key, value, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), @@ -337,7 +348,7 @@ def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale): def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -370,11 +381,11 @@ def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale): n_head = query.size(2) q_len = query.size(1) k_len = key.size(1) - # do attn_mask->logical_not() in aten.scaled_dot_product_attention + # do attn_mask->logical_not() in _scaled_dot_product_attention attn_mask = ( (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) ) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -404,7 +415,7 @@ def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p): def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -440,11 +451,11 @@ def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p): n_head = query.size(2) q_len = query.size(1) k_len = key.size(1) - # do attn_mask->logical_not() in aten.scaled_dot_product_attention + # do attn_mask->logical_not() in _scaled_dot_product_attention attn_mask = ( (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) ) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -489,7 +500,7 @@ def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p): permuted_key = key.transpose(1, 2) permuted_value = value.transpose(1, 2) return ( - aten.scaled_dot_product_attention( + _scaled_dot_product_attention( query.transpose(1, 2), permuted_key, permuted_value, @@ -526,7 +537,7 @@ def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p): counters["inductor"]["fuse_attention"] += 1 fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) attn_mask = torch.where(causal_mask, attn_mask, fill_value) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query, key, value, diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 054f1a135740..189be09d8ba9 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -50,8 +50,15 @@ def evaluate_platform_supports_flash_attention(): return not IS_WINDOWS and SM80OrLater return False +def evaluate_platform_supports_efficient_attention(): + if TEST_WITH_ROCM: + return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') + if TEST_CUDA: + return True + return False + PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention()) -PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM) +PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention()) # TODO(eqy): gate this against a cuDNN version PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM and torch.backends.cuda.cudnn_sdp_enabled()) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 151210cf9f53..476d85d5de6f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15923,15 +15923,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): device_type='cpu'), DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', device_type='cpu'), - # TODO: Do not work even on MI200 because of stride mismatching. - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', - device_type='cuda', dtypes=[torch.float16, torch.bfloat16], - active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', - device_type='cuda', dtypes=[torch.float16, torch.bfloat16], - active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_amp', - device_type='cuda', active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), # When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', device_type='cpu'), @@ -15951,6 +15942,19 @@ def reference_flatten(input, start_dim=0, end_dim=-1): device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), + # FIXME + DecorateInfo(unittest.skip('test_cow_input does not work with efficient attention on ROCM'), + 'TestCompositeCompliance', 'test_cow_input', + device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), + DecorateInfo(unittest.skip('test_fake_crossref_backward_amp does not work with efficient attention on ROCM'), + 'TestFakeTensor', 'test_fake_crossref_backward_amp', + device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), + DecorateInfo(unittest.skip('test_fake_crossref_backward_no_amp does not work with efficient attention on ROCM'), + 'TestFakeTensor', 'test_fake_crossref_backward_no_amp', + device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), # registered in fake_impls.py instead of _meta_registrations.py, so meta kernels will fail. # However, for implementations that fall back to the constituent ops, the meta kernels may not # fail. Fused kernels will fail, whereas unfused kernels will not fail. @@ -15958,6 +15962,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # mem_eff_attention also supports fp32 - so if it is supported the test will fail. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bfloat16, torch.float16), active_if=PLATFORM_SUPPORTS_FUSED_ATTENTION), + # TODO: float32 support in ROCM efficient attention DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", @@ -15997,13 +16002,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), # None Mismatch Tensor DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), - # TODO: Do not work on MI200 because of stride mismatching. - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', - device_type='cuda', dtypes=[torch.float16, torch.bfloat16], - active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', - device_type='cuda', dtypes=[torch.float16, torch.bfloat16], - active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), ) ), OpInfo( @@ -16020,7 +16018,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1): check_batched_forward_grad=False, # TODO: Skip because it produces a CUDA illegal memory access for some reason skip_cow_input_backward=True, - decorators=[skipCUDAIf(TEST_WITH_ROCM, "ROCm doesn't support efficient attention")], + # FIXME: mask_type == 2 (LowerRight) + decorators=[ + skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), + skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")], skips=( # Device mismatch due to philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'), From 2c2cf1d7799369724771468667e354ec7567d388 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Fri, 7 Jun 2024 14:31:14 -0700 Subject: [PATCH 552/706] [dtensor][experiment] experimenting with displaying model parameters (#127630) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): **Summary** Example code to display model parameters and verify them against ground truth. Also expanded on moduletracker to accomplish this. **Test Plan** python3 torch/distributed/_tensor/examples/display_sharding_example.py * #127987 * __->__ #127630 * #127360 * #127358 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127630 Approved by: https://github.com/XilunWu ghstack dependencies: #127358, #127360 --- torch/distributed/_tensor/debug/comm_mode.py | 48 +++++++++ .../examples/display_sharding_example.py | 97 +++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 torch/distributed/_tensor/examples/display_sharding_example.py diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index 150ef9250c2d..81f544434131 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -3,9 +3,17 @@ from typing import Any, Dict import torch +from torch.autograd.graph import register_multi_grad_hook from torch.distributed._tensor.api import DTensor + +from torch.nn.modules.module import ( + register_module_forward_hook, + register_module_forward_pre_hook, +) from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten +from torch.utils.module_tracker import ModuleTracker funcol_native = torch.ops._c10d_functional funcol_py = torch.ops.c10d_functional @@ -44,6 +52,40 @@ } +class ModuleParamaterShardingTracker(ModuleTracker): + """ + Inherits ModuleTracker and expands on its functionality to track the + parameters and sharding information of a model at a module-level + """ + + def __init__(self): + super().__init__() + self.module_parameters_dict = {} + + def _fw_pre_hook(self, mod, input): + name = super()._get_mod_name(mod) + super()._get_append_fn(name, False)() + + args, _ = tree_flatten(input) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if tensors: + register_multi_grad_hook(tensors, super()._get_pop_fn(name, True)) + + for param_name, param in mod.named_parameters(recurse=False): + if name not in self.module_parameters_dict: + self.module_parameters_dict[name] = {} + + self.module_parameters_dict[name][param_name] = param.data + + def __enter__(self): + self.module_parameters_dict.clear() + self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) + self._fw_post_handle = register_module_forward_hook(super()._fw_post_hook) + + def __exit__(self, *args): + super().__exit__(*args) + + class CommDebugMode(TorchDispatchMode): """ ``CommDebugMode`` is a context manager that counts the number of @@ -72,6 +114,7 @@ def __init__(self): self.comm_registry.add(py_op) self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall) + self.advanced_module_tracker = ModuleParamaterShardingTracker() def get_total_counts(self) -> int: return sum(self.comm_counts.values()) @@ -84,12 +127,17 @@ def get_comm_counts(self) -> Dict[Any, int]: """ return self.comm_counts + def get_parameter_info(self) -> Dict[str, Dict[str, Any]]: + return self.advanced_module_tracker.module_parameters_dict + def __enter__(self): self.comm_counts.clear() super().__enter__() + self.advanced_module_tracker.__enter__() return self def __exit__(self, *args): + self.advanced_module_tracker.__exit__() super().__exit__(*args) def __torch_dispatch__(self, func, types, args=(), kwargs=None): diff --git a/torch/distributed/_tensor/examples/display_sharding_example.py b/torch/distributed/_tensor/examples/display_sharding_example.py new file mode 100644 index 000000000000..95d8d73f77f4 --- /dev/null +++ b/torch/distributed/_tensor/examples/display_sharding_example.py @@ -0,0 +1,97 @@ +from typing import Any, Dict + +import torch + +from torch.distributed._tensor.debug import CommDebugMode + +from torch.distributed._tensor.debug.comm_mode import ModuleParamaterShardingTracker + +from torch.testing._internal.distributed._tensor.common_dtensor import ( + MLPModule, + MLPStacked, +) + + +class DisplayShardingExample: + """ + Checks if the set of keys in ground truth dictionary and the set + produced in advanced_module_tracker are in the same order + """ + + def same_set_of_keys(self, dict1, dict2): + dict1_keys = [] + dict2_keys = [] + + for key in dict1: + for nested_key in dict1[key]: + dict1_keys.append((key, nested_key)) + + for key in dict2: + for nested_key in dict2[key]: + dict2_keys.append((key, nested_key)) + + if len(dict1_keys) != len(dict2_keys): + return False + + for i in range(len(dict1_keys)): + if dict1_keys[i] != dict2_keys[i]: + return False + + return True + + def ground_truth(self, model): + module_parameters_dict: Dict[str, Any] = {} + + for name, parameters in model.named_parameters(): + module_name = model.__class__.__name__ + "." + name.rsplit(".", 1)[0] + parameter_name = name.rsplit(".", 1)[1] + + if module_name not in module_parameters_dict: + module_parameters_dict[module_name] = {} + + module_parameters_dict[module_name][parameter_name] = parameters.data + + return module_parameters_dict + + def test_display_parameters_MLP(self): + """ + Example of using obtaining all module's FQN and parameters for a given model + """ + + inp_size = [8, 10] + + rng_seed = 0 + torch.manual_seed(rng_seed) + inp = torch.rand(*inp_size) + model = MLPModule(None) + + LR = 0.25 + + optim = torch.optim.SGD(model.parameters(), lr=LR) + comm_mode = CommDebugMode() + module_tracker = ModuleParamaterShardingTracker() + + with comm_mode, module_tracker: + output = model(inp) + output.sum().backward() + + print( + self.same_set_of_keys( + self.ground_truth(model), module_tracker.module_parameters_dict + ) + ) + + model2 = MLPStacked(None) + with comm_mode, module_tracker: + output = model2(inp) + + print( + self.same_set_of_keys( + self.ground_truth(model2), module_tracker.module_parameters_dict + ) + ) + + +if __name__ == "__main__": + instantiated_test = DisplayShardingExample() + instantiated_test.test_display_parameters_MLP() From f681e3689b857b8811f19d60d439bfb3fb2dd2d3 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Fri, 7 Jun 2024 14:31:14 -0700 Subject: [PATCH 553/706] [dtensor][experiment] experimenting with displaying distributed model parameters and printing sharding info (#127987) **Summary** Example code to display distributed model parameters and verify them against ground truth. Also prints sharding information. **Test Plan** torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/display_sharding_example.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/127987 Approved by: https://github.com/XilunWu ghstack dependencies: #127358, #127360, #127630 --- torch/distributed/_tensor/debug/comm_mode.py | 22 +++++ .../examples/display_sharding_example.py | 92 +++++++++++++++++-- 2 files changed, 107 insertions(+), 7 deletions(-) diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index 81f544434131..cc28498d766c 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -61,6 +61,7 @@ class ModuleParamaterShardingTracker(ModuleTracker): def __init__(self): super().__init__() self.module_parameters_dict = {} + self.sharding_dict = {} def _fw_pre_hook(self, mod, input): name = super()._get_mod_name(mod) @@ -77,14 +78,26 @@ def _fw_pre_hook(self, mod, input): self.module_parameters_dict[name][param_name] = param.data + if isinstance(param.data, DTensor): + key_name = name + "." + param_name + self.sharding_dict[key_name] = param.data.placements + def __enter__(self): self.module_parameters_dict.clear() + self.sharding_dict.clear() self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) self._fw_post_handle = register_module_forward_hook(super()._fw_post_hook) def __exit__(self, *args): super().__exit__(*args) + def print_paramater_info(self): + print(self.module_parameters_dict) + + def print_sharding_info(self): + for key, value in self.sharding_dict.items(): + print(key + ": " + str(value)) + class CommDebugMode(TorchDispatchMode): """ @@ -130,6 +143,9 @@ def get_comm_counts(self) -> Dict[Any, int]: def get_parameter_info(self) -> Dict[str, Dict[str, Any]]: return self.advanced_module_tracker.module_parameters_dict + def get_sharding_info(self) -> Dict[str, Dict[str, Any]]: + return self.advanced_module_tracker.sharding_dict + def __enter__(self): self.comm_counts.clear() super().__enter__() @@ -140,6 +156,12 @@ def __exit__(self, *args): self.advanced_module_tracker.__exit__() super().__exit__(*args) + def print_paramater_info(self): + self.advanced_module_tracker.print_paramater_info() + + def print_sharding_info(self): + self.advanced_module_tracker.print_sharding_info() + def __torch_dispatch__(self, func, types, args=(), kwargs=None): # When running this mode with DTensor, ordinarily all modes will # run **before** subclasses get a chance to run. diff --git a/torch/distributed/_tensor/examples/display_sharding_example.py b/torch/distributed/_tensor/examples/display_sharding_example.py index 95d8d73f77f4..0e32ed074534 100644 --- a/torch/distributed/_tensor/examples/display_sharding_example.py +++ b/torch/distributed/_tensor/examples/display_sharding_example.py @@ -1,23 +1,50 @@ +import os from typing import Any, Dict import torch +from torch.distributed._tensor import DeviceMesh, Shard from torch.distributed._tensor.debug import CommDebugMode - from torch.distributed._tensor.debug.comm_mode import ModuleParamaterShardingTracker +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) + from torch.testing._internal.distributed._tensor.common_dtensor import ( MLPModule, MLPStacked, + NUM_DEVICES, ) +def get_device_type(): + return ( + "cuda" + if torch.cuda.is_available() and torch.cuda.device_count() >= 4 + else "cpu" + ) + + +c10d_functional = torch.ops.c10d_functional + +aten = torch.ops.aten +supported_ops = [aten.view.default, aten._to_copy.default] + + class DisplayShardingExample: """ Checks if the set of keys in ground truth dictionary and the set produced in advanced_module_tracker are in the same order """ + def __init__(self, world_size, rank): + self.world_size = world_size + self.rank = rank + self.device_type = get_device_type() + def same_set_of_keys(self, dict1, dict2): dict1_keys = [] dict2_keys = [] @@ -54,9 +81,7 @@ def ground_truth(self, model): return module_parameters_dict def test_display_parameters_MLP(self): - """ - Example of using obtaining all module's FQN and parameters for a given model - """ + """Example of obtaining all module's FQN and parameters for a given model""" inp_size = [8, 10] @@ -67,7 +92,6 @@ def test_display_parameters_MLP(self): LR = 0.25 - optim = torch.optim.SGD(model.parameters(), lr=LR) comm_mode = CommDebugMode() module_tracker = ModuleParamaterShardingTracker() @@ -91,7 +115,61 @@ def test_display_parameters_MLP(self): ) ) + def test_display_parameters_MLP_distributed( + self, is_seq_parallel=False, recompute_activation=False + ): + "Example of obtaining all module's FQN and parameters for a given distributed model and printing the sharding info" + device_mesh = DeviceMesh( + self.device_type, + torch.arange(0, NUM_DEVICES), + ) + inp_size = [8, 10] + rng_seed = self.rank if is_seq_parallel else 0 + torch.manual_seed(rng_seed) + inp = torch.rand(*inp_size, device=self.device_type) + model = MLPModule(self.device_type) + + LR = 0.25 + + parallelize_plan = { + "net1": ColwiseParallel(input_layouts=Shard(0)) + if is_seq_parallel + else ColwiseParallel(), + "net2": RowwiseParallel(output_layouts=Shard(0)) + if is_seq_parallel + else RowwiseParallel(), + } + + model = parallelize_module(model, device_mesh, parallelize_plan) + + comm_mode = CommDebugMode() + + with comm_mode: + output_tp = model(inp) + output_tp.sum().backward() + + print( + self.same_set_of_keys( + self.ground_truth(model), comm_mode.get_parameter_info() + ) + ) + + comm_mode.print_sharding_info() + + +def run_example(world_size, rank): + # set manual seed + torch.manual_seed(0) + + # run the example + instantiated_test = DisplayShardingExample(world_size, rank) + instantiated_test.test_display_parameters_MLP_distributed() + if __name__ == "__main__": - instantiated_test = DisplayShardingExample() - instantiated_test.test_display_parameters_MLP() + # this script is launched via torchrun which automatically manages ProcessGroup + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + assert world_size == 4 # our example uses 4 worker ranks + + run_example(world_size, rank) From 7bfd1db53a83f1aa77eac1f4ebba5245765f6163 Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 9 Jun 2024 03:08:53 +0000 Subject: [PATCH 554/706] [4/N] Change static functions in headers to inline (#128286) Follows #128194. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128286 Approved by: https://github.com/Skylion007, https://github.com/XuehaiPan --- aten/src/ATen/native/Activation.h | 4 +- aten/src/ATen/native/AdaptivePooling.h | 6 +- aten/src/ATen/native/Distributions.h | 10 +- aten/src/ATen/native/FractionalMaxPooling.h | 4 +- aten/src/ATen/native/LinearAlgebraUtils.h | 72 +++++----- aten/src/ATen/native/Math.h | 144 ++++++++++---------- aten/src/ATen/native/Padding.h | 2 +- aten/src/ATen/native/Pool.h | 18 +-- aten/src/ATen/native/Pow.h | 6 +- aten/src/ATen/native/ReduceOpsUtils.h | 42 +++--- aten/src/ATen/native/ReductionType.h | 4 +- aten/src/ATen/native/Resize.h | 6 +- aten/src/ATen/native/UpSample.h | 28 ++-- aten/src/ATen/native/im2col_shape_check.h | 4 +- aten/src/ATen/native/mkldnn/Utils.h | 4 +- 15 files changed, 177 insertions(+), 177 deletions(-) diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h index dca6a39a0970..dc84547b7fe1 100644 --- a/aten/src/ATen/native/Activation.h +++ b/aten/src/ATen/native/Activation.h @@ -23,7 +23,7 @@ enum class GeluType { END }; -static GeluType get_gelutype_enum(const c10::string_view approximate) { +inline GeluType get_gelutype_enum(const c10::string_view approximate) { if (approximate == "none") { return GeluType::None; } else if (approximate == "tanh") { @@ -33,7 +33,7 @@ static GeluType get_gelutype_enum(const c10::string_view approximate) { } } -static std::string gelutype_to_string(const GeluType type) { +inline std::string gelutype_to_string(const GeluType type) { switch(type) { case GeluType::None: return "none"; case GeluType::Tanh: return "tanh"; diff --git a/aten/src/ATen/native/AdaptivePooling.h b/aten/src/ATen/native/AdaptivePooling.h index bb2fda9906ab..6c49fd38d940 100644 --- a/aten/src/ATen/native/AdaptivePooling.h +++ b/aten/src/ATen/native/AdaptivePooling.h @@ -28,15 +28,15 @@ using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, con DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel); DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel); -static inline int64_t start_index(int64_t a, int64_t b, int64_t c) { +inline int64_t start_index(int64_t a, int64_t b, int64_t c) { return (a / b) * c + ((a % b) * c) / b; } -static inline int64_t end_index(int64_t a, int64_t b, int64_t c) { +inline int64_t end_index(int64_t a, int64_t b, int64_t c) { return 1 + ((a + 1) * c - 1) / b; } -static inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) { +inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) { int64_t ndim = gradOutput_.ndimension(); for (const auto i : c10::irange(1, ndim)) { TORCH_CHECK(gradOutput_.size(i) > 0, diff --git a/aten/src/ATen/native/Distributions.h b/aten/src/ATen/native/Distributions.h index 2c334157eba9..664e2db3b2dc 100644 --- a/aten/src/ATen/native/Distributions.h +++ b/aten/src/ATen/native/Distributions.h @@ -254,7 +254,7 @@ C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler -C10_DEVICE static inline scalar_t digamma_one(scalar_t x) { +C10_DEVICE inline scalar_t digamma_one(scalar_t x) { constexpr accscalar_t PSI_10 = 2.25175258906672110764; if (x == 0) { return INFINITY; @@ -376,7 +376,7 @@ C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) { // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha. // Assumes x is close to zero and uses a Taylor expansion. template -C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) { +C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) { const scalar_t factor = digamma_one(alpha) - digamma_one(alpha + beta) - compat_log(x); scalar_t numer = 1; @@ -394,7 +394,7 @@ C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t al // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta. // Assumes x is close to zero and uses a Taylor expansion. template -C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) { +C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) { const scalar_t factor = digamma_one(alpha + beta) - digamma_one(beta); scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha; for (int i = 1; i <= 8; ++i) { @@ -412,7 +412,7 @@ C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alp // Assumes alpha and beta are both large and uses a Rice saddle point expansion. // To ensure numerical stability, this computation is performed at higher precision. template -C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) { +C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) { const accscalar_t total = alpha + beta; const accscalar_t mean = alpha / total; const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total; @@ -452,7 +452,7 @@ C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_ // This function inputs total=alpha+beta to make it easy to implement // Dirichlet reparameterized gradients in terms of Betas. template -C10_HOST_DEVICE static inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) { +C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) { accscalar_t x_ = static_cast(x); accscalar_t alpha_ = static_cast(alpha); accscalar_t total_ = static_cast(total); diff --git a/aten/src/ATen/native/FractionalMaxPooling.h b/aten/src/ATen/native/FractionalMaxPooling.h index cb5438a03e70..95c05618caef 100644 --- a/aten/src/ATen/native/FractionalMaxPooling.h +++ b/aten/src/ATen/native/FractionalMaxPooling.h @@ -6,7 +6,7 @@ namespace at::native { template -static inline std::vector generate_intervals( +inline std::vector generate_intervals( scalar_t sample, int64_t inputSize, int64_t outputSize, @@ -28,7 +28,7 @@ static inline std::vector generate_intervals( } template -static inline void fractional_max_pool_check_shape( +inline void fractional_max_pool_check_shape( const Tensor& input, const Tensor& randomSamples) { diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 0b05d5162e66..52f5e1cb6555 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -27,7 +27,7 @@ namespace at::native { -static inline c10::MaybeOwned expect_resolved_conj(const Tensor& tensor) { +inline c10::MaybeOwned expect_resolved_conj(const Tensor& tensor) { if (tensor.is_conj()) { return c10::MaybeOwned::owned(tensor.resolve_conj()); } else { @@ -35,7 +35,7 @@ static inline c10::MaybeOwned expect_resolved_conj(const Tensor& tensor) } } -static inline DimVector batched_matrix_contiguous_strides( +inline DimVector batched_matrix_contiguous_strides( const IntArrayRef sizes, const bool f_contig = false) { // f_contig chooses between the strides of a batch of Fortran (F-contiguous) @@ -62,7 +62,7 @@ static inline DimVector batched_matrix_contiguous_strides( * P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N' * matrix starting at Q.data_ptr()[B * M' * N']. */ -static inline Tensor cloneBatchedColumnMajor(const Tensor& src) { +inline Tensor cloneBatchedColumnMajor(const Tensor& src) { // If src is already in batched column major format, then // this will be efficient (no reordering of the data will occur) // because the first transpose will make the tensor contiguous, @@ -75,7 +75,7 @@ static inline Tensor cloneBatchedColumnMajor(const Tensor& src) { /* * contig chooses between C-contig (true) and F-contig (false) */ -static inline c10::MaybeOwned borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) { +inline c10::MaybeOwned borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) { return cond ? c10::MaybeOwned::borrowed(borrow) : c10::MaybeOwned::owned(contig ? clone.clone(MemoryFormat::Contiguous) : cloneBatchedColumnMajor(clone)); @@ -92,7 +92,7 @@ static inline c10::MaybeOwned borrow_else_clone(const bool cond, const T * which is either the original batch size of the input, or its larger * broadcasted shape. */ -static inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1, +inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1, at::OptionalIntArrayRef desired_batch_sizes = c10::nullopt) { nrows = (nrows == -1) ? src.size(-2) : nrows; auto copy_sizes = desired_batch_sizes.has_value() @@ -109,7 +109,7 @@ static inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = - * Given batches of matrices with arbitrary batch dim, * computes the number of batches. */ -static inline int64_t batchCount(const Tensor& batched_matrices) { +inline int64_t batchCount(const Tensor& batched_matrices) { int64_t result = 1; for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) { result *= batched_matrices.size(i); @@ -118,15 +118,15 @@ static inline int64_t batchCount(const Tensor& batched_matrices) { } // Computes the number of elements of a matrix in a batched matrix tensor -static inline int64_t matrixStride(const Tensor& batched_matrices) { +inline int64_t matrixStride(const Tensor& batched_matrices) { return batched_matrices.size(-1) * batched_matrices.size(-2); } // Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig) -static inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") { +inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") { TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions."); } -static inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") { +inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") { checkIsMatrix(self, f_name, arg_name); TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2), f_name, @@ -134,7 +134,7 @@ static inline void squareCheckInputs(const Tensor& self, const char* const f_nam "but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices"); } -static inline void checkInputsSolver(const Tensor& A, +inline void checkInputsSolver(const Tensor& A, const Tensor& B, const bool left, const char* const f_name) { @@ -146,14 +146,14 @@ static inline void checkInputsSolver(const Tensor& A, " (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")"); } -static inline bool is_row_or_column_contiguous(const Tensor& t) { +inline bool is_row_or_column_contiguous(const Tensor& t) { // This could be made more general, similar to how it's checked in matmul, which would allow to // ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky. // We choose to be conservative for simplicity return t.is_contiguous() || t.transpose(-2, -1).is_contiguous(); } -static inline TransposeType to_transpose_type(const bool contig, const bool conj) { +inline TransposeType to_transpose_type(const bool contig, const bool conj) { if (conj) { if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); } else { return TransposeType::ConjTranspose; } @@ -261,7 +261,7 @@ void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const fu } // Returns the epsilon value for floating types except half -static inline double _get_epsilon(const ScalarType& sc_type) { +inline double _get_epsilon(const ScalarType& sc_type) { switch (sc_type) { case at::ScalarType::Float: return static_cast(std::numeric_limits::epsilon()); @@ -274,7 +274,7 @@ static inline double _get_epsilon(const ScalarType& sc_type) { // Validates input shapes and devices // for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve) -static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) { +inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) { TORCH_CHECK(self.device() == A.device(), "Expected b and A to be on the same device, but found b on ", self.device(), " and A on ", A.device(), " instead."); @@ -293,7 +293,7 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, c " but each b matrix is ", self.size(-2), " by ", self.size(-1)); } -static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) { +inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) { auto dtype = t.scalar_type(); TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)), f_name, ": Expected a floating point or complex tensor as input. Got ", dtype); @@ -305,13 +305,13 @@ static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_n // Checks if all the Tensors in a TensorList are of the same dimensions -static inline void checkAllSameDim(TensorList tensors, int64_t dim) { +inline void checkAllSameDim(TensorList tensors, int64_t dim) { for (auto &t : tensors) { TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead."); } } -static inline std::tuple, std::vector> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) { +inline std::tuple, std::vector> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) { // broadcast the batch dimensions of arg1 and arg2. IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2); IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2); @@ -325,7 +325,7 @@ static inline std::tuple, std::vector> _linalg_bro return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size)); } -static inline std::tuple _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) { +inline std::tuple _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) { // If there's no name we assume we don't want to check the errors if (name != nullptr) { linearSolveCheckInputs(arg1, arg2, name); @@ -338,7 +338,7 @@ static inline std::tuple _linalg_broadcast_batch_dims(const Tenso return std::make_tuple(arg1_broadcasted, arg2_broadcasted); } -static inline std::vector broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) { +inline std::vector broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) { IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims); IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims); auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes); @@ -346,7 +346,7 @@ static inline std::vector broadcast_batch_size(const Tensor& t1, const } // Return a permutation with the given axes moved to the end. -static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { +inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { const std::vector a = axes.vec(); const int64_t ndim = self.ndimension(); std::vector perm; @@ -368,7 +368,7 @@ static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { } // parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced) -static inline std::tuple _parse_qr_mode(c10::string_view mode) { +inline std::tuple _parse_qr_mode(c10::string_view mode) { bool compute_q; bool reduced; if (mode == "reduced") { @@ -388,7 +388,7 @@ static inline std::tuple _parse_qr_mode(c10::string_view mode) { } // Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition -static inline std::tuple _compute_geometry_for_Q( +inline std::tuple _compute_geometry_for_Q( const Tensor& input, bool reduced) { int64_t m = input.size(-2), n = input.size(-1); @@ -407,7 +407,7 @@ static inline std::tuple _compute_geometry_for_Q( return std::make_tuple(q_sizes, q_strides, n_columns_q); } -static inline bool svd_uses_cusolver(const Tensor& A) { +inline bool svd_uses_cusolver(const Tensor& A) { // if cusolver is available, it is used unconditionally return A.is_cuda() && at::globalContext().hasCuSOLVER() @@ -417,7 +417,7 @@ static inline bool svd_uses_cusolver(const Tensor& A) { // Function used instead of .to so that the original strides are retained // .to doesn't retain strides and make the output tensor contiguous -static inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) { +inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) { auto strided_to = at::empty_strided(original_tensor.sizes(), original_tensor.strides(), options); @@ -433,7 +433,7 @@ static inline Tensor same_stride_to(const Tensor& original_tensor, const at::Ten // For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by // calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will // be `vec(0, 2, 1, 3)`. -static inline std::vector create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) { +inline std::vector create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) { TORCH_CHECK( (dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0), "duplicate or invalid dimensions"); @@ -453,7 +453,7 @@ static inline std::vector create_dim_backshift_permutation(int64_t dim0 // will reverse a given permutation. // The reverse permutation array is created by swapping the indices and their // associated values from the given permutation array. -static inline std::vector create_reverse_permutation(std::vector permutation) { +inline std::vector create_reverse_permutation(std::vector permutation) { int64_t ndim = permutation.size(); std::vector reverse_permutation(ndim); for (const auto dim_ind : c10::irange(ndim)) { @@ -464,7 +464,7 @@ static inline std::vector create_reverse_permutation(std::vector(std::toupper(static_cast(uplo[0]))); TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'), "Expected UPLO argument to be 'L' or 'U', but got ", uplo); } -static inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") { +inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") { TORCH_CHECK( result.device() == input.device(), fn_name, @@ -504,7 +504,7 @@ static inline void checkSameDevice(const std::string& fn_name, Tensor result, Te // (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype. // According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch // c10::canCast is used for checking the "safe copy" dtype requirements. -static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") { +inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") { bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type()); TORCH_CHECK( can_cast, @@ -514,7 +514,7 @@ static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor } // Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type) -static inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") { +inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") { bool can_cast = c10::canCast(result_type, out_type); TORCH_CHECK( can_cast, @@ -523,7 +523,7 @@ static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Scalar out_name, " with dtype ", out_type); } -static inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) { +inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) { TORCH_CHECK(!at::isComplexType(tol.scalar_type()), f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type()); } @@ -538,7 +538,7 @@ static inline void checkNotComplexTolerance(const Tensor& tol, const c10::string Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m). This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389 */ -static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) { +inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) { auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1] bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape)); return vector_case; @@ -547,7 +547,7 @@ static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& /* Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor. */ -static inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) { +inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) { TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous(); } @@ -578,7 +578,7 @@ class BroadcastLinearIndices { } }; -static inline bool is_blas_compatible_column_major_order(const Tensor& input) { +inline bool is_blas_compatible_column_major_order(const Tensor& input) { IntArrayRef input_strides = input.strides(); IntArrayRef input_sizes = input.sizes(); auto ndim = input.dim(); @@ -599,7 +599,7 @@ static inline bool is_blas_compatible_column_major_order(const Tensor& input) { batch_stride_compatible; } -static inline bool is_blas_compatible_row_major_order(const Tensor& input) { +inline bool is_blas_compatible_row_major_order(const Tensor& input) { IntArrayRef input_strides = input.strides(); IntArrayRef input_sizes = input.sizes(); auto ndim = input.dim(); diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 8296d6cf60a2..e86a9aea411a 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -147,7 +147,7 @@ jiterator_also_stringify_as(jiterator_code( #define CENTRAL_RANGE 0.7 template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type calc_erfinv(T y) { /* Function to calculate inverse error function. Rational approximation is used to generate an initial approximation, which is then improved to @@ -232,7 +232,7 @@ Date: February 1996 * See note [3-Clause BSD License for the Cephes Math Library]. */ template -C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ { +C10_HOST_DEVICE inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ { using acc_t = at::acc_type; const acc_t MACHEP = acc_t{1.11022302462515654042E-16}; constexpr acc_t zero = acc_t{0.0}; @@ -324,7 +324,7 @@ C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_igno * N 0 */ template -C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) { +C10_HOST_DEVICE inline T polevl(const T x, const T A[], size_t len) { T result = 0; for (size_t i = 0; i <= len; i++) { result = result * x + A[i]; @@ -332,7 +332,7 @@ C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) { return result; } -static inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ { +inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ { double sign = +1; double result = 0; if (x < 0.5) { @@ -350,7 +350,7 @@ static inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ { return sign * result; } -static inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ { +inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ { float sign = +1; float result = 0; if (x < 0.5f) { @@ -372,7 +372,7 @@ static inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ { * This function is derived from the implementation of the digamma function in the Cephes Math Library. * See note [3-Clause BSD License for the Cephes Math Library]. */ -static inline double calc_digamma(double x) { +inline double calc_digamma(double x) { // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma static double PSI_10 = 2.25175258906672110764; if (x == 0) { @@ -430,7 +430,7 @@ static inline double calc_digamma(double x) { * This function is derived from the implementation of the digamma function in the Cephes Math Library. * See note [3-Clause BSD License for the Cephes Math Library]. */ -static inline float calc_digamma(float x) { +inline float calc_digamma(float x) { // See [C++ Standard Reference: Gamma Function] static float PSI_10 = 2.25175258906672110764f; if (x == 0) { @@ -485,16 +485,16 @@ static inline float calc_digamma(float x) { return result + logf(x) - (0.5f / x) - y; } -static inline c10::BFloat16 calc_digamma(c10::BFloat16 a) { +inline c10::BFloat16 calc_digamma(c10::BFloat16 a) { return calc_digamma(static_cast(a)); } -static inline c10::Half calc_digamma(c10::Half a) { +inline c10::Half calc_digamma(c10::Half a) { return calc_digamma(static_cast(a)); } template -static inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { +inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { // already blocked if n <= 1 const auto one = scalar_t{1}; return ((n % 2) ? one : -one) * @@ -519,7 +519,7 @@ static inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { * See NOTICE for the licenses. */ template -static scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, +scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, const scalar_t denom[], int64_t N) { // evaluating rational function, i.e., the ratio of two polynomials // the coefficients for numerator are given by `num` while coeffs for @@ -1061,7 +1061,7 @@ static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { } template -static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { +inline scalar_t calc_igammac(scalar_t a, scalar_t x) { /* the calculation of the regularized upper incomplete gamma function * is done differently based on the values of a and x: * - if x and/or a is at the boundary of defined region, then assign the @@ -1141,7 +1141,7 @@ static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { } template -static inline scalar_t calc_igamma(scalar_t a, scalar_t x) { +scalar_t calc_igamma(scalar_t a, scalar_t x) { /* the calculation of the regularized lower incomplete gamma function * is done differently based on the values of a and x: * - if x and/or a is at the boundary of defined region, then assign the @@ -1203,39 +1203,39 @@ static inline scalar_t calc_igamma(scalar_t a, scalar_t x) { } template <> -C10_UNUSED c10::BFloat16 calc_igamma(c10::BFloat16 a, c10::BFloat16 x) { +C10_UNUSED inline c10::BFloat16 calc_igamma(c10::BFloat16 a, c10::BFloat16 x) { return calc_igamma(float(a), float(x)); } template <> -C10_UNUSED c10::Half calc_igamma(c10::Half a, c10::Half x) { +C10_UNUSED inline c10::Half calc_igamma(c10::Half a, c10::Half x) { return calc_igamma(float(a), float(x)); } template <> -C10_UNUSED c10::BFloat16 calc_igammac(c10::BFloat16 a, c10::BFloat16 x) { +C10_UNUSED inline c10::BFloat16 calc_igammac(c10::BFloat16 a, c10::BFloat16 x) { return calc_igammac(float(a), float(x)); } template <> -C10_UNUSED c10::Half calc_igammac(c10::Half a, c10::Half x) { +C10_UNUSED inline c10::Half calc_igammac(c10::Half a, c10::Half x) { return calc_igammac(float(a), float(x)); } inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); } template -static T abs_impl(T v) { +inline T abs_impl(T v) { return std::abs(v); } template <> -C10_UNUSED uint8_t abs_impl(uint8_t v) { +C10_UNUSED inline uint8_t abs_impl(uint8_t v) { return v; } template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type calc_gcd(T a, T b) { a = abs_impl(a); b = abs_impl(b); @@ -1284,7 +1284,7 @@ C10_HOST_DEVICE c10::complex exp2_impl(c10::complex x) { * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1. */ template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type chbevl(const T x, const T array[], size_t len) { T b0, b1, b2; @@ -1310,7 +1310,7 @@ chbevl(const T x, const T array[], size_t len) { * of all inputs to convert them into the domain of the approximation. */ template -static inline std::tuple chebyshev_coefficients_i0e_A() { +inline std::tuple chebyshev_coefficients_i0e_A() { /* Chebyshev coefficients for exp(-x) I0(x) * in the interval [0,8]. * @@ -1336,7 +1336,7 @@ static inline std::tuple chebyshev_coefficients_i0e_A() { }; template -static inline std::tuple chebyshev_coefficients_i0e_B() { +inline std::tuple chebyshev_coefficients_i0e_B() { /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) * in the inverted interval [8,infinity]. * @@ -1361,7 +1361,7 @@ static inline std::tuple chebyshev_coefficients_i0e_B() { }; template -static inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if::value, std::tuple>::type chebyshev_coefficients_i1e_A() { /* Chebyshev coefficients for exp(-x) I1(x) * in the interval [0,8]. @@ -1388,7 +1388,7 @@ chebyshev_coefficients_i1e_A() { }; template -static inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if::value, std::tuple>::type chebyshev_coefficients_i1e_A() { /* Chebyshev coefficients for exp(-x) I1(x) * in the interval [0,8]. @@ -1417,7 +1417,7 @@ chebyshev_coefficients_i1e_A() { }; template -static inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if::value, std::tuple>::type chebyshev_coefficients_i1e_B() { /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) * in the inverted interval [8,infinity]. @@ -1443,7 +1443,7 @@ chebyshev_coefficients_i1e_B() { }; template -static inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if::value, std::tuple>::type chebyshev_coefficients_i1e_B() { /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) * in the inverted interval [8,infinity]. @@ -1463,7 +1463,7 @@ chebyshev_coefficients_i1e_B() { }; template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type calc_i0(T _x) { T x = std::abs(_x); @@ -1481,7 +1481,7 @@ calc_i0(T _x) { } // Upcast bfloat16 input to float for numerical accuracy purposes -static inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast(a)); } +inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast(a)); } /* * This function is derived from the implementation of the i1 function in the Cephes Math Library. @@ -1493,7 +1493,7 @@ static inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cas * of all inputs to convert them into the domain of the approximation. */ template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type calc_i1(T _x) { T x = std::abs(_x); @@ -1522,7 +1522,7 @@ calc_i1(T _x) { * of all inputs to convert them into the domain of the approximation. */ template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type calc_i1e(T _x) { T x = std::abs(_x); @@ -1549,7 +1549,7 @@ calc_i1e(T _x) { * (integrated from minus infinity to x) is equal to y. */ template -static inline C10_HOST_DEVICE T calc_ndtri(T y0) { +inline C10_HOST_DEVICE T calc_ndtri(T y0) { /* sqrt(2pi) */ constexpr T s2pi = 2.50662827463100050242E0; @@ -1737,7 +1737,7 @@ static inline C10_HOST_DEVICE T calc_ndtri(T y0) { template -C10_HOST_DEVICE static inline typename std::enable_if::value, T>::type +C10_HOST_DEVICE inline typename std::enable_if::value, T>::type erfcx_y100(T y100) { switch (static_cast(y100)) { @@ -2148,7 +2148,7 @@ return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682 } template -C10_HOST_DEVICE static inline typename std::enable_if::value, T>::type +C10_HOST_DEVICE inline typename std::enable_if::value, T>::type calc_erfcx(T x) { if (at::_isnan(x)) { @@ -2188,7 +2188,7 @@ calc_erfcx(T x) * See NOTICE for the licenses. */ template -static inline C10_HOST_DEVICE T calc_log_ndtr(T x) { +inline C10_HOST_DEVICE T calc_log_ndtr(T x) { T t = x * c10::frac_sqrt_2; if (x < T{-1.0}) { return std::log(calc_erfcx(-t) / 2) - t * t; @@ -2198,7 +2198,7 @@ static inline C10_HOST_DEVICE T calc_log_ndtr(T x) { } template -static inline C10_HOST_DEVICE T airy_ai_forward(T x) { +inline C10_HOST_DEVICE T airy_ai_forward(T x) { static const T AN[] = { +3.46538101525629032477e-01, +1.20075952739645805542e+01, @@ -2377,7 +2377,7 @@ static inline C10_HOST_DEVICE T airy_ai_forward(T x) { } // T airy_ai(T x) template -static inline C10_HOST_DEVICE T bessel_j0_forward(T x) { +inline C10_HOST_DEVICE T bessel_j0_forward(T x) { static const T PP[] = { +7.96936729297347051624e-04, +8.28352392107440799803e-02, @@ -2489,7 +2489,7 @@ static inline C10_HOST_DEVICE T bessel_j0_forward(T x) { } // bessel_j0_forward(T x) template -static inline C10_HOST_DEVICE T bessel_j1_forward(T x) { +inline C10_HOST_DEVICE T bessel_j1_forward(T x) { static const T PP[] = { +7.62125616208173112003e-04, +7.31397056940917570436e-02, @@ -2597,7 +2597,7 @@ static inline C10_HOST_DEVICE T bessel_j1_forward(T x) { } // bessel_j1_forward(T x) template -static inline C10_HOST_DEVICE T bessel_y0_forward(T x) { +inline C10_HOST_DEVICE T bessel_y0_forward(T x) { static const T PP[] = { +7.96936729297347051624e-04, +8.28352392107440799803e-02, @@ -2712,7 +2712,7 @@ static inline C10_HOST_DEVICE T bessel_y0_forward(T x) { } // bessel_y0_forward(T x) template -static inline C10_HOST_DEVICE T bessel_y1_forward(T x) { +inline C10_HOST_DEVICE T bessel_y1_forward(T x) { static const T PP[] = { +7.62125616208173112003e-04, +7.31397056940917570436e-02, @@ -2826,7 +2826,7 @@ static inline C10_HOST_DEVICE T bessel_y1_forward(T x) { } // bessel_y1_forward(T x) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -2865,12 +2865,12 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { } // chebyshev_polynomial_t_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) { return chebyshev_polynomial_t_forward(x, static_cast(n)); } // chebyshev_polynomial_t_forward(T x, T n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -2913,12 +2913,12 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { } // chebyshev_polynomial_u_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) { return chebyshev_polynomial_u_forward(x, static_cast(n)); } // chebyshev_polynomial_u_forward(T x, T n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -2969,12 +2969,12 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { } // chebyshev_polynomial_v_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) { return chebyshev_polynomial_v_forward(x, static_cast(n)); } // chebyshev_polynomial_v_forward(T x, T n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3029,12 +3029,12 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { } // chebyshev_polynomial_w_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) { return chebyshev_polynomial_w_forward(x, static_cast(n)); } // chebyshev_polynomial_w_forward(T x, T n) template -static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3061,17 +3061,17 @@ static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) { } // hermite_polynomial_h_forward(T x, int64_t n) template::value, int> = 0> -static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { return hermite_polynomial_h_forward(x, static_cast(n)); } // hermite_polynomial_h_forward(T x, T n) template::value, int> = 0> -static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { return hermite_polynomial_h_forward(x, ((!std::isinf(n)) && (!std::isnan(n))) ? static_cast(n) : static_cast(-1)); } // hermite_polynomial_h_forward(T x, T n) template -static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3098,12 +3098,12 @@ static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) { } // hermite_polynomial_he_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) { +inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) { return hermite_polynomial_he_forward(x, static_cast(n)); } // hermite_polynomial_he_forward(T x, T n) template -static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3134,12 +3134,12 @@ static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { } // laguerre_polynomial_l_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) { +inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) { return laguerre_polynomial_l_forward(x, static_cast(n)); } // laguerre_polynomial_l_forward(T x, T n) template -static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3174,12 +3174,12 @@ static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { } // legendre_polynomial_p_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) { +inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) { return legendre_polynomial_p_forward(x, static_cast(n)); } // legendre_polynomial_p_forward(T x, T n) template -static inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) { +inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) { static const T A[] = { -4.41534164647933937950e-18, +3.33079451882223809783e-17, @@ -3268,7 +3268,7 @@ static inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) { } // modified_bessel_i0_forward(T x) template -static inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) { +inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) { static const T A[] = { +2.77791411276104639959e-18, -2.11142121435816608115e-17, @@ -3364,7 +3364,7 @@ static inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) { } // modified_bessel_i1_forward(T x) template -static inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) { +inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) { static const T A[] = { +1.37446543561352307156e-16, +4.25981614279661018399e-14, @@ -3441,7 +3441,7 @@ static inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) { } // modified_bessel_k0_forward(T x) template -static inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { +inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { static const T A[] = { -7.02386347938628759343e-18, -2.42744985051936593393e-15, @@ -3519,7 +3519,7 @@ static inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { } // modified_bessel_k1_forward(T x) template -static inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { +inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { static const T A[] = { +1.37446543561352307156e-16, +4.25981614279661018399e-14, @@ -3596,7 +3596,7 @@ static inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { } // T scaled_modified_bessel_k0_forward(T x) template -static inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { +inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { static const T A[] = { -7.02386347938628759343e-18, -2.42744985051936593393e-15, @@ -3674,7 +3674,7 @@ static inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { } // T scaled_modified_bessel_k1_forward(T x) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3717,12 +3717,12 @@ static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int6 } // shifted_chebyshev_polynomial_t_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) { return shifted_chebyshev_polynomial_t_forward(x, static_cast(n)); } // shifted_chebyshev_polynomial_t_forward(T x, T n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3769,12 +3769,12 @@ static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int6 } // shifted_chebyshev_polynomial_u_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) { return shifted_chebyshev_polynomial_u_forward(x, static_cast(n)); } // shifted_chebyshev_polynomial_u_forward(T x, T n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3825,12 +3825,12 @@ static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int6 } // shifted_chebyshev_polynomial_v_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) { return shifted_chebyshev_polynomial_v_forward(x, static_cast(n)); } // shifted_chebyshev_polynomial_v_forward(T x, T n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3881,12 +3881,12 @@ static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int6 } // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) { return shifted_chebyshev_polynomial_w_forward(x, static_cast(n)); } // shifted_chebyshev_polynomial_w_forward(T x, T n) template -static inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) { +inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) { if (std::isinf(x)) { return T(0.0); } diff --git a/aten/src/ATen/native/Padding.h b/aten/src/ATen/native/Padding.h index 083436134282..53a054027f33 100644 --- a/aten/src/ATen/native/Padding.h +++ b/aten/src/ATen/native/Padding.h @@ -26,7 +26,7 @@ DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel); namespace padding { template -static inline void check_valid_input(const Tensor& input, IntArrayRef padding) { +inline void check_valid_input(const Tensor& input, IntArrayRef padding) { TORCH_CHECK(padding.size() == 2 * dim, "padding size is expected to be ", 2 * dim, diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index df73299ea230..df677019e897 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -48,7 +48,7 @@ DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel); namespace { template -static inline dest_t +inline dest_t safe_downcast(src_t v) { TORCH_CHECK(std::numeric_limits::min() <= v && v <= std::numeric_limits::max(), @@ -58,7 +58,7 @@ safe_downcast(src_t v) } template -static inline T pooling_output_shape_pad_lr( +inline T pooling_output_shape_pad_lr( T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation, bool ceil_mode) { T outputSize = div_rtn( @@ -75,7 +75,7 @@ static inline T pooling_output_shape_pad_lr( } template -static inline T pooling_output_shape( +inline T pooling_output_shape( T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) { TORCH_CHECK(stride != 0, "stride should not be zero"); TORCH_CHECK(pad >= 0, @@ -117,7 +117,7 @@ inline std::pair pooling_same_mode_padding_lr( } // AveragePool2d/DilatedMaxPool2d (forward) -static inline void +inline void pool2d_shape_check( const Tensor& input, int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW, @@ -164,7 +164,7 @@ pool2d_shape_check( } // DilatedMaxPool2d (backward) -static inline void +inline void max_pool2d_backward_shape_check( const Tensor& input, const Tensor& gradOutput, @@ -192,7 +192,7 @@ max_pool2d_backward_shape_check( } // AveragePool2d (backward) -static inline void +inline void avg_pool2d_backward_shape_check( const Tensor& input, const Tensor& gradOutput, @@ -218,7 +218,7 @@ avg_pool2d_backward_shape_check( } // AveragePool3d/DilatedMaxPool3d (forward) -static inline void +inline void pool3d_shape_check( const Tensor& input, int64_t nslices, @@ -280,7 +280,7 @@ pool3d_shape_check( "Output size is too small"); } -static inline void +inline void max_pool3d_backward_shape_check( const Tensor& input, const Tensor& gradOutput, @@ -317,7 +317,7 @@ max_pool3d_backward_shape_check( check_dim_size(indices, ndim, ndim-1, owidth); } -static inline void +inline void avg_pool3d_backward_shape_check( const Tensor& input, const Tensor& gradOutput, diff --git a/aten/src/ATen/native/Pow.h b/aten/src/ATen/native/Pow.h index 068482ee300c..76ddda846a59 100644 --- a/aten/src/ATen/native/Pow.h +++ b/aten/src/ATen/native/Pow.h @@ -24,7 +24,7 @@ namespace native { // only non-zero result. template ::value, T>::type* = nullptr> -static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) { +inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) { T result = 1; while (b) { if (b & 1) { @@ -38,13 +38,13 @@ static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, template ::value && !std::is_signed::value, T>::type* = nullptr> -static inline HOST_DEVICE T powi(T a, T b) { +inline HOST_DEVICE T powi(T a, T b) { return powi_impl(a, b); } template ::value && std::is_signed::value, T>::type* = nullptr> -static inline HOST_DEVICE T powi(T a, T b) { +inline HOST_DEVICE T powi(T a, T b) { if ( b < 0 ) { if ( a == 1 ) { return 1; diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h index 505cf3bb3a77..cfb4776fa846 100644 --- a/aten/src/ATen/native/ReduceOpsUtils.h +++ b/aten/src/ATen/native/ReduceOpsUtils.h @@ -31,7 +31,7 @@ constexpr scalar_t lower_bound() { return lim::has_infinity ? -lim::infinity() : lim::lowest(); } -static inline Tensor restride_dim( +inline Tensor restride_dim( const Tensor& src, int64_t dim, IntArrayRef replacement_shape ) { @@ -96,13 +96,13 @@ inline std::optional _allreduce_return_trivial( " but found ", out.option())\ } -static inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) { +inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) { OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self); OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options()); OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options()); } -static inline Tensor integer_upcast(const Tensor& self, std::optional dtype) { +inline Tensor integer_upcast(const Tensor& self, std::optional dtype) { ScalarType scalarType = self.scalar_type(); TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented"); ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType); @@ -111,7 +111,7 @@ static inline Tensor integer_upcast(const Tensor& self, std::optional get_zero_numel_tensor_size( +inline std::vector get_zero_numel_tensor_size( const Tensor& self, const int64_t dim, const bool keepdim, @@ -313,7 +313,7 @@ static std::vector get_zero_numel_tensor_size( // This function should be called when you are reducing a zero-numel tensor and want to // resize the output and return it. This function exists for resizing zero-numel // tensors when the size of the reduction dimension is non-zero. -static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices, +inline C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices, const Tensor& self, const int64_t dim, const bool keepdim, const char *fn_name) { auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name); @@ -349,7 +349,7 @@ inline ScalarType get_dtype_from_result(Tensor& result, std::optional inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); } template -static inline void checkInBoundsForStorage( +inline void checkInBoundsForStorage( ArrayRef size, ArrayRef stride, T storage_offset, @@ -111,7 +111,7 @@ static inline void checkInBoundsForStorage( } template -static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset, +inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset, ArrayRef size, ArrayRef stride) { // FIXME: stride should be optional if (stride.data()) { diff --git a/aten/src/ATen/native/UpSample.h b/aten/src/ATen/native/UpSample.h index 6f063a0dc2fb..275d2028f764 100644 --- a/aten/src/ATen/native/UpSample.h +++ b/aten/src/ATen/native/UpSample.h @@ -210,7 +210,7 @@ std::array upsample_3d_common_check(IntArrayRef input_size, IntArray return {nbatch, channels, output_depth, output_height, output_width}; } -static inline void upsample_2d_shape_check( +inline void upsample_2d_shape_check( const Tensor& input, const Tensor& grad_output, int64_t nbatch, @@ -251,7 +251,7 @@ static inline void upsample_2d_shape_check( } template -static inline scalar_t compute_scales_value( +inline scalar_t compute_scales_value( const std::optional scale, int64_t input_size, int64_t output_size) { @@ -263,7 +263,7 @@ static inline scalar_t compute_scales_value( } template -static inline scalar_t area_pixel_compute_scale( +inline scalar_t area_pixel_compute_scale( int64_t input_size, int64_t output_size, bool align_corners, @@ -281,7 +281,7 @@ static inline scalar_t area_pixel_compute_scale( } template -static inline scalar_t area_pixel_compute_source_index( +inline scalar_t area_pixel_compute_source_index( scalar_t scale, int64_t dst_index, bool align_corners, @@ -308,7 +308,7 @@ static inline scalar_t area_pixel_compute_source_index( } } -static inline int64_t nearest_neighbor_compute_source_index( +inline int64_t nearest_neighbor_compute_source_index( const float scale, int64_t dst_index, int64_t input_size) { @@ -319,7 +319,7 @@ static inline int64_t nearest_neighbor_compute_source_index( return src_index; } -static inline int64_t nearest_neighbor_exact_compute_source_index( +inline int64_t nearest_neighbor_exact_compute_source_index( const float scale, int64_t dst_index, int64_t input_size) { @@ -331,7 +331,7 @@ static inline int64_t nearest_neighbor_exact_compute_source_index( return src_index; } -static inline int64_t nearest_idx( +inline int64_t nearest_idx( int64_t output_index, int64_t input_size, int64_t output_size, @@ -352,7 +352,7 @@ static inline int64_t nearest_idx( } } -static inline int64_t nearest_exact_idx( +inline int64_t nearest_exact_idx( int64_t output_index, int64_t input_size, int64_t output_size, @@ -392,17 +392,17 @@ static void upsample_increment_value_bounded( // Based on // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm template -static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) { +inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) { return ((A + 2) * x - (A + 3)) * x * x + 1; } template -static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) { +inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) { return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; } template -static inline void get_cubic_upsample_coefficients( +inline void get_cubic_upsample_coefficients( scalar_t coeffs[4], scalar_t t) { scalar_t A = -0.75; @@ -418,7 +418,7 @@ static inline void get_cubic_upsample_coefficients( } template -static inline scalar_t cubic_interp1d( +inline scalar_t cubic_interp1d( scalar_t x0, scalar_t x1, scalar_t x2, @@ -434,7 +434,7 @@ static inline scalar_t cubic_interp1d( // type can accurately represent, the type casting to `int64_t` might exceed // `input_size`, causing overflow. So we guard it with `std::min` below. template -static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) { +inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) { input_index = std::min(static_cast(floorf(real_input_index)), input_size - 1); lambda = std::min( std::max(real_input_index - input_index, static_cast(0)), @@ -443,7 +443,7 @@ static inline void guard_index_and_lambda(const opmath_t& real_input_index, cons } template -static inline void compute_source_index_and_lambda( +inline void compute_source_index_and_lambda( int64_t& input_index0, int64_t& input_index1, scalar_t& lambda0, diff --git a/aten/src/ATen/native/im2col_shape_check.h b/aten/src/ATen/native/im2col_shape_check.h index f7ae0854f78e..8a6fa47ba10f 100644 --- a/aten/src/ATen/native/im2col_shape_check.h +++ b/aten/src/ATen/native/im2col_shape_check.h @@ -5,7 +5,7 @@ namespace at::native { -static inline void col2im_shape_check( +inline void col2im_shape_check( const Tensor& input, const Tensor& grad_output, int64_t output_height, @@ -135,7 +135,7 @@ static inline void col2im_shape_check( } } -static inline void im2col_shape_check( +inline void im2col_shape_check( const Tensor& input, const Tensor& grad_output, int64_t kernel_height, diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h index 75f1b2c1b709..a63d9ebfa2c1 100644 --- a/aten/src/ATen/native/mkldnn/Utils.h +++ b/aten/src/ATen/native/mkldnn/Utils.h @@ -36,7 +36,7 @@ void check_mkldnn_binary_fusion_inputs( const Tensor& weight, const Tensor& bias); -static inline std::vector padding_r( +inline std::vector padding_r( IntArrayRef padding, IntArrayRef output_padding) { // ConvTranpose padding adjustment @@ -60,7 +60,7 @@ static inline std::vector padding_r( // Make sure input has default contiguous strides if it's contiguous tensors for better performance. // For example, for tensor of size = [1, 1280], stride = [0, 1], we'll convert it to size = [1, 1280], stride = [1280, 1] // before calling oneDNN for better performance. -static inline Tensor may_convert_to_default_contiguous_strides(const Tensor& input) { +inline Tensor may_convert_to_default_contiguous_strides(const Tensor& input) { auto input_size = input.sizes().vec(); auto input_stride = input.strides().vec(); auto input_default_contiguous_strides = c10::contiguous_strides(input_size); From 31c3fa6cf57cdea99f2da1bd106c7bfaff2dd0bd Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Sun, 9 Jun 2024 04:29:02 +0000 Subject: [PATCH 555/706] [audio hash update] update the pinned audio hash (#128178) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128178 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 98cd949f9713..a8141b25ecdd 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -1980f8af5bcd0bb2ce51965cf79d8d4c25dad8a0 +b829e936f7cc61b48149f5f957a451a38bf2a178 From 3964a3ec7351fb51460351d8914e562817d8cd72 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 8 Jun 2024 18:12:54 -0700 Subject: [PATCH 556/706] Complete revamp of float/promotion sympy handling (#126905) At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** **Reland notes.** This requires this internal fbcode diff https://www.internalfb.com/phabricator/paste/view/P1403322587 but I cannot prepare the diff codev due to https://fb.workplace.com/groups/osssupport/posts/26343544518600814/ It also requires this Executorch PR https://github.com/pytorch/executorch/pull/3911 but the ET PR can be landed prior to this landing. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano --- c10/core/SymNodeImpl.h | 18 + test/dynamo/test_dynamic_shapes.py | 7 - test/dynamo/test_misc.py | 17 +- test/inductor/test_indexing.py | 72 +--- .../test_torchinductor_dynamic_shapes.py | 28 ++ test/onnx/test_fx_to_onnx_with_onnxruntime.py | 8 +- test/test_dynamic_shapes.py | 208 +++------ test/test_proxy_tensor.py | 3 +- test/test_sympy_utils.py | 122 +++--- torch/__init__.py | 162 ++++++- torch/_export/pass_base.py | 1 + torch/_export/serde/serialize.py | 10 +- torch/_export/verifier.py | 1 + torch/_inductor/bounds.py | 5 + torch/_inductor/codegen/common.py | 176 ++++++-- torch/_inductor/codegen/cpp.py | 4 +- torch/_inductor/codegen/cpp_utils.py | 55 ++- torch/_inductor/codegen/triton.py | 64 ++- torch/_inductor/graph.py | 5 +- torch/_inductor/ir.py | 16 +- torch/_inductor/kernel/flex_attention.py | 5 +- torch/_inductor/lowering.py | 6 +- torch/_inductor/ops_handler.py | 60 ++- torch/_inductor/select_algorithm.py | 4 +- torch/_inductor/sizevars.py | 20 +- torch/_inductor/utils.py | 2 +- torch/_subclasses/fake_tensor.py | 2 +- torch/csrc/jit/python/init.cpp | 5 + torch/csrc/utils/python_symnode.h | 20 + torch/export/dynamic_shapes.py | 9 +- torch/fx/experimental/recording.py | 8 +- torch/fx/experimental/sym_node.py | 210 +++++++-- torch/fx/experimental/symbolic_shapes.py | 103 +++-- torch/fx/experimental/validator.py | 32 +- torch/utils/_sympy/functions.py | 398 ++++++++++++++---- torch/utils/_sympy/interp.py | 71 +++- torch/utils/_sympy/reference.py | 151 ++++--- torch/utils/_sympy/solve.py | 1 + torch/utils/_sympy/value_ranges.py | 278 ++++++++---- 39 files changed, 1697 insertions(+), 670 deletions(-) diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 9ffab5065109..bb92b09775b7 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -49,15 +49,33 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode mul(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + // NB: legacy, prefer float_truediv or int_truediv virtual SymNode truediv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode float_truediv(const SymNode& other) { + return truediv(other); + } + virtual SymNode int_truediv(const SymNode& other) { + return truediv(other); + } + // NB: legacy, prefer float_pow or pow_by_natural virtual SymNode pow(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode float_pow(const SymNode& other) { + return pow(other); + } + virtual SymNode pow_by_natural(const SymNode& other) { + return pow(other); + } + // NB: legacy, prefer int_floordiv virtual SymNode floordiv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode int_floordiv(const SymNode& other) { + return floordiv(other); + } virtual SymNode mod(const SymNode& other) { TORCH_CHECK(false, "NYI"); } diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 0bead6e47e48..a3c63ef66152 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -78,13 +78,6 @@ def make_dynamic_cls(cls): del test if TEST_Z3: - # this only fails when z3 is available - unittest.expectedFailure( - # SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'. - # Ref: https://github.com/sympy/sympy/issues/25146 - DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821 - ) - if not config.inline_inbuilt_nn_modules: # TODO model is somehow not being freed when z3 is available unittest.expectedFailure( diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index ce423eab7d8a..02f7c68aa1a9 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9309,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9343,7 +9343,7 @@ def test_shape_env_equal_unbacked(self): > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), u1: ValueRanges(lower=0, upper=1, is_bool=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False)} + > Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]} > Right: {} """, ) @@ -9420,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self): > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]} + > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} """, ) self._replay_and_check(main) @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self): > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]} + > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} """, ) self._replay_and_check(main) @@ -9484,10 +9484,7 @@ def test_shape_env_equal_runtime_assert(self): ShapeEnv not equal: field values don't match: ==> deferred_runtime_asserts: values don't match. - > Left: {u0: [Eq(Mod(u0, 3), 0)]} - > Right: {} -==> divisible: values don't match. - > Left: {Mod(u0, 3)} + > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} > Right: {} ==> name_to_node: values don't match. > Left: {_assert, eq, mod, u0} diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index a3a7bf4b83ab..19a736160908 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -19,7 +19,12 @@ parametrize, ) from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA -from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Round, RoundDecimal +from torch.utils._sympy.functions import ( + FloorDiv, + ModularIndexing, + RoundDecimal, + RoundToInt, +) DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" @@ -245,21 +250,11 @@ def test_print_pow(self): common_cases = [ # expr, result - # Test exprs. - ( - s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1), - lambda c, L: f"((-1{L})*({c}/((-1{L}) + (2{L}*foo)))) + (foo*({c}/((-1{L}) + (2{L}*foo))))", - ), - (s1 / (s2 - s3), lambda c, L: f"foo*({c}/(bar + ((-1{L})*baz)))"), # Test Pow directly. ( sympy.Pow(s1 + s2, 0), lambda _, L: f"1{L}", ), # note: simplified before _print_Pow - ( - sympy.Pow(s1 + s2, -3), - lambda c, _: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))", - ), ] gpu_cases = common_cases + [ @@ -308,12 +303,10 @@ def test_print_ceil(self): self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""") def test_print_round(self): - expr = Round(sympy.Symbol("x", integer=True) / 2) + expr = RoundToInt(sympy.Symbol("x", integer=True) / 2) self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""") self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""") - self.assertExpectedInline( - texpr(expr), """libdevice.llrint((1/2)*x).to(tl.int64)""" - ) + self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""") @parametrize("ndigits", [-1, 0, 1]) def test_print_round_decimal(self, ndigits): @@ -328,45 +321,18 @@ def test_print_round_decimal(self, ndigits): f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}", ) - expr = RoundDecimal(sympy.Symbol("x", integer=True), ndigits) - if ndigits >= 0: - for do_print in [pexpr, cexpr, texpr]: - self.assertEqual(do_print(expr), "x") - else: - self.assertEqual(pexpr(expr), f"round(x, {ndigits})") - for do_print in [cexpr, texpr]: - with self.assertRaisesRegex( - ValueError, "only non-negative ndigits are currently supported" - ): - do_print(expr) - def test_print_floor_div(self): - for integer in [True, False]: - s1 = sympy.Symbol("s1", integer=integer) - s2 = sympy.Symbol("s2", integer=integer) - expr = FloorDiv(s1, s2) - self.assertEqual(pexpr(expr), "(s1 // s2)") - if integer: - self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") - else: - self.assertEqual( - cexpr(expr), - "c10::div_floor_floating(static_cast(s1), static_cast(s2))", - ) - - for integer in [True, False]: - s1 = sympy.Symbol("s1", integer=integer) - s2 = sympy.S(-1) - expr = FloorDiv(s1, s2) - if integer: - self.assertEqual(pexpr(expr), "(-1)*s1") - self.assertEqual(cexpr(expr), "(-1L)*s1") - else: - self.assertEqual(pexpr(expr), "(s1 // (-1))") - self.assertEqual( - cexpr(expr), - "c10::div_floor_floating(static_cast(s1), static_cast((-1L)))", - ) + s1 = sympy.Symbol("s1", integer=True) + s2 = sympy.Symbol("s2", integer=True) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(s1 // s2)") + self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") + + s1 = sympy.Symbol("s1", integer=True) + s2 = sympy.S(-1) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(-1)*s1") + self.assertEqual(cexpr(expr), "(-1L)*s1") def test_print_Min_Max(self): cases = ( diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 8513e928c412..2f9506a9d561 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -3,6 +3,7 @@ import importlib import math +import operator import os import sys import unittest @@ -649,6 +650,33 @@ def fn(a): actual = cfn(5) self.assertEqual(expect, actual) + def test_interpolate_ceil_eq(self, device): + ceiling = math.ceil + IntTrueDiv = operator.truediv + + def fn(t): + s0, s2, s3 = t.size() + x = torch.zeros( + ( + s0, + 2048, + ceiling(IntTrueDiv(2 * ((s2 - 1) // 8) + 2, 1)), + ceiling(IntTrueDiv(2 * ((s3 - 1) // 8) + 2, 1)), + ), + dtype=torch.bfloat16, + ) + return torch.nn.functional.interpolate( + x, + scale_factor=2, + mode="nearest", + ) + + cfn = self.compile_fn(fn) + arg = torch.randn(4, 16, 18) + expect = fn(arg) + actual = cfn(arg) + self.assertEqual(expect, actual) + def test_full_recompiles(self, device): def fn(x): _, L = x.shape diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index b70bfbf9c4a7..0f0e01bc0dc2 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -158,8 +158,12 @@ def forward(self, x, y): torch.tensor([operator.sub(x.item(), y.item())]), torch.tensor([operator.mul(x.item(), y.item())]), torch.tensor([operator.truediv(x.item(), y.item())]), - torch.tensor([operator.floordiv(x.item(), y.item())]), - torch.tensor([operator.pow(x.item(), y.item())]), + # This requires torch.sym_float, probably easy to lower to + # ONNX but I don't know where to put it + # torch.tensor([operator.floordiv(x.item(), y.item())]), + # NB: abs so that the base and exponent are provably + # non-negative, so we don't generate runtime asserts + torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]), torch.tensor([operator.abs(x.item())]), torch.tensor([operator.neg(x.item())]), torch.tensor([math.ceil(x.item())]), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d548e9df0707..3b47f12198d5 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -205,15 +205,15 @@ def create_symtype(cls, pytype, shape_env, val, duck=True): # TODO: default duck to False -def create_symint(shape_env, i: int, duck=True): +def create_symint(shape_env, i: int, duck=True) -> SymInt: return create_symtype(SymInt, int, shape_env, i, duck=duck) -def create_symbool(shape_env, b: bool): +def create_symbool(shape_env, b: bool) -> SymBool: return create_symtype(SymBool, bool, shape_env, b) -def create_symfloat(shape_env, f: float): +def create_symfloat(shape_env, f: float) -> SymFloat: return create_symtype(SymFloat, float, shape_env, f) @@ -457,14 +457,16 @@ def test_sym_int(self): r = sym_int(a1 / 2) self.assertEqual(guard_int(r), 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" + ) a3 = create_symint(shape_env, 3) r = sym_int(2.0 * torch.sym_float(a3)) self.assertEqual(guard_int(r), 6) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)""" + str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" ) def test_sym_sqrt(self): @@ -474,7 +476,7 @@ def test_sym_sqrt(self): self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)""" + str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)""" ) def test_sym_floor(self): @@ -483,11 +485,17 @@ def test_sym_floor(self): r = math.floor(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""") + self.assertExpectedInline( + str(shape_env.guards[0][0]), + """Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""", + ) r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), + """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", + ) def test_sym_trunc(self): shape_env = ShapeEnv() @@ -495,12 +503,14 @@ def test_sym_trunc(self): r = math.trunc(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""") + self.assertExpectedInline( + str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" + ) r = torch.sym_int(torch.sym_sqrt(a0)) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)""" + str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)""" ) def test_sym_ceil(self): @@ -510,12 +520,17 @@ def test_sym_ceil(self): self.assertEqual(r, 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""" + str(shape_env.guards[0][0]), + """Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""", ) - r = math.floor(3.0 * a0) + r1 = 3.0 * a0 + r = math.floor(r1) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), + """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", + ) def test_sym_ite(self): shape_env = ShapeEnv() @@ -962,8 +977,14 @@ def test_ephemeral_source_unified_with_non_ephemeral_source(self): ) class TestSymNumberMagicMethods(TestCase): def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): + with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn): + return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn) + + def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn): # Helper function # NB: don't use one as that will get specialized + # TODO: We don't have to circuitously create the float, can just + # create a symfloat directly seed_node = (create_symint(shape_env, 2) / 2.0).node bool_seed_node = (create_symint(shape_env, 2) == 2).node @@ -976,27 +997,42 @@ def get_sym_inp(inp): else: return torch.SymFloat(to_node(seed_node, inp)) + if fn == "float_pow": + if inp1 < 0: + return + + if fn == "pow_by_natural": + if isinstance(inp1, float) or isinstance(inp2, float): + return + if inp2 < 0: + return + def maybe_xfail(inp1, inp2): if fn == "sym_sqrt" and inp1 < 0: # ValueError: math domain error return self.assertRaises((ValueError,)) - elif fn in ("truediv", "floordiv", "mod") and inp2 == 0: + elif ( + fn in ("float_truediv", "int_truediv", "int_floordiv", "mod") + and inp2 == 0 + ): # ZeroDivisionError: division by zero return self.assertRaises((ZeroDivisionError,)) - elif fn == "pow" and inp1 == 0 and inp2 < 0: + elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0: # ZeroDivisionError: 0.0 cannot be raised to a negative power return self.assertRaises((ZeroDivisionError,)) elif ( - fn == "pow" + # TODO: dear catastrophe waitress, + # this doesn't work + fn in ["float_pow", "pow_by_natural"] and inp1 < 0 - and inp2 in (2.5, -2.5) and ( - type(inp1) in (SymFloat, SymInt) or type(inp2) in (SymFloat, SymInt) + type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat) ) + and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float)) ): # Complex result, which we do not support: # TypeError: Cannot convert complex to float - return self.assertRaises((TypeError,)) + return self.assertRaises((RuntimeError,)) elif fn in ("lshift", "rshift") and not ( isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int)) ): @@ -1080,6 +1116,9 @@ def test_method(self, fn, first_type, second_type): ) and fn in sym_node.only_float_magic_methods: self.skipTest(f"{fn} is not an int method") + if second_type == "float" and fn in ["mod"]: + self.skipTest(f"{fn} only handles int") + is_unary_fn = fn in sym_node.unary_methods or fn == "round" # Second argument is ignored for unary function. So only run for one type if is_unary_fn and second_type == "float": @@ -1251,112 +1290,15 @@ def yield_test_cases(values, negate=True): yield (-x, -y) def test_floordiv_float_int(self): - values = ( - (2.5, 2.1), - (2.1, 2.5), - (2.0, 2.1), - (7, 2.5), - (2.1, 7), - (7, 2), - ) + values = ((7, 2),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) ) - def test_floordiv_bool(self): - values = ( - (False, True), - (True, 2.5), - (2.5, True), - (False, 7), - (7, True), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - # Compares to int since our FloorDiv has no bool support - self.assertEqual( - TestFloorDiv.python_floordiv(x, y), - TestFloorDiv.torch_floordiv(int(x), int(y)), - ) - # Tests that our impl throws - self.assertRaisesRegex( - TypeError, - ( - rf"unsupported operand type\(s\) for //: " - rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" - rf", expected integer or real" - ), - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_complex(self): - values = ( - (1.5 + 2.5j, 1.3 + 3.5j), - (1.5 + 2.5j, 2.5), - (2.5, 1.5 + 2.5j), - (1.5 + 2.5j, 7), - (7, 1.5 + 2.5j), - ) - - for x, y in TestFloorDiv.yield_test_cases(values): - # We don't test error messages to avoid depending on Python - # interpreter version - self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y)) - self.assertRaisesRegex( - TypeError, - ( - rf"unsupported operand type\(s\) for //: " - rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" - rf", expected integer or real" - ), - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_div_by_zero(self): - values = ( - (2.5, 0), - (2.1, 0.0), - (2.3, sympy.Symbol("s", zero=True)), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - # We don't test error messages to avoid depending on Python - # interpreter version - if type(y) is not sympy.Symbol: - self.assertRaises( - ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y) - ) - self.assertRaisesRegex( - ZeroDivisionError, - "division by zero", - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_zero_base(self): - values = ( - (0, 2.5), - (0.0, 2.1), - (sympy.Symbol("s", zero=True), 2.3), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - if type(x) is not sympy.Symbol: - self.assertEqual( - TestFloorDiv.python_floordiv(x, y), - TestFloorDiv.torch_floordiv(x, y), - ) - else: - self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y)) - def test_floordiv_div_by_one(self): - values = ( - (2.5, 1), - (2.1, 1.0), - (2, 1.0), - (2, 1), - ) + values = ((2, 1),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( @@ -1367,12 +1309,7 @@ def test_floordiv_simplify(self): # Tests how we simplify or evaluate FloorDiv without free variables shape_env = ShapeEnv() result = 21 - exprs = ( - 7 * FloorDiv(6, 2), - 7 * FloorDiv(6.28, 2), - 7 * FloorDiv(6.28, 2.0), - 7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))), - ) + exprs = (7 * FloorDiv(6, 2),) for expr in exprs: self.assertEqual(expr, result) @@ -1382,33 +1319,10 @@ def test_floordiv_simplify(self): self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result) - def test_floordiv_simplify_rational(self): - result = 21 - - a = sympy.Symbol("a", integer=True) - b = sympy.Symbol("b") - - cases = [ - (FloorDiv(a, sympy.Rational(1, 8)), 8 * a), - (FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)), - ] - - for expr, expected in cases: - self.assertEqual(expr, expected) - def test_floordiv_assumptions(self): - # We define two Symbols (with different names) for each type to make - # sure the behavior is consistent regardless of whether both arguments - # are the same object or not. cases = ( sympy.Symbol("i1", integer=True), sympy.Symbol("i2", integer=True), - sympy.Symbol("r1", real=True), - sympy.Symbol("r2", real=True), - sympy.Symbol("c1", complex=True, real=False, integer=False), - sympy.Symbol("c2", complex=True, real=False, integer=False), - sympy.Symbol("s1"), - sympy.Symbol("s2"), ) for base, divisor in itertools.product(cases, repeat=2): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index c7b2e51ced20..04483ffba0fc 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1618,7 +1618,8 @@ def f(a): self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) - pow_1 = sym_size_int ** 0.5; sym_size_int = None + sym_float = torch.sym_float(sym_size_int); sym_size_int = None + pow_1 = sym_float ** 0.5; sym_float = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None return div""") diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index c5da8f7fc0da..8b16b2c620fd 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -36,7 +36,12 @@ "floor", "ceil", ] -BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"] +BINARY_OPS = [ + "truediv", "floordiv", + # "truncdiv", # TODO + # NB: pow is float_pow + "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" +] UNARY_BOOL_OPS = ["not_"] BINARY_BOOL_OPS = ["or_", "and_"] @@ -81,16 +86,24 @@ def valid_unary(fn, v): def valid_binary(fn, a, b): if fn == "pow" and ( + # sympy will expand to x*x*... for integral b; don't do it if it's big b > 4 - or ( # sympy will expand to x*x*... for integral b; don't do it if it's big - a <= 0 and b == -1 - ) - or (a == b == 0) # no imaginary numbers # 0**0 is undefined + # no imaginary numbers + or a <= 0 + # 0**0 is undefined + or (a == b == 0) ): return False - elif fn == "mod" and b == 0: + elif fn == "pow_by_natural" and ( + # sympy will expand to x*x*... for integral b; don't do it if it's big + b > 4 + or b < 0 + or (a == b == 0) + ): return False - elif (fn == "div" or fn == "truediv") and b == 0: + elif fn == "mod" and (a < 0 or b <= 0): + return False + elif (fn in ["div", "truediv", "floordiv"]) and b == 0: return False return True @@ -130,27 +143,26 @@ def test_pow_half(self): ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) @parametrize("fn", BINARY_OPS) - @parametrize("dtype_a", ("int", "float")) - @parametrize("dtype_b", ("int", "float")) - def test_binary_ref(self, fn, dtype_a, dtype_b): + @parametrize("dtype", ("int", "float")) + def test_binary_ref(self, fn, dtype): to_dtype = {"int": sympy.Integer, "float": sympy.Float} - dtype_a = to_dtype[dtype_a] - dtype_b = to_dtype[dtype_b] + # Don't test float on int only methods + if dtype == "float" and fn in ["pow_by_natural", "mod"]: + return + dtype = to_dtype[dtype] for a, b in itertools.product(CONSTANTS, repeat=2): if not valid_binary(fn, a, b): continue - a = dtype_a(a) - b = dtype_b(b) + a = dtype(a) + b = dtype(b) with self.subTest(a=a, b=b): r = getattr(ValueRangeAnalysis, fn)(a, b) if r == ValueRanges.unknown(): continue ref_r = getattr(ReferenceAnalysis, fn)(a, b) - # sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf - if fn != "floordiv": - self.assertEqual(r.lower.is_integer, r.upper.is_integer) - self.assertEqual(ref_r.is_integer, r.upper.is_integer) + self.assertEqual(r.lower.is_integer, r.upper.is_integer) + self.assertEqual(ref_r.is_integer, r.upper.is_integer) self.assertEqual(r.lower, r.upper) self.assertEqual(ref_r, r.lower) @@ -200,7 +212,8 @@ def test_binary_bool_ref_range(self, fn): @parametrize("fn", UNARY_OPS) def test_unary_ref_range(self, fn): - vals = [-sympy.oo, *CONSTANTS, sympy.oo] + # TODO: bring back sympy.oo testing for float unary fns + vals = CONSTANTS for a in generate_range(vals): with self.subTest(a=a): ref_r = getattr(ValueRangeAnalysis, fn)(a) @@ -216,40 +229,26 @@ def test_unary_ref_range(self, fn): # This takes about 4s for all the variants @parametrize("fn", BINARY_OPS + COMPARE_OPS) def test_binary_ref_range(self, fn): - vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo] + # TODO: bring back sympy.oo testing for float unary fns + vals = LESS_CONSTANTS for a, b in itertools.product(generate_range(vals), repeat=2): # don't attempt pow on exponents that are too large (but oo is OK) if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: continue with self.subTest(a=a, b=b): - ref_r = getattr(ValueRangeAnalysis, fn)(a, b) for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): if a0 not in a or b0 not in b: continue if not valid_binary(fn, a0, b0): continue with self.subTest(a0=a0, b0=b0): + ref_r = getattr(ValueRangeAnalysis, fn)(a, b) r = getattr(ReferenceAnalysis, fn)( sympy.Integer(a0), sympy.Integer(b0) ) if r.is_finite: self.assertIn(r, ref_r) - def test_rational_bounds(self): - # Repro from https://github.com/pytorch/pytorch/issues/105097 - from sympy import floor, Eq - shape_0 = sympy.Symbol('shape_0', positive=True, integer=True) - new_expr = ( - Eq(30 * floor(4 * ((shape_0 + 1) // 96) * - ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 + - 2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647), - 2880 * floor(((shape_0 + 1) // 96) * - ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 + - 323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764))) - new_range_env = {shape_0: ValueRanges(lower=1, upper=190)} - self.assertTrue(new_expr.subs({shape_0: 95})) - self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)) - class TestSympyInterp(TestCase): @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) @@ -258,7 +257,13 @@ def test_interp(self, fn): if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): return - from sympy.abc import x, y + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Dummy('x', integer=is_integer) + y = sympy.Dummy('y', integer=is_integer) + vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] @@ -300,29 +305,17 @@ def test_python_interp_fx(self, fn): if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: arity = 2 - from sympy.abc import x, y + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Dummy('x', integer=is_integer) + y = sympy.Dummy('y', integer=is_integer) symbols = [x] if arity == 2: symbols = [x, y] - # Workaround mpf from symbol error - if fn == "minimum": - sympy_expr = sympy.Min(x, y) - elif fn == "maximum": - sympy_expr = sympy.Max(x, y) - else: - sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) - - if arity == 1: - def trace_f(px): - return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) - else: - def trace_f(px, py): - return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) - - gm = fx.symbolic_trace(trace_f) - for args in itertools.product(vals, repeat=arity): if arity == 1 and not valid_unary(fn, *args): continue @@ -330,11 +323,28 @@ def trace_f(px, py): continue if fn == "truncdiv" and args[1] == 0: continue - elif fn == "pow" and (args[0] == 0 and args[1] <= 0): + elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0): continue elif fn == "floordiv" and args[1] == 0: continue with self.subTest(args=args): + # Workaround mpf from symbol error + if fn == "minimum": + sympy_expr = sympy.Min(x, y) + elif fn == "maximum": + sympy_expr = sympy.Max(x, y) + else: + sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) + + if arity == 1: + def trace_f(px): + return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) + else: + def trace_f(px, py): + return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) + + gm = fx.symbolic_trace(trace_f) + self.assertEqual( sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), gm(*args) diff --git a/torch/__init__.py b/torch/__init__.py index b07d4ea1c180..1c4d5e45b305 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -317,6 +317,75 @@ def __index__(self): # Magic methods installed by torch.fx.experimental.sym_node + def __round__(self, ndigits=None): + return self + + def __truediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__float_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_truediv__(other) + + def __rtruediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rfloat_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_truediv__(other) + + def __floordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return torch.sym_float(math.floor(sym_float(self) / other)) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_floordiv__(other) + + def __rfloordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return torch.sym_float(math.floor(other / sym_float(self))) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_floordiv__(other) + + # nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + def __pow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__pow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + # Guards! This guard is necessary because we need to know it to + # determine the output type of this operation + if other >= 0: + return self.__pow_by_natural__(other) + else: + # Mercifully, when the exponent is negative, Python just promotes + # to doubles and does a float pow: + # + # if (Py_SIZE(b) < 0 && c == NULL) { + # /* if exponent is negative and there's no modulus: + # return a float. This works because we know + # that this calls float_pow() which converts its + # arguments to double. */ + # Py_DECREF(a); + # Py_DECREF(b); + # return PyFloat_Type.tp_as_number->nb_power(v, w, x); + # } + return sym_float(self).__pow__(sym_float(other)) + + def __rpow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rpow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + if self >= 0: # self is exponent + return self.__rpow_by_natural__(other) + else: + return sym_float(self).__rpow__(sym_float(other)) + def __eq__(self, other: object) -> builtins.bool: raise AssertionError("type stub not overridden") @@ -338,6 +407,24 @@ def __add__(self, other) -> "SymInt": def __mul__(self, other) -> "SymInt": raise AssertionError("type stub not overridden") + def __pow_by_natural__(self, other) -> "SymInt": + raise AssertionError("type stub not overridden") + + def __rpow_by_natural__(self, other) -> "SymInt": + raise AssertionError("type stub not overridden") + + def __int_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rint_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __int_floordiv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rint_floordiv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + def __sym_max__(self, other): raise AssertionError("type stub not overridden") @@ -372,9 +459,43 @@ def __init__(self, node): # class has a field named node that stores SymNode self.node = node + def __truediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__float_truediv__(sym_float(other)) + + def __rtruediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__rfloat_truediv__(sym_float(other)) + + def __floordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return torch.sym_float(math.floor(self / sym_float(other))) + + def __rfloordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return torch.sym_float(math.floor(sym_float(other) / self)) + def __bool__(self): return self.node.bool_() + # Symbolic power does NOT work with negative base, this is to avoid + # potential complex outputs + def __pow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(self >= 0) + return self.__float_pow__(other) + + def __rpow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(other >= 0) + return self.__rfloat_pow__(other) + # Magic methods installed by torch.fx.experimental.sym_node def __eq__(self, other: object) -> builtins.bool: @@ -392,6 +513,18 @@ def __le__(self, other) -> builtins.bool: def __ge__(self, other) -> builtins.bool: raise AssertionError("type stub not overridden") + def __float_pow__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rfloat_pow__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __float_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rfloat_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + def __trunc__(self): raise AssertionError("type stub not overridden") @@ -525,7 +658,12 @@ def sym_int(a): return py_int(a) # type: ignore[operator] def sym_max(a, b): - """ SymInt-aware utility for max().""" + """ + SymInt-aware utility for max which avoids branching on a < b. + Unlike builtins.max(), this only works for int/float, and it always + promotes to float if any argument is float (unlike builtins.max, which + will faithfully preserve the type of the input argument). + """ from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -533,14 +671,19 @@ def sym_max(a, b): if isinstance(a, (SymInt, SymFloat)): return a.__sym_max__(b) elif isinstance(b, (SymInt, SymFloat)): - # NB: If you actually care about preserving output type exactly - # if you do something like max(0, 0.0), it is NOT sound to treat - # min/max as commutative + # Due to promotion semantics, this is operator is commutative: + # max(1, 1.0) === max(1.0, 1) === 1.0 return b.__sym_max__(a) - return builtins.max(a, b) # type: ignore[operator] + # TODO: Probably can make bool work too, just lazy + assert isinstance(a, (builtins.int, builtins.float)), type(a) + assert isinstance(b, (builtins.int, builtins.float)), type(b) + if isinstance(a, builtins.float) or isinstance(b, builtins.float): + return builtins.float(builtins.max(a, b)) + else: + return builtins.max(a, b) def sym_min(a, b): - """ SymInt-aware utility for max().""" + """ SymInt-aware utility for min().""" from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -549,7 +692,12 @@ def sym_min(a, b): return a.__sym_min__(b) elif isinstance(b, (SymInt, SymFloat)): return b.__sym_min__(a) - return builtins.min(a, b) # type: ignore[operator] + assert isinstance(a, (builtins.int, builtins.float)), type(a) + assert isinstance(b, (builtins.int, builtins.float)), type(b) + if isinstance(a, builtins.float) or isinstance(b, builtins.float): + return builtins.float(builtins.min(a, b)) + else: + return builtins.min(a, b) # Drop in replacement for math.sqrt, math.sin, math.cos etc def _get_sym_math_fn(name): diff --git a/torch/_export/pass_base.py b/torch/_export/pass_base.py index 2200193e78a5..840fc663f3ea 100644 --- a/torch/_export/pass_base.py +++ b/torch/_export/pass_base.py @@ -33,6 +33,7 @@ _TORCH_SYM_OPS: Set[Callable] = { torch.sym_int, + torch.sym_float, torch.sym_ite, torch.sym_max, torch.sym_min, diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 51dbc435deaf..f8fdc1011b52 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -171,6 +171,7 @@ def _reverse_map(d: Dict[Any, Enum]): operator.floordiv, operator.mod, torch.sym_int, + torch.sym_float, torch.sym_ite, torch.sym_max, torch.sym_min, @@ -1475,10 +1476,15 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: # Here we force symbols corresponding to SymInts to be at least integers. # Otherwise some expressions that the shape env would otherwise evaluate to False, # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. + # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024) sym = sym.subs( {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} ) - if isinstance(sym, sympy.Symbol): + # We need to check if the symbol has already been allocated, + # self.symbol_name_to_symbol is not enough because the + # integer-ification of symbols can induce simplification; + # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral + if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val: self.symbol_name_to_symbol[val.expr_str] = sym if hint is not None: self.shape_env.add_var_to_val(sym, hint) @@ -1497,7 +1503,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: free_symbols = sym.free_symbols for s in free_symbols: if s.name not in self.symbol_name_to_symbol: - self.symbol_name_to_symbol[s.name] = s + self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment] if vr := self.symbol_name_to_range.get(s.name): self.shape_env.constrain_symbol_range( s, diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 07b5ca097400..8ee7c8926834 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -176,6 +176,7 @@ def _allowed_op_types() -> Tuple[Type[Any], ...]: _allowed_torch_functions = ( torch.autograd.grad_mode.set_grad_enabled, torch.sym_int, + torch.sym_float, torch.sym_ite, torch.sym_max, torch.sym_min, diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index a1412adb505d..8c62ef2ba3c9 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import logging import operator from functools import partial from typing import Any, Callable, Dict @@ -12,6 +13,9 @@ from .virtualized import V +log = logging.getLogger(__name__) + + class BoundVars: """ Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() @@ -56,6 +60,7 @@ def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: with V.set_ops_handler(ValueRangeAnalysis()): interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) interpreter.run(V.get_ops_handler(), initial_env=self._bounds) return self._bounds diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 29d6db791672..8ca6dc2b9153 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -341,6 +341,8 @@ def propagate_scheduler_node(cls, node): DataTypePropagation.propagate_loopbody(node._body) +# This printer contains rules that are supposed to be generic for both C/C++ and +# Python class ExprPrinter(Printer): @staticmethod def paren(string): @@ -370,12 +372,6 @@ def all_in_parens(string): return string return f"({string})" - def _print_Infinity(self, expr): - return "math.inf" - - def _print_NegativeInfinity(self, expr): - return "-math.inf" - def _print_Relational(self, expr): return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) @@ -385,11 +381,14 @@ def _print_Mul(self, expr): def _print_Add(self, expr): return " + ".join(map(self.paren, map(self._print, expr.args))) + # NB: this is OK to put here, because Mod is only defined for positive + # numbers, and so across C/Python its behavior is consistent def _print_Mod(self, expr): return " % ".join(map(self.paren, map(self._print, expr.args))) - def _print_FloorDiv(self, expr): - raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) @@ -400,10 +399,84 @@ def _print_GreaterThan(self, expr): # Go figure... return " >= ".join(map(self.paren, map(self._print, expr.args))) + # NB: The C implementation is injected into codegen at + # torch/_inductor/codegen/wrapper.py def _print_align(self, expr): assert len(expr.args) == 1 return f"align({self._print(expr.args[0])})" + # This must be implemented because sympy will collect x * x into Pow(x, 2), without + # any explicit intervention. We print it just like x * x, notably, we + # never generate sympy.Pow with floats. + # + # NB: this pow by natural, you should never have used builtin sympy.pow + # for FloatPow, and a symbolic exponent should be PowByNatural. These + # means exp is guaranteed to be integer. + def _print_Pow(self, expr): + base, exp = expr.args + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + assert exp >= 0 + if exp > 0: + return "*".join([self.paren(base)] * exp) + else: # exp == 0 + return "1" + + # Explicit NotImplemented functions are to prevent default sympy printing + # behavior, which will just barf out ToFloat(...) to your IR. The error + # message is better here because it tells you which printer class it needs + # to go in. + + def _print_ToFloat(self, expr): + raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") + + def _print_Infinity(self, expr): + raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") + + def _print_NegativeInfinity(self, expr): + raise NotImplementedError( + f"_print_NegativeInfinity not implemented for {type(self)}" + ) + + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_PythonMod(self, expr): + raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") + + def _print_IntTrueDiv(self, expr): + raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") + + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr): + raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") + + def _print_TruncToInt(self, expr): + raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") + + def _print_RoundToInt(self, expr): + raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") + + def _print_RoundDecimal(self, expr): + raise NotImplementedError( + f"_print_RoundDecimal not implemented for {type(self)}" + ) + + # NB: Some float operations are INTENTIONALLY not implemented for + # printers. You can implement them as a quick unblock, but it is better + # to ask yourself why we haven't done this computation in the Tensor + # universe instead + + def _print_TruncToFloat(self, expr): + raise NotImplementedError( + f"_print_TruncToFloat not implemented for {type(self)}" + ) + def doprint(self, expr, *, simplify: bool = True): # TODO: why are people passing strings to the printer here :think: if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): @@ -412,6 +485,10 @@ def doprint(self, expr, *, simplify: bool = True): class PythonPrinter(ExprPrinter): + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"float({self._print(expr.args[0])})" + def _print_ModularIndexing(self, expr): x, div, mod = expr.args x = self.paren(self.doprint(x)) @@ -421,56 +498,72 @@ def _print_ModularIndexing(self, expr): x = f"({x} // {div})" return f"{x} % {mod}" + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # WARNING: this is dangerous for Triton, which has C-style modulus def _print_FloorDiv(self, expr): x, div = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) return f"({x} // {div})" + # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python + # does a special algorithm + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + def _helper_sqrt(self, expr): return f"math.sqrt({self._print(expr)})" def _print_OpaqueUnaryFn_sqrt(self, expr): return self._helper_sqrt(expr.args[0]) - def _print_Pow(self, expr): - # Pow() confuses triton + def _print_FloatPow(self, expr): base, exp = expr.args - # NB: Remember this is sizevar computation! You don't typically - # expect to have to do floating point computation including exponents - # in sizevar compute. Instead of adding support for floating - # point pow, you should make upstream retranslate the Sympy expression - # into Tensor expressions earlier and do that instead. - if exp == 0.5: - return self._helper_sqrt(base) - elif exp == -0.5: - return "1/" + self._helper_sqrt(base) - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - if exp > 0: - return "*".join([self.paren(base)] * exp) - elif exp < 0: - return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - return "1" + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + + # TODO: Not sure this works with Triton, even when base/exp are integral + def _print_PowByNatural(self, expr): + base, exp = expr.args + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" def _print_floor(self, expr): assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_TruncToInt(self, expr): assert len(expr.args) == 1 + # This also could have been int(), they'll do the same thing for float return f"math.trunc({self._print(expr.args[0])})" def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"math.ceil({self._print(expr.args[0])})" + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + def _print_Abs(self, expr): assert len(expr.args) == 1 return f"abs({self._print(expr.args[0])})" + # NB: It's expected that we've made explicit any promotion in the sympy + # expression, so it doesn't matter that Python max/min doesn't perform + # promotion def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" @@ -515,7 +608,7 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" @@ -654,6 +747,29 @@ def remainder(a, b): ) return ops.where(cond, ops.add(r, b), r) + @staticmethod + def trunc_to_int(a, dtype): + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def floor_to_int(a, dtype): + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a, dtype): + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def round_to_int(a, dtype): + return ops.to_dtype(ops.round(a), dtype) + + @staticmethod + def int_truediv(a, b): + # TODO: this is wrong + # TODO: an easy bandaid is to generate runtime asserts that it's + # <= 2**53, which is when this equation is correct + return ops.truediv(a, b) + @staticmethod def load_seed(name, offset): return ops.load(name, sympy.Integer(offset)) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 35e604a35e48..749a6e6d4cab 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -276,11 +276,11 @@ def visit_modular_indexing(divisor, modulus): original_index = index - div = sympy.Wild("divisor") + div = sympy.Wild("divisor", integer=True) if index.has(FloorDiv): index = index.replace(FloorDiv(var, div), visit_indexing_div) - mod = sympy.Wild("modulus") + mod = sympy.Wild("modulus", integer=True) if index.has(ModularIndexing): index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 438ac908486a..ef7566c8bcba 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -101,11 +101,54 @@ def _print_floor(self, expr): r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): assert len(expr.args) == 1 - r = f"std::trunc({self._print(expr.args[0])})" + r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::trunc({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" + + def _print_TruncToFloat(self, expr): + assert len(expr.args) == 1 + return f"std::trunc({self._print(expr.args[0])})" + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"static_cast({self._print(expr.args[0])})" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_CMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**53 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT + # use std::pow, that operates on floats + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + def _print_FloatPow(self, expr): + base, exp = expr.args + return f"std::pow({self._print(base)}, {self._print(exp)})" + def _print_Pow(self, expr): # Uses float constants to perform FP div base, exp = expr.args @@ -140,6 +183,11 @@ def _print_ceiling(self, expr): r = f"std::ceil({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_Min(self, expr): args = [self._print(a) for a in expr.args] if len(args) == 2: @@ -201,8 +249,9 @@ def _print_OpaqueUnaryFn_atan(self, expr): def _print_OpaqueUnaryFn_sqrt(self, expr): return f"std::sqrt({self._print(expr.args[0])})" - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 + # TODO: dispatch to llrint depending on index type return f"std::lrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index f366533e3b94..9b6184f7e185 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -286,23 +286,68 @@ def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]): return f"{value}[{', '.join(expand)}]" +# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a +# number of operators which Triton "implements", but in a way that is +# inconsistent with Python semantics (and consistent with C semantics). We +# must override all of these, or it is potential silent correctness problem class TritonPrinter(PythonPrinter): + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + return ( + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). If you are trying to hit this, maybe try something like + # torch.arange(n, device="cuda") - 1 and then do a modulus on it + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # TODO: This is wrong, see + # https://github.com/triton-lang/triton/issues/955 + # But for Sympy expressions, things will /mostly/ work out because we + # don't usually deal with negative numbers in the division + def _print_FloorDiv(self, expr): + assert expr.is_integer + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"({x} // {div})" + + # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher + # precision algorithm, which we would need to replicate here + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + # NB: sympy.floor/ceiling produce integers, so we have to do the + # conversion to index dtype def _print_floor(self, expr): assert len(expr.args) == 1 return ( f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): assert len(expr.args) == 1 return ( - f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + def _helper_sqrt(self, expr): return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))" @@ -373,20 +418,9 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" - def _print_FloorDiv(self, expr): - if expr.is_integer: - return super()._print_FloorDiv(expr) - - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"libdevice.floor({x} / {div}).to({V.kernel.index_dtype})" - - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 - return ( - f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" - ) + return f"libdevice.llrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): assert len(expr.args) == 2 diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 4c5ea746f3f5..f2bdf22e2d96 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1197,8 +1197,11 @@ def debug(msg): elif is_magic_method(n.target): # TODO: this is sus, it probably should be handled in the # lowerings themselves similarly to sym_size/sym-stride + # https://github.com/pytorch/pytorch/issues/127789 debug("is_magic_method") - if isinstance(n.meta["val"], torch.SymInt): + if isinstance( + n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) + ): result = n.meta["val"].node.expr else: result = super().run_node(n) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8bf3bb22f93f..7b2cf76e7943 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -45,7 +45,6 @@ is_boolean_dtype, is_float_dtype, make_channels_last_strides_for, - make_contiguous_strides_for, StrideType, ) from torch._subclasses.fake_tensor import get_schema_info @@ -237,7 +236,7 @@ def ir_node_to_tensor(x, guard_shape=True): if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] else: - stride = make_contiguous_strides_for(size) # type: ignore[arg-type] + stride = FlexibleLayout.contiguous_strides(size) # type: ignore[arg-type] dtype = x.get_dtype() device = x.get_device() size = convert_shape_to_symint(size) @@ -2767,6 +2766,7 @@ class FlexibleLayout(Layout): allow_indexing = False + # WARNING! This doesn't handle zero size tensors correctly @staticmethod def contiguous_strides(sizes): if len(sizes) == 0: @@ -5916,7 +5916,7 @@ def _original_deconv_weight_size( # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) if dynamic_shapes and is_contiguous_storage_and_layout(x): - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) else: output_stride = make_channels_last_strides_for(output_size) @@ -5968,7 +5968,7 @@ def _prepare_linear_fusion_create( assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" inputs = [x, weight] - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) kernel_layout = FixedLayout( x.get_device(), x.get_dtype(), @@ -6284,7 +6284,7 @@ def create(cls, x, packed_w, orig_w, B, batch_size): *m, _ = x.get_size() oc, _ = orig_w.get_size() output_size = list(m) + [oc] - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) inputs = [x, packed_w, orig_w] constant_args = [batch_size] if B is not None: @@ -6602,13 +6602,13 @@ def create( def get_strides_of_lstm_output(output_shape, batch_first): assert len(output_shape) == 3, "Expect output_shape to be 3D" - return make_contiguous_strides_for(output_shape) + return FlexibleLayout.contiguous_strides(output_shape) output_sizes = [output_shape, hy_shape, cy_shape] output_strides = [ get_strides_of_lstm_output(output_shape, batch_first), - make_contiguous_strides_for(hy_shape), - make_contiguous_strides_for(cy_shape), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), ] output_ir = [ MultiOutput( diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index b2c5bb271501..932bcd50b920 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -6,7 +6,6 @@ from typing import Any, List, Tuple import torch -from torch._prims_common import make_contiguous_strides_for from .. import config from ..ir import ( ComputedBuffer, @@ -390,7 +389,7 @@ def flex_attention(*args, **kwargs): query.get_device(), query.get_dtype(), query.get_size(), - make_contiguous_strides_for(query.get_size()), + FlexibleLayout.contiguous_strides(query.get_size()), ) # see NOTE:[TritonTemplates with multiple outputs] logsumexp_shape = query.get_size()[:-1] # [B, H, M] @@ -746,7 +745,7 @@ def flex_attention_backward(*args, **kwargs): key.get_device(), key.get_dtype(), key.get_size(), - make_contiguous_strides_for(key.get_size()), + FlexibleLayout.contiguous_strides(key.get_size()), ) # Create delta which will is needed for the bwd's kernel diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 0519211c01aa..0461cc3683d5 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -35,7 +35,7 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 @@ -4263,7 +4263,7 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): out_sz = out_sz[dim] in_sz = in_sz[dim] kernel_sz = kernel_sz[dim] - alpha = (in_sz - kernel_sz) / (out_sz - 1) + alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) samples_loader = samples.make_loader() def load(prefix, i): @@ -4373,7 +4373,7 @@ def upsample_nearest2d_backward( w_kernel_max = ceildiv(inp_w, out_w) def start_index(index, out_dim, inp_dim): - return CeilDiv(index * inp_dim, out_dim) + return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) def end_index(index, out_dim, inp_dim): return start_index((index + 1), out_dim, inp_dim) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 20d652019372..1f0a0bc1a6b3 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -139,6 +139,38 @@ def to_dtype( """ ... + def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with truncation semantics (similar to how the int + constructor works in Python). In Inductor codegen, this just decays + to trunc and then to_dtype, but this composite operation helps + roundtrips for Sympy evaluation. + + dtype is taken as an explicit parameter because the desired output + dtype is typically the index dtype, which may vary between int32 and + int64 depending on if we've shown that all the indexing operations can + be done in int32. + """ + ... + + def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def floor_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def round_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with round-to-even semantics. See also trunc_to_int. + """ + ... + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: """ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) @@ -399,21 +431,23 @@ def isinf(self, x0: T) -> T: def isnan(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation + # This rounds half to even to break ties def round(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation def floor(self, x0: T) -> T: ... def sign(self, x0: T) -> T: ... - def to_int(self, x0: T) -> T: - ... - + # NB: this returns a float, like the torch operation def trunc(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation def ceil(self, x0: T) -> T: ... @@ -450,6 +484,7 @@ def sub(self, x0: T, x1: T) -> T: def mul(self, x0: T, x1: T) -> T: ... + # NB: this returns a float, like the torch operation def pow(self, x0: T, x1: T) -> T: ... @@ -618,14 +653,21 @@ def truncdiv(self, x0: T, x1: T) -> T: def floordiv(self, x0: T, x1: T) -> T: """Python-style floor division between integers only. Computes the - true division of two numbers and floors the result. + true division of two numbers and floors the result. If you want + floor division for floats, do regular truediv and floor the result. """ ... def truediv(self, x0: T, x1: T) -> T: - """True division between floats. Integer inputs are NOT valid: to do - Python style (int, int) -> float division, promote the inputs to float - first.""" + """True division between floats. Integer inputs are NOT valid. To + do Python-style (int, int) -> float division, use int_truediv""" + ... + + def int_truediv(self, x0: T, x1: T) -> T: + """True division between integers. This is NOT the same as promoting + to float and doing integer division, there is a bespoke algorithm for + doing the division in higher precision than the above. + """ ... def div(self, x0: T, x1: T) -> T: @@ -641,6 +683,10 @@ def remainder(self, x0: T, x1: T) -> T: """Python-style modulus, take sign from RHS (x1).""" ... + def round_decimal(self, x0: T, x1: T) -> T: + """Python-style round with decimal argument""" + ... + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # In CUDA, optimized implementations of other mathematical operations are # offered separately via libdevice for double precision computation (in diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 4f9d12b13c64..467af6f57812 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -387,7 +387,7 @@ def store_output( assert isinstance(mask, (str, type(None))) assert self.template_mask is None indices = list(map(TritonPrinter.paren, indices)) - index_symbols = [sympy.Symbol(x) for x in indices] + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] lengths = [ V.graph.sizevars.simplify(s) for s in self.output_node.get_size() ] @@ -411,7 +411,7 @@ def store_output( output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) if output_index == contiguous_index: - output_index = sympy.Symbol("xindex") + output_index = sympy.Symbol("xindex", integer=True) epilogue_args = [val] for input_node in itertools.chain( diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 1a863bc5485d..8ea55b024c59 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -162,9 +162,9 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(ModularIndexing): expr = expr.replace( ModularIndexing( - sympy.Wild("base"), - sympy.Wild("divisor"), - sympy.Wild("modulus"), + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + sympy.Wild("modulus", integer=True), ), visit_modular_indexing, ) @@ -172,8 +172,8 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(FloorDiv): expr = expr.replace( FloorDiv( - sympy.Wild("base"), - sympy.Wild("divisor"), + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), ), visit_indexing_div, ) @@ -736,11 +736,11 @@ def _join_dimensions_cached(expr: Expr) -> Expr: """ assert isinstance(expr, sympy.Add) - scale = sympy.Wild("scale", exclude=[0]) - base = sympy.Wild("base") - divisor = sympy.Wild("divisor") - mod1 = sympy.Wild("modulus") - mod2 = sympy.Wild("modulus2") + scale = sympy.Wild("scale", exclude=[0], integer=True) + base = sympy.Wild("base", integer=True) + divisor = sympy.Wild("divisor", integer=True) + mod1 = sympy.Wild("modulus", integer=True) + mod2 = sympy.Wild("modulus2", integer=True) for term1 in expr.args: m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) if m1: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 9a83b3d10d40..c77576451264 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -222,7 +222,7 @@ def ceildiv( numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] ) -> Union[int, sympy.Expr]: if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): - return CeilDiv(numer, denom) + return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) # TODO: There is a bug in a call to this function, to repro: # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy # --amp --only YituTechConvBert --dynamic-shapes diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index f9075c603f11..c5a549860f47 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1728,7 +1728,7 @@ def go(t, real_t): for run_impl_check, op_impl in op_implementations_checks: if run_impl_check(func): op_impl_out = op_impl(self, func, *args, **kwargs) - if op_impl_out != NotImplemented: + if op_impl_out is not NotImplemented: return maybe_propagate_real_tensors(op_impl_out) def maybe_run_unsafe_fallback(error=None): diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index a7ce337f9ac8..2a3cb62c56d7 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1200,8 +1200,13 @@ void initJITBindings(PyObject* module) { SYMNODE_BINARY(sub) SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) + SYMNODE_BINARY(int_truediv) + SYMNODE_BINARY(float_truediv) SYMNODE_BINARY(pow) + SYMNODE_BINARY(float_pow) + SYMNODE_BINARY(pow_by_natural) SYMNODE_BINARY(floordiv) + SYMNODE_BINARY(int_floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(eq) SYMNODE_BINARY(ne) diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index f8c710cf6579..15738b1a67e1 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -198,14 +198,34 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__func__, other); } + c10::SymNode float_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode int_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode pow(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } + c10::SymNode float_pow(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode pow_by_natural(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode floordiv(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } + c10::SymNode int_floordiv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode mod(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index e351df8d622c..a5ce066faa47 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -2,7 +2,6 @@ import builtins import dataclasses import inspect -import math import sys import weakref from collections import defaultdict @@ -267,11 +266,14 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): shared: Optional[_ConstraintTarget] = None debug_name: Optional[str] = None - def _clone_with_range(self, lower=0, upper=math.inf): + def _clone_with_range(self, lower=0, upper=None): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges + if upper is None: + upper = sys.maxsize - 1 + constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), warn_only=False, @@ -499,7 +501,6 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): ) # Import sympy locally - import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges @@ -509,7 +510,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): id(t), index, StrictMinMaxConstraint( - vr=ValueRanges(lower=0, upper=sympy.oo), warn_only=False + vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False ), debug_name=debug_name, ) diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 3eeb7ad02602..1c384d9dfbeb 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -278,7 +278,13 @@ def wrapper(*args, **kwargs): raise except Exception: - log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs) + log.error( # noqa: G201 + "failed while running %s(*%s, **%s)", + name, + args[1:], + kwargs, + exc_info=log.isEnabledFor(logging.INFO), + ) raise return wrapper diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 559c3f8ed4cd..8f270c56e6c1 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -268,8 +268,11 @@ def mul(self, other) -> "SymNode": def mod(self, other) -> "SymNode": return self._mod(other) # type: ignore[attr-defined] - def pow(self, other) -> "SymNode": - return self._pow(other) # type: ignore[attr-defined] + def float_pow(self, other) -> "SymNode": + return self._float_pow(other) # type: ignore[attr-defined] + + def pow_by_natural(self, other) -> "SymNode": + return self._pow_by_natural(other) # type: ignore[attr-defined] def and_(self, other) -> "SymNode": return self._and_(other) # type: ignore[attr-defined] @@ -277,11 +280,14 @@ def and_(self, other) -> "SymNode": def or_(self, other) -> "SymNode": return self._or_(other) # type: ignore[attr-defined] - def truediv(self, other) -> "SymNode": - return self._truediv(other) # type: ignore[attr-defined] + def float_truediv(self, other) -> "SymNode": + return self._float_truediv(other) # type: ignore[attr-defined] - def floordiv(self, other) -> "SymNode": - return self._floordiv(other) # type: ignore[attr-defined] + def int_truediv(self, other) -> "SymNode": + return self._int_truediv(other) # type: ignore[attr-defined] + + def int_floordiv(self, other) -> "SymNode": + return self._int_floordiv(other) # type: ignore[attr-defined] def lshift(self, other) -> "SymNode": return self._lshift(other) # type: ignore[attr-defined] @@ -362,6 +368,17 @@ def sym_or(self, other): def sym_and(self, other): return self.and_(other) + # There is no int_truediv available from C++ + def truediv(self, other): + return self.float_truediv(other) + + def floordiv(self, other) -> "SymNode": + return self.int_floordiv(other) + + # We didn't bind integer pow in C++ + def pow(self, other): + return self.float_pow(other) + def is_non_overlapping_and_dense(self, sizes, strides): return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] @@ -478,7 +495,7 @@ def is_constant(self): "eq": operator.eq, "floor": math.floor, "trunc": math.trunc, - "floordiv": operator.floordiv, + "int_floordiv": operator.floordiv, "ge": operator.ge, "gt": operator.gt, "is_integer": lambda x: x.is_integer(), @@ -490,7 +507,8 @@ def is_constant(self): "ne": operator.ne, "neg": operator.neg, "or": operator.or_, - "pow": operator.pow, + "float_pow": operator.pow, + "pow_by_natural": operator.pow, "round": builtins.round, "rshift": operator.rshift, "sub": operator.sub, @@ -499,12 +517,14 @@ def is_constant(self): "sym_max": sym_max, "sym_min": sym_min, "sym_not": sym_not, - "truediv": operator.truediv, + "float_truediv": operator.truediv, + "int_truediv": operator.truediv, } unary_magic_methods = { "abs", "sym_float", + "sym_int", "ceil", "floor", "neg", @@ -560,20 +580,20 @@ def fn(self): bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods # Methods that are only for float -only_float_magic_methods = {"is_integer"} +only_float_magic_methods = {"is_integer", "round", "sym_int"} magic_methods_on_operator_with_trailing_underscore = {"and", "or"} -always_float_magic_methods = {"truediv", "sym_float", "pow"} +always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} for name in math_op_names: sym_name = f"sym_{name}" always_float_magic_methods.add(sym_name) -always_int_magic_methods = {"ceil", "floor", "trunc"} +always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} always_bool_magic_methods = { "eq", "ne", @@ -591,10 +611,16 @@ def fn(self): # Methods that have a `__foo__` as well as `__rfoo__` -def _sympy_truediv(a, b): - from torch.utils._sympy.functions import TrueDiv +def _sympy_float_truediv(a, b): + from torch.utils._sympy.functions import FloatTrueDiv - return TrueDiv(a, b) + return FloatTrueDiv(a, b) + + +def _sympy_int_truediv(a, b): + from torch.utils._sympy.functions import IntTrueDiv + + return IntTrueDiv(a, b) def _sympy_floordiv(a, b): @@ -604,15 +630,24 @@ def _sympy_floordiv(a, b): def _sympy_mod(a, b): - from torch.utils._sympy.functions import Mod + from torch.utils._sympy.functions import Mod, PythonMod + + if a.is_nonnegative and b.is_nonnegative: + return Mod(a, b) + else: + return PythonMod(a, b) + - return Mod(a, b) +def _sympy_pow_by_natural(a, b): + from torch.utils._sympy.functions import PowByNatural + return PowByNatural(a, b) -def _sympy_pow(a, b): - from torch.utils._sympy.functions import Pow - return Pow(a, b) +def _sympy_float_pow(a, b): + from torch.utils._sympy.functions import FloatPow + + return FloatPow(a, b) def _sympy_and(a, b): @@ -644,11 +679,13 @@ def _sympy_rshift(a, b): "sub": operator.sub, "mul": operator.mul, "mod": _sympy_mod, - "pow": _sympy_pow, + "pow_by_natural": _sympy_pow_by_natural, + "float_pow": _sympy_float_pow, "and": _sympy_and, "or": _sympy_or, - "truediv": _sympy_truediv, - "floordiv": _sympy_floordiv, + "float_truediv": _sympy_float_truediv, + "int_truediv": _sympy_int_truediv, + "int_floordiv": _sympy_floordiv, "lshift": _sympy_lshift, "rshift": _sympy_rshift, } @@ -673,21 +710,23 @@ def _floor_ceil_helper(a, fn): def _sympy_floor(a): - import sympy + from torch.utils._sympy.functions import FloorToInt - return _floor_ceil_helper(a, sympy.floor) + return FloorToInt(a) +# NB: this is Python trunc semantics which returns an int. Do NOT use this to +# represent torch.trunc (which is float to float) def _sympy_trunc(a): - from torch.utils._sympy.functions import Trunc + from torch.utils._sympy.functions import TruncToInt - return Trunc(a) + return TruncToInt(a) def _sympy_ceil(a): - import sympy + from torch.utils._sympy.functions import CeilToInt - return _floor_ceil_helper(a, sympy.ceiling) + return CeilToInt(a) def _sympy_eq(a, b): @@ -772,26 +811,28 @@ def _sympy_abs(a): def _sympy_round(number, ndigits=None): - from torch.utils._sympy.functions import Round, RoundDecimal + from torch.utils._sympy.functions import RoundDecimal, RoundToInt if ndigits is None: - return Round(number) + return RoundToInt(number) else: return RoundDecimal(number, ndigits) def _sympy_sym_float(a): - # Cannot use sympy.Float(a) here, coz it expects python literals - # Multiply by 1.0 to cast to float. This is needed when the input - # is a SymInt which has the assumption that it is integer and - # SymPy will otherwise assume that return value cannot be a float. - return a * 1.0 + from torch.utils._sympy.functions import ToFloat + + # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly + # reports that it is an integer + return ToFloat(a) def _sympy_is_integer(a): import sympy - return sympy.Eq(sympy.floor(a), a) + from torch.utils._sympy.functions import ToFloat + + return sympy.Eq(ToFloat(sympy.floor(a)), a) magic_methods = { @@ -990,9 +1031,26 @@ def binary_magic_impl(self, other): self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) ) assert isinstance(other, SymNode) - # TODO: consider constant prop here try: - out = func(self.expr, other.expr) + if method == "mod": + from torch.utils._sympy.functions import Mod, PythonMod + + # Special handling for mod that requires access to the value + # ranges + shape_env = self.shape_env + if ( + self.expr.is_nonnegative + or shape_env.bound_sympy(self.expr).lower >= 0 + ) and ( + other.expr.is_nonnegative + or shape_env.bound_sympy(other.expr).lower >= 0 + ): + out = Mod(self.expr, other.expr) + else: + out = PythonMod(self.expr, other.expr) + else: + # TODO: consider constant prop here + out = func(self.expr, other.expr) except Exception: log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) raise @@ -1123,9 +1181,13 @@ def round_impl(self, ndigits=None): except Exception: log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) raise + out = safe_expand(out) - pytype = int if ndigits is None else self.pytype + if ndigits is None: + pytype = int + else: + pytype = self.pytype out_hint = None if self.hint is not None: @@ -1137,6 +1199,7 @@ def round_impl(self, ndigits=None): # hack down below works, because all round function down the line all take ndigits=None as default in their # signature. # TODO: Remove the args construction below if a different sentinel is used by FX. + # ezyang(May 2024): LOL args = [self.fx_node] if ndigits is not None: args.append(ndigits) @@ -1260,6 +1323,32 @@ def is_constant(x): return x.node.is_constant() return False + # Promotion rules for binary operations. NB: we preserve PYTHON semantics + # - if args are same type, do nothing + # - if one arg is float, promote other arg to float + # - nb: this applies to floordiv, even though output is integral + # (it's still float) + # - pow is funny business + # - if both ints + # - trigger a guard on exponent >= 0 + # - if non-negative, output is int + # - otherwise, output is float + # - otherwise, promote other arg to float + # - nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + # - equality is pain: Python does the fancy thing where it unpacks the + # mantissa from the float and then compares that against the int. + # Which means it is able to tell that + # 9007199254740993 != 9007199254740992. (rather than if the LHS was + # promoted to float, in which case it would have truncated to the RHS + # and subsequently been equal). We'll model this exactly by having + # special mixed type equality operations. Unfortunately, we need to + # do this for all comparison operations (maybe I'll only implement + # compare) + # - sym_ite mumble mumble really shouldn't allow mixed but whatever + if method in bool_becomes_int_magic_methods: def promote(x): @@ -1273,6 +1362,41 @@ def promote(x): def promote(x): return x + def promote2(self, other): + # TODO: Remove eq and other relations from this list. + # CPython has fancy implementations for these to get as much precision + # as possible instead of just promoting to float64 and praying, so we + # need to handle them specially too. + # Also, note that int_truediv doesn't go through this path: both + # arguments are "int" so there isn't any promotion + if method not in [ + "add", + "sub", + "mul", + "mod", + "float_pow", + "float_truediv", + "int_floordiv", + "sym_min", + "sym_max", + # TODO: remove these + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + ]: + return self, other + f_self = isinstance(self, (float, torch.SymFloat)) + f_other = isinstance(other, (float, torch.SymFloat)) + if f_self or f_other: + if not f_self: + self = torch.sym_float(self) + if not f_other: + other = torch.sym_float(other) + return self, other + # Before and after performing the operation, check if any operands are constant. # If so, extract out the constant values first. If `self` itself is a # constant, then "redispatch" by calling back into the operator. Sometimes @@ -1287,9 +1411,12 @@ def unary_magic_impl(self): return wrap_node(getattr(self.node, method_attr)()) def binary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented sym_node_log.debug("MAGIC %s %s %s", method, self, other) self = promote(self) other = promote(other) + self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): @@ -1301,8 +1428,11 @@ def binary_magic_impl(self, other): return get_constant(ret) if is_constant(ret) else ret def rbinary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented self = promote(self) other = promote(other) + self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 544950298861..0e05d88f6756 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -61,7 +61,9 @@ from torch import SymBool, SymFloat, SymInt from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator +from torch.utils._sympy.functions import ( + FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt +) from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.singleton_int import SingletonInt @@ -869,9 +871,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): for N=1. """ if min is None: - min = -sympy.oo + min = -sys.maxsize - 1 if max is None: - max = sympy.oo + max = sys.maxsize - 1 if max < min: raise ValueError( @@ -979,16 +981,6 @@ def eval_guards(gm, *args, ignore_static=True): def bind_symbols(gm, *args): return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) -def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges): - """ - We assert that the bounds are either Boolean, or not finite, or can be computed - in exact prevision via rational arithmetic. - The only exception to this is the rare case when the user calls `sqrt(s0)` - sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still) - """ - assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr) - assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr) - class DimDynamic(Enum): """ Controls how to perform symbol allocation for a dimension. It is always @@ -1387,14 +1379,19 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 'Min': min, 'Max': max, 'Mod': operator.mod, + 'PythonMod': operator.mod, 'FloorDiv': operator.floordiv, 'TrueDiv': operator.truediv, 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 'floor': math.floor, 'ceiling': math.ceil, + 'FloorToInt': math.floor, + 'CeilToInt': math.ceil, 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, - 'Round': builtins.round, + 'RoundToInt': builtins.round, 'RoundDecimal': builtins.round, + 'TruncToInt': math.trunc, + 'IntTrueDiv': operator.truediv, } @@ -1642,10 +1639,17 @@ def floor_div_handler(*args): congruence = (base - mod_reduced) % divisor if congruence != 0: self._congruences[s].add(congruence) + # NB: Must not be CleanDiv, it needs to be regular sympy division + # so inequality solver works. This is sort of problematic for + # is_integer tests though haha return (base - mod_reduced) / divisor if expr.has(Mod): expr = expr.replace(Mod, mod_handler) + # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative + # arguments should be OK. + if expr.has(PythonMod): + expr = expr.replace(PythonMod, mod_handler) if expr.has(FloorDiv): expr = expr.replace(FloorDiv, floor_div_handler) return expr @@ -3325,6 +3329,7 @@ def create_unbacked_symfloat(self): self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() + assert vr.is_float # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, float) @@ -3343,6 +3348,7 @@ def create_unbacked_symint(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = self._default_unspecified_value_range() + assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, int) @@ -3366,6 +3372,7 @@ def create_unbacked_symbool(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges(0, 1) + assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) @@ -3511,6 +3518,7 @@ def create_symbol( self.var_to_range[sympy_expr] &= constraint_dim.vr vr = self.var_to_range[sympy_expr] + assert vr.is_int if val not in vr: raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") @@ -3519,6 +3527,7 @@ def create_symbol( elif isinstance(val, float): self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) range_str = f"[{vr.lower}, {vr.upper}]" + assert vr.is_float else: # Skip var_range logic for SingletonInt # Only used for jagged layout nested tensors @@ -3568,6 +3577,7 @@ def create_symbol( def add_var_to_val(self, expr: sympy.Symbol, val: int): """ Adds a new symbol to the symbolic environment. """ + log.debug("add_var_to_val %s %s", expr, val, stack_info=True) assert expr not in self.var_to_val, f"{expr} already exists" self.var_to_val[expr] = sympy.Integer(val) @@ -4322,7 +4332,8 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa # Clamp values of size-like variables for x in self.size_like & var_to_range.keys(): if var_to_range[x] is not None: - var_to_range[x] = ValueRanges(2, sympy.oo) + var_to_range[x] = ValueRanges(2, sys.maxsize - 1) + assert var_to_range[x].is_int return bound_sympy(expr, var_to_range) @_lru_cache @@ -4439,6 +4450,11 @@ def _maybe_evaluate_static( vr = self._default_unspecified_value_range() if size_oblivious and k in self.size_like: lower = max(2, vr.lower) + # This is a bit dodgy: what this means is that there was a + # size-like unbacked symbol whose upper bound < 2. This + # causes... problems. + if lower <= vr.upper: + vr = ValueRanges(lower, vr.upper) else: lower = vr.lower # Don't do anything if we don't have a nontrivial lower bound @@ -4446,10 +4462,17 @@ def _maybe_evaluate_static( # SymInt if ( lower < (-sys.maxsize - 1) // 2 or - (unbacked_only and k in self.var_to_val) + (unbacked_only and k in self.var_to_val) or + not vr.is_int ): new_range_env[k] = vr continue + # The goal is to take our symbols which have various lower bounds + # and reallocate them into new symbols which are exactly positive; + # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in + # [1, inf], where s0 = ess0 + 1. This gives the most information + # to sympy for subsequent simplifications. + # # Positive means >= 1 # Positive - 1 means >= 0 # Positive + lower - 1 means >= lower @@ -4481,6 +4504,14 @@ def replace(expr, repl): self.counter["sympy_recursion_error"] += 1 return None + new_expr = safe_expand(new_expr) + if new_expr.is_number: + return new_expr + + # This is bad to do, the replacement with division leaves us with + # rationals when atom.args[0] is addition, e.g., sympy will happily + # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication! + """ floor_div_replace = {} for atom in new_expr.atoms(FloorDiv): floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) @@ -4489,13 +4520,12 @@ def replace(expr, repl): # are still free symbols if new_expr.is_number: return new_expr + """ # Check if the range can solve it statically out = bound_sympy(new_expr, new_range_env) - if expect_rational: - _assert_bound_is_rational(new_expr, out) - if out.is_singleton(): - return out.lower + if out.is_singleton(): + return out.lower return new_expr if unbacked_only else None @@ -4547,7 +4577,7 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": for fd in expr.atoms(FloorDiv): base, divisor = fd.args if self.replace(Mod(base, divisor)) in self.divisible: - div_replacements[fd] = base / divisor + div_replacements[fd] = CleanDiv(base, divisor) new_expr = expr.xreplace(div_replacements) new_expr = safe_expand(new_expr) new_pows = new_expr.atoms(sympy.Pow) @@ -4656,9 +4686,15 @@ def _update_var_to_range(self, symbol, vr): # Updates the range and the guards corresponding to each bound of the symbol. if symbol not in self.var_to_range: - self.var_to_range[symbol] = ValueRanges(lower, upper) + r = ValueRanges(lower, upper) + self.log.debug("_update_var_to_range %s = %s (new)", symbol, r) + self.var_to_range[symbol] = r else: - self.var_to_range[symbol] &= ValueRanges(lower, upper) + old = self.var_to_range[symbol] + new = old & ValueRanges(lower, upper) + if new != old: + self.var_to_range[symbol] = new + self.log.debug("_update_var_to_range %s = %s (update)", symbol, new) def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None: """ @@ -4691,7 +4727,10 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1) def issubset(x, y): - return (x & int_range).issubset(y & int_range) + if x.is_int and y.is_int: + return (x & int_range).issubset(y & int_range) + else: + return x.issubset(y) # First, refine the value range of a based on the computed value range # of tgt. This is always OK to do, even if we decide not to do the @@ -4710,8 +4749,15 @@ def issubset(x, y): # Try to invert the equality r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) if r is not None: - b_bound = self.bound_sympy(r[1]) - self.var_to_range[b] = b_bound & self.var_to_range[b] + self.log.debug("set_replacement: solve for %s in %s == %s gives %s", b, a, tgt, r) + # The solution here can be non-integral, for example, if + # we have s0 = 2*s1, then s1 = s0/2. What we would like + # to do is calculated the bounds in arbitrary precision, + # and then requantize the bound to integers when we are + # done. + rat_b_bound = self.bound_sympy(r[1]) + b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)) + self._update_var_to_range(b, b_bound) tgt_bound = self.bound_sympy(tgt) assert issubset(tgt_bound, src_bound) @@ -4920,12 +4966,12 @@ def trivial_solve(lhs, rhs): ): # We have Mod(i0, q / c) == 0, which means we can # rewrite i0 as (q / gcd(q, c)) * i1 - d = q / sympy.gcd(q, c) + d = q / sympy.gcd(q, c) # TODO: CleanDiv? i1 = self.create_unbacked_symint().node.expr # Propagate the value ranges. It doesn't really # matter if we use truediv or floordiv, because we # have established divisibility. - self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv( + self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv( self.var_to_range[i0], ValueRanges.wrap(d) )) # Propagate size-like-ness @@ -5362,7 +5408,6 @@ def _refine_ranges(self, expr: sympy.Expr) -> None: lower, upper = vr.lower, vr.upper rhs_vr = bound_sympy(rhs, self.var_to_range) - _assert_bound_is_rational(rhs, rhs_vr) # Let's suppose that we have a preexisting range for x [0, 100]. # Now, we issue a guard x > y, where the range for y is [50, 150]. diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index f9219fa4d551..871b8dd4709b 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -217,10 +217,7 @@ def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: def abs(self, number: z3.ArithRef) -> z3.ArithRef: return z3.Abs(number) - def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: - if ndigits is not None: - raise ValueError("round(..., ndigits=) is currently not supported by shape validations.") - + def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: # Pythons builtin 'round' implements the 'round half to even' strategy # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to @@ -285,7 +282,7 @@ def wrapper(*args): operator.truediv: lift(ops.div), operator.mod: lift(ops.mod), operator.abs: lift(ops.abs), - builtins.round: lift(ops.round), + builtins.round: lift(ops.round_to_int), # Math module. math.ceil: lift(ops.ceil), @@ -351,6 +348,7 @@ def __init__( self._ops = _Z3Ops(self._validator) def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + # TODO: Probably OK to relax this and allow lower precision if dtype is torch.int64: return z3.IntVal(int(value)) if dtype is torch.double: @@ -359,6 +357,20 @@ def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: return z3.BoolVal(bool(value)) raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") + def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + if dtype == torch.float64: + return z3.ToReal(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return z3.ToInt(x) + + def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.round_to_int(x) + + def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: return self._ops.div(numerator, denominator) @@ -371,11 +383,17 @@ def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: return self._ops.pow(base, exp) + def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: return self._ops.mod(p, q) - def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: - return self._ops.round(number, ndigits) + def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.ceil(x) + + def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.floor(x) def __getattr__(self, name: str) -> Any: REPLACEMENT = { diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 5109bc38ffcf..86a5b32aabb9 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,44 +1,79 @@ # mypy: allow-untyped-defs +import functools import math +import sys import sympy from sympy import S -from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or __all__ = [ "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", - "Pow", - "TrueDiv", + "IntTrueDiv", + "FloatTrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", - "Round", + "RoundToInt", "RoundDecimal", + "ToFloat", + "FloatPow", + "PowByNatural", ] +def _keep_float(f): + @functools.wraps(f) + def inner(*args): + r = f(*args) + if any(isinstance(a, sympy.Float) for a in args) and not isinstance( + r, sympy.Float + ): + r = sympy.Float(float(r)) + return r + + return inner + + def fuzzy_eq(x, y): if None in (x, y): return None return x == y +# It would be nice to have assertions on whether or not inputs is_integer +# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy +# sometimes inconsistently reports floats an integers. +# +# What we can assume from sympy is that if something is an int, it +# definitely is is_integer, but if it is a float it may or may not +# be is_integer. So we are unable to do strong asserts that things +# are NOT integers. + + +# TODO: In Triton, // rounds to zero, but in Python, it is floor division. +# When we can prove both arguments are non-negative, we should just have a +# GenericFloorDiv (name pending) which can codegen efficiently in Python/C, +# and then PythonFloorDiv and CIntDiv which have the appropriate rounding +# semantics. +# +# Right now, FloorDiv de facto changes behavior if arguments are negative or +# not, this can potentially cause correctness issues. class FloorDiv(sympy.Function): """ We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) + + NB: This is Python-style floor division, round to -Inf """ nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 - # Default return type for SymPy assumptions. - # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers - is_real = True + is_integer = True @property def base(self): @@ -53,29 +88,14 @@ def _sympystr(self, printer): divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" - # SymPy assumptions based on argument types. - def _eval_is_real(self): - return fuzzy_or([self.base.is_real, self.divisor.is_real]) - - def _eval_is_integer(self): - return fuzzy_and([self.base.is_integer, self.divisor.is_integer]) - # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod def eval(cls, base, divisor): - def check_supported_type(x): - if ( - x.is_integer is False and x.is_real is False and x.is_complex - ) or x.is_Boolean: - raise TypeError( - f"unsupported operand type(s) for //: " - f"'{type(base).__name__}' and '{type(divisor).__name__}'" - f", expected integer or real" - ) - - check_supported_type(base) - check_supported_type(divisor) + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # Assert triggered by inequality solver + # assert base.is_integer, base + # assert divisor.is_integer, divisor # We don't provide the same error message as in Python because SymPy # makes it difficult to check the types. @@ -86,26 +106,22 @@ def check_supported_type(x): return sympy.S.Zero if base.is_integer and divisor == 1: return base - if base.is_real and divisor == 1: - return sympy.floor(base) if base.is_integer and divisor == -1: return sympy.Mul(base, -1) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): - return base // divisor - if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance( - divisor, (sympy.Integer, sympy.Float) - ): - return sympy.floor(base / divisor) + return sympy.Integer(int(base) // int(divisor)) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) - if isinstance(divisor, sympy.Rational) and divisor.p == 1: - return sympy.floor(base * divisor.q) + # gcd in sympy is over polynomials, so you'll end up with rationals if + # you do this. Don't. + """ if isinstance(base, sympy.Add): for a in base.args: gcd = sympy.gcd(a, divisor) if gcd == divisor: return FloorDiv(base - a, divisor) + a / gcd + """ try: gcd = sympy.gcd(base, divisor) @@ -190,6 +206,19 @@ class Where(sympy.Function): nargs = (3,) + def _eval_is_integer(self): + return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] + + def _eval_is_nonnegative(self): + return ( + True + if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] + else None + ) + + def _eval_is_positive(self): + return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] + @classmethod def eval(cls, c, p, q): if c == sympy.true: @@ -198,28 +227,27 @@ def eval(cls, c, p, q): return q -class Mod(sympy.Function): - """ - We maintain this so that we avoid SymPy correctness issues, such as: - https://github.com/sympy/sympy/issues/25146 - """ - +# Python-style modulus: take sign from RHS +class PythonMod(sympy.Function): nargs = (2,) + is_integer = True + @classmethod def eval(cls, p, q): - # This was adapted from: sympy/core/mod.py + # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint + # Triggered by sympy.solvers.inequalities.reduce_inequalities + # assert p.is_integer, p + # assert q.is_integer, q if q.is_zero: raise ZeroDivisionError("Modulo by zero") - # If either of them is NaN or infinite. - if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False: - return S.NaN + # Three cases: # 1. p == 0 # 2. p is either q or -q # 3. p is integer and q == 1 - if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1): + if p is S.Zero or p in (q, -q) or q == 1: return S.Zero # Evaluate if they are both literals. @@ -248,10 +276,7 @@ def eval(cls, p, q): if sympy.Mod(p, q) == 0: return S.Zero - def _eval_is_integer(self): - p, q = self.args - return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined] - + # NB: args[1] for PythonMod def _eval_is_nonnegative(self): return True if self.args[1].is_positive else None # type: ignore[attr-defined] @@ -259,6 +284,58 @@ def _eval_is_nonpositive(self): return True if self.args[1].is_negative else None # type: ignore[attr-defined] +# Generic modulus: only defined on non-negative arguments +class Mod(sympy.Function): + nargs = (2,) + + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, p, q): + # This was adapted from: sympy/core/mod.py + + # Triggered by + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # assert p.is_integer, p + # assert q.is_integer, q + + if q.is_zero: + raise ZeroDivisionError("Modulo by zero") + + # Three cases: + # 1. p == 0 + # 2. p is either q or -q + # 3. p is integer and q == 1 + if p is S.Zero or p in (q, -q) or q == 1: + return S.Zero + + # Evaluate if they are both literals. + if q.is_Number and p.is_Number: + assert p >= 0, p + assert q >= 1, q + return p % q + + # If q == 2, it's a matter of whether p is odd or even. + if q.is_Number and q == 2: + if p.is_even: + return S.Zero + if p.is_odd: + return S.One + + # If p is a multiple of q. + r = p / q + if r.is_integer: + return S.Zero + + # If p < q and its ratio is positive, then: + # - floor(p / q) = 0 + # - p % q = p - floor(p / q) * q = p + less = p < q + if less.is_Boolean and bool(less) and r.is_positive: + return p + + class CleanDiv(FloorDiv): """ Div where we can assume no rounding. @@ -268,6 +345,36 @@ class CleanDiv(FloorDiv): pass +# Don't use sympy ceiling/floor as they will attempt simplifications involving +# frac +class CeilToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): + return sympy.Integer(math.ceil(float(number))) + + +class FloorToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): + return sympy.Integer(math.floor(float(number))) + + class CeilDiv(sympy.Function): """ Div used in indexing that rounds up. @@ -276,6 +383,8 @@ class CeilDiv(sympy.Function): is_integer = True def __new__(cls, base, divisor): + base = sympy.sympify(base) + divisor = sympy.sympify(divisor) if sympy.gcd(base, divisor) == divisor: return CleanDiv(base, divisor) else: @@ -283,6 +392,8 @@ def __new__(cls, base, divisor): class LShift(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, shift): if shift < 0: @@ -291,6 +402,8 @@ def eval(cls, base, shift): class RShift(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, shift): if shift < 0: @@ -298,28 +411,107 @@ def eval(cls, base, shift): return base // 2**shift -# Overloaded to be compatible with regular Python. -# https://github.com/pytorch/pytorch/issues/90900 -class Pow(sympy.Function): +def safe_pow(base, exp): + sign = 1 + if base < 0: + base = -base + sign = 1 if exp % 2 == 0 else -1 + return sign * _safe_pow(base, exp) + + +def _safe_pow(base, exponent): + if exponent < 0: + raise ValueError("Exponent must be non-negative.") + + if exponent == 0: + return 1 + + half_exp = safe_pow(base, exponent // 2) + if half_exp > sys.maxsize - 1: + return sys.maxsize - 1 + + result = half_exp * half_exp + if result > sys.maxsize - 1: + return sys.maxsize - 1 + + if exponent % 2 == 1: + result *= base + if result > sys.maxsize - 1: + return sys.maxsize - 1 + + return result + + +class PowByNatural(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, exp): - if exp.is_zero: - return sympy.Integer(1) - elif base.is_zero and exp < 0: - raise ZeroDivisionError(f"{base} cannot be raised to a negative power") - else: - return base**exp + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Integer(safe_pow(base, exp)) + if isinstance(exp, sympy.Integer): + # Translate power into iterated multiplication + r = sympy.Integer(1) + for _ in range(int(exp)): + r *= base + return r + # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp + # is a natural number if we do + + +# base is assumed to be nonnegative, thereby prevent complex numbers from +# occuring +class FloatPow(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, base, exp): + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Float(float(base) ** float(exp)) + # NB: do not do any nontrivial reasoning # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 -class TrueDiv(sympy.Function): +# +# In particular, sympy division is willing to simplify x/x == 1 +# where 1 is an integer, but this must be a float if x was float. +class FloatTrueDiv(sympy.Function): + is_integer = False + is_real = True + @classmethod def eval(cls, base, divisor): + # assert base.is_integer is not True, base + # assert divisor.is_integer is not True, divisor + if divisor.is_zero: raise ZeroDivisionError("division by zero") - else: - return base / divisor + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(float(base) / float(divisor)) + + +# Overloaded to be compatible with regular Python. We distinguish this from +# FloatTrueDiv, because the code generation has to be different for this case: +# Python has a fancy algorithm for integer true division that isn't just +# "promote both arguments to float and use float division", so you need to +# codegen it differently. While technically you can work it out from the +# types of the input, this is often inconvenient to do in Inductor codegen, +# so just have a different operator +# NB: Right now, Inductor codegen doesn't implement this correctly lol +class IntTrueDiv(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, base, divisor): + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(int(base) / int(divisor)) # TODO: As an indicator, this != 0 implies == 1 (and vice versa). @@ -354,45 +546,85 @@ def eval(cls, *args): return None -class Trunc(sympy.Function): +# NB: this is inconsistent with math.trunc in Python +class TruncToFloat(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if isinstance(number, sympy.Number): + # NB: It is safe to use truncation to integer, which is what + # math.trunc does, as Python integers are arbitrary precision and + # so we are guaranteed not to lose precision when we do this + return sympy.Float(math.trunc(float(number))) + + +class TruncToInt(sympy.Function): is_integer = True @classmethod def eval(cls, number): - if number.is_integer: - return number - elif isinstance(number, sympy.Number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): return sympy.Integer(math.trunc(float(number))) -class Round(sympy.Function): +# This is float -> int +class RoundToInt(sympy.Function): is_integer = True @classmethod def eval(cls, number): - if number.is_integer: - return number - elif isinstance(number, sympy.Number): - return sympy.Integer(round(float(number))) + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Float): + return sympy.Integer(round(float(number), 0)) - def __int__(self): - # This will only ever be called when computing size hints. At that point, self.args[0] should be a number and - # no longer an expression. If it were, the float call would fail and the caller would handle this further. - return round(float(self.args[0])) # type: ignore[arg-type] +# To get float -> int, Python style round semantics. +# +# x = PyFloat_AsDouble(self); +# if (o_ndigits == Py_None) { +# /* single-argument round or with None ndigits: +# * round to nearest integer */ +# rounded = round(x); +# if (fabs(x-rounded) == 0.5) +# /* halfway case: round to even */ +# rounded = 2.0*round(x/2.0); +# return PyLong_FromDouble(rounded); +# } + +# NB: Like Round, this only ever returns floats. ndigits cannot be None class RoundDecimal(sympy.Function): + is_integer = False + is_real = True + @classmethod def eval(cls, number, ndigits): - if number.is_integer and ndigits >= 0: + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): + return sympy.Float(round(float(number), int(ndigits))) + + +class ToFloat(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, number): + if number in [sympy.oo, -sympy.oo]: return number - elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): - value_type, output_type = ( - (int, sympy.Integer) - if isinstance(number, sympy.Integer) - else (float, sympy.Float) - ) - return output_type(round(value_type(number), int(ndigits))) + + if isinstance(number, sympy.Integer): + return sympy.Float(int(number)) def make_opaque_unary_fn(name): diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index e5c3c1aa43a7..640b991cd104 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -16,16 +16,23 @@ import torch from .functions import ( + CeilToInt, CleanDiv, + FloatPow, + FloatTrueDiv, FloorDiv, + FloorToInt, + IntTrueDiv, IsNonOverlappingAndDenseIndicator, Mod, ModularIndexing, - Pow, - Round, + PowByNatural, + PythonMod, RoundDecimal, - TrueDiv, - Trunc, + RoundToInt, + ToFloat, + TruncToFloat, + TruncToInt, Where, ) @@ -50,30 +57,39 @@ def handlers(): sympy.Le: "le", sympy.Ge: "ge", sympy.Not: "not_", - TrueDiv: "truediv", + IntTrueDiv: "int_truediv", + FloatTrueDiv: "truediv", FloorDiv: "floordiv", - CleanDiv: "div", - Trunc: "trunc", + CleanDiv: "floordiv", # TODO: hmm? + TruncToFloat: "trunc", Where: "where", sympy.Add: "add", sympy.Mul: "mul", - Pow: "pow", - sympy.Pow: "pow", + FloatPow: "pow", + PowByNatural: "pow_by_natural", + # sympy simplifies x * x into Pow(x, 2), so we need to handle this. + # Do NOT use builtin Pow for floats + # TODO: There is a hazard here, if we have float * float it will + # also get turned into Pow(float, 2) but we don't want this because + # pow_by_natural is assumed to only be integers. Probably the fix is + # to add a FloatMul to impede this optimization + sympy.Pow: "pow_by_natural", Mod: "mod", + PythonMod: "mod", # TODO: this is wrong + # TODO: Inductor can generate these, but it's ill-specified which + # semantics were intended here. Needs to be cleaned up along with + # FloorDiv in a bigger cleanup sympy.Mod: "mod", sympy.Abs: "abs", sympy.log: "log", sympy.exp: "exp", - sympy.floor: "floor", - sympy.ceiling: "ceil", sympy.Min: "minimum", sympy.Max: "maximum", ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", - Round: "round", - RoundDecimal: "round", + RoundDecimal: "round_decimal", } for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: HANDLERS[getattr(sympy, name)] = name @@ -85,7 +101,11 @@ def handlers(): def sympy_interp( - analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean] + analysis, + env: Dict[sympy.Symbol, Any], + expr: Union[sympy.Expr, SympyBoolean], + *, + index_dtype=torch.int64, ): # Handle base cases dtype = None @@ -106,9 +126,32 @@ def sympy_interp( expr.args[1], sympy.core.numbers.Half ): return analysis.sqrt(sympy_interp(analysis, env, expr.args[0])) + if isinstance(expr, ToFloat): + return analysis.to_dtype( + sympy_interp(analysis, env, expr.args[0]), torch.float64 + ) # Recursive case args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type] + + # These handlers are special because they take an extra dtype argument + # specifying what they should convert to, and we need to appropriately set + # this up when we convert from Sympy. A reasonable default when you + # are translating is to conservatively do int64, and then narrow these + # arguments later when you discover you can narrow the index range. But + # if you already know that 32-bit indexing is OK, you can directly do the + # sympy translation with index_dtype=torch.int32 + INDEX_DTYPE_HANDLERS = { + TruncToInt: "trunc_to_int", + sympy.floor: "floor_to_int", + sympy.ceiling: "ceil_to_int", + FloorToInt: "floor_to_int", + CeilToInt: "ceil_to_int", + RoundToInt: "round_to_int", + } + if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: + return getattr(analysis, handler_name)(*args, index_dtype) + if hasattr(expr.func, "_torch_handler_name"): handler_name = expr.func._torch_handler_name else: diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index eea543a30943..156891ac5497 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,13 +1,26 @@ # mypy: allow-untyped-defs import math +import operator + import sympy import torch from torch.utils._sympy.functions import ( + _keep_float, + FloatPow, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, + Mod, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, OpaqueUnaryFn_sqrt, + PowByNatural, + RoundDecimal, + RoundToInt, + ToFloat, + TruncToInt, ) @@ -63,20 +76,41 @@ def not_(a): @staticmethod def reciprocal(x): - return 1 / x + return FloatTrueDiv(1.0, x) @staticmethod def square(x): - return x * x + return PowByNatural(x, 2) + + @staticmethod + def trunc_to_int(x, dtype): + return TruncToInt(x) + + @staticmethod + def ceil_to_int(x, dtype): + return sympy.ceiling(x) + + @staticmethod + def floor_to_int(x, dtype): + return sympy.floor(x) + + @staticmethod + def floor(x): + return _keep_float(sympy.floor)(x) + + @staticmethod + def ceil(x): + return _keep_float(sympy.ceiling)(x) + + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return ToFloat(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") @staticmethod def mod(x, y): - ret = abs(x) % abs(y) - # without check: - # tracing will fail trying to go through control-flow if x is Proxy() - if isinstance(x, (int, sympy.Number)) and x < 0: - ret *= -1 - return ret + return Mod(x, y) @staticmethod def abs(x): @@ -88,37 +122,31 @@ def neg(x): @staticmethod def truediv(a, b): - return a / b + return FloatTrueDiv(a, b) @staticmethod - def div(a, b): - return ReferenceAnalysis.truediv(a, b) + def int_truediv(a, b): + return IntTrueDiv(a, b) @staticmethod def floordiv(a, b): - if b == 0: - return sympy.nan if a == 0 else sympy.zoo - return a // b + return FloorDiv(a, b) @staticmethod def truncdiv(a, b): - result = a / b - if result.is_finite: - result = sympy.Integer(result) - - return result + raise NotImplementedError("TODO: truncdiv") @staticmethod def add(a, b): - return a + b + return _keep_float(operator.add)(a, b) @staticmethod def mul(a, b): - return a * b + return _keep_float(operator.mul)(a, b) @staticmethod def sub(a, b): - return a - b + return _keep_float(operator.sub)(a, b) @staticmethod def exp(x): @@ -134,39 +162,27 @@ def sqrt(x): @staticmethod def pow(a, b): - return a**b + return _keep_float(FloatPow)(a, b) + + @staticmethod + def pow_by_natural(a, b): + return PowByNatural(a, b) @staticmethod def minimum(a, b): - # Poorman's version of upcasting in Sympy - # This won't do for sympy.Expr as the casting does nothing for those - if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: - result_type = sympy.Float - else: - assert a.is_Integer - assert b.is_Integer - result_type = sympy.Integer - return sympy.Min(result_type(a), result_type(b)) + return sympy.Min(a, b) @staticmethod def maximum(a, b): - # Poorman's version of upcasting in Sympy - # This won't do for sympy.Expr as the casting does nothing for those - if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: - result_type = sympy.Float - else: - assert a.is_Integer - assert b.is_Integer - result_type = sympy.Integer - return sympy.Max(result_type(a), result_type(b)) + return sympy.Max(a, b) @staticmethod - def floor(x): - return sympy.floor(x) + def round_to_int(a, dtype): + return RoundToInt(a) @staticmethod - def ceil(x): - return sympy.ceiling(x) + def round_decimal(a, b): + return RoundDecimal(a, b) # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain @@ -192,10 +208,20 @@ def not_(a): def floordiv(a, b): return a // b + @staticmethod + def mod(x, y): + return x % y + @staticmethod def truncdiv(a, b): return a / b + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return float(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + @staticmethod def exp(x): raise AssertionError("exp is not valid shape sympy expr") @@ -217,9 +243,40 @@ def maximum(a, b): return torch.sym_max(a, b) @staticmethod - def floor(x): + def floor_to_int(x, dtype): return math.floor(x) @staticmethod - def ceil(x): + def ceil_to_int(x, dtype): return math.ceil(x) + + @staticmethod + def floor(x): + return float(math.floor(x)) + + @staticmethod + def ceil(x): + return float(math.ceil(x)) + + @staticmethod + def truediv(a, b): + return a / b + + @staticmethod + def pow(a, b): + return a**b + + @staticmethod + def pow_by_natural(a, b): + # Pray that safe_pow is not needed here lol. In particular, this + # never participates in VR low/high ranges, so overflow should be + # unlikely + return a**b + + @staticmethod + def round_to_int(a, dtype): + return round(a) + + @staticmethod + def round_decimal(a, b): + return round(a, ndigits=b) diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index 6276c696293c..02ddf7c34219 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -88,6 +88,7 @@ def try_solve( # Return if we were able to isolate 'thing' on the left-hand side. if isinstance(e, sympy.Rel) and e.lhs == thing: + log.debug("solved: %s ---> %s", expr, e) return e, e.rhs return None diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 619a9046796d..48f846c2fd72 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -6,6 +6,7 @@ import logging import math import operator +import sys from typing import ( Callable, Dict, @@ -26,17 +27,20 @@ from torch._prims_common import dtype_to_type from .functions import ( - OpaqueUnaryFn_acos, - OpaqueUnaryFn_asinh, - OpaqueUnaryFn_atan, - OpaqueUnaryFn_cosh, + _keep_float, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, - OpaqueUnaryFn_sinh, OpaqueUnaryFn_sqrt, - OpaqueUnaryFn_tanh, - Round, + PowByNatural, RoundDecimal, + RoundToInt, + safe_pow, + ToFloat, + TruncToFloat, + TruncToInt, ) from .interp import sympy_interp @@ -121,6 +125,11 @@ class ValueRanges(Generic[_T]): lower: _T upper: _T is_bool: bool + is_int: bool + is_float: bool + + def __repr__(self) -> str: + return f"VR[{self.lower}, {self.upper}]" @overload def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: @@ -143,8 +152,39 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: # Because this is a frozen class object.__setattr__(self, "lower", lower) object.__setattr__(self, "upper", upper) + # Unlike bool/int in Python, we don't report bools are ints object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean)) - assert isinstance(upper, SympyBoolean) == self.is_bool + if self.is_bool: + assert isinstance(upper, SympyBoolean), (lower, upper) + + # Warning: is_int/is_float is best effort. We do pretty well in + # Dynamo, but in Inductor these attributes are often wrong because we + # are not very rigorous in dtype analysis. This is also why we need + # the flexible analysis for is_int: sometimes a sympy.oo pops in for + # an integer bound. I would /like/ for us not to do this, but it's + # too hard to push the invariant through right now. + + object.__setattr__( + self, + "is_int", + not self.is_bool + and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)), + ) + """ + # This assert is just impossible right now, too many sympy bugs + if self.is_int: + # NB: sympy will sometimes randomly lose the float-ness of zero, + # so we also need to account for that in the assertion here. + # See also https://github.com/sympy/sympy/issues/26620 + assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], ( + lower, + upper, + ) + assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper) + """ + # NB: [-oo, oo] always advertises as float! + object.__setattr__(self, "is_float", not self.is_bool and not self.is_int) + assert self.is_bool or self.is_int or self.is_float, (lower, upper) def boolify(self) -> ValueRanges[SympyBoolean]: if vr_is_bool(self): @@ -185,6 +225,8 @@ def __and__(self: AllVR, other: AllVR) -> AllVR: if self == ValueRanges.unknown(): return other assert self.is_bool == other.is_bool, (self, other) + assert self.is_int == other.is_int, (self, other) + assert self.is_float == other.is_float, (self, other) if self.is_bool: return ValueRanges( sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper) @@ -354,7 +396,12 @@ def constant(value, dtype): # using nan makes subsequent computation throw, and for the purposes of optimization # returning -math.inf - math.inf is equivalent to giving up if isinstance(value, SupportsFloat) and math.isnan(value): - return ValueRanges.unknown() + if dtype == torch.bool: + return ValueRanges.unknown_bool() + elif dtype.is_floating_point: + return ValueRanges.unknown() + else: + return ValueRanges(-sys.maxsize - 1, sys.maxsize) if is_python: type_ = dtype_to_type(dtype) @@ -370,7 +417,18 @@ def constant(value, dtype): # dtype is intXX assert value.is_integer - return ValueRanges.wrap(value) + r = ValueRanges.wrap(value) + return r + + @staticmethod + def to_dtype(a, dtype, src_dtype=None): + if dtype == torch.float64: + return ValueRanges.increasing_map(a, ToFloat) + return ValueRanges.unknown() + + @staticmethod + def trunc_to_int(a, dtype): + return ValueRanges.increasing_map(a, TruncToInt) @staticmethod def not_(a): @@ -429,7 +487,9 @@ def ge(cls, a, b): @staticmethod def add(a, b): - return ValueRanges.coordinatewise_increasing_map(a, b, operator.add) + return ValueRanges.coordinatewise_increasing_map( + a, b, _keep_float(operator.add) + ) @classmethod def mul(cls, a, b): @@ -449,11 +509,20 @@ def safe_mul(a, b): else: return a * b - return ValueRanges.coordinatewise_monotone_map(a, b, safe_mul) + return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul)) - @classmethod - def div(cls, a, b): - return cls.truediv(a, b) + @staticmethod + def int_truediv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if 0 in b or ( + (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + ): + return ValueRanges.unknown() + else: + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(IntTrueDiv) + ) @staticmethod def truediv(a, b): @@ -464,18 +533,22 @@ def truediv(a, b): ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, operator.truediv) + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(FloatTrueDiv) + ) @staticmethod def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) if 0 in b or ( - (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + # TODO: make this more precise + (-sympy.oo in a or sympy.oo in a) + or (-sympy.oo in b or sympy.oo in b) ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, operator.floordiv) + return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv) @classmethod def mod(cls, x, y): @@ -524,17 +597,51 @@ def modular_indexing(cls, a, b, c): @classmethod def is_non_overlapping_and_dense_indicator(cls, *args): - return ValueRanges.unknown() + return ValueRanges.unknown() # TODO: type here is wrong @classmethod - def pow(cls, a, b): - def is_integer(val): - return isinstance(val, int) or ( - hasattr(val, "is_integer") and val.is_integer + def pow_by_natural(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if a.is_singleton() and b.is_singleton(): + return ValueRanges.wrap(safe_pow(a.lower, b.lower)) + # NB: Exclude zero, because zero is special + elif a.lower >= 1: + # We should know that b >= 0 but we may have forgotten this fact due + # to replacements, so don't assert it, but DO clamp it to prevent + # degenerate problems + return ValueRanges.coordinatewise_increasing_map( + a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural ) + elif b.is_singleton(): + if b.lower % 2 == 0: + # x^n where n is even + return ValueRanges.convex_min_zero_map( + a, lambda x: safe_pow(x, b.lower) + ) + else: + # x^n where n is odd + return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower)) + else: + # a is potentially negative, and we don't know if the exponent is + # even or odd. So just conservatively set the upper and lower + # bound based on what the maximum absolute value could be, in both + # directions + max_base = max(a.upper, -a.lower) + return ValueRanges( + -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper) + ) + + @classmethod + def pow(cls, a, b): + return ValueRanges.unknown() + # We could implement all this, but for floating point pow, is there + # really a point? + """ a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) + # Not implemented yet. It's a bit tricky # If you want to implement it, compute the partial derivatives of a ** b # and check the ranges where the function is increasing / decreasing @@ -554,8 +661,7 @@ def is_integer(val): if b == 0: if not a.lower.is_finite: return ValueRanges.unknown() - type_ = sympy.Float if a.lower.is_real else sympy.Integer - return ValueRanges.wrap(type_(1)) + return ValueRanges.wrap(1.0) if b < 0: a = cls.reciprocal(a) @@ -564,21 +670,12 @@ def is_integer(val): if a == ValueRanges.unknown(): return ValueRanges.unknown() - # Here b > 0 - if not is_integer(b): - # If the base is positive, then we're good, otherwise nothing's defined - if a.lower >= 0: - return ValueRanges.increasing_map(a, lambda x: x**b) - else: - return ValueRanges.unknown() + # If the base is positive, then we're good, otherwise nothing's defined + if a.lower >= 0: + return ValueRanges.increasing_map(a, lambda x: x**b) else: - # b > 0 integer - if b % 2 == 0: - # x^n where n is even - return ValueRanges.convex_min_zero_map(a, lambda x: x**b) - else: - # x^n where n is odd - return ValueRanges.increasing_map(a, lambda x: x**b) + return ValueRanges.unknown() + """ @staticmethod def reciprocal(x): @@ -587,7 +684,7 @@ def reciprocal(x): if 0 in x: return ValueRanges.unknown() else: - return ValueRanges.decreasing_map(x, lambda y: 1 / y) # type: ignore[operator] + return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y)) # type: ignore[operator] @staticmethod def abs(x): @@ -616,45 +713,64 @@ def maximum(cls, a, b): def min_or_max(a, b, fn): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) + return ValueRanges.coordinatewise_increasing_map(a, b, fn) - # Performs upcasting first - def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr: - # Poorman's version of upcasting in Sympy - # Inf is not a float... - if x.is_Integer and y.is_Integer: - result_type = sympy.Integer - elif x.is_rational and y.is_rational: - result_type = sympy.Rational - else: - assert x.is_real or not x.is_finite or y.is_real or not y.is_finite - result_type = sympy.Float - return fn(result_type(x), result_type(y)) + @classmethod + def floor_to_int(cls, x, dtype): + return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) - return ValueRanges.coordinatewise_increasing_map(a, b, fn_) + @classmethod + def ceil_to_int(cls, x, dtype): + return ValueRanges.increasing_map( + x, sympy.functions.elementary.integers.ceiling + ) + + # I think these implementations are sound. The hazard here is that sympy + # will carry out the floor/ceil at too high precision and then something + # bad will happen when we convert it to float. + # + # For truncation, the implementation is clearly sound, because the desired + # target float is always exactly representable, since you're just chopping + # off bits the mantissa. But what about ceil/floor? + # + # The important constraint here is that we're not defining floor on + # arbitrary real numbers, only representable float numbers. So we can + # take advantage of the fact that before we reach the first + # unrepresentable integer in floating point space, we have the range of + # numbers corresponding to exponent zero: all integers, with no fractional + # amounts. floor/ceil is an identity operation in this case. In the + # range below here, representable floating point numbers are spaced + # exactly 1/2 apart, and notably, both the floor/ceil are defined floating + # point numbers. There is no "gap" as you step up to the next exponent. @classmethod def floor(cls, x): - return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) + return ValueRanges.increasing_map( + x, _keep_float(sympy.functions.elementary.integers.floor) + ) @classmethod def ceil(cls, x): return ValueRanges.increasing_map( - x, sympy.functions.elementary.integers.ceiling + x, _keep_float(sympy.functions.elementary.integers.ceiling) ) @classmethod - def round(cls, number, ndigits=None): - if ndigits is None: - fn = Round - else: - assert ndigits.is_singleton() - ndigits = ndigits.lower - # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind - # the second parameter. - fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 + def round_decimal(cls, number, ndigits): + if not ndigits.is_singleton(): + return ValueRanges.unknown() + + ndigits = ndigits.lower + # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind + # the second parameter. + fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 return ValueRanges.increasing_map(number, fn) + @classmethod + def round_to_int(cls, number, dtype): + return ValueRanges.increasing_map(number, RoundToInt) + # It's used in some models on symints @staticmethod def sqrt(x): @@ -709,12 +825,15 @@ def cos(x): @staticmethod def cosh(x): + return ValueRanges(0.0, sympy.oo) + """ x = ValueRanges.wrap(x) if x.lower > 0: return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh) elif x.upper < 0: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh) return ValueRanges(0.0, sympy.oo) + """ @staticmethod def sin(x): @@ -724,7 +843,8 @@ def sin(x): @staticmethod def sinh(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + return ValueRanges(-sympy.oo, sympy.oo) @staticmethod def tan(x): @@ -732,32 +852,37 @@ def tan(x): @staticmethod def tanh(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + return ValueRanges(-sympy.oo, sympy.oo) @staticmethod def asin(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh) return ValueRanges.unknown() + """ @staticmethod def acos(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos) return ValueRanges.unknown() + """ @staticmethod def atan(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) + return ValueRanges(-sympy.oo, sympy.oo) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) @staticmethod def trunc(x): - def trunc(x): - return sympy.Integer(x) if x.is_finite else x - - return ValueRanges.increasing_map(x, trunc) + return ValueRanges.increasing_map(x, TruncToFloat) class ValueRangeAnalysis(SymPyValueRangeAnalysis): @@ -792,9 +917,10 @@ def store(self, name, index, value, mode=None): def reduction(self, name, dtype, src_dtype, reduction_type, index, value): return ValueRanges.unknown() - def index_expr(self, index, dtype): + @classmethod + def index_expr(cls, index, dtype): assert isinstance(index, ValueRanges) - return index + return cls.to_dtype(index, dtype) @staticmethod def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): @@ -831,12 +957,15 @@ def cast(x, dtype): @staticmethod def square(x): - return ValueRanges.convex_min_zero_map(x, lambda y: y * y) + return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) @staticmethod def neg(x): return ValueRanges.decreasing_map(x, operator.neg) + # TODO: this is slightly inaccurate because truncdiv operates at integer + # precision, but we're going through float truediv which means we can + # potentially lose precision on the bounds @classmethod def truncdiv(cls, a, b): x = cls.truediv(a, b) @@ -857,6 +986,7 @@ def __getattr__(self, name): def bound_sympy( expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: + log.debug("bound_sympy(%s, %s)", expr, ranges) if isinstance(expr, sympy.Number): return ValueRanges.wrap(expr) From 4c971932e839fc5da2b91906ad028d4654932bca Mon Sep 17 00:00:00 2001 From: eqy Date: Sun, 9 Jun 2024 06:53:34 +0000 Subject: [PATCH 557/706] [cuDNN][SDPA] Remove `TORCH_CUDNN_SDPA_ENABLED=1`, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343) Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes. What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here... Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343 Approved by: https://github.com/Skylion007 --- aten/src/ATen/Context.h | 2 +- aten/src/ATen/native/cudnn/MHA.cpp | 9 +- aten/src/ATen/native/native_functions.yaml | 4 +- .../ATen/native/transformers/attention.cpp | 2 +- .../native/transformers/cuda/attention.cu | 51 +++++++- .../transformers/cuda/attention_backward.cu | 28 +++- .../native/transformers/cuda/sdp_utils.cpp | 82 ++++++------ .../ATen/native/transformers/sdp_utils_cpp.h | 19 ++- docs/source/backends.rst | 2 + test/test_flop_counter.py | 51 ++++++-- test/test_transformers.py | 121 +++++++++++------- tools/autograd/derivatives.yaml | 6 +- torch/_C/__init__.pyi.in | 1 + torch/backends/cuda/__init__.py | 21 +++ torch/csrc/Module.cpp | 9 ++ torch/testing/_internal/common_cuda.py | 12 +- torch/utils/flop_counter.py | 10 +- 17 files changed, 297 insertions(+), 133 deletions(-) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index a922bcd5922f..4f6eb0d4a109 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -364,7 +364,7 @@ class TORCH_API Context { bool enabled_flashSDP = true; bool enabled_mem_efficientSDP = true; bool enabled_mathSDP = true; - bool enabled_cudnnSDP = false; + bool enabled_cudnnSDP = true; #ifdef USE_ROCM bool benchmark_cudnn = true; #else diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 1f6bdbf5305a..4f992098aea8 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -614,6 +614,13 @@ void run_cudnn_SDP_bprop( Tensor& dV, const Tensor& dropoutseed, const Tensor& dropoutoffset) { + Tensor dO_ = dO; + if (!dO.strides()[dO.strides().size() - 1]) { + TORCH_WARN( + "cuDNN SDPA backward got an innermost stride of 0 in grad_out, which is unsupported. Materializing a contiguous\ + tensor which will increase memory usage..."); + dO_ = dO.contiguous(); + } cudnnHandle_t handle = getCudnnHandle(); auto key = MHACacheKeyWrapper( b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true); @@ -635,7 +642,7 @@ void run_cudnn_SDP_bprop( k, v, o, - dO, + dO_, softmaxstats, dQ, dK, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b7314756cec5..7970e17eb960 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14728,12 +14728,12 @@ CUDA: _scaled_dot_product_efficient_attention_backward_cuda tags: nondeterministic_seeded -- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset) dispatch: CUDA: _scaled_dot_product_cudnn_attention_cuda tags: nondeterministic_seeded -- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor) +- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _scaled_dot_product_cudnn_attention_backward_cuda tags: nondeterministic_seeded diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 50b47e5b1731..6a83175a15fb 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -666,7 +666,7 @@ Tensor scaled_dot_product_attention( case sdp::SDPBackend::cudnn_attention: { bool compute_logsumexp = should_compute_logsumexp(query_, key, value); auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention( - query_, key, value, dropout_p, is_causal, compute_logsumexp, scale); + query_, key, value, compute_logsumexp, dropout_p, is_causal, scale); return std::get<0>(out_lse_softmax); } case sdp::SDPBackend::flash_attention: { diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 1a5dbe3a6911..655efeec5b42 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -735,14 +735,27 @@ std::tuple _scaled_dot_product_cudnn_attention_cuda( +// Adapted from TE +// extract seed and offset from PhiloxCudaState +__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr) { + if (arg.captured_) { + *seed_ptr = static_cast(*arg.seed_.ptr); + *offset_ptr = static_cast( + *(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); + } else { + *seed_ptr = static_cast(arg.seed_.val); + *offset_ptr = static_cast(arg.offset_.val); + } +} + +std::tuple _scaled_dot_product_cudnn_attention_cuda( const Tensor& query, const Tensor& key, const Tensor& value, + bool compute_logsumexp, double dropout_p, bool is_causal, - bool training, - std::optional scale) { + c10::optional scale) { // Used for tracking usage statistics C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn"); // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) @@ -761,9 +774,33 @@ std::tuple( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + // if using dropout, we produce 1 random number for each element of the + // attention tensor + // TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount + philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_k); + unpack_cudnn<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + philox_state, static_cast(cudnn_seed.data_ptr()), static_cast(cudnn_offset.data_ptr())); + } + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + Tensor debugmask; run_cudnn_SDP_fprop(batch_size/*int64_t b*/, num_heads/*int64_t h*/, @@ -771,7 +808,7 @@ std::tuple _scaled_dot_product_efficient_attention_cuda( diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index af9da7b8835b..bc0ce3d25c03 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -171,18 +171,32 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ const Tensor& value, const Tensor& out, const Tensor& logsumexp, - const Tensor& cumulative_sequence_length_q, - const Tensor& cumulative_sequence_length_k, - const int64_t max_seqlen_batch_q, - const int64_t max_seqlen_batch_k, - double dropout_p, - bool is_causal, const Tensor& philox_seed, const Tensor& philox_offset, - std::optional scale) { +// const Tensor& cumulative_sequence_length_q, +// const Tensor& cumulative_sequence_length_k, +// const int64_t max_seqlen_batch_q, +// const int64_t max_seqlen_batch_k, + double dropout_p, + bool is_causal, + c10::optional scale) { + + + auto& ctx = at::globalContext(); + if (ctx.deterministicAlgorithms()) { + if (ctx.deterministicAlgorithmsWarnOnly()) { + TORCH_WARN_ONCE( + "cuDNN Attention defaults to a non-deterministic algorithm. ", + "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); + } + } + + const int64_t batch_size = query.size(0); const int64_t num_heads = query.size(1); const int64_t head_dim = query.size(3); + const int64_t max_seqlen_batch_q = query.size(1); + const int64_t max_seqlen_batch_k = key.size(1); const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index b474e4ee2a7c..389c08b152ba 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -44,14 +45,28 @@ namespace sdp { namespace { + +// TODO(eqy): more benchmarking to determine whether this should include sm86/89 +// Needs to be kept in-sync with test_fused_chocie in test_transformers.py +bool check_prefer_cudnn_attention() { + auto dprops = at::cuda::getCurrentDeviceProperties(); + return dprops->major >= 9; +} + // flash_attention V2 is universally faster than efficient_attention and Math std::array priority_order(sdp_params const& params) { constexpr std::array default_order{ + SDPBackend::flash_attention, + SDPBackend::cudnn_attention, + SDPBackend::efficient_attention, + SDPBackend::math}; + constexpr std::array cudnn_order{ SDPBackend::cudnn_attention, SDPBackend::flash_attention, SDPBackend::efficient_attention, SDPBackend::math}; - return default_order; + static const bool prefer_cudnn = check_prefer_cudnn_attention(); + return prefer_cudnn ? cudnn_order : default_order; } bool use_tensor_cores(sdp_params const& params, cudaDeviceProp* dprops, bool is_half) { @@ -451,17 +466,6 @@ bool check_cudnn_hardware_support(sdp_params const& params, bool debug) { return true; } -bool check_is_causal(sdp_params const& params, bool debug) { - // Check that the input is causal - if (!params.is_causal) { - if (debug) { - TORCH_WARN("CuDNN requires is_causal=True."); - } - return false; - } - return true; -} - bool check_for_nested_inputs(sdp_params const& params, bool debug) { // Check that the input is nested if (has_for_nested_inputs(params)) { @@ -485,22 +489,6 @@ bool check_dtypes_low_precision(sdp_params const& params, bool debug) { } } -bool check_runtime_enabled_cudnn(sdp_params const& params, bool debug) { - static c10::once_flag supported_flag; - static bool supported = false; - c10::call_once(supported_flag, []() { - supported = (c10::utils::check_env("TORCH_CUDNN_SDPA_ENABLED") == true); - }); - if (!supported) { - if (debug) { - TORCH_WARN( - "The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1`"); - } - return false; - } - return true; -} - bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { // We check the global context to see if user has explicitly turned of cudnn // sdp kernels @@ -513,13 +501,15 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { return true; } -bool check_cudnn_requires_grad(sdp_params const& params, bool debug) { - // Check that the input is causal - if (input_requires_grad(params)) { - if (debug) { - TORCH_WARN("CuDNN does not currently support inputs with requires_grad=True."); +bool check_cudnn_deterministic(const sdp_params& params, bool debug) { + auto& ctx = at::globalContext(); + if (ctx.deterministicAlgorithms()) { + if (!ctx.deterministicAlgorithmsWarnOnly()) { + if (debug) { + TORCH_WARN("cuDNN SDPA is not deterministic."); + } + return false; } - return false; } return true; } @@ -527,21 +517,29 @@ bool check_cudnn_requires_grad(sdp_params const& params, bool debug) { } // namespace bool can_use_cudnn_attention(const sdp_params& params, bool debug) { - +#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \ + (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) + TORCH_WARN_ONCE(!debug, "Torch was not compiled with cuDNN attention."); + return false; +#endif // Define gate functions that determine if a flash kernel can be ran // Replace with std::to_array when we migrate to c++20 constexpr auto general_constraints = array_of( - check_runtime_enabled_cudnn, - check_runtime_disabled_cudnn, - check_cudnn_hardware_support, + check_for_nested_inputs, + check_nonzero_sequence_lengths_dense, + check_last_dim_stride_equals_1_dense*/>, check_all_tensors_on_device, + check_tensor_shapes, check_cudnn_tensor_shapes, - check_cudnn_layout, + check_runtime_disabled_cudnn, + check_cudnn_deterministic, + // check_cudnn_layout, // check_is_causal, - check_for_nested_inputs, - check_cudnn_requires_grad, - check_dtypes_low_precision); + check_dtypes_low_precision, + check_for_attn_mask_cudnn, + check_cudnn_hardware_support + ); for (auto& constraint : general_constraints) { if (!constraint(params, debug)) { return false; diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index 7c56a1f617db..70d9be903ce9 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -266,7 +266,18 @@ inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) inline bool check_for_attn_mask(sdp_params const& params, bool debug) { if (params.attn_mask.has_value()) { if (debug) { - TORCH_WARN("Flash Attention does not support non-null attn_mask."); + TORCH_WARN("Flash Attention do not support non-null attn_mask."); + } + return false; + } + return true; +} + +// TODO(eqy): remove this once support is added +inline bool check_for_attn_mask_cudnn(sdp_params const& params, bool debug) { + if (params.attn_mask.has_value()) { + if (debug) { + TORCH_WARN("cuDNN Attention does not support non-null attn_mask."); } return false; } @@ -313,7 +324,7 @@ inline bool check_tensor_shapes(sdp_params const& params, bool debug) { (query_dim == 4))) { if (debug) { TORCH_WARN( - "Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ", + "All fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ", query_dim, ", Key dim: ", params.key.dim(), @@ -425,7 +436,7 @@ inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool if (zero_seq_len_q || zero_seq_len_k) { if (debug) { TORCH_WARN( - "Both fused kernels do not support zero seq_len_q or seq_len_kv."); + "All fused kernels do not support zero seq_len_q or seq_len_kv."); } return false; } @@ -460,7 +471,7 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool } epilogue_message << " instead."; TORCH_WARN( - "Both fused kernels require the last dimension of the input to have stride 1. ", + "All fused kernels require the last dimension of the input to have stride 1. ", "Got Query.stride(-1): ", params.query.sym_stride(-1), ", Key.stride(-1): ", diff --git a/docs/source/backends.rst b/docs/source/backends.rst index ef3c720e8335..bd83e49f5f2d 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -92,6 +92,8 @@ torch.backends.cuda .. autofunction:: torch.backends.cuda.can_use_efficient_attention +.. autofunction:: torch.backends.cuda.can_use_cudnn_attention + .. autofunction:: torch.backends.cuda.sdp_kernel torch.backends.cudnn diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index 3f09a85a6d97..9f3a8ce223e5 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -9,6 +9,7 @@ from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + PLATFORM_SUPPORTS_CUDNN_ATTENTION ) from torch.testing._internal.common_utils import ( run_tests, @@ -300,7 +301,8 @@ def test_noop(self): @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION - or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION + or not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", ) def test_sdpa(self): @@ -355,15 +357,31 @@ def get_flops( if backend == "math": backend = torch.backends.cuda.sdp_kernel( - enable_flash=False, enable_math=True, enable_mem_efficient=False + enable_flash=False, + enable_math=True, + enable_mem_efficient=False, + enable_cudnn=False, ) elif backend == "flash": backend = torch.backends.cuda.sdp_kernel( - enable_flash=True, enable_math=False, enable_mem_efficient=False + enable_flash=True, + enable_math=False, + enable_mem_efficient=False, + enable_cudnn=False, ) elif backend == "mem_efficient": backend = torch.backends.cuda.sdp_kernel( - enable_flash=False, enable_math=False, enable_mem_efficient=True + enable_flash=False, + enable_math=False, + enable_mem_efficient=True, + enable_cudnn=False, + ) + elif backend == "cudnn": + backend = torch.backends.cuda.sdp_kernel( + enable_flash=False, + enable_math=False, + enable_mem_efficient=False, + enable_cudnn=True, ) mode = FlopCounterMode() @@ -389,22 +407,24 @@ def get_flops( flops = [ run_uniform_flops(backend, with_backward=False) - for backend in ["math", "flash", "mem_efficient"] + for backend in ["math", "flash", "mem_efficient", "cudnn"] ] - flops_fw_math, flops_fw_flash, flops_fw_efficient = flops + flops_fw_math, flops_fw_flash, flops_fw_efficient, flops_fw_cudnn = flops self.assertEqual(flops_fw_math, flops_fw_flash) self.assertEqual(flops_fw_math, flops_fw_efficient) + self.assertEqual(flops_fw_math, flops_fw_cudnn) self.assertExpectedInline(str(flops_fw_math), """134217728""") flops = [ run_uniform_flops(backend, with_backward=True) - for backend in ["math", "flash", "mem_efficient"] + for backend in ["math", "flash", "mem_efficient", "cudnn"] ] - flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops + flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient, flops_fw_bw_cudnn = flops self.assertEqual(flops_fw_math * 3, flops_fw_bw_math) self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash) self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient) + self.assertEqual(flops_fw_bw_flash, flops_fw_bw_cudnn) run_nonuniform_flops = functools.partial( get_flops, @@ -448,15 +468,24 @@ def get_flops(q, k, v, backend, with_backward=False): if backend == "math": backend = torch.backends.cuda.sdp_kernel( - enable_flash=False, enable_math=True, enable_mem_efficient=False + enable_flash=False, + enable_math=True, + enable_mem_efficient=False, + enable_cudnn=False, ) elif backend == "flash": backend = torch.backends.cuda.sdp_kernel( - enable_flash=True, enable_math=False, enable_mem_efficient=False + enable_flash=True, + enable_math=False, + enable_mem_efficient=False, + enable_cudnn=False, ) elif backend == "mem_efficient": backend = torch.backends.cuda.sdp_kernel( - enable_flash=False, enable_math=False, enable_mem_efficient=True + enable_flash=False, + enable_math=False, + enable_mem_efficient=True, + enable_cudnn=False, ) with backend, mode: diff --git a/test/test_transformers.py b/test/test_transformers.py index 774cb60ee94d..fdf64f11aed6 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -43,7 +43,8 @@ IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, - PLATFORM_SUPPORTS_CUDNN_ATTENTION + PLATFORM_SUPPORTS_CUDNN_ATTENTION, + tf32_on_and_off ) if TEST_FAIRSEQ: @@ -315,6 +316,7 @@ def test_transformerencoderlayer_src_mask(self, device, nhead): with torch.no_grad(): model(src, src_mask=src_mask) + @tf32_on_and_off(0.001) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @parametrize("use_autocast", [True, False]) @@ -405,8 +407,9 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste # no garauntees on output corresponding to masked tokens, so they may vary between slow/fast path. set all to 0. fastpath_output_expanded = fastpath_output_expanded.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) - torch.testing.assert_close(fastpath_output_expanded, slowpath_output, rtol=1e-7, atol=1e-5) + self.assertEqual(fastpath_output_expanded, slowpath_output) + @tf32_on_and_off(0.001) @parametrize("with_no_grad", [True, False]) @parametrize("training", [True, False]) @parametrize("enable_nested_tensor", [False]) @@ -450,7 +453,7 @@ def test_transformerencoder_square_input(self, with_no_grad, training, enable_ne [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]] ).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + self.assertEqual(result, ref_output) @parametrize("batch_first", [True, False]) @parametrize("training", [True, False]) @@ -1397,7 +1400,7 @@ def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend): q = torch.randn(size, device=device, dtype=dtype) k = torch.randn(size, device=device, dtype=dtype) v = torch.randn(size, device=device, dtype=dtype) - with self.assertWarnsRegex(UserWarning, "Both fused kernels requires query, key and value to be 4 dimensional"): + with self.assertWarnsRegex(UserWarning, "All fused kernels requires query, key and value to be 4 dimensional"): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) @@ -1429,7 +1432,7 @@ def test_invalid_sequence_lengths(self, device, kernel: SDPBackend): make_tensor = partial(torch.rand, device=device, dtype=dtype) size = SdpaShape(2, 2, 0, 8) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) - with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support zero seq_len_q or seq_len_kv."): + with self.assertWarnsRegex(UserWarning, "All fused kernels do not support zero seq_len_q or seq_len_kv."): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) @@ -1444,7 +1447,7 @@ def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): size = SdpaShape(2, 2, 8, 8) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) q.as_strided_(size, [2, 2, 2, 2]) - with self.assertWarnsRegex(UserWarning, "Both fused kernels require the last dimension of the input to have stride 1."): + with self.assertWarnsRegex(UserWarning, "All fused kernels require the last dimension of the input to have stride 1."): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) @@ -2353,7 +2356,7 @@ def rand_tensor(shape): math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous() self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) - self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3) + self.assertEqual(actual_test, math_ref_test, atol=7e-3, rtol=7e-3) @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Efficient Attention was not built for this system") @parametrize("contiguous_inputs", [True, False]) @@ -2471,7 +2474,12 @@ def test_fused_sdp_choice(self, device, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - if PLATFORM_SUPPORTS_FLASH_ATTENTION: + major, minor = torch.cuda.get_device_capability(device) + is_sm90_or_newer = major >= 9 + + if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and is_sm90_or_newer: + assert torch._fused_sdp_choice(query, key, value) == SDPBackend.CUDNN_ATTENTION.value + elif PLATFORM_SUPPORTS_FLASH_ATTENTION: assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION.value else: assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value @@ -2511,7 +2519,8 @@ def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fus make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float16, packed=False, requires_grad=True) query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) - kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else "Flash Attention" + kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else \ + "Flash Attention" if fused_kernel == SDPBackend.FLASH_ATTENTION else "cuDNN Attention" warning_context = ( self.assertWarnsRegex( UserWarning, @@ -2523,7 +2532,12 @@ def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fus with use_deterministic_algorithims(True, warn_only=warn_only): with sdpa_kernel(backends=[fused_kernel]): with warning_context: - torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward() + if warn_only or fused_kernel != SDPBackend.CUDNN_ATTENTION: + torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward() + else: + # cuDNN attention has no deterministic fallback + self.assertRaises(RuntimeError, lambda: + torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()) @unittest.skip("This test is not behaving deterministaclly non-deterministaclly on CI/CD") @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not support fused SDPA") @@ -2663,7 +2677,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) # Fudge Factor when dropout is enabled - dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 2.0 + dropout_fudge_factor = 1.5 if dropout_p == 0.0 else 2.0 query_fudge_factor = dropout_fudge_factor grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) @@ -2786,8 +2800,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, # Fudge Factor when dropout is enabled dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.75 mask_fudge_factor = 1.0 if attn_mask is None else 1.5 + query_fudge_factor = 2.0 - query_fudge_factor = dropout_fudge_factor grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) # TODO: Investigate why grad_k needs larger tolerances @@ -2989,7 +3003,8 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d device=device, dtype=dtype, requires_grad=True) fused_op = (torch.ops.aten._scaled_dot_product_efficient_attention - if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention) + if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention + if fused_kernel == SDPBackend.FLASH_ATTENTION else torch.ops.aten._scaled_dot_product_cudnn_attention) # Run the math kernel on low precision references query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype) @@ -3007,6 +3022,10 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d kwargs["attn_bias"] = None if fused_kernel == SDPBackend.FLASH_ATTENTION: kwargs['return_debug_mask'] = dropout_p > 0.0 + if fused_kernel == SDPBackend.CUDNN_ATTENTION: + kwargs["compute_log_sumexp"] = True + if "return_debug_mask" in kwargs: + kwargs.pop("return_debug_mask") with torch.cuda.stream(s): # Create real output output_tuple = fused_op(query, key, value, **kwargs) @@ -3044,7 +3063,8 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d # Low Precision Math Reference out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - else: + # cuDNN attention doesn't support returning dropout mask + elif fused_kernel != SDPBackend.CUDNN_ATTENTION: # Create the dropout_mask dropout_mask = get_dropout_mask(output_tuple, fused_kernel, batch_size, n_heads, seq_len_q, seq_len_k, dropout_p, device) @@ -3062,37 +3082,38 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d with torch.cuda.graph(g1): out.backward(upstream_grad) g1.replay() - out_ref.backward(upstream_grad.to(out_ref.dtype)) - out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) - - # [Note] Fused Tolerances - # Establish the numerical error between the "true" high precision math output - # and the low precision math reference. We use this reference for the atol - # And we use the default rtol for the low precision type. - # We then provide a fudge factor for gradients respectively to account - # for the use of the fused kernel rather than the eager implemntation. - output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) - - # Fudge Factor when dropout is enabled - dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5 - - query_fudge_factor = dropout_fudge_factor - grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) - - # TODO: Investigate why grad_k needs larger tolerances - key_fudge_factor = 8 * dropout_fudge_factor - grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) - - value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 - grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) - - self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) - self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), - atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) - self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), - atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) - self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), - atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) + if fused_kernel != SDPBackend.CUDNN_ATTENTION or dropout_p == 0.0: + out_ref.backward(upstream_grad.to(out_ref.dtype)) + out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) + + # [Note] Fused Tolerances + # Establish the numerical error between the "true" high precision math output + # and the low precision math reference. We use this reference for the atol + # And we use the default rtol for the low precision type. + # We then provide a fudge factor for gradients respectively to account + # for the use of the fused kernel rather than the eager implemntation. + output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) + + # Fudge Factor when dropout is enabled + dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5 + + query_fudge_factor = dropout_fudge_factor + grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) + + # TODO: Investigate why grad_k needs larger tolerances + key_fudge_factor = 8 * dropout_fudge_factor + grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) + + value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 + grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) + + self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), + atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) + self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), + atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) + self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), + atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) @skipIfRocm # Nested Tensor @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @@ -3216,7 +3237,7 @@ def _broadcast(t, batch_broadcasted, num_heads_broadcasted): query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(), attn_mask=None, dropout_p=0.0, is_causal=False) - self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) + self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1.5e-3, rtol=1e-2) @skipIfRocm # Nested tensor @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @@ -3379,6 +3400,7 @@ def run_test( forw_tolerances: Optional[Tolerances] = None, grad_tolerances: Optional[Tolerances] = None, backend=None, + causal_variant=None, ): if backend is not None: torch._dynamo.reset() @@ -3446,9 +3468,11 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: Lis if causal_variant == CausalVariant.UPPER_LEFT: attn_bias = causal_upper_left(seq_len_q, seq_len_kv) else: + print(seq_len_q, seq_len_kv) attn_bias = causal_lower_right(seq_len_q, seq_len_kv) - self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None) + with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.MATH]): + self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None) @skipIfRocm # CausalVariant @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) @@ -3479,7 +3503,8 @@ def test_causal_variants_compile(self, device, causal_variant: CausalVariant, sh else: attn_bias = causal_lower_right(seq_len_q, seq_len_kv) - self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts) + with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.MATH]): + self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts) self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)]) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 1e9b9091a20e..81bd19b8e185 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2824,9 +2824,9 @@ output_differentiability: [True, False, False, False, False, False] query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) -- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) - output_differentiability: [True, False, False, False, False, False, False, False, False] - query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) +- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset) + output_differentiability: [True, False, False, False] + query, key, value: _scaled_dot_product_cudnn_attention_backward(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, dropout_p, is_causal, scale) # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 4326cd3c71da..6f719d7db20a 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1936,6 +1936,7 @@ class _SDPBackend(Enum): EFFICIENT_ATTENTION = 2 CUDNN_ATTENTION = 3 +def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ... def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ... def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ... diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index cb5f511bc5db..00f511a544e6 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -27,6 +27,7 @@ "enable_math_sdp", "can_use_flash_attention", "can_use_efficient_attention", + "can_use_cudnn_attention", "sdp_kernel", ] @@ -359,6 +360,26 @@ def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool return torch._C._can_use_mem_efficient_attention(params, debug) +def can_use_cudnn_attention(params: SDPAParams, debug: bool = False) -> bool: + r"""Check if cudnn_attention can be utilized in scaled_dot_product_attention. + + Args: + params: An instance of SDPAParams containing the tensors for query, + key, value, an optional attention mask, dropout rate, and + a flag indicating if the attention is causal. + debug: Whether to logging.warn with information as to why cuDNN attention could not be run. + Defaults to False. + + Returns: + True if cuDNN can be used with the given parameters; otherwise, False. + + Note: + This function is dependent on a CUDA-enabled build of PyTorch. It will return False + in non-CUDA environments. + """ + return torch._C._can_use_cudnn_attention(params, debug) + + def cudnn_sdp_enabled(): r""" .. warning:: This flag is beta and subject to change. diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index dbd58657b951..7e509fce7af2 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1918,6 +1918,15 @@ Call this whenever a new thread is created in order to propagate values from return sdp::can_use_mem_efficient_attention(params, debug); #else return false; +#endif + }); + py_module.def( + "_can_use_cudnn_attention", + [](const sdp::sdp_params& params, bool debug) { +#ifdef USE_CUDA + return sdp::can_use_cudnn_attention(params, debug); +#else + return false; #endif }); diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 189be09d8ba9..02b38bf9351a 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -50,6 +50,9 @@ def evaluate_platform_supports_flash_attention(): return not IS_WINDOWS and SM80OrLater return False +def evaluate_platform_supports_cudnn_attention(): + return (not TEST_WITH_ROCM) and (not IS_WINDOWS) and TEST_CUDA and SM80OrLater + def evaluate_platform_supports_efficient_attention(): if TEST_WITH_ROCM: return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') @@ -59,11 +62,12 @@ def evaluate_platform_supports_efficient_attention(): PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention()) PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention()) -# TODO(eqy): gate this against a cuDNN version -PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM and - torch.backends.cuda.cudnn_sdp_enabled()) +PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention()) + # This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate -PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) +PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or + PLATFORM_SUPPORTS_CUDNN_ATTENTION or + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index 93c1cf78e710..a4f05c6c720b 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -243,7 +243,9 @@ def sdpa_flop_count(query_shape, key_shape, value_shape): return total_flops -@register_flop_formula([aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention]) +@register_flop_formula([aten._scaled_dot_product_efficient_attention, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_cudnn_attention]) def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: """Count flops for self-attention.""" # NB: We aren't accounting for causal attention here @@ -435,7 +437,9 @@ def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape return total_flops -@register_flop_formula([aten._scaled_dot_product_efficient_attention_backward, aten._scaled_dot_product_flash_attention_backward]) +@register_flop_formula([aten._scaled_dot_product_efficient_attention_backward, + aten._scaled_dot_product_flash_attention_backward, + aten._scaled_dot_product_cudnn_attention_backward]) def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: """Count flops for self-attention backward.""" return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) @@ -516,8 +520,10 @@ def _efficient_attention_backward_flop( aten.convolution_backward: conv_backward_flop, aten._scaled_dot_product_efficient_attention: sdpa_flop, aten._scaled_dot_product_flash_attention: sdpa_flop, + aten._scaled_dot_product_cudnn_attention: sdpa_flop, aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop, aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop, + aten._scaled_dot_product_cudnn_attention_backward: sdpa_backward_flop, aten._flash_attention_forward: _flash_attention_forward_flop, aten._efficient_attention_forward: _efficient_attention_forward_flop, aten._flash_attention_backward: _flash_attention_backward_flop, From 75b0720a97ac5d82e8a7a1a6ae7c5f7a87d7183d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 9 Jun 2024 09:05:17 +0000 Subject: [PATCH 558/706] Revert "Use hidden visibility in OBJECTCXX files (#127265)" This reverts commit 669560d51aa1e81ebd09e2aa8288d0d314407d82. Reverted https://github.com/pytorch/pytorch/pull/127265 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but I suspect that it causes this failure https://github.com/pytorch/vision/issues/8478 on vision where its C++ extension could not be loaded on macOS ([comment](https://github.com/pytorch/pytorch/pull/127265#issuecomment-2156401838)) --- cmake/public/utils.cmake | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 0f5da8e6cae2..c4adccf3b61b 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -479,9 +479,7 @@ function(torch_compile_options libname) # templated classes crossing library boundary get duplicated (but identical) # definitions. It's easier to just disable it. target_compile_options(${libname} PRIVATE - $<$: -fvisibility=hidden> - $<$: -fvisibility=hidden> - $<$: -fvisibility=hidden>) + $<$: -fvisibility=hidden>) endif() # Use -O2 for release builds (-O3 doesn't improve perf, and -Os results in perf regression) From 0bf2fe522ad7dde7a6a226971f4a31e9479bf46c Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Thu, 6 Jun 2024 10:25:05 -0700 Subject: [PATCH 559/706] [RFC] Provide optional switches to _dump_nccl_trace (#127651) Summary: Data from PyTorch distributed is mostly useful during initial stages of model development. Provide options to reduce data sent/dumped. `_dump_nccl_trace` takes 3 optional switches. Default as before returns everything - `includeCollectives`: option to also include collectives: Default is True. - `includeStacktraces`: option to include stack traces in collectives. Default is True. - `onlyActive`: option to only send active collective work - i.e. not completed. Default is False (i.e. send everything) Test Plan: Unit tests Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/127651 Approved by: https://github.com/wconstab --- test/distributed/test_c10d_nccl.py | 73 ++++++++------- .../distributed/c10d/ProcessGroupNCCL.cpp | 25 +++-- .../distributed/c10d/ProcessGroupNCCL.hpp | 15 ++- torch/csrc/distributed/c10d/TraceUtils.h | 92 ++++++++++++------- torch/csrc/distributed/c10d/init.cpp | 25 ++++- 5 files changed, 152 insertions(+), 78 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index baf2adb1fb2d..21a8a632bade 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3523,7 +3523,8 @@ class NCCLTraceTest(NCCLTraceTestBase): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("timing_enabled", [True, False]) - def test_short(self, timing_enabled): + @parametrize("include_collectives", [True, False]) + def test_short(self, timing_enabled, include_collectives): if self.rank == self.MAIN_PROCESS_RANK: return pg = self._create_process_group_nccl() @@ -3538,8 +3539,14 @@ def test_short(self, timing_enabled): # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api time.sleep(1) - - t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + if include_collectives: + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + else: + t = pickle.loads( + torch._C._distributed_c10d._dump_nccl_trace( + includeCollectives=False, includeStackTraces=None, onlyActive=None + ) + ) ver = t["version"] self.assertEqual(ver, "2.1") pg_config = t["pg_config"] @@ -3550,35 +3557,39 @@ def test_short(self, timing_enabled): self.assertIn("ranks", default_pg_info) global_ranks = pg_config["0"]["ranks"] self.assertEqual(len(json.loads(global_ranks)), self.world_size) - t = t["entries"] - self.assertEqual(len(t), 2) - last = t[-1] - self.assertEqual(last["process_group"], ("0", "default_pg")) - self.assertEqual(last["state"], "completed") - s = last["time_discovered_started_ns"] - f = last["time_discovered_completed_ns"] - self.assertEqual(last["record_id"], 1) - self.assertIsNotNone(f) - if timing_enabled: - self.assertIsNotNone(s) - self.assertTrue(s <= f) - self.assertIn("test_c10d_nccl.py", str(last["frames"])) - self.assertEqual(last["input_sizes"], ((3, 4),)) - self.assertEqual(last["input_dtypes"], ["Float"]) - self.assertEqual(last["output_sizes"], ((3, 4),)) - self.assertEqual(last["output_dtypes"], ["Float"]) - self.assertEqual(last["collective_seq_id"], 2) - now = datetime.now() - event_created_time = datetime.fromtimestamp( - last["time_created_ns"] / 1000000000 - ) - before_test = now - timedelta(minutes=1) - self.assertTrue(before_test < event_created_time < now) - if timing_enabled: - # very loose bounds, measured 0.036 ms on devgpu - self.assertTrue(0 < last["duration_ms"] < 100) + if include_collectives: + self.assertEqual(len(t["entries"]), 2) + t = t["entries"] + self.assertEqual(len(t), 2) + last = t[-1] + self.assertEqual(last["process_group"], ("0", "default_pg")) + self.assertEqual(last["state"], "completed") + s = last["time_discovered_started_ns"] + f = last["time_discovered_completed_ns"] + self.assertEqual(last["record_id"], 1) + self.assertIsNotNone(f) + if timing_enabled: + self.assertIsNotNone(s) + self.assertTrue(s <= f) + self.assertIn("test_c10d_nccl.py", str(last["frames"])) + self.assertEqual(last["input_sizes"], ((3, 4),)) + self.assertEqual(last["input_dtypes"], ["Float"]) + self.assertEqual(last["output_sizes"], ((3, 4),)) + self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["collective_seq_id"], 2) + now = datetime.now() + event_created_time = datetime.fromtimestamp( + last["time_created_ns"] / 1000000000 + ) + before_test = now - timedelta(minutes=1) + self.assertTrue(before_test < event_created_time < now) + if timing_enabled: + # very loose bounds, measured 0.036 ms on devgpu + self.assertTrue(0 < last["duration_ms"] < 100) + else: + self.assertTrue("duration_ms" not in last) else: - self.assertTrue("duration_ms" not in last) + self.assertTrue("entries" not in t) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 26381207ca7d..8adf1e02c1a0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -342,7 +342,10 @@ void cacheAllocatorDeregisterHook( } #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) -std::string dump_nccl_trace() { +std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { std::unordered_map< std::string /* ncclUniqueID */, std::unordered_map /* dump from this comm */> @@ -362,19 +365,27 @@ std::string dump_nccl_trace() { std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); } - return NCCLTraceBuffer::get()->dump(ncclDumpMap); + return NCCLTraceBuffer::get()->dump( + ncclDumpMap, includeCollectives, includeStackTraces, onlyActive); } + #else -std::string dump_nccl_trace() { - return NCCLTraceBuffer::get()->dump(c10::nullopt); +std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + return NCCLTraceBuffer::get()->dump( + c10::nullopt, includeCollectives, includeStackTraces, onlyActive); } #endif // TODO(c-p-i-o): add a JSON endpoint. control_plane::RegisterHandler dumpHandler{ "dump_nccl_trace_pickle", - [](const control_plane::Request&, control_plane::Response& res) { - res.setContent(dump_nccl_trace(), "application/octet-stream"); + [](const control_plane::Request& req, control_plane::Response& res) { + // TODO: c-p-i-o: params from the request need to go to dump_nccl_trace. + res.setContent( + dump_nccl_trace(true, true, false), "application/octet-stream"); }}; std::optional)>>& @@ -1197,7 +1208,7 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() { // We dump nccl trace into local disk by default and users can register // their customized writer by inheriting `DebugInfoWriter` via // `registerDebugInfoWriter`. - auto ncclTrace = dump_nccl_trace(); + auto ncclTrace = dump_nccl_trace(true, true, false); DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank()); LOG(INFO) << logPrefix() << "ProcessGroupNCCL dumping nccl trace to " << writer.getWriterTarget(); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index f36ebdeb16e9..faaabe411bfc 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -1114,11 +1114,16 @@ class TORCH_API ProcessGroupNCCL : public Backend { ProcessGroupStatus pgStatus_; }; -TORCH_API std::string dump_nccl_trace(); - -// Gets a mutable reference to a global optional function. Heartbeat Monitor -// will use this function to dump traces, if available. Inside fbcode, we store -// a function here that uses an internal tool for process tracing +// Dumps the NCCL comm traces and additional information about the Process +// Group. +TORCH_API std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive); + +// Gets a mutable reference to a global optional function.Heartbeat Monitor +// will use this function to dump traces, if available. Inside fbcode, we +// store a function here that uses an internal tool for process tracing TORCH_API std::optional< std::function)>>& get_cpp_trace_dumper(); diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index e8dadb6537e0..c3b0464cf992 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -655,31 +655,44 @@ struct NCCLTraceBuffer { entry->start_ = entry->end_ = nullptr; } - std::string dump( - const std::optional>>& ncclDumpMap) { - auto result = dump_entries(); + const c10::List getCollectiveTrace( + bool includeStacktraces, + bool onlyActive) { auto entries = new_list(); - + auto result = dump_entries(); std::vector tracebacks; - for (auto& e : result) { - tracebacks.push_back(e.traceback_.get()); - } - torch::SymbolizedTracebacks stracebacks = torch::symbolize(tracebacks); + torch::SymbolizedTracebacks stracebacks; std::vector all_frames; - for (const auto& f : stracebacks.all_frames) { - auto d = new_dict(); - d.insert(name_key, f.funcname); - d.insert(filename_key, f.filename); - d.insert(line_key, int64_t(f.lineno)); - all_frames.emplace_back(std::move(d)); + if (includeStacktraces) { + for (auto& e : result) { + tracebacks.push_back(e.traceback_.get()); + } + stracebacks = torch::symbolize(tracebacks); + for (const auto& f : stracebacks.all_frames) { + auto d = new_dict(); + d.insert(name_key, f.funcname); + d.insert(filename_key, f.filename); + d.insert(line_key, int64_t(f.lineno)); + all_frames.emplace_back(std::move(d)); + } } - for (auto i : c10::irange(result.size())) { - auto& e = result.at(i); - auto& tb = stracebacks.tracebacks.at(i); auto dict = new_dict(); + auto& e = result.at(i); + // Skip completed events + if (onlyActive && e.time_discovered_completed_.has_value()) { + continue; + } + + if (includeStacktraces) { + auto& tb = stracebacks.tracebacks.at(i); + auto frames = new_list(); + for (int64_t frame : tb) { + frames.push_back(all_frames.at(frame)); + } + dict.insert(frames_key, frames); + } + dict.insert(record_id_key, int64_t(e.id_)); dict.insert(pg_id_key, int64_t(e.pg_id_)); dict.insert(pg_name_key, e.pg_name_); @@ -741,13 +754,13 @@ struct NCCLTraceBuffer { dict.insert(retired_key, e.retired_); dict.insert(is_p2p_key, e.isP2P_); - auto frames = new_list(); - for (int64_t frame : tb) { - frames.push_back(all_frames.at(frame)); - } - dict.insert(frames_key, frames); entries.push_back(dict); } + return entries; + } + + // dump pg_entries + const c10::Dict getPgConfig() { auto pg_config = new_dict(); for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { auto pg_info = new_dict(); @@ -756,6 +769,27 @@ struct NCCLTraceBuffer { pg_info.insert("ranks", ranks_str(ranks)); pg_config.insert(std::get<0>(pg_name), pg_info); } + return pg_config; + } + + // dump all collectives + ncclDumpMap + std::string dump( + const std::optional>>& ncclDumpMap, + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + auto result = new_dict(); + // common values + result.insert(version_key, version_val); + result.insert(pg_config_key, getPgConfig()); + + // collective trace + if (includeCollectives) { + result.insert( + entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); + } // convert ncclDumpMap into a dictionary auto per_comm_dict = new_dict(); @@ -768,16 +802,10 @@ struct NCCLTraceBuffer { per_comm_dict.insert(ncclId, inner_dict); } } - - auto dict = new_dict(); - dict.insert(entries_key, entries); - dict.insert(version_key, version_val); if (per_comm_dict.size() > 0) { - dict.insert(nccl_comm_key, per_comm_dict); + result.insert(nccl_comm_key, per_comm_dict); } - dict.insert(pg_config_key, pg_config); - - return pickle_str(dict); + return pickle_str(result); } }; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index f0284c0a3bb7..6f1b28886b98 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -3164,9 +3164,28 @@ such as `dist.all_reduce(tensor, async_op=True)`. Arguments: tensors(List[torch.Tensor]): List of tensors we want to hash. )"); - module.def("_dump_nccl_trace", []() { - return py::bytes(::c10d::dump_nccl_trace()); - }); + module.def( + "_dump_nccl_trace", + [](std::optional includeCollectives, + std::optional includeStackTraces, + std::optional onlyActive) { + return py::bytes(::c10d::dump_nccl_trace( + includeCollectives.value_or(true), + includeStackTraces.value_or(true), + onlyActive.value_or(false))); + }, + py::arg("includeCollectives") = std::optional(), + py::arg("includeStackTraces") = std::optional(), + py::arg("onlyActive") = std::optional(), + R"( + Arguments: + includeCollectives(bool, optional): Whether to include collective work traces. Default is True. + includeStackTraces(bool, optional): Whether to include stacktraces in the collective work traces. Default is True. + onlyActive (bool, optional): Whether to only include active collective work traces. Default is False. + Returns: + Stringified pickle work traces. + Default settings return everything - i.e. contains NCCL comm dumps and collective traces. + )"); #endif intrusive_ptr_class_<::c10d::control_plane::WorkerServer>( From c7e2c9c37eac0b408dea1cb21eb506e4ba539582 Mon Sep 17 00:00:00 2001 From: Shuqiang Zhang Date: Fri, 7 Jun 2024 18:16:53 -0700 Subject: [PATCH 560/706] [c10d][doc] add a doc page for NCCL ENVs (#128235) Addressing issue: https://github.com/pytorch/pytorch/issues/128204 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128235 Approved by: https://github.com/wconstab --- docs/source/torch_environment_variables.rst | 1 + .../torch_nccl_environment_variables.rst | 35 +++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 docs/source/torch_nccl_environment_variables.rst diff --git a/docs/source/torch_environment_variables.rst b/docs/source/torch_environment_variables.rst index f63760de87e9..04feed91de4a 100644 --- a/docs/source/torch_environment_variables.rst +++ b/docs/source/torch_environment_variables.rst @@ -24,3 +24,4 @@ If you find anything in this documentation that is missing, incorrect, or could debugging_environment_variables miscellaneous_environment_variables logging + torch_nccl_environment_variables diff --git a/docs/source/torch_nccl_environment_variables.rst b/docs/source/torch_nccl_environment_variables.rst new file mode 100644 index 000000000000..a2498027e7ff --- /dev/null +++ b/docs/source/torch_nccl_environment_variables.rst @@ -0,0 +1,35 @@ +.. _torch_nccl_environment_variables: + +PYTORCH ProcessGroupNCCL Environment Variables +============================================== +For more information on the environment variables, see `ProcessGroupNCCL Environment Variables `_. + +.. list-table:: + :header-rows: 1 + + * - Variable + - Description + * - ``TORCH_NCCL_HIGH_PRIORITY`` + - Control whether to use high priority stream for the NCCL communicator. + * - ``TORCH_NCCL_BLOCKING_WAIT`` + - Control whether or not wait() is blocking or non-blocking. + * - ``TORCH_NCCL_DUMP_ON_TIMEOUT`` + - Control whether dumping debug info on watchdog timeout or exception is detected. This variable must be set together with TORCH_NCCL_TRACE_BUFFER_SIZE larger than 0. + * - ``TORCH_NCCL_DESYNC_DEBUG`` + - Control whether Desync Debug is enabled. This is helpful in figuring out the culprit rank of collective desync. + * - ``TORCH_NCCL_ENABLE_TIMING`` + - If set to ``1``, enable recording start-events for all ProcessGroupNCCL collectives, and compute accurate collective timing per-collective. + * - ``TORCH_NCCL_ENABLE_MONITORING`` + - If set to ``1``,enable monitoring thread which aborts the process when the ProcessGroupNCCL Watchdog thread gets stuck and no heartbeat is detected after TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged time than necessary tying up cluster resources. + * - ``TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC`` + - Control the watchdog heartbeat timeout period after which the monitoring thread will abort the process. + * - ``TORCH_NCCL_TRACE_BUFFER_SIZE`` + - The maximum number of events we store in the flight recorder's ring buffer. One event could be the start or end of a collective, for example. Set to 0 to disable the tracebuffer and debugging info dump. + * - ``TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC`` + - Control how much extra time we will wait for dumping the debugging info before we exit and throws timeout exception. + * - ``TORCH_NCCL_DEBUG_INFO_TEMP_FILE`` + - The file into which the debugging info would be dumped. + * - ``TORCH_NCCL_DEBUG_INFO_PIPE_FILE`` + - The pipe file to trigger debugging dump manually, write anything into the pipe would trigger the dump. + * - ``TORCH_NCCL_NAN_CHECK`` + - Control whether to enable NAN check for the input, Error would be thrown if NAN is detected. From 5e7377e044adae33bedcaa18428587b8055cc754 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 9 Jun 2024 21:38:00 +0000 Subject: [PATCH 561/706] [Dynamo][TVM] Make the `opt_level` parameter adjustable (#127876) Fixes #127874 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127876 Approved by: https://github.com/jansel --- test/dynamo/test_backends.py | 1 + torch/_dynamo/backends/tvm.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index f94a4a6e5283..ca935ea69bc8 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -135,6 +135,7 @@ def test_onnxrt(self): def test_tvm(self): self._check_backend_works("tvm") self._check_backend_works("tvm", options={"scheduler": None}) + self._check_backend_works("tvm", options={"opt_level": 0}) def test_list_backends(self): self.assertIn("inductor", torch._dynamo.list_backends()) diff --git a/torch/_dynamo/backends/tvm.py b/torch/_dynamo/backends/tvm.py index bf4413690a1c..6c024b114fe2 100644 --- a/torch/_dynamo/backends/tvm.py +++ b/torch/_dynamo/backends/tvm.py @@ -24,7 +24,7 @@ def tvm( example_inputs, *, options: Optional[MappingProxyType] = MappingProxyType( - {"scheduler": None, "trials": 20000} + {"scheduler": None, "trials": 20000, "opt_level": 3} ), ): import tvm # type: ignore[import] @@ -51,6 +51,7 @@ def tvm( scheduler = os.environ.get("TVM_SCHEDULER", None) trials = options.get("trials", 20000) + opt_level = options.get("opt_level", 3) if scheduler == "auto_scheduler": from tvm import auto_scheduler @@ -83,7 +84,7 @@ def tvm( with auto_scheduler.ApplyHistoryBest(log_file): with tvm.transform.PassContext( - opt_level=3, config={"relay.backend.use_auto_scheduler": True} + opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True} ): lib = relay.build(mod, target=target, params=params) elif scheduler == "meta_schedule": @@ -107,16 +108,18 @@ def tvm( num_trials_per_iter=64, params=params, strategy="evolutionary", + opt_level=opt_level, ) lib = ms.relay_integration.compile_relay( database=database, mod=mod, target=target, params=params, + opt_level=opt_level, ) elif scheduler == "default" or not scheduler: # no autotuning - with tvm.transform.PassContext(opt_level=10): + with tvm.transform.PassContext(opt_level=opt_level): lib = relay.build(mod, target=target, params=params) else: raise NotImplementedError( From 55b2a0a002bc3eb75027d0717909ed2d58fd7748 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Fri, 7 Jun 2024 13:01:07 +0100 Subject: [PATCH 562/706] [AOTAutograd] Use _set_grad_enabled instead of no_grad (#128183) This saves ~1us of overhead from each inductor graph call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128183 Approved by: https://github.com/lezcano --- .../_aot_autograd/runtime_wrappers.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 3293db8f8a93..9dc606113d84 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -205,19 +205,21 @@ def runtime_wrapper(args: List[Any]): compiled_fn, args_, disable_amp=disable_amp, steal_args=True ) else: - # When we have an inference graph, we run with torch.no_grad. + # When we have an inference graph, we run with grad disabled. # It's possible to get an inference graph with inputs that require grad, # in which case we want to make sure autograd is disabled # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on) - if torch.is_grad_enabled(): - with torch.no_grad(): - all_outs = call_func_at_runtime_with_args( - compiled_fn, args, disable_amp=disable_amp, steal_args=True - ) - else: + # NOTE: We use _set_grad_enabled directly to reduce runtime overhead + grad_enabled = torch.is_grad_enabled() + try: + if grad_enabled: + torch._C._set_grad_enabled(False) all_outs = call_func_at_runtime_with_args( compiled_fn, args, disable_amp=disable_amp, steal_args=True ) + finally: + if grad_enabled: + torch._C._set_grad_enabled(True) del args num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices @@ -390,7 +392,7 @@ def runtime_wrapper(args: List[Any]): else: t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy() if runtime_metadata.grad_enabled_mutation is not None: - torch.set_grad_enabled(runtime_metadata.grad_enabled_mutation) + torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation) return ret_outs return runtime_wrapper From 253fa9c7111132a2c86e7ebd836d14c8975a7b07 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Fri, 7 Jun 2024 13:01:08 +0100 Subject: [PATCH 563/706] [AOTAutograd] Remove runtime import from view replay function (#128184) `gen_alias_from_base` spends about ~0.5 us in this import statement, which is called for each view in the graph output. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128184 Approved by: https://github.com/lezcano ghstack dependencies: #128183 --- .../_aot_autograd/functional_utils.py | 31 ++++++++++++++----- torch/_functorch/_aot_autograd/schemas.py | 26 +++------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index a8af6f0366cc..02cf2ab0f428 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -6,7 +6,9 @@ 3. regenerating/replaying views from their base 4. checking if a graph is functional i.e. whether it contains any mutation ops """ +from __future__ import annotations +from typing import Optional import torch from torch import Tensor @@ -220,10 +222,7 @@ def gen_alias_from_base( aliased_base_tensor, target_meta_tensor, target_requires_grad, - # Actual type: Optional[FunctionalTensorMetadataEq] - # Can't use it here because it lives inside schemas.py. Importing that class would lead - # to an error due to an import cycle. - target_functional_tensor=None, + target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None, ): # Patch the correct requires_grad field of the output tensor, depending on whether: # (i) the reconstructed output (out) was came from a tensor that requires grad or not; @@ -245,9 +244,6 @@ def patch_requires_grad(out): and target_functional_tensor is not None and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) ): - from .schemas import FunctionalTensorMetadataEq - - assert isinstance(target_functional_tensor, FunctionalTensorMetadataEq) functional_tensor = target_functional_tensor.tensor out = torch._functionalize_apply_view_metas( @@ -322,6 +318,27 @@ def has_same_metadata(t1, t2): ) +# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata +# after applying all the ViewMeta operations. +class FunctionalTensorMetadataEq: + def __init__(self, tensor: torch.Tensor) -> None: + assert torch._is_functional_tensor(tensor) + self.tensor = tensor + + def __eq__(self, other: object) -> bool: + # If other is None, then it probably means that we weren't able to recreate + # the FunctionalTensorMetadataEq. One of this cases is when we update the + # view metadata by calling: create_synthetic_base_metadata. + if other is None: + return True + + # Comparison agains any other type is not implemented. + if not isinstance(other, FunctionalTensorMetadataEq): + return NotImplemented + + return has_same_metadata(self.tensor, other.tensor) + + # new_arg and arg here are either: # (1) both a FakeTensor # (2) both a traceable tensor subclass that holds a FakeTensor diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 338bff655b66..d5588a6e912c 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -18,7 +18,10 @@ from .. import config -from .functional_utils import _check_if_mutation_can_be_in_graph, has_same_metadata +from .functional_utils import ( + _check_if_mutation_can_be_in_graph, + FunctionalTensorMetadataEq, +) from .utils import strict_zip zip = strict_zip @@ -55,27 +58,6 @@ ) -# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata -# after applying all the ViewMeta operations. -class FunctionalTensorMetadataEq: - def __init__(self, tensor: torch.Tensor) -> None: - assert torch._is_functional_tensor(tensor) - self.tensor = tensor - - def __eq__(self, other: object) -> bool: - # If other is None, then it probably means that we weren't able to recreate - # the FunctionalTensorMetadataEq. One of this cases is when we update the - # view metadata by calling: create_synthetic_base_metadata. - if other is None: - return True - - # Comparison agains any other type is not implemented. - if not isinstance(other, FunctionalTensorMetadataEq): - return NotImplemented - - return has_same_metadata(self.tensor, other.tensor) - - # This class stores info about every user output. @dataclass(frozen=True) class OutputAliasInfo: From cd2ad29afe583ebecf746da43988399c87b1f07e Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Fri, 7 Jun 2024 13:01:09 +0100 Subject: [PATCH 564/706] [inductor] Reduce binding overhead of _reinterpret_tensor (#128185) Going through the dispatcher + pybind11 + torch.ops adds about 2 us overhead per call compared to `PyArgParser`. Note that views of inputs are reconstructed by AOTAutograd before being returned to the python code, so dispatching for autograd's sake shouldn't be required here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128185 Approved by: https://github.com/lezcano ghstack dependencies: #128183, #128184 --- torch/_inductor/codegen/wrapper.py | 2 +- torch/csrc/dynamo/guards.cpp | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 41b9fdc180bc..092dfd4e0b9c 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -512,8 +512,8 @@ def write_header(self) -> None: assert_size_stride = torch._C._dynamo.guards.assert_size_stride empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor alloc_from_pool = torch.ops.inductor._alloc_from_pool - reinterpret_tensor = torch.ops.inductor._reinterpret_tensor async_compile = AsyncCompile() """ diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index c3321b244735..b7fde50a9f1a 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2,8 +2,11 @@ #include #include #include +#include #include +#include #include +#include #include #include #include @@ -742,6 +745,27 @@ static PyObject* _empty_strided_cuda(PyObject* dummy, PyObject* args) { END_HANDLE_TH_ERRORS; } +static PyObject* _reinterpret_tensor(PyObject* dummy, PyObject* args) { + HANDLE_TH_ERRORS; + static PythonArgParser parser( + {"_reinterpret_tensor(Tensor base, IntArrayRef sizes, IntArrayRef strides, int64_t offset_increment=0)"}, + /*traceable=*/true); + + ParsedArgs<4> parsed_args; + auto r = parser.parse(args, /*kwargs=*/nullptr, parsed_args); + + Tensor self = r.tensor(0); + auto sizes = r.intlist(1); + auto strides = r.intlist(2); + auto offset_increment = r.toInt64(3); + + auto res = torch::inductor::_reinterpret_tensor( + self, sizes, strides, offset_increment); + return torch::autograd::utils::wrap(res); + + END_HANDLE_TH_ERRORS; +} + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) static PyMethodDef _methods[] = { {"check_type_id", check_type_id, METH_VARARGS, nullptr}, @@ -750,6 +774,7 @@ static PyMethodDef _methods[] = { {"dict_version", dict_version, METH_VARARGS, nullptr}, {"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr}, {"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr}, + {"_reinterpret_tensor", _reinterpret_tensor, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; static struct PyModuleDef _module = { From d3817d8a60acb990f7ca3c067b289382ff16990e Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Sun, 9 Jun 2024 21:50:54 +0100 Subject: [PATCH 565/706] Don't create python tuple when _maybe_handle_torch_function is called from C++ (#128187) Marginal overhead reduction when calling through the `torch.ops` API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128187 Approved by: https://github.com/lezcano ghstack dependencies: #128183, #128184, #128185 --- torch/csrc/jit/python/init.cpp | 7 ++++++- torch/csrc/jit/python/pybind_utils.cpp | 27 +++++++++++++------------- torch/csrc/jit/python/pybind_utils.h | 2 +- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 2a3cb62c56d7..818f09bee7bc 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1779,13 +1779,18 @@ void initJITBindings(PyObject* module) { [](py::handle op_overload_packet, py::args args, py::kwargs kwargs) { py::list ns_method = op_overload_packet.attr("_qualified_op_name").attr("split")("::"); - return _maybe_handle_torch_function( + auto res = _maybe_handle_torch_function( py::cast(ns_method[0]), py::cast(ns_method[1]), "", false, args, kwargs); + if (res) { + return py::make_tuple(true, *res); + } else { + return py::make_tuple(false, py::none()); + } }); m.def( diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 4cfe3309a766..a731640223c0 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -13,6 +13,7 @@ #include #include +#include namespace torch::jit { @@ -816,7 +817,7 @@ py::object invokeOperatorFromPython( return createPyObjectForStack(std::move(stack)); } -py::tuple _maybe_handle_torch_function( +std::optional _maybe_handle_torch_function( const std::string& ns, const std::string& method_name, const std::string& overload_name, @@ -861,18 +862,16 @@ py::tuple _maybe_handle_torch_function( } std::string module_name("torch.ops"); module_name.append(ns); - return py::make_tuple( - true, - pybind11::reinterpret_steal( - handle_torch_function_no_python_arg_parser( - overloaded_args, - args.ptr(), - kwargs.ptr(), - method_name.c_str(), - self_func.ptr(), - module_name.c_str()))); + return {pybind11::reinterpret_steal( + handle_torch_function_no_python_arg_parser( + overloaded_args, + args.ptr(), + kwargs.ptr(), + method_name.c_str(), + self_func.ptr(), + module_name.c_str()))}; } - return py::make_tuple(false, py::none()); + return std::nullopt; } py::object _get_operation_for_overload_or_packet( @@ -887,9 +886,9 @@ py::object _get_operation_for_overload_or_packet( std::string overload_name = operations[0]->schema().overload_name(); auto res = _maybe_handle_torch_function( ns, method_name, overload_name, is_overload, args, kwargs); - auto torch_function_called = py::cast(res[0]); + auto torch_function_called = res.has_value(); return torch_function_called - ? res[1] + ? *res : invokeOperatorFromPython(operations, args, kwargs, dk); } diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 242da11af7c0..23fda5b0d784 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -1257,7 +1257,7 @@ TORCH_PYTHON_API py::object invokeOperatorFromPython( const py::kwargs& kwargs, std::optional dk = c10::nullopt); -TORCH_PYTHON_API py::tuple _maybe_handle_torch_function( +TORCH_PYTHON_API std::optional _maybe_handle_torch_function( const std::string& ns, const std::string& method_name, const std::string& overload_name, From 26f6a87ae9d0077dc9327e72b2c0f654e56d86ab Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 10 Jun 2024 01:57:49 +0000 Subject: [PATCH 566/706] [5/N] Remove unused functions (#127185) Follows #128193 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127185 Approved by: https://github.com/ezyang --- .../src/ATen/functorch/BatchRulesUnaryOps.cpp | 7 ------- aten/src/ATen/native/LossNLL.cpp | 9 --------- aten/src/ATen/native/TriangularOps.cpp | 4 ---- aten/src/ATen/native/mkldnn/Normalization.cpp | 4 ++-- .../ATen/native/quantized/cpu/BinaryOps.cpp | 3 ++- .../quantized/cpu/UpSampleNearest2d.cpp | 20 ------------------- aten/src/ATen/native/sparse/SoftMax.cpp | 9 --------- 7 files changed, 4 insertions(+), 52 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp index d8213a1b9e0d..85210d0b214c 100644 --- a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp @@ -59,13 +59,6 @@ view_as_complex_batch_rule(const Tensor& self, optional self_bdim) { return std::make_tuple(result, 0); } -std::tuple> -to_other_batch_rule(const Tensor& self, optional self_bdim, - const Tensor& other, optional other_bdim, - bool non_blocking, - bool copy, std::optional memory_format) { - return std::make_tuple(self.to(other, non_blocking, copy, memory_format), self_bdim); -} } TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { diff --git a/aten/src/ATen/native/LossNLL.cpp b/aten/src/ATen/native/LossNLL.cpp index b7809ab21dd5..35ae21c32736 100644 --- a/aten/src/ATen/native/LossNLL.cpp +++ b/aten/src/ATen/native/LossNLL.cpp @@ -675,15 +675,6 @@ Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const std::op return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, std::move(ignore_index))); } -// Duplicate of above code for non-symbolic ints. Kept for BC purposes and to minimize breakages. -static Tensor nll_loss(const Tensor & self, const Tensor & target, const std::optional& weight_opt, int64_t reduction, int64_t ignore_index) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, ignore_index)); -} - Tensor nll_loss_nd_symint( const Tensor& self, const Tensor& target, diff --git a/aten/src/ATen/native/TriangularOps.cpp b/aten/src/ATen/native/TriangularOps.cpp index 9cb75a0eccf4..62440b956c80 100644 --- a/aten/src/ATen/native/TriangularOps.cpp +++ b/aten/src/ATen/native/TriangularOps.cpp @@ -180,10 +180,6 @@ TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) { compute_triu_tril(self, k, result); } -static Tensor trace_backward(const Tensor& grad, at::IntArrayRef sizes) { - return at::native::trace_backward_symint(grad, c10::fromIntArrayRefSlow(sizes)); -} - Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) { if (sizes.size() != 2) { throw std::runtime_error("expected matrix input"); diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index 47dbe792d73a..6ed703c3b5fd 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -14,6 +14,7 @@ #include #include #endif +#include #if !AT_MKLDNN_ENABLED() @@ -37,7 +38,7 @@ std::tuple mkldnn_batch_norm_backward( TORCH_CHECK(false, "mkldnn_batch_norm_backward: ATen not compiled with MKLDNN support"); } -static std::tuple mkldnn_layer_norm_last_index_weight_bias_f32( +std::tuple mkldnn_layer_norm_last_index_weight_bias_f32( const Tensor& input, IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias, double eps, bool inplace) { @@ -81,7 +82,6 @@ std::tuple _new_batch_norm_backward_mkldnn( #else // AT_MKLDNN_ENABLED #include -#include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp b/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp index 8b5fb286ec61..be39a7db2cfa 100644 --- a/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -500,7 +501,7 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { } // namespace -static Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){ +Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){ return qadd(std::move(qa), std::move(qb), scale, zero_point); } diff --git a/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp b/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp index 191407bed66a..03cbb080d558 100644 --- a/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp +++ b/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp @@ -218,25 +218,5 @@ Tensor _upsample_nearest_exact2d_quantized_cpu( return _upsample_nearest2d_quantized_cpu(input, osize, scale_h, scale_w); } -static Tensor upsample_nearest2d_quantized_cpu( - const Tensor& input, - at::OptionalIntArrayRef output_size, - std::optional> scale_factors) { - auto osize = compute_output_size(input.sizes(), output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return upsample_nearest2d_quantized_cpu(input, osize, scale_h, scale_w); -} - -static Tensor _upsample_nearest_exact2d_quantized_cpu( - const Tensor& input, - at::OptionalIntArrayRef output_size, - std::optional> scale_factors) { - auto osize = compute_output_size(input.sizes(), output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return _upsample_nearest_exact2d_quantized_cpu(input, osize, scale_h, scale_w); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/sparse/SoftMax.cpp b/aten/src/ATen/native/sparse/SoftMax.cpp index 668032cb588e..33ac3d176e6c 100644 --- a/aten/src/ATen/native/sparse/SoftMax.cpp +++ b/aten/src/ATen/native/sparse/SoftMax.cpp @@ -624,15 +624,6 @@ Tensor _sparse_softmax(const Tensor& self, Dimname dim, optional dty return at::_sparse_softmax(self, dimname_to_position(self, dim), dtype); } -static Tensor _sparse_log_softmax(const Tensor& input_, const int64_t dim_) { - auto result = [&]() { - NoNamesGuard guard; - return at::_sparse_log_softmax(input_, dim_, false); - }(); - namedinference::propagate_names(result, input_); - return result; -} - Tensor _sparse_log_softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; From df43d5843edd9abd95d2c039670bce51375d9c06 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 10 Jun 2024 02:45:46 +0000 Subject: [PATCH 567/706] fix miss isa bool check (#128274) New cpp builder missed ISA bool(dry-compile) check. image @jgong5 Found this missing and then I submit this PR to fix it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128274 Approved by: https://github.com/jgong5, https://github.com/ezyang --- torch/_inductor/codecache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 71815a31718e..ae8453660813 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1494,7 +1494,7 @@ def valid_vec_isa_list() -> List[VecISA]: isa_list = [] _cpu_supported_isa = x86_isa_checker() for isa in supported_vec_isa_list: - if str(isa) in _cpu_supported_isa: + if str(isa) in _cpu_supported_isa and isa: isa_list.append(isa) return isa_list From b66e3f0957b96b058c9b632ca60833d9717a9d8a Mon Sep 17 00:00:00 2001 From: CaoE Date: Thu, 6 Jun 2024 07:52:26 -0700 Subject: [PATCH 568/706] Set simdlen based on ATEN_CPU_CAPABILITY (#123514) It is part of https://github.com/pytorch/pytorch/issues/123224. Set simdlen based on the environment ATEN_CPU_CAPABILITY to control CPU vec ISA like eager. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123514 Approved by: https://github.com/jgong5, https://github.com/peterbell10 --- test/inductor/test_cpu_repro.py | 137 ++++++++++++++++++++++-- test/inductor/test_extension_backend.py | 4 + test/inductor/test_torchinductor.py | 42 +++++++- torch/_dynamo/testing.py | 6 ++ torch/_inductor/codecache.py | 34 +++++- torch/_inductor/codegen/cpp_prefix.h | 1 + 6 files changed, 205 insertions(+), 19 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index b2ab30832e06..1f04b71a961b 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4,6 +4,7 @@ import functools import itertools import math +import os import platform import sys import unittest @@ -66,12 +67,13 @@ check_model = test_torchinductor.check_model requires_vectorization = unittest.skipUnless( - codecache.valid_vec_isa_list(), "Does not support vectorization" + codecache.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default", + "Does not support vectorization", ) def check_metrics_vec_kernel_count(num_expected_vec_kernels): - if codecache.valid_vec_isa_list(): + if codecache.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default": assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels @@ -1580,6 +1582,71 @@ def fn(x): metrics.reset() self.common(fn, (value,)) + @unittest.skipIf( + not codecache.valid_vec_isa_list() + or "avx2" in [str(vec_isa) for vec_isa in codecache.valid_vec_isa_list()], + "Does not support vectorization or not s390x/neon machine", + ) + @patch("torch.cuda.is_available", lambda: False) + def test_auto_zvec_neon_simd(self): + vec_zvec_neon = codecache.valid_vec_isa_list()[0] + self.assertTrue(vec_zvec_neon.bit_width() == 256) + + with config.patch({"cpp.simdlen": 0}): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": 1}): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": 257}): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": 256}): + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon) + + pre_var = os.getenv("ATEN_CPU_CAPABILITY") + if pre_var: + os.environ.pop("ATEN_CPU_CAPABILITY") + + try: + with config.patch({"cpp.simdlen": None}): + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx2" + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx512" + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "default" + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "neon" + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "zvector" + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon) + finally: + if pre_var: + os.environ["ATEN_CPU_CAPABILITY"] = pre_var + elif os.getenv("ATEN_CPU_CAPABILITY"): + os.environ.pop("ATEN_CPU_CAPABILITY") + @unittest.skipIf( platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(), "Does not support vectorization or not x86_64 machine", @@ -1595,13 +1662,6 @@ def test_auto_simd(self): self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) - with config.patch({"cpp.simdlen": None}): - isa = codecache.pick_vec_isa() - if vec_avx512 in codecache.valid_vec_isa_list(): - self.assertTrue(isa == vec_avx512) - else: - self.assertTrue(isa == vec_avx2) - with config.patch({"cpp.simdlen": 0}): isa = codecache.pick_vec_isa() self.assertFalse(isa) @@ -1631,6 +1691,60 @@ def test_auto_simd(self): isa = codecache.pick_vec_isa() self.assertTrue(isa == vec_avx2) + pre_var = os.getenv("ATEN_CPU_CAPABILITY") + if pre_var: + os.environ.pop("ATEN_CPU_CAPABILITY") + + try: + with config.patch({"cpp.simdlen": None}): + isa = codecache.pick_vec_isa() + if vec_avx512 in codecache.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx2" + isa = codecache.pick_vec_isa() + if vec_avx512 in codecache.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx2) + elif vec_avx2 in codecache.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx512" + isa = codecache.pick_vec_isa() + if vec_avx512 in codecache.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "default" + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "neon" + isa = codecache.pick_vec_isa() + if vec_avx512 in codecache.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "zvector" + isa = codecache.pick_vec_isa() + if vec_avx512 in codecache.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + finally: + if pre_var: + os.environ["ATEN_CPU_CAPABILITY"] = pre_var + elif os.getenv("ATEN_CPU_CAPABILITY"): + os.environ.pop("ATEN_CPU_CAPABILITY") + @requires_vectorization @patch("torch.cuda.is_available", lambda: False) def test_masked_fill_softmax(self): @@ -3371,6 +3485,7 @@ def forward(self, idx, x): self.common(m, (idx, x)) check_metrics_vec_kernel_count(1) + @requires_vectorization def test_embedding_vec_bf16(self): class M(torch.nn.Module): def __init__(self): @@ -3655,7 +3770,7 @@ def fn(x): x = torch.randint(0, 100, (819,), dtype=torch.int64) metrics.reset() self.common(fn, (x,)) - assert metrics.generated_cpp_vec_kernel_count == 1 + check_metrics_vec_kernel_count(1) def test_reduction_float_to_int64(self): # https://github.com/pytorch/pytorch/issues/124821 @@ -3665,7 +3780,7 @@ def fn(x): x = torch.randint(0, 100, (22, 51), dtype=torch.int64) metrics.reset() self.common(fn, (x,)) - assert metrics.generated_cpp_vec_kernel_count == 1 + check_metrics_vec_kernel_count(1) @config.patch({"cpp.dynamic_threads": True}) def test_reduction_with_dynamic_threads(self): diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py index 3cb473255e74..a3bad9582d8c 100644 --- a/test/inductor/test_extension_backend.py +++ b/test/inductor/test_extension_backend.py @@ -8,6 +8,7 @@ import torch._dynamo import torch.utils.cpp_extension from torch._C import FileCheck +from torch._dynamo.testing import expectedFailureScalar try: from extension_backends.cpp.extension_codegen_backend import ( @@ -103,6 +104,9 @@ def tearDown(self): # return the working directory (see setUp) os.chdir(self.old_working_dir) + # Fails when testing the scalar version + # See https://github.com/pytorch/pytorch/issues/126372. + @expectedFailureScalar def test_open_device_registration(self): torch.utils.rename_privateuse1_backend("extension_device") torch._register_device_module("extension_device", self.module) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 53167a83ecd8..fe23b07e12da 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -34,6 +34,7 @@ from torch._dynamo.testing import ( CompileCounterWithBackend, expectedFailureCodegenDynamic, + expectedFailureScalar, rand_strided, same, skipIfPy312, @@ -1315,6 +1316,9 @@ def fn(a): self.common(fn, (torch.randn(1024),)) + # Fails when testing the scalar version + # See https://github.com/pytorch/pytorch/issues/128029. + @expectedFailureScalar @skipIfRocm @config.patch(debug_index_asserts=False) def test_neg_index(self): @@ -1577,16 +1581,40 @@ def test_multilayer_var(self): def fn(a): return torch.var(a) - self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float32),))) - self.common(fn, ((torch.rand((14923), dtype=torch.float32),))) + atol = None + rtol = None + if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default": + atol = 1e-4 + rtol = 1e-4 + self.common( + fn, + ((torch.rand((10, 3, 352, 352), dtype=torch.float32),)), + rtol=rtol, + atol=atol, + ) + self.common( + fn, ((torch.rand((14923), dtype=torch.float32),)), rtol=rtol, atol=atol + ) @skipCPUIf(IS_MACOS, "fails on macos") def test_multilayer_var_lowp(self): def fn(a): return torch.var(a) - self.common(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),)) - self.common(fn, (torch.rand((14923), dtype=torch.float16),)) + atol = None + rtol = None + if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default": + atol = 1e-3 + rtol = 1e-3 + self.common( + fn, + (torch.rand((16, 16, 352, 352), dtype=torch.float16),), + rtol=rtol, + atol=atol, + ) + self.common( + fn, (torch.rand((14923), dtype=torch.float16),), rtol=rtol, atol=atol + ) def test_split_cumsum(self): def fn(a): @@ -8199,7 +8227,7 @@ def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4): rand_strided(shape, stride, dtype).requires_grad_(True).add(1) for shape, stride, dtype in args ] - self.common(forward, args) + self.common(forward, args, atol=1e-05, rtol=1e-05) @requires_gpu() def test_tmp_not_defined_issue3(self): @@ -9281,6 +9309,7 @@ def func(arg0_1): # To support this behavior, we need to allow const-propping tensors that store symint data. # For now, dynamo will explicitly graph break when it encounters user code with this behavior. @expectedFailureCodegenDynamic + @expectedFailureScalar def test_AllenaiLongformerBase_repro(self): def fn(query, scores, window_overlap): batch_size, seq_len, num_heads, _ = query.size() @@ -9316,6 +9345,9 @@ def fn(query, scores, window_overlap): opt_fn = torch._dynamo.optimize("inductor")(fn) _, code = run_and_get_cpp_code(opt_fn, *args) print(code) + # When testing the scalar version, i.e., ATEN_CPU_CAPABILITY=default, + # static_cast(256) is not found, but static_cast(256). + # See https://github.com/pytorch/pytorch/issues/126262. FileCheck().check_count( "static_cast(256)", 1, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 527e0138fc25..d254a5e261ed 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -381,6 +381,12 @@ def expectedFailureDynamicWrapper(fn): return fn +def expectedFailureScalar(fn): + if os.getenv("ATEN_CPU_CAPABILITY") == "default": + return unittest.expectedFailure(fn) + return fn + + def reset_rng_state(use_xla=False): torch.manual_seed(1337) random.seed(1337) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index ae8453660813..d497272d00a3 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1465,6 +1465,31 @@ def _check_and_append_supported_isa( supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()] +def get_isa_from_cpu_capability( + capability: str | None, vec_isa_list: List[VecISA], invalid_vec_isa: InvalidVecISA +): + # VSX is not supported in inductor + capability_to_isa_str = { + "default": "INVALID_VEC_ISA", + "neon": "asimd", + "zvector": "zvector", + "avx2": "avx2", + "avx512": "avx512", + } + if capability in capability_to_isa_str.keys(): + isa_str = capability_to_isa_str[capability] + if isa_str == "INVALID_VEC_ISA": + return invalid_vec_isa + for vec_isa in vec_isa_list: + if isa_str == str(vec_isa): + return vec_isa + + if capability: + warnings.warn(f"ignoring invalid value for ATEN_CPU_CAPABILITY {capability}") + + return vec_isa_list[0] + + # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content # might have too much redundant content that is useless for ISA check. Hence, # we only cache some key isa information. @@ -1507,10 +1532,13 @@ def pick_vec_isa() -> VecISA: if not _valid_vec_isa_list: return invalid_vec_isa - # If the simdlen is None, it indicates determine the vectorization length automatically + # If the simdlen is None, set simdlen based on the environment ATEN_CPU_CAPABILITY + # to control CPU vec ISA + if config.cpp.simdlen is None: - assert _valid_vec_isa_list - return _valid_vec_isa_list[0] + return get_isa_from_cpu_capability( + os.getenv("ATEN_CPU_CAPABILITY"), _valid_vec_isa_list, invalid_vec_isa + ) for isa in _valid_vec_isa_list: if config.cpp.simdlen == isa.bit_width(): diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 6898a8a52112..1492023eed38 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -24,6 +24,7 @@ #include #include #include +#include #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) #define INDUCTOR_USE_VECTOR_TYPES() 1 From 04da6aeb61f4d57bf73ed1054dd897abbcceca83 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Sun, 9 Jun 2024 15:20:55 +0000 Subject: [PATCH 569/706] Add OpInfo entry for alias_copy (#127232) (#128142) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128142 Approved by: https://github.com/lezcano --- .../ATen/functorch/BatchRulesDecompositions.cpp | 1 + test/distributed/_tensor/test_dtensor_ops.py | 1 + .../HasDecompTest.test_has_decomposition.expect | 2 -- test/functorch/test_vmap_registrations.py | 1 + test/onnx/test_fx_op_consistency.py | 4 ++++ tools/autograd/gen_variable_type.py | 1 + torch/_decomp/__init__.py | 1 + torch/_inductor/exc.py | 2 +- torch/_refs/__init__.py | 4 ++++ .../_internal/common_methods_invocations.py | 15 +++++++++++++++ 10 files changed, 29 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 3e064d6c39dc..a0007aa18a00 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -324,6 +324,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(type_as); OP_DECOMPOSE(linalg_diagonal); OP_DECOMPOSE(diagonal_copy); + OP_DECOMPOSE(alias_copy); m.impl("pad", native::pad_symint); m.impl("_pad_circular", native::_pad_circular_symint); OP_DECOMPOSE(swapdims_); diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 83f0bb875167..07f8bfedc615 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -102,6 +102,7 @@ def wrapped(fn): xfail("addr"), xfail("all"), xfail("allclose"), + xfail("alias_copy"), xfail("amax"), xfail("amin"), xfail("aminmax"), diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index ad9cf07d7550..eeee3685e1fb 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -647,8 +647,6 @@ aten::adaptive_max_pool3d_backward.grad_input aten::addbmm aten::addbmm.out aten::addr_ -aten::alias_copy -aten::alias_copy.out aten::allclose aten::angle aten::angle.out diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index 967152945af5..737927a60f80 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -25,6 +25,7 @@ } xfail_functorch_batched_decomposition = { + "aten::alias_copy", "aten::diagonal_copy", "aten::is_same_size", "aten::unfold_copy", diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index e72c4206d578..6d675d446030 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -218,6 +218,10 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.COMPLEX_TYPES, reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") ), + xfail( + "alias_copy", + reason="OnnxExporterError: Failed to export model", + ), xfail( "allclose", reason=onnx_test_common.reason_dynamo_does_not_support("Allclose") diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index b9651ea2da80..6abb13d244e9 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -305,6 +305,7 @@ "linalg_eig", "diagonal_copy", "diagonal_scatter", + "alias_copy", "select_backward", "diagonal_backward", "slice_backward", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index e0c7e5b6f49d..7674e5f466a8 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -261,6 +261,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.addcmul_, aten.addr, aten.affine_grid_generator, + aten.alias_copy, aten.all, aten.aminmax, aten.arange.default, diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index 27dcc6d8ef2d..8a172d8c29b1 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -46,7 +46,7 @@ def __init__(self, target, args, kwargs): There is a decomposition available for {target} in torch._decomp.get_decompositions(). Please add this operator to the - `decompositions` list in torch._inductor.decompositions + `decompositions` list in torch._inductor.decomposition """ ) ) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index db1f2a99d3d4..e0157368c62c 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -233,6 +233,7 @@ # View & Shape Ops # "alias", + "alias_copy", "atleast_1d", "atleast_2d", "atleast_3d", @@ -4462,6 +4463,9 @@ def alias(a: TensorLikeType) -> TensorLikeType: return prims.view_of(a) +alias_copy = _make_copy_from_view(alias) + + @register_decomposition(aten.transpose) def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 476d85d5de6f..a0b5d91ac67e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11584,6 +11584,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1): out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:] return np.reshape(input, out_shape) + +def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad)) + yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) + + # Operator database (sorted alphabetically) op_db: List[OpInfo] = [ UnaryUfuncInfo('abs', @@ -13087,6 +13093,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_diagonal_scatter), + OpInfo('alias_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_alias_copy, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), BinaryUfuncInfo('eq', ref=np.equal, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), @@ -23224,6 +23235,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # # View & Shape OpInfos # + PythonRefInfo( + "_refs.alias_copy", + torch_opinfo_name="alias_copy", + ), PythonRefInfo( "_refs.atleast_1d", torch_opinfo_name="atleast_1d", From c993f1b37fe167c186911885dd680bba52471aeb Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 5 Jun 2024 15:05:55 +0000 Subject: [PATCH 570/706] Fix edge cases for gather in inductor (#126893) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126893 Approved by: https://github.com/peterbell10 ghstack dependencies: #126876 --- test/inductor/test_torchinductor_opinfo.py | 1 + torch/_inductor/lowering.py | 18 ++++++++++++++++-- .../_internal/common_methods_invocations.py | 4 ++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index b66c0ce0832f..1d9c733a7302 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -441,6 +441,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "cummax", "cummin", "nextafter", + "gather", "_chunk_cat", "constant_pad_nd", } diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 0461cc3683d5..e432fd45cd94 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2781,18 +2781,29 @@ def gather(x, dim, index, sparse_grad=False): # sparse_grad doesn't affect forward computation, # and backward tracing is taken care of by AOT Autograd assert isinstance(x, TensorBox) + if index.get_numel() == 0: + # Empty index case. Return an empty array with the same shape + return new_empty(x, index.get_size()) + assert index.get_dtype() == torch.int64 size = x.get_size() offset = len(size) == 0 dim = _validate_dim(x, dim, offset) + if offset: + x = expand(x, [1]) + size = [1] + x_loader = x.make_loader() index_loader = index.make_loader() def fn(idx): idx = list(idx) - if len(idx) != 0: - idx[dim] = ops.indirect_indexing(index_loader(idx), size[dim]) + gather_idx = ops.indirect_indexing(index_loader(idx), size[dim]) + if len(idx) == 0: + idx = [gather_idx] + else: + idx[dim] = gather_idx return x_loader(idx) return Pointwise.create( @@ -3272,6 +3283,9 @@ def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = if isinstance(index, TensorBox) and len(index.get_size()) == 0: index = view(index, [1]) + if index.get_numel() == 0: + return self + dim = _validate_dim(self, dim) self.realize() diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a0b5d91ac67e..5c32d1a11aff 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2623,6 +2623,10 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs): make_arg((S,)), 0, torch.tensor([], dtype=torch.uint8, device=device)) + yield SampleInput( + make_arg((S,)), + 0, + torch.tensor([[], []], dtype=torch.uint8, device=device)) # 0D tensor case yield SampleInput( make_arg(()), From 3b73f5de3a022b423a2a90fbdd9109997474b155 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 10 Jun 2024 16:17:16 +0000 Subject: [PATCH 571/706] Revert "Add OpInfo entry for alias_copy (#127232) (#128142)" This reverts commit 04da6aeb61f4d57bf73ed1054dd897abbcceca83. Reverted https://github.com/pytorch/pytorch/pull/128142 on behalf of https://github.com/DanilBaibak due to The changes broke the test_output_match_alias_copy_cpu_complex64 test. ([comment](https://github.com/pytorch/pytorch/pull/128142#issuecomment-2158793878)) --- .../ATen/functorch/BatchRulesDecompositions.cpp | 1 - test/distributed/_tensor/test_dtensor_ops.py | 1 - .../HasDecompTest.test_has_decomposition.expect | 2 ++ test/functorch/test_vmap_registrations.py | 1 - test/onnx/test_fx_op_consistency.py | 4 ---- tools/autograd/gen_variable_type.py | 1 - torch/_decomp/__init__.py | 1 - torch/_inductor/exc.py | 2 +- torch/_refs/__init__.py | 4 ---- .../_internal/common_methods_invocations.py | 15 --------------- 10 files changed, 3 insertions(+), 29 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index a0007aa18a00..3e064d6c39dc 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -324,7 +324,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(type_as); OP_DECOMPOSE(linalg_diagonal); OP_DECOMPOSE(diagonal_copy); - OP_DECOMPOSE(alias_copy); m.impl("pad", native::pad_symint); m.impl("_pad_circular", native::_pad_circular_symint); OP_DECOMPOSE(swapdims_); diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 07f8bfedc615..83f0bb875167 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -102,7 +102,6 @@ def wrapped(fn): xfail("addr"), xfail("all"), xfail("allclose"), - xfail("alias_copy"), xfail("amax"), xfail("amin"), xfail("aminmax"), diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index eeee3685e1fb..ad9cf07d7550 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -647,6 +647,8 @@ aten::adaptive_max_pool3d_backward.grad_input aten::addbmm aten::addbmm.out aten::addr_ +aten::alias_copy +aten::alias_copy.out aten::allclose aten::angle aten::angle.out diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index 737927a60f80..967152945af5 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -25,7 +25,6 @@ } xfail_functorch_batched_decomposition = { - "aten::alias_copy", "aten::diagonal_copy", "aten::is_same_size", "aten::unfold_copy", diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index 6d675d446030..e72c4206d578 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -218,10 +218,6 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.COMPLEX_TYPES, reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") ), - xfail( - "alias_copy", - reason="OnnxExporterError: Failed to export model", - ), xfail( "allclose", reason=onnx_test_common.reason_dynamo_does_not_support("Allclose") diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 6abb13d244e9..b9651ea2da80 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -305,7 +305,6 @@ "linalg_eig", "diagonal_copy", "diagonal_scatter", - "alias_copy", "select_backward", "diagonal_backward", "slice_backward", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 7674e5f466a8..e0c7e5b6f49d 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -261,7 +261,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.addcmul_, aten.addr, aten.affine_grid_generator, - aten.alias_copy, aten.all, aten.aminmax, aten.arange.default, diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index 8a172d8c29b1..27dcc6d8ef2d 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -46,7 +46,7 @@ def __init__(self, target, args, kwargs): There is a decomposition available for {target} in torch._decomp.get_decompositions(). Please add this operator to the - `decompositions` list in torch._inductor.decomposition + `decompositions` list in torch._inductor.decompositions """ ) ) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index e0157368c62c..db1f2a99d3d4 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -233,7 +233,6 @@ # View & Shape Ops # "alias", - "alias_copy", "atleast_1d", "atleast_2d", "atleast_3d", @@ -4463,9 +4462,6 @@ def alias(a: TensorLikeType) -> TensorLikeType: return prims.view_of(a) -alias_copy = _make_copy_from_view(alias) - - @register_decomposition(aten.transpose) def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5c32d1a11aff..edacc3c4023e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11588,12 +11588,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:] return np.reshape(input, out_shape) - -def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): - yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad)) - yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) - - # Operator database (sorted alphabetically) op_db: List[OpInfo] = [ UnaryUfuncInfo('abs', @@ -13097,11 +13091,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_diagonal_scatter), - OpInfo('alias_copy', - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), - sample_inputs_func=sample_inputs_alias_copy, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True), BinaryUfuncInfo('eq', ref=np.equal, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), @@ -23239,10 +23228,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): # # View & Shape OpInfos # - PythonRefInfo( - "_refs.alias_copy", - torch_opinfo_name="alias_copy", - ), PythonRefInfo( "_refs.atleast_1d", torch_opinfo_name="atleast_1d", From d22287d1ad3406a547a150b9504ae762e175f8f1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 10 Jun 2024 16:29:20 +0000 Subject: [PATCH 572/706] Revert "Fix 'get_real_value' on placeholder nodes (#127698)" This reverts commit 19b31d899a78a6806314bcc73b88172dabf0c26e. Reverted https://github.com/pytorch/pytorch/pull/127698 on behalf of https://github.com/clee2000 due to broke (executorch?) internal tests D58295865 ([comment](https://github.com/pytorch/pytorch/pull/127696#issuecomment-2158820093)) --- test/dynamo/test_decorators.py | 16 ---------------- torch/_dynamo/utils.py | 3 --- 2 files changed, 19 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 440872ecc7bc..94b0d5bf3bb6 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -487,22 +487,6 @@ def fn(B): fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B) ) - def test_assume_constant_result_on_computation_with_graph_input(self): - @torch._dynamo.assume_constant_result - def check(y): - return y[0].item() == 1 - - def fn(x, y): - if check(y): - return x + 2 - else: - return x + 1 - - y = torch.tensor([1]) - x = torch.tensor(1) - - self.assertEqual(fn(x, y), torch.compile(fn)(x, y)) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 6da8b514f16b..63f339fe90ec 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1941,9 +1941,6 @@ def get_real_value(node, tracer): lambda n: get_real_value(n, tracer), ) - if op == "placeholder" and "grapharg" in node.meta: - return node.meta["grapharg"].example - if op == "call_module": nn_module = tracer.output_graph.nn_modules[node.target] if not is_lazy_module(nn_module): From ca561d639b40a9fce088262e2fb35c7dfb61d588 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 10 Jun 2024 16:29:20 +0000 Subject: [PATCH 573/706] Revert "Fix 'get_attr' call in dynamo 'run_node' (#127696)" This reverts commit b741819b0580204e6a6b60c62ce44dacaf7787c8. Reverted https://github.com/pytorch/pytorch/pull/127696 on behalf of https://github.com/clee2000 due to broke (executorch?) internal tests D58295865 ([comment](https://github.com/pytorch/pytorch/pull/127696#issuecomment-2158820093)) --- test/dynamo/test_decorators.py | 22 ---------------------- torch/_dynamo/utils.py | 2 +- 2 files changed, 1 insertion(+), 23 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 94b0d5bf3bb6..890edca40ccc 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -465,28 +465,6 @@ def fn(a, b, c): self.assertEqual(cnt.frame_count, 1) - def test_assume_constant_result_on_user_defined_fn(self): - @torch._dynamo.assume_constant_result - def const_fn(n, s): - return torch.full([n], s) - - def fn(B): - B = const_fn(B.size(0), 13) - X = B * 2 - return X.tolist() - - B_list = [8] * 32 - - B = torch.tensor(B_list, dtype=torch.int32) - torch._dynamo.decorators.mark_static(B, 0) - - torch._dynamo.config.capture_scalar_outputs = True - torch._dynamo.config.capture_dynamic_output_shape_ops = True - - self.assertEqual( - fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B) - ) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 63f339fe90ec..4dfddcc2cdf8 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1906,7 +1906,7 @@ def make_error_message(e): assert nnmodule is not None return nnmodule(*args, **kwargs) elif op == "get_attr": - return tracer.output_graph.get_submodule(node.target) + return tracer.get_submodule(node.target) elif op == "placeholder": assert "example_value" in node.meta return node.meta["example_value"] From 7b9c5e0e3fb860ed7be4e3e7c74296a452178c00 Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Mon, 10 Jun 2024 16:48:58 +0000 Subject: [PATCH 574/706] Turn on GraphTransformObserver for inductor (#127962) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The FX graphs for some PT2 models are very complicated, Inductor usually goes through many passes of graph optimization to generate the final FX graph. It’s very difficult to see the change in each pass, and check if the optimized graph is correct and optimal. GraphTransformObserver is an observer listening to all add/erase node events on GraphModule during a graph transform pass, and save the changed nodes. When the pass is done and if there is any change in the graph, GraphTransformObserver will save the SVG files of the input graph and the output graph for that pass. This PR is to enable GraphTransformObserver for inductor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127962 Approved by: https://github.com/jansel --- .../inductor/test_graph_transform_observer.py | 72 ++++++++++++++++++ torch/_inductor/fx_passes/ddp_fusion.py | 29 ++++--- .../_inductor/fx_passes/group_batch_fusion.py | 10 ++- torch/_inductor/fx_passes/joint_graph.py | 20 +++-- torch/_inductor/fx_passes/post_grad.py | 11 ++- torch/_inductor/fx_passes/pre_grad.py | 31 ++++++-- torch/_inductor/fx_passes/replace_random.py | 7 +- torch/_inductor/pattern_matcher.py | 75 +++++++++++-------- torch/_inductor/utils.py | 4 +- torch/fx/passes/graph_transform_observer.py | 2 + 10 files changed, 201 insertions(+), 60 deletions(-) create mode 100644 test/inductor/test_graph_transform_observer.py diff --git a/test/inductor/test_graph_transform_observer.py b/test/inductor/test_graph_transform_observer.py new file mode 100644 index 000000000000..678458284c4f --- /dev/null +++ b/test/inductor/test_graph_transform_observer.py @@ -0,0 +1,72 @@ +# Owner(s): ["module: inductor"] +import glob +import math +import os +import shutil +import tempfile + +import torch +import torch._dynamo +import torch._inductor.config as inductor_config +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION +from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm +from torch.testing._internal.inductor_utils import HAS_CUDA + +try: + import pydot # noqa: F401 + + HAS_PYDOT = True +except ImportError: + HAS_PYDOT = False + + +HAS_DOT = True if shutil.which("dot") is not None else False + + +class TestGraphTransformObserver(TestCase): + @skipIfRocm + def test_sdpa_rewriter(self): + if not ( + HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION and HAS_PYDOT and HAS_DOT + ): + return + + def dot_prod_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" + return ( + torch.matmul(query, key.transpose(-2, -1)) + .div(math.sqrt(key.shape[-1])) + .softmax(dim=-1) + .matmul(value) + ) + + log_url = tempfile.mkdtemp() + inductor_config.trace.log_url_for_graph_xform = log_url + inductor_config.force_disable_caches = True + compiled_fn = torch.compile(dot_prod_attention, fullgraph=True) + + tensor_shape = (4, 2, 16, 32) + q = torch.randn(tensor_shape, device="cuda") + k = torch.randn(tensor_shape, device="cuda") + v = torch.randn(tensor_shape, device="cuda") + compiled_fn(q, k, v) + + found_input_svg = False + found_output_svg = False + for filepath_object in glob.glob(log_url + "/*"): + if os.path.isfile(filepath_object): + if filepath_object.endswith("input_graph.svg"): + found_input_svg = True + elif filepath_object.endswith("output_graph.svg"): + found_output_svg = True + + self.assertTrue(found_input_svg) + self.assertTrue(found_output_svg) + + +if __name__ == "__main__": + if IS_LINUX: + run_tests() diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 532a546dd4b6..6ef0f71a807c 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -22,9 +22,11 @@ import torch import torch.fx as fx from torch._dynamo.utils import counters +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from .. import config from ..fx_utils import get_fake_args_kwargs from ..virtualized import V @@ -578,14 +580,19 @@ def schedule_comm_wait(graph: fx.Graph) -> None: def fuse_ddp_communication( graph: fx.Graph, passes: List[Union[Callable[..., None], str]], bucket_size_mb: int ) -> None: - for pa in passes: - if isinstance(pa, str): - func = globals()[pa] - else: - func = pa - if "bucket_size_mb" in { - v.name for v in inspect.signature(func).parameters.values() - }: - func(graph, bucket_size_mb=bucket_size_mb) - else: - func(graph) + for i, pa in enumerate(passes): + with GraphTransformObserver( + graph.owning_module, + f"fuse_ddp_communication_pass_{i}", + config.trace.log_url_for_graph_xform, + ): + if isinstance(pa, str): + func = globals()[pa] + else: + func = pa + if "bucket_size_mb" in { + v.name for v in inspect.signature(func).parameters.values() + }: + func(graph, bucket_size_mb=bucket_size_mb) + else: + func(graph) diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 7c095841140d..9a9d4cd136da 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -19,6 +19,7 @@ import torch from torch._dynamo.utils import counters, optimus_scuba_log from torch._utils_internal import upload_graph +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from .. import config from ..pattern_matcher import ( @@ -1242,5 +1243,10 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): if has_fbgemm: fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False) - for rule in fusions: - apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] + for i, rule in enumerate(fusions): + with GraphTransformObserver( + graph.owning_module, + f"group_batch_fusion_{i}", + config.trace.log_url_for_graph_xform, + ): + apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 8358fdd0bd31..ad134decd228 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -9,6 +9,7 @@ import torch._guards from torch._inductor.constant_folding import ConstantFolder from torch.fx.experimental.symbolic_shapes import statically_known_true +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.multiprocessing.reductions import StorageWeakRef from .. import config @@ -311,15 +312,21 @@ def joint_graph_passes(graph: torch.fx.GraphModule): lazy_init() count = 0 if config.joint_custom_pre_pass is not None: - config.joint_custom_pre_pass(graph.graph) - count += 1 + with GraphTransformObserver( + graph, "joint_custom_pre_pass", config.trace.log_url_for_graph_xform + ): + config.joint_custom_pre_pass(graph.graph) + count += 1 from .post_grad import remove_noop_ops remove_noop_ops(graph.graph) if config.joint_graph_constant_folding: - constant_fold_uniform_value(graph) + with GraphTransformObserver( + graph, "constant_fold_uniform_value", config.trace.log_url_for_graph_xform + ): + constant_fold_uniform_value(graph) if config.pattern_matcher: for patterns in pass_patterns: @@ -329,8 +336,11 @@ def joint_graph_passes(graph: torch.fx.GraphModule): count += replace_random_passes(graph) if config.joint_custom_post_pass is not None: - config.joint_custom_post_pass(graph.graph) - count += 1 + with GraphTransformObserver( + graph, "joint_custom_post_pass", config.trace.log_url_for_graph_xform + ): + config.joint_custom_post_pass(graph.graph) + count += 1 if count: stable_topological_sort(graph.graph) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 3f36c2e7918f..4d1dfe830e01 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -17,6 +17,7 @@ from torch._utils_internal import upload_graph from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from .. import config, ir, pattern_matcher from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage @@ -82,7 +83,10 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): fake_tensor_updater = FakeTensorUpdater(gm.graph) if config.post_grad_custom_pre_pass is not None: - config.post_grad_custom_pre_pass(gm.graph) + with GraphTransformObserver( + gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform + ): + config.post_grad_custom_pre_pass(gm.graph) if config.pattern_matcher: lazy_init() @@ -116,7 +120,10 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): ) if config.post_grad_custom_post_pass is not None: - config.post_grad_custom_post_pass(gm.graph) + with GraphTransformObserver( + gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform + ): + config.post_grad_custom_post_pass(gm.graph) stable_topological_sort(gm.graph) diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 717a46811802..a93c987fe051 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -12,6 +12,7 @@ matches_module_pattern, replace_node_module, ) +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights @@ -220,7 +221,10 @@ def shape_prop(mod) -> None: efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type] if config.pre_grad_custom_pass is not None: - config.pre_grad_custom_pass(gm.graph) + with GraphTransformObserver( + gm, "pre_grad_custom_pass", config.trace.log_url_for_graph_xform + ): + config.pre_grad_custom_pass(gm.graph) stable_topological_sort(gm.graph) from .quantization import quant_lift_up @@ -261,16 +265,31 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule: # For linear permute fusion, we need to check input info to identify # and perform proper permutation/transpose ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) - gm = linear_permute_fusion(gm) - gm = permute_linear_fusion(gm) - gm = permute_matmul_fusion(gm) + with GraphTransformObserver( + gm, "linear_permute_fusion", config.trace.log_url_for_graph_xform + ): + gm = linear_permute_fusion(gm) + with GraphTransformObserver( + gm, "permute_linear_fusion", config.trace.log_url_for_graph_xform + ): + gm = permute_linear_fusion(gm) + with GraphTransformObserver( + gm, "permute_matmul_fusion", config.trace.log_url_for_graph_xform + ): + gm = permute_matmul_fusion(gm) # make sure the autograd is disabled. if torch.is_grad_enabled() or not is_cpu: return gm if config.freezing: - gm = remove_identity(gm) - gm = fuse_conv_bn(gm) + with GraphTransformObserver( + gm, "remove_identity", config.trace.log_url_for_graph_xform + ): + gm = remove_identity(gm) + with GraphTransformObserver( + gm, "fuse_conv_bn", config.trace.log_url_for_graph_xform + ): + gm = fuse_conv_bn(gm) return gm diff --git a/torch/_inductor/fx_passes/replace_random.py b/torch/_inductor/fx_passes/replace_random.py index 4265bf7f26bd..c028eb353791 100644 --- a/torch/_inductor/fx_passes/replace_random.py +++ b/torch/_inductor/fx_passes/replace_random.py @@ -3,7 +3,7 @@ import logging import torch - +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import _extract_tensor_metadata from .. import config, inductor_prims from ..pattern_matcher import ( @@ -25,7 +25,10 @@ def replace_random_passes(gm: torch.fx.GraphModule): return 0 count = patterns.apply(gm) - count += fuse_seed_creation_pass(gm.graph) + with GraphTransformObserver( + gm, "fuse_seed_creation_pass", config.trace.log_url_for_graph_xform + ): + count += fuse_seed_creation_pass(gm.graph) return count diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index b9f4e1e18c93..7c43b23efdd2 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -80,10 +80,12 @@ import torch.utils._pytree as pytree from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.utils import counters +from torch._inductor.config import trace as trace_config from torch._prims_common import is_integer_dtype from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode from torch.fx.experimental.symbolic_shapes import guard_size_oblivious from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from .._functorch import config as functorch_config from .._functorch.aot_autograd import aot_function, make_boxed_func @@ -1649,11 +1651,18 @@ def __init__( def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: return self.patterns[item] - def apply(self, graph: torch.fx.GraphModule) -> int: + def apply(self, gm: torch.fx.GraphModule) -> int: if not self.patterns: return 0 - if isinstance(graph, torch.fx.GraphModule): - graph = graph.graph + if isinstance(gm, torch.fx.GraphModule): + graph = gm.graph + elif isinstance(gm, torch.fx.Graph): + graph = gm + gm = graph.owning_module + else: + raise RuntimeError( + f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}" + ) if self.prevent_match_across_mutations: if should_compute_mutation_region_ids(graph): compute_mutation_region_ids(graph) @@ -1670,36 +1679,40 @@ def apply(self, graph: torch.fx.GraphModule) -> int: nodes.append(graph.find_nodes(op=op, target=target, sort=False)) if has_call_module: nodes.append(graph.find_nodes(op="call_module", sort=False)) - for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): - target = extract_target(node) - if node.op == "call_module": - if (node.op, target) not in self.patterns: - continue - - # conservatively not applying pattern for cpu input, - # since some of the patterns induce codegen and split nodes. - # Note: we will only skip cpu compute if disable_cpp_codegen=True - if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): - continue + pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher" + with GraphTransformObserver( + gm, pass_name, trace_config.log_url_for_graph_xform + ): + for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): + target = extract_target(node) + if node.op == "call_module": + if (node.op, target) not in self.patterns: + continue - for entry in self.patterns[(node.op, target)]: - if node._erased: - break - m = entry.pattern.match(node) - # pattern match crosses mutation barrier - discard - if ( - self.prevent_match_across_mutations - and is_match(m) - and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] - ): + # conservatively not applying pattern for cpu input, + # since some of the patterns induce codegen and split nodes. + # Note: we will only skip cpu compute if disable_cpp_codegen=True + if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): continue - if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: - log.warning("%s%s %s %s", node, node.args, m, entry.pattern) - if is_match(m) and entry.extra_check(m): - count += 1 - entry.apply(m, graph, node) # type: ignore[arg-type] - counters["inductor"]["pattern_matcher_count"] += 1 - counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) + + for entry in self.patterns[(node.op, target)]: + if node._erased: + break + m = entry.pattern.match(node) + # pattern match crosses mutation barrier - discard + if ( + self.prevent_match_across_mutations + and is_match(m) + and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] + ): + continue + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning("%s%s %s %s", node, node.args, m, entry.pattern) + if is_match(m) and entry.extra_check(m): + count += 1 + entry.apply(m, graph, node) # type: ignore[arg-type] + counters["inductor"]["pattern_matcher_count"] += 1 + counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) return count def clear(self) -> None: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index c77576451264..64dc283843f4 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -52,6 +52,7 @@ from torch._dynamo.utils import detect_fake_mode from torch.autograd import DeviceType from torch.autograd.profiler_util import EventList +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import ShapeProp from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import make_symbol, SymT @@ -1397,7 +1398,8 @@ def pass_execution_and_save(func, gm, inp, msg): print(f"Before:\n{gm.graph}", file=f) print(gm.graph, file=before_io) start_time = datetime.now() - func(gm.graph) + with GraphTransformObserver(gm, msg, config.trace.log_url_for_graph_xform): + func(gm.graph) time_elapsed = datetime.now() - start_time # recompile graph stable_topological_sort(gm.graph) diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index 83975a930115..503844a97aa9 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -2,6 +2,7 @@ import os from typing import Optional +from torch.fx._compatibility import compatibility from torch.fx.graph_module import GraphModule from .graph_drawer import FxGraphDrawer @@ -9,6 +10,7 @@ __all__ = ["GraphTransformObserver"] +@compatibility(is_backward_compatible=False) class GraphTransformObserver: __pass_count = 0 From 8e482e909bd35813d045e8e0150ee3f4a973c485 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 10 Jun 2024 07:12:04 -0700 Subject: [PATCH 575/706] Add some guard to size oblivious has_internal_overlap (#128328) This doesn't actually help on https://github.com/pytorch/pytorch/issues/122477 but I noticed this modest improvement so sure, why not. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/128328 Approved by: https://github.com/Skylion007 --- aten/src/ATen/MemoryOverlap.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 7689934e4113..2e6792d5ca69 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -19,7 +19,13 @@ MemOverlap has_internal_overlap(TensorImpl* t) { auto strides = t->sym_strides(); auto sizes = t->sym_sizes(); for (const auto i : c10::irange(strides.size())) { - if (strides[i] == 0 && sizes[i] > 1) { + // NB: The size oblivious test is written very carefully here. When + // unbacked SymInts are involved, we should try to conservatively report + // if memory overlap /could/ happen under some setting of unbacked + // SymInts. Thus, if I have u0 size, we should assume that this has > 1 + // elements (first expression), but if I have a u0 stride, I should NOT + // assume that it is not zero (second expression) + if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_gt(1)) && strides[i] == 0) { return MemOverlap::Yes; } } From ab3a0b192aaa0ac4f31c4bc9d896fb8c379616c4 Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Fri, 7 Jun 2024 10:27:36 -0700 Subject: [PATCH 576/706] [RFC] add per-collective timeout value in flight recorder (#128190) Summary: Add timeout value field on every collected record. Test Plan: Unit tests Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/128190 Approved by: https://github.com/wconstab --- test/distributed/test_c10d_nccl.py | 5 ++++- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 3 +++ torch/csrc/distributed/c10d/TraceUtils.h | 10 +++++++++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 21a8a632bade..f45600c5d17d 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3548,7 +3548,7 @@ def test_short(self, timing_enabled, include_collectives): ) ) ver = t["version"] - self.assertEqual(ver, "2.1") + self.assertEqual(ver, "2.2") pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] @@ -3577,6 +3577,7 @@ def test_short(self, timing_enabled, include_collectives): self.assertEqual(last["output_sizes"], ((3, 4),)) self.assertEqual(last["output_dtypes"], ["Float"]) self.assertEqual(last["collective_seq_id"], 2) + self.assertEqual(last["timeout_ms"], 600000) now = datetime.now() event_created_time = datetime.fromtimestamp( last["time_created_ns"] / 1000000000 @@ -3661,6 +3662,7 @@ def test_long(self): self.assertEqual(last["input_dtypes"], ["Float"]) self.assertEqual(last["output_sizes"], ((3, 4),)) self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["timeout_ms"], 600000) self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9) @requires_nccl() @@ -3865,6 +3867,7 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled): self.assertTrue(0.001 < duration < 10000, duration) else: self.assertTrue("duration_ms" not in t["entries"][coalesced_op]) + self.assertEqual(t["entries"][coalesced_op]["timeout_ms"], 600000) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 8adf1e02c1a0..07bbcd5a0af4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2356,6 +2356,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( outputs, r->ncclStartEvent_.get(), r->ncclEndEvent_.get(), + options_->timeout, isP2P); } return r; @@ -2966,6 +2967,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( {tensor}, nullptr, nullptr, + options_->timeout, /*isP2P=*/true); // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get // their timings/states updated by proxy when the Work obj representing the @@ -2999,6 +3001,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( {tensor}, work->ncclStartEvent_.get(), work->ncclEndEvent_.get(), + options_->timeout, /*isP2P=*/true); } diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index c3b0464cf992..de623d77fe9e 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -8,6 +8,7 @@ #include #include #include +#include #ifdef USE_C10D_NCCL #include @@ -28,7 +29,7 @@ static c10::IValue nccl_comm_key = "nccl_comm_state"; static c10::IValue version_key = "version"; // Update whenever changing contents or formatting of the dump // (minor when adding fields, major when changing existing fields) -static c10::IValue version_val = "2.1"; +static c10::IValue version_val = "2.2"; static c10::IValue pg_config_key = "pg_config"; static c10::IValue record_id_key = "record_id"; static c10::IValue pg_id_key = "pg_id"; @@ -44,6 +45,7 @@ static c10::IValue output_sizes_key = "output_sizes"; static c10::IValue output_dtypes_key = "output_dtypes"; static c10::IValue time_created_key = "time_created_ns"; static c10::IValue duration_key = "duration_ms"; +static c10::IValue timeout_key = "timeout_ms"; static c10::IValue frames_key = "frames"; static c10::IValue state_key = "state"; @@ -461,6 +463,9 @@ struct NCCLTraceBuffer { // was 'enqueued'- not necessarily started c10::time_t time_created_; + // configured timeout for this entry + c10::time_t timeout_ms_; + // Is this a P2P event? bool isP2P_; @@ -508,6 +513,7 @@ struct NCCLTraceBuffer { const std::vector& outputs, Event* start, Event* end, + std::chrono::milliseconds timeout_ms, bool isP2P) { if (!enabled_) { return c10::nullopt; @@ -528,6 +534,7 @@ struct NCCLTraceBuffer { std::move(start), std::move(end), c10::getTime(), + timeout_ms.count(), isP2P}; for (const auto& input : inputs) { @@ -752,6 +759,7 @@ struct NCCLTraceBuffer { ? int64_t(*e.time_discovered_completed_) : c10::IValue()); dict.insert(retired_key, e.retired_); + dict.insert(timeout_key, e.timeout_ms_); dict.insert(is_p2p_key, e.isP2P_); entries.push_back(dict); From 46948300a25e27d002ef3e068c08877abf432102 Mon Sep 17 00:00:00 2001 From: Shengbao Zheng Date: Mon, 10 Jun 2024 17:20:03 +0000 Subject: [PATCH 577/706] [c10d] integrate PMI NCCL initialization to NCCL-PG (#128243) Summary: Move broadcastUniqueID check to NCCLUtils Differential Revision: D58273755 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128243 Approved by: https://github.com/wconstab --- test/distributed/test_c10d_nccl.py | 10 ++-------- torch/csrc/distributed/c10d/NCCLUtils.cpp | 7 +++++++ torch/csrc/distributed/c10d/NCCLUtils.hpp | 1 + torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 3 +-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index f45600c5d17d..feaa649e5851 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -617,18 +617,12 @@ def test_comm_split_subgroup(self): # rank 0 hasn't split yet, but rank 1 did for the # nocolor... so split count matches rank count coincidentally # in each of the proceses this test spawned! - # when using ncclCommCreateFromRanks() in version 2.21+, - # unused ranks are not included in split - version = torch.cuda.nccl.version() - is_nccl_2_21 = version >= (2, 21) - exp_count = 0 if (is_nccl_2_21 or self.rank == 0) else 1 - self.assertEqual(backend.comm_split_count(), exp_count) + self.assertEqual(backend.comm_split_count(), self.rank) if self.rank == 0: dist.broadcast(tensor, 0, group=ng) # now everyone has split because rank 0 has performed a comm - exp_count = 1 if not is_nccl_2_21 else (1 if self.rank == 0 else 0) - self.assertEqual(backend.comm_split_count(), exp_count) + self.assertEqual(backend.comm_split_count(), 1) self.assertEqual(tensor, original_tensor) @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index e2771641af69..bc820fc1c8d5 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -83,6 +83,13 @@ std::shared_ptr NCCLComm::split( } #endif +#ifndef FBCODE_CAFFE2 +bool shouldBroadcastNCCLUniqueID(bool isSendRecvSelf) { + // For point-to-point communication on the same process, don't need broadcast. + return !isSendRecvSelf; +} +#endif + std::string getNcclVersion() { static c10::once_flag ncclGetVersionFlag; static std::string versionString; diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 7617f929feb3..9ce25b55dc13 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -177,6 +177,7 @@ TORCH_API std::string getNcclVersion(); TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error); bool nccl_use_nonblocking(); int nccl_nonblocking_timeout(); +bool shouldBroadcastNCCLUniqueID(bool isSendRecvSelf); // Provides additional detail into NCCL error codes based on when these are // thrown in the NCCL codebase. diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 07bbcd5a0af4..bb9198f22200 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2044,8 +2044,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), c10::nullopt); } - // For point-to-point communication on the same process, don't need broadcast. - if (!isSendRecvSelf) { + if (shouldBroadcastNCCLUniqueID(isSendRecvSelf)) { // Broadcast so that each process can have a unique NCCL ID auto timeStarted = std::chrono::steady_clock::now(); broadcastUniqueNCCLID(&ncclID, singleP2POp, deviceKey, p2pRank); From 08d038f8a8587e4e87ffd33c9c982bd294498d27 Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Mon, 10 Jun 2024 18:03:40 +0000 Subject: [PATCH 578/706] [PT2] Fix a typo and lint problem (#128258) Summary: Titled Test Plan: see signal Differential Revision: D58310169 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128258 Approved by: https://github.com/dshi7, https://github.com/Yuzhen11 --- torch/_inductor/fx_passes/decompose_mem_bound_mm.py | 6 +++--- torch/_inductor/fx_passes/split_cat.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py index 66f0afed9e7d..dba2f62e7d6f 100644 --- a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +++ b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -20,12 +20,12 @@ min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION max_other_dimention_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION -if "decompose_mem_bound_mm" in config.post_grad_fusion_options: +if "decompose_mm_pass" in config.post_grad_fusion_options: min_first_dimension_decomposition = config.post_grad_fusion_options[ - "decompose_mem_bound_mm" + "decompose_mm_pass" ].get("min_first_dimension_decomposition", MIN_FIRST_DIMENSION_DECOMPOSITION) max_other_dimention_decomposition = config.post_grad_fusion_options[ - "decompose_mem_bound_mm" + "decompose_mm_pass" ].get("max_other_dimention_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index b5014a8780f1..8a2c571ee612 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -81,7 +81,7 @@ ) -def construct_pattern_matcher_pass(pass_name: str) -> PatternMatcherPass: +def construct_pattern_matcher_pass(pass_name: str): """ Return the specific pattern_matcher_pass given the pass name. """ From 83941482f74efa926dc700026e6b0e195205c06a Mon Sep 17 00:00:00 2001 From: Andrea Frittoli Date: Mon, 10 Jun 2024 18:10:55 +0000 Subject: [PATCH 579/706] Add docstring for the torch.distributed.elastic.utils.distributed.get_free_port function (#128133) Fixes: #127914 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128133 Approved by: https://github.com/H-Huang --- torch/distributed/elastic/utils/distributed.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torch/distributed/elastic/utils/distributed.py b/torch/distributed/elastic/utils/distributed.py index a1ad1acca796..04ff2fe680f1 100644 --- a/torch/distributed/elastic/utils/distributed.py +++ b/torch/distributed/elastic/utils/distributed.py @@ -114,6 +114,24 @@ def _check_full_rank(store, world_size, timeout): def get_free_port(): + """ + Returns an unused port on localhost. + + This function finds an unused port on localhost by opening to socket to bind + to a port and then closing it. + + Returns: + int: an unused port on localhost + + Example: + >>> # xdoctest: +SKIP("Nondeterministic") + >>> get_free_port() + 63976 + + ..note: + The port returned by :func:`get_free_port` is not reserved and may be + taken by another process after this function returns. + """ sock = get_socket_with_port() with closing(sock): return sock.getsockname()[1] From 136bdb96cb648fd8ebce8f13a5ecd1bece16bb4f Mon Sep 17 00:00:00 2001 From: Aaron Enye Shi Date: Mon, 10 Jun 2024 18:12:32 +0000 Subject: [PATCH 580/706] Update Kineto submodule with fix to test_basic_chrome_trace (#128333) Summary: We've updated the sort_index in Kineto chrome traces to support device ids up to 16 devices. This should make chrome trace rows be ordered in the same way as CUDA. We need to update the unit test as well. Test Plan: Ran locally the changing test: ``` $ buck2 test 'fbcode//mode/opt' fbcode//caffe2/test:test_profiler_cuda -- --exact 'caffe2/test:test_profiler_cuda - test_basic_chrome_trace (profiler.test_profiler.TestProfiler)' File changed: fbcode//caffe2/third_party/kineto.submodule.txt Buck UI: https://www.internalfb.com/buck2/f4fd1e9a-99f1-4422-aeed-b54903c64146 Test UI: https://www.internalfb.com/intern/testinfra/testrun/16888498639845776 Network: Up: 5.4KiB Down: 8.6KiB (reSessionID-0329120e-7fa2-4bc0-b539-7e58058f8fce) Jobs completed: 6. Time elapsed: 1:01.2s. Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D58362964 Pulled By: aaronenyeshi Pull Request resolved: https://github.com/pytorch/pytorch/pull/128333 Approved by: https://github.com/Skylion007 --- test/profiler/test_profiler.py | 5 ++++- third_party/kineto | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 38e83d448fdd..ca0481922e4f 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -1723,9 +1723,12 @@ def _validate_basic_json(self, traceEvents, cuda_available=False): gpu_value = traceEvent.get("args", {}).get("labels", None) if gpu_value and "GPU" in gpu_value: gpu_dict[gpu_value] += 1 + # Max PID offset is 5M, based from pytorch/kineto include header: + # https://github.com/pytorch/kineto/blob/8681ff11e1fa54da39023076c5c43eddd87b7a8a/libkineto/include/output_base.h#L35 + kExceedMaxPid = 5000000 self.assertTrue( traceEvents[i + 1]["args"]["sort_index"] - == 0x1000000 + int(gpu_value.split()[1]) + == kExceedMaxPid + int(gpu_value.split()[1]) ) # TODO add checking gpu count if cpuOnly_ is true or not diff --git a/third_party/kineto b/third_party/kineto index be1317644c68..8681ff11e1fa 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit be1317644c68b4bfc4646024a6b221066e430031 +Subproject commit 8681ff11e1fa54da39023076c5c43eddd87b7a8a From fa8ec8e718999656bef956625467293b4978a087 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Thu, 6 Jun 2024 16:18:49 +0000 Subject: [PATCH 581/706] [dynamo] handle hashable exceptions in trace_rules lookup (#128078) Summary: Found during user empathy day when attempting to hash a fractions.Fraction object before it was fully constructed. See https://github.com/pytorch/pytorch/issues/128075 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128078 Approved by: https://github.com/anijain2305 --- torch/_dynamo/trace_rules.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 585ddb04eda1..4d3f5b11edb0 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3524,7 +3524,11 @@ def lookup_inner( # The rules defined in `torch_name_rule_map` mainly includes two parts: # - Manually defined rules for any functions. # - The list of torch in graph functions. - if not hashable(obj): + try: + can_hash = hashable(obj) + except Exception: + can_hash = False + if not can_hash: if reasons is not None: reasons.add("obj is not hashable") return None From 093a4ff5f859ccbbd8ba62dd189f76e5faadfb04 Mon Sep 17 00:00:00 2001 From: angelayi Date: Mon, 10 Jun 2024 18:39:33 +0000 Subject: [PATCH 582/706] [export] FIx unflattener for preserving modules containing unused inputs (#128260) Currently unflattener fails if the module its preserving the module signature for contains unused inputs/outputs. This also fixes unflattener issues in D57829276. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128260 Approved by: https://github.com/pianpwk --- test/export/test_unflatten.py | 49 +++++++++++++++++++++++++++++++++++ torch/export/unflatten.py | 36 ++++++++++++++++++++----- 2 files changed, 78 insertions(+), 7 deletions(-) diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 3940cde45234..618155e34622 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -312,6 +312,55 @@ def forward(self, x): export_module.module(), unflattened, (torch.randn((2, 3)),) ) + def test_unflatten_preserve_with_alias(self): + class M1(torch.nn.Module): + def forward(self, x, y): + return x + y, x + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.m1 = M1() + + def forward(self, x, y): + return self.m1(x, y)[0] + + ep = torch.export.export( + M(), + (torch.randn(3, 3), torch.randn(3, 3)), + preserve_module_call_signature=("m1",), + ) + unflattened = unflatten(ep) + self.compare_outputs( + ep.module(), unflattened, (torch.randn(3, 3), torch.randn(3, 3)) + ) + + def test_unflatten_preserve_with_unused_input(self): + class M1(torch.nn.Module): + def forward(self, x, a, b): + return x + a, b + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.m1 = M1() + + def forward(self, x, y): + a, b = torch.topk(y, 2) + return self.m1(x, a, b)[0] + + ep = torch.export.export( + M(), + (torch.randn(2), torch.randn(5)), + preserve_module_call_signature=("m1",), + strict=False, + ) + print(ep.graph) + ep.graph.eliminate_dead_code() + print(ep.graph) + unflattened = unflatten(ep) + self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5))) + def test_unflatten_wrong_input(self): class Mod(torch.nn.Module): def __init__(self): diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 4de95dad2c8d..11075058a0e9 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -731,14 +731,20 @@ def __init__( ) if isinstance(arg, ConstantArgument): continue - flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) - self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node + + if arg.name in self.seen_nodes: + flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) + self.node_to_placeholder[ + self.seen_nodes[arg.name] + ] = flat_arg_node with self.parent.graph.inserting_before(self.parent_call_module): input_nodes: List[Optional[torch.fx.Node]] = [] for input in signature.inputs: if isinstance(input, ConstantArgument) and input.value is None: input_nodes.append(None) + elif input.name not in self.seen_nodes: + input_nodes.append(None) else: assert isinstance(input, (TensorArgument, SymIntArgument)) input_nodes.append( @@ -801,18 +807,32 @@ def finalize_outputs(self): if signature is not None and self.parent is not None: for output in signature.outputs: if isinstance(output, (TensorArgument, SymIntArgument)): - orig_outputs.append(self.seen_nodes[output.name]) + if output.name in self.seen_nodes: + orig_outputs.append(self.seen_nodes[output.name]) + else: + orig_outputs.append(None) else: raise RuntimeError( f"Unsupported data type for output node: {output}" ) + def get_actual_output_node(output): + if output is None: + return None + + seen_node = self.seen_nodes[output.name] + if seen_node in self.node_map: + return self.node_map[seen_node] + elif seen_node in self.node_to_placeholder: + return self.node_to_placeholder[seen_node] + else: + raise RuntimeError( + f"Could not find output node {output}. Graph: {self.graph}" + ) + tree_out_node = _generate_unflatten( self.module, - tuple( - self.node_map[self.seen_nodes[output.name]] - for output in orig_outputs - ), + tuple(get_actual_output_node(output) for output in orig_outputs), signature.out_spec, ) parent_out: Optional[torch.fx.Node] = _generate_flatten( @@ -852,6 +872,8 @@ def finalize_outputs(self): self.parent.node_map[orig_outputs[0]] = parent_out else: for i, orig_output in enumerate(orig_outputs): + if orig_output is None: + continue # Use Proxy to record getitem access. proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] proxy_out.meta["val"] = orig_output.meta.get("val") From db2fa7b827cdc5b49d60aa094268583a2ab7cf92 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 10 Jun 2024 18:42:33 +0000 Subject: [PATCH 583/706] Revert "[export] FIx unflattener for preserving modules containing unused inputs (#128260)" This reverts commit 093a4ff5f859ccbbd8ba62dd189f76e5faadfb04. Reverted https://github.com/pytorch/pytorch/pull/128260 on behalf of https://github.com/angelayi due to breaking windows test ([comment](https://github.com/pytorch/pytorch/pull/128260#issuecomment-2159050726)) --- test/export/test_unflatten.py | 49 ----------------------------------- torch/export/unflatten.py | 36 +++++-------------------- 2 files changed, 7 insertions(+), 78 deletions(-) diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 618155e34622..3940cde45234 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -312,55 +312,6 @@ def forward(self, x): export_module.module(), unflattened, (torch.randn((2, 3)),) ) - def test_unflatten_preserve_with_alias(self): - class M1(torch.nn.Module): - def forward(self, x, y): - return x + y, x - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.m1 = M1() - - def forward(self, x, y): - return self.m1(x, y)[0] - - ep = torch.export.export( - M(), - (torch.randn(3, 3), torch.randn(3, 3)), - preserve_module_call_signature=("m1",), - ) - unflattened = unflatten(ep) - self.compare_outputs( - ep.module(), unflattened, (torch.randn(3, 3), torch.randn(3, 3)) - ) - - def test_unflatten_preserve_with_unused_input(self): - class M1(torch.nn.Module): - def forward(self, x, a, b): - return x + a, b - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.m1 = M1() - - def forward(self, x, y): - a, b = torch.topk(y, 2) - return self.m1(x, a, b)[0] - - ep = torch.export.export( - M(), - (torch.randn(2), torch.randn(5)), - preserve_module_call_signature=("m1",), - strict=False, - ) - print(ep.graph) - ep.graph.eliminate_dead_code() - print(ep.graph) - unflattened = unflatten(ep) - self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5))) - def test_unflatten_wrong_input(self): class Mod(torch.nn.Module): def __init__(self): diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 11075058a0e9..4de95dad2c8d 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -731,20 +731,14 @@ def __init__( ) if isinstance(arg, ConstantArgument): continue - - if arg.name in self.seen_nodes: - flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) - self.node_to_placeholder[ - self.seen_nodes[arg.name] - ] = flat_arg_node + flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) + self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node with self.parent.graph.inserting_before(self.parent_call_module): input_nodes: List[Optional[torch.fx.Node]] = [] for input in signature.inputs: if isinstance(input, ConstantArgument) and input.value is None: input_nodes.append(None) - elif input.name not in self.seen_nodes: - input_nodes.append(None) else: assert isinstance(input, (TensorArgument, SymIntArgument)) input_nodes.append( @@ -807,32 +801,18 @@ def finalize_outputs(self): if signature is not None and self.parent is not None: for output in signature.outputs: if isinstance(output, (TensorArgument, SymIntArgument)): - if output.name in self.seen_nodes: - orig_outputs.append(self.seen_nodes[output.name]) - else: - orig_outputs.append(None) + orig_outputs.append(self.seen_nodes[output.name]) else: raise RuntimeError( f"Unsupported data type for output node: {output}" ) - def get_actual_output_node(output): - if output is None: - return None - - seen_node = self.seen_nodes[output.name] - if seen_node in self.node_map: - return self.node_map[seen_node] - elif seen_node in self.node_to_placeholder: - return self.node_to_placeholder[seen_node] - else: - raise RuntimeError( - f"Could not find output node {output}. Graph: {self.graph}" - ) - tree_out_node = _generate_unflatten( self.module, - tuple(get_actual_output_node(output) for output in orig_outputs), + tuple( + self.node_map[self.seen_nodes[output.name]] + for output in orig_outputs + ), signature.out_spec, ) parent_out: Optional[torch.fx.Node] = _generate_flatten( @@ -872,8 +852,6 @@ def get_actual_output_node(output): self.parent.node_map[orig_outputs[0]] = parent_out else: for i, orig_output in enumerate(orig_outputs): - if orig_output is None: - continue # Use Proxy to record getitem access. proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] proxy_out.meta["val"] = orig_output.meta.get("val") From 9cab5987bdeb66df8efbc581b3469bfe300e168c Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 9 Jun 2024 09:48:47 -0700 Subject: [PATCH 584/706] Introduce int_oo (#127693) In a previous life, we used sympy.oo to represent the lower/upper bounds of integer ranges. Later, we changed this to be sys.maxsize - 1 for a few reasons: (1) sometimes we do tests on a value being exactly sys.maxsize, and we wanted to avoid a data dependent guard in this case, (2) sympy.oo corresponds to floating point infinity, so you get incorrect types for value ranges with oo, and (3) you can do slightly better reasoning if you assume that input sizes fall within representable 64-bit integer range. After working in the sys.maxsize regime for a bit, I've concluded that this was actually a bad idea. Specifically, the problem is that you end up with sys.maxsize in your upper bound, and then whenever you do any sort of size-increasing computation like size * 2, you end up with 2 * sys.maxsize, and you end up doing a ton of arbitrary precision int computation that is totally unnecessary. A symbolic bound is better. But especially after #126905, we can't go back to using sympy.oo, because that advertises that it's not an integer, and now your ValueRanges is typed incorrectly. So what do we do? We define a new numeric constant `int_oo`, which is like `sympy.oo` but it advertises `is_integer`. **test/test_sympy_utils.py** describes some basic properties of the number, and **torch/utils/_sympy/numbers.py** has the actual implementation. The rest of the changes of the PR are working out the implications of this change. I'll give more commentary as inline comments. Fixes https://github.com/pytorch/pytorch/issues/127396 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/127693 Approved by: https://github.com/lezcano ghstack dependencies: #126905 --- test/dynamo/test_exc.py | 9 +- test/dynamo/test_export.py | 1 - test/dynamo/test_misc.py | 12 +- test/export/test_export.py | 15 +- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 4 - test/test_dynamic_shapes.py | 11 + test/test_proxy_tensor.py | 4 +- test/test_sympy_utils.py | 70 ++++ torch/_decomp/decompositions.py | 9 +- ...runtime_assertions_for_constraints_pass.py | 5 +- torch/_export/serde/serialize.py | 11 +- torch/_inductor/graph.py | 14 +- torch/export/dynamic_shapes.py | 30 +- torch/fx/experimental/symbolic_shapes.py | 83 ++-- torch/fx/passes/runtime_assert.py | 5 +- torch/utils/_sympy/functions.py | 124 ++++-- torch/utils/_sympy/interp.py | 27 +- torch/utils/_sympy/numbers.py | 394 ++++++++++++++++++ torch/utils/_sympy/value_ranges.py | 63 ++- 19 files changed, 746 insertions(+), 145 deletions(-) create mode 100644 torch/utils/_sympy/numbers.py diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index 953e8ecd0a35..b7b17ed4a1dd 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -253,7 +253,6 @@ def fn(x, shape): ==> (>= 0 s1) ==> (>= 0 s2) ==> (>= 0 s3) - ==> (>= 9223372036854775806 s0) Failed Source Expressions: ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", @@ -287,14 +286,14 @@ def fn(x, shape): Model: ==> L['shape'][0]: 1 ==> L['shape'][1]: 1 - ==> L['shape'][2]: 2 + ==> L['shape'][2]: 0 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 ==> s0: 3 ==> s1: 1 ==> s2: 1 - ==> s3: 2 + ==> s3: 0 Assertions: ==> (== 0 L['x'].storage_offset()) @@ -318,10 +317,6 @@ def fn(x, shape): ==> (== L['shape'][2] s3) ==> (== L['x'].size()[0] s0) ==> (> s0 0) - ==> (>= 9223372036854775806 s0) - ==> (>= 9223372036854775807 s1) - ==> (>= 9223372036854775807 s2) - ==> (>= 9223372036854775807 s3) Failed Source Expressions: ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index dbf983faabb7..776e8ef85cbf 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -3473,7 +3473,6 @@ def forward(self, pred, x): ] false_guard_code = [ "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)", - "-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])", ] test_symbool_guards( f, diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 02f7c68aa1a9..68c38089c53a 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9309,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} + > Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9343,7 +9343,7 @@ def test_shape_env_equal_unbacked(self): > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]} + > Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]} > Right: {} """, ) @@ -9420,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self): > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]} - > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} + > Left: {s0: VR[3, 3], s1: VR[2, int_oo]} + > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} """, ) self._replay_and_check(main) @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self): > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]} - > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} + > Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]} + > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} """, ) self._replay_and_check(main) diff --git a/test/export/test_export.py b/test/export/test_export.py index 19acbbca39f1..c3458ff8003a 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -201,6 +201,19 @@ def forward(self, x): dynamic_shapes={"x": {0: dim_x}}, ) + def test_export_slice_maxsize(self): + class Slice(torch.nn.Module): + def forward(self, *args): + return torch.ops.aten.slice.Tensor(*args) + + inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807) + dynamic_shapes = (({0: Dim("dim")}, None, None, None),) + torch.export.export( + Slice(), + inp, + dynamic_shapes=dynamic_shapes, + ) + def test_export_constraints_error(self): class ConflictingConstraints(torch.nn.Module): def forward(self, x): @@ -5183,7 +5196,7 @@ def forward(self, x): } export(f, (inputs,), dynamic_shapes=dynamic_shapes) - def test_disable_forced_specializations(self): + def test_disable_forced_specializations_ok(self): # check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags # both behave correctly, avoiding forced specializations and deferring to runtime. # case 1: modulo guards diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 0f0e01bc0dc2..f8154a149b41 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -633,10 +633,6 @@ def forward(self, x): func, (torch.randn(3, 4),) ) - @pytorch_test_common.xfail_if_model_type_is_exportedprogram( - error_message="Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}.", - reason="https://github.com/pytorch/pytorch/issues/112622", - ) def test_operator_with_scalar_output(self): class Foo(torch.nn.Module): def forward(self, x, y): diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 3b47f12198d5..b064d1896f0e 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -381,6 +381,17 @@ def test_size_expressions(self): self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) + def test_floordiv_static(self): + shape_env = ShapeEnv() + s0 = create_symint(shape_env, 8) + # This was extracted from + # python test/inductor/test_cuda_cpp_wrapper.py -k + # DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper + bool(s0 % 2 == 0) + bool(s0 % (s0 // 2) == 0) + bool(2 * (s0 // 2) == s0) + self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2)) + def test_numel(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 04483ffba0fc..3985eea7d5b9 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1201,7 +1201,9 @@ def f(src_tokens): batch_size = 4 src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size)) gm = make_fx(f, tracing_mode="symbolic")(src_tokens) - self.assertEqual(len(gm.shape_env.guards), 0) + # Guards to rule out batch_size == sys.maxsize (wobbling between 2 and + # 1 ok) + self.assertEqual(len(gm.shape_env.guards), 1) @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') def test_cpu_scalar_cuda(self): diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 8b16b2c620fd..569c06291331 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: pt2"] import itertools +import math import sys import sympy @@ -19,6 +20,7 @@ from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis from torch.utils._sympy.interp import sympy_interp from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity from sympy.core.relational import is_ge, is_le, is_gt, is_lt import functools import torch.fx as fx @@ -122,6 +124,74 @@ def generate_range(vals): yield ValueRanges(a1, a2) +class TestNumbers(TestCase): + def test_int_infinity(self): + self.assertIsInstance(int_oo, IntInfinity) + self.assertIsInstance(-int_oo, NegativeIntInfinity) + self.assertTrue(int_oo.is_integer) + # is tests here are for singleton-ness, don't use it for comparisons + # against numbers + self.assertIs(int_oo + int_oo, int_oo) + self.assertIs(int_oo + 1, int_oo) + self.assertIs(int_oo - 1, int_oo) + self.assertIs(-int_oo - 1, -int_oo) + self.assertIs(-int_oo + 1, -int_oo) + self.assertIs(-int_oo + (-int_oo), -int_oo) + self.assertIs(-int_oo - int_oo, -int_oo) + self.assertIs(1 + int_oo, int_oo) + self.assertIs(1 - int_oo, -int_oo) + self.assertIs(int_oo * int_oo, int_oo) + self.assertIs(2 * int_oo, int_oo) + self.assertIs(int_oo * 2, int_oo) + self.assertIs(-1 * int_oo, -int_oo) + self.assertIs(-int_oo * int_oo, -int_oo) + self.assertIs(2 * -int_oo, -int_oo) + self.assertIs(-int_oo * 2, -int_oo) + self.assertIs(-1 * -int_oo, int_oo) + self.assertIs(int_oo / 2, sympy.oo) + self.assertIs(-(-int_oo), int_oo) # noqa: B002 + self.assertIs(abs(int_oo), int_oo) + self.assertIs(abs(-int_oo), int_oo) + self.assertIs(int_oo ** 2, int_oo) + self.assertIs((-int_oo) ** 2, int_oo) + self.assertIs((-int_oo) ** 3, -int_oo) + self.assertEqual(int_oo ** -1, 0) + self.assertEqual((-int_oo) ** -1, 0) + self.assertIs(int_oo ** int_oo, int_oo) + self.assertTrue(int_oo == int_oo) + self.assertFalse(int_oo != int_oo) + self.assertTrue(-int_oo == -int_oo) + self.assertFalse(int_oo == 2) + self.assertTrue(int_oo != 2) + self.assertFalse(int_oo == sys.maxsize) + self.assertTrue(int_oo >= sys.maxsize) + self.assertTrue(int_oo >= 2) + self.assertTrue(int_oo >= -int_oo) + + def test_relation(self): + self.assertIs(sympy.Add(2, int_oo), int_oo) + self.assertFalse(-int_oo > 2) + + def test_lt_self(self): + self.assertFalse(int_oo < int_oo) + self.assertIs(min(-int_oo, -4), -int_oo) + self.assertIs(min(-int_oo, -int_oo), -int_oo) + + def test_float_cast(self): + self.assertEqual(float(int_oo), math.inf) + self.assertEqual(float(-int_oo), -math.inf) + + def test_mixed_oo_int_oo(self): + # Arbitrary choice + self.assertTrue(int_oo < sympy.oo) + self.assertFalse(int_oo > sympy.oo) + self.assertTrue(sympy.oo > int_oo) + self.assertFalse(sympy.oo < int_oo) + self.assertIs(max(int_oo, sympy.oo), sympy.oo) + self.assertTrue(-int_oo > -sympy.oo) + self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo) + + class TestValueRanges(TestCase): @parametrize("fn", UNARY_OPS) @parametrize("dtype", ("int", "float")) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 7c9d342ea0f0..7ebc69462fa1 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -734,6 +734,11 @@ def slice_forward( end: Optional[int] = None, step: int = 1, ): + from torch.fx.experimental.symbolic_shapes import ( + guard_size_oblivious, + statically_known_true, + ) + ndim = self.dim() if ndim == 0: raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") @@ -760,7 +765,9 @@ def slice_forward( if end_val < start_val: end_val = start_val - elif end_val > sizes[dim]: + elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious( + end_val > sizes[dim] + ): end_val = sizes[dim] storage_offset = self.storage_offset() + start_val * strides[dim] diff --git a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py index 44f0ea270212..e3bfb0f3de55 100644 --- a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +++ b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -10,6 +10,7 @@ import torch import torch.fx from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils._sympy.numbers import int_oo from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.fx.passes.infra.pass_base import PassBase, PassResult @@ -23,9 +24,9 @@ class InputDim(NamedTuple): def _convert_to_int(val): # Convert simple sympy Integers into concrete int - if val == sympy.oo: + if val in (sympy.oo, int_oo): return math.inf - if val == -sympy.oo: + if val in (-sympy.oo, -int_oo): return -math.inf if isinstance(val, sympy.Integer): return int(val) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index f8fdc1011b52..ff729ddb3c5c 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -42,6 +42,7 @@ from torch.utils import _pytree as pytree from torch.utils._pytree import treespec_dumps, treespec_loads from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils._sympy.numbers import int_oo from .schema import ( # type: ignore[attr-defined] Argument, @@ -321,9 +322,9 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...] def _sympy_int_to_int(val: sympy.Expr, adjust: str): # Convert simple sympy Integers into concrete int - if val == sympy.oo: + if val in (sympy.oo, int_oo): return math.inf - if val == -sympy.oo: + if val in (-sympy.oo, -int_oo): return -math.inf if isinstance(val, sympy.Integer): return int(val) @@ -346,9 +347,9 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str): def _int_to_sympy_int(val) -> sympy.Expr: # Convert concrete int into simple sympy Integers if val == math.inf: - return sympy.oo + return int_oo if val == -math.inf: - return -sympy.oo + return -int_oo return sympy.Integer(val) @@ -1826,7 +1827,7 @@ def deserialize( self.symbol_name_to_range = {} if symbol_name_to_range: for k, vr in symbol_name_to_range.items(): - lower = int(vr.lower) + lower = vr.lower if vr.upper >= 2: # max is >= 2, not sym bool range lower = max(2, lower) self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index f2bdf22e2d96..19e81b236ad9 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -42,6 +42,7 @@ SymTypes, ) from torch.utils._mode_utils import no_dispatch +from torch.utils._sympy.numbers import int_oo from . import config, ir from .codegen.common import ( @@ -1427,18 +1428,21 @@ def format_buffers(): vr = shape_env.var_to_range[i0] if not shape_env._default_unspecified_value_range().issubset(vr): - def convert(s): + def is_convertible(s): + if s in (int_oo, -int_oo): + return False try: - return int(s) + int(s) + return True except TypeError: - return None + return False - if (lower := convert(vr.lower)) is not None: + if is_convertible(vr.lower): self.register_buffer( ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"), set_name=True, ) - if (upper := convert(vr.upper)) is not None: + if is_convertible(vr.upper): self.register_buffer( ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"), set_name=True, diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index a5ce066faa47..8572e069f536 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import builtins import dataclasses import inspect import sys @@ -41,9 +40,11 @@ class _Dim(type): @staticmethod def readable(name, min_, max_): + from torch.utils._sympy.numbers import int_oo + if min_ == 2: min_ = None - if max_ == sys.maxsize - 1: + if max_ == int_oo: max_ = None if min_ is None and max_ is None: return f"Dim('{name}')" @@ -140,6 +141,11 @@ def min(self): # TODO(avik): use sympy value range analysis instead? from sympy import Integer + from torch.utils._sympy.numbers import int_oo + + if self.root.min is -int_oo: # type: ignore[attr-defined] + return -int_oo # fn not needed cuz increasing + _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] root = self.root # type: ignore[attr-defined] assert _min_symint >= 0, ( @@ -155,6 +161,11 @@ def max(self): # TODO(avik): use sympy value range analysis instead? from sympy import Integer + from torch.utils._sympy.numbers import int_oo + + if self.root.max is int_oo: # type: ignore[attr-defined] + return int_oo # fn not needed cuz increasing + _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] root = self.root # type: ignore[attr-defined] assert _max_symint <= sys.maxsize - 1, ( @@ -190,8 +201,10 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): Returns: A type that can be used in dynamic shape specifications for tensors. """ + from torch.utils._sympy.numbers import int_oo + _min = 0 if min is None else min - _max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1) + _max = int_oo if max is None else max assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" dim = _Dim(name, (int,), {"min": _min, "max": _max}) dim.__module__ = getattr( @@ -269,10 +282,11 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): def _clone_with_range(self, lower=0, upper=None): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.value_ranges import ValueRanges if upper is None: - upper = sys.maxsize - 1 + upper = int_oo constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), @@ -503,15 +517,14 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.value_ranges import ValueRanges return _create_constraint( weakref.ref(t), id(t), index, - StrictMinMaxConstraint( - vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False - ), + StrictMinMaxConstraint(vr=ValueRanges(lower=0, upper=int_oo), warn_only=False), debug_name=debug_name, ) @@ -725,6 +738,7 @@ def to_constraint(dim, tensor, i): import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import ValueRanges @@ -799,7 +813,7 @@ def root_value(): constraint = dynamic_dim(tensor, i, debug_name=dim.__name__) if dim.min != 0: constraint = constraint >= dim.min - if dim.max != sys.maxsize - 1: + if dim.max != int_oo: constraint = constraint <= dim.max return constraint diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 0e05d88f6756..3852dc44e7ea 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -65,6 +65,7 @@ FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt ) from torch.utils._sympy.solve import try_solve +from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._traceback import format_frame, CapturedTraceback @@ -871,9 +872,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): for N=1. """ if min is None: - min = -sys.maxsize - 1 + min = -int_oo if max is None: - max = sys.maxsize - 1 + max = int_oo if max < min: raise ValueError( @@ -1382,6 +1383,7 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 'PythonMod': operator.mod, 'FloorDiv': operator.floordiv, 'TrueDiv': operator.truediv, + 'PowByNatural': operator.pow, 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 'floor': math.floor, 'ceiling': math.ceil, @@ -1994,7 +1996,7 @@ def _check_same_range(c, dim): (dim.min < 2 and c.get("min", 2) == 2) or dim.min == c.get("min", 2) ) # let pass if analysis min = 2 and specified min = 0/1 - and dim.max == c.get("max", sys.maxsize - 1) + and dim.max == c.get("max", int_oo) ) # 1) newly introduced roots @@ -2017,7 +2019,7 @@ def _check_same_range(c, dim): modulus, remainder = sympy.polys.polytools.div(c["eq"], root) c_min = c.get("min", 2) min_ = math.ceil((c_min - remainder) / modulus) - c_max = c.get("max", sys.maxsize - 1) + c_max = c.get("max", int_oo) max_ = math.floor((c_max - remainder) / modulus) # create result & dim results[str(root)] = {"min": min_, "max": max_} @@ -2765,7 +2767,7 @@ def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None, if min is None: min = 0 if max is None: - max = sys.maxsize - 1 + max = int_oo if max < min: raise ValueError( @@ -4094,7 +4096,7 @@ def issue_guard(guard: ShapeGuard) -> None: assert sources bounds = [] - if r.lower != -sympy.oo: + if r.lower not in (-sympy.oo, -int_oo): if any(is_dim(source) for source in sources): self.dim_constraints.add(sympy.Ge(symbol, r.lower)) # Only print lower bound in simplified mode if it is not the @@ -4102,14 +4104,7 @@ def issue_guard(guard: ShapeGuard) -> None: if not _simplified or r.lower != self._default_value_range().lower: bounds.append(str(r.lower)) bounds.append(source_ref(sources[0])) - # NB: This looks like an off-by-one error but it's not: the - # upper bound may be sys.maxsize - 1 because we intentionally - # exclude sys.maxsize from our bounds to deal with direct - # == INT_MAX guards, but it's still dumb to actually test it. - # Note that you can be off by a pretty large constant and it - # won't matter because sizes in practice will be no where near - # the 64-bit limit. - if r.upper != sympy.oo and r.upper < sys.maxsize - 1: + if r.upper not in (sympy.oo, int_oo): if any(is_dim(source) for source in sources): self.dim_constraints.add(sympy.Le(symbol, r.upper)) # nontrivial upper bound is always interesting @@ -4121,9 +4116,8 @@ def issue_guard(guard: ShapeGuard) -> None: constraints = symbol_to_constraints[symbol] for c in constraints: if isinstance(c, StrictMinMaxConstraint): - # NB: By default, we have a restrictive range - # 2 <= s0 <= sys.maxsize - 1. But export users generally - # expect to be able to specify nice ranges like [0, oo] + # TODO: With int_oo, I think this condition is a noop + # now if not (c.vr & self._default_value_range()).issubset(r): source = sources[0] @@ -4196,9 +4190,9 @@ def issue_guard(guard: ShapeGuard) -> None: # Reason: '_maybe_evaluate_static' may eliminate guards based on the # refined value ranges. for sym, vr in self.var_to_range.items(): - if vr.lower != -sympy.oo: + if vr.lower not in (-sympy.oo, -int_oo): self._add_target_expr(sympy.Le(vr.lower, sym)) - if vr.upper != sympy.oo: + if vr.upper not in (sympy.oo, int_oo): self._add_target_expr(sympy.Le(sym, vr.upper)) # Before validating, populate the input of the validator with the @@ -4330,9 +4324,14 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} if size_oblivious: # Clamp values of size-like variables + # NB: discarding the old upper bound in intentional, per + # https://github.com/pytorch/pytorch/pull/123675 for x in self.size_like & var_to_range.keys(): if var_to_range[x] is not None: - var_to_range[x] = ValueRanges(2, sys.maxsize - 1) + # NB: do NOT set upper to 2 ** 48, we're using this solely + # to determine if we can do size-like replacement, the + # upper bound is irrelevant here + var_to_range[x] = ValueRanges(2, int_oo) assert var_to_range[x].is_int return bound_sympy(expr, var_to_range) @@ -4450,18 +4449,25 @@ def _maybe_evaluate_static( vr = self._default_unspecified_value_range() if size_oblivious and k in self.size_like: lower = max(2, vr.lower) + # Clamping size-oblivious to some quantity below sys.maxsize + # helps us determine that f(u0) != sys.maxsize, which is a + # test that is looking for sys.maxsize as a sentinel, but you + # don't really want to worry about it for unbacked SymInts. + # This is similar to the flavor where size oblivious omits + # 0/1, it changes semantics but in a benign way. + upper = min(2 ** 48, vr.upper) # This is a bit dodgy: what this means is that there was a # size-like unbacked symbol whose upper bound < 2. This # causes... problems. - if lower <= vr.upper: - vr = ValueRanges(lower, vr.upper) + if lower <= upper: + vr = ValueRanges(lower, upper) else: lower = vr.lower # Don't do anything if we don't have a nontrivial lower bound # Also don't do anything if we asked only to simplify unbacked # SymInt if ( - lower < (-sys.maxsize - 1) // 2 or + lower is -int_oo or (unbacked_only and k in self.var_to_val) or not vr.is_int ): @@ -4717,21 +4723,6 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No if a in self.var_to_range: src_bound = self.var_to_range[a] - # If you have x in [2, maxint], then 2*x in [4, 2*maxint]. - # But we don't really care that the max bound says we can - # go beyond the maximum integer size, because we aren't - # using bigints anyway. Arguably, ValueRanges should know - # to do this truncation automaticaly (to avoid doing - # bigint compute in range analysis), but right now it doesn't - # so we need to get rid of some unnecessary precision. - int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1) - - def issubset(x, y): - if x.is_int and y.is_int: - return (x & int_range).issubset(y & int_range) - else: - return x.issubset(y) - # First, refine the value range of a based on the computed value range # of tgt. This is always OK to do, even if we decide not to do the # substitution in the end. This might be a no-op, if a already has @@ -4744,7 +4735,7 @@ def issubset(x, y): # - the source bound non-trivially improves over what we get out of # the existing bounds. # - the replacement is univariate and we can invert the tgt expression - if not issubset(tgt_bound, src_bound) and len(tgt.free_symbols) == 1: + if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1: b = next(iter(tgt.free_symbols)) # Try to invert the equality r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) @@ -4759,7 +4750,7 @@ def issubset(x, y): b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)) self._update_var_to_range(b, b_bound) tgt_bound = self.bound_sympy(tgt) - assert issubset(tgt_bound, src_bound) + assert tgt_bound.issubset(src_bound) # TODO: Should we propagate size-like-ness? # @@ -4797,13 +4788,13 @@ def issubset(x, y): # - If the variable is unbacked, only substitute if the substitution # would preserve the bounds also under size-like-ness conditions. - if not issubset(tgt_bound, src_bound): + if not tgt_bound.issubset(src_bound): self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound) return elif a in self.size_like: tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) src_bound_so = self.bound_sympy(a, size_oblivious=True) - if not issubset(tgt_bound_so, src_bound_so): + if not tgt_bound_so.issubset(src_bound_so): self.log.debug("skipped set_replacement %s = %s (%s) " "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) return @@ -4888,6 +4879,7 @@ def _smart_symbol_sort(x): has_only_ephemeral_sources = ( x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x]) ) + # NB: size_hint is int, not sympy.Expr, do not use int_oo here size = self.size_hint(x, allow_none=True) or sys.maxsize name = x.name # 1 puts ephemeral sourced symbols first when sorting in reverse @@ -4984,15 +4976,12 @@ def trivial_solve(lhs, rhs): return # See: Note - On 0/1 specialization - # NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT - # as a sentinel sometimes. Your sizevar isn't going to be - # anywhere near the max 64-bit integer anyway. def _default_value_range(self) -> ValueRanges: lower = 2 if self.specialize_zero_one else 0 - return ValueRanges(lower, sys.maxsize - 1) + return ValueRanges(lower, int_oo) def _default_unspecified_value_range(self) -> ValueRanges: - return ValueRanges(-sys.maxsize - 1, sys.maxsize) + return ValueRanges(-int_oo, int_oo) @_lru_cache def _simplify_floor_div(self, expr): diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 66b8fbe29d9f..d1d206eff63e 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -65,7 +65,7 @@ def insert_deferred_runtime_asserts( ): assert len(node.args) == 1 nodes_that_already_have_sym_constraint_range.add( - (node.args[0], node.kwargs["min"], node.kwargs["max"]) + (node.args[0], node.kwargs.get("min"), node.kwargs.get("max")) ) if ( node.op == "call_function" @@ -86,6 +86,7 @@ def insert_deferred_runtime_asserts( InnerTensorKey, ) from torch.utils._sympy.interp import sympy_interp + from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.reference import PythonReferenceAnalysis # TODO: Request simplification on runtime asserts before emitting them @@ -367,6 +368,8 @@ def go(node, keypath): # (refinement should not be necessary once runtime # asserts cause refinement, but that's NYI) def convert(s): + if s in (int_oo, -int_oo): + return None try: return int(s) except TypeError: diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 86a5b32aabb9..0d7c5a784c63 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -6,6 +6,8 @@ import sympy from sympy import S +from .numbers import int_oo + __all__ = [ "FloorDiv", "ModularIndexing", @@ -101,6 +103,15 @@ def eval(cls, base, divisor): # makes it difficult to check the types. if divisor.is_zero: raise ZeroDivisionError("division by zero") + if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in ( + int_oo, + -int_oo, + sympy.oo, + -sympy.oo, + ): + return sympy.nan + if base is sympy.nan or divisor is sympy.nan: + return sympy.nan if base.is_zero: return sympy.S.Zero @@ -108,6 +119,23 @@ def eval(cls, base, divisor): return base if base.is_integer and divisor == -1: return sympy.Mul(base, -1) + if ( + isinstance(base, sympy.Number) + and isinstance(divisor, sympy.Number) + and ( + base in (int_oo, -int_oo, sympy.oo, -sympy.oo) + or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo) + ) + ): + r = float(base) / float(divisor) + if r == math.inf: + return int_oo + elif r == -math.inf: + return -int_oo + elif math.isnan(r): + return sympy.nan + else: + return sympy.Integer(math.floor(r)) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): return sympy.Integer(int(base) // int(divisor)) if isinstance(base, FloorDiv): @@ -353,10 +381,10 @@ class CeilToInt(sympy.Function): @classmethod def eval(cls, number): # assert number.is_integer is not True, number - if number == sympy.oo: - return sympy.Integer(sys.maxsize - 1) - if number == -sympy.oo: - return sympy.Integer(-sys.maxsize - 1) + if number in (sympy.oo, int_oo): + return int_oo + if number in (-sympy.oo, -int_oo): + return -int_oo if isinstance(number, sympy.Number): return sympy.Integer(math.ceil(float(number))) @@ -367,10 +395,10 @@ class FloorToInt(sympy.Function): @classmethod def eval(cls, number): # assert number.is_integer is not True, number - if number == sympy.oo: - return sympy.Integer(sys.maxsize - 1) - if number == -sympy.oo: - return sympy.Integer(-sys.maxsize - 1) + if number in (sympy.oo, int_oo): + return int_oo + if number in (-sympy.oo, int_oo): + return -int_oo if isinstance(number, sympy.Number): return sympy.Integer(math.floor(float(number))) @@ -419,6 +447,7 @@ def safe_pow(base, exp): return sign * _safe_pow(base, exp) +# Prevent people from overflowing pow def _safe_pow(base, exponent): if exponent < 0: raise ValueError("Exponent must be non-negative.") @@ -427,17 +456,20 @@ def _safe_pow(base, exponent): return 1 half_exp = safe_pow(base, exponent // 2) - if half_exp > sys.maxsize - 1: - return sys.maxsize - 1 + if half_exp is int_oo: + return int_oo + + # TODO: microoptimization is to avoid overflowing into arbitrary precision + # and detect overflow prior to doing operations result = half_exp * half_exp - if result > sys.maxsize - 1: - return sys.maxsize - 1 + if result > sys.maxsize: + return int_oo if exponent % 2 == 1: result *= base - if result > sys.maxsize - 1: - return sys.maxsize - 1 + if result > sys.maxsize: + return int_oo return result @@ -447,14 +479,20 @@ class PowByNatural(sympy.Function): @classmethod def eval(cls, base, exp): - if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): - return sympy.Integer(safe_pow(base, exp)) + if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer): + r = safe_pow(base, exp) + if r in (-int_oo, int_oo): + return r + return sympy.Integer(r) if isinstance(exp, sympy.Integer): - # Translate power into iterated multiplication - r = sympy.Integer(1) - for _ in range(int(exp)): - r *= base - return r + # Rely on regular sympy Pow for this (note that iterated + # multiplication turns into a Pow anyway, you can't escape!!) + return sympy.Pow(base, exp) + if exp in (int_oo, sympy.oo): + if base.is_nonnegative: + return int_oo + elif base.is_negative: + return sympy.zoo # this is apparently what (-2)**sympy.oo does # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp # is a natural number if we do @@ -467,6 +505,11 @@ class FloatPow(sympy.Function): @classmethod def eval(cls, base, exp): + # NB: These test sympy.Number, not sympy.Float, because: + # - Sometimes we may have sympy.oo or int_oo, and that's not a Float + # (but coerces to math.Inf) + # - Sometimes Float(0.0) will unpredictably decay to Integer(0), + # but we should still accept it in floatey contexts if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): return sympy.Float(float(base) ** float(exp)) # NB: do not do any nontrivial reasoning @@ -510,7 +553,18 @@ def eval(cls, base, divisor): if divisor.is_zero: raise ZeroDivisionError("division by zero") - if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + if ( + isinstance(base, sympy.Number) + and isinstance(divisor, sympy.Number) + and ( + base in (int_oo, -int_oo, sympy.oo, -sympy.oo) + or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo) + ) + ): + # Don't have to worry about precision here, you're getting zero or + # inf from the division + return sympy.Float(float(base) / float(divisor)) + if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): return sympy.Float(int(base) / int(divisor)) @@ -567,10 +621,10 @@ class TruncToInt(sympy.Function): @classmethod def eval(cls, number): # assert number.is_integer is not True, number - if number == sympy.oo: - return sympy.Integer(sys.maxsize - 1) - if number == -sympy.oo: - return sympy.Integer(-sys.maxsize - 1) + if number in (sympy.oo, int_oo): + return int_oo + if number in (-sympy.oo, -int_oo): + return -int_oo if isinstance(number, sympy.Number): return sympy.Integer(math.trunc(float(number))) @@ -583,7 +637,11 @@ class RoundToInt(sympy.Function): def eval(cls, number): # assert number.is_integer is not True, number - if isinstance(number, sympy.Float): + if number is sympy.oo: + return int_oo + if number is -sympy.oo: + return -int_oo + if isinstance(number, sympy.Number): return sympy.Integer(round(float(number), 0)) @@ -610,7 +668,7 @@ class RoundDecimal(sympy.Function): def eval(cls, number, ndigits): # assert number.is_integer is not True, number - if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): + if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): return sympy.Float(round(float(number), int(ndigits))) @@ -625,6 +683,10 @@ def eval(cls, number): if isinstance(number, sympy.Integer): return sympy.Float(int(number)) + if number is int_oo: + return sympy.oo + if number is -int_oo: + return -sympy.oo def make_opaque_unary_fn(name): @@ -655,7 +717,11 @@ def eval(cls, a): # weird objects but ask silly questions, get silly answers except OverflowError: return getattr(sympy, name)(a) - elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo]: + elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]: + if a is int_oo: + a = sympy.oo + if a is -int_oo: + a = -sympy.oo return getattr(sympy, name)(a) return None diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 640b991cd104..36ff6fc23d4a 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -9,6 +9,7 @@ """ import functools +import logging from typing import Any, Dict, Union import sympy @@ -37,6 +38,9 @@ ) +log = logging.getLogger(__name__) + + # TODO: Dedupe this with SYMPY_INTERP @@ -157,11 +161,18 @@ def sympy_interp( else: handler_name = handlers()[expr.func] handler = getattr(analysis, handler_name) - if handler_name in ASSOCIATIVE_OPS: - assert len(args) > 1 - acc = handler(args[0], args[1]) - for i in range(2, len(args)): - acc = handler(acc, args[i]) - return acc - else: - return handler(*args) + try: + if handler_name in ASSOCIATIVE_OPS: + assert len(args) > 1 + acc = handler(args[0], args[1]) + for i in range(2, len(args)): + acc = handler(acc, args[i]) + log.debug("%s(%s) -> %s", handler_name, args, acc) + return acc + else: + r = handler(*args) + log.debug("%s(%s) -> %s", handler_name, args, r) + return r + except Exception: + log.warning("failed while executing %s(%s)", handler_name, args) + raise diff --git a/torch/utils/_sympy/numbers.py b/torch/utils/_sympy/numbers.py new file mode 100644 index 000000000000..89dac14fddf3 --- /dev/null +++ b/torch/utils/_sympy/numbers.py @@ -0,0 +1,394 @@ +import mpmath.libmp as mlib # type: ignore[import-untyped] +import sympy +from sympy import Expr +from sympy.core.decorators import _sympifyit +from sympy.core.expr import AtomicExpr +from sympy.core.numbers import Number +from sympy.core.parameters import global_parameters +from sympy.core.singleton import S, Singleton + + +class IntInfinity(Number, metaclass=Singleton): + r"""Positive integer infinite quantity. + + Integer infinity is a value in an extended integers which + is greater than all other integers. We distinguish it from + sympy's existing notion of infinity in that it reports that + it is_integer. + + Infinity is a singleton, and can be accessed by ``S.IntInfinity``, + or can be imported as ``int_oo``. + """ + + # NB: We can't actually mark this as infinite, as integer and infinite are + # inconsistent assumptions in sympy. We also report that we are complex, + # different from sympy.oo + + is_integer = True + is_commutative = True + is_number = True + is_extended_real = True + is_comparable = True + is_extended_positive = True + is_prime = False + + # Ensure we get dispatched to before plain numbers + _op_priority = 100.0 + + __slots__ = () + + def __new__(cls): + return AtomicExpr.__new__(cls) + + def _sympystr(self, printer): + return "int_oo" + + def _eval_subs(self, old, new): + if self == old: + return new + + # We could do these, not sure about it + """ + def _eval_evalf(self, prec=None): + return Float('inf') + + def evalf(self, prec=None, **options): + return self._eval_evalf(prec) + """ + + @_sympifyit("other", NotImplemented) + def __add__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other is S.NegativeInfinity: + return S.NegativeInfinity + if other in (S.NegativeIntInfinity, S.NaN): + return S.NaN + return self + return Number.__add__(self, other) + + __radd__ = __add__ + + @_sympifyit("other", NotImplemented) + def __sub__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other is S.Infinity: + return S.NegativeInfinity + if other in (S.IntInfinity, S.NaN): + return S.NaN + return self + return Number.__sub__(self, other) + + @_sympifyit("other", NotImplemented) + def __rsub__(self, other): + return (-self).__add__(other) + + @_sympifyit("other", NotImplemented) + def __mul__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other.is_zero or other is S.NaN: + return S.NaN + if other.is_extended_positive: + return self + return S.NegativeIntInfinity + return Number.__mul__(self, other) + + __rmul__ = __mul__ + + @_sympifyit("other", NotImplemented) + def __truediv__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other in ( + S.Infinity, + S.IntInfinity, + S.NegativeInfinity, + S.NegativeIntInfinity, + S.NaN, + ): + return S.NaN + if other.is_extended_nonnegative: + return S.Infinity # truediv produces float + return S.NegativeInfinity # truediv produces float + return Number.__truediv__(self, other) + + def __abs__(self): + return S.IntInfinity + + def __neg__(self): + return S.NegativeIntInfinity + + def _eval_power(self, expt): + if expt.is_extended_positive: + return S.IntInfinity + if expt.is_extended_negative: + return S.Zero + if expt is S.NaN: + return S.NaN + if expt is S.ComplexInfinity: + return S.NaN + if expt.is_extended_real is False and expt.is_number: + from sympy.functions.elementary.complexes import re + + expt_real = re(expt) + if expt_real.is_positive: + return S.ComplexInfinity + if expt_real.is_negative: + return S.Zero + if expt_real.is_zero: + return S.NaN + + return self ** expt.evalf() + + def _as_mpf_val(self, prec): + return mlib.finf + + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + return other is S.IntInfinity + + def __ne__(self, other): + return other is not S.IntInfinity + + def __gt__(self, other): + if other is S.Infinity: + return sympy.false # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.true + + def __ge__(self, other): + if other is S.Infinity: + return sympy.false # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.true + + def __lt__(self, other): + if other is S.Infinity: + return sympy.true # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.false + + def __le__(self, other): + if other is S.Infinity: + return sympy.true # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.false + + @_sympifyit("other", NotImplemented) + def __mod__(self, other): + if not isinstance(other, Expr): + return NotImplemented + return S.NaN + + __rmod__ = __mod__ + + def floor(self): + return self + + def ceiling(self): + return self + + +int_oo = S.IntInfinity + + +class NegativeIntInfinity(Number, metaclass=Singleton): + """Negative integer infinite quantity. + + NegativeInfinity is a singleton, and can be accessed + by ``S.NegativeInfinity``. + + See Also + ======== + + IntInfinity + """ + + # Ensure we get dispatched to before plain numbers + _op_priority = 100.0 + + is_integer = True + is_extended_real = True + is_commutative = True + is_comparable = True + is_extended_negative = True + is_number = True + is_prime = False + + __slots__ = () + + def __new__(cls): + return AtomicExpr.__new__(cls) + + def _eval_subs(self, old, new): + if self == old: + return new + + def _sympystr(self, printer): + return "-int_oo" + + """ + def _eval_evalf(self, prec=None): + return Float('-inf') + + def evalf(self, prec=None, **options): + return self._eval_evalf(prec) + """ + + @_sympifyit("other", NotImplemented) + def __add__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other is S.Infinity: + return S.Infinity + if other in (S.IntInfinity, S.NaN): + return S.NaN + return self + return Number.__add__(self, other) + + __radd__ = __add__ + + @_sympifyit("other", NotImplemented) + def __sub__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other is S.NegativeInfinity: + return S.Infinity + if other in (S.NegativeIntInfinity, S.NaN): + return S.NaN + return self + return Number.__sub__(self, other) + + @_sympifyit("other", NotImplemented) + def __rsub__(self, other): + return (-self).__add__(other) + + @_sympifyit("other", NotImplemented) + def __mul__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other.is_zero or other is S.NaN: + return S.NaN + if other.is_extended_positive: + return self + return S.IntInfinity + return Number.__mul__(self, other) + + __rmul__ = __mul__ + + @_sympifyit("other", NotImplemented) + def __truediv__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other in ( + S.Infinity, + S.IntInfinity, + S.NegativeInfinity, + S.NegativeIntInfinity, + S.NaN, + ): + return S.NaN + if other.is_extended_nonnegative: + return self + return S.Infinity # truediv returns float + return Number.__truediv__(self, other) + + def __abs__(self): + return S.IntInfinity + + def __neg__(self): + return S.IntInfinity + + def _eval_power(self, expt): + if expt.is_number: + if expt in ( + S.NaN, + S.Infinity, + S.NegativeInfinity, + S.IntInfinity, + S.NegativeIntInfinity, + ): + return S.NaN + + if isinstance(expt, sympy.Integer) and expt.is_extended_positive: + if expt.is_odd: + return S.NegativeIntInfinity + else: + return S.IntInfinity + + inf_part = S.IntInfinity**expt + s_part = S.NegativeOne**expt + if inf_part == 0 and s_part.is_finite: + return inf_part + if ( + inf_part is S.ComplexInfinity + and s_part.is_finite + and not s_part.is_zero + ): + return S.ComplexInfinity + return s_part * inf_part + + def _as_mpf_val(self, prec): + return mlib.fninf + + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + return other is S.NegativeIntInfinity + + def __ne__(self, other): + return other is not S.NegativeIntInfinity + + def __gt__(self, other): + if other is S.NegativeInfinity: + return sympy.true # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.false + + def __ge__(self, other): + if other is S.NegativeInfinity: + return sympy.true # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.false + + def __lt__(self, other): + if other is S.NegativeInfinity: + return sympy.false # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.true + + def __le__(self, other): + if other is S.NegativeInfinity: + return sympy.false # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.true + + @_sympifyit("other", NotImplemented) + def __mod__(self, other): + if not isinstance(other, Expr): + return NotImplemented + return S.NaN + + __rmod__ = __mod__ + + def floor(self): + return self + + def ceiling(self): + return self + + def as_powers_dict(self): + return {S.NegativeOne: 1, S.IntInfinity: 1} diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 48f846c2fd72..087e741a72ec 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -6,7 +6,6 @@ import logging import math import operator -import sys from typing import ( Callable, Dict, @@ -24,6 +23,7 @@ from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom import torch +from torch._logging import LazyString from torch._prims_common import dtype_to_type from .functions import ( @@ -43,6 +43,7 @@ TruncToInt, ) from .interp import sympy_interp +from .numbers import int_oo, IntInfinity, NegativeIntInfinity log = logging.getLogger(__name__) @@ -168,7 +169,10 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: self, "is_int", not self.is_bool - and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)), + and ( + isinstance(lower, (sympy.Integer, NegativeIntInfinity)) + or isinstance(upper, (sympy.Integer, IntInfinity)) + ), ) """ # This assert is just impossible right now, too many sympy bugs @@ -265,11 +269,14 @@ def __or__(self: AllVR, other: AllVR) -> AllVR: def is_singleton(self) -> bool: return self.lower == self.upper - # TODO: this doesn't work with bools but arguably it should @staticmethod def unknown() -> ValueRanges[sympy.Expr]: return ValueRanges(-sympy.oo, sympy.oo) + @staticmethod + def unknown_int() -> ValueRanges[sympy.Expr]: + return ValueRanges(-int_oo, int_oo) + @staticmethod def unknown_bool() -> ValueRanges[SympyBoolean]: return ValueRanges(sympy.false, sympy.true) @@ -401,7 +408,7 @@ def constant(value, dtype): elif dtype.is_floating_point: return ValueRanges.unknown() else: - return ValueRanges(-sys.maxsize - 1, sys.maxsize) + return ValueRanges(-int_oo, int_oo) if is_python: type_ = dtype_to_type(dtype) @@ -424,6 +431,10 @@ def constant(value, dtype): def to_dtype(a, dtype, src_dtype=None): if dtype == torch.float64: return ValueRanges.increasing_map(a, ToFloat) + elif dtype == torch.bool: + return ValueRanges.unknown_bool() + elif not dtype.is_floating_point: + return ValueRanges.unknown_int() return ValueRanges.unknown() @staticmethod @@ -515,9 +526,7 @@ def safe_mul(a, b): def int_truediv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) - if 0 in b or ( - (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) - ): + if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)): return ValueRanges.unknown() else: return ValueRanges.coordinatewise_monotone_map( @@ -541,14 +550,17 @@ def truediv(a, b): def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) - if 0 in b or ( - # TODO: make this more precise - (-sympy.oo in a or sympy.oo in a) - or (-sympy.oo in b or sympy.oo in b) - ): + if 0 in b: return ValueRanges.unknown() - else: - return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv) + products = [] + for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]): + r = FloorDiv(x, y) + if r is sympy.nan: + products.append((sympy.sign(x) * sympy.sign(y)) * int_oo) + else: + products.append(r) + + return ValueRanges(min(products), max(products)) @classmethod def mod(cls, x, y): @@ -564,10 +576,10 @@ def c_mod(a, b): def c_div(a, b): x = a / b - return sympy.Integer(x) if x.is_finite else x + return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x if 0 in y: - return ValueRanges.unknown() + return ValueRanges.unknown_int() elif y.is_singleton(): y_val = abs(y.lower) # If it wraps, we need to take the whole interval @@ -597,7 +609,7 @@ def modular_indexing(cls, a, b, c): @classmethod def is_non_overlapping_and_dense_indicator(cls, *args): - return ValueRanges.unknown() # TODO: type here is wrong + return ValueRanges.unknown_int() @classmethod def pow_by_natural(cls, a, b): @@ -611,7 +623,7 @@ def pow_by_natural(cls, a, b): # to replacements, so don't assert it, but DO clamp it to prevent # degenerate problems return ValueRanges.coordinatewise_increasing_map( - a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural + a, b & ValueRanges(0, int_oo), PowByNatural ) elif b.is_singleton(): if b.lower % 2 == 0: @@ -939,6 +951,8 @@ def cast(x, dtype): if dtype.is_floating_point: return sympy.Float(x) else: + if x in (int_oo, -int_oo): + return x try: return sympy.Integer(x) except TypeError: @@ -986,7 +1000,18 @@ def __getattr__(self, name): def bound_sympy( expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: - log.debug("bound_sympy(%s, %s)", expr, ranges) + log.debug( + "bound_sympy(%s)%s", + expr, + LazyString( + lambda: "\n" + + "\n".join( + f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols + ) + if ranges + else "" + ), + ) if isinstance(expr, sympy.Number): return ValueRanges.wrap(expr) From 55646554b7fd2c6019b90c3b0cba7f8348b19f37 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Mon, 10 Jun 2024 19:21:39 +0000 Subject: [PATCH 585/706] [EZ] Fix typos in SECURITY.md (#128340) permisisons -> permissions lates -> latest Pull Request resolved: https://github.com/pytorch/pytorch/pull/128340 Approved by: https://github.com/clee2000, https://github.com/atalman, https://github.com/kit1980 --- SECURITY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SECURITY.md b/SECURITY.md index a6f676ef39be..119a2b7615ac 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -40,7 +40,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de ### Untrusted inputs during training and prediction -If you plan to open your model to untrusted inputs, be aware that inputs can also be used as vectors by malicious agents. To minimize risks, make sure to give your model only the permisisons strictly required, and keep your libraries updated with the lates security patches. +If you plan to open your model to untrusted inputs, be aware that inputs can also be used as vectors by malicious agents. To minimize risks, make sure to give your model only the permissions strictly required, and keep your libraries updated with the latest security patches. If applicable, prepare your model against bad inputs and prompt injections. Some recommendations: - Pre-analysis: check how the model performs by default when exposed to prompt injection (e.g. using fuzzing for prompt injection). From 946f554c8fc2b99877f8f593783c3f2ef27089b9 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 8 Jun 2024 12:18:05 -0700 Subject: [PATCH 586/706] Flip default value for mypy disallow_untyped_defs [10+1/11] (#128293) See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128293 Approved by: https://github.com/oulgen --- torch/_export/passes/constant_folding.py | 1 + torch/_inductor/cpp_builder.py | 1 + torch/_inductor/fx_passes/micro_pipeline_tp.py | 1 + 3 files changed, 3 insertions(+) diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py index 54b7a1565924..684fe07b0ec3 100644 --- a/torch/_export/passes/constant_folding.py +++ b/torch/_export/passes/constant_folding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections from collections import defaultdict from typing import Any, Callable, Dict, Optional diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 8c9d2c18cabd..413270edc314 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This CPP JIT builder is designed to support both Windows and Linux OS. # The design document please check this RFC: https://github.com/pytorch/pytorch/issues/124245 diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 20a864377787..fdac76f75e43 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from dataclasses import dataclass from typing import cast, List, Set, Tuple, Union From 38e0a0440c2249974c88c7eb4b056298b553a296 Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Mon, 10 Jun 2024 19:55:21 +0000 Subject: [PATCH 587/706] [AMD] Default to hipblaslt in gemm (#127944) Summary: It has been a constant pain that we have to specify env var to go with the hipblaslt path. The default path is very slow on MI300. Therefore, let's default to hipblaslt. Differential Revision: D58150764 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127944 Approved by: https://github.com/aaronenyeshi, https://github.com/houseroad --- aten/src/ATen/Context.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 4f6eb0d4a109..bb6b0611b743 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -385,8 +385,11 @@ class TORCH_API Context { ? at::LinalgBackend::Cusolver : at::LinalgBackend::Default; at::BlasBackend blas_preferred_backend = - (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true || - c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) +#ifdef USE_ROCM + (c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false) +#else + (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true) +#endif ? at::BlasBackend::Cublaslt : at::BlasBackend::Cublas; #ifdef C10_MOBILE From 90bb510ece29ff505fe8fdf13220e622485a8e9b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 10 Jun 2024 20:44:42 +0000 Subject: [PATCH 588/706] Revert "Deprecate `torch._utils.is_compiling()` and `torch._dynamo.external_utils.is_compiling()` (#127690)" This reverts commit 348b181a97abc2e636a6c18e5880a78e5d1dab94. Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/clee2000 due to sorry I think https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456 is still relevant, I will reach out to them to see what needs to be done in internal to get this remerged ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2159248859)) --- test/test_optim.py | 2 +- torch/_dynamo/decorators.py | 3 ++- torch/_dynamo/external_utils.py | 5 ----- torch/_functorch/apis.py | 6 +++--- torch/_functorch/eager_transforms.py | 4 ++-- torch/_higher_order_ops/associative_scan.py | 2 +- torch/_utils.py | 6 +----- .../algorithms/ddp_comm_hooks/default_hooks.py | 4 ++-- torch/distributed/tensor/parallel/_utils.py | 2 +- torch/nn/parallel/distributed.py | 4 ++-- torch/optim/adadelta.py | 6 +++--- torch/optim/adam.py | 6 +++--- torch/optim/adamax.py | 6 +++--- torch/optim/adamw.py | 6 +++--- torch/optim/asgd.py | 4 ++-- torch/optim/nadam.py | 4 ++-- torch/optim/optimizer.py | 11 ++++++----- torch/optim/radam.py | 4 ++-- torch/optim/rmsprop.py | 6 +++--- torch/optim/rprop.py | 6 +++--- torch/optim/sgd.py | 2 +- torch/testing/_internal/optests/generate_tests.py | 2 +- 22 files changed, 47 insertions(+), 54 deletions(-) diff --git a/test/test_optim.py b/test/test_optim.py index 3ab57fecd833..d61c33e2adce 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -287,7 +287,7 @@ def test_param_group_with_lrscheduler_goes_right_direction( inpt = torch.randn(5, device=device, dtype=dtype) # avoid endless recompiles by wrapping LR in a tensor if we're compiling - lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01 + lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01 optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}]) schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c] diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 557b9a72dde1..ec25d06281fc 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -9,6 +9,7 @@ from .comptime import comptime from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage +from .external_utils import is_compiling if TYPE_CHECKING: from torch._C._dynamo.eval_frame import ( # noqa: F401 @@ -264,7 +265,7 @@ def mark_static(t, index=None): Unlike mark_dynamic, this can be done inside a graph, in which case it induces specialization on the tensor. """ - if torch.compiler.is_compiling(): + if is_compiling(): if index is None: for s in t.size(): comptime.force_static(s) diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 7982e77f6feb..caea92bc6be0 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -3,7 +3,6 @@ import functools from typing import List -from typing_extensions import deprecated import torch import torch.utils._pytree as pytree @@ -14,10 +13,6 @@ np = None # type: ignore[assignment] -@deprecated( - "`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.", - category=FutureWarning, -) def is_compiling() -> bool: """ Indicates whether we are tracing/compiling with torch.compile() or torch.export(). diff --git a/torch/_functorch/apis.py b/torch/_functorch/apis.py index 8d4a77457867..1b755550a8bf 100644 --- a/torch/_functorch/apis.py +++ b/torch/_functorch/apis.py @@ -189,7 +189,7 @@ def vmap( vmap does not provide general autobatching or handle variable-length sequences out of the box. """ - from torch.compiler import is_compiling + from torch._dynamo import is_compiling _check_randomness_arg(randomness) if not (chunk_size is None or chunk_size > 0): @@ -391,7 +391,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla """ # To avoid cyclical dependency. import torch._functorch.eager_transforms as eager_transforms - from torch.compiler import is_compiling + from torch._dynamo import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) @@ -433,8 +433,8 @@ def grad_and_value( See :func:`grad` for examples """ + from torch._dynamo import is_compiling from torch._functorch import eager_transforms - from torch.compiler import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_and_value_impl( diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index 80751c9694fd..fff6bd67838f 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -765,7 +765,7 @@ def compute_jacobian_preallocate_and_copy(): # Dynamo does not support HOP composition if their inner function is # annotated with @functools.wraps(...). We circumvent this issue by applying # wraps only if we're not tracing with dynamo. - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) return wrapper_fn @@ -1346,7 +1346,7 @@ def push_jvp(basis): # Dynamo does not support HOP composition if their inner function is # annotated with @functools.wraps(...). We circumvent this issue by applying # wraps only if we're not tracing with dynamo. - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) return wrapper_fn diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index 540d5c1a77f9..0d88aa0db2c6 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -77,7 +77,7 @@ def add(x: torch.Tensor, y: torch.Tensor): assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}" assert isinstance(dim, int), "dim must be an int, but got {type(dim)}" - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): return torch.compile(associative_scan, fullgraph=True)( combine_fn, input, dim diff --git a/torch/_utils.py b/torch/_utils.py index b0dcb448092a..5096b62618df 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -8,7 +8,7 @@ import warnings from collections import defaultdict from typing import Any, Callable, DefaultDict, Generic, List, Optional -from typing_extensions import deprecated, ParamSpec +from typing_extensions import ParamSpec import torch @@ -853,10 +853,6 @@ def classproperty(func): return _ClassPropertyDescriptor(func) -@deprecated( - "`torch._utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.", - category=FutureWarning, -) def is_compiling() -> bool: """ Indicates whether we are tracing/compiling with torch.compile() or torch.export(). diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index 6ad4280e95ae..621e46fc1989 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -86,7 +86,7 @@ def decompress(fut): decompressed_tensor.copy_(value) return decompressed_tensor - if torch.compiler.is_compiling(): + if torch._utils.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) @@ -135,7 +135,7 @@ def decompress(fut): decompressed_tensor.copy_(value) return decompressed_tensor - if torch.compiler.is_compiling(): + if torch._utils.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 013a2a9d1723..394fde457bb2 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -6,7 +6,7 @@ from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import _mesh_resources try: - from torch.compiler import is_compiling as is_torchdynamo_compiling + from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling except Exception: def is_torchdynamo_compiling(): # type: ignore[misc] return False diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index c71c838cfb85..34c593cd2c14 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1482,7 +1482,7 @@ def _lazy_init(self): def _should_disable_cpp_reducer(self) -> bool: return self._use_python_reducer and ( - torch.compiler.is_compiling() or self._force_to_disable_cpp_reducer + torch._utils.is_compiling() or self._force_to_disable_cpp_reducer ) def _pre_forward(self, *inputs, **kwargs): @@ -1495,7 +1495,7 @@ def _pre_forward(self, *inputs, **kwargs): h.remove() self._accum_grad_hooks.clear() - if not self._lazy_init_ran and not torch.compiler.is_compiling(): + if not self._lazy_init_ran and not torch._utils.is_compiling(): self._lazy_init() if self._delay_all_reduce_all_params: diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index eff24f159213..d6f19fb069ae 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -255,7 +255,7 @@ def _single_tensor_adadelta( has_complex: bool, ): # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -311,7 +311,7 @@ def _multi_tensor_adadelta( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -414,7 +414,7 @@ def adadelta( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adam.py b/torch/optim/adam.py index bff29613175a..86785be4ed17 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -354,7 +354,7 @@ def _single_tensor_adam( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -467,7 +467,7 @@ def _multi_tensor_adam( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -744,7 +744,7 @@ def adam( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index c2e39b788014..27caa5f9d81c 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -244,7 +244,7 @@ def _single_tensor_adamax( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -316,7 +316,7 @@ def _multi_tensor_adamax( return # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -425,7 +425,7 @@ def adamax( See :class:`~torch.optim.Adamax` for details. """ - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 2292c17f9e0f..00931bed0227 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -355,7 +355,7 @@ def _single_tensor_adamw( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -468,7 +468,7 @@ def _multi_tensor_adamw( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -729,7 +729,7 @@ def adamw( See :class:`~torch.optim.AdamW` for details. """ - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 454772670904..84c7602912d0 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -215,7 +215,7 @@ def _single_tensor_asgd( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type @@ -288,7 +288,7 @@ def _multi_tensor_asgd( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index a4a6d07b2ca8..cd2eeff92c05 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -305,7 +305,7 @@ def _single_tensor_nadam( exp_avg_sq = torch.view_as_real(exp_avg_sq) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == mu_product.device.type == step_t.device.type @@ -391,7 +391,7 @@ def _multi_tensor_nadam( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 498669e65fbd..582dc2105a5a 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -25,6 +25,7 @@ import torch import torch.utils.hooks as hooks +from torch._utils import is_compiling from torch.utils._foreach_utils import ( _get_foreach_kernels_supported_devices, _get_fused_kernels_supported_devices, @@ -97,14 +98,14 @@ def _use_grad(self, *args, **kwargs): def _get_value(x): # item is significantly faster than a cpu tensor in eager mode - if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and is_compiling(): return x else: return x.item() if isinstance(x, torch.Tensor) else x def _stack_if_compiling(x): - if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and is_compiling(): return torch.stack(x) else: return x @@ -145,7 +146,7 @@ def wrapper(func): # the capturable flag. If capturable=True, this is not a problem. @functools.wraps(func) def maybe_fallback(*args, **kwargs): - if torch.compiler.is_compiling() and ( + if is_compiling() and ( not kwargs.get("capturable", False) and has_state_steps and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda) @@ -418,7 +419,7 @@ def _cuda_graph_capture_health_check(self) -> None: # Thus, when compiling, inductor will determine if cudagraphs # can be enabled based on whether there is input mutation or CPU tensors. if ( - not torch.compiler.is_compiling() + not is_compiling() and torch.backends.cuda.is_built() and torch.cuda.is_available() ): @@ -505,7 +506,7 @@ def _group_tensors_by_device_and_dtype( """Groups a list of lists of tensors by device and dtype. Skips this step if we are compiling since this will occur during inductor lowering. """ - if torch.compiler.is_compiling(): + if is_compiling(): return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} else: return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] diff --git a/torch/optim/radam.py b/torch/optim/radam.py index d21973caeceb..1ecf20ffde86 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -272,7 +272,7 @@ def _single_tensor_radam( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -370,7 +370,7 @@ def _multi_tensor_radam( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 30b56779fc75..5311aa2fd6b8 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -277,7 +277,7 @@ def _single_tensor_rmsprop( step = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step.device.type @@ -350,7 +350,7 @@ def _multi_tensor_rmsprop( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert all( p.device.type == step.device.type @@ -468,7 +468,7 @@ def rmsprop( """ # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index 69043b48673e..ae34865f1c15 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -237,7 +237,7 @@ def _single_tensor_rprop( step = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step.device.type @@ -303,7 +303,7 @@ def _multi_tensor_rprop( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert all( p.device.type == step.device.type @@ -415,7 +415,7 @@ def rprop( """ # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index e682d83701d5..8cf26cfcf95c 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -430,7 +430,7 @@ def _multi_tensor_sgd( if not device_has_sparse_grad: # handle internal item() call if lr is a tensor - if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling(): + if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): grads_x_lr = torch._foreach_mul(device_grads, -lr) torch._foreach_add_(device_params, grads_x_lr) else: diff --git a/torch/testing/_internal/optests/generate_tests.py b/torch/testing/_internal/optests/generate_tests.py index d01f91563c92..70ee48274800 100644 --- a/torch/testing/_internal/optests/generate_tests.py +++ b/torch/testing/_internal/optests/generate_tests.py @@ -569,7 +569,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if ( torch.jit.is_tracing() or torch.jit.is_scripting() - or torch.compiler.is_compiling() + or torch._dynamo.is_compiling() ): return func(*args, **kwargs) # Pre-existing code may not use the .default overload. If we see an From 4460e481bcac8df5f8af8bdd29d8d6a81e310c0d Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 10 Jun 2024 11:35:40 -0300 Subject: [PATCH 589/706] Disable jacrev/jacfwd/hessian if compiling with dynamo (#128255) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128255 Approved by: https://github.com/zou3519 --- test/dynamo/test_higher_order_ops.py | 69 +++++++++++++++++++ ...ion_no_setup_context_transform_hessian_cpu | 0 ...tion_no_setup_context_transform_jacfwd_cpu | 0 ...essianCPU.test_jacfwd_different_levels_cpu | 0 test/functorch/test_eager_transforms.py | 4 +- torch/_functorch/eager_transforms.py | 4 ++ torch/testing/_internal/common_utils.py | 1 + 7 files changed, 76 insertions(+), 2 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu delete mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu delete mode 100644 test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 9b86a90b02f3..30dff83e12dd 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -2746,6 +2746,26 @@ def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0): wrapped_gm = backend.graphs[graph_idx] return wrapped_gm + def test_hessian_graph_break(self): + counters.clear() + + def wrapper_fn(x): + return torch.func.hessian(torch.sin)(x) + + x = torch.randn(4, 3) + expected = wrapper_fn(x) + got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) + self.assertEqual(expected, got) + self.assertEqual(len(counters["graph_break"]), 2) + self.assertEqual( + { + "'skip function disable in file _dynamo/decorators.py'": 1, + "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, + }, + {munge_exc(k): v for k, v in counters["graph_break"].items()}, + ) + + @unittest.expectedFailure def test_hessian(self): counters.clear() @@ -2880,6 +2900,7 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) + @unittest.expectedFailure def test_hessian_argnums(self): counters.clear() @@ -3032,6 +3053,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """ return (unflatten, child_3, child_2, _wrap_for_grad_1, child_4, o)""", ) + @unittest.expectedFailure def test_hessian_disable_capture(self): counters.clear() @@ -3058,6 +3080,26 @@ def wrapper_fn(x): ) self.assertEqual(actual, expected) + def test_jacrev_graph_break(self): + counters.clear() + + def wrapper_fn(x): + return torch.func.jacrev(torch.sin)(x) + + x = torch.randn(4, 3) + expected = wrapper_fn(x) + got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) + self.assertEqual(expected, got) + self.assertEqual(len(counters["graph_break"]), 2) + self.assertEqual( + { + "'skip function disable in file _dynamo/decorators.py'": 1, + "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, + }, + {munge_exc(k): v for k, v in counters["graph_break"].items()}, + ) + + @unittest.expectedFailure def test_jacrev(self): counters.clear() @@ -3134,6 +3176,7 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) + @unittest.expectedFailure def test_jacrev_two_tensors_argnums(self): counters.clear() @@ -3216,6 +3259,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) + @unittest.expectedFailure def test_jacrev_has_aux(self): counters.clear() @@ -3300,6 +3344,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) + @unittest.expectedFailure def test_jacrev_disable_capture(self): counters.clear() @@ -4246,6 +4291,26 @@ def wrapper_fn(x, y): self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) + def test_jacfwd_graph_break(self): + counters.clear() + + def wrapper_fn(x): + return torch.func.jacfwd(torch.sin)(x) + + x = torch.randn(4, 3) + expected = wrapper_fn(x) + got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) + self.assertEqual(expected, got) + self.assertEqual(len(counters["graph_break"]), 2) + self.assertEqual( + { + "'skip function disable in file _dynamo/decorators.py'": 1, + "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, + }, + {munge_exc(k): v for k, v in counters["graph_break"].items()}, + ) + + @unittest.expectedFailure def test_jacfwd(self): counters.clear() @@ -4329,6 +4394,7 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) + @unittest.expectedFailure def test_jacfwd_two_tensors_argnums(self): counters.clear() @@ -4418,6 +4484,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) + @unittest.expectedFailure def test_jacfwd_has_aux(self): counters.clear() @@ -4512,6 +4579,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) + @unittest.expectedFailure def test_jacfwd_randomness(self): counters.clear() @@ -4615,6 +4683,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) + @unittest.expectedFailure def test_jacfwd_disable_capture(self): counters.clear() diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu b/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index c767810beb85..8107f865f7bc 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -77,7 +77,6 @@ subtest, TEST_WITH_TORCHDYNAMO, TestCase, - xfailIfTorchDynamo, ) from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -2342,7 +2341,8 @@ def f(x): self.assertEqual(actual, expected) # https://github.com/pytorch/pytorch/issues/127036 - @xfailIfTorchDynamo + # it won't fail as jacrev/jacfwd were not inlined (see #128255) + # @xfailIfTorchDynamo @parametrize("_preallocate_and_copy", (True, False)) def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy): # With chunk_size=1, we shouldn't `vmap` and hence not be limited diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index fff6bd67838f..fbea5164014b 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -767,6 +767,8 @@ def compute_jacobian_preallocate_and_copy(): # wraps only if we're not tracing with dynamo. if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) + else: + wrapper_fn = torch._dynamo.disable(wrapper_fn) return wrapper_fn @@ -1348,6 +1350,8 @@ def push_jvp(basis): # wraps only if we're not tracing with dynamo. if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) + else: + wrapper_fn = torch._dynamo.disable(wrapper_fn) return wrapper_fn diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index fbfb5cdfa02b..5d72a444fdda 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -5013,6 +5013,7 @@ def repl_frame(m): return m.group(0) s = re.sub(r' File "([^"]+)", line \d+, in (.+)\n .+\n( +[~^]+ *\n)?', repl_frame, s) + s = re.sub(r'( Date: Mon, 10 Jun 2024 09:54:57 -0700 Subject: [PATCH 590/706] [aota] compiled forward outputs requires_grad alignment with eager (#128016) Original issue: https://github.com/pytorch/pytorch/issues/114338 We assume only two possible mutually exclusive scenarios: 1. Running compiled region for training (Any of inputs has requires_grad) - Produced differentiable outputs should have requires_grad. 2. Running compiled region for inference (None of inputs has requires_grad) - All outputs do not have requires_grad. Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1). With current state that means: 1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad 2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad() Pull Request resolved: https://github.com/pytorch/pytorch/pull/128016 Approved by: https://github.com/bdhirsh --- test/functorch/test_aotdispatch.py | 28 +++++++++++++++++++ .../_aot_autograd/runtime_wrappers.py | 7 ++++- torch/_functorch/aot_autograd.py | 5 ++-- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 5046347c8d0c..7bce7d558abb 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -5331,6 +5331,34 @@ def f(a, b): self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) + def test_aot_dispatch_output_requires_grad_in_no_grad(self): + def fn(x): + out1 = x.sin() + with torch.enable_grad(): + out2 = x.cos() + return out1, out2 + + inp_fns = [ + lambda: torch.ones(10, requires_grad=True), + lambda: torch.ones(10, requires_grad=False), + ] + + compiled_f = aot_function(fn, nop) + for inp_fn in inp_fns: + with torch.no_grad(): + ref_x = inp_fn() + ref_out = fn(ref_x) + x = inp_fn() + out = compiled_f(x) + for r, o in zip(ref_out, out): + self.assertEqual(r.requires_grad, o.requires_grad) + if ref_x.requires_grad: + with torch.enable_grad(): + (ref_out[0] + ref_out[1]).sum().backward() + (out[0] + out[1]).sum().backward() + self.assertEqual(ref_x.grad, x.grad) + assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3) + class TestAOTModuleSimplified(AOTTestCase): def test_aot_module_simplified(self): diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 9dc606113d84..0afa24ce4ee8 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -200,7 +200,12 @@ def runtime_wrapper(args: List[Any]): for idx in indices_of_inps_to_detach: if isinstance(args_[idx], torch.Tensor): args_[idx] = args_[idx].detach() - with torch.autograd._force_original_view_tracking(True): + # It's possible to have trace_joint inside user specified with no_grad() region, + # if there is a nested with enable_grad(), that forces some outputs to require gradients. + # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. + with torch.autograd._force_original_view_tracking( + True + ), torch.enable_grad(): all_outs = call_func_at_runtime_with_args( compiled_fn, args_, disable_amp=disable_amp, steal_args=True ) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index d6b084537567..c52a9cde0d55 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -544,9 +544,8 @@ def convert(idx, x): fake_flat_args = process_inputs(flat_args) - needs_autograd = ( - any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)) - and torch.is_grad_enabled() + needs_autograd = any( + x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) ) with enable_python_dispatcher(): From 3a2d0755a431a11edb3d7c63a9a0a93897ae1a00 Mon Sep 17 00:00:00 2001 From: laithsakka Date: Sun, 9 Jun 2024 07:17:30 -0700 Subject: [PATCH 591/706] enable test_ParameterList with dynamo if nn module inlining enabled only (#128308) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128308 Approved by: https://github.com/anijain2305 --- test/dynamo_expected_failures/TestNN.test_ParameterList | 0 test/test_nn.py | 1 + 2 files changed, 1 insertion(+) delete mode 100644 test/dynamo_expected_failures/TestNN.test_ParameterList diff --git a/test/dynamo_expected_failures/TestNN.test_ParameterList b/test/dynamo_expected_failures/TestNN.test_ParameterList deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/test_nn.py b/test/test_nn.py index 6dfac4f7ca1b..e1468a5d8328 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1063,6 +1063,7 @@ def check(): self.assertRaises(NotImplementedError, module_dict) self.assertRaises(NotImplementedError, module_dict, torch.rand(1, 3)) + @skipIfTorchDynamo() def test_ParameterList(self): def make_param(): return Parameter(torch.randn(2, 2)) From 6630dcd53c6699b7189f917aa5a0935a8cca82fe Mon Sep 17 00:00:00 2001 From: Andrea Frittoli Date: Mon, 10 Jun 2024 21:33:54 +0000 Subject: [PATCH 592/706] Add docstring for the torch.serialization.default_restore_location function (#128132) Fixes: #127887 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128132 Approved by: https://github.com/mikaylagawarecki --- torch/serialization.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torch/serialization.py b/torch/serialization.py index 1cab9b92c550..311aac28c8c5 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -391,6 +391,25 @@ def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.Untyp def default_restore_location(storage, location): + """ + Restores `storage` using a deserializer function registered for the `location`. + + This function looks in the registry for deserializer functions that match the `location`. + If found, it attempts to use them, in priority order, to restore `storage` until one + returns a not `None` result. If no deserializer can be found in the registry, or all found fail + to bear a result, it raises a `RuntimeError`. + + Args: + storage (STORAGE): the storage object to restore + location (str): the location tag associated with the storage object + + Returns: + storage: Optional[STORAGE] + + Raises: + RuntimeError: If no deserializer matching `location` is found in the registry or if + all matching ones return `None`. + """ for _, _, fn in _package_registry: result = fn(storage, location) if result is not None: From 58083ffb106fe48384f3d4767afe8555e7bbf8da Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 10 Jun 2024 11:28:49 -0700 Subject: [PATCH 593/706] Improve unbacked reasoning involving has internal overlap (#128332) Fixes https://github.com/pytorch/pytorch/issues/122477 Partially addresses https://github.com/pytorch/pytorch/issues/116336 This PR is slightly overkill: not only does it disable the overlap test when there are unbacked SymInts, it also improves the is non-overlapping and dense test for some more unbacked situations. We technically don't need the latter change, but I was already deep in the sauce and just went ahead and did it. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/128332 Approved by: https://github.com/lezcano --- test/test_dynamic_shapes.py | 70 ++++++++++++++++++++++++++++++++- torch/_meta_registrations.py | 10 ++++- torch/utils/_sympy/functions.py | 54 +++++++++++++++++++------ 3 files changed, 119 insertions(+), 15 deletions(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index b064d1896f0e..156b23742900 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -41,7 +41,11 @@ ) from torch.utils import _pytree as pytree from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils._sympy.functions import FloorDiv, Mod +from torch.utils._sympy.functions import ( + FloorDiv, + IsNonOverlappingAndDenseIndicator, + Mod, +) aten = torch.ops.aten @@ -777,6 +781,70 @@ def test_non_overlapping_and_dense(self): r = torch.empty_strided((a0, 7), (1, a0), device="meta") self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r)) + def test_non_overlapping_and_dense_unbacked(self): + shape_env = ShapeEnv() + u0 = shape_env.create_unbacked_symint() + torch._check_is_size(u0) + cf = torch.ops.aten.is_non_overlapping_and_dense.default + + self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1) + self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1) + self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta"))) + self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta"))) + + self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1) + self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1) + self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta"))) + self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta"))) + + Max = torch.sym_max + # NB: This only works because we're able to determine this tensor is + # contiguous. transpose(0, 1) makes it stop working + self.assertTrue( + cf( + torch.empty_strided( + ( + 2, + 3, + 1, + u0, + ), + (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), + device="meta", + ) + ) + ) + + def test_debug_has_internal_overlap_unbacked(self): + shape_env = ShapeEnv() + u0 = shape_env.create_unbacked_symint() + torch._check_is_size(u0) + cf = torch._debug_has_internal_overlap + self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0) + self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0) + self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0) + self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0) + Max = torch.sym_max + self.assertEqual( + cf( + torch.empty_strided( + ( + 2, + 3, + 1, + u0, + ), + (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), + device="meta", + ) + ), + 0, + ) + + # Wobbling these to zero is OK too + self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2) + self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2) + def test_specialize_zero_one(self): shape_env = ShapeEnv(specialize_zero_one=True) a0 = create_symint(shape_env, 5) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 3afe3a98d102..89262a7a203c 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -368,8 +368,14 @@ def meta_copy_(self, src, non_blocking=False): # which runs most of the meta checks that we care about. # In theory, we should make this more robust by carefully # auditing our C++ copy_() kernel and copying the checks here. - - if torch._debug_has_internal_overlap(self) == 1: # 1 == MemOverlap::Yes + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + # TODO: Ideally, we'd insert a deferred runtime assert here, but if we are + # calling an actual copy_, you'll get that automatically + # https://github.com/pytorch/pytorch/issues/122477 + if ( + not free_unbacked_symbols(self) and torch._debug_has_internal_overlap(self) == 1 + ): # 1 == MemOverlap::Yes raise RuntimeError( "more than one element of the written-to tensor refers to a single memory location" ) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 0d7c5a784c63..fd9921848d60 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import functools import math +import operator import sys import sympy @@ -581,22 +582,51 @@ class IsNonOverlappingAndDenseIndicator(sympy.Function): def eval(cls, *args): assert len(args) % 2 == 0 dim = len(args) // 2 - # TODO: it is possible to make progress evaluating this guard - # even if not all of the inputs are known. For example, a 2D - # tensor with non-0/1 sizes but strides (0, 1) is definitely - # false, because we know its numel > 1 but it's broadcasted - # in dim 0. + sizes = args[0:dim] + strides = args[dim:] + + # sym_node imported in torch.__init__. Local import to avoid an import cycle + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) + if all(isinstance(a, sympy.Integer) for a in args): - # sym_node imported in torch.__init__. Local import to avoid an import cycle - from torch.fx.experimental.symbolic_shapes import ( - eval_is_non_overlapping_and_dense, + return eval_is_non_overlapping_and_dense( + [int(a) for a in sizes], [int(a) for a in strides] ) - size_args = args[0:dim] - stride_args = args[dim:] - return eval_is_non_overlapping_and_dense( - [int(a) for a in size_args], [int(a) for a in stride_args] + if dim == 1: + # Manually implement the rank one short circuit + if strides[0].is_Number and strides[0] == 1: + return 1 + + if sizes[0].is_Number and sizes[0] < 2: + return 1 + + # return 0 case covered by case above + + # TODO: Inability to access size-obliviousness sucks: if we have a + # size oblivious test on a size-like unbacked SymInt, we could + # confidently return zero when we have a size-like u0 stride + # and a size-like u1 size. Maybe a fancy ValueRanges analysis for + # this function could help figure this out. + + if all(isinstance(a, sympy.Integer) for a in strides): + assert dim != 0 + # When all strides are integral, we can sort, and the size for the + # largest stride doesn't matter and can be arbitrarily symbolic + s_sizes, s_strides = zip( + *sorted(zip(sizes, strides), key=operator.itemgetter(1)) ) + # Put something arbitrary in the max size spot, it'll be ignored + if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]): + s_sizes = s_sizes[:-1] + (42,) + # We can reuse the regular eval, because it is invariant to + # permutation of dimensions + return eval_is_non_overlapping_and_dense( + [int(a) for a in s_sizes], [int(a) for a in s_strides] + ) + return None From a2d4fea87261cfb68d39e52eb67d17ed7323f19f Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 10 Jun 2024 11:12:32 -0700 Subject: [PATCH 594/706] [easy] Move state_dict hooks tests to test_module_hooks and decorate tests that call load_state_dict with swap (#126906) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126906 Approved by: https://github.com/albanD --- test/nn/test_load_state_dict.py | 29 ----------- test/nn/test_module_hooks.py | 89 +++++++++++++++++++++++++++++++++ test/test_nn.py | 52 ------------------- 3 files changed, 89 insertions(+), 81 deletions(-) diff --git a/test/nn/test_load_state_dict.py b/test/nn/test_load_state_dict.py index cd9540382cc1..1bb9f7e82572 100644 --- a/test/nn/test_load_state_dict.py +++ b/test/nn/test_load_state_dict.py @@ -3,14 +3,12 @@ import unittest from copy import deepcopy from itertools import product -from tempfile import NamedTemporaryFile import torch import torch.nn as nn from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, - IS_WINDOWS, parametrize, run_tests, skipIfCrossRef, @@ -206,33 +204,6 @@ def hook_fn( model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True) model.load_state_dict(model.state_dict(), strict=True) - @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") - @swap([True, False]) - def test_register_state_dict_pre_hook_backward_compat(self): - called = False - - def my_state_dict_pre_hook(*args, **kwargs): - nonlocal called - called = True - - m = nn.Linear(1, 1) - self.assertTrue(hasattr(m, "_state_dict_pre_hooks")) - delattr(m, "_state_dict_pre_hooks") - # Save and load, ensure we can still call state_dict - # without running into issues. - with NamedTemporaryFile() as f: - # Note that torch.save / torch.load is not recommended - # to save / load modules. - torch.save(m, f.name) - m = torch.load(f.name) - - # Ensure we can run state_dict without issues - _ = m.state_dict() - self.assertFalse(called) - m.register_state_dict_pre_hook(my_state_dict_pre_hook) - _ = m.state_dict() - self.assertTrue(called) - # fails swapping as LSTM installs weak references on the parameters @swap([False]) @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index f76837660302..dc4bead78242 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -21,6 +21,7 @@ parametrize as parametrize_test, run_tests, skipIfTorchDynamo, + swap, TestCase, ) @@ -549,6 +550,7 @@ def _hook_to_pickle(*args, **kwargs): class TestStateDictHooks(TestCase): + @swap([True, False]) def test_load_state_dict_pre_hook(self): m = nn.Linear(10, 10) m_state_dict = m.state_dict() @@ -613,6 +615,7 @@ def test_pickled_hook(self): m._register_load_state_dict_pre_hook(_hook_to_pickle, True) pickle.loads(pickle.dumps(m)) + @swap([True, False]) def test_load_state_dict_module_pre_hook(self): hook_called = 0 @@ -686,6 +689,7 @@ def __init__(self, mod): m.load_state_dict(state_dict) self.assertEqual(2, hook_called) + @swap([True, False]) def test_load_state_dict_post_hook(self): hook_called = 0 @@ -743,6 +747,7 @@ def load_hook_clear_incompatible(module, incompatible_keys): self.assertEqual([], ret.unexpected_keys) @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") + @swap([True, False]) def test_load_state_dict_post_hook_backward_compatibility(self): def my_post_load_hook(mod, _): nonlocal called @@ -771,6 +776,89 @@ def my_post_load_hook(mod, _): m.load_state_dict(sd) self.assertTrue(called) + def _test_register_state_dict_pre_hook(self, model, submodule): + _state_dict_prefix = "foo." + state_dict_pre_hook_count = 0 + keep_var_setting = False + + def my_state_dict_pre_hook(module, prefix, keep_vars): + self.assertEqual(keep_vars, keep_var_setting) + nonlocal state_dict_pre_hook_count + state_dict_pre_hook_count += 1 + self.assertTrue(prefix.startswith(_state_dict_prefix)) + + model.register_state_dict_pre_hook(my_state_dict_pre_hook) + # Test to ensure submodules run the hook as well. + submodule.register_state_dict_pre_hook(my_state_dict_pre_hook) + + def check_results(model): + nonlocal state_dict_pre_hook_count, keep_var_setting + for keep_var_setting in [True, False]: + _ = model.state_dict( + prefix=_state_dict_prefix, keep_vars=keep_var_setting + ) + self.assertEqual(2, state_dict_pre_hook_count) + state_dict_pre_hook_count = 0 + + # Test state dict works as expected after model construction + check_results(model) + # Test state dict works as expected after forward + model(torch.ones(10, 3)) + check_results(model) + + def test_register_state_dict_pre_hook(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = nn.Sequential( + nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3) + ) + + def forward(self, x): + return self.a(x) + + mod = MyModule() + self._test_register_state_dict_pre_hook(mod, mod.a) + + def test_register_state_dict_pre_hook_lazy_module(self): + class MyLazyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.LazyLinear(8) + self.layer2 = nn.LazyLinear(5) + + def forward(self, x): + return self.layer2(self.layer1(x)) + + mod = MyLazyModule() + self._test_register_state_dict_pre_hook(mod, mod.layer1) + + @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") + def test_register_state_dict_pre_hook_backward_compat(self): + called = False + + def my_state_dict_pre_hook(*args, **kwargs): + nonlocal called + called = True + + m = nn.Linear(1, 1) + self.assertTrue(hasattr(m, "_state_dict_pre_hooks")) + delattr(m, "_state_dict_pre_hooks") + # Save and load, ensure we can still call state_dict + # without running into issues. + with NamedTemporaryFile() as f: + # Note that torch.save / torch.load is not recommended + # to save / load modules. + torch.save(m, f.name) + m = torch.load(f.name) + + # Ensure we can run state_dict without issues + _ = m.state_dict() + self.assertFalse(called) + m.register_state_dict_pre_hook(my_state_dict_pre_hook) + _ = m.state_dict() + self.assertTrue(called) + class TestModuleGlobalHooks(TestCase): def tearDown(self): @@ -1553,6 +1641,7 @@ def parameter_registration_hook(module, name, parameter): instantiate_parametrized_tests(TestModuleHooks) +instantiate_parametrized_tests(TestStateDictHooks) if __name__ == "__main__": run_tests() diff --git a/test/test_nn.py b/test/test_nn.py index e1468a5d8328..2553db01ee6b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2282,58 +2282,6 @@ def test_state_dict(self): # Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545 self.assertNotWarn(lambda: l.state_dict(destination=dict()), "Should not warn kwarg destination w/o _metadata") - def _test_register_state_dict_pre_hook(self, model, submodule): - _state_dict_prefix = "foo." - state_dict_pre_hook_count = 0 - keep_var_setting = False - - def my_state_dict_pre_hook(module, prefix, keep_vars): - self.assertEqual(keep_vars, keep_var_setting) - nonlocal state_dict_pre_hook_count - state_dict_pre_hook_count += 1 - self.assertTrue(prefix.startswith(_state_dict_prefix)) - - model.register_state_dict_pre_hook(my_state_dict_pre_hook) - # Test to ensure submodules run the hook as well. - submodule.register_state_dict_pre_hook(my_state_dict_pre_hook) - - def check_results(model): - nonlocal state_dict_pre_hook_count, keep_var_setting - for keep_var_setting in [True, False]: - _ = model.state_dict(prefix=_state_dict_prefix, keep_vars=keep_var_setting) - self.assertEqual(2, state_dict_pre_hook_count) - state_dict_pre_hook_count = 0 - # Test state dict works as expected after model construction - check_results(model) - # Test state dict works as expected after forward - model(torch.ones(10, 3)) - check_results(model) - - def test_register_state_dict_pre_hook(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)) - - def forward(self, x): - return self.a(x) - - mod = MyModule() - self._test_register_state_dict_pre_hook(mod, mod.a) - - def test_register_state_dict_pre_hook_lazy_module(self): - class MyLazyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = nn.LazyLinear(8) - self.layer2 = nn.LazyLinear(5) - - def forward(self, x): - return self.layer2(self.layer1(x)) - - mod = MyLazyModule() - self._test_register_state_dict_pre_hook(mod, mod.layer1) - def test_extra_state(self): class SubModule(torch.nn.Module): From c38b3381a12a0ec033dd417827c530c4474b8165 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 10 Jun 2024 11:12:33 -0700 Subject: [PATCH 595/706] Make nn.Module state_dict load_state_dict pre-hook and state_dict post hook public (#126704) Fixes https://github.com/pytorch/pytorch/issues/75287 and https://github.com/pytorch/pytorch/issues/117437 - `nn.Module._register_state_dict_hook` --> add public `nn.Module.register_state_dict_post_hook` - Add a test as this API was previously untested - `nn.Module._register_load_state_dict_pre_hook` --> add public `nn.Module.register_load_state_dict_pre_hook` (remove the `with_module` flag, default it to `True` ~- For consistency with optimizer `load_state_dict_pre_hook` raised by @janeyx99, allow the pre-hook to return a new `state_dict`~ - Document issue pointed out by https://github.com/pytorch/pytorch/issues/117437 regarding `_register_state_dict_hook` semantic of returning a new state_dict only being respected for the root for private hook - Remove this for the public `register_state_dict_post_hook` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126704 Approved by: https://github.com/albanD ghstack dependencies: #126906 --- test/nn/test_load_state_dict.py | 2 +- test/nn/test_module_hooks.py | 42 +++++++++-- .../fx/_lower_to_native_backend.py | 2 +- .../_checkpoint/checkpoint_wrapper.py | 4 +- torch/nn/modules/module.py | 71 +++++++++++++------ 5 files changed, 89 insertions(+), 32 deletions(-) diff --git a/test/nn/test_load_state_dict.py b/test/nn/test_load_state_dict.py index 1bb9f7e82572..3ad7e9c3a639 100644 --- a/test/nn/test_load_state_dict.py +++ b/test/nn/test_load_state_dict.py @@ -201,7 +201,7 @@ def hook_fn( module_state_dict = module.state_dict() self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys())) - model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True) + model[0][0].register_load_state_dict_pre_hook(hook_fn) model.load_state_dict(model.state_dict(), strict=True) # fails swapping as LSTM installs weak references on the parameters diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index dc4bead78242..b2fddcdf0cbd 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -588,21 +588,28 @@ def hook_with_module( hook_called += 1 hook_called = 0 + # Test private API since this sets with_module=False which diverges from public API m_load._register_load_state_dict_pre_hook(hook_without_module) m_load.load_state_dict(m_state_dict) self.assertEqual(1, hook_called) hook_called = 0 - m_load._register_load_state_dict_pre_hook(hook_with_module, True) + m_load.register_load_state_dict_pre_hook(hook_with_module) m_load.load_state_dict(m_state_dict) self.assertEqual(2, hook_called) + # Test private API with with_module=True + hook_called = 0 + m_load._register_load_state_dict_pre_hook(hook_with_module, True) + m_load.load_state_dict(m_state_dict) + self.assertEqual(3, hook_called) + def test_no_extra_ref_to_module(self): try: gc.disable() m = nn.Linear(10, 10) - m._register_load_state_dict_pre_hook(_hook_to_pickle, True) + m.register_load_state_dict_pre_hook(_hook_to_pickle) weak_m = weakref.ref(m) del m @@ -612,7 +619,7 @@ def test_no_extra_ref_to_module(self): def test_pickled_hook(self): m = nn.Linear(10, 10) - m._register_load_state_dict_pre_hook(_hook_to_pickle, True) + m.register_load_state_dict_pre_hook(_hook_to_pickle) pickle.loads(pickle.dumps(m)) @swap([True, False]) @@ -678,14 +685,13 @@ def __init__(self, mod): mod = m hook_called = 0 + # Test private API since this sets with_module=False which diverges from public API mod._register_load_state_dict_pre_hook(mod.my_pre_load_hook) m.load_state_dict(state_dict) self.assertEqual(1, hook_called) hook_called = 0 - mod._register_load_state_dict_pre_hook( - mod.my_pre_load_hook_with_module, True - ) + mod.register_load_state_dict_pre_hook(mod.my_pre_load_hook_with_module) m.load_state_dict(state_dict) self.assertEqual(2, hook_called) @@ -859,6 +865,30 @@ def my_state_dict_pre_hook(*args, **kwargs): _ = m.state_dict() self.assertTrue(called) + def test_register_state_dict_post_hook(self): + def state_dict_post_hook(module, state_dict, prefix, local_metadata): + for name, param in module.named_parameters(recurse=False): + state_dict[prefix + name] = torch.nn.Parameter( + state_dict[prefix + name] + ) + + def register_linear_hook(module): + if isinstance(module, nn.Linear): + module.register_state_dict_post_hook(state_dict_post_hook) + + m = nn.Transformer( + d_model=4, nhead=2, num_encoder_layers=2, num_decoder_layers=2 + ) + m.apply(register_linear_hook) + + sd = m.state_dict() + + for k, v in m.state_dict().items(): + if "linear" in k or "out_proj" in k: + self.assertTrue(isinstance(v, torch.nn.Parameter)) + else: + self.assertFalse(isinstance(v, torch.nn.Parameter)) + class TestModuleGlobalHooks(TestCase): def tearDown(self): diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index 92620a169383..f36904c3f587 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -442,7 +442,7 @@ def load_arg(a): quantized_model = GraphModule(quantized_model, folded_graph) quantized_model._register_state_dict_hook(_save_packed_weight) - quantized_model._register_load_state_dict_pre_hook(_load_packed_weight, with_module=True) + quantized_model.register_load_state_dict_pre_hook(_load_packed_weight) return quantized_model def _get_module(node: Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]: diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index 86ab1de003db..a39082e7ea49 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -34,8 +34,8 @@ def __init__(self, mod): self._register_state_dict_hook(self._post_state_dict_hook) # load_state_dict pre-hook to allow loading back into # checkpoint-wrapped module. - self._register_load_state_dict_pre_hook( - self._pre_load_state_dict_hook, with_module=True + self.register_load_state_dict_pre_hook( + self._pre_load_state_dict_hook ) def forward(self, *args, **kwargs): diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f803d3f02a17..942d5b8a8f95 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -418,7 +418,9 @@ def forward(self, x): # As JIT does not support Set[int], this dict is used as a set, where all # hooks represented in this dict accept kwargs. _forward_pre_hooks_with_kwargs: Dict[int, bool] - _state_dict_hooks: Dict[int, Callable] + # The bool indicates whether the hook comes from the private method + # or the public method. + _state_dict_hooks: Dict[int, Tuple[Callable, bool]] _load_state_dict_pre_hooks: Dict[int, Callable] _state_dict_pre_hooks: Dict[int, Callable] _load_state_dict_post_hooks: Dict[int, Callable] @@ -1799,24 +1801,40 @@ def __delattr__(self, name): super().__delattr__(name) def _register_state_dict_hook(self, hook): - r"""Register a state-dict hook. + r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. - These hooks will be called with arguments: `self`, `state_dict`, - `prefix`, `local_metadata`, after the `state_dict` of `self` is set. - Note that only parameters and buffers of `self` or its children are - guaranteed to exist in `state_dict`. The hooks may modify `state_dict` - inplace or return a new one. + It should have the following signature:: + hook(module, state_dict, prefix, local_metadata) -> None or state_dict + + The registered hooks can modify the ``state_dict`` inplace or return a new one. + If a new ``state_dict`` is returned, it will only be respected if it is the root + module that :meth:`~nn.Module.state_dict` is called from. + """ + handle = hooks.RemovableHandle(self._state_dict_hooks) + # True indicates that the hook was registered via the private method + self._state_dict_hooks[handle.id] = (hook, True) + return handle + + def register_state_dict_post_hook(self, hook): + r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. + + It should have the following signature:: + hook(module, state_dict, prefix, local_metadata) -> None + + The registered hooks can modify the ``state_dict`` inplace. """ handle = hooks.RemovableHandle(self._state_dict_hooks) - self._state_dict_hooks[handle.id] = hook + # False indicates that the hook was registered via the public method + self._state_dict_hooks[handle.id] = (hook, False) return handle def register_state_dict_pre_hook(self, hook): r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. - These hooks will be called with arguments: ``self``, ``prefix``, - and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered - hooks can be used to perform pre-processing before the ``state_dict`` + It should have the following signature:: + hook(module, prefix, keep_vars) -> None + + The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. """ handle = hooks.RemovableHandle(self._state_dict_pre_hooks) @@ -1937,22 +1955,19 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): for name, module in self._modules.items(): if module is not None: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) - for hook in self._state_dict_hooks.values(): + for (hook, from_private) in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) - if hook_result is not None: + if from_private and hook_result is not None: destination = hook_result return destination def _register_load_state_dict_pre_hook(self, hook, with_module=False): - r"""Register a pre-hook for the :meth:`~torch.nn.Module.load_state_dict` method. - - These hooks will be called with arguments: `state_dict`, `prefix`, - `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, - `error_msgs`, before loading `state_dict` into `self`. These arguments - are exactly the same as those of `_load_from_state_dict`. + r"""See :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` for details. - If ``with_module`` is ``True``, then the first argument to the hook is - an instance of the module. + A subtle difference is that if ``with_module`` is set to ``False``, then the + hook will not take the ``module`` as the first argument whereas + :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` always takes the + ``module`` as the first argument. Arguments: hook (Callable): Callable hook that will be invoked before @@ -1964,8 +1979,20 @@ def _register_load_state_dict_pre_hook(self, hook, with_module=False): self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(hook, self if with_module else None) return handle + def register_load_state_dict_pre_hook(self, hook): + r"""Register a pre-hook to be run before module's :meth:`~nn.Module.load_state_dict` is called. + + It should have the following signature:: + hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950 + + Arguments: + hook (Callable): Callable hook that will be invoked before + loading the state dict. + """ + return self._register_load_state_dict_pre_hook(hook, with_module=True) + def register_load_state_dict_post_hook(self, hook): - r"""Register a post hook to be run after module's ``load_state_dict`` is called. + r"""Register a post-hook to be run after module's :meth:`~nn.Module.load_state_dict` is called. It should have the following signature:: hook(module, incompatible_keys) -> None From 583a56d5a8ed35fa69841526905700114547c927 Mon Sep 17 00:00:00 2001 From: loganthomas Date: Mon, 10 Jun 2024 22:17:31 +0000 Subject: [PATCH 596/706] DOC: add docstring to construct_and_record_rdzv_event() (#128189) Fixes #127902 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128189 Approved by: https://github.com/kurman --- docs/source/elastic/events.rst | 2 ++ torch/distributed/elastic/events/__init__.py | 34 ++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/docs/source/elastic/events.rst b/docs/source/elastic/events.rst index 86d0be8dad52..c32136d00302 100644 --- a/docs/source/elastic/events.rst +++ b/docs/source/elastic/events.rst @@ -10,6 +10,8 @@ API Methods .. autofunction:: torch.distributed.elastic.events.record +.. autofunction:: torch.distributed.elastic.events.construct_and_record_rdzv_event + .. autofunction:: torch.distributed.elastic.events.get_logging_handler Event Objects diff --git a/torch/distributed/elastic/events/__init__.py b/torch/distributed/elastic/events/__init__.py index db6cb639ef1c..9f6e1733518a 100644 --- a/torch/distributed/elastic/events/__init__.py +++ b/torch/distributed/elastic/events/__init__.py @@ -86,6 +86,40 @@ def construct_and_record_rdzv_event( local_id: Optional[int] = None, rank: Optional[int] = None, ) -> None: + """ + Initialize rendezvous event object and record its operations. + + Args: + run_id (str): The run id of the rendezvous. + message (str): The message describing the event. + node_state (NodeState): The state of the node (INIT, RUNNING, SUCCEEDED, FAILED). + name (str): Event name. (E.g. Current action being performed). + hostname (str): Hostname of the node. + pid (Optional[int]): The process id of the node. + master_endpoint (str): The master endpoint for the rendezvous store, if known. + local_id (Optional[int]): The local_id of the node, if defined in dynamic_rendezvous.py + rank (Optional[int]): The rank of the node, if known. + Returns: + None + Example: + >>> # See DynamicRendezvousHandler class + >>> def _record( + ... self, + ... message: str, + ... node_state: NodeState = NodeState.RUNNING, + ... rank: Optional[int] = None, + ... ) -> None: + ... construct_and_record_rdzv_event( + ... name=f"{self.__class__.__name__}.{get_method_name()}", + ... run_id=self._settings.run_id, + ... message=message, + ... node_state=node_state, + ... hostname=self._this_node.addr, + ... pid=self._this_node.pid, + ... local_id=self._this_node.local_id, + ... rank=rank, + ... ) + """ # We don't want to perform an extra computation if not needed. if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler): return From 2176ef7dfaf02dd6dbb8484a50c99d5fadf3ea0b Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 7 Jun 2024 16:26:47 -0700 Subject: [PATCH 597/706] [compiled autograd] support .backward(inputs=) (#128252) autograd already marks nodes as needed or not before calling calling compiled autograd. so our worklist already skips nodes not specified in the `inputs` kwarg. For the .backward(inputs=) case, I'm keeping the grads as outputs, just like for .grad(inputs=), this is to still guard on graph_output when we collect the nodes. This does not get DCE'd rn, and is ignored in the post graph bytecode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128252 Approved by: https://github.com/jansel --- test/inductor/test_compiled_autograd.py | 150 +++++++++++++++--- test/test_autograd.py | 17 ++ .../csrc/dynamo/python_compiled_autograd.cpp | 3 - 3 files changed, 142 insertions(+), 28 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 776496f9331f..e09928cf5576 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -581,7 +581,7 @@ def model(x): self.check_output_and_recompiles(fn) - def test_output_nodes(self): + def test_output_nodes_all_leaves(self): def fn(): y = torch.randn(1, 4, requires_grad=True) z = torch.randn(1, 4, requires_grad=True) @@ -593,7 +593,7 @@ def model(x): x = torch.randn([1, 4]) result = model(x).sum() - gy, gz = torch.autograd.grad(result, [y, z]) + gy, gz = torch.autograd.grad(result, inputs=[y, z]) assert y.grad is None assert z.grad is None yield gy @@ -601,6 +601,111 @@ def model(x): self.check_output_and_recompiles(fn) + def test_output_nodes_some_leaves(self): + def fn(): + class UnreachableBwd(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, gO): + raise RuntimeError + + y = torch.randn(1, 4, requires_grad=True) + z = torch.randn(1, 4, requires_grad=True) + + def model(x): + return torch.sigmoid(UnreachableBwd.apply(y) * z) + + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + gz = torch.autograd.grad(result, inputs=[z]) + assert y.grad is None + assert z.grad is None + yield gz + + self.check_output_and_recompiles(fn) + + def test_no_output_nodes_all_leaves(self): + def fn(): + y = torch.randn(1, 4, requires_grad=True) + z = torch.randn(1, 4, requires_grad=True) + + def model(x): + return torch.sigmoid(x * z + torch.sin(y) + torch.cos(y)) + + for _ in range(3): + x = torch.randn([1, 4]) + result = model(x).sum() + out = result.backward() + assert out is None + assert y.grad is not None + assert z.grad is not None + yield y.grad + yield z.grad + y.grad = None + z.grad = None + + self.check_output_and_recompiles(fn) + + def test_no_output_nodes_some_leaves(self): + def fn(): + class UnreachableBwd(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, gO): + raise RuntimeError + + y = torch.randn(1, 4, requires_grad=True) + z = torch.randn(1, 4, requires_grad=True) + a = torch.randn(1, 4, requires_grad=True) + + def model(x): + return torch.sigmoid(x * y * z * UnreachableBwd.apply(a)) + + for _ in range(3): + x = torch.randn([1, 4]) + result = model(x).sum() + out = result.backward(inputs=[y, z]) + assert out is None + assert y.grad is not None + assert z.grad is not None + assert a.grad is None + yield y.grad + yield z.grad + y.grad = None + z.grad = None + + self.check_output_and_recompiles(fn) + + def test_no_output_nodes_different_leaves_will_recompile(self): + def fn(): + def fwd(x, y, z): + out = x * y # MulBackward0 + out2 = out * z # MulBackward0 + return out2.sum() # SumBackward0 + + x = torch.randn(5, requires_grad=True) + y = torch.randn(5, requires_grad=True) + z = torch.randn(5, requires_grad=True) + loss = fwd(x, y, z) + torch.compile(lambda: torch.autograd.backward(loss, inputs=[x]))() + yield x.grad + x.grad = None + + loss = fwd(x, y, z) + torch.compile(lambda: torch.autograd.backward(loss, inputs=[y]))() + yield y.grad + + # Guarded by TensorArg id, mismatch on last MulBackward0 + self.check_output_and_recompiles(fn, 2) + def test_dynamic_shapes(self): def fn(): model = torch.nn.Sequential( @@ -1986,7 +2091,18 @@ def wrap_test_class(orig_cls): return cls -known_graph_breaks_tests = {} +known_graph_breaks_tests = { + "test_hook_none", # uses assert in hook + "test_post_accumulate_grad_hook_e2e", # optim.Adam manually graph breaks + "test_tensor_hooks_inplace", # uses assert in hook + "test_tensor_hooks_inplace_over_view", # uses assert in hook + "test_grad_fn_prehooks", # uses assert in hook + "test_grad_fn_prehooks_multiple_outputs", # uses assert in hook + "test_grad_fn_prehooks_remove_hooks", # uses handle.remove() in hook + "test_tensor_hooks_inplace_multiple_outputs", # uses assert in hook + "test_hooks", # uses assert in hook + "test_accumulate_grad_posthooks_can_observe_tensor_prehook", # allclose +} # These groups of tests aren't supported yet known_failures_re = re.compile( @@ -2004,23 +2120,14 @@ def wrap_test_class(orig_cls): "test_saved_variable_saved_original_inplace_detach", # AssertionError: RuntimeError not raised "test_saving_variable_to_disk", # Cannot call numel() on tensor with symbolic sizes/strides "test_setitem_mask", # torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: It appears that you're - "test_tensor_hooks_inplace_over_view", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} - "test_tensor_hooks_inplace", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} "test_wrapped_number_saved_variable_hooks", # RuntimeError: this hook should not be called - "test_accumulate_grad_posthooks_can_observe_tensor_prehook", # data dependent operator: aten.allclose.default "test_accumulate_grad_tensor_reference", # backend='inner_compiler' raised: "test_anomaly_grad_warnings", # "one of the variables needed for gradient computation has been modified by an... "test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd - "test_backward_with_inputs", # specifying inputs= with .backward() not yet implemented for compiled autograd "test_current_node", # TorchDispatchMode not yet implemented for compiled autograd "test_custom_function_exception", # "Simulate error on backward pass" does not match "type object 'SimulateBackwa... "test_grad_batched_grad", # Cannot access storage of BatchedTensorImpl - "test_grad_unreachable_discovery", # specifying inputs= with .backward() not yet implemented for compiled autograd "test_index_backward_does_not_save_tensor", # dynamic shape operator: aten.nonzero.default - "test_post_accumulate_grad_hook_e2e", # tensor_post_acc_grad_hooks not implemented for compiled autograd - "test_post_accumulate_grad_hook_gets_cleaned_up", # tensor_post_acc_grad_hooks not implemented for compiled autograd - "test_post_accumulate_grad_hook_multiple_hooks", # tensor_post_acc_grad_hooks not implemented for compiled autograd - "test_post_accumulate_grad_hook_multiple_tensors", # tensor_post_acc_grad_hooks not implemented for compiled autograd "test_post_accumulate_grad_hook_ordering", # tensor_post_acc_grad_hooks not implemented for compiled autograd "test_post_accumulate_grad_hook_returns_not_None", # "hooks should return None." does not match "test_reentrant_child_error", # "Simulate error" does not match "type object 'ReentrantFunc' has no attribute... @@ -2052,21 +2159,20 @@ def wrap_test_class(orig_cls): "test_hessian_vector", # RuntimeError: compiled_autograd does not support create_graph "test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_False", # AttributeError: type object "test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_True", # AttributeError: type object - "test_hook_edge_case_when_called_with_grad", # RuntimeError: specifying inputs= with .backward() not yet - "test_hooks", # torch._dynamo.exc.Unsupported: inline in skipfiles + "test_hook_edge_case_when_called_with_grad", # retains_grad_hooks NYI "test_inplace_on_view_backward", # RuntimeError: compiled_autograd does not support create_graph - "test_multi_grad_any_hooks", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd - "test_multi_grad_all_hooks", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd + "test_multi_grad_any_hooks", # register_multi_grad_hook NYI + "test_multi_grad_all_hooks", # retains_grad_hooks NYI "test_nested_anomaly_detect_nan", # RuntimeError: compiled_autograd does not support create_graph "test_nested_anomaly_printstack_cleanup", # RuntimeError: compiled_autograd does not support create_graph "test_once_differentiable", # RuntimeError: compiled_autograd does not support create_graph - "test_prehook_ordering", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd + "test_prehook_ordering", # retains_grad_hooks NYI "test_retain_grad", # RuntimeError: retains_grad_hooks not implemented for compiled autograd "test_saved_variable_packing_unpacking_saved_original_with_hooks", # RuntimeError: compiled_autograd "test_select_sum", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients "test_unrelated_inputs", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients - "test_will_engine_execute_node", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd - "test_backward_to_node", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd + "test_will_engine_execute_node", # retains_grad_hooks NYI + "test_backward_to_node", # retains_grad_hooks NYI "test_anomaly_detect_nan", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function aten.add.Tensor( "test_autograd_multiple_views_python", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable( "test_autograd_node_isinstance", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsInstance @@ -2083,11 +2189,7 @@ def wrap_test_class(orig_cls): "test_deep_reentrant", # torch._dynamo.exc.InternalTorchDynamoError: '<' not supported between instances of "test_dont_materialize_grads", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNone "test_function_returns_undefined_tensor", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function - "test_grad_fn_prehooks", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} - "test_grad_fn_prehooks_multiple_outputs", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: - "test_grad_fn_prehooks_remove_hooks", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: RemovableHandle.remove "test_grad_mode_restored_reentrant", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertTrue - "test_hook_none", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNotNone "test_invalid_gradients", # AssertionError: "expected shape" does not match "The size of tensor a (5) must match "test_mark_non_differentiable_mixed", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertTrue "test_materialize_grads", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} @@ -2107,7 +2209,6 @@ def wrap_test_class(orig_cls): "test_set_materialize_non_diff_grads", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNone "test_setup_context_when_forward_has_default_args", # torch._dynamo.exc.Unsupported: call_function args "test_simple_reentrant", # torch._dynamo.exc.Unsupported: call_method SkipFunctionVariable() sum [] {} - "test_tensor_hooks_inplace_multiple_outputs", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} "test_lobpcg", # torch._dynamo.exc.Unsupported: 'call_function LOBPCGAutogradFunction.backward in skip_files "test_backward_dict_grad_for_nontensor", # AssertionError: "non-Tensor-like types" does not match "'skip function "test_backward_dict_invalid_keys", # AssertionError: "to have keys {'x'}" does not match "'skip function @@ -2120,7 +2221,6 @@ def wrap_test_class(orig_cls): "test_backward_tensorlist_input_requires_list_grads_none_or_Tensor", # AssertionError: "None or Tensor" "test_backward_tensorlist_input_requires_list_grads_with_same_numel", # AssertionError: "3 gradients "test_save_for_backward_inputs_are_namedtuple", # torch._dynamo.exc.Unsupported: 'skip function - "test_autograd_function_backed_op", # RuntimeError: compiled_args not implemented "test_setitem", # AssertionError: Tensor-likes are not close! "test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors) "test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads diff --git a/test/test_autograd.py b/test/test_autograd.py index c032319fa160..ce5b4234b829 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1342,6 +1342,23 @@ def prehook(gI): b.backward() + def test_accumulate_grad_posthooks_should_not_execute(self): + def tensor_prehook(g): + raise RuntimeError + + def posthook(gO, gI): + raise RuntimeError + + a = torch.tensor(1.0, requires_grad=True) + a.register_hook(tensor_prehook) + b = torch.tensor(1.0, requires_grad=True) + c = a.clone() + acc = c.grad_fn.next_functions[0][0] + acc.register_hook(posthook) + + out = a + b + c + out.sum().backward(inputs=[b]) + def test_hook_edge_case_when_called_with_grad(self): # grad executes the tensor hooks of the next node but not # grad_fn pre hooks or the post hooks diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 6cdce255d7df..2e5cb3bfab02 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -630,9 +630,6 @@ variable_list compiled_autograd( GraphTask& graph_task, bool accumulate_grad, const edge_list& output_edges) { - TORCH_CHECK( - output_edges.empty() || !accumulate_grad, - "specifying inputs= with .backward() not yet implemented for compiled autograd") TORCH_CHECK( c10::impl::TorchDispatchModeTLS::stack_len() == 0, "TorchDispatchMode not yet implemented for compiled autograd") From 4bbadeee8af837e95fc7742f36639a5710c38247 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 10 Jun 2024 22:46:01 +0000 Subject: [PATCH 598/706] Revert "Set simdlen based on ATEN_CPU_CAPABILITY (#123514)" This reverts commit b66e3f0957b96b058c9b632ca60833d9717a9d8a. Reverted https://github.com/pytorch/pytorch/pull/123514 on behalf of https://github.com/clee2000 due to broke test/inductor/test_torchinductor.py::CpuTests::test_new_cpp_build_logical_cpu on periodic test on the no gpu tests https://hud.pytorch.org/pytorch/pytorch/commit/b66e3f0957b96b058c9b632ca60833d9717a9d8a https://github.com/pytorch/pytorch/actions/runs/9453518547/job/26040077301 ([comment](https://github.com/pytorch/pytorch/pull/123514#issuecomment-2159433432)) --- test/inductor/test_cpu_repro.py | 137 ++---------------------- test/inductor/test_extension_backend.py | 4 - test/inductor/test_torchinductor.py | 42 +------- torch/_dynamo/testing.py | 6 -- torch/_inductor/codecache.py | 34 +----- torch/_inductor/codegen/cpp_prefix.h | 1 - 6 files changed, 19 insertions(+), 205 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 1f04b71a961b..b2ab30832e06 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4,7 +4,6 @@ import functools import itertools import math -import os import platform import sys import unittest @@ -67,13 +66,12 @@ check_model = test_torchinductor.check_model requires_vectorization = unittest.skipUnless( - codecache.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default", - "Does not support vectorization", + codecache.valid_vec_isa_list(), "Does not support vectorization" ) def check_metrics_vec_kernel_count(num_expected_vec_kernels): - if codecache.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default": + if codecache.valid_vec_isa_list(): assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels @@ -1582,71 +1580,6 @@ def fn(x): metrics.reset() self.common(fn, (value,)) - @unittest.skipIf( - not codecache.valid_vec_isa_list() - or "avx2" in [str(vec_isa) for vec_isa in codecache.valid_vec_isa_list()], - "Does not support vectorization or not s390x/neon machine", - ) - @patch("torch.cuda.is_available", lambda: False) - def test_auto_zvec_neon_simd(self): - vec_zvec_neon = codecache.valid_vec_isa_list()[0] - self.assertTrue(vec_zvec_neon.bit_width() == 256) - - with config.patch({"cpp.simdlen": 0}): - isa = codecache.pick_vec_isa() - self.assertFalse(isa) - - with config.patch({"cpp.simdlen": 1}): - isa = codecache.pick_vec_isa() - self.assertFalse(isa) - - with config.patch({"cpp.simdlen": 257}): - isa = codecache.pick_vec_isa() - self.assertFalse(isa) - - with config.patch({"cpp.simdlen": 256}): - isa = codecache.pick_vec_isa() - self.assertTrue(isa == vec_zvec_neon) - - pre_var = os.getenv("ATEN_CPU_CAPABILITY") - if pre_var: - os.environ.pop("ATEN_CPU_CAPABILITY") - - try: - with config.patch({"cpp.simdlen": None}): - isa = codecache.pick_vec_isa() - self.assertTrue(isa == vec_zvec_neon) - - with config.patch({"cpp.simdlen": None}): - os.environ["ATEN_CPU_CAPABILITY"] = "avx2" - isa = codecache.pick_vec_isa() - self.assertTrue(isa == vec_zvec_neon) - - with config.patch({"cpp.simdlen": None}): - os.environ["ATEN_CPU_CAPABILITY"] = "avx512" - isa = codecache.pick_vec_isa() - self.assertTrue(isa == vec_zvec_neon) - - with config.patch({"cpp.simdlen": None}): - os.environ["ATEN_CPU_CAPABILITY"] = "default" - isa = codecache.pick_vec_isa() - self.assertFalse(isa) - - with config.patch({"cpp.simdlen": None}): - os.environ["ATEN_CPU_CAPABILITY"] = "neon" - isa = codecache.pick_vec_isa() - self.assertTrue(isa == vec_zvec_neon) - - with config.patch({"cpp.simdlen": None}): - os.environ["ATEN_CPU_CAPABILITY"] = "zvector" - isa = codecache.pick_vec_isa() - self.assertTrue(isa == vec_zvec_neon) - finally: - if pre_var: - os.environ["ATEN_CPU_CAPABILITY"] = pre_var - elif os.getenv("ATEN_CPU_CAPABILITY"): - os.environ.pop("ATEN_CPU_CAPABILITY") - @unittest.skipIf( platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(), "Does not support vectorization or not x86_64 machine", @@ -1662,6 +1595,13 @@ def test_auto_simd(self): self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) + with config.patch({"cpp.simdlen": None}): + isa = codecache.pick_vec_isa() + if vec_avx512 in codecache.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + with config.patch({"cpp.simdlen": 0}): isa = codecache.pick_vec_isa() self.assertFalse(isa) @@ -1691,60 +1631,6 @@ def test_auto_simd(self): isa = codecache.pick_vec_isa() self.assertTrue(isa == vec_avx2) - pre_var = os.getenv("ATEN_CPU_CAPABILITY") - if pre_var: - os.environ.pop("ATEN_CPU_CAPABILITY") - - try: - with config.patch({"cpp.simdlen": None}): - isa = codecache.pick_vec_isa() - if vec_avx512 in codecache.valid_vec_isa_list(): - self.assertTrue(isa == vec_avx512) - else: - self.assertTrue(isa == vec_avx2) - - with config.patch({"cpp.simdlen": None}): - os.environ["ATEN_CPU_CAPABILITY"] = "avx2" - isa = codecache.pick_vec_isa() - if vec_avx512 in codecache.valid_vec_isa_list(): - self.assertTrue(isa == vec_avx2) - elif vec_avx2 in codecache.valid_vec_isa_list(): - self.assertTrue(isa == vec_avx2) - - with config.patch({"cpp.simdlen": None}): - os.environ["ATEN_CPU_CAPABILITY"] = "avx512" - isa = codecache.pick_vec_isa() - if vec_avx512 in codecache.valid_vec_isa_list(): - self.assertTrue(isa == vec_avx512) - else: - self.assertTrue(isa == vec_avx2) - - with config.patch({"cpp.simdlen": None}): - os.environ["ATEN_CPU_CAPABILITY"] = "default" - isa = codecache.pick_vec_isa() - self.assertFalse(isa) - - with config.patch({"cpp.simdlen": None}): - os.environ["ATEN_CPU_CAPABILITY"] = "neon" - isa = codecache.pick_vec_isa() - if vec_avx512 in codecache.valid_vec_isa_list(): - self.assertTrue(isa == vec_avx512) - else: - self.assertTrue(isa == vec_avx2) - - with config.patch({"cpp.simdlen": None}): - os.environ["ATEN_CPU_CAPABILITY"] = "zvector" - isa = codecache.pick_vec_isa() - if vec_avx512 in codecache.valid_vec_isa_list(): - self.assertTrue(isa == vec_avx512) - else: - self.assertTrue(isa == vec_avx2) - finally: - if pre_var: - os.environ["ATEN_CPU_CAPABILITY"] = pre_var - elif os.getenv("ATEN_CPU_CAPABILITY"): - os.environ.pop("ATEN_CPU_CAPABILITY") - @requires_vectorization @patch("torch.cuda.is_available", lambda: False) def test_masked_fill_softmax(self): @@ -3485,7 +3371,6 @@ def forward(self, idx, x): self.common(m, (idx, x)) check_metrics_vec_kernel_count(1) - @requires_vectorization def test_embedding_vec_bf16(self): class M(torch.nn.Module): def __init__(self): @@ -3770,7 +3655,7 @@ def fn(x): x = torch.randint(0, 100, (819,), dtype=torch.int64) metrics.reset() self.common(fn, (x,)) - check_metrics_vec_kernel_count(1) + assert metrics.generated_cpp_vec_kernel_count == 1 def test_reduction_float_to_int64(self): # https://github.com/pytorch/pytorch/issues/124821 @@ -3780,7 +3665,7 @@ def fn(x): x = torch.randint(0, 100, (22, 51), dtype=torch.int64) metrics.reset() self.common(fn, (x,)) - check_metrics_vec_kernel_count(1) + assert metrics.generated_cpp_vec_kernel_count == 1 @config.patch({"cpp.dynamic_threads": True}) def test_reduction_with_dynamic_threads(self): diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py index a3bad9582d8c..3cb473255e74 100644 --- a/test/inductor/test_extension_backend.py +++ b/test/inductor/test_extension_backend.py @@ -8,7 +8,6 @@ import torch._dynamo import torch.utils.cpp_extension from torch._C import FileCheck -from torch._dynamo.testing import expectedFailureScalar try: from extension_backends.cpp.extension_codegen_backend import ( @@ -104,9 +103,6 @@ def tearDown(self): # return the working directory (see setUp) os.chdir(self.old_working_dir) - # Fails when testing the scalar version - # See https://github.com/pytorch/pytorch/issues/126372. - @expectedFailureScalar def test_open_device_registration(self): torch.utils.rename_privateuse1_backend("extension_device") torch._register_device_module("extension_device", self.module) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index fe23b07e12da..53167a83ecd8 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -34,7 +34,6 @@ from torch._dynamo.testing import ( CompileCounterWithBackend, expectedFailureCodegenDynamic, - expectedFailureScalar, rand_strided, same, skipIfPy312, @@ -1316,9 +1315,6 @@ def fn(a): self.common(fn, (torch.randn(1024),)) - # Fails when testing the scalar version - # See https://github.com/pytorch/pytorch/issues/128029. - @expectedFailureScalar @skipIfRocm @config.patch(debug_index_asserts=False) def test_neg_index(self): @@ -1581,40 +1577,16 @@ def test_multilayer_var(self): def fn(a): return torch.var(a) - atol = None - rtol = None - if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default": - atol = 1e-4 - rtol = 1e-4 - self.common( - fn, - ((torch.rand((10, 3, 352, 352), dtype=torch.float32),)), - rtol=rtol, - atol=atol, - ) - self.common( - fn, ((torch.rand((14923), dtype=torch.float32),)), rtol=rtol, atol=atol - ) + self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float32),))) + self.common(fn, ((torch.rand((14923), dtype=torch.float32),))) @skipCPUIf(IS_MACOS, "fails on macos") def test_multilayer_var_lowp(self): def fn(a): return torch.var(a) - atol = None - rtol = None - if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default": - atol = 1e-3 - rtol = 1e-3 - self.common( - fn, - (torch.rand((16, 16, 352, 352), dtype=torch.float16),), - rtol=rtol, - atol=atol, - ) - self.common( - fn, (torch.rand((14923), dtype=torch.float16),), rtol=rtol, atol=atol - ) + self.common(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),)) + self.common(fn, (torch.rand((14923), dtype=torch.float16),)) def test_split_cumsum(self): def fn(a): @@ -8227,7 +8199,7 @@ def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4): rand_strided(shape, stride, dtype).requires_grad_(True).add(1) for shape, stride, dtype in args ] - self.common(forward, args, atol=1e-05, rtol=1e-05) + self.common(forward, args) @requires_gpu() def test_tmp_not_defined_issue3(self): @@ -9309,7 +9281,6 @@ def func(arg0_1): # To support this behavior, we need to allow const-propping tensors that store symint data. # For now, dynamo will explicitly graph break when it encounters user code with this behavior. @expectedFailureCodegenDynamic - @expectedFailureScalar def test_AllenaiLongformerBase_repro(self): def fn(query, scores, window_overlap): batch_size, seq_len, num_heads, _ = query.size() @@ -9345,9 +9316,6 @@ def fn(query, scores, window_overlap): opt_fn = torch._dynamo.optimize("inductor")(fn) _, code = run_and_get_cpp_code(opt_fn, *args) print(code) - # When testing the scalar version, i.e., ATEN_CPU_CAPABILITY=default, - # static_cast(256) is not found, but static_cast(256). - # See https://github.com/pytorch/pytorch/issues/126262. FileCheck().check_count( "static_cast(256)", 1, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index d254a5e261ed..527e0138fc25 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -381,12 +381,6 @@ def expectedFailureDynamicWrapper(fn): return fn -def expectedFailureScalar(fn): - if os.getenv("ATEN_CPU_CAPABILITY") == "default": - return unittest.expectedFailure(fn) - return fn - - def reset_rng_state(use_xla=False): torch.manual_seed(1337) random.seed(1337) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index d497272d00a3..ae8453660813 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1465,31 +1465,6 @@ def _check_and_append_supported_isa( supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()] -def get_isa_from_cpu_capability( - capability: str | None, vec_isa_list: List[VecISA], invalid_vec_isa: InvalidVecISA -): - # VSX is not supported in inductor - capability_to_isa_str = { - "default": "INVALID_VEC_ISA", - "neon": "asimd", - "zvector": "zvector", - "avx2": "avx2", - "avx512": "avx512", - } - if capability in capability_to_isa_str.keys(): - isa_str = capability_to_isa_str[capability] - if isa_str == "INVALID_VEC_ISA": - return invalid_vec_isa - for vec_isa in vec_isa_list: - if isa_str == str(vec_isa): - return vec_isa - - if capability: - warnings.warn(f"ignoring invalid value for ATEN_CPU_CAPABILITY {capability}") - - return vec_isa_list[0] - - # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content # might have too much redundant content that is useless for ISA check. Hence, # we only cache some key isa information. @@ -1532,13 +1507,10 @@ def pick_vec_isa() -> VecISA: if not _valid_vec_isa_list: return invalid_vec_isa - # If the simdlen is None, set simdlen based on the environment ATEN_CPU_CAPABILITY - # to control CPU vec ISA - + # If the simdlen is None, it indicates determine the vectorization length automatically if config.cpp.simdlen is None: - return get_isa_from_cpu_capability( - os.getenv("ATEN_CPU_CAPABILITY"), _valid_vec_isa_list, invalid_vec_isa - ) + assert _valid_vec_isa_list + return _valid_vec_isa_list[0] for isa in _valid_vec_isa_list: if config.cpp.simdlen == isa.bit_width(): diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 1492023eed38..6898a8a52112 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -24,7 +24,6 @@ #include #include #include -#include #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) #define INDUCTOR_USE_VECTOR_TYPES() 1 From a287ff75d079e39e0c20d1bb5e26aa01ad86b8eb Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Mon, 10 Jun 2024 23:02:48 +0000 Subject: [PATCH 599/706] Use init_torchbind_implementations in inductor torchbind tests. (#128341) Summary: To unify how we load the torch bind libraries for testing. Test Plan: Existing tests. Differential Revision: D58372372 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128341 Approved by: https://github.com/angelayi --- test/inductor/test_torchbind.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index e1bb0ad36d0b..3350e8e895f3 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -1,6 +1,4 @@ # Owner(s): ["module: functorch"] -import unittest - import torch import torch._dynamo import torch._functorch @@ -8,30 +6,14 @@ import torch._inductor.decomposition from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.common_utils import ( - find_library_location, - IS_FBCODE, - IS_MACOS, - IS_SANDCASTLE, - IS_WINDOWS, -) + +from torch.testing._internal.torchbind_impls import init_torchbind_implementations class TestTorchbind(TestCase): def setUp(self): super().setUp() - if IS_MACOS: - raise unittest.SkipTest("non-portable load_library call used in test") - elif IS_SANDCASTLE or IS_FBCODE: - torch.ops.load_library( - "//caffe2/test/cpp/jit:test_custom_class_registrations" - ) - elif IS_WINDOWS: - lib_file_path = find_library_location("torchbind_test.dll") - torch.ops.load_library(str(lib_file_path)) - else: - lib_file_path = find_library_location("libtorchbind_test.so") - torch.ops.load_library(str(lib_file_path)) + init_torchbind_implementations() def get_exported_model(self): """ From 05711eece92ae051d3f0942f1fb84a38028bab15 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 10 Jun 2024 11:29:33 -0700 Subject: [PATCH 600/706] [dynamo][inlining inbuilt modules] Ensure BC for nn_module_stack (#128295) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128295 Approved by: https://github.com/ydwu4 --- test/dynamo/test_repros.py | 48 +++++++++++++++ torch/_dynamo/variables/builder.py | 2 +- torch/_dynamo/variables/nn_module.py | 77 ++++++++++++++----------- torch/_dynamo/variables/user_defined.py | 19 +++++- 4 files changed, 110 insertions(+), 36 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 8515b6a7f735..8dd1b91f43f7 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -32,6 +32,7 @@ import torch._functorch.config import torch.library +import torch.utils._pytree as pytree from torch import nn from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import CompileCounter, rand_strided, same @@ -5054,6 +5055,53 @@ def fn(x, y): opt_fn = torch.compile(fn, backend="eager") self.assertEqual(fn(x, y), opt_fn(x, y)) + def test_nn_module_stack_bc(self): + from torch._dynamo.mutation_guard import GenerationTracker + + def compiler(gm, *args): + module_stacks = [ + node.meta.get("nn_module_stack", None) for node in gm.graph.nodes + ] + module_stacks, _ = pytree.tree_flatten(module_stacks) + module_stacks = [x for x in module_stacks if isinstance(x, str)] + for stack in module_stacks: + self.assertTrue("_module" not in stack) + return gm.forward + + class SubMod(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.submod1 = SubMod() + self.submod2 = SubMod() + + def forward(self, x): + return self.submod1(x) + self.submod2(x) + + mod = Mod() + opt_mod = torch.compile(mod, backend=compiler) + opt_mod(torch.randn(2, 2)) + + with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True): + mod = Mod() + opt_mod = torch.compile(mod, backend=compiler) + opt_mod(torch.randn(2, 2)) + + # an example similar to Pippy usecase + mod = Mod() + GenerationTracker.tag(mod.submod1) + GenerationTracker.mark_class_dynamic(type(mod.submod1)) + mod = Mod() + opt_mod = torch.compile(mod, backend=compiler) + opt_mod(torch.randn(2, 2)) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 2d0543f8b147..f31f5c97eb62 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1137,7 +1137,7 @@ def wrap_module(self, value: torch.nn.Module): value.__class__, torch.nn.parallel.distributed.DistributedDataParallel ): self.install_guards(GuardBuilder.TYPE_MATCH) - return UnspecializedNNModuleVariable(value) + return UnspecializedNNModuleVariable(value, source=self.get_source()) elif getattr(value, "_is_fsdp_managed_module", False): # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] # in fully_sharded_data_parallel.py for more information diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 0a6bad4730dd..37c0bc17697a 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -67,38 +67,9 @@ def convert_to_fake(x): mod._infer_parameters(mod, fake_args, fake_kwargs) -def cleanup_source_for_nn_module_stack(source): - # TODO(anijain2305, export-team) This is a bad hack to fix the nn module - # fully_qualified_name to work with export/unflatten. It converts - # mod._modules['net1'] to mod.net1. - - # This type of source occurs when we use UnspecializedNNModule variable - # because unspecialized nn module variable inlines module __getattr__ calls. - # For export, we rely heavily on NNModuleVariable and do not support - # UnspecializedNNModule. But there is one case where this gets exposed - - # Pippy. Pippy uses export/unflatten (an export feature) and also - # monkepatches the `forward` method of a mod that forces Dynamo to use - # UnspecializedNNModule. Therefore, we will need proper work to retain the - # nn module stack when we let export rely on UnspecializedNNModule variable. - - # This does not work if we have recursively UnspecializedNNModule variables - # e.g. mod._modules['net1']._modules['net2']. This is unlikely to happen in - # Pippy so the hotfix is enough for Pippy. - - if ( - isinstance(source, GetItemSource) - and isinstance(source.base, AttrSource) - and isinstance(source.base.base, NNModuleSource) - and source.base.member == "_modules" - ): - return AttrSource(source.base.base, source.index) - return source - - @contextmanager def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module): - source_for_nn_module_stack = cleanup_source_for_nn_module_stack(source) - fully_qualified_name = source_for_nn_module_stack.name() + fully_qualified_name = source.name() try: tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__) yield @@ -144,6 +115,7 @@ class NNModuleVariable(VariableTracker): "module_type", "module_key", "module", + "nn_module_stack_source", *VariableTracker._nonvar_fields, } @@ -155,6 +127,13 @@ def __init__( self.module_key = module_key self.module = module assert self.source + self.nn_module_stack_source = self.source + + def get_nn_module_stack_source(self): + return self.nn_module_stack_source or self.source + + def set_nn_module_stack_source(self, source): + self.nn_module_stack_source = source def python_type(self): return self.module_type @@ -301,7 +280,17 @@ def var_getattr(self, tx, name): return variables.UserDefinedClassVariable(base.__class__, source=source) if object_member: - return VariableBuilder(tx, NNModuleSource(source))(subobj) + out = VariableBuilder(tx, NNModuleSource(source))(subobj) + + if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)): + # nn_module_stack source is BC surface area. Ensure that + # mod._modules["linear"] is reflected as mod.linear for + # nn_module_stack. + out.set_nn_module_stack_source( + AttrSource(self.get_nn_module_stack_source(), name) + ) + return out + else: if istype(subobj, property): if self.source: @@ -343,7 +332,9 @@ def call_function( ) -> "VariableTracker": mod = tx.output.get_submodule(self.module_key) - with record_nn_module_stack(self.module_key, self.source, tx, mod): + with record_nn_module_stack( + self.module_key, self.get_nn_module_stack_source(), tx, mod + ): is_lazy = is_lazy_module(mod) if ( isinstance(mod, torch.nn.Sequential) @@ -487,7 +478,9 @@ def generic_call_method_helper(name): # Example: `self.layer.forward(x)` # This is used for explicit calling `forward` in a forward function. # Dynamo puts `call_method` node in FX, doesn't trigger hooks. - with record_nn_module_stack(self.module_key, self.source, tx, module): + with record_nn_module_stack( + self.module_key, self.get_nn_module_stack_source(), tx, module + ): return generic_call_method_helper(name) if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed( @@ -750,6 +743,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable): _nonvar_fields = { "value_type", "is_state_mutated", + "nn_module_stack_source", *UserDefinedObjectVariable._nonvar_fields, } @@ -778,6 +772,19 @@ def __init__(self, value, **kwargs): super().__init__(value=value, **kwargs) self.is_state_mutated = False + # nn_module_stack_source is used to ensure BC for nn_module_stack. + # Downstream users prefer mod.linear instead of mod._modules['linear'] + # as the module stack. When Dynamo inlines the __getattr__ method, we + # cannot use self.source for nn_module_stack because it will be similar + # to mod._modules['linear']. In these cases, we set the + # nn_module_stack_source appropriately to resemble mod.linear. + self.nn_module_stack_source = self.source + + def get_nn_module_stack_source(self): + return self.nn_module_stack_source or self.source + + def set_nn_module_stack_source(self, source): + self.nn_module_stack_source = source @staticmethod @functools.lru_cache(None) @@ -830,7 +837,9 @@ def call_function( guard_to_detect_forward_monkeypatching(self.source, mod) ctx = ( - record_nn_module_stack(str(id(mod)), self.source, tx, mod) + record_nn_module_stack( + str(id(mod)), self.get_nn_module_stack_source(), tx, mod + ) if self.source else nullcontext() ) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 5b785293911f..7c7673a103fd 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -853,9 +853,26 @@ def var_getattr(self, tx, name): new_source = None if self.source: new_source = AttrSource(self.source, "__getattr__") - return variables.UserMethodVariable( + out = variables.UserMethodVariable( getattr_fn, self, source=new_source ).call_function(tx, [ConstantVariable.create(name)], {}) + + if self.source and getattr_fn is torch.nn.Module.__getattr__: + if isinstance( + out, + ( + variables.UnspecializedNNModuleVariable, + variables.NNModuleVariable, + ), + ): + # nn_module_stack source is BC surface area. Ensure that + # mod._modules["linear"] is reflected as mod.linear for + # nn_module_stack. + out.set_nn_module_stack_source( + AttrSource(self.get_nn_module_stack_source(), name) + ) + return out + elif getattr_fn is not None: unimplemented("UserDefined with non-function __getattr__") From b2d602306a9eb19e30328cbaee941c874f8148a9 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 10 Jun 2024 11:29:33 -0700 Subject: [PATCH 601/706] [RELAND][dynamo][nn-modules] Trace through nn.Module dunder methods for UnspecializedNNModule (#126578) Tracing through `__init__` is important because it initializes (calls STORE_ATTR) on members. By doing that, we kick in the mutation tracking for these objects. So, things like mutating `_modules` etc is tracked automatically. Fixes https://github.com/pytorch/pytorch/issues/111837 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126578 Approved by: https://github.com/jansel ghstack dependencies: #128295 --- test/distributed/test_dynamo_distributed.py | 10 +-- test/dynamo/test_higher_order_ops.py | 16 ++--- ...ddingNN.test_embedding_sparse_empty_tensor | 0 ...ngNN.test_embeddingbag_include_last_offset | 0 ....test_profiler_pattern_matcher_json_report | 0 .../TestJitGeneratedModule.test_nn_Bilinear | 0 .../TestJitGeneratedModule.test_nn_Embedding | 0 ...dModule.test_nn_EmbeddingBag_discontiguous | 0 ...itGeneratedModule.test_nn_EmbeddingBag_max | 0 ...odule.test_nn_EmbeddingBag_max_padding_idx | 0 ...tGeneratedModule.test_nn_EmbeddingBag_mean | 0 ...dule.test_nn_EmbeddingBag_mean_padding_idx | 0 ...eneratedModule.test_nn_EmbeddingBag_sparse | 0 ...itGeneratedModule.test_nn_EmbeddingBag_sum | 0 ...odule.test_nn_EmbeddingBag_sum_padding_idx | 0 ...atedModule.test_nn_Embedding_discontiguous | 0 ...itGeneratedModule.test_nn_Embedding_sparse | 0 .../TestJitGeneratedModule.test_nn_Linear | 0 ...eneratedModule.test_nn_Linear_no_batch_dim | 0 ...GeneratedModule.test_nn_PReLU_no_batch_dim | 0 .../TestNN.test_ParameterDict | 0 .../TestNN.test_Sequential_iadd | 0 .../TestNN.test_bilinear_broadcasting | 0 ...st_layer_norm_grads_with_create_graph_flag | 0 ..._linear_autograd_device_cpu_bias_weightCOO | 0 ..._linear_autograd_device_cpu_bias_weightCSC | 0 ..._linear_autograd_device_cpu_bias_weightCSR | 0 .../TestNN.test_linear_broadcasting | 0 .../TestNN.test_module_apply_inplace_op | 0 ...est_overwrite_module_params_on_conversion} | 0 ...metrized_tensor_parametrization_swap_False | 0 ....test_new_spectral_norm_forward_swap_True} | 0 ...rization.test_new_spectral_norm_swap_True} | 0 ...weight_norm_parametrization_swap_False_cpu | 0 ..._weight_norm_parametrization_swap_True_cpu | 0 ...sorDeviceTypeCPU.test_embedding_jagged_cpu | 0 .../TestPruningNN.test_identity_pruning | 0 ...TestPruningNN.test_pruning_id_consistency} | 0 .../TestPruningNN.test_random_pruning_0perc | 0 test/profiler/test_profiler.py | 1 + torch/_dynamo/create_parameter_op.py | 20 ++++++ torch/_dynamo/mutation_guard.py | 3 + torch/_dynamo/side_effects.py | 32 ++++++---- torch/_dynamo/symbolic_convert.py | 11 +++- torch/_dynamo/utils.py | 4 +- torch/_dynamo/variables/dicts.py | 6 +- torch/_dynamo/variables/misc.py | 26 +++++--- torch/_dynamo/variables/nn_module.py | 40 ++++++++---- torch/_dynamo/variables/torch.py | 9 ++- torch/_dynamo/variables/user_defined.py | 63 ++++++++++++------- 50 files changed, 169 insertions(+), 72 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor delete mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset delete mode 100644 test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim delete mode 100644 test/dynamo_expected_failures/TestNN.test_ParameterDict delete mode 100644 test/dynamo_expected_failures/TestNN.test_Sequential_iadd delete mode 100644 test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting delete mode 100644 test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_broadcasting delete mode 100644 test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op rename test/dynamo_expected_failures/{FakeTensorTest.test_embedding_bag_meta => TestNN.test_overwrite_module_params_on_conversion} (100%) delete mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False rename test/dynamo_expected_failures/{TestCompileTransformsCPU.test_compile_vmap_hessian_cpu => TestNNParametrization.test_new_spectral_norm_forward_swap_True} (100%) rename test/dynamo_expected_failures/{TestEmbeddingNN.test_embedding_max_norm => TestNNParametrization.test_new_spectral_norm_swap_True} (100%) delete mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu delete mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu delete mode 100644 test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu delete mode 100644 test/dynamo_expected_failures/TestPruningNN.test_identity_pruning rename test/dynamo_expected_failures/{TestEmbeddingNN.test_embedding_sparse_basic => TestPruningNN.test_pruning_id_consistency} (100%) delete mode 100644 test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index b31a2f717537..db44f1ce915d 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1084,12 +1084,14 @@ def _(ctx): # far from an exhaustive check of all the expected guards, just check a couple of them. FileCheck().check("""local "L['self']" TYPE_MATCH""").check( """local "L['self']" ID_MATCH""" - ).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check( - f"""{expected_guard_source} "L['self'].net" ID_MATCH""" ).check( - f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH""" + f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" ).check( - f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH""" + f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH""" + ).check( + f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" + ).check( + f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH""" ).run( GUARDS_FILE.getvalue() ) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 30dff83e12dd..c934cf55e8f5 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -5187,10 +5187,10 @@ def wrapper_fn(x): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): - l_self_tensor_constant0 = L_self_tensor_constant0 + def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"): + l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_ - alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None + alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default) @@ -5209,16 +5209,16 @@ def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): - getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_ - getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_ + def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): + l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_ + l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_ l_flat_tangents_1_ = L_flat_tangents_1_ - _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None + _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None - mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None + mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None return (mul_tensor,) """, ) diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset b/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report b/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_ParameterDict b/test/dynamo_expected_failures/TestNN.test_ParameterDict deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_Sequential_iadd b/test/dynamo_expected_failures/TestNN.test_Sequential_iadd deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting b/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag b/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_broadcasting b/test/dynamo_expected_failures/TestNN.test_linear_broadcasting deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op b/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta b/test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion similarity index 100% rename from test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta rename to test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False b/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu b/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True similarity index 100% rename from test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu rename to test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm b/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True similarity index 100% rename from test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm rename to test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu b/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning b/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic b/test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency similarity index 100% rename from test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic rename to test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency diff --git a/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc b/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index ca0481922e4f..81d158635c0e 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -2411,6 +2411,7 @@ def test_profiler_matmul_dim_fp16_pattern(self): num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) + @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_profiler_pattern_matcher_json_report(self): x = torch.ones((100, 100)) model = nn.Sequential( diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py index f6cd12de2021..d30e4a37f003 100644 --- a/torch/_dynamo/create_parameter_op.py +++ b/torch/_dynamo/create_parameter_op.py @@ -1,4 +1,7 @@ # mypy: allow-untyped-defs +import threading +from contextlib import contextmanager + import torch doc = """ @@ -37,3 +40,20 @@ def new_parameter_placeholder(size, dtype, device, requires_grad): # Allocating a zero tensor would causes assert failures in autograd. result.untyped_storage().resize_(0) return result + + +_TLS = threading.local() + + +@contextmanager +def do_not_convert_to_tracable_parameter(): + old_flag = getattr(_TLS, "convert_tracable_parameter", True) + _TLS.convert_tracable_parameter = False + try: + yield False + finally: + _TLS.convert_tracable_parameter = old_flag + + +def can_convert_to_tracable_parameter(): + return getattr(_TLS, "convert_tracable_parameter", True) diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py index 22e2b9999e03..9077ecd3d57f 100644 --- a/torch/_dynamo/mutation_guard.py +++ b/torch/_dynamo/mutation_guard.py @@ -11,6 +11,9 @@ from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks +unpatched_nn_module_init = torch.nn.Module.__init__ + + class MutationTracker: db = ExactWeakKeyDictionary() diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 229282f709cb..94797251c866 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -347,13 +347,7 @@ def codegen_save_tempvars(self, cg: PyCodegen): elif isinstance(var.mutable_local, AttributeMutationNew): if isinstance(var, variables.AutogradFunctionContextVariable): unimplemented("AutogradFunctionContextVariable escaped") - if "__call_nn_module_init" in self.store_attr_mutations.get( - var.mutable_local, {} - ): - assert isinstance(var, variables.UnspecializedNNModuleVariable) - cg.load_import_from(utils.__name__, "nn_module_new") - else: - cg.load_import_from(utils.__name__, "object_new") + cg.load_import_from(utils.__name__, "object_new") cg(var.mutable_local.cls_source) cg.extend_output(create_call_function(1, True)) cg.add_cache(var) @@ -480,9 +474,25 @@ def codegen_update_mutated(self, cg: PyCodegen): ] ) elif self.is_attribute_mutation(var): - for name, value in self.store_attr_mutations.get( - var.mutable_local, {} - ).items(): + # Applying mutations involves two steps: 1) Push all + # reconstructed objects onto the stack. 2) Call STORE_ATTR to + # apply the mutations. + # + # Dynamo must ensure that mutations are applied in the same + # order as in the original program. Therefore, two reverse + # operations occur below. + # + # The first reverse operation concerns `suffixes`. We apply + # suffixes in reverse order due to the way Python handles the + # stack. In Step 1, we push all reconstructed objects onto the + # stack, but the item at the top of the stack refers to the last + # attribute in the mutation order. If not fixed, this will apply + # the mutations of attributes in the reverse order. To account + # for this reversal, we iterate through the mutable attributes + # in reverse order. + for name, value in reversed( + self.store_attr_mutations.get(var.mutable_local, {}).items() + ): if isinstance(var, variables.NewGlobalVariable): cg.tx.output.update_co_names(name) cg(value) @@ -490,8 +500,6 @@ def codegen_update_mutated(self, cg: PyCodegen): suffixes.append( [create_instruction("STORE_GLOBAL", argval=name)] ) - elif name == "__call_nn_module_init": - pass # handled in codegen_save_tempvars elif isinstance(value, variables.DeletedVariable): if isinstance( var.mutable_local, AttributeMutationExisting diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 41ceaa615916..678a0497c8a2 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -416,10 +416,15 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): self.push(value) self.jump(inst) elif isinstance(value, UserDefinedObjectVariable): - x = value.var_getattr(self, "__bool__") - # if __bool__ is missing, trying __len__ to infer a truth value. - if isinstance(x, GetAttrVariable): + try: + x = value.var_getattr(self, "__bool__") + except exc.ObservedException: + # if __bool__ is missing, trying __len__ to infer a truth value. x = value.var_getattr(self, "__len__") + else: + if isinstance(x, GetAttrVariable): + # if __bool__ is missing, trying __len__ to infer a truth value. + x = value.var_getattr(self, "__len__") # __bool__ or __len__ is function if isinstance(x, UserMethodVariable): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 4dfddcc2cdf8..59500f0338a1 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2016,12 +2016,12 @@ def object_has_getattribute(value: Any): return False -def get_custom_getattr(value: Any): +def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): try: getattr_fn = inspect.getattr_static(type(value), "__getattr__") except AttributeError: getattr_fn = None - if getattr_fn is torch.nn.Module.__getattr__: + if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__: # ignore this case of getattr getattr_fn = None return getattr_fn diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 0724a80621f7..8391563c8e76 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -174,7 +174,11 @@ def python_type(self): def __contains__(self, vt): assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker - return is_hashable(vt) and Hashable(vt) in self.items + return ( + is_hashable(vt) + and Hashable(vt) in self.items + and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) + ) def reconstruct(self, codegen): # instructions to load collections.OrderedDict if necessary diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index cc0fb7096701..9ef36eb7f29f 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -14,8 +14,10 @@ import torch.utils._pytree as pytree from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction +from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import unimplemented from ..guards import GuardBuilder, install_guard +from ..mutation_guard import unpatched_nn_module_init from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource from ..utils import ( check_unspec_or_constant_args, @@ -121,7 +123,6 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) - if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: @@ -133,12 +134,10 @@ def call_method( and isinstance(objvar.mutable_local, AttributeMutationNew) and not (args or kwargs) ): - tx.output.side_effects.store_attr( - objvar, - "__call_nn_module_init", - variables.ConstantVariable.create(True), - ) - return variables.ConstantVariable.create(None) + with do_not_convert_to_tracable_parameter(): + return variables.UserFunctionVariable( + unpatched_nn_module_init, source=source + ).call_function(tx, [self.objvar] + args, kwargs) else: unimplemented("super() nn.Module.__init__") elif isinstance(inner_fn, types.FunctionType): @@ -175,6 +174,19 @@ def call_method( self.objvar, UserDefinedObjectVariable ): return self.objvar.method_setattr_standard(tx, *args, **kwargs) + elif inner_fn is object.__delattr__: + attr = args[0] + try: + attr = attr.as_python_constant() + except NotImplementedError: + unimplemented(f"non-const delattr attr: {attr}") + if not tx.output.side_effects.is_attribute_mutation(self.objvar): + unimplemented(f"delattr({self.objvar}, {attr}, ...)") + + tx.output.side_effects.store_attr( + self.objvar, attr, variables.DeletedVariable() + ) + return variables.ConstantVariable(None) unimplemented(f"non-function or method super: {inner_fn}") diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 37c0bc17697a..d3f7052a9445 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -215,7 +215,7 @@ def _custom_getattr_fallback(self, base, tx, name, options): if object_has_getattribute(base): unimplemented("torch.nn.Module with a custom __getattribute__ defined") - getattr_fn = get_custom_getattr(base) + getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True) if getattr_fn is None: return None @@ -665,7 +665,6 @@ def gen_source(source, name): if isinstance(args[0], SliceVariable): # Build a TupleVariable of NNModules result = [] - submods = [] # Turn the slice into the list of integers keys = list(range(len(module)))[args[0].as_python_constant()] @@ -679,9 +678,8 @@ def gen_source(source, name): source=src, ) ) - submods.append(submod) - new_module = torch.nn.Sequential(*submods) + new_module = module[args[0].as_python_constant()] new_module_variable = tx.output.register_attr_or_module( new_module, f"{self}.__getitem__(slice)", @@ -695,8 +693,10 @@ def gen_source(source, name): if isinstance(args[0], SymNodeVariable): key = args[0].evaluate_expr(tx.output) - else: + elif args[0].is_python_constant(): key = args[0].as_python_constant() + else: + unimplemented(f"getitem on NNModuleVariable with key {args[0]}") submod = module[key] return tx.output.register_attr_or_module( @@ -790,7 +790,7 @@ def set_nn_module_stack_source(self, source): @functools.lru_cache(None) def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler - supported = {torch.nn.Module.__setattr__} + supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} return { id(x.__code__) for x in torch.nn.Module.__dict__.values() @@ -798,8 +798,6 @@ def _nn_module_method_ids(): } def unpack_var_sequence(self, tx): - from .builder import VariableBuilder - try: fn = inspect.getattr_static(self.value_type, "__iter__") except AttributeError as e: @@ -810,11 +808,16 @@ def unpack_var_sequence(self, tx): torch.nn.ParameterList.__iter__, torch.nn.Sequential.__iter__, ): - assert self.source - return [ - VariableBuilder(tx, source=GetItemSource(self.source, idx))(item) - for idx, item in enumerate(self.value) - ] + # The program can mutate the nn module object but the saved `value` + # will not reflect the mutations. So, trace through the `__iter__` + # function to reflect any tracked mutations. + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn), + [ + self, + ], + {}, + ).unpack_var_sequence(tx) return super().unpack_var_sequence(tx) @@ -943,6 +946,17 @@ def call_method( # Handle submodules self.is_state_mutated = True + if method is torch.nn.Module.__setattr__ and isinstance( + args[1], variables.DeletedVariable + ): + # Trace through __delattr__ to track mutations on the module + # members like `_modules``. + return tx.inline_user_function_return( + variables.UserFunctionVariable(torch.nn.Module.__delattr__), + [self, args[0]], + kwargs, + ) + return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 4d7b96b6a320..934e9a316a4b 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -18,7 +18,11 @@ from ..._guards import TracingContext from .. import config, polyfill, variables from ..codegen import PyCodegen -from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter +from ..create_parameter_op import ( + can_convert_to_tracable_parameter, + new_parameter_placeholder, + tracable_create_parameter, +) from ..device_interface import get_registered_device_interfaces from ..exc import unimplemented from ..guards import GuardBuilder, install_guard @@ -871,6 +875,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) + if not can_convert_to_tracable_parameter(): + unimplemented("Workaround for issues with nn_parameter construction") + try: shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) dtype = data.var_getattr(tx, "dtype").as_python_constant() diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 7c7673a103fd..6c79d9cfcbef 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -34,7 +34,8 @@ from torch._guards import TracingContext from .. import variables -from ..exc import unimplemented +from ..create_parameter_op import do_not_convert_to_tracable_parameter +from ..exc import ObservedException, unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource from ..utils import ( @@ -57,10 +58,7 @@ def is_standard_setattr(val): - return val in ( - object.__setattr__, - torch.nn.Module.__setattr__, - ) + return val in (object.__setattr__,) class UserDefinedVariable(VariableTracker): @@ -378,17 +376,7 @@ def call_function( else UserDefinedObjectVariable, {}, ) - if ( - inspect.getattr_static(self.value, "__init__", None) - is torch.nn.Module.__init__ - ): - tx.output.side_effects.store_attr( - var, - "__call_nn_module_init", - variables.ConstantVariable.create(True), - ) - return var - else: + with do_not_convert_to_tracable_parameter(): var.call_method(tx, "__init__", args, kwargs) return var elif variables.CustomizedDictVariable.is_matching_cls(self.value): @@ -638,6 +626,10 @@ def call_method( else AttrSource(AttrSource(self.source, "__class__"), name) ) # TODO(jansel): add a guard to check for monkey patching? + from ..mutation_guard import unpatched_nn_module_init + + if method is torch.nn.Module.__init__: + method = unpatched_nn_module_init return UserMethodVariable(method, self, source=source).call_function( tx, args, kwargs ) @@ -799,7 +791,7 @@ def _check_for_getattr(self): def _getattr_static(self, name): if ( - isinstance(self.value, (torch.nn.Module, PyTreeSpec)) + isinstance(self.value, PyTreeSpec) or "__slots__" in self.value.__class__.__dict__ or type(self.value) == threading.local ): @@ -812,7 +804,6 @@ def _getattr_static(self, name): return cls_var except AttributeError: pass # __slots__ - # this might call torch.nn.Module.__getattr__ subobj = getattr(self.value, name) else: subobj = inspect.getattr_static(self.value, name) @@ -1018,14 +1009,35 @@ def call_hasattr(self, tx, name: str) -> "VariableTracker": install_guard( AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) ) - if self._check_for_getattribute() or self._check_for_getattr(): - unimplemented("hasattr with custom __getattr__") + if self._check_for_getattribute(): + unimplemented("hasattr with custom __getattribute__") try: self._getattr_static(name) return variables.ConstantVariable.create(True) except AttributeError: - return variables.ConstantVariable.create(False) + # Now check in __getattr__ function + getattr_fn = self._check_for_getattr() + if isinstance(getattr_fn, types.FunctionType): + # Dynamo is going to trace the __getattr__ function with + # args=name. Set the source accordingly. + new_source = None + if self.source: + new_source = AttrSource(self.source, "__getattr__") + try: + result = variables.UserMethodVariable( + getattr_fn, self, source=new_source + ).call_function(tx, [variables.ConstantVariable.create(name)], {}) + + return variables.ConstantVariable.create( + not isinstance(result, variables.DeletedVariable) + ) + except ObservedException: + return variables.ConstantVariable.create(False) + elif getattr_fn is None: + return variables.ConstantVariable.create(False) + else: + unimplemented("UserDefined with non-function __getattr__") def odict_getitem(self, tx, key): from .builder import VariableBuilder @@ -1092,6 +1104,12 @@ def var_getattr(self, tx, name): return super().var_getattr(tx, name) +class RemovableHandleClass: + # Dummy class to pass to python_type of RemovableHandleVariable + # Useful for isinstance check on hooks + pass + + class RemovableHandleVariable(VariableTracker): REMOVED = -1 @@ -1122,3 +1140,6 @@ def reconstruct(self, codegen): return # unreachable due to codegen.add_cache() when the hook is installed super().reconstruct(codegen) + + def python_type(self): + return RemovableHandleClass From 739aa224ec1aa8777cbeb215f6f668ab10d86803 Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Mon, 10 Jun 2024 23:24:16 +0000 Subject: [PATCH 602/706] [Fix] Parameter un/lifting issues in the TorchScript to ExportedProgram converter (#127975) This PR fixes issues related to parameters and inputs lifting in the converter. #### Issue 1 ``` > Graph[linear.weights, bias.weights, x.1] %1 ... %2 ... %3 = CreateObject() > Block 0[] %linear.0 = GetAttr(linear)[%3] > Block 0.0[] %weight.0 = GetAttr(weights)[%linear.0] > Block 1[] ... ``` * Model parameters for the top level module should be unlifted, while parameters from sub-blocks should be lifted. #### Fixes * Bottom-up traversal (i.e., start from the inner most block) to figure out which parameters to be lifted for sub-blocks. #### Test Plan * Add test cases for nested block without control flow `pytest test/export/test_converter.py -s -k test_convert_nn_module_with_nested_param` * Add test cases for nested block with control flow `pytest test/export/test_converter.py -s -k test_convert_nn_module_with_nested_if_and_param` #### Outcome ##### TorchScript ``` graph(%x.1 : Float(3, strides=[1], requires_grad=0, device=cpu), %m1.m1.linear.weight : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu), %m1.m1.linear.bias : Float(3, strides=[1], requires_grad=0, device=cpu), %m1.linear.weight : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu), %m1.linear.bias : Float(3, strides=[1], requires_grad=0, device=cpu), %m1.m2.linear.weight : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu), %m1.m2.linear.bias : Float(3, strides=[1], requires_grad=0, device=cpu), %linear.weight : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu), %linear.bias : Float(3, strides=[1], requires_grad=0, device=cpu), %m2.m1.linear.weight : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu), %m2.m1.linear.bias : Float(3, strides=[1], requires_grad=0, device=cpu), %m2.linear.weight : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu), %m2.linear.bias : Float(3, strides=[1], requires_grad=0, device=cpu), %m2.m2.linear.weight : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu), %m2.m2.linear.bias : Float(3, strides=[1], requires_grad=0, device=cpu)): %15 : __torch__.export.test_converter.___torch_mangle_14.SuperNestedM1 = prim::CreateObject() %16 : NoneType = prim::Constant(), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %17 : int = prim::Constant[value=1](), scope: export.test_converter.SuperNestedM1:: # /data/users/jiashenc/pytorch/test/export/test_converter.py:342:34 %18 : Tensor = aten::max(%x.1), scope: export.test_converter.SuperNestedM1:: # /data/users/jiashenc/pytorch/test/export/test_converter.py:342:19 %19 : Tensor = aten::gt(%18, %17), scope: export.test_converter.SuperNestedM1:: # /data/users/jiashenc/pytorch/test/export/test_converter.py:342:19 %20 : bool = aten::Bool(%19), scope: export.test_converter.SuperNestedM1:: # /data/users/jiashenc/pytorch/test/export/test_converter.py:342:19 %21 : Tensor = prim::If(%20), scope: export.test_converter.SuperNestedM1:: # /data/users/jiashenc/pytorch/test/export/test_converter.py:342:16 block0(): %linear.6 : __torch__.torch.nn.modules.linear.___torch_mangle_17.Linear = prim::GetAttr[name="linear"](%15), scope: export.test_converter.SuperNestedM1:: %m1.1 : __torch__.export.test_converter.___torch_mangle_15.NestedM = prim::GetAttr[name="m1"](%15), scope: export.test_converter.SuperNestedM1:: %24 : Tensor = aten::sum(%x.1, %16), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 # /data/users/jiashenc/pytorch/test/export/test_converter.py:327:19 %25 : Tensor = aten::gt(%24, %17), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 # /data/users/jiashenc/pytorch/test/export/test_converter.py:327:19 %26 : bool = aten::Bool(%25), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 # /data/users/jiashenc/pytorch/test/export/test_converter.py:327:19 %27 : Tensor = prim::If(%26), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 # /data/users/jiashenc/pytorch/test/export/test_converter.py:327:16 block0(): %linear.10 : __torch__.torch.nn.modules.linear.___torch_mangle_17.Linear = prim::GetAttr[name="linear"](%m1.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %m1.3 : __torch__.export.test_converter.___torch_mangle_16.M = prim::GetAttr[name="m1"](%m1.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %linear.12 : __torch__.torch.nn.modules.linear.___torch_mangle_17.Linear = prim::GetAttr[name="linear"](%m1.3), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %weight.4 : Tensor = prim::GetAttr[name="weight"](%linear.12), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %bias.4 : Tensor = prim::GetAttr[name="bias"](%linear.12), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %33 : Tensor = aten::linear(%x.1, %weight.4, %bias.4), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 # /data/users/jiashenc/pytorch/torch/nn/modules/linear.py:116:15 %weight.6 : Tensor = prim::GetAttr[name="weight"](%linear.10), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %bias.6 : Tensor = prim::GetAttr[name="bias"](%linear.10), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %36 : Tensor = aten::linear(%33, %weight.6, %bias.6), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 # /data/users/jiashenc/pytorch/torch/nn/modules/linear.py:116:15 -> (%36) block1(): %linear.14 : __torch__.torch.nn.modules.linear.___torch_mangle_17.Linear = prim::GetAttr[name="linear"](%m1.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %m2.3 : __torch__.export.test_converter.___torch_mangle_16.M = prim::GetAttr[name="m2"](%m1.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %linear.16 : __torch__.torch.nn.modules.linear.___torch_mangle_17.Linear = prim::GetAttr[name="linear"](%m2.3), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %weight.8 : Tensor = prim::GetAttr[name="weight"](%linear.16), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %bias.8 : Tensor = prim::GetAttr[name="bias"](%linear.16), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %42 : Tensor = aten::linear(%x.1, %weight.8, %bias.8), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 # /data/users/jiashenc/pytorch/torch/nn/modules/linear.py:116:15 %weight.2 : Tensor = prim::GetAttr[name="weight"](%linear.14), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %bias.2 : Tensor = prim::GetAttr[name="bias"](%linear.14), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 %45 : Tensor = aten::linear(%42, %weight.2, %bias.2), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m1 # /data/users/jiashenc/pytorch/torch/nn/modules/linear.py:116:15 -> (%45) %weight.10 : Tensor = prim::GetAttr[name="weight"](%linear.6), scope: export.test_converter.SuperNestedM1::/torch.nn.modules.linear.Linear::linear %bias.10 : Tensor = prim::GetAttr[name="bias"](%linear.6), scope: export.test_converter.SuperNestedM1::/torch.nn.modules.linear.Linear::linear %48 : Tensor = aten::linear(%27, %weight.10, %bias.10), scope: export.test_converter.SuperNestedM1::/torch.nn.modules.linear.Linear::linear # /data/users/jiashenc/pytorch/torch/nn/modules/linear.py:116:15 -> (%48) block1(): %linear.8 : __torch__.torch.nn.modules.linear.___torch_mangle_17.Linear = prim::GetAttr[name="linear"](%15), scope: export.test_converter.SuperNestedM1:: %m2.1 : __torch__.export.test_converter.___torch_mangle_15.NestedM = prim::GetAttr[name="m2"](%15), scope: export.test_converter.SuperNestedM1:: %51 : Tensor = aten::sum(%x.1, %16), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 # /data/users/jiashenc/pytorch/test/export/test_converter.py:327:19 %52 : Tensor = aten::gt(%51, %17), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 # /data/users/jiashenc/pytorch/test/export/test_converter.py:327:19 %53 : bool = aten::Bool(%52), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 # /data/users/jiashenc/pytorch/test/export/test_converter.py:327:19 %54 : Tensor = prim::If(%53), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 # /data/users/jiashenc/pytorch/test/export/test_converter.py:327:16 block0(): %linear.1 : __torch__.torch.nn.modules.linear.___torch_mangle_17.Linear = prim::GetAttr[name="linear"](%m2.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %m1 : __torch__.export.test_converter.___torch_mangle_16.M = prim::GetAttr[name="m1"](%m2.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %linear.5 : __torch__.torch.nn.modules.linear.___torch_mangle_17.Linear = prim::GetAttr[name="linear"](%m1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %weight.1 : Tensor = prim::GetAttr[name="weight"](%linear.5), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %bias.1 : Tensor = prim::GetAttr[name="bias"](%linear.5), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %60 : Tensor = aten::linear(%x.1, %weight.1, %bias.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 # /data/users/jiashenc/pytorch/torch/nn/modules/linear.py:116:15 %weight.3 : Tensor = prim::GetAttr[name="weight"](%linear.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %bias.3 : Tensor = prim::GetAttr[name="bias"](%linear.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %63 : Tensor = aten::linear(%60, %weight.3, %bias.3), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 # /data/users/jiashenc/pytorch/torch/nn/modules/linear.py:116:15 -> (%63) block1(): %linear.3 : __torch__.torch.nn.modules.linear.___torch_mangle_17.Linear = prim::GetAttr[name="linear"](%m2.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %m2 : __torch__.export.test_converter.___torch_mangle_16.M = prim::GetAttr[name="m2"](%m2.1), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %linear : __torch__.torch.nn.modules.linear.___torch_mangle_17.Linear = prim::GetAttr[name="linear"](%m2), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %weight.5 : Tensor = prim::GetAttr[name="weight"](%linear), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %bias.5 : Tensor = prim::GetAttr[name="bias"](%linear), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %69 : Tensor = aten::linear(%x.1, %weight.5, %bias.5), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 # /data/users/jiashenc/pytorch/torch/nn/modules/linear.py:116:15 %weight.12 : Tensor = prim::GetAttr[name="weight"](%linear.3), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %bias.12 : Tensor = prim::GetAttr[name="bias"](%linear.3), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 %72 : Tensor = aten::linear(%69, %weight.12, %bias.12), scope: export.test_converter.SuperNestedM1::/export.test_converter.NestedM::m2 # /data/users/jiashenc/pytorch/torch/nn/modules/linear.py:116:15 -> (%72) %weight : Tensor = prim::GetAttr[name="weight"](%linear.8), scope: export.test_converter.SuperNestedM1::/torch.nn.modules.linear.Linear::linear %bias : Tensor = prim::GetAttr[name="bias"](%linear.8), scope: export.test_converter.SuperNestedM1::/torch.nn.modules.linear.Linear::linear %75 : Tensor = aten::linear(%54, %weight, %bias), scope: export.test_converter.SuperNestedM1::/torch.nn.modules.linear.Linear::linear # /data/users/jiashenc/pytorch/torch/nn/modules/linear.py:116:15 -> (%75) return (%21) ``` ##### ExportedProgram ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_linear_weight: "f32[3, 3]", p_linear_bias: "f32[3]", p_m1_linear_weight: "f32[3, 3]", p_m1_linear_bias: "f32[3]", p_m1_m1_linear_weight: "f32[3, 3]", p_m1_m1_linear_bias: "f32[3]", p_m1_m2_linear_weight: "f32[3, 3]", p_m1_m2_linear_bias: "f32[3]", p_m2_linear_weight: "f32[3, 3]", p_m2_linear_bias: "f32[3]", p_m2_m1_linear_weight: "f32[3, 3]", p_m2_m1_linear_bias: "f32[3]", p_m2_m2_linear_weight: "f32[3, 3]", p_m2_m2_linear_bias: "f32[3]", x_1: "f32[3]"): # No stacktrace found for following nodes max_1: "f32[]" = torch.ops.aten.max.default(x_1) gt: "b8[]" = torch.ops.aten.gt.Scalar(max_1, 1); max_1 = None # File: .137:23 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_2, cond_false_2, [l_args_3_0_, l_args_3_13_, l_args_3_5_, l_args_3_12_, l_args_3_14_, l_args_3_1_, l_args_3_3_, l_args_3_4_, l_args_3_7_, l_args_3_10_, l_args_3_11_, l_args_3_2_, l_args_3_6_, l_args_3_8_, l_args_3_9_]); l_args_0_ = cond_true_2 = cond_false_2 = l_args_3_0_ = l_args_3_13_ = l_args_3_5_ = l_args_3_12_ = l_args_3_14_ = l_args_3_1_ = l_args_3_3_ = l_args_3_4_ = l_args_3_7_ = l_args_3_10_ = l_args_3_11_ = l_args_3_2_ = l_args_3_6_ = l_args_3_8_ = l_args_3_9_ = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [p_linear_weight, p_linear_bias, x_1, p_m1_linear_weight, p_m1_m1_linear_bias, p_m1_linear_bias, p_m1_m2_linear_weight, p_m1_m2_linear_bias, p_m1_m1_linear_weight, p_m2_m2_linear_bias, p_m2_m1_linear_weight, p_m2_linear_weight, p_m2_m1_linear_bias, p_m2_m2_linear_weight, p_m2_linear_bias]); gt = true_graph_0 = false_graph_0 = p_linear_weight = p_linear_bias = x_1 = p_m1_linear_weight = p_m1_m1_linear_bias = p_m1_linear_bias = p_m1_m2_linear_weight = p_m1_m2_linear_bias = p_m1_m1_linear_weight = p_m2_m2_linear_bias = p_m2_m1_linear_weight = p_m2_linear_weight = p_m2_m1_linear_bias = p_m2_m2_linear_weight = p_m2_linear_bias = None getitem: "f32[3]" = conditional[0]; conditional = None return (getitem,) class (torch.nn.Module): def forward(self, p_linear_weight: "f32[3, 3]", p_linear_bias: "f32[3]", x_1: "f32[3]", p_m1_linear_weight: "f32[3, 3]", p_m1_m1_linear_bias: "f32[3]", p_m1_linear_bias: "f32[3]", p_m1_m2_linear_weight: "f32[3, 3]", p_m1_m2_linear_bias: "f32[3]", p_m1_m1_linear_weight: "f32[3, 3]", p_m2_m2_linear_bias: "f32[3]", p_m2_m1_linear_weight: "f32[3, 3]", p_m2_linear_weight: "f32[3, 3]", p_m2_m1_linear_bias: "f32[3]", p_m2_m2_linear_weight: "f32[3, 3]", p_m2_linear_bias: "f32[3]"): # File: .134:8 in forward, code: sum_default = torch.ops.aten.sum.default(l_args_3_5__1, dtype = None) sum_1: "f32[]" = torch.ops.aten.sum.default(x_1) # File: .134:9 in forward, code: gt_scalar = torch.ops.aten.gt.Scalar(sum_default, 1); sum_default = None gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 1); sum_1 = None # File: .134:12 in forward, code: cond = torch.ops.higher_order.cond(gt_scalar, cond_true_0, cond_false_0, [l_args_3_12__true_branch, l_args_3_1__true_branch, l_args_3_5__1, l_args_3_14__true_branch, l_args_3_7__true_branch, l_args_3_3__true_branch, l_args_3_4__true_branch]); gt_scalar = cond_true_0 = cond_false_0 = l_args_3_12__true_branch = l_args_3_1__true_branch = l_args_3_5__1 = l_args_3_14__true_branch = l_args_3_7__true_branch = l_args_3_3__true_branch = l_args_3_4__true_branch = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [p_m1_linear_weight, p_m1_linear_bias, x_1, p_m1_m1_linear_bias, p_m1_m1_linear_weight, p_m1_m2_linear_weight, p_m1_m2_linear_bias]); gt = true_graph_0 = false_graph_0 = p_m1_linear_weight = p_m1_linear_bias = x_1 = p_m1_m1_linear_bias = p_m1_m1_linear_weight = p_m1_m2_linear_weight = p_m1_m2_linear_bias = None getitem: "f32[3]" = conditional[0]; conditional = None # File: .134:14 in forward, code: linear_default = torch.ops.aten.linear.default(getitem, l_args_3_0__1, l_args_3_13__1); getitem = l_args_3_0__1 = l_args_3_13__1 = None linear: "f32[3]" = torch.ops.aten.linear.default(getitem, p_linear_weight, p_linear_bias); getitem = p_linear_weight = p_linear_bias = None return (linear,) class (torch.nn.Module): def forward(self, p_m1_linear_weight: "f32[3, 3]", p_m1_linear_bias: "f32[3]", x_1: "f32[3]", p_m1_m1_linear_bias: "f32[3]", p_m1_m1_linear_weight: "f32[3, 3]", p_m1_m2_linear_weight: "f32[3, 3]", p_m1_m2_linear_bias: "f32[3]"): # File: .130:8 in forward, code: linear_default = torch.ops.aten.linear.default(l_args_3_5__1, l_args_3_7__true_branch, l_args_3_14__true_branch); l_args_3_5__1 = l_args_3_7__true_branch = l_args_3_14__true_branch = None linear: "f32[3]" = torch.ops.aten.linear.default(x_1, p_m1_m1_linear_weight, p_m1_m1_linear_bias); x_1 = p_m1_m1_linear_weight = p_m1_m1_linear_bias = None # File: .130:9 in forward, code: linear_default_1 = torch.ops.aten.linear.default(linear_default, l_args_3_12__1, l_args_3_1__1); linear_default = l_args_3_12__1 = l_args_3_1__1 = None linear_1: "f32[3]" = torch.ops.aten.linear.default(linear, p_m1_linear_weight, p_m1_linear_bias); linear = p_m1_linear_weight = p_m1_linear_bias = None return (linear_1,) class (torch.nn.Module): def forward(self, p_m1_linear_weight: "f32[3, 3]", p_m1_linear_bias: "f32[3]", x_1: "f32[3]", p_m1_m1_linear_bias: "f32[3]", p_m1_m1_linear_weight: "f32[3, 3]", p_m1_m2_linear_weight: "f32[3, 3]", p_m1_m2_linear_bias: "f32[3]"): # File: .131:8 in forward, code: linear_default = torch.ops.aten.linear.default(l_args_3_5__1, l_args_3_3__false_branch, l_args_3_4__false_branch); l_args_3_5__1 = l_args_3_3__false_branch = l_args_3_4__false_branch = None linear: "f32[3]" = torch.ops.aten.linear.default(x_1, p_m1_m2_linear_weight, p_m1_m2_linear_bias); x_1 = p_m1_m2_linear_weight = p_m1_m2_linear_bias = None # File: .131:9 in forward, code: linear_default_1 = torch.ops.aten.linear.default(linear_default, l_args_3_12__1, l_args_3_1__1); linear_default = l_args_3_12__1 = l_args_3_1__1 = None linear_1: "f32[3]" = torch.ops.aten.linear.default(linear, p_m1_linear_weight, p_m1_linear_bias); linear = p_m1_linear_weight = p_m1_linear_bias = None return (linear_1,) class (torch.nn.Module): def forward(self, p_linear_weight: "f32[3, 3]", p_linear_bias: "f32[3]", x_1: "f32[3]", p_m1_linear_weight: "f32[3, 3]", p_m1_m1_linear_bias: "f32[3]", p_m1_linear_bias: "f32[3]", p_m1_m2_linear_weight: "f32[3, 3]", p_m1_m2_linear_bias: "f32[3]", p_m1_m1_linear_weight: "f32[3, 3]", p_m2_m2_linear_bias: "f32[3]", p_m2_m1_linear_weight: "f32[3, 3]", p_m2_linear_weight: "f32[3, 3]", p_m2_m1_linear_bias: "f32[3]", p_m2_m2_linear_weight: "f32[3, 3]", p_m2_linear_bias: "f32[3]"): # File: .135:8 in forward, code: sum_default = torch.ops.aten.sum.default(l_args_3_5__1, dtype = None) sum_1: "f32[]" = torch.ops.aten.sum.default(x_1) # File: .135:9 in forward, code: gt_scalar = torch.ops.aten.gt.Scalar(sum_default, 1); sum_default = None gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 1); sum_1 = None # File: .135:12 in forward, code: cond = torch.ops.higher_order.cond(gt_scalar, cond_true_1, cond_false_1, [l_args_3_2__false_branch, l_args_3_5__1, l_args_3_9__false_branch, l_args_3_11__false_branch, l_args_3_6__false_branch, l_args_3_10__false_branch, l_args_3_8__false_branch]); gt_scalar = cond_true_1 = cond_false_1 = l_args_3_2__false_branch = l_args_3_5__1 = l_args_3_9__false_branch = l_args_3_11__false_branch = l_args_3_6__false_branch = l_args_3_10__false_branch = l_args_3_8__false_branch = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [p_m2_linear_weight, x_1, p_m2_linear_bias, p_m2_m1_linear_weight, p_m2_m1_linear_bias, p_m2_m2_linear_bias, p_m2_m2_linear_weight]); gt = true_graph_0 = false_graph_0 = p_m2_linear_weight = x_1 = p_m2_linear_bias = p_m2_m1_linear_weight = p_m2_m1_linear_bias = p_m2_m2_linear_bias = p_m2_m2_linear_weight = None getitem: "f32[3]" = conditional[0]; conditional = None # File: .135:14 in forward, code: linear_default = torch.ops.aten.linear.default(getitem, l_args_3_0__1, l_args_3_13__1); getitem = l_args_3_0__1 = l_args_3_13__1 = None linear: "f32[3]" = torch.ops.aten.linear.default(getitem, p_linear_weight, p_linear_bias); getitem = p_linear_weight = p_linear_bias = None return (linear,) class (torch.nn.Module): def forward(self, p_m2_linear_weight: "f32[3, 3]", x_1: "f32[3]", p_m2_linear_bias: "f32[3]", p_m2_m1_linear_weight: "f32[3, 3]", p_m2_m1_linear_bias: "f32[3]", p_m2_m2_linear_bias: "f32[3]", p_m2_m2_linear_weight: "f32[3, 3]"): # File: .132:8 in forward, code: linear_default = torch.ops.aten.linear.default(l_args_3_5__1, l_args_3_11__true_branch, l_args_3_6__true_branch); l_args_3_5__1 = l_args_3_11__true_branch = l_args_3_6__true_branch = None linear: "f32[3]" = torch.ops.aten.linear.default(x_1, p_m2_m1_linear_weight, p_m2_m1_linear_bias); x_1 = p_m2_m1_linear_weight = p_m2_m1_linear_bias = None # File: .132:9 in forward, code: linear_default_1 = torch.ops.aten.linear.default(linear_default, l_args_3_2__1, l_args_3_9__1); linear_default = l_args_3_2__1 = l_args_3_9__1 = None linear_1: "f32[3]" = torch.ops.aten.linear.default(linear, p_m2_linear_weight, p_m2_linear_bias); linear = p_m2_linear_weight = p_m2_linear_bias = None return (linear_1,) class (torch.nn.Module): def forward(self, p_m2_linear_weight: "f32[3, 3]", x_1: "f32[3]", p_m2_linear_bias: "f32[3]", p_m2_m1_linear_weight: "f32[3, 3]", p_m2_m1_linear_bias: "f32[3]", p_m2_m2_linear_bias: "f32[3]", p_m2_m2_linear_weight: "f32[3, 3]"): # File: .133:8 in forward, code: linear_default = torch.ops.aten.linear.default(l_args_3_5__1, l_args_3_8__false_branch, l_args_3_10__false_branch); l_args_3_5__1 = l_args_3_8__false_branch = l_args_3_10__false_branch = None linear: "f32[3]" = torch.ops.aten.linear.default(x_1, p_m2_m2_linear_weight, p_m2_m2_linear_bias); x_1 = p_m2_m2_linear_weight = p_m2_m2_linear_bias = None # File: .133:9 in forward, code: linear_default_1 = torch.ops.aten.linear.default(linear_default, l_args_3_2__1, l_args_3_9__1); linear_default = l_args_3_2__1 = l_args_3_9__1 = None linear_1: "f32[3]" = torch.ops.aten.linear.default(linear, p_m2_linear_weight, p_m2_linear_bias); linear = p_m2_linear_weight = p_m2_linear_bias = None return (linear_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m1_linear_weight'), target='m1.linear.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m1_linear_bias'), target='m1.linear.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m1_m1_linear_weight'), target='m1.m1.linear.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m1_m1_linear_bias'), target='m1.m1.linear.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m1_m2_linear_weight'), target='m1.m2.linear.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m1_m2_linear_bias'), target='m1.m2.linear.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m2_linear_weight'), target='m2.linear.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m2_linear_bias'), target='m2.linear.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m2_m1_linear_weight'), target='m2.m1.linear.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m2_m1_linear_bias'), target='m2.m1.linear.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m2_m2_linear_weight'), target='m2.m2.linear.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_m2_m2_linear_bias'), target='m2.m2.linear.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='x_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127975 Approved by: https://github.com/angelayi, https://github.com/ydwu4 --- test/export/test_converter.py | 251 +++++++++++++++++++++++++++++++++- torch/_export/converter.py | 171 ++++++++++++++++++++--- 2 files changed, 399 insertions(+), 23 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 90e92f183746..362e0a6b2ba3 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -21,6 +21,15 @@ def _check_equal_ts_ep_converter(self, mod, inp) -> ExportedProgram: ep = TS2EPConverter(ts_model, inp).convert() ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) orig_out, _ = pytree.tree_flatten(mod(*inp)) + + # Check module. + if isinstance(mod, torch.nn.Module): + self.assertEqual( + ep.module().state_dict().keys(), + mod.state_dict().keys(), + ) + + # Check results. self.assertEqual(len(ep_out), len(orig_out)) for ep_t, orig_t in zip(ep_out, orig_out): if isinstance(ep_t, torch.Tensor): @@ -259,11 +268,226 @@ def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]): x = x.cos() return x + y - inp = torch.ones(1, 4) + inp = (torch.ones(4),) self._check_equal_ts_ep_converter(MUnpackList(), inp) inp = ((torch.zeros(1, 4), torch.ones(1, 4)),) self._check_equal_ts_ep_converter(MUnpackTuple(), inp) + def test_convert_nn_module_with_nested_param(self): + class M(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + class NestedM(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(dim, dim) + self.m = M(dim) + + def forward(self, x: torch.Tensor): + return self.linear(self.m(x)) + + class SuperNestedM(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(dim, dim) + self.m = NestedM(dim) + + def forward(self, x: torch.Tensor): + return self.linear(self.m(x)) + + inp = (torch.ones(3),) + orig_m = NestedM(3) + ep = self._check_equal_ts_ep_converter(orig_m, inp) + orig_m = SuperNestedM(3) + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + def test_convert_nn_module_with_nested_buffer(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + return self.w + x + + class NestedM(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m = M() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + return self.w + self.m(x) + + class SuperNestedM(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m = NestedM() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + return self.w + self.m(x) + + inp = (torch.ones(1),) + orig_m = NestedM() + ep = self._check_equal_ts_ep_converter(orig_m, inp) + orig_m = SuperNestedM() + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + def test_convert_nn_module_with_nested_if_and_buffer(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + return self.w + x + + class NestedM(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m1 = M() + self.m2 = M() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + if torch.sum(x) > 1: + return self.w + self.m1(x) + else: + return self.w + self.m2(x) + + # Super nested, parameters neeed to lifted + # multiple times. + class SuperNestedM(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m1 = NestedM() + self.m2 = NestedM() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + if torch.max(x) > 1: + return self.w + self.m1(x) + else: + return self.w + self.m2(x) + + # Super nested module testing. + inp = (torch.ones(1),) + orig_m = SuperNestedM() + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + t = inp[0] + t -= 1 + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) + + def test_convert_nn_module_with_nested_if_and_param(self): + class M(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + class NestedM(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.m1 = M(dim) + self.m2 = M(dim) + self.linear = torch.nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor): + if torch.sum(x) > 1: + return self.linear(self.m1(x)) + else: + return self.linear(self.m2(x)) + + # Super nested, parameters neeed to lifted + # multiple times. + class SuperNestedM1(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.m1 = NestedM(dim) + self.m2 = NestedM(dim) + self.linear = torch.nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor): + if torch.max(x) > 1: + return self.linear(self.m1(x)) + else: + return self.linear(self.m2(x)) + + # Super nested, even the input needs to be + # lifted recursively due to value propogation optimiztaion. + class SuperNestedM2(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.m1 = NestedM(dim) + self.m2 = NestedM(dim) + self.linear = torch.nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor): + if torch.sum(x) > 1: + return self.linear(self.m1(x)) + else: + return self.linear(self.m2(x)) + + # Basic module testing. + inp = (torch.ones(3),) + orig_m = M(3) + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + t = inp[0] + t -= 0.8 + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) + + # Nested module testing. + inp = (torch.ones(3),) + orig_m = NestedM(3) + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + t = inp[0] + t -= 0.8 + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) + + # Super nested module testing. + inp = (torch.ones(3),) + orig_m = SuperNestedM1(3) + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + t = inp[0] + t -= 0.8 + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) + + # # Super nested module testing. + # inp = (torch.ones(3),) + # orig_m = SuperNestedM2(3) + # ep = self._check_equal_ts_ep_converter(orig_m, inp) + + # t = inp[0] + # t -= 0.8 + # torch.testing.assert_close( + # ep.module()(*inp), + # orig_m(*inp), + # ) + def test_ts2ep_converter_contains(self): class MIn(torch.nn.Module): def forward(self, x: torch.Tensor): @@ -322,6 +546,31 @@ def forward(self, x): m = M() self._check_equal_ts_ep_converter(m, inp) + def test_convert_func_without_param(self): + def func1(x, y): + return x + y + + def func2(x, y): + if x.sum() > 0: + return x + y + else: + return x - y + + inp = ( + torch.tensor(1), + torch.tensor(1), + ) + self._check_equal_ts_ep_converter(func1, inp) + + ep = self._check_equal_ts_ep_converter(func2, inp) + + t = inp[0] + t -= 1 + torch.testing.assert_close( + ep.module()(*inp), + func2(*inp), + ) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 20b6101948de..e2b108a658e0 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -48,6 +48,12 @@ def ir_name_to_func_name(name: str) -> str: return "convert_" + "_".join(name_list) +def get_node_for_param_and_buffer(fx_graph, name, is_top_level_graph): + if is_top_level_graph: + return fx_graph.get_attr(name) + return fx_graph.placeholder(name) + + # Those operators will be automatically populated to a instance method # of TS2FXGraphConverter with name convert__(). # Please check __init__ for method population implementations. @@ -60,6 +66,97 @@ def ir_name_to_func_name(name: str) -> str: } +def get_ir_value_parent_name_and_attr_name(node): + irv_parent_name, irv_name = node.input().debugName(), node.output().debugName() + attr_name = node.s("name") + return irv_name, irv_parent_name, attr_name + + +def construct_fqn(ir, ref_map, name_map): + name_list = [] + while ir in ref_map: + name_list.append(name_map[ir]) + ir = ref_map[ir] + return ".".join(reversed(name_list)) + + +def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set[str]]: + """ + Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes. + When a graph has control flow, the graph will be divided into multiple blocks. We want to convert + each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model + parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model, + we will run this pass which will: + 1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls. + 2. Process the graph bottom up to find the lifted attributes of each block by taking the union + of the attributes used in the current block, and the lifted attributes of all its child blocks. + + Returns: + A mapping of blocks to a set of FQNs of its lifted attributes. + """ + + # A map from a block to its expected to be lifted arguments. + blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = dict() + + # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a + # GetAttr node. By traversing this reference map, we can figure out the + # full IR aliasing pass and figure out the FQN of an attribute. + # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1" + node_to_parent_map: Dict[str, str] = dict() + + # Used for reconstructing the FQN of an attribute based on the reference map. + # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR + # This name map stores which attribute name is called for a src IR --> dest IR action. + # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear" + node_to_attr_name: Dict[str, str] = dict() + + def _dfs_get_attr_dependency(entry): + """ + First DFS path to construct reference map and name map. + """ + for node in entry.nodes(): + if node.kind() == "prim::GetAttr": + ( + irv_name, + irv_parent_name, + attr_name, + ) = get_ir_value_parent_name_and_attr_name(node) + node_to_parent_map[irv_name] = irv_parent_name + node_to_attr_name[irv_name] = attr_name + for block in node.blocks(): + _dfs_get_attr_dependency(block) + + def _map_blocks_to_lifted_attrs(entry): + """ + Walk the graph in a bottom-up fashion to build the expected to be + lifted arguments for each block. + """ + arguments: Set[str] = set() + for node in entry.nodes(): + for block in node.blocks(): + # Recursively build. + arguments = arguments.union(_map_blocks_to_lifted_attrs(block)) + if node.kind() == "prim::GetAttr": + irv_name = node.output().debugName() + # Skip for intermediate GetAttr, which will anyway not result a FQN. + # E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"} + # node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"} + # There is only one FQN %3-->%2-->%1: self.linear.weight + # %2-->%1 is not a FQN: self.linear + if irv_name not in set(node_to_parent_map.values()): + arguments.add( + construct_fqn(irv_name, node_to_parent_map, node_to_attr_name) + ) + if not isinstance(entry, torch._C.Graph): # Skip the top level. + blocks_to_lifted_attrs[entry] = arguments + return arguments + + _dfs_get_attr_dependency(graph) + _map_blocks_to_lifted_attrs(graph) + + return blocks_to_lifted_attrs + + def get_op_overload(node: torch._C.Node): schema_str = node.schema() schema = FunctionSchema.parse(schema_str) @@ -85,12 +182,13 @@ class TS2FXGraphConverter: def __init__( self, ts_graph: Union[torch._C.Graph, torch._C.Block], - param_names: Set[str], - buffer_names: Set[str], + name_to_param_map: Dict[str, torch.Tensor], + name_to_buffer_map: Dict[str, torch.Tensor], + blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], ): self.ts_graph = ts_graph - self.param_names = param_names - self.buffer_names = buffer_names + self.name_to_param_map = name_to_param_map + self.name_to_buffer_map = name_to_buffer_map self.fx_graph: torch.fx.Graph = torch.fx.Graph() self.input_specs: List[InputSpec] = [] @@ -105,6 +203,8 @@ def __init__( self.subgraphs: Dict[str, torch.fx.GraphModule] = {} + self.blocks_to_lifted_attrs = blocks_to_lifted_attrs + # Populate methods for the standard operators. for k in kind_to_standard_operators.keys(): handler_func_name = ir_name_to_func_name(k) @@ -116,6 +216,9 @@ def __init__( lambda node: self._convert_standard_operators(node), ) + def is_top_level_graph(self): + return isinstance(self.ts_graph, torch._C.Graph) + def add_subgraph(self, subgraph) -> str: name = f"subgraph_{len(self.subgraphs)}" self.subgraphs[name] = subgraph @@ -157,7 +260,11 @@ def convert(self) -> torch.fx.GraphModule: self.convert_graph_outputs() - gm = torch.fx.GraphModule(self.subgraphs, self.fx_graph) + # Pass parameter and buffer to the root for lookup. + gm = torch.fx.GraphModule( + {**self.subgraphs, **self.name_to_param_map, **self.name_to_buffer_map}, + self.fx_graph, + ) inplace_optimize_sym_size_div(gm) @@ -170,14 +277,7 @@ def convert_graph_inputs(self): name = graph_input.debugName() normalized_name = normalize_name(name) - fx_node = self.fx_graph.placeholder(normalized_name) - - # fx_node.meta["val"] = FakeTensor() - # TODO: set fx_node.meta["val"] - - self.name_to_node[name] = fx_node - - if name in self.param_names: + if name in self.name_to_param_map: self.input_specs.append( InputSpec( InputKind.PARAMETER, @@ -185,7 +285,10 @@ def convert_graph_inputs(self): target=name, ) ) - elif name in self.buffer_names: + fx_node = get_node_for_param_and_buffer( + self.fx_graph, name, self.is_top_level_graph() + ) + elif name in self.name_to_buffer_map: self.input_specs.append( InputSpec( InputKind.BUFFER, @@ -194,6 +297,9 @@ def convert_graph_inputs(self): persistent=True, ) ) + fx_node = get_node_for_param_and_buffer( + self.fx_graph, name, self.is_top_level_graph() + ) else: self.input_specs.append( InputSpec( @@ -202,6 +308,9 @@ def convert_graph_inputs(self): target=name, ) ) + fx_node = self.fx_graph.placeholder(normalized_name) + + self.name_to_node[name] = fx_node def convert_prim_Constant(self, node: torch._C.Node): name = node.output().debugName() @@ -439,13 +548,20 @@ def convert_prim_If(self, node: torch._C.Node): arguments.update(block_args) + # Lift parameters as inputs. + for block in node.blocks(): + arguments = arguments.union(self.blocks_to_lifted_attrs[block]) + arguments = list(arguments) # Convert blocks to subgraphs subgraph_nodes = [] for block in node.blocks(): - subgraph_converter = TS2FXGraphConverter(block, set(), set()) + subgraph_converter = TS2FXGraphConverter( + block, dict(), dict(), self.blocks_to_lifted_attrs + ) subgraph_converter.constant_map = self.constant_map + subgraph_converter.attribute_map = self.attribute_map for block_arg in arguments: normalized_block_arg_name = normalize_name(block_arg) @@ -555,7 +671,7 @@ class TS2EPConverter: # TorchScript model to ExportedProgram converter def __init__( self, - ts_model, + ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction], sample_args: Tuple[Any, ...], sample_kwargs: Optional[Dict[str, Any]] = None, ): @@ -565,12 +681,25 @@ def __init__( self.sample_args = sample_args self.sample_kwargs = sample_kwargs - self.param_names: Set[str] = {name for name, _ in ts_model.named_parameters()} - self.buffer_names: Set[str] = {name for name, _ in ts_model.named_buffers()} + self.name_to_param_map: Dict[str, torch.Tensor] = ( + dict(ts_model.named_parameters()) + if isinstance(ts_model, torch.jit.ScriptModule) + else dict() + ) + self.name_to_buffer_map: Dict[str, torch.Tensor] = ( + dict(ts_model.named_buffers()) + if isinstance(ts_model, torch.jit.ScriptModule) + else dict() + ) def convert(self) -> ExportedProgram: + blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph) + graph_converter = TS2FXGraphConverter( - self.ts_graph, self.param_names, self.buffer_names + self.ts_graph, + self.name_to_param_map, + self.name_to_buffer_map, + blocks_to_lifted_attrs, ) gm = graph_converter.convert() ep = self.retrace_as_exported_program(gm, graph_converter.tensor_constants) @@ -578,11 +707,9 @@ def convert(self) -> ExportedProgram: def retrace_as_exported_program(self, gm: torch.fx.GraphModule, tensor_constants): # TODO: adjust input orders to match GraphSignature convention - inputs = [*self.sample_args, *self.params, *tensor_constants.values()] - ep = torch.export._trace._export( gm, - tuple(inputs), + self.sample_args, strict=False, pre_dispatch=True, ) From 2126ae186e30b6820151baeba051308be8c44659 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 10 Jun 2024 23:40:18 +0000 Subject: [PATCH 603/706] Remove caffe2/perfkernels files (#128186) These files are not used. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128186 Approved by: https://github.com/ezyang, https://github.com/r-barnes --- BUILD.bazel | 7 - caffe2/perfkernels/adagrad.cc | 186 -------- caffe2/perfkernels/adagrad.h | 205 --------- caffe2/perfkernels/adagrad_avx2.cc | 125 ------ caffe2/perfkernels/adagrad_avx512.cc | 45 -- caffe2/perfkernels/batch_box_cox.cc | 113 ----- caffe2/perfkernels/batch_box_cox.h | 35 -- caffe2/perfkernels/batch_box_cox_avx2.cc | 399 ------------------ caffe2/perfkernels/cvtsh_ss_bugfix.h | 75 ---- .../fused_8bit_rowwise_embedding_lookup.cc | 211 --------- .../fused_8bit_rowwise_embedding_lookup.h | 55 --- ...fused_8bit_rowwise_embedding_lookup_idx.cc | 213 ---------- .../fused_8bit_rowwise_embedding_lookup_idx.h | 57 --- .../fused_nbit_rowwise_conversion.cc | 214 ---------- .../fused_nbit_rowwise_conversion.h | 39 -- caffe2/perfkernels/lstm_unit_cpu-impl.h | 141 ------- caffe2/perfkernels/lstm_unit_cpu.h | 73 ---- caffe2/perfkernels/lstm_unit_cpu_avx2.cc | 123 ------ caffe2/perfkernels/lstm_unit_cpu_common.cc | 125 ------ caffe2/perfkernels/lstm_unit_cpu_common.h | 71 ---- caffe2/perfkernels/math.h | 35 -- caffe2/perfkernels/math_cpu_avx2.cc | 246 ----------- caffe2/perfkernels/math_cpu_base.cc | 168 -------- caffe2/perfkernels/typed_axpy.cc | 88 ---- caffe2/perfkernels/typed_axpy.h | 12 - caffe2/perfkernels/typed_axpy_avx.cc | 68 --- caffe2/perfkernels/typed_axpy_avx2.cc | 104 ----- caffe2/perfkernels/vectorizer.h | 28 -- 28 files changed, 3261 deletions(-) delete mode 100644 caffe2/perfkernels/adagrad.cc delete mode 100644 caffe2/perfkernels/adagrad.h delete mode 100644 caffe2/perfkernels/adagrad_avx2.cc delete mode 100644 caffe2/perfkernels/adagrad_avx512.cc delete mode 100644 caffe2/perfkernels/batch_box_cox.cc delete mode 100644 caffe2/perfkernels/batch_box_cox.h delete mode 100644 caffe2/perfkernels/batch_box_cox_avx2.cc delete mode 100644 caffe2/perfkernels/cvtsh_ss_bugfix.h delete mode 100644 caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc delete mode 100644 caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h delete mode 100644 caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc delete mode 100644 caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h delete mode 100644 caffe2/perfkernels/fused_nbit_rowwise_conversion.cc delete mode 100644 caffe2/perfkernels/fused_nbit_rowwise_conversion.h delete mode 100644 caffe2/perfkernels/lstm_unit_cpu-impl.h delete mode 100644 caffe2/perfkernels/lstm_unit_cpu.h delete mode 100644 caffe2/perfkernels/lstm_unit_cpu_avx2.cc delete mode 100644 caffe2/perfkernels/lstm_unit_cpu_common.cc delete mode 100644 caffe2/perfkernels/lstm_unit_cpu_common.h delete mode 100644 caffe2/perfkernels/math.h delete mode 100644 caffe2/perfkernels/math_cpu_avx2.cc delete mode 100644 caffe2/perfkernels/math_cpu_base.cc delete mode 100644 caffe2/perfkernels/typed_axpy.cc delete mode 100644 caffe2/perfkernels/typed_axpy.h delete mode 100644 caffe2/perfkernels/typed_axpy_avx.cc delete mode 100644 caffe2/perfkernels/typed_axpy_avx2.cc delete mode 100644 caffe2/perfkernels/vectorizer.h diff --git a/BUILD.bazel b/BUILD.bazel index b58fb57199f3..f2f3be210e93 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -461,15 +461,8 @@ filegroup( filegroup( name = "caffe2_perfkernels_srcs", srcs = [ - "caffe2/perfkernels/adagrad.cc", "caffe2/perfkernels/embedding_lookup.cc", "caffe2/perfkernels/embedding_lookup_idx.cc", - "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc", - "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc", - "caffe2/perfkernels/fused_nbit_rowwise_conversion.cc", - "caffe2/perfkernels/lstm_unit_cpu_common.cc", - "caffe2/perfkernels/math_cpu_base.cc", - "caffe2/perfkernels/typed_axpy.cc", ], ) diff --git a/caffe2/perfkernels/adagrad.cc b/caffe2/perfkernels/adagrad.cc deleted file mode 100644 index c589187cb2eb..000000000000 --- a/caffe2/perfkernels/adagrad.cc +++ /dev/null @@ -1,186 +0,0 @@ -#include "caffe2/perfkernels/adagrad.h" - -#include - -#include "caffe2/perfkernels/common.h" - -namespace caffe2 { - -void adagrad_update__base( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - const float lr, - const float weight_decay = 0.f) { - internal::adagrad_update_base_inlined( - N, w, g, h, nw, nh, decay, epsilon, lr, weight_decay); -} - -void adagrad_update_prefetch__base( - int N, - const float* w, - const float* /* w_n */, // prefetch ptr - - const float* g, - - const float* h, - const float* /* h_n */, // prefetch ptr - - float* nw, - float* /* nw_n */, // prefetch ptr - - float* nh, - float* /* nh_n */, // prefetch ptr - - float epsilon, - float lr, - float weight_decay = 0.f) { - adagrad_update__base(N, w, g, h, nw, nh, epsilon, 1.0f, lr, weight_decay); -} - -void adagrad_fp16_update_prefetch__base( - int N, - const at::Half* w, - const at::Half* /* w_n */, // prefetch ptr - const float* g, - const at::Half* h, - const at::Half* /* h_n */, // prefetch ptr - at::Half* nw, - at::Half* /* nw_n */, // prefetch ptr - at::Half* nh, - at::Half* /* nh_n */, // prefetch ptr - float epsilon, - float lr, - float weight_decay = 0.f) { - internal::adagrad_update_base_inlined( - N, w, g, h, nw, nh, 1.0f, epsilon, lr, weight_decay); -} - -// version without prefetching -decltype(adagrad_update__base) adagrad_update__avx2_fma; -decltype(adagrad_update__base) adagrad_update__avx512; -void adagrad_update( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - float lr, - float weight_decay) { - AVX512_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay); - AVX2_FMA_DO( - adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay); - BASE_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay); -} - -decltype(adagrad_update_prefetch__base) adagrad_update_prefetch__avx2_fma; -void adagrad_update_prefetch( - int N, - const float* w, - const float* w_n, // prefetch ptr - - const float* g, - - const float* h, - const float* h_n, // prefetch ptr - - float* nw, - float* nw_n, // prefetch ptr - - float* nh, - float* nh_n, // prefetch ptr - - float epsilon, - float lr, - float weight_decay) { - AVX2_FMA_DO( - adagrad_update_prefetch, - N, - w, - w_n, - g, - h, - h_n, - nw, - nw_n, - nh, - nh_n, - epsilon, - lr, - weight_decay); - BASE_DO( - adagrad_update_prefetch, - N, - w, - w_n, - g, - h, - h_n, - nw, - nw_n, - nh, - nh_n, - epsilon, - lr, - weight_decay); -} - -// Version with prefetching for embeddings and -// momentum using fp16 -decltype(adagrad_fp16_update_prefetch__base) - adagrad_fp16_update_prefetch__avx2_fma; -void adagrad_fp16_update_prefetch( - int N, - const at::Half* w, - const at::Half* w_n, // prefetch ptr - const float* g, - const at::Half* h, - const at::Half* h_n, // prefetch ptr - at::Half* nw, - at::Half* nw_n, // prefetch ptr - at::Half* nh, - at::Half* nh_n, // prefetch ptr - float epsilon, - float lr, - float weight_decay) { - AVX2_FMA_DO( - adagrad_fp16_update_prefetch, - N, - w, - w_n, - g, - h, - h_n, - nw, - nw_n, - nh, - nh_n, - epsilon, - lr, - weight_decay); - BASE_DO( - adagrad_fp16_update_prefetch, - N, - w, - w_n, - g, - h, - h_n, - nw, - nw_n, - nh, - nh_n, - epsilon, - lr, - weight_decay); -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/adagrad.h b/caffe2/perfkernels/adagrad.h deleted file mode 100644 index f030e3e09d60..000000000000 --- a/caffe2/perfkernels/adagrad.h +++ /dev/null @@ -1,205 +0,0 @@ -#pragma once - -#if defined(__AVX__) && !defined(__NVCC__) && \ - (defined(__x86_64__) || defined(_M_X64) || defined(__i386__)) -#define CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC -#include -#endif -#include -#include - -namespace caffe2 { - -namespace internal { - -// The following functions inside internal namespace are inlined because they -// are performance critical. - -template -static inline void adagrad_update_base_inlined( - int N, - const T* w, - const float* g, - const T* h, - T* nw, - T* nh, - float decay, - float epsilon, - float lr, - float weight_decay = 0.f) { - for (const auto i : c10::irange(N)) { - float gi = std::fma(weight_decay, w[i], g[i]); - float hi = decay * h[i] + gi * gi; - nh[i] = hi; - nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon); - } -} - -// version with prefetching -// TODO(msmelyan) -// Crux of the computation is computing a / (sqrt(b) + epsilon), -// where a and b are vectors and epsilon is very small (eg., 10^-5) and does not -// change. Today it's computed using two vector sqrt and vector divide simd -// instructions. It is slow. We can take advantage of existing fast vector -// VRSQRTPS instruction that computes approximate reciprocals of square roots -// of the vector. It is 6x faster than vsrt and vdiv combinations. Since the -// addition of epsilon is just done to avoid division by zero, we approximate a -// / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can -// use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for -// the test on random numbers between 0.1 and 1 the absolute error was about -// 10^-3 compared to using slower but more accurate combination of vsqrt and -// vdiv. Extend Marat's function with more NR iterations to get more accuracy -// for training -// TODO(msmelyan) -// explore streaming stores, but need to have unique indices (deduplication) -inline void adagrad_update_prefetch_inlined( - int N, - const float* w, -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - const float* w_n, // prefetch ptr -#else - const float* /* unused */, -#endif - - const float* g, - - const float* h, -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - const float* h_n, // prefetch ptr -#else - const float* /* unused */, -#endif - - float* nw, -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - float* nw_n, // prefetch ptr -#else - float* /* unused */, -#endif - - float* nh, -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - float* nh_n, // prefetch ptr -#else - float* /* unused */, -#endif - - float epsilon, - float lr, - float weight_decay = 0.f) { - auto i = 0; - -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - constexpr int kSize = 8; - for (; i + kSize <= N; i += kSize) { - _mm_prefetch(reinterpret_cast(&w_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&h_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&nw_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&nh_n[i]), _MM_HINT_T0); - - __m256 gi = _mm256_loadu_ps(g + i); - __m256 hi = _mm256_loadu_ps(h + i); - __m256 wi = _mm256_loadu_ps(w + i); -#ifdef __FMA__ - gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi); - -#else - gi = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(weight_decay), wi), gi); -#endif - - __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi)); - _mm256_storeu_ps(nh + i, nhi); - __m256 vtmp = _mm256_div_ps( - _mm256_mul_ps(_mm256_set1_ps(lr), gi), - _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon))); - _mm256_storeu_ps(nw + i, _mm256_add_ps(wi, vtmp)); - } -#endif - - adagrad_update_base_inlined( - N - i, - w + i, - g + i, - h + i, - nw + i, - nh + i, - 1.0f, - epsilon, - lr, - weight_decay); -} - -} // namespace internal - -// version with prefetching -// TODO(msmelyan) -// Crux of the computation is computing a / (sqrt(b) + epsilon), -// where a and b are vectors and epsilon is very small (eg., 10^-5) and does not -// change. Today it's computed using two vector sqrt and vector divide simd -// instructions. It is slow. We can take advantage of existing fast vector -// VRSQRTPS instruction that computes approximate reciprocals of square roots -// of the vector. It is 6x faster than vsrt and vdiv combinations. Since the -// addition of epsilon is just done to avoid division by zero, we approximate a -// / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can -// use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for -// the test on random numbers between 0.1 and 1 the absolute error was about -// 10^-3 compared to using slower but more accurate combination of vsqrt and -// vdiv. Extend Marat's function with more NR iterations to get more accuracy -// for training -// TODO(msmelyan) -// explore streaming stores, but need to have inuque indices (deduplication) -void adagrad_update_prefetch( - int N, - const float* w, - const float* w_n, // prefetch ptr - - const float* g, - - const float* h, - const float* h_n, // prefetch ptr - - float* nw, - float* nw_n, // prefetch ptr - - float* nh, - float* nh_n, // prefetch ptr - - float epsilon, - float lr, - float weight_decay = 0.f); - -// Version with prefetching for embeddings and -// momentum using fp16 -void adagrad_fp16_update_prefetch( - int N, - const at::Half* w, - const at::Half* w_n, // prefetch ptr - const float* g, - const at::Half* h, - const at::Half* h_n, // prefetch ptr - at::Half* nw, - at::Half* nw_n, // prefetch ptr - at::Half* nh, - at::Half* nh_n, // prefetch ptr - float epsilon, - float lr, - float weight_decay = 0.f); - -// version without prefetching -void adagrad_update( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - float lr, - float weight_decay = 0.f); - -} // namespace caffe2 - -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC -#undef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC -#endif diff --git a/caffe2/perfkernels/adagrad_avx2.cc b/caffe2/perfkernels/adagrad_avx2.cc deleted file mode 100644 index 08c9fd00d9a0..000000000000 --- a/caffe2/perfkernels/adagrad_avx2.cc +++ /dev/null @@ -1,125 +0,0 @@ -#include "caffe2/perfkernels/adagrad.h" -#include "caffe2/perfkernels/cvtsh_ss_bugfix.h" - -#include -#include - -namespace caffe2 { - -// version without prefetching -void adagrad_update__avx2_fma( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - float lr, - float weight_decay = 0.f) { - constexpr int kSize = 8; - auto i = 0; - for (; i + kSize <= N; i += kSize) { - __m256 gi = _mm256_loadu_ps(g + i); - __m256 hi = _mm256_loadu_ps(h + i); - __m256 wi = _mm256_loadu_ps(w + i); - gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi); - - __m256 nhi = _mm256_add_ps( - _mm256_mul_ps(_mm256_set1_ps(decay), hi), _mm256_mul_ps(gi, gi)); - _mm256_storeu_ps(nh + i, nhi); - __m256 vtmp = _mm256_div_ps( - _mm256_mul_ps(_mm256_set1_ps(lr), gi), - _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon))); - _mm256_storeu_ps(nw + i, _mm256_add_ps(wi, vtmp)); - } - - for (; i < N; ++i) { - float gi = std::fma(weight_decay, w[i], g[i]); - float hi = nh[i] = decay * h[i] + gi * gi; - nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon); - } -} - -void adagrad_update_prefetch__avx2_fma( - int N, - const float* w, - const float* w_n, // prefetch ptr - - const float* g, - - const float* h, - const float* h_n, // prefetch ptr - - float* nw, - float* nw_n, // prefetch ptr - - float* nh, - float* nh_n, // prefetch ptr - - float epsilon, - float lr, - float weight_decay = 0.f) { - internal::adagrad_update_prefetch_inlined( - N, w, w_n, g, h, h_n, nw, nw_n, nh, nh_n, epsilon, lr, weight_decay); -} - -// Compute adagrad sparse, assumes embedding and momentum are at::Half -void adagrad_fp16_update_prefetch__avx2_fma( - int N, - const at::Half* w, - const at::Half* w_n, // prefetch ptr - const float* g, - const at::Half* h, - const at::Half* h_n, // prefetch ptr - at::Half* nw, - at::Half* nw_n, // prefetch ptr - at::Half* nh, - at::Half* nh_n, // prefetch ptr - float epsilon, - float lr, - float weight_decay = 0.f) { - constexpr int kSize = 8; - auto i = 0; - for (; i + kSize <= N; i += kSize) { - _mm_prefetch(reinterpret_cast(&w_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&h_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&nw_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&nh_n[i]), _MM_HINT_T0); - - // only convert momentum and embedding, gradient is fp32 - __m256 gi = _mm256_loadu_ps(g + i); - __m128i hhi = _mm_loadu_si128(reinterpret_cast(h + i)); - __m256 hi = _mm256_cvtph_ps(hhi); - __m128i whi = _mm_loadu_si128(reinterpret_cast(w + i)); - __m256 wi = _mm256_cvtph_ps(whi); - gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi); - - __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi)); - __m128i nhhi = _mm256_cvtps_ph(nhi, 0); - _mm_storeu_si128(reinterpret_cast<__m128i*>(nh + i), nhhi); - - __m256 vtmp = _mm256_div_ps( - _mm256_mul_ps(_mm256_set1_ps(lr), gi), - _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon))); - __m256 nwi = _mm256_add_ps(wi, vtmp); - __m128i nhwi = _mm256_cvtps_ph(nwi, 0); - _mm_storeu_si128(reinterpret_cast<__m128i*>(nw + i), nhwi); - } - - for (; i < N; ++i) { - float gi = std::fma( - weight_decay, - _cvtsh_ss(reinterpret_cast(w)[i]), - g[i]); - float nhi = - _cvtsh_ss(reinterpret_cast(h)[i]) + gi * gi; - reinterpret_cast(nh)[i] = _cvtss_sh(nhi, 0); - float nwi = _cvtsh_ss(reinterpret_cast(w)[i]) + - lr * gi / (std::sqrt(nhi) + epsilon); - reinterpret_cast(nw)[i] = _cvtss_sh(nwi, 0); - } -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/adagrad_avx512.cc b/caffe2/perfkernels/adagrad_avx512.cc deleted file mode 100644 index 417dd1ca8bab..000000000000 --- a/caffe2/perfkernels/adagrad_avx512.cc +++ /dev/null @@ -1,45 +0,0 @@ -#include "caffe2/perfkernels/adagrad.h" -#include "caffe2/perfkernels/cvtsh_ss_bugfix.h" - -#include -#include - -namespace caffe2 { - -// version without prefetching -void adagrad_update__avx512( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - float lr, - float weight_decay = 0.f) { - constexpr int kSize = 16; - auto i = 0; - for (; i + kSize <= N; i += kSize) { - __m512 gi = _mm512_loadu_ps(g + i); - __m512 hi = _mm512_loadu_ps(h + i); - __m512 wi = _mm512_loadu_ps(w + i); - gi = _mm512_fmadd_ps(_mm512_set1_ps(weight_decay), wi, gi); - - __m512 nhi = _mm512_add_ps( - _mm512_mul_ps(_mm512_set1_ps(decay), hi), _mm512_mul_ps(gi, gi)); - _mm512_storeu_ps(nh + i, nhi); - __m512 vtmp = _mm512_div_ps( - _mm512_mul_ps(_mm512_set1_ps(lr), gi), - _mm512_add_ps(_mm512_sqrt_ps(nhi), _mm512_set1_ps(epsilon))); - _mm512_storeu_ps(nw + i, _mm512_add_ps(wi, vtmp)); - } - - for (; i < N; ++i) { - float gi = std::fma(weight_decay, w[i], g[i]); - float hi = nh[i] = decay * h[i] + gi * gi; - nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon); - } -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/batch_box_cox.cc b/caffe2/perfkernels/batch_box_cox.cc deleted file mode 100644 index 7172f4b9d8cd..000000000000 --- a/caffe2/perfkernels/batch_box_cox.cc +++ /dev/null @@ -1,113 +0,0 @@ -#include "caffe2/perfkernels/common.h" - -#include -#include -#include - -namespace caffe2 { - -namespace { -template -void BoxCoxNaive( - std::size_t N, - std::size_t D, - const T* data_ptr, - const T* __restrict lambda1_ptr, - const T* __restrict lambda2_ptr, - T* output_ptr) { - constexpr T k_eps = static_cast(1e-6); - - for (std::size_t i = 0; i < N; i++) { - for (std::size_t j = 0; j < D; j++, data_ptr++, output_ptr++) { - T lambda1_v = lambda1_ptr[j]; - T lambda2_v = lambda2_ptr[j]; - T tmp = std::max(*data_ptr + lambda2_v, k_eps); - if (lambda1_v == 0) { - *output_ptr = std::log(tmp); - } else { - T lambda_1 = 1 / lambda1_v; - T pow = std::pow(tmp, lambda1_v); - *output_ptr = lambda_1 * pow - lambda_1; - } - } - } - -} -} - -#if defined(CAFFE2_PERF_WITH_AVX2) && defined(CAFFE2_PERF_USE_MKL) -namespace details { -template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const T* data_ptr, - const T* __restrict lambda1_ptr, - const T* __restrict lambda2_ptr, - T* output_ptr); - -extern template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const float* self_data, - const float* __restrict lambda1_data, - const float* __restrict lambda2_data, - float* output_data); - -extern template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const double* self_data, - const double* __restrict lambda1_data, - const double* __restrict lambda2_data, - double* output_data); -} // namespace detail -#endif - -template -void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const T* data, - const T* lambda1_data, - const T* lambda2_data, - T* output_data) { -#ifdef CAFFE2_PERF_WITH_AVX2 - AVX2_FMA_DO( - details::compute_batch_box_cox, - N, - D, - block_size, - data, - lambda1_data, - lambda2_data, - output_data); -#endif - BoxCoxNaive(N, D, data, lambda1_data, lambda2_data, output_data); -} - -template void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const float* data, - const float* lambda1_data, - const float* lambda2_data, - float* output_data); - -template void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const double* data, - const double* lambda1_data, - const double* lambda2_data, - double* output_data); - -} // namespace caffe2 diff --git a/caffe2/perfkernels/batch_box_cox.h b/caffe2/perfkernels/batch_box_cox.h deleted file mode 100644 index 60c973bbf8ea..000000000000 --- a/caffe2/perfkernels/batch_box_cox.h +++ /dev/null @@ -1,35 +0,0 @@ -// Impmenets BoxCox operator for CPU -#pragma once -#include - -namespace caffe2 { - -template -void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const T* self_data, - const T* lambda1_data, - const T* lambda2_data, - T* output_data); - -extern template void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const float* data, - const float* lambda1_data, - const float* lambda2_data, - float* output_data); - -extern template void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const double* data, - const double* lambda1_data, - const double* lambda2_data, - double* output_data); - -} // namespace caffe2 diff --git a/caffe2/perfkernels/batch_box_cox_avx2.cc b/caffe2/perfkernels/batch_box_cox_avx2.cc deleted file mode 100644 index 6171b5bfd032..000000000000 --- a/caffe2/perfkernels/batch_box_cox_avx2.cc +++ /dev/null @@ -1,399 +0,0 @@ -#include -#ifdef CAFFE2_PERF_USE_MKL -#include -#include -#include - -#include "vectorizer.h" - -// Enable compiler vectorized version only if numerical consistency is not -// required between dev and opt versions - disabled for now -#ifndef FAST_VECTORIZED_KERNEL -#define CPU_CAPABILITY_AVX2 -#include - -namespace at::vec { - -// Implements the vectorized version of std::max() operation, -// which DOESNOT propagates NaN for second argument -template -Vectorized max(const Vectorized& a, const Vectorized& b); - -template <> -Vectorized max(const Vectorized& a, const Vectorized& b) { - // std::max(NaN, nonNan) -> NaN - return _mm256_max_pd(b, a); -} - -template <> -Vectorized max(const Vectorized& a, const Vectorized& b) { - // std::max(NaN, nonNan) -> NaN - return _mm256_max_ps(b, a); -} - -// Implements recieprocal method based on newton-rapson method -// 1. user RCP approximiation -// 2. update with RCP = RCP * (2 - X * RCP) -template -Vectorized fast_recieprocal(const Vectorized& b); -template -scalar_t fast_recieprocal(scalar_t b); - -template<> -Vectorized fast_recieprocal(const Vectorized& b) { - auto minus2 = _mm256_set1_ps(-2.f); - auto rcp = _mm256_rcp_ps(b); - rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); - rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); - return rcp; -} - -template <> -float fast_recieprocal(float b) { - auto minus2 = _mm_set_ss(-2.f); - auto b_reg = _mm_set_ss(b); - auto rcp = _mm_rcp_ss(b_reg); - rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); - rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); - return _mm_cvtss_f32(rcp); -} - -template<> -Vectorized fast_recieprocal(const Vectorized& b) { - return b.reciprocal(); -} - -template <> -double fast_recieprocal(double b) { - return 1./b; -} - -} -#endif - -#include -#include -#include - -#include - -namespace caffe2::details { - -// MKL VML function templates. -template -void PackV(const int N, const T* a, const int* ia, T* y); -template -void UnpackV(const int N, const T* a, T* y, const int* iy); - -#define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \ - template <> \ - void PackV(const int N, const T* a, const int* ia, T* y) { \ - OriginalFunc(N, a, ia, y); \ - } -DELEGATE_PACKV_FUNCTION(float, vsPackV) -DELEGATE_PACKV_FUNCTION(double, vdPackV) -#undef DELEGATE_PACKV_FUNCTION - -#define DELEGATE_UNPACKV_FUNCTION(T, OriginalFunc) \ - template <> \ - void UnpackV(const int N, const T* a, T* y, const int* iy) { \ - OriginalFunc(N, a, y, iy); \ - } -DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV) -DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV) -#undef DELEGATE_UNPACKV_FUNCTION - -#ifndef FAST_VECTORIZED_KERNEL -template -void box_cox_zero_lambda( - size_t D, - const T* const self_data, - const T* const lambda2_data, - T k_eps, - T* const output_data) { - int j = 0; - using Vec = at::vec::Vectorized; - constexpr int64_t VLEN = Vec::size(); - auto k_eps_vec = Vec(k_eps); - for(; j + VLEN < D; j += VLEN) { - auto data = Vec::loadu(self_data + j); - auto lambda2 = Vec::loadu(lambda2_data + j); - auto sum = data + lambda2; - auto max = at::vec::max(sum, k_eps_vec); - auto res = max.log(); - res.store(output_data + j); - } - for ( ;j < D; ++j) { - auto sum = self_data[j] + lambda2_data[j]; - auto max = std::max(sum, k_eps); - output_data[j] = std::log(max); - } -} - -template -void box_cox_nonzero_lambda( - int64_t D, - const T* data_ptr, - const T* lambda1_ptr, - const T* lambda2_ptr, - T k_eps, - T* out) { - - int j = 0; - using Vec = at::vec::Vectorized; - constexpr int64_t VLEN = Vec::size(); - auto k_eps_vec = Vec(k_eps); - for(; j + VLEN < D; j += VLEN) { - auto data = Vec::loadu(data_ptr + j); - auto lambda2 = Vec::loadu(lambda2_ptr + j); - auto sum = data + lambda2; - auto max = at::vec::max(sum, k_eps_vec); - auto lambda1 = Vec::loadu(lambda1_ptr + j); - auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); - auto pow = max.pow(lambda1); - auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1); - res.store(out + j); - } - for ( ;j < D; ++j) { - auto sum = data_ptr[j] + lambda2_ptr[j]; - auto max = std::max(sum, k_eps); - auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]); - auto pow = std::pow(max, lambda1_ptr[j]); - out[j] = pow * lambda_over_1 - lambda_over_1; - } -} -#else -template -void box_cox_zero_lambda( - size_t D, - const T* const self_data, - const T* const lambda2_data, - T k_eps, - T* const output_data) { - VECTOR_LOOP for (auto j=0 ;j < D; ++j) { - auto sum = self_data[j] + lambda2_data[j]; - auto max = std::max(sum, k_eps); - output_data[j] = std::log(max); - } -} - -template -void box_cox_nonzero_lambda( - int64_t D, - const T* data_ptr, - const T* lambda1_ptr, - const T* lambda2_ptr, - T k_eps, - T* out) { - - VECTOR_LOOP for (auto j=0 ;j < D; ++j) { - FAST_MATH - auto sum = data_ptr[j] + lambda2_ptr[j]; - auto max = std::max(sum, k_eps); - auto lamda1 = lambda1_ptr[j]; - auto lambda_over_1 = 1 / lamda1; - if constexpr (std::is_same::value) { - lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); - lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); - } - auto pow = std::pow(max, lamda1); - out[j] = pow * lambda_over_1 - lambda_over_1; - } -} -#endif - -template -void box_cox_mixed_lambda( - const T* const self_data, - const std::vector& nonzeros, - const std::vector& zeros, - const T* const lambda1, - const T* const lambda2, - const T* const lambda2_z_, - T k_eps, - T* const buffer, - T* const output_data) { - PackV(nonzeros.size(), self_data, nonzeros.data(), buffer); - box_cox_nonzero_lambda( - nonzeros.size(), buffer, lambda1, lambda2, k_eps, buffer); - UnpackV(nonzeros.size(), buffer, output_data, nonzeros.data()); - - PackV(zeros.size(), self_data, zeros.data(), buffer); - box_cox_zero_lambda( - zeros.size(), buffer, lambda2_z_, k_eps, buffer); - UnpackV(zeros.size(), buffer, output_data, zeros.data()); -} - -template -void TileArrayIntoVector( - const T* const a, - const size_t D, - const int K, - std::vector& b) { - b.resize(K * D); - for (const auto k : c10::irange(K)) { - std::copy(a, a + D, b.begin() + k * D); - } -} - -void TileIndicesInPlace(std::vector& v, const std::size_t D, const std::size_t K) { - auto n = v.size(); - v.resize(K * n); - for (const auto k : c10::irange(1, K)) { - for (const auto j : c10::irange(n)) { - v[k * n + j] = v[j] + k * D; - } - } -} - -template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const T* self_data, - const T* __restrict lambda1_data, - const T* __restrict lambda2_data, - T* output_data) { - constexpr T k_eps = static_cast(1e-6); - - FOLLY_DECLARE_REUSED(zeros, std::vector); - FOLLY_DECLARE_REUSED(nonzeros, std::vector); - // Don't bother calling reserve; calls after the first will get a - // correctly-sized allocation anyway. - for (const auto j : c10::irange(D)) { - if (lambda1_data[j] == 0) { - zeros.push_back(j); - } else { - nonzeros.push_back(j); - } - } - - // Process K rows at a time for effective vectorization with small rows. - const auto K = std::min(N, (block_size + D - 1) / D); - - FOLLY_DECLARE_REUSED(lambda1_, std::vector); - FOLLY_DECLARE_REUSED(lambda2_, std::vector); - FOLLY_DECLARE_REUSED(lambda2_z_, std::vector); - - if (nonzeros.size() == D) { - // ((x + lambda2)^lambda1 - 1)/lambda1, if lambda1 != 0 - size_t i = 0; - if (K > 1) { - TileArrayIntoVector(lambda1_data, D, K, lambda1_); - TileArrayIntoVector(lambda2_data, D, K, lambda2_); - DCHECK_EQ(K * D, lambda1_.size()); - DCHECK_EQ(K * D, lambda2_.size()); - for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { - box_cox_nonzero_lambda( - K * D, - self_data, - lambda1_.data(), - lambda2_.data(), - k_eps, - output_data); - } - } - for (; i < N; i++, self_data += D, output_data += D) { - box_cox_nonzero_lambda( - D, self_data, lambda1_data, lambda2_data, k_eps, output_data); - } - } else if (zeros.size() == D) { - // ln(x + lambda2), if lambda1 == 0 - size_t i = 0; - if (K > 1) { - TileArrayIntoVector(lambda2_data, D, K, lambda2_z_); - DCHECK_EQ(K * D, lambda2_z_.size()); - for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { - box_cox_zero_lambda( - K * D, self_data, lambda2_z_.data(), k_eps, output_data); - } - } - for (; i < N; i++, self_data += D, output_data += D) { - box_cox_zero_lambda( - D, self_data, lambda2_data, k_eps, output_data); - } - } else { - // mix zeros and nonzeros - const size_t n = nonzeros.size(); - if (K > 1) { - TileIndicesInPlace(nonzeros, 0, K); - TileIndicesInPlace(zeros, 0, K); - } - - FOLLY_DECLARE_REUSED(buffer, std::vector); - - buffer.resize(std::max(nonzeros.size(), zeros.size())); - lambda1_.resize(nonzeros.size()); - lambda2_.resize(nonzeros.size()); - lambda2_z_.resize(zeros.size()); - PackV(nonzeros.size(), lambda1_data, nonzeros.data(), lambda1_.data()); - PackV(nonzeros.size(), lambda2_data, nonzeros.data(), lambda2_.data()); - PackV(zeros.size(), lambda2_data, zeros.data(), lambda2_z_.data()); - - size_t i = 0; - if (K > 1) { - // Truncate to original size, and re-tile with offsets this time. - nonzeros.resize(n); - DCHECK_GT(D, n); - zeros.resize(D - n); - TileIndicesInPlace(nonzeros, D, K); - TileIndicesInPlace(zeros, D, K); - DCHECK_EQ(nonzeros.size(), lambda1_.size()); - DCHECK_EQ(nonzeros.size(), lambda2_.size()); - DCHECK_EQ(zeros.size(), lambda2_z_.size()); - - for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { - box_cox_mixed_lambda( - self_data, - nonzeros, - zeros, - lambda1_.data(), - lambda2_.data(), - lambda2_z_.data(), - k_eps, - buffer.data(), - output_data); - } - // Truncate to original size. - nonzeros.resize(n); - zeros.resize(D - n); - } - for (; i < N; i++, self_data += D, output_data += D) { - box_cox_mixed_lambda( - self_data, - nonzeros, - zeros, - lambda1_.data(), - lambda2_.data(), - lambda2_z_.data(), - k_eps, - buffer.data(), - output_data); - } - } -}; - - -template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const float* self_data, - const float* __restrict lambda1_data, - const float* __restrict lambda2_data, - float* output_data); - -template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const double* self_data, - const double* __restrict lambda1_data, - const double* __restrict lambda2_data, - double* output_data); - -} // namespace caffe2::detail -#endif diff --git a/caffe2/perfkernels/cvtsh_ss_bugfix.h b/caffe2/perfkernels/cvtsh_ss_bugfix.h deleted file mode 100644 index 6a748faa0e57..000000000000 --- a/caffe2/perfkernels/cvtsh_ss_bugfix.h +++ /dev/null @@ -1,75 +0,0 @@ -#pragma once - -// Apple clang was fixed in 8.1 -#if defined(__apple_build_version__) && \ - ((__clang_major__ < 8) || \ - ((__clang_major__ == 8) && (__clang_minor__ < 1))) -#define CAFFE2_INTERNAL_APPLE_NEED_FIX 1 -#endif - -// Regular clang was fixed in 3.9 -#if defined(__clang__) && (__clang_major__ < 4) && (__clang_minor__ < 9) -#define CAFFE2_INTERNAL_CLANG_NEED_FIX 1 -#endif - -#if defined(CAFFE2_INTERNAL_APPLE_NEED_FIX) || \ - defined(CAFFE2_INTERNAL_CLANG_NEED_FIX) - -#include -#include - -// This version of clang has a bug that _cvtsh_ss is not defined, see -// https://reviews.llvm.org/D16177 -static __inline float - __attribute__((__always_inline__, __nodebug__, __target__("f16c"))) - _cvtsh_ss(unsigned short a) { - __v8hi v = {(short)a, 0, 0, 0, 0, 0, 0, 0}; - __v4sf r = __builtin_ia32_vcvtph2ps(v); - return r[0]; -} - -static __inline unsigned short - __attribute__((__always_inline__, __nodebug__, __target__("f16c"))) - _cvtss_sh(float a, int imm8) { - unsigned short ret; - *reinterpret_cast(&ret) = a; - return ret; -} - -#endif // __APPLE_NEED_FIX || __CLANG_NEED_FIX - -#undef __APPLE_NEED_FIX -#undef __CLANG_NEED_FIX - -#if defined(_MSC_VER) && !defined(__clang__) - -#include -#include - -// It seems that microsoft msvc does not have a _cvtsh_ss implementation so -// we will add a dummy version to it. - -static inline float _cvtsh_ss(unsigned short x) { - union { - std::uint32_t intval; - float floatval; - } t1; - std::uint32_t t2, t3; - t1.intval = x & 0x7fff; // Non-sign bits - t2 = x & 0x8000; // Sign bit - t3 = x & 0x7c00; // Exponent - t1.intval <<= 13; // Align mantissa on MSB - t2 <<= 16; // Shift sign bit into position - t1.intval += 0x38000000; // Adjust bias - t1.intval = (t3 == 0 ? 0 : t1.intval); // Denormals-as-zero - t1.intval |= t2; // Re-insert sign bit - return t1.floatval; -} - -static inline unsigned short _cvtss_sh(float x, int imm8) { - unsigned short ret; - *reinterpret_cast(&ret) = x; - return ret; -} - -#endif // _MSC_VER diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc deleted file mode 100644 index d919f22c5795..000000000000 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc +++ /dev/null @@ -1,211 +0,0 @@ -#include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h" - -#include "caffe2/perfkernels/common.h" - -#include -#include - -namespace caffe2 { - -/** - * Base implementation does runtime dispatch for each segment of reduction - * @return false if there is an out-of-bound error - */ -template < - typename IndexType, - typename InType, - typename OutType, - bool IS_WEIGHT_POSITIONAL = false> -static bool Fused8BitRowwiseEmbeddingLookupGenericSlow( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const InType* input, - const IndexType* indices, - const int* lengths, - const float* weights, // optional, can be null for sum reducer - bool normalize_by_lengths, - OutType* out) { - // block_size is the number of elements and fused_block_size is the size of - // an entire row, including scale and bias. - const auto scale_bias_offset = 8 / sizeof(InType); - const int64_t fused_block_size = block_size + scale_bias_offset; - int64_t current = 0; - for (const auto m : c10::irange(output_size)) { - memset(out, 0, sizeof(OutType) * block_size); - if (current + lengths[m] > index_size) { - return false; - } - for (int i = 0; i < lengths[m]; ++i) { - int64_t idx = indices[current]; - if (idx < 0 || idx >= data_size) { - return false; - } -#ifdef __GNUC__ - if (current + 1 < index_size) { - __builtin_prefetch( - input + fused_block_size * indices[current + 1], 0, 1); - } -#endif // __GNUC__ - - const float* scale_bias = reinterpret_cast( - input + fused_block_size * indices[current] + block_size); - - float weight = 1.0f; - if (weights) { - weight = weights[IS_WEIGHT_POSITIONAL ? i : current]; - } - const float scale = weight * scale_bias[0]; - const float bias = weight * scale_bias[1]; - - for (const auto j : c10::irange(block_size)) { - out[j] += scale * input[fused_block_size * indices[current] + j] + bias; - } - - ++current; - } - if (normalize_by_lengths && lengths[m]) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float scale = 1.f / lengths[m]; - for (const auto j : c10::irange(block_size)) { - out[j] *= scale; - } - } - out += block_size; - } - return current == index_size; -} - -// clang-format off -// Proxy back to generic implementation -#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(IndexType, OutType) \ - bool \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - return Fused8BitRowwiseEmbeddingLookupGenericSlow< \ - IndexType, \ - uint8_t, \ - OutType, \ - false>( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - } \ - decltype( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base) \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__avx2_fma; \ - bool Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - const int32_t one = 1; \ - CAFFE_ENFORCE_EQ( \ - reinterpret_cast(&one)[0], \ - 1, \ - "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ - AVX2_FMA_DO( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - BASE_DO( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - } \ - template <> \ - void Fused8BitRowwiseEmbeddingLookup( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - bool success = \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - if (success) { \ - return; \ - } \ - int64_t current = 0; \ - for (int m = 0; m < output_size; ++m) { \ - for (int i = 0; i < lengths[m]; ++i) { \ - CAFFE_ENFORCE_LT(current, index_size); \ - IndexType idx = indices[current]; \ - CAFFE_ENFORCE( \ - 0 <= idx && idx < data_size, \ - "Index ", \ - current, \ - " is out of bounds: ", \ - idx, \ - ", range 0 to ", \ - data_size); \ - ++current; \ - } \ - } \ - CAFFE_ENFORCE_EQ( \ - current, \ - index_size, \ - "Your input seems to be incorrect: the sum of lengths values should be " \ - "the size of the indices tensor, but it appears not."); \ - } -// clang-format on - -FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, float); -FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, float); - -#undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION - -} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h deleted file mode 100644 index cfaab0d361b1..000000000000 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h +++ /dev/null @@ -1,55 +0,0 @@ -#pragma once - -#include - -namespace caffe2 { - -/** - * Embedding lookup with reduction. - * - * `input` of size data_size * (block_size + 8B) - * `indices` of size index_size - * `lengths` of size output_size - * `weights` nullptr or array of size index_size - * `out` of size output_size * block_size - * sum(lengths[i]) == index_size - * - * Note that block_size should be the number of quantized values per row in the - * data, i.e. excluding the scale and bias. The total (fused) block size is - * assumed to be this block_size, plus 4 bytes for scale and 4 bytes for bias. - * - * Behavior is roughly equivalent to pseudocode: - * - * pos = 0 - * fused_block_size = block_size + 8B // quantized values and scale and bias - * for (i = 0..output_size-1) - * for (k = 0..block_size-1) - * out[i*block_size + k] = 0 - * for (j = 0..lengths[i]-1) - * for (k = 0..block_size-1) - * out[i*block_size + k] += input[indices[pos]*(fused_block_size) + k] * - * (weights ? weights[IS_WEIGHT_POSITIONAL ? j : pos] : 1.0) - * pos += 1 - * if (normalize_weights && lengths[i] > 0) - * for (k = 0..block_size-1) - * out[i*block_size + k] /= lengths[i] - * - */ - -template < - typename IndexType, - typename InType, - typename OutType, - bool IS_WEIGHT_POSITIONAL = false> -void Fused8BitRowwiseEmbeddingLookup( - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t data_size, - const InType* input, - const IndexType* indices, - const int* lengths, - const float* weights, // optional, can be null for non-weighted sum - bool normalize_by_lengths, - OutType* out); -} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc deleted file mode 100644 index 8f7e926c0e9c..000000000000 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc +++ /dev/null @@ -1,213 +0,0 @@ -#include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h" - -#include "caffe2/perfkernels/common.h" - -#include -#include - -namespace caffe2 { - -/** - * Base implementation does runtime dispatch for each segment of reduction - * @return false if there is an out-of-bound error - */ -template < - typename IndexType, - typename InType, - typename OutType, - bool IS_WEIGHT_POSITIONAL = false> -static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const InType* input, - const IndexType* indices, - const IndexType* offsets, - const float* weights, // optional, can be null for sum reducer - bool normalize_by_lengths, - OutType* out) { - // block_size is the number of elements and fused_block_size is the size of - // an entire row, including scale and bias. - const auto scale_bias_offset = 8 / sizeof(InType); - const int64_t fused_block_size = block_size + scale_bias_offset; - int64_t current = 0; - for (const auto m : c10::irange(output_size)) { - memset(out, 0, sizeof(OutType) * block_size); - if (current != offsets[m] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[m]; - int64_t end_offset = offsets[m + 1]; - int64_t length = end_offset - start_offset; - for (const auto i : c10::irange(start_offset, end_offset)) { - int64_t idx = indices[current]; - if (idx < 0 || idx >= data_size) { - return false; - } -#ifdef __GNUC__ - if (current + 1 < index_size) { - __builtin_prefetch( - input + fused_block_size * indices[current + 1], 0, 1); - } -#endif // __GNUC__ - - const float* scale_bias = reinterpret_cast( - input + fused_block_size * indices[current] + block_size); - - float weight = 1.0f; - if (weights) { - weight = weights[IS_WEIGHT_POSITIONAL ? i : current]; - } - const float scale = weight * scale_bias[0]; - const float bias = weight * scale_bias[1]; - - for (const auto j : c10::irange(block_size)) { - out[j] += scale * input[fused_block_size * indices[current] + j] + bias; - } - - ++current; - } - if (normalize_by_lengths && length) { - float scale = 1.f / length; - for (const auto j : c10::irange(block_size)) { - out[j] *= scale; - } - } - out += block_size; - } - return current == index_size; -} - -// clang-format off -// Proxy back to generic implementation -#define FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(IndexType, OutType) \ - bool \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__base( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const IndexType* offsets, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - return Fused8BitRowwiseEmbeddingLookupGenericSlowIdx< \ - IndexType, \ - uint8_t, \ - OutType, \ - false>( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - offsets, \ - weights, \ - normalize_by_lengths, \ - out); \ - } \ - decltype( \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__base) \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__avx2_fma; \ - bool Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const IndexType* offsets, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - const int32_t one = 1; \ - CAFFE_ENFORCE_EQ( \ - reinterpret_cast(&one)[0], \ - 1, \ - "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ - AVX2_FMA_DO( \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - offsets, \ - weights, \ - normalize_by_lengths, \ - out); \ - BASE_DO( \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - offsets, \ - weights, \ - normalize_by_lengths, \ - out); \ - } \ - template <> \ - void Fused8BitRowwiseEmbeddingLookupIdx( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const IndexType* offsets, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - bool success = \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - offsets, \ - weights, \ - normalize_by_lengths, \ - out); \ - if (success) { \ - return; \ - } \ - int64_t current = 0; \ - for (int m = 0; m < output_size; ++m) { \ - for (int64_t i = offsets[m]; i < offsets[m + 1]; ++i) { \ - CAFFE_ENFORCE_LT(current, index_size); \ - IndexType idx = indices[current]; \ - CAFFE_ENFORCE( \ - 0 <= idx && idx < data_size, \ - "Index ", \ - current, \ - " is out of bounds: ", \ - idx, \ - ", range 0 to ", \ - data_size); \ - ++current; \ - } \ - } \ - CAFFE_ENFORCE_EQ( \ - current, \ - index_size, \ - "Your input seems to be incorrect: the sum of lengths values should be " \ - "the size of the indices tensor, but it appears not."); \ - } -// clang-format on - -FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(int32_t, float); -FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(int64_t, float); - -#undef FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION - -} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h deleted file mode 100644 index f7422bd7b752..000000000000 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h +++ /dev/null @@ -1,57 +0,0 @@ -#pragma once - -#include - -namespace caffe2 { - -/** - * Embedding lookup with reduction. - * - * `input` of size data_size * (block_size + 8B) - * `indices` of size index_size - * `offsets` of size output_size - * `weights` nullptr or array of size index_size - * `out` of size output_size * block_size - * - * Note that block_size should be the number of quantized values per row in the - * data, i.e. excluding the scale and bias. The total (fused) block size is - * assumed to be this block_size, plus 4 bytes for scale and 4 bytes for bias. - * - * Behavior is roughly equivalent to pseudocode: - * - * pos = 0 - * fused_block_size = block_size + 8B // quantized values and scale and bias - * for (i = 0..output_size-1) - * for (k = 0..block_size-1) - * out[i*block_size + k] = 0 - * start_offset = offsets[i] - * end_offset = i == output_size-1 ? index_size : offsets[i+1] - 1 - * length = end_offset - start_offset - * for (j = start_offset..end_offset) - * for (k = 0..block_size-1) - * out[i*block_size + k] += input[indices[pos]*(fused_block_size) + k] * - * (weights ? weights[IS_WEIGHT_POSITIONAL ? j : pos] : 1.0) - * pos += 1 - * if (normalize_weights && length > 0) - * for (k = 0..block_size-1) - * out[i*block_size + k] /= length - * - */ - -template < - typename IndexType, - typename InType, - typename OutType, - bool IS_WEIGHT_POSITIONAL = false> -void Fused8BitRowwiseEmbeddingLookupIdx( - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t data_size, - const InType* input, - const IndexType* indices, - const IndexType* offsets, - const float* weights, // optional, can be null for non-weighted sum - bool normalize_by_lengths, - OutType* out); -} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_nbit_rowwise_conversion.cc b/caffe2/perfkernels/fused_nbit_rowwise_conversion.cc deleted file mode 100644 index 05cae2e280be..000000000000 --- a/caffe2/perfkernels/fused_nbit_rowwise_conversion.cc +++ /dev/null @@ -1,214 +0,0 @@ -#include "./fused_nbit_rowwise_conversion.h" - -#include -#include -#include - -#include "common.h" - -#ifdef USE_FBGEMM -#include "fbgemm/QuantUtils.h" -#endif - -namespace caffe2 { - -void FloatToFused8BitRowwiseQuantized__base( - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output) { - constexpr float kEpsilon = 1e-8f; - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - int output_columns = input_columns + 2 * sizeof(float); - for (std::size_t row = 0; row < input_rows; ++row) { - const float* input_row = input + row * input_columns; - std::uint8_t* output_row = output + row * output_columns; - float* output_row_scale_bias = - reinterpret_cast(output_row + input_columns); - - float minimum_element = - *std::min_element(input_row, input_row + input_columns); - float maximum_element = - *std::max_element(input_row, input_row + input_columns); - float range = maximum_element - minimum_element; - - output_row_scale_bias[0] = range / 255.0f; - output_row_scale_bias[1] = minimum_element; - const auto inverse_scale = 255.0f / (range + kEpsilon); - for (std::size_t col = 0; col < static_cast(input_columns); ++col) { - output_row[col] = - std::lrintf((input_row[col] - minimum_element) * inverse_scale); - } - } -} - -void Fused8BitRowwiseQuantizedToFloat__base( - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - int output_columns = input_columns - 2 * sizeof(float); - - for (std::size_t row = 0; row < input_rows; ++row) { - const std::uint8_t* input_row = input + row * input_columns; - const float* input_row_scale_bias = - reinterpret_cast(input_row + output_columns); - float* output_row = output + row * output_columns; - - for (std::size_t col = 0; col < static_cast(output_columns); ++col) { - output_row[col] = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1]; - } - } -} - -void FloatToFused8BitRowwiseQuantized( - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output) { -#ifdef USE_FBGEMM - fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( - input, input_rows, input_columns, output); -#else - FloatToFused8BitRowwiseQuantized__base( - input, input_rows, input_columns, output); -#endif -} - -void Fused8BitRowwiseQuantizedToFloat( - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output) { -#ifdef USE_FBGEMM - fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( - input, input_rows, input_columns, output); -#else - Fused8BitRowwiseQuantizedToFloat__base( - input, input_rows, input_columns, output); -#endif -} - -void FloatToFusedNBitRowwiseQuantizedSBHalf__base( - int bit_rate, - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output) { - int num_elem_per_byte = 8 / bit_rate; - int output_columns = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - (input_columns + num_elem_per_byte - 1) / num_elem_per_byte + - 2 * sizeof(at::Half); - for (std::size_t row = 0; row < input_rows; ++row) { - const float* input_row = input + row * input_columns; - std::uint8_t* output_row = output + row * output_columns; - at::Half* output_row_scale_bias = reinterpret_cast( - output_row + - (input_columns + num_elem_per_byte - 1) / num_elem_per_byte); - - float minimum_element = - *std::min_element(input_row, input_row + input_columns); - float maximum_element = - *std::max_element(input_row, input_row + input_columns); - - minimum_element = static_cast(minimum_element); - const float range = maximum_element - minimum_element; - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - at::Half scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1); - if (scale == 0) { - // Corner case handling when maximum_element == minimum_element - // Any scale would work because X - minimum_element will be 0 for all X - scale = 1.0f; - } - float inverse_scale = 1.0f / scale; - if (std::isinf(inverse_scale)) { - scale = 1.0f; - inverse_scale = 1.0f; - } - - output_row_scale_bias[0] = scale; - output_row_scale_bias[1] = minimum_element; - for (std::size_t col = 0; col < static_cast(input_columns); ++col) { - float X = input_row[col]; - std::uint8_t quantized = std::max( - 0, - std::min( - std::lrintf((X - minimum_element) * inverse_scale), - (1 << bit_rate) - 1)); - if (col % num_elem_per_byte == 0) { - output_row[col / num_elem_per_byte] = quantized; - } else { - output_row[col / num_elem_per_byte] |= - (quantized << ((col % num_elem_per_byte) * bit_rate)); - } - } - } -} - -void FusedNBitRowwiseQuantizedSBHalfToFloat__base( - int bit_rate, - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output) { - int num_elem_per_byte = 8 / bit_rate; - int output_columns = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - (input_columns - 2 * sizeof(at::Half)) * num_elem_per_byte; - - for (std::size_t row = 0; row < static_cast(input_rows); ++row) { - const std::uint8_t* input_row = input + row * input_columns; - const at::Half* input_row_scale_bias = reinterpret_cast( - input_row + - (output_columns + num_elem_per_byte - 1) / num_elem_per_byte); - float scale = input_row_scale_bias[0]; - float bias = input_row_scale_bias[1]; - float* output_row = output + row * output_columns; - - for (std::size_t col = 0; col < static_cast(output_columns); ++col) { - std::uint8_t quantized = input_row[col / num_elem_per_byte]; - quantized >>= (col % num_elem_per_byte) * bit_rate; - quantized &= (1 << bit_rate) - 1; - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - output_row[col] = scale * quantized + bias; - } - } -} - -void FloatToFusedNBitRowwiseQuantizedSBHalf( - int bit_rate, - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output) { -#ifdef USE_FBGEMM - fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( - bit_rate, input, input_rows, input_columns, output); -#else - FloatToFusedNBitRowwiseQuantizedSBHalf__base( - bit_rate, input, input_rows, input_columns, output); -#endif -} - -void FusedNBitRowwiseQuantizedSBHalfToFloat( - int bit_rate, - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output) { -#ifdef USE_FBGEMM - fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( - bit_rate, input, input_rows, input_columns, output); -#else - FusedNBitRowwiseQuantizedSBHalfToFloat__base( - bit_rate, input, input_rows, input_columns, output); -#endif -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_nbit_rowwise_conversion.h b/caffe2/perfkernels/fused_nbit_rowwise_conversion.h deleted file mode 100644 index da9ec5c6cdd6..000000000000 --- a/caffe2/perfkernels/fused_nbit_rowwise_conversion.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include -#include - -namespace caffe2 { - -void FloatToFused8BitRowwiseQuantized( - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output); - -void Fused8BitRowwiseQuantizedToFloat( - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output); - -/** - * Row-wise quantization with fp16 scale and bias - * - * @param bit_rate can be 2, 4, or 8 - */ -void FloatToFusedNBitRowwiseQuantizedSBHalf( - int bit_rate, - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output); - -void FusedNBitRowwiseQuantizedSBHalfToFloat( - int bit_rate, - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output); - -} // namespace caffe2 diff --git a/caffe2/perfkernels/lstm_unit_cpu-impl.h b/caffe2/perfkernels/lstm_unit_cpu-impl.h deleted file mode 100644 index 239d2807f778..000000000000 --- a/caffe2/perfkernels/lstm_unit_cpu-impl.h +++ /dev/null @@ -1,141 +0,0 @@ -#pragma once -#include -#include -#include -#include "c10/util/irange.h" -#include "caffe2/utils/conversions.h" - -#include "vectorizer.h" - -namespace caffe2 { -namespace perfkernels { -namespace { -template -inline T sigmoid(T x) { - return 1 / (1 + std::exp(-x)); -} - -template -inline T host_tanh(T x) { - return 2 * sigmoid(2 * x) - 1; -} - -template -inline void LstmUnitImpl( - const int N, - const int D, - const int t, - const T* H_prev, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const bool drop_states, - T* C, - T* H, - const float forget_bias) { - const T forgetBias = convert::To(forget_bias); - for (const auto n : c10::irange(N)) { - const bool valid = seqLengths == nullptr || t < seqLengths[n]; - if (!valid) { - if (drop_states) { - memset(H, 0, sizeof(T) * D); - memset(C, 0, sizeof(T) * D); - } else { - memcpy(H, H_prev, sizeof(T) * D); - memcpy(C, C_prev, sizeof(T) * D); - } - } else { - const T* X_D = &X[D]; - const T* X_2D = &X[2 * D]; - const T* X_3D = &X[3 * D]; - VECTOR_LOOP for (const auto d : c10::irange(D)) { - const T i = sigmoid(X[d]); - const T f = sigmoid(X_D[d] + forgetBias); - const T o = sigmoid(X_2D[d]); - const T g = host_tanh(X_3D[d]); - const T c_prev = C_prev[d]; - const T c = f * c_prev + i * g; - C[d] = c; - const T host_tanh_c = host_tanh(c); - H[d] = o * host_tanh_c; - } - } - H_prev += D; - C_prev += D; - X += 4 * D; - C += D; - H += D; - } -} - -template -inline void LstmUnitGradientImpl( - int N, - int D, - int t, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const T* C, - const T* H, - const T* C_diff, - const T* H_diff, - bool drop_states, - T* H_prev_diff, - T* C_prev_diff, - T* X_diff, - const float forget_bias) { - const T localForgetBias = convert::To(forget_bias); - for (const auto n : c10::irange(N)) { - const bool valid = seqLengths == nullptr || t < seqLengths[n]; - - if (!valid) { - if (drop_states) { - memset(C_prev_diff, 0, sizeof(T) * D); - memset(H_prev_diff, 0, sizeof(T) * D); - } else { - memcpy(H_prev_diff, H_diff, sizeof(T) * D); - memcpy(C_prev_diff, C_diff, sizeof(T) * D); - } - memset(X_diff, 0, 4 * sizeof(T) * D); - } else { - VECTOR_LOOP for (const auto d : c10::irange(D)) { - T* c_prev_diff = C_prev_diff + d; - T* h_prev_diff = H_prev_diff + d; - T* i_diff = X_diff + d; - T* f_diff = X_diff + 1 * D + d; - T* o_diff = X_diff + 2 * D + d; - T* g_diff = X_diff + 3 * D + d; - - const T i = sigmoid(X[d]); - const T f = sigmoid(X[1 * D + d] + localForgetBias); - const T o = sigmoid(X[2 * D + d]); - const T g = host_tanh(X[3 * D + d]); - const T c_prev = C_prev[d]; - const T c = C[d]; - const T host_tanh_c = host_tanh(c); - const T c_term_diff = - C_diff[d] + H_diff[d] * o * (1 - host_tanh_c * host_tanh_c); - *c_prev_diff = c_term_diff * f; - *h_prev_diff = 0; // not used in 'valid' case - *i_diff = c_term_diff * g * i * (1 - i); - *f_diff = c_term_diff * c_prev * f * (1 - f); - *o_diff = H_diff[d] * host_tanh_c * o * (1 - o); - *g_diff = c_term_diff * i * (1 - g * g); - } - } - C_prev += D; - X += 4 * D; - C += D; - H += D; - C_diff += D; - H_diff += D; - X_diff += 4 * D; - H_prev_diff += D; - C_prev_diff += D; - } -} - -} // namespace -} // namespace perfkernels -} // namespace caffe2 diff --git a/caffe2/perfkernels/lstm_unit_cpu.h b/caffe2/perfkernels/lstm_unit_cpu.h deleted file mode 100644 index e9c87f3082f9..000000000000 --- a/caffe2/perfkernels/lstm_unit_cpu.h +++ /dev/null @@ -1,73 +0,0 @@ -#pragma once -#include - -namespace caffe2 { -namespace detail { - -// Forward declration of the LSTMUnit templated -// implementation -template -void LstmUnitCpu( - const int N, - const int D, - const int t, - const T* H_prev, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const bool drop_states, - T* C, - T* H, - const float forget_bias); - -// Forward specialization -extern template void LstmUnitCpu( - const int N, - const int D, - const int t, - const float* H_prev, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const bool drop_states, - float* C, - float* H, - const float forget_bias); - -template -void LstmUnitGradientCpu( - int N, - int D, - int t, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const T* C, - const T* H, - const T* C_diff, - const T* H_diff, - bool drop_states, - T* H_prev_diff, - T* C_prev_diff, - T* X_diff, - const float forget_bias); - -extern template void LstmUnitGradientCpu( - int N, - int D, - int t, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const float* C, - const float* H, - const float* C_diff, - const float* H_diff, - bool drop_states, - float* H_prev_diff, - float* C_prev_diff, - float* X_diff, - const float forget_bias); - -} // namespace detail -} // namespace caffe2 diff --git a/caffe2/perfkernels/lstm_unit_cpu_avx2.cc b/caffe2/perfkernels/lstm_unit_cpu_avx2.cc deleted file mode 100644 index ac66c6bd3f52..000000000000 --- a/caffe2/perfkernels/lstm_unit_cpu_avx2.cc +++ /dev/null @@ -1,123 +0,0 @@ -#include "caffe2/perfkernels/lstm_unit_cpu-impl.h" - -namespace caffe2 { -namespace perfkernels { -namespace { -// Explicit initialize for float and AVX2 vectorization -template void LstmUnitImpl( - const int N, - const int D, - const int t, - const float* H_prev, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const bool drop_states, - float* C, - float* H, - const float forget_bias); - -template void LstmUnitGradientImpl( - int N, - int D, - int t, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const float* C, - const float* H, - const float* C_diff, - const float* H_diff, - bool drop_states, - float* H_prev_diff, - float* C_prev_diff, - float* X_diff, - const float forget_bias); -} // namespace - -// Define templated implementation fo LSTM kernels on CPU supporting AVX2 -template -void LstmUnitImpl__avx2_fma( - const int N, - const int D, - const int t, - const T* H_prev, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const bool drop_states, - T* C, - T* H, - const float forget_bias) { - LstmUnitImpl( - N, D, t, H_prev, C_prev, X, seqLengths, drop_states, C, H, forget_bias); -} - -template -void LstmUnitGradientImpl__avx2_fma( - int N, - int D, - int t, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const T* C, - const T* H, - const T* C_diff, - const T* H_diff, - bool drop_states, - T* H_prev_diff, - T* C_prev_diff, - T* X_diff, - const float forget_bias) { - LstmUnitGradientImpl( - N, - D, - t, - C_prev, - X, - seqLengths, - C, - H, - C_diff, - H_diff, - drop_states, - H_prev_diff, - C_prev_diff, - X_diff, - forget_bias); -} - -// Explicit initialize for float -template void LstmUnitImpl__avx2_fma( - const int N, - const int D, - const int t, - const float* H_prev, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const bool drop_states, - float* C, - float* H, - const float forget_bias); - -template void LstmUnitGradientImpl__avx2_fma( - int N, - int D, - int t, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const float* C, - const float* H, - const float* C_diff, - const float* H_diff, - bool drop_states, - float* H_prev_diff, - float* C_prev_diff, - float* X_diff, - const float forget_bias); - -} // namespace perfkernels -} // namespace caffe2 diff --git a/caffe2/perfkernels/lstm_unit_cpu_common.cc b/caffe2/perfkernels/lstm_unit_cpu_common.cc deleted file mode 100644 index 72d97d832625..000000000000 --- a/caffe2/perfkernels/lstm_unit_cpu_common.cc +++ /dev/null @@ -1,125 +0,0 @@ -#include "caffe2/perfkernels/lstm_unit_cpu_common.h" -#include "caffe2/perfkernels/common.h" -#include "caffe2/perfkernels/lstm_unit_cpu-impl.h" - -namespace caffe2 { -namespace detail { - -// Define templated implementation fo LSTM kernels on CPU -template -void LstmUnitCpu( - const int N, - const int D, - const int t, - const T* H_prev, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const bool drop_states, - T* C, - T* H, - const float forget_bias) { - // Do CPU dispatching - AVX2_FMA_DO( - perfkernels::LstmUnitImpl, - N, - D, - t, - H_prev, - C_prev, - X, - seqLengths, - drop_states, - C, - H, - forget_bias); - perfkernels::LstmUnitImpl( - N, D, t, H_prev, C_prev, X, seqLengths, drop_states, C, H, forget_bias); -} - -template -void LstmUnitGradientCpu( - int N, - int D, - int t, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const T* C, - const T* H, - const T* C_diff, - const T* H_diff, - bool drop_states, - T* H_prev_diff, - T* C_prev_diff, - T* X_diff, - const float forget_bias) { - // Do CPU dispatching - AVX2_FMA_DO( - perfkernels::LstmUnitGradientImpl, - N, - D, - t, - C_prev, - X, - seqLengths, - C, - H, - C_diff, - H_diff, - drop_states, - H_prev_diff, - C_prev_diff, - X_diff, - forget_bias); - perfkernels::LstmUnitGradientImpl( - N, - D, - t, - C_prev, - X, - seqLengths, - C, - H, - C_diff, - H_diff, - drop_states, - H_prev_diff, - C_prev_diff, - X_diff, - forget_bias); -} - -// Explicit initialize for float -template void LstmUnitCpu( - const int N, - const int D, - const int t, - const float* H_prev, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const bool drop_states, - float* C, - float* H, - const float forget_bias); - -template void LstmUnitGradientCpu( - int N, - int D, - int t, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const float* C, - const float* H, - const float* C_diff, - const float* H_diff, - bool drop_states, - float* H_prev_diff, - float* C_prev_diff, - float* X_diff, - const float forget_bias); - -} // namespace detail -} // namespace caffe2 diff --git a/caffe2/perfkernels/lstm_unit_cpu_common.h b/caffe2/perfkernels/lstm_unit_cpu_common.h deleted file mode 100644 index d8680adf7d1d..000000000000 --- a/caffe2/perfkernels/lstm_unit_cpu_common.h +++ /dev/null @@ -1,71 +0,0 @@ -#pragma once -#include - -namespace caffe2 { -namespace perfkernels { - -template -void LstmUnitImpl__avx2_fma( - const int N, - const int D, - const int t, - const T* H_prev, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const bool drop_states, - T* C, - T* H, - const float forget_bias); - -template -void LstmUnitGradientImpl__avx2_fma( - int N, - int D, - int t, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const T* C, - const T* H, - const T* C_diff, - const T* H_diff, - bool drop_states, - T* H_prev_diff, - T* C_prev_diff, - T* X_diff, - const float forget_bias); - -// Forward declaration of specialized functions -extern template void LstmUnitImpl__avx2_fma( - const int N, - const int D, - const int t, - const float* H_prev, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const bool drop_states, - float* C, - float* H, - const float forget_bias); - -extern template void LstmUnitGradientImpl__avx2_fma( - int N, - int D, - int t, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const float* C, - const float* H, - const float* C_diff, - const float* H_diff, - bool drop_states, - float* H_prev_diff, - float* C_prev_diff, - float* X_diff, - const float forget_bias); - -} // namespace perfkernels -} // namespace caffe2 diff --git a/caffe2/perfkernels/math.h b/caffe2/perfkernels/math.h deleted file mode 100644 index 63380fc3f9a1..000000000000 --- a/caffe2/perfkernels/math.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include - -namespace caffe2 { - -namespace math { - -// Returns the quantized and compressed values of floating inputs -// The "fused" representation stores the [bitwidth][tail][min][max] -// with the quantized data in one array. Since we store 8/bitwidth -// quantized data in one byte, the last buckets of some bytes may have -// unused bits. There are totally tail buckets are unused. -// We encode *bitwidth* and *tail* at the beginning, -// following by 32-bit floating data respresenting min and max. -// | bitwidth | tail | min | max | ... int8 data ... | -// | 1B | 1B | 4B | 4B | ...output_data....| -// In output_data: the b-th bucket of the i-th byte stores -// the i-th data of the b-th segment of input row - -void quantize_and_compress( - const float* input_data, - std::uint8_t* output_data, - std::uint64_t input_size, - std::uint64_t bitwidth, - bool random, - const float* random_buffer); - -void decompress_and_dequantize( - const std::uint8_t* input_data, - float* output_data, - std::uint64_t input_size); - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/perfkernels/math_cpu_avx2.cc b/caffe2/perfkernels/math_cpu_avx2.cc deleted file mode 100644 index 325d9c4591ef..000000000000 --- a/caffe2/perfkernels/math_cpu_avx2.cc +++ /dev/null @@ -1,246 +0,0 @@ -// Implements the math functions for CPU. -// The implementation in this file allows us to route the underlying numerical -// computation library to different compiler options (-mno-avx2 or -mavx2). - -#include -#include -#include - -#include - -using std::uint64_t; -using std::uint8_t; - -namespace caffe2 { - -namespace math { - -static constexpr double QEPSILON = 1e-8; - -void quantize_and_compress__avx2( - const float* input_data, - uint8_t* output_data, - uint64_t input_size, - uint64_t bitwidth, - bool random, - const float* random_buffer) { - __m256i shuffle_mask_v = _mm256_set_epi8( - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - 0x0c, - 0x08, - 0x04, - 0x00, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - 0x0c, - 0x08, - 0x04, - 0x00); - __m256i permute_mask_v = - _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); - - uint64_t data_per_byte = 8 / bitwidth; - uint64_t tail = input_size % data_per_byte; - tail = tail ? data_per_byte - tail : 0; - uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte; - - // basic info - float minimum_element = INFINITY, maximum_element = -INFINITY; - for (const auto i : c10::irange(input_size)) { - minimum_element = - (input_data[i] < minimum_element) ? input_data[i] : minimum_element; - maximum_element = - (input_data[i] > maximum_element) ? input_data[i] : maximum_element; - } - output_data[0] = bitwidth; - output_data[1] = tail; - reinterpret_cast(output_data + 2)[0] = minimum_element; - reinterpret_cast(output_data + 2)[1] = maximum_element; - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f); - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float gap_inverse = 1. / (gap + QEPSILON); - uint8_t max_q = (1 << bitwidth) - 1; - uint64_t bit_start = 0; - if (random) { - for (uint64_t start = 0; start < input_size; start += segment_size) { - uint64_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; - uint64_t i = 0; - constexpr int VLEN = 8; - for (; i < stride / VLEN * VLEN; i += VLEN) { - __m256 r_v = _mm256_loadu_ps(&random_buffer[start + i]); - __m256 fval_v = _mm256_loadu_ps(input_data + start + i); - __m256 thetimes_v = _mm256_mul_ps( - _mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)), - _mm256_set1_ps(gap_inverse)); - __m256 rounded_v = _mm256_floor_ps(_mm256_add_ps(thetimes_v, r_v)); - rounded_v = _mm256_max_ps( - _mm256_setzero_ps(), - _mm256_min_ps(_mm256_set1_ps(max_q), rounded_v)); - __m256i qval_v = _mm256_cvtps_epi32(rounded_v); - __m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128( - reinterpret_cast(output_data + 10 + i))); - orval_v = - _mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start)); - orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v); - orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v); - *reinterpret_cast(output_data + 10 + i) = - _mm256_extract_epi64(orval_v, 0); - } - for (; i < stride; ++i) { - float fval = input_data[start + i]; - float thetimes = (fval - minimum_element) * gap_inverse; - float rounded = floor(thetimes + random_buffer[start + i]); - rounded = rounded < static_cast(max_q) - ? rounded - : static_cast(max_q); - rounded = rounded > 0.0f ? rounded : 0.0f; - uint8_t qval = rounded; - - uint8_t orval = output_data[10 + i]; - output_data[10 + i] = orval | static_cast(qval << bit_start); - } - bit_start += bitwidth; - } - } else { - // !random - for (uint64_t start = 0; start < input_size; start += segment_size) { - uint64_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; - uint64_t i = 0; - constexpr int VLEN = 8; - for (; i < stride / VLEN * VLEN; i += VLEN) { - __m256 fval_v = _mm256_loadu_ps(input_data + start + i); - __m256 thetimes_v = _mm256_mul_ps( - _mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)), - _mm256_set1_ps(gap_inverse)); - thetimes_v = _mm256_max_ps( - _mm256_setzero_ps(), - _mm256_min_ps(_mm256_set1_ps(max_q), thetimes_v)); - __m256i qval_v = _mm256_cvtps_epi32(_mm256_round_ps( - thetimes_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - __m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128( - reinterpret_cast(output_data + 10 + i))); - orval_v = - _mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start)); - orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v); - orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v); - *reinterpret_cast(output_data + 10 + i) = - _mm256_extract_epi64(orval_v, 0); - } - for (; i < stride; ++i) { - float fval = input_data[start + i]; - float thetimes = (fval - minimum_element) * gap_inverse; - thetimes = thetimes < static_cast(max_q) - ? thetimes - : static_cast(max_q); - thetimes = thetimes > 0.0f ? thetimes : 0.0f; - uint8_t qval = nearbyint(thetimes); - - uint8_t orval = output_data[10 + i]; - output_data[10 + i] = orval | static_cast(qval << bit_start); - } - bit_start += bitwidth; - } - } // !random -} - -void decompress_and_dequantize__avx2( - const uint8_t* input_data, - float* output_data, - uint64_t input_size) { - // basic info - const float minimum_element = - reinterpret_cast(input_data + 2)[0]; - const float maximum_element = - reinterpret_cast(input_data + 2)[1]; - const uint64_t bitwidth = input_data[0]; - const float gap = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) + - QEPSILON; // for exact recovering - - const uint64_t tail = input_data[1]; - - const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail; - // decoding - uint64_t bit_start = 0; - const uint64_t segment_size = input_size - 10; - for (uint64_t start = 0; start < output_size; start += segment_size) { - uint64_t stride = start + segment_size <= output_size ? segment_size - : output_size - start; - uint8_t mask = (1 << bitwidth) - 1; - uint64_t i = 0; - // Can process 8 elements at a time because we need to expand uint8_t - // to int32_t to use epi32 vector instructions. - constexpr int VLEN = 8; - for (; i < stride / VLEN * VLEN; i += VLEN) { - __m128i in_v = _mm_lddqu_si128( - reinterpret_cast(input_data + 10 + i)); - __m256i out_epi32_v = _mm256_and_si256( - _mm256_srli_epi32(_mm256_cvtepu8_epi32(in_v), bit_start), - _mm256_set1_epi32(mask)); - __m256 out_v = _mm256_fmadd_ps( - _mm256_cvtepi32_ps(out_epi32_v), - _mm256_set1_ps(gap), - _mm256_set1_ps(minimum_element)); - _mm256_storeu_ps(output_data + start + i, out_v); - } - for (; i < stride; ++i) { - output_data[start + i] = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - ((input_data[10 + i] >> bit_start) & mask) * gap + minimum_element; - } - bit_start += bitwidth; - } -} - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/perfkernels/math_cpu_base.cc b/caffe2/perfkernels/math_cpu_base.cc deleted file mode 100644 index fd3ba83cd4a9..000000000000 --- a/caffe2/perfkernels/math_cpu_base.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Implements the math functions for CPU. -// The implementation in this file allows us to route the underlying numerical -// computation library to different compiler options (-mno-avx2 or -mavx2). - -#include -#include -#include - -#include "common.h" -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include "math.h" - -#include - -using std::uint64_t; -using std::uint8_t; - -namespace caffe2 { - -namespace math { - -static constexpr double QEPSILON = 1e-8; - -void quantize_and_compress__base( - const float* input_data, - uint8_t* output_data, - uint64_t input_size, - uint64_t bitwidth, - bool random, - const float* random_buffer) { - uint64_t data_per_byte = 8 / bitwidth; - uint64_t tail = input_size % data_per_byte; - tail = tail ? data_per_byte - tail : 0; - uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte; - - // basic info - float minimum_element = INFINITY, maximum_element = -INFINITY; - for (const auto i : c10::irange(input_size)) { - minimum_element = - input_data[i] < minimum_element ? input_data[i] : minimum_element; - maximum_element = - input_data[i] > maximum_element ? input_data[i] : maximum_element; - } - output_data[0] = bitwidth; - output_data[1] = tail; - reinterpret_cast(output_data + 2)[0] = minimum_element; - reinterpret_cast(output_data + 2)[1] = maximum_element; - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f); - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float gap_inverse = 1. / (gap + QEPSILON); - uint8_t max_q = (1 << bitwidth) - 1; - uint64_t bit_start = 0; - if (random) { - for (uint64_t start = 0; start < input_size; start += segment_size) { - uint64_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; - uint64_t i = 0; - for (; i < stride; ++i) { - float fval = input_data[start + i]; - float thetimes = (fval - minimum_element) * gap_inverse; - float rounded = floor(thetimes + random_buffer[start + i]); - rounded = rounded < static_cast(max_q) - ? rounded - : static_cast(max_q); - rounded = rounded > 0.0f ? rounded : 0.0f; - uint8_t qval = rounded; - - uint8_t orval = output_data[10 + i]; - output_data[10 + i] = orval | static_cast(qval << bit_start); - } - bit_start += bitwidth; - } - } else { - for (uint64_t start = 0; start < input_size; start += segment_size) { - uint64_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; - uint64_t i = 0; - for (; i < stride; ++i) { - float fval = input_data[start + i]; - float thetimes = (fval - minimum_element) * gap_inverse; - thetimes = thetimes < static_cast(max_q) - ? thetimes - : static_cast(max_q); - thetimes = thetimes > 0.0f ? thetimes : 0.0f; - uint8_t qval = nearbyint(thetimes); - - uint8_t orval = output_data[10 + i]; - output_data[10 + i] = orval | static_cast(qval << bit_start); - } - bit_start += bitwidth; - } - } -} - -decltype(quantize_and_compress__base) quantize_and_compress__avx2; -void quantize_and_compress( - const float* input_data, - uint8_t* output_data, - uint64_t input_size, - uint64_t bitwidth, - bool random, - const float* random_buffer) { - AVX2_DO( - quantize_and_compress, - input_data, - output_data, - input_size, - bitwidth, - random, - random_buffer); - BASE_DO( - quantize_and_compress, - input_data, - output_data, - input_size, - bitwidth, - random, - random_buffer); -} - -void decompress_and_dequantize__base( - const uint8_t* input_data, - float* output_data, - uint64_t input_size) { - // basic info - const float minimum_element = - reinterpret_cast(input_data + 2)[0]; - const float maximum_element = - reinterpret_cast(input_data + 2)[1]; - const uint64_t bitwidth = input_data[0]; - const float gap = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) + - QEPSILON; // for exact recovering - - const uint64_t tail = input_data[1]; - - const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail; - // decoding - uint64_t bit_start = 0; - const uint64_t segment_size = input_size - 10; - for (uint64_t start = 0; start < output_size; start += segment_size) { - uint64_t stride = start + segment_size <= output_size ? segment_size - : output_size - start; - uint8_t mask = (1 << bitwidth) - 1; - uint64_t i = 0; - for (; i < stride; ++i) { - output_data[start + i] = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - ((input_data[10 + i] >> bit_start) & mask) * gap + minimum_element; - } - bit_start += bitwidth; - } -} - -decltype(decompress_and_dequantize__base) decompress_and_dequantize__avx2; -void decompress_and_dequantize( - const uint8_t* input_data, - float* output_data, - uint64_t input_size) { - AVX2_DO(decompress_and_dequantize, input_data, output_data, input_size); - BASE_DO(decompress_and_dequantize, input_data, output_data, input_size); -} - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/perfkernels/typed_axpy.cc b/caffe2/perfkernels/typed_axpy.cc deleted file mode 100644 index 400041766e61..000000000000 --- a/caffe2/perfkernels/typed_axpy.cc +++ /dev/null @@ -1,88 +0,0 @@ -#include -#include "caffe2/perfkernels/typed_axpy.h" -#include "caffe2/perfkernels/common.h" - -namespace caffe2 { - -void TypedAxpy__base(int N, const float a, const float* x, float* y) { - for (int i = 0; i < N; ++i) { - y[i] += a * x[i]; - } -} - -decltype(TypedAxpy__base) TypedAxpy__avx2_fma; -decltype(TypedAxpy__base) TypedAxpy__avx_f16c; -template <> -void TypedAxpy(int N, const float a, const float* x, float* y) { - AVX2_FMA_DO(TypedAxpy, N, a, x, y); - AVX_F16C_DO(TypedAxpy, N, a, x, y); - BASE_DO(TypedAxpy, N, a, x, y); -} - -void TypedAxpyHalffloat__base( - int N, - const float a, - const at::Half* x, - float* y) { - for (int i = 0; i < N; ++i) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - union { - uint32_t intval; - float floatval; - } t1; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t t2, t3; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t1.intval = x[i].x & 0x7fff; // Non-sign bits - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t2 = x[i].x & 0x8000; // Sign bit - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t3 = x[i].x & 0x7c00; // Exponent - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t1.intval <<= 13; // Align mantissa on MSB - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t2 <<= 16; // Shift sign bit into position - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t1.intval += 0x38000000; // Adjust bias - t1.intval = (t3 == 0 ? 0 : t1.intval); // Denormals-as-zero - t1.intval |= t2; // Re-insert sign bit - y[i] += t1.floatval * a; - } -} - -decltype(TypedAxpyHalffloat__base) TypedAxpyHalffloat__avx2_fma; -decltype(TypedAxpyHalffloat__base) TypedAxpyHalffloat__avx_f16c; -template <> -void TypedAxpy( - int N, - const float a, - const at::Half* x, - float* y) { - AVX2_FMA_DO(TypedAxpyHalffloat, N, a, x, y); - AVX_F16C_DO(TypedAxpyHalffloat, N, a, x, y); - BASE_DO(TypedAxpyHalffloat, N, a, x, y); -} - -void TypedAxpy_uint8_float__base( - int N, - const float a, - const std::uint8_t* x, - float* y) { - for (int i = 0; i < N; ++i) { - y[i] += (float)(x[i]) * a; - } -} - -decltype(TypedAxpy_uint8_float__base) TypedAxpy_uint8_float__avx2_fma; -decltype(TypedAxpy_uint8_float__base) TypedAxpy_uint8_float__avx_f16c; -template <> -void TypedAxpy( - int N, - const float a, - const std::uint8_t* x, - float* y) { - AVX2_FMA_DO(TypedAxpy_uint8_float, N, a, x, y); - BASE_DO(TypedAxpy_uint8_float, N, a, x, y); -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/typed_axpy.h b/caffe2/perfkernels/typed_axpy.h deleted file mode 100644 index 85b1adda0b9b..000000000000 --- a/caffe2/perfkernels/typed_axpy.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -namespace caffe2 { - -// Similar to Axpy that calculate y = a * x + y, but allowing x and y to be -// of different data types. -// It also provides a performance optimization hint (use_a) to see if a is going -// to be 1 or not. -template -void TypedAxpy(int N, const OUT a, const IN* x, OUT* y); - -} // namespace caffe2 diff --git a/caffe2/perfkernels/typed_axpy_avx.cc b/caffe2/perfkernels/typed_axpy_avx.cc deleted file mode 100644 index 2663cbc3ec79..000000000000 --- a/caffe2/perfkernels/typed_axpy_avx.cc +++ /dev/null @@ -1,68 +0,0 @@ -#include "caffe2/perfkernels/cvtsh_ss_bugfix.h" - -#include -#include -#include - -namespace caffe2 { - -void TypedAxpy__avx_f16c(int N, const float a, const float* x, float* y) { - int current = 0; - const int bound = (N % 8) ? N - 8 : N; - __m256 mma = _mm256_set1_ps(a); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (; current < bound; current += 8) { - _mm256_storeu_ps( - y + current, - _mm256_add_ps( - _mm256_mul_ps(mma, _mm256_loadu_ps(x + current)), - _mm256_loadu_ps(y + current))); - } - - if (bound != N) { - while (current < N) { - y[current] += x[current] * a; - ++current; - } - } -} - -void TypedAxpyHalffloat__avx_f16c( - int N, - const float a, - const at::Half* x, - float* y) { - // if x does not start at the 16 byte boundary, we will process the first few. - // before we get to a real one. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - while ((reinterpret_cast(x) % 16) && N) { - *(y++) += _cvtsh_ss((*(x++)).x) * a; - --N; - } - - // From now on we can do vectorized additions using __m256, which is 8 floats, - // so we will vectorize every 8 element and then resort to cvtsh_ss. - __m256 mma = _mm256_set1_ps(a); - int current = 0; - const int bound = (N % 8) ? N - 8 : N; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (; current < bound; current += 8) { - __m128i mmx_16 = - _mm_loadu_si128(reinterpret_cast(x + current)); - __m256 mmx_32 = _mm256_cvtph_ps(mmx_16); - __m256 mmy_in = _mm256_loadu_ps(y + current); - __m256 mmmul = _mm256_mul_ps(mmx_32, mma); - __m256 mmy_out = _mm256_add_ps(mmmul, mmy_in); - _mm256_storeu_ps(y + current, mmy_out); - } - - if (bound != N) { - while (current < N) { - y[current] += _cvtsh_ss(x[current].x) * a; - ++current; - } - } -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/typed_axpy_avx2.cc b/caffe2/perfkernels/typed_axpy_avx2.cc deleted file mode 100644 index 2da1e7e379bd..000000000000 --- a/caffe2/perfkernels/typed_axpy_avx2.cc +++ /dev/null @@ -1,104 +0,0 @@ -#include "caffe2/perfkernels/cvtsh_ss_bugfix.h" - -#include -#include -#include - -namespace caffe2 { - -void TypedAxpy__avx2_fma(int N, const float a, const float* x, float* y) { - int current = 0; - const int bound = (N % 8) ? N - 8 : N; - __m256 mma = _mm256_set1_ps(a); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (; current < bound; current += 8) { - _mm256_storeu_ps( - y + current, - _mm256_fmadd_ps( - mma, _mm256_loadu_ps(x + current), _mm256_loadu_ps(y + current))); - } - - if (bound != N) { - while (current < N) { - y[current] += x[current] * a; - ++current; - } - } -} - -void TypedAxpyHalffloat__avx2_fma( - int N, - const float a, - const at::Half* x, - float* y) { - // if x does not start at the 16 byte boundary, we will process the first few. - // before we get to a real one. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - while ((reinterpret_cast(x) % 16) && N) { - *(y++) += _cvtsh_ss((*(x++)).x) * a; - --N; - } - - // From now on we can do vectorized additions using __m256, which is 8 floats, - // so we will vectorize every 8 element and then resort to cvtsh_ss. - __m256 mma = _mm256_set1_ps(a); - int current = 0; - const int bound = (N % 8) ? N - 8 : N; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (; current < bound; current += 8) { - __m128i mmx_16 = - _mm_loadu_si128(reinterpret_cast(x + current)); - __m256 mmx_32 = _mm256_cvtph_ps(mmx_16); - __m256 mmy = _mm256_loadu_ps(y + current); - mmy = _mm256_fmadd_ps(mmx_32, mma, mmy); - _mm256_storeu_ps(y + current, mmy); - } - - if (bound != N) { - while (current < N) { - y[current] += _cvtsh_ss(x[current].x) * a; - ++current; - } - } -} - -void TypedAxpy_uint8_float__avx2_fma( - int N, - const float a, - const std::uint8_t* x, - float* y) { - // if x does not start at the 16 byte boundary, we will process the first few. - // before we get to a real one. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - while ((reinterpret_cast(x) % 16) && N) { - *(y++) += static_cast(*(x++)) * a; - --N; - } - - // From now on we can do vectorized additions using __m256, which is 8 floats, - // so we will vectorize every 8 element and then resort to cvtsh_ss. - __m256 mma = _mm256_set1_ps(a); - int current = 0; - const int bound = (N % 8) ? N - 8 : N; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (; current < bound; current += 8) { - __m256i mmx_int32 = _mm256_cvtepi8_epi32( - _mm_loadu_si128(reinterpret_cast(x + current))); - __m256 mmx_fp32 = _mm256_cvtepi32_ps(mmx_int32); - - __m256 mmy = _mm256_loadu_ps(y + current); - mmy = _mm256_fmadd_ps(mmx_fp32, mma, mmy); - _mm256_storeu_ps(y + current, mmy); - } - - if (bound != N) { - while (current < N) { - y[current] += (float)(x[current]) * a; - ++current; - } - } -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/vectorizer.h b/caffe2/perfkernels/vectorizer.h deleted file mode 100644 index be4e6bbc280f..000000000000 --- a/caffe2/perfkernels/vectorizer.h +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#if (ENABLE_VECTORIZATION > 0) && !defined(_DEBUG) && !defined(DEBUG) -#if defined(__clang__) && (__clang_major__ > 7) -#define IS_SANITIZER \ - ((__has_feature(address_sanitizer) == 1) || \ - (__has_feature(memory_sanitizer) == 1) || \ - (__has_feature(thread_sanitizer) == 1) || \ - (__has_feature(undefined_sanitizer) == 1)) - -#if IS_SANITIZER == 0 -#define VECTOR_LOOP _Pragma("clang loop vectorize(enable)") -#define FAST_MATH _Pragma("clang fp contract(fast)") -#define VECTORIZED_KERNEL 1 -#endif -#elif defined(_OPENMP) && (_OPENMP >= 201511) -// Support with OpenMP4.5 and above -#define VECTOR_LOOP _Pragma("omp for simd") -#define VECTORIZED_KERNEL 1 -#define FAST_MATH -#endif -#endif - -#ifndef VECTOR_LOOP -// Not supported -#define VECTOR_LOOP -#define FAST_MATH -#endif From 30875953a4baeb34952cd726f25100a713821ffd Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 10 Jun 2024 23:40:45 +0000 Subject: [PATCH 604/706] [1/N] Remove inclusion of c10/util/string_utils.h (#128300) As a first step to remove it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128300 Approved by: https://github.com/ezyang, https://github.com/eqy --- aten/src/ATen/DLConvertor.cpp | 20 +++++++++---------- aten/src/ATen/TensorIterator.cpp | 3 +-- aten/src/ATen/code_template.h | 2 +- aten/src/ATen/native/cuda/jit_utils.cpp | 2 +- aten/src/ATen/native/quantized/cpu/qconv.cpp | 14 ++++++------- .../quantized/cpu/qembeddingbag_prepack.cpp | 16 ++++++++------- .../torch/nn/modules/container/modulelist.h | 6 +++--- .../nn/modules/container/parameterlist.h | 10 +++++----- .../torch/nn/modules/container/sequential.h | 8 ++++---- torch/csrc/autograd/cpp_hook.cpp | 2 +- torch/csrc/autograd/custom_function.h | 8 ++++---- torch/csrc/distributed/rpc/agent_utils.cpp | 8 ++++---- torch/csrc/distributed/rpc/rref_context.cpp | 8 ++++---- .../csrc/distributed/rpc/tensorpipe_agent.cpp | 12 +++++------ torch/csrc/distributed/rpc/utils.cpp | 4 ++-- torch/csrc/jit/api/function_impl.cpp | 2 +- torch/csrc/jit/codegen/fuser/codegen.cpp | 14 ++++++------- torch/csrc/jit/codegen/fuser/compiler.cpp | 2 +- .../jit/frontend/function_schema_parser.cpp | 9 ++++----- torch/csrc/jit/frontend/ir_emitter.cpp | 6 +++--- torch/csrc/jit/frontend/name_mangler.cpp | 4 ++-- .../csrc/jit/frontend/schema_type_parser.cpp | 1 - torch/csrc/jit/frontend/tree_views.h | 3 +-- torch/csrc/jit/ir/ir.h | 2 +- .../mobile/compatibility/backport_manager.cpp | 2 +- torch/csrc/jit/mobile/train/export_data.cpp | 2 +- .../jit/passes/fixup_trace_scope_blocks.cpp | 2 +- .../jit/passes/hoist_conv_packed_params.cpp | 4 ++-- torch/csrc/jit/passes/onnx/peephole.cpp | 2 +- torch/csrc/jit/passes/prepack_folding.cpp | 2 +- .../passes/quantization/dedup_module_uses.cpp | 4 ++-- .../passes/quantization/insert_observers.cpp | 4 ++-- .../quantization/insert_quant_dequant.cpp | 4 ++-- .../quantization/register_packed_params.cpp | 4 ++-- .../csrc/jit/passes/utils/subgraph_utils.cpp | 2 +- torch/csrc/jit/runtime/register_ops_utils.cpp | 2 +- torch/csrc/jit/runtime/register_ops_utils.h | 1 - torch/csrc/jit/runtime/static/impl.cpp | 8 ++++---- torch/csrc/jit/tensorexpr/block_codegen.cpp | 2 +- torch/csrc/jit/tensorexpr/eval.cpp | 8 ++++---- torch/csrc/jit/tensorexpr/eval.h | 1 - torch/csrc/jit/tensorexpr/ir.cpp | 2 +- torch/csrc/jit/tensorexpr/ir.h | 3 +-- torch/csrc/jit/tensorexpr/kernel.cpp | 7 +++---- torch/csrc/jit/tensorexpr/loopnest.cpp | 3 +-- torch/csrc/jit/tensorexpr/operators/misc.cpp | 2 +- torch/csrc/jit/tensorexpr/registerizer.cpp | 2 +- .../jit/tensorexpr/unique_name_manager.cpp | 3 +-- 48 files changed, 116 insertions(+), 126 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 3d2350d26101..6fb966f66713 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -143,7 +143,7 @@ static Device getATenDevice(const DLDevice& ctx, void* data) { return at::detail::getXPUHooks().getDeviceFromPtr(data); default: TORCH_CHECK( - false, "Unsupported device_type: " + c10::to_string(ctx.device_type)); + false, "Unsupported device_type: ", std::to_string(ctx.device_type)); } } @@ -167,7 +167,7 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kUInt bits " + c10::to_string(dtype.bits)); + false, "Unsupported kUInt bits ", std::to_string(dtype.bits)); } break; case DLDataTypeCode::kDLInt: @@ -186,7 +186,7 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kInt bits " + c10::to_string(dtype.bits)); + false, "Unsupported kInt bits ", std::to_string(dtype.bits)); } break; case DLDataTypeCode::kDLFloat: @@ -202,7 +202,7 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kFloat bits " + c10::to_string(dtype.bits)); + false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; case DLDataTypeCode::kDLBfloat: @@ -212,7 +212,7 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kFloat bits " + c10::to_string(dtype.bits)); + false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; case DLDataTypeCode::kDLComplex: @@ -228,7 +228,7 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kFloat bits " + c10::to_string(dtype.bits)); + false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; case DLDataTypeCode::kDLBool: @@ -238,11 +238,11 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kDLBool bits " + c10::to_string(dtype.bits)); + false, "Unsupported kDLBool bits ", std::to_string(dtype.bits)); } break; default: - TORCH_CHECK(false, "Unsupported code " + c10::to_string(dtype.code)); + TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code)); } return stype; } @@ -298,9 +298,7 @@ Tensor fromDLPack(DLManagedTensor* src) { return fromDLPack(src, std::move(deleter)); } -Tensor fromDLPack( - DLManagedTensor* src, - std::function deleter) { +Tensor fromDLPack(DLManagedTensor* src, std::function deleter) { Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data); ScalarType stype = toScalarType(src->dl_tensor.dtype); if (!src->dl_tensor.strides) { diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index c4a68a33e306..ecc90ace61e6 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -22,7 +22,6 @@ #endif #include -#include #include #include @@ -1398,7 +1397,7 @@ bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) { break; } default: - TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", c10::to_string((int)setup_type)); + TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", std::to_string((int)setup_type)); } //coalescing dimensions consists of collapsing dimensions to 1 (we are limited to contiguous no-broadcast cases here) if (ndim() > 1){ diff --git a/aten/src/ATen/code_template.h b/aten/src/ATen/code_template.h index 393e322e6fe6..ebf113e9d226 100644 --- a/aten/src/ATen/code_template.h +++ b/aten/src/ATen/code_template.h @@ -31,7 +31,7 @@ struct TemplateEnv { // Add a number 'v' to the map at key 'k' template void d(const std::string& k, const T& v) { - strings_[k] = c10::to_string(v); + strings_[k] = std::to_string(v); lists_.erase(k); } diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 0d870cef5870..67b8d3e54ba5 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -1002,7 +1002,7 @@ std::string generate_code( std::string extra_args = ""; for (size_t i = 0; i < extra_args_typenames.size(); i++) { auto type = std::string(extra_args_typenames[i]); - auto name = "extra_arg_" + std::string(to_string(i)); + auto name = "extra_arg_" + std::to_string(i); extra_params += "," + type + " " + name; extra_args += ", " + name; } diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 6b9cbc4a92c1..25b9b2b4e92c 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1,6 +1,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include +#include #include #include @@ -35,7 +36,6 @@ #endif #include -#include namespace { // To have a sanity check for maximum matrix size. @@ -1848,15 +1848,15 @@ class QConvInt8ForBC final { int64_t output_zero_point) { if (kReluFused) { TORCH_WARN_ONCE( - "Arguments [stride, padding, dilation, groups] in ops.quantized.conv" - + c10::to_string(kSpatialDim) + "d_relu, " + - "have been removed, please update your model to remove these arguments."); + "Arguments [stride, padding, dilation, groups] in ops.quantized.conv" + + std::to_string(kSpatialDim), + "d_relu, have been removed, please update your model to remove these arguments."); return packed_weight->apply_relu(act, output_scale, output_zero_point); } else { TORCH_WARN_ONCE( - "Arguments [stride, padding, dilation, groups] in ops.quantized.conv" - + c10::to_string(kSpatialDim) + "d, " + - "have been removed, please update your model to remove these arguments."); + "Arguments [stride, padding, dilation, groups] in ops.quantized.conv", + std::to_string(kSpatialDim), + "d, have been removed, please update your model to remove these arguments."); return packed_weight->apply(act, output_scale, output_zero_point); } } diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index 9cfbce72e31d..4b9c8ea2bdc9 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -342,7 +342,10 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { output_shape[cols_dim] = output_columns; at::SymDimVector output_shape_vec(output_shape); - return at::empty_symint(output_shape_vec, weight.options().dtype(weight.scalar_type()), weight.suggest_memory_format()); + return at::empty_symint( + output_shape_vec, + weight.options().dtype(weight.scalar_type()), + weight.suggest_memory_format()); } namespace { @@ -373,9 +376,10 @@ Tensor _qembeddingbag_nbit_prepack_helper( int NUM_ELEM_PER_BYTE = 8 / bit_width; TORCH_CHECK( weight_contig.size(weight.dim() - 1) % NUM_ELEM_PER_BYTE == 0, - "qembeddingbag_" + c10::to_string(bit_width) + - "bit_prepack only works for the number of columns a multiple of " + - c10::to_string(NUM_ELEM_PER_BYTE)); + "qembeddingbag_", + std::to_string(bit_width), + "bit_prepack only works for the number of columns a multiple of ", + std::to_string(NUM_ELEM_PER_BYTE)); // The "fused" representation stores the scale and bias with the // row-wise quantized data in one tensor. @@ -551,11 +555,9 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { TORCH_FN(QEmbeddingPackWeights::run)); } - TORCH_LIBRARY_IMPL(quantized, Meta, m) { m.impl( - "quantized::embedding_bag_byte_prepack", - qembeddingbag_byte_prepack_meta); + "quantized::embedding_bag_byte_prepack", qembeddingbag_byte_prepack_meta); } } // namespace diff --git a/torch/csrc/api/include/torch/nn/modules/container/modulelist.h b/torch/csrc/api/include/torch/nn/modules/container/modulelist.h index 72a76163ac03..683b6416b04f 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/modulelist.h +++ b/torch/csrc/api/include/torch/nn/modules/container/modulelist.h @@ -91,7 +91,7 @@ class ModuleListImpl : public Cloneable { void push_back(std::shared_ptr module) { modules_.push_back(std::move(module)); const auto index = modules_.size() - 1; - register_module(c10::to_string(index), modules_[index]); + register_module(std::to_string(index), modules_[index]); } /// Adds a new `Module` to the `ModuleList` container, moving or copying @@ -224,9 +224,9 @@ class ModuleListImpl : public Cloneable { for (const auto i : c10::irange(index, size() - 1)) { (void)i; // Suppress unused variable warning - replace_module(c10::to_string(index), modules_[index]); + replace_module(std::to_string(index), modules_[index]); } - register_module(c10::to_string(size() - 1), modules_.back()); + register_module(std::to_string(size() - 1), modules_.back()); } } diff --git a/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h b/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h index 30b7eb89e48b..cb816d1bb2a1 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h +++ b/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h @@ -50,14 +50,14 @@ class ParameterListImpl : public Cloneable { void append(torch::Tensor&& param) { bool requires_grad = param.requires_grad(); register_parameter( - c10::to_string(parameters_.size()), std::move(param), requires_grad); + std::to_string(parameters_.size()), std::move(param), requires_grad); } /// push the a given parameter at the end of the list void append(const torch::Tensor& param) { bool requires_grad = param.requires_grad(); register_parameter( - c10::to_string(parameters_.size()), param, requires_grad); + std::to_string(parameters_.size()), param, requires_grad); } /// push the a given parameter at the end of the list @@ -65,7 +65,7 @@ class ParameterListImpl : public Cloneable { /// will be added into the `ParameterList` void append(const OrderedDict::Item& pair) { register_parameter( - c10::to_string(parameters_.size()), + std::to_string(parameters_.size()), pair.value(), pair.value().requires_grad()); } @@ -111,7 +111,7 @@ class ParameterListImpl : public Cloneable { /// for a non-throwing way of access at::Tensor& at(size_t idx) { TORCH_CHECK(idx < size(), "Index out of range"); - return parameters_[c10::to_string(idx)]; + return parameters_[std::to_string(idx)]; } /// Returns the value associated with the given `key`. Throws an exception if @@ -119,7 +119,7 @@ class ParameterListImpl : public Cloneable { /// for a non-throwing way of access const at::Tensor& at(size_t idx) const { TORCH_CHECK(idx < size(), "Index out of range"); - return parameters_[c10::to_string(idx)]; + return parameters_[std::to_string(idx)]; } /// Returns the value associated with the given `key`. Throws an exception if diff --git a/torch/csrc/api/include/torch/nn/modules/container/sequential.h b/torch/csrc/api/include/torch/nn/modules/container/sequential.h index 4007e2cfd801..acefa23d49e5 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/container/sequential.h @@ -195,7 +195,7 @@ class SequentialImpl : public Cloneable { /// Adds a new (boxed) `Module` to the `Sequential` container. template void push_back(std::shared_ptr module_ptr) { - push_back(c10::to_string(modules_.size()), std::move(module_ptr)); + push_back(std::to_string(modules_.size()), std::move(module_ptr)); } /// Adds a new named (boxed) `Module` to the `Sequential` container. @@ -211,7 +211,7 @@ class SequentialImpl : public Cloneable { /// `Sequential(std::make_shared(3, 4))`. template > void push_back(M&& module) { - push_back(c10::to_string(modules_.size()), std::forward(module)); + push_back(std::to_string(modules_.size()), std::forward(module)); } /// Adds a new named `Module` to the `Sequential` container, moving or copying @@ -227,7 +227,7 @@ class SequentialImpl : public Cloneable { /// `Sequential`. template void push_back(const ModuleHolder& module_holder) { - push_back(c10::to_string(modules_.size()), module_holder); + push_back(std::to_string(modules_.size()), module_holder); } /// Unwraps the contained named module of a `ModuleHolder` and adds it to the @@ -247,7 +247,7 @@ class SequentialImpl : public Cloneable { /// Adds a type-erased `AnyModule` to the `Sequential`. void push_back(AnyModule any_module) { - push_back(c10::to_string(modules_.size()), std::move(any_module)); + push_back(std::to_string(modules_.size()), std::move(any_module)); } void push_back(std::string name, AnyModule any_module) { diff --git a/torch/csrc/autograd/cpp_hook.cpp b/torch/csrc/autograd/cpp_hook.cpp index 36f4671ee2e6..b851078b5280 100644 --- a/torch/csrc/autograd/cpp_hook.cpp +++ b/torch/csrc/autograd/cpp_hook.cpp @@ -41,7 +41,7 @@ variable_list CppFunctionTensorPreHook::operator()( // Don't change gradient continue; } - check_single_result(value, res, c10::to_string(i)); + check_single_result(value, res, std::to_string(i)); value = std::move(res); } variable_list results(values); diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 8c20bd807820..aed3eaa3e558 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -444,8 +444,8 @@ variable_list CppNode::apply(variable_list&& inputs) { if (num_outputs != num_forward_inputs) { std::string msg("function "); msg += name() + " returned an incorrect number of gradients (expected "; - msg += c10::to_string(num_forward_inputs) + ", got "; - msg += c10::to_string(num_outputs) + ")"; + msg += std::to_string(num_forward_inputs) + ", got "; + msg += std::to_string(num_outputs) + ")"; throw std::runtime_error(msg); } @@ -458,8 +458,8 @@ variable_list CppNode::apply(variable_list&& inputs) { std::string msg("function "); msg += name() + " returned a gradient different that is defined at position "; - msg += c10::to_string(i + 1) + - ", but the corresponding forward input was not a Variable"; + msg += std::to_string(i + 1) + + ", std the corresponding forward input was not a Variable"; throw std::runtime_error(msg); } continue; diff --git a/torch/csrc/distributed/rpc/agent_utils.cpp b/torch/csrc/distributed/rpc/agent_utils.cpp index 8eaae18cb209..89cb878755d9 100644 --- a/torch/csrc/distributed/rpc/agent_utils.cpp +++ b/torch/csrc/distributed/rpc/agent_utils.cpp @@ -13,7 +13,7 @@ std::unordered_map collectNames( std::vector selfNameVector( (uint8_t*)selfName.c_str(), (uint8_t*)selfName.c_str() + selfName.length()); - store.set(c10::to_string(selfId), selfNameVector); + store.set(std::to_string(selfId), selfNameVector); std::unordered_map nameToId; nameToId.reserve(worldSize); @@ -22,7 +22,7 @@ std::unordered_map collectNames( if (workerId == selfId) { continue; } - std::vector workerNameVector = store.get(c10::to_string(workerId)); + std::vector workerNameVector = store.get(std::to_string(workerId)); std::string workerName( (char*)workerNameVector.data(), workerNameVector.size()); @@ -69,7 +69,7 @@ std::unordered_map collectCurrentNames( // Check that ID does not already exist and set {ID : NAME} std::vector resultVector = store.compareSet( - c10::to_string(selfId), std::vector(), selfNameVector); + std::to_string(selfId), std::vector(), selfNameVector); TORCH_CHECK( resultVector == selfNameVector, "RPC worker id ", @@ -80,7 +80,7 @@ std::unordered_map collectCurrentNames( selfNameVector, " cannot be added."); - store.set(c10::to_string(selfId), selfNameVector); + store.set(std::to_string(selfId), selfNameVector); std::unordered_map nameToId; nameToId.emplace(selfName, selfId); diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index 73b66f954541..bba751e08917 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -143,10 +143,10 @@ std::unordered_map RRefContext::getDebugInfo() { numForks += owner.second.size(); } lock.unlock(); - info[kNumOwnerRRefs] = c10::to_string(ownerSize); - info[kNumPendingFutures] = c10::to_string(numPendingFutures_.load()); - info[kNumPendingUsers] = c10::to_string(numPendingUsers); - info[kNumForks] = c10::to_string(numForks); + info[kNumOwnerRRefs] = std::to_string(ownerSize); + info[kNumPendingFutures] = std::to_string(numPendingFutures_.load()); + info[kNumPendingUsers] = std::to_string(numPendingUsers); + info[kNumForks] = std::to_string(numForks); return info; } diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 8af4336c0746..2de6bacb7ee4 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -1279,13 +1279,13 @@ void TensorPipeAgent::updateGroupMembership( } std::unordered_map TensorPipeAgent::getMetrics() { std::unordered_map metrics; - metrics[kThreadPoolSize] = c10::to_string(threadPool_.size()); - metrics[kNumIdleThreads] = c10::to_string(threadPool_.numAvailable()); + metrics[kThreadPoolSize] = std::to_string(threadPool_.size()); + metrics[kNumIdleThreads] = std::to_string(threadPool_.numAvailable()); { std::unique_lock lock(callCountMutex_); - metrics[kClientActiveCalls] = c10::to_string(clientActiveCalls_); - metrics[kServerActiveCalls] = c10::to_string(serverActiveCalls_); - metrics[kServerActiveAsyncCalls] = c10::to_string(serverActiveAsyncCalls_); + metrics[kClientActiveCalls] = std::to_string(clientActiveCalls_); + metrics[kServerActiveCalls] = std::to_string(serverActiveCalls_); + metrics[kServerActiveAsyncCalls] = std::to_string(serverActiveAsyncCalls_); } if (isGILProfilingEnabled()) { { @@ -1295,7 +1295,7 @@ std::unordered_map TensorPipeAgent::getMetrics() { auto averageGilWaitTime = timeSeriesMetrics_[kGilAverageWaitTime].computeAverage(); lock.unlock(); - metrics[kGilAverageWaitTime] = c10::to_string(averageGilWaitTime); + metrics[kGilAverageWaitTime] = std::to_string(averageGilWaitTime); } } diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index 822079b12ecf..bde9d1ad61ad 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -389,7 +389,7 @@ std::string wireSerialize( // out of scope of this loop. auto writeableTensorData = jit::getWriteableTensorData(tensorData[i]); entries.push_back( - {c10::to_string(i), + {std::to_string(i), writeableTensorData.data(), writeableTensorData.sizeInBytes()}); } @@ -401,7 +401,7 @@ std::string wireSerialize( tot += e.size; header.append(e.name) .append(" ") - .append(c10::to_string(e.size)) + .append(std::to_string(e.size)) .append("\n"); } header.push_back('\n'); diff --git a/torch/csrc/jit/api/function_impl.cpp b/torch/csrc/jit/api/function_impl.cpp index c0f0b4e486b4..5f25ce51702a 100644 --- a/torch/csrc/jit/api/function_impl.cpp +++ b/torch/csrc/jit/api/function_impl.cpp @@ -28,7 +28,7 @@ c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) { for (const auto i : c10::irange(num_inputs)) { const Value* v = g.inputs().at(i); std::string name = v->hasDebugName() ? v->debugNameBase() - : ("argument_" + c10::to_string(i)); + : ("argument_" + std::to_string(i)); args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type())); } for (const auto i : c10::irange(g.outputs().size())) { diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index 2f9217e13369..a2d26979c1e0 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -30,15 +30,15 @@ size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes}; )"); static std::string valueName(const Value* n) { - return "n" + c10::to_string(n->unique()); + return "n" + std::to_string(n->unique()); } static std::string scalarValue(const int64_t v) { - return c10::to_string(v); + return std::to_string(v); } static std::string scalarValue(const bool v) { - return c10::to_string(v); + return std::to_string(v); } // Note: The NAN, NEG_INFINITY and POS_INFINITY strings map to device-specific @@ -274,10 +274,10 @@ static std::string encodeRHS(const Node* n) { // PyTorch converts (scalar) argument types to result before applying the // operator e.g. 1.4-torch.tensor(3) = -2 env.s( - c10::to_string(i), + std::to_string(i), typeCastedValueName(*in->type(), *outtype, valueName(in))); // Uncasted operands only used for comparison operators - env.s(c10::to_string(i) + "_nocast", valueName(in)); + env.s(std::to_string(i) + "_nocast", valueName(in)); i++; } @@ -391,7 +391,7 @@ std::string generateKernel( 1); // + 1 because the first argument is the linearIndex std::string tensor = "t" + - c10::to_string( + std::to_string( formals.size()); // can't be unique() because Param may be an output const auto nDim = desc.nDim(); emitCheckFor(tensorChecks, tensor, nDim, desc); @@ -413,7 +413,7 @@ std::string generateKernel( 1); // + 1 because the first argument is the linearIndex std::string scalar = "s" + - c10::to_string( + std::to_string( formals.size()); // can't be unique() because Param may be an output env.d( "formal_index", diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index 3c05b70e8341..b4bc3e8f4727 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -281,7 +281,7 @@ std::shared_ptr compileKernel( } const bool use_cuda = device.is_cuda(); - const std::string name = "kernel_" + c10::to_string(next_kernel_id++); + const std::string name = "kernel_" + std::to_string(next_kernel_id++); std::string code = generateKernel(name, *graph, flat_inputs, flat_outputs, use_cuda); const FusedKernelConstructor& kernel_ctor = diff --git a/torch/csrc/jit/frontend/function_schema_parser.cpp b/torch/csrc/jit/frontend/function_schema_parser.cpp index 13497c20e15c..ba86a891d31d 100644 --- a/torch/csrc/jit/frontend/function_schema_parser.cpp +++ b/torch/csrc/jit/frontend/function_schema_parser.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -157,7 +156,7 @@ struct SchemaParser { // note: an array with a size hint can only occur at the Argument level fake_type = ListType::create(std::move(fake_type)); real_type = ListType::create(std::move(real_type)); - N = c10::stoll(L.expect(TK_NUMBER).text()); + N = std::stoll(L.expect(TK_NUMBER).text()); L.expect(']'); auto container = type_parser.parseAliasAnnotation(); if (alias_info) { @@ -244,14 +243,14 @@ struct SchemaParser { n = L.expect(TK_NUMBER).text(); if (kind == TypeKind::ComplexType || n.find('j') != std::string::npos) { - auto imag = c10::stod(n.substr(0, n.size() - 1)); + auto imag = std::stod(n.substr(0, n.size() - 1)); return c10::complex(0, imag); } else if ( kind == TypeKind::FloatType || n.find('.') != std::string::npos || n.find('e') != std::string::npos) { - return c10::stod(n); + return std::stod(n); } else { - int64_t v = c10::stoll(n); + int64_t v = std::stoll(n); return v; } } diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 0aca3ea80062..350305b83567 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -722,7 +722,7 @@ struct to_ir { std::vector def_stack_; size_t temp_name_count_ = 0; std::string createTempName(const std::string& prefix) { - return prefix + c10::to_string(temp_name_count_++); + return prefix + std::to_string(temp_name_count_++); } void pushFrame(Block* b, bool starts_def = false) { @@ -3222,7 +3222,7 @@ struct to_ir { case TK_IN: return aten::__contains__; default: - throw std::runtime_error("unknown kind " + c10::to_string(kind)); + throw std::runtime_error("unknown kind " + std::to_string(kind)); } } @@ -3269,7 +3269,7 @@ struct to_ir { case TK_RSHIFT: return "__rshift__"; default: - throw std::runtime_error("unknown kind " + c10::to_string(kind)); + throw std::runtime_error("unknown kind " + std::to_string(kind)); } } diff --git a/torch/csrc/jit/frontend/name_mangler.cpp b/torch/csrc/jit/frontend/name_mangler.cpp index fbf1d24932e8..698bdd1e67b7 100644 --- a/torch/csrc/jit/frontend/name_mangler.cpp +++ b/torch/csrc/jit/frontend/name_mangler.cpp @@ -21,7 +21,7 @@ c10::QualifiedName NameMangler::mangle(const c10::QualifiedName& name) { // Append the part of the name up to the end of the prefix newAtomPrefix.append(atom, 0, pos); newAtomPrefix.append(manglePrefix); - atom = newAtomPrefix + c10::to_string(mangleIndex_++); + atom = newAtomPrefix + std::to_string(mangleIndex_++); // increment mangleIndex_ until the type is not defined return c10::QualifiedName(atoms); } @@ -29,7 +29,7 @@ c10::QualifiedName NameMangler::mangle(const c10::QualifiedName& name) { // Otherwise add a mangle namespace right before the basename TORCH_INTERNAL_ASSERT(!atoms.empty()); - atoms.insert(atoms.end() - 1, manglePrefix + c10::to_string(mangleIndex_++)); + atoms.insert(atoms.end() - 1, manglePrefix + std::to_string(mangleIndex_++)); return c10::QualifiedName(atoms); } diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index b81a6c720770..2adacb976a04 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/frontend/tree_views.h b/torch/csrc/jit/frontend/tree_views.h index a6488c92f406..77d06bee94a9 100644 --- a/torch/csrc/jit/frontend/tree_views.h +++ b/torch/csrc/jit/frontend/tree_views.h @@ -1,5 +1,4 @@ #pragma once -#include #include #include #include @@ -1032,7 +1031,7 @@ struct SliceExpr : public Expr { private: Expr createInt(int64_t value) const { - return Expr(Const::create(range(), c10::to_string(value))); + return Expr(Const::create(range(), std::to_string(value))); } }; diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 549f4a11001f..859da3cb3cae 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -224,7 +224,7 @@ struct Value { if (hasDebugName()) { return unique_name_; } - return c10::to_string(unique()); + return std::to_string(unique()); } TORCH_API std::string debugNameBase() const; Node* node() { diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index 09c5df58f0be..f0dd562cc1cd 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -348,7 +348,7 @@ std::stringstream backport_v5_to_v4(std::stringstream& input_model_stream) { for (const auto& td : data_pickle.tensorData()) { WriteableTensorData writable_td = getWriteableTensorData(td); - std::string fname = prefix + c10::to_string(i++); + std::string fname = prefix + std::to_string(i++); writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); } std::string fname = archive_name + ".pkl"; diff --git a/torch/csrc/jit/mobile/train/export_data.cpp b/torch/csrc/jit/mobile/train/export_data.cpp index 731ffef15424..aeb9f95dad67 100644 --- a/torch/csrc/jit/mobile/train/export_data.cpp +++ b/torch/csrc/jit/mobile/train/export_data.cpp @@ -61,7 +61,7 @@ class IValuePickler final { std::string prefix = archive_name + "/"; for (const auto& td : data_pickle.tensorData()) { WriteableTensorData writable_td = getWriteableTensorData(td); - std::string fname = prefix + c10::to_string(i++); + std::string fname = prefix + std::to_string(i++); writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); } std::string fname = archive_name + ".pkl"; diff --git a/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp b/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp index 6f1aa4aee308..b4c0fd053511 100644 --- a/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp +++ b/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp @@ -388,7 +388,7 @@ std::string mangleMethodName( for (size_t method_idx = 0;; method_idx++) { auto mangled = method_name; if (method_idx != 0) { - mangled += c10::to_string(method_idx); + mangled += std::to_string(method_idx); } bool found = false; for (Function* fn : mod_type->methods()) { diff --git a/torch/csrc/jit/passes/hoist_conv_packed_params.cpp b/torch/csrc/jit/passes/hoist_conv_packed_params.cpp index c3db2373f2a3..5034626923b5 100644 --- a/torch/csrc/jit/passes/hoist_conv_packed_params.cpp +++ b/torch/csrc/jit/passes/hoist_conv_packed_params.cpp @@ -64,10 +64,10 @@ static void hoistConvPackedParams( } std::string newNameBase = prefix + "." + suffix + "_packed_params"; nameUniqueCounter++; - std::string newName = newNameBase + "." + c10::to_string(nameUniqueCounter); + std::string newName = newNameBase + "." + std::to_string(nameUniqueCounter); while (rootModule.hasattr(newName)) { nameUniqueCounter++; - newName = newNameBase + "." + c10::to_string(nameUniqueCounter); + newName = newNameBase + "." + std::to_string(nameUniqueCounter); } // copy the packed params diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 73c19851e569..b468e739a03f 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -710,7 +710,7 @@ static void eraseListUnpack(Node* n, int opset_version) { // onnx::SequenceAt was introduced in onnx opset version 11 throw std::runtime_error( "Unsupported: ONNX export of prim::ListUnpack in opset " + - c10::to_string(opset_version) + ". Please try opset version 11."); + std::to_string(opset_version) + ". Please try opset version 11."); } auto g = n->owningGraph(); diff --git a/torch/csrc/jit/passes/prepack_folding.cpp b/torch/csrc/jit/passes/prepack_folding.cpp index 1c7372e23633..d37201c5b3d5 100644 --- a/torch/csrc/jit/passes/prepack_folding.cpp +++ b/torch/csrc/jit/passes/prepack_folding.cpp @@ -30,7 +30,7 @@ void PrePackingOpsFolder( if (optional_outputs) { auto outputs = optional_outputs.value(); TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output"); - auto attr_name = attr_name_base + c10::to_string(uid++); + auto attr_name = attr_name_base + std::to_string(uid++); TORCH_CHECK( !(m.type()->findAttributeSlot(attr_name)), "Attribute name ", diff --git a/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp b/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp index 65e900d3888a..2c83bcbc10e1 100644 --- a/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp +++ b/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp @@ -97,9 +97,9 @@ class ModuleUseDeduper { // Original name of the child module const std::string& original_name = path[path.size() - 1]; int uid = 0; - std::string child_name = original_name + "_" + c10::to_string(uid++); + std::string child_name = original_name + "_" + std::to_string(uid++); while (parent_of_leaf.hasattr(child_name)) { - child_name = original_name + "_" + c10::to_string(uid++); + child_name = original_name + "_" + std::to_string(uid++); } parent_of_leaf.register_module(child_name, child_module.deepcopy()); return child_name; diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index de1cff1ba9d1..145448210958 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -953,9 +953,9 @@ void InsertObserversHelper::insertObserverFor( } GRAPH_DEBUG("Inserting observer for:", v->debugName()); Module observer = observer_module.deepcopy(); - std::string observer_name = "_observer_" + c10::to_string(uid_++); + std::string observer_name = "_observer_" + std::to_string(uid_++); while (module.hasattr(observer_name)) { - observer_name = "_observer_" + c10::to_string(uid_++); + observer_name = "_observer_" + std::to_string(uid_++); } module.register_module(observer_name, observer); observer_name_and_modules.emplace_back(observer_name, observer); diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 02f4f1096976..92fb2fc79bcc 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -1042,10 +1042,10 @@ void InsertQuantDeQuantHelper::quantizeTensors( const auto& qparam = pr.second; size_t uid = 0; auto qparam_name = - original_value->debugName() + name + "_" + c10::to_string(uid++); + original_value->debugName() + name + "_" + std::to_string(uid++); while (module.hasattr(qparam_name)) { qparam_name = - original_value->debugName() + name + "_" + c10::to_string(uid++); + original_value->debugName() + name + "_" + std::to_string(uid++); } qparam_name_map_for_node_[n][name] = qparam_name; module.register_attribute(qparam_name, qparam.type(), qparam); diff --git a/torch/csrc/jit/passes/quantization/register_packed_params.cpp b/torch/csrc/jit/passes/quantization/register_packed_params.cpp index bd93c6535e61..1d7dcfe72eea 100644 --- a/torch/csrc/jit/passes/quantization/register_packed_params.cpp +++ b/torch/csrc/jit/passes/quantization/register_packed_params.cpp @@ -73,13 +73,13 @@ std::unordered_set RegisterPrePackParams( WithInsertPoint ins(n->next()); Value* packed_param_value = n->output(0); TORCH_CHECK(n->outputs().size() == 1, "Prepack ops have single output"); - auto attr_name = attr_name_base + c10::to_string(uid++); + auto attr_name = attr_name_base + std::to_string(uid++); TORCH_CHECK( packed_param_value->uses().size() == 1, "Packed param must be used by exactly one op."); auto use = packed_param_value->uses()[0]; while (m.hasattr(attr_name)) { - attr_name = attr_name_base + "_" + c10::to_string(uid++); + attr_name = attr_name_base + "_" + std::to_string(uid++); } // Now register attribute for this packed param but dont set it to any // value. No value because we dont know what the value is at this point. diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index 1bb82432e218..377621c04b6d 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -606,7 +606,7 @@ static std::string truncateStrWithHash(const std::string& s, size_t maxlen) { if (s.size() <= maxlen) { return s; } - std::string hash_str = c10::to_string(c10::hash{}(s)); + std::string hash_str = std::to_string(c10::hash{}(s)); // If hash-string plus '_' can fit into maxlen, then truncate the original // string correspondingly so that the final string with the hash included fits // into maxlen. If that's not possible, at least truncate the original string diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index 7335f132dfbf..a057367af81c 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -133,7 +133,7 @@ void checkDoubleInRange(double a) { a > double(std::numeric_limits::max()) || a < double(std::numeric_limits::min())) { throw c10::Error( - "Cannot convert float " + c10::to_string(a) + " to integer"); + "Cannot convert float " + std::to_string(a) + " to integer"); return; } } diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index 15e59acb9fe6..3386bc3e4a49 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -32,7 +32,6 @@ #include #include #include -#include namespace torch::jit { constexpr inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index a8ff315c9173..9dc31446d1e1 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -1870,8 +1870,8 @@ bool BlockRunner::check_for_memory_leak( // `BlockRunner::deallocateOutputTensors`. continue; } - const std::string error_msg = "Output " + c10::to_string(i) + ", %" + - val->debugName() + " of node " + c10::to_string(n) + + const std::string error_msg = "Output " + std::to_string(i) + ", %" + + val->debugName() + " of node " + std::to_string(n) + " which has kind " + pnode.node()->kind().toQualString() + " was not cleaned up"; if (output_ivalues.count(ival) == 0) { @@ -1947,8 +1947,8 @@ bool BlockRunner::checkOutputTensorMemoryLeaks() { const auto& t = ival->toTensor(); if (t.defined()) { auto* storage_impl = t.storage().unsafeGetStorageImpl(); - const std::string error_msg = "Output " + c10::to_string(i) + ", %" + - val->debugName() + " of node " + c10::to_string(n) + + const std::string error_msg = "Output " + std::to_string(i) + ", %" + + val->debugName() + " of node " + std::to_string(n) + " was not cleaned up"; TORCH_CHECK(storage_impl->data() == nullptr, error_msg); } diff --git a/torch/csrc/jit/tensorexpr/block_codegen.cpp b/torch/csrc/jit/tensorexpr/block_codegen.cpp index 1b32600426ca..1237120cc806 100644 --- a/torch/csrc/jit/tensorexpr/block_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/block_codegen.cpp @@ -321,7 +321,7 @@ std::string BlockCodeGen::GetUniqueFuncName(const std::string& func_prefix) { static int64_t counter = 0; ++counter; int64_t value = counter; - return func_prefix + "_" + c10::to_string(value); + return func_prefix + "_" + std::to_string(value); } void BlockCodeGen::Initialize() { diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index d0b9abaa1fa6..5666097f2dd4 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -1178,7 +1178,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { case kIsNan: return std::isnan(v); default: - throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); + throw std::runtime_error("Invalid op_type: " + std::to_string(op_type)); } } @@ -1198,7 +1198,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { } default: throw std::runtime_error( - "Invalid integral op_type: " + c10::to_string(op_type)); + "Invalid integral op_type: " + std::to_string(op_type)); } } @@ -1208,7 +1208,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { case kIsNan: return std::isnan(v); default: - throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); + throw std::runtime_error("Invalid op_type: " + std::to_string(op_type)); } } @@ -1224,7 +1224,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { case kAtan2: return std::atan2(v1, v2); default: - throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); + throw std::runtime_error("Invalid op_type: " + std::to_string(op_type)); } } diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index 9bbea1bd28a4..0959151fb734 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -9,7 +9,6 @@ #include #include -#include #include #include #include diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp index cea5170afcfe..889eeafc028f 100644 --- a/torch/csrc/jit/tensorexpr/ir.cpp +++ b/torch/csrc/jit/tensorexpr/ir.cpp @@ -175,7 +175,7 @@ int Intrinsics::OpArgCount(IntrinsicsOp op_type) { case kRemainder: return 2; default: - throw std::runtime_error("invalid op_type: " + c10::to_string(op_type)); + throw std::runtime_error("invalid op_type: " + std::to_string(op_type)); } } diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index f35bafb332ea..89c3f96aba6e 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -4,7 +4,6 @@ #include #include -#include #include #include #include @@ -827,7 +826,7 @@ class TORCH_API Intrinsics : public ExprNode { return "isnan"; default: throw std::runtime_error( - "invalid op_type: " + c10::to_string(op_type())); + "invalid op_type: " + std::to_string(op_type())); } } diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 50578a041457..d18a3d65f21e 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -885,7 +884,7 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { inner1->set_gpu_thread_index(0); } else { throw std::runtime_error( - "Invalid loop-level: " + c10::to_string(loopLevels)); + "Invalid loop-level: " + std::to_string(loopLevels)); } } } @@ -953,7 +952,7 @@ std::string TensorExprKernel::getCodeGenName(BackendType backendType) { default: throw std::runtime_error( "invalid backend type: " + - c10::to_string(static_cast(backendType))); + std::to_string(static_cast(backendType))); } } @@ -1190,7 +1189,7 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) { ToDtype(static_cast(*tt->scalarType()))); result = Compute( - "input" + c10::to_string(bufs_.size() + 1), + "input" + std::to_string(bufs_.size() + 1), size_handles, [&](const std::vector& axes) { ExprHandle idx = 0; diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 1b08286fbd9f..62a67af7fb14 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -11,7 +11,6 @@ #include #include -#include #include #include @@ -3140,7 +3139,7 @@ void LoopNest::computeAt(StmtPtr s, ForPtr f) { for (const auto i : c10::irange(dims.size())) { // TODO: Use name-hint of the producer indices instead of 'idx' temp_indices[i] = - alloc(std::string("idx") + c10::to_string(i), dims[i]->dtype()); + alloc(std::string("idx") + std::to_string(i), dims[i]->dtype()); } // Prepare substitute rules for constructing the temp statement from the prod diff --git a/torch/csrc/jit/tensorexpr/operators/misc.cpp b/torch/csrc/jit/tensorexpr/operators/misc.cpp index 70991f6db1f4..938cab6ffd88 100644 --- a/torch/csrc/jit/tensorexpr/operators/misc.cpp +++ b/torch/csrc/jit/tensorexpr/operators/misc.cpp @@ -576,7 +576,7 @@ static Tensor computeCatWoConditionals( std::vector store_indices(dims.size()); for (int64_t i = 0; i < static_cast(dims.size()); ++i) { for_vars[i] = alloc( - "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), + "i" + std::to_string(inp_pos) + "_" + std::to_string(i), dims[i].dtype()); load_indices[i] = for_vars[i]; if (i == norm_concat_dim) { diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index 939f82c616dc..5e57209f39e2 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -732,7 +732,7 @@ void RegisterizerReplacer::buildReplacements() { for (auto& info : infoSet_) { VarPtr v = alloc( info->buf()->name_hint() + "_" + - c10::to_string(getBufferAccessCount(info->buf())), + std::to_string(getBufferAccessCount(info->buf())), info->buf()->dtype()); info->replacement().var = v; diff --git a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp index 01065f5eff5b..1307e53577f4 100644 --- a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp +++ b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp @@ -1,6 +1,5 @@ #include -#include #include #include @@ -28,7 +27,7 @@ const std::string& UniqueNameManager::get_unique_name(VarPtr v) { int count_v = count++; std::string unique_name = name_hint; if (count_v > 0) { - unique_name += "_" + c10::to_string(count_v); + unique_name += "_" + std::to_string(count_v); } if (all_unique_names_.count(unique_name) == 0) { all_unique_names_.insert(unique_name); From f843ccbb1ab1b5264bc4ee06d90b9f796a891b1a Mon Sep 17 00:00:00 2001 From: Jun Luo Date: Mon, 10 Jun 2024 23:42:50 +0000 Subject: [PATCH 605/706] [MTIA] Add set_device support (#128040) Summary: Support set_device API in MTIA backend. Reviewed By: gnahzg Differential Revision: D58089498 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128040 Approved by: https://github.com/gnahzg --- docs/source/mtia.rst | 1 + torch/mtia/__init__.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/docs/source/mtia.rst b/docs/source/mtia.rst index f2f5b5195dcb..b729c061b2a6 100644 --- a/docs/source/mtia.rst +++ b/docs/source/mtia.rst @@ -18,6 +18,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined init is_available is_initialized + set_device set_stream stream synchronize diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index b68a25bdb61b..f9554a9bcb27 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -160,6 +160,18 @@ def set_stream(stream: Stream): torch._C._mtia_setCurrentStream(stream) +def set_device(device: _device_t) -> None: + r"""Set the current device. + + Args: + device (torch.device or int): selected device. This function is a no-op + if this argument is negative. + """ + device = _get_device_index(device) + if device >= 0: + torch._C._accelerator_hooks_set_current_device(device) + + class device: r"""Context-manager that changes the selected device. @@ -257,6 +269,7 @@ def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext: "current_device", "current_stream", "default_stream", + "set_device", "set_stream", "stream", "device", From 99f5a85a09596b3b2c83329f5e686ce4d8775efb Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 10 Jun 2024 23:49:58 +0000 Subject: [PATCH 606/706] [Clang Tidy] Fix misc-header-include-cycle errors in clang-tidy and ignore some files (#127233) Since there are such cycles in libfmt and PyTorch, which are detected by clang-tidy. ``` /home/cyy/pytorch/third_party/fmt/include/fmt/format-inl.h:25:10: error: circular header file dependency detected while including 'format.h', please check the include path [misc-header-include-cycle,-warnings-as-errors] 25 | #include "format.h" | ^ /home/cyy/pytorch/third_party/fmt/include/fmt/format.h:4530:12: note: 'format-inl.h' included from here 4530 | # include "format-inl.h" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127233 Approved by: https://github.com/ezyang --- .clang-tidy | 2 ++ torch/csrc/jit/python/python_custom_class.cpp | 1 + torch/csrc/jit/python/python_custom_class.h | 1 - 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.clang-tidy b/.clang-tidy index fef154d4b0c1..1f7521ce7600 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -62,4 +62,6 @@ readability-string-compare, ' HeaderFilterRegex: '^(aten/|c10/|torch/).*$' WarningsAsErrors: '*' +CheckOptions: + misc-header-include-cycle.IgnoredFilesList: 'format.h;ivalue.h;custom_class.h;Dict.h;List.h' ... diff --git a/torch/csrc/jit/python/python_custom_class.cpp b/torch/csrc/jit/python/python_custom_class.cpp index 55cde36c0e62..bf9e516566e5 100644 --- a/torch/csrc/jit/python/python_custom_class.cpp +++ b/torch/csrc/jit/python/python_custom_class.cpp @@ -1,3 +1,4 @@ +#include #include #include diff --git a/torch/csrc/jit/python/python_custom_class.h b/torch/csrc/jit/python/python_custom_class.h index d7cff488f273..1033fc008f27 100644 --- a/torch/csrc/jit/python/python_custom_class.h +++ b/torch/csrc/jit/python/python_custom_class.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include From 734e8f6ad7e7f0fa0341fb658f1f986225173f5f Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 7 Jun 2024 12:19:27 -0700 Subject: [PATCH 607/706] [inductor] enable fx graph cache on torchbench (#128239) Summary: We've already enabled for timm and huggingface, but we had failures saving cache entries for moco. It looks like https://github.com/pytorch/pytorch/pull/128052 has fixed that issue, so we can enable for torchbench. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128239 Approved by: https://github.com/oulgen --- benchmarks/dynamo/torchbench.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index fd11e984bbdc..d7877c5a3fac 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -25,6 +25,10 @@ # We are primarily interested in tf32 datatype torch.backends.cuda.matmul.allow_tf32 = True +# Enable FX graph caching +if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: + torch._inductor.config.fx_graph_cache = True + def _reassign_parameters(model): # torch_geometric models register parameter as tensors due to From 3b555ba47713d489975a9bb6cb6c31975f805e3f Mon Sep 17 00:00:00 2001 From: Arun Pa Date: Tue, 11 Jun 2024 01:32:42 +0000 Subject: [PATCH 608/706] Add docstring for torch.utils.data.datapipes.decoder.basicandlers (#128018) Fixes #127912 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128018 Approved by: https://github.com/andrewkho --- torch/utils/data/datapipes/utils/decoder.py | 29 ++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/torch/utils/data/datapipes/utils/decoder.py b/torch/utils/data/datapipes/utils/decoder.py index 7c055c567295..b465f3a0aaa6 100644 --- a/torch/utils/data/datapipes/utils/decoder.py +++ b/torch/utils/data/datapipes/utils/decoder.py @@ -29,7 +29,34 @@ ################################################################ # handle basic datatypes ################################################################ -def basichandlers(extension, data): +def basichandlers(extension: str, data): + """Transforms raw data (byte stream) into python objects. + + Looks at the extension and loads the data into a python object supporting + the corresponding extension. + + Args: + extension (str): The file extension + data (byte stream): Data to load into a python object. + + Returns: + object: The data loaded into a corresponding python object + supporting the extension. + + Example: + >>> import pickle + >>> data = pickle.dumps('some data') + >>> new_data = basichandlers('pickle', data) + >>> new_data + some data + + The transformation of data for extensions are: + - txt, text, transcript: utf-8 decoded data of str format + - cls, cls2, class, count, index, inx, id: int + - json, jsn: json loaded data + - pickle, pyd: pickle loaded data + - pt: torch loaded data + """ if extension in "txt text transcript": return data.decode("utf-8") From 841d87177a900c2bbd59b6589165189141c4e8bb Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 10 Jun 2024 13:24:10 -0700 Subject: [PATCH 609/706] Make sure #126704 is BC for torch.save-ed `nn.Module` (#128344) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128344 Approved by: https://github.com/albanD ghstack dependencies: #126906, #126704 --- torch/nn/modules/module.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 942d5b8a8f95..2c2c2865687c 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1955,7 +1955,12 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): for name, module in self._modules.items(): if module is not None: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) - for (hook, from_private) in self._state_dict_hooks.values(): + for value in self._state_dict_hooks.values(): + # For BC reasons + if isinstance(value, tuple): + hook, from_private = value + else: + hook, from_private = value, True hook_result = hook(self, destination, prefix, local_metadata) if from_private and hook_result is not None: destination = hook_result From d1d9bc7aa65450cd35a72d4357e1089f871944e4 Mon Sep 17 00:00:00 2001 From: Andrew Hoblitzell Date: Tue, 11 Jun 2024 02:37:01 +0000 Subject: [PATCH 610/706] init add comment (#128083) Fixes #127898 ### Description Add docstring to torch/onnx/symbolic_opset9.py:sigmoid function ### Checklist - [x] The issue that is being fixed is referred in the description - [x] Only one issue is addressed in this pull request - [x] Labels from the issue that this PR is fixing are added to this pull request - [x] No unnecessary issues are included into this pull request Pull Request resolved: https://github.com/pytorch/pytorch/pull/128083 Approved by: https://github.com/titaiwangms --- torch/onnx/symbolic_opset9.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index f71ef713636a..21ce701d26fe 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -345,6 +345,20 @@ def reshape_as(g: jit_utils.GraphContext, self, other): @_onnx_symbolic("aten::add") @_beartype.beartype def add(g: jit_utils.GraphContext, self, other, alpha=None): + """ + This function takes the add function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (float, optional): The scaling factor for the second operand. Defaults to None. + + Returns: + ONNX operator. + """ if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): return symbolic_helper._onnx_opset_unsupported_detailed( "Add", 9, 11, "Add between list of tensors not supported", self From 793df7b7cb1473004837f5867f4c1c4b2b0f751d Mon Sep 17 00:00:00 2001 From: eellison Date: Mon, 10 Jun 2024 14:40:39 -0700 Subject: [PATCH 611/706] Prevent expansion of cat indexing to avoid int64 intermediate (#127815) Fix for https://github.com/pytorch/pytorch/issues/127652 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127815 Approved by: https://github.com/shunting314, https://github.com/peterbell10 --- test/inductor/test_cuda_repro.py | 40 ++++++++++++++++++++++++++++++ torch/_inductor/bounds.py | 9 +++++++ torch/_inductor/codegen/common.py | 3 +++ torch/_inductor/lowering.py | 12 +++++++-- torch/_inductor/utils.py | 10 ++++++-- torch/utils/_sympy/functions.py | 16 ++++++++++++ torch/utils/_sympy/interp.py | 2 ++ torch/utils/_sympy/value_ranges.py | 4 +++ 8 files changed, 92 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 8365d216f82c..23243b7db5b5 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1238,6 +1238,46 @@ def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950 ) + def test_int64_index_intermediate(self): + def foo(inp): + view_23 = torch.ops.aten.view.default(inp, [-1, 8192, 8192]) + split_1 = torch.ops.aten.split.Tensor(view_23, 1024, 1) + view_23 = None + getitem_17 = split_1[0] + getitem_18 = split_1[1] + getitem_19 = split_1[2] + getitem_20 = split_1[3] + getitem_21 = split_1[4] + getitem_22 = split_1[5] + getitem_23 = split_1[6] + getitem_24 = split_1[7] + split_1 = None + cat_1 = torch.ops.aten.cat.default( + [ + getitem_17, + getitem_18, + getitem_19, + getitem_20, + getitem_21, + getitem_22, + getitem_23, + getitem_24, + ] + ) + getitem_17 = ( + getitem_18 + ) = ( + getitem_19 + ) = getitem_20 = getitem_21 = getitem_22 = getitem_23 = getitem_24 = None + return cat_1 + + for mark_dynamic in [False, True]: + inp = torch.rand((65536, 8192), dtype=torch.bfloat16, device="cuda") + if mark_dynamic: + torch._dynamo.mark_dynamic(inp, 0) + foo_c = torch.compile(foo) + torch.testing.assert_allclose(foo(inp), foo_c(inp)) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 8c62ef2ba3c9..b7bb37e5ee68 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -45,6 +45,15 @@ def upper_bound(v): # To access this variable call `get_bounds()` self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {} + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"loop_body={self.loop_body},\n " + f"replacement_vals={self.replacement_vals}, \n" + f"unbounded_vars={self.unbounded_vars}, \n" + f"_bounds={self._bounds})" + ) + @cache_on_self def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: submodules = self.swap_submodules(self.loop_body.submodules) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 8ca6dc2b9153..02aa3e7395f7 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -393,6 +393,9 @@ def _print_FloatTrueDiv(self, expr): def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) + def _print_Identity(self, expr): + return self._print(expr.args[0]) + def _print_GreaterThan(self, expr): # GreaterThan: >= # StrictlyGreaterThan: > diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index e432fd45cd94..f9f1bca3d920 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -35,7 +35,13 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing +from torch.utils._sympy.functions import ( + CeilDiv, + FloorDiv, + Identity, + IntTrueDiv, + ModularIndexing, +) from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 @@ -1016,7 +1022,9 @@ def inner_fn(idx): # if we're concatting [4], [2] # when we index the second tensor for 5 we want to index 5 - 4 - idx_load[dim] -= inputs_ranges[i][0] + # Use Identity to prevent expansion of index * stride to keep expression + # in same int bitwidth as shape + idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0]) masked_loads.append( ops.masked( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 64dc283843f4..ea3826855f59 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -54,7 +54,13 @@ from torch.autograd.profiler_util import EventList from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import ShapeProp -from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.functions import ( + CeilDiv, + CleanDiv, + FloorDiv, + Identity, + ModularIndexing, +) from torch.utils._sympy.symbol import make_symbol, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges from . import config @@ -575,7 +581,7 @@ def sympy_str(expr: sympy.Expr) -> str: if isinstance(expr, sympy.Mul): return " * ".join(map(sympy_str, expr.args)) - if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)): + if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)): return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" return str(expr) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index fd9921848d60..3c845f58117b 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -24,6 +24,7 @@ "ToFloat", "FloatPow", "PowByNatural", + "Identity", ] @@ -719,6 +720,21 @@ def eval(cls, number): return -sympy.oo +class Identity(sympy.Function): + """ + Prevents expansion and other optimizations + """ + + def __repr__(self): + return f"Identity({self.args[0]})" + + def _eval_is_real(self): + return self.args[0].is_real + + def _eval_is_integer(self): + return self.args[0].is_integer # type: ignore[attr-defined] + + def make_opaque_unary_fn(name): class OpaqueUnaryFn(sympy.Function): """ diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 36ff6fc23d4a..3bcb369bcebc 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -23,6 +23,7 @@ FloatTrueDiv, FloorDiv, FloorToInt, + Identity, IntTrueDiv, IsNonOverlappingAndDenseIndicator, Mod, @@ -92,6 +93,7 @@ def handlers(): ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", + Identity: "identity", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", RoundDecimal: "round_decimal", } diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 087e741a72ec..d16da832459a 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -470,6 +470,10 @@ def eq(a, b): def ne(cls, a, b): return cls.not_(cls.eq(a, b)) + @classmethod + def identity(cls, a): + return ValueRanges.wrap(a) + @classmethod def lt(cls, a, b): a = ValueRanges.wrap(a) From e4bd0adca5a22e32c0a3946a8d94591573e82343 Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 11 Jun 2024 02:46:31 +0000 Subject: [PATCH 612/706] [6/N] Remove unused functions (#128309) Follows #127185 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128309 Approved by: https://github.com/ezyang --- aten/src/ATen/native/SoftMax.cpp | 18 -------- aten/src/ATen/native/TensorFactories.cpp | 8 ---- aten/src/ATen/native/TensorProperties.cpp | 4 -- aten/src/ATen/native/TensorShape.cpp | 16 ------- aten/src/ATen/native/mkldnn/Conv.cpp | 46 ------------------- .../ATen/native/quantized/cpu/ReduceOps.cpp | 10 ---- .../quantized/cpu/UpSampleBilinear2d.cpp | 14 ------ .../ATen/native/sparse/SparseTensorMath.cpp | 12 ----- 8 files changed, 128 deletions(-) diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index fa7be5a698e9..aa9173154a14 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -440,15 +440,6 @@ TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) ( } } -static Tensor softmax(const Tensor& input_, const int64_t dim_) { - auto result = [&]() { - NoNamesGuard guard; - return at::_softmax(input_, dim_, false); - }(); - namedinference::propagate_names(result, input_); - return result; -} - Tensor softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; @@ -505,15 +496,6 @@ Tensor special_softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 5e7c9cf8a5f8..55961d9e0be9 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -172,18 +172,10 @@ Tensor arange( return at::arange_out(result, start, end, step); } -static Tensor& arange_start_out(const Scalar& start, const Scalar& end, Tensor& result) { - return at::arange_out(result, start, end, /*step=*/1); -} - Tensor& arange_out(const Scalar& end, Tensor& result) { return at::arange_out(result, /*start=*/0, end, /*step=*/1); } -static Tensor& arange_out(Tensor& result, const Scalar& start, const Scalar& end) { - return at::arange_out(result, start, end, /*step=*/1); -} - Tensor _dim_arange(const Tensor& like, int64_t dim) { return at::arange(like.size(dim), like.options().dtype(at::kLong)); } diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 899cf68a7a5a..95c88f4572cb 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -105,10 +105,6 @@ Tensor & detach_(Tensor & self) { return self; } -static Tensor contiguous(const Tensor & self) { - return contiguous(self, MemoryFormat::Contiguous); -} - Tensor contiguous(const Tensor& self, MemoryFormat memory_format) { if (self.is_contiguous(memory_format)) { return self; diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index bdab4ce24551..bf1b6e5fa262 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1181,14 +1181,6 @@ Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef s return result; } -static Tensor as_strided_tensorimpl_meta(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional storage_offset_) { - auto storage_offset = storage_offset_.value_or(self.storage_offset()); - auto result = at::detail::make_tensor( - c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); - setStrided(result, size, stride, storage_offset); - return result; -} - template inline void setStridedUnchecked( const Tensor& self, @@ -1249,10 +1241,6 @@ const Tensor &as_strided__symint(const Tensor& self, SymIntArrayRef size, SymInt return self; } -static Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) { - return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); -} - // Should just use narrow_copy_out, but this API is used internally at Meta: // https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561 Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ @@ -3587,10 +3575,6 @@ Tensor view_as(const Tensor& self, const Tensor& other) { return self.view_symint(other.sym_sizes()); } -static int64_t numel(const Tensor& self) { - return self.unsafeGetTensorImpl()->numel(); -} - std::vector unbind(const Tensor &self, int64_t dim) { dim = maybe_wrap_dim(dim, self.dim()); int64_t size = self.size(dim); diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 09dca06e2b5a..643bd7eed0a2 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -27,53 +27,7 @@ Tensor mkldnn_convolution( TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support"); } -static Tensor mkldnn_convolution_backward_input( - IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { - TORCH_CHECK(false, "mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support"); -} - -static std::tuple mkldnn_convolution_backward_weights( - IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { - TORCH_CHECK(false, "mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support"); -} - -static std::tuple mkldnn_convolution_backward( - const Tensor& input, const Tensor& grad_output_t, const Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array output_mask) { - TORCH_CHECK(false, "mkldnn_convolution_backward: ATen not compiled with MKLDNN support"); -} - REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub); - -static Tensor mkldnn_convolution_transpose( - const Tensor& input, const Tensor& weight, const std::optional& bias_opt, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) { - TORCH_CHECK(false, "mkldnn_convolution_transpose: ATen not compiled with MKLDNN support"); -} - -static Tensor mkldnn_convolution_transpose_backward_input( - IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool bias_defined) { - TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_input: ATen not compiled with MKLDNN support"); -} - -static std::tuple mkldnn_convolution_transpose_backward_weights( - IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool bias_defined) { - TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_weights: ATen not compiled with MKLDNN support"); -} - -static std::tuple mkldnn_convolution_transpose_backward( - const Tensor& input, const Tensor& grad_output_t, const Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, std::array output_mask) { - TORCH_CHECK(false, "mkldnn_convolution_transpose_backward: ATen not compiled with MKLDNN support"); -} - REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_stub); REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub); diff --git a/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp b/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp index 573b2ffff4b4..5d471d235275 100644 --- a/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp +++ b/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp @@ -172,16 +172,6 @@ Tensor mean_quantized_cpu( return result; } -static Tensor& mean_out_quantized_cpu( - Tensor& result, - const Tensor& self, - DimnameList dim, - bool keepdim, - std::optional opt_dtype) { - return mean_out_quantized_cpu( - self, dimnames_to_positions(self, dim), keepdim, opt_dtype, result); -} - // qstd inline bool is_std_inner_dim_fast_path( const Tensor& self, diff --git a/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp b/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp index d4dfa7ff08c9..947f9f1696dd 100644 --- a/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp +++ b/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp @@ -216,20 +216,6 @@ Tensor upsample_bilinear2d_quantized_cpu( } } -using at::native::upsample::compute_output_size; -using at::native::upsample::get_scale_value; - -static Tensor upsample_bilinear2d_quantized_cpu( - const Tensor& input, - at::OptionalIntArrayRef output_size, - bool align_corners, - std::optional> scale_factors) { - auto osize = compute_output_size(input.sizes(), output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return upsample_bilinear2d_quantized_cpu(input, osize, align_corners, scale_h, scale_w); -} - DEFINE_DISPATCH(qupsample_bilinear2d_nhwc_stub); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index f058c68579f8..fff755c7b418 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -270,10 +270,6 @@ Tensor& div_sparse_(Tensor& self, const Tensor& value) { return div_out_sparse_zerodim(self, value, self); } -static SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, SparseTensor& r) { - return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r); -} - Tensor div_sparse(const Tensor& self, const Tensor& value, std::optional rounding_mode) { auto commonDtype = at::result_type(self, value); if (c10::isIntegralType(commonDtype, /*includeBool=*/true) && !rounding_mode.has_value()) { @@ -287,10 +283,6 @@ Tensor& div_sparse_(Tensor& self, const Tensor& value, std::optional rounding_mode, SparseTensor& r) { - return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), std::move(rounding_mode), r); -} - // -------------------------------------------------------------------- // floor_divide(SparseTensor, Scalar) // -------------------------------------------------------------------- @@ -350,10 +342,6 @@ Tensor& floor_divide_sparse_(Tensor& self, const Tensor& value) { return floor_divide_out_sparse_zerodim(self, value, self); } -static SparseTensor& floor_divide_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, const Scalar& value) { - return floor_divide_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r); -} - // -------------------------------------------------------------------- // norm(SparseTensor, Scalar) // -------------------------------------------------------------------- From 4077cdd589f97b63cae4d84455e24b45e41f09fb Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 10 Jun 2024 16:10:53 -0700 Subject: [PATCH 613/706] [pipelining][doc] Update arg list of pipeline API (#128361) And document the use of `build_stage` API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128361 Approved by: https://github.com/wconstab --- docs/source/distributed.pipelining.rst | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 2e5e9f74c662..36e90d3aa0f7 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -261,11 +261,12 @@ Let us see how the ``pipeline`` API works: from torch.distributed.pipelining import pipeline, SplitPoint + # An example micro-batch input x = torch.LongTensor([1, 2, 4, 5]) + pipe = pipeline( module=mod, - num_chunks=1, - example_args=(x,), + mb_args=(x,), split_spec={ "layers.1": SplitPoint.BEGINNING, } @@ -306,7 +307,7 @@ If we ``print(pipe)``, we can see:: The "model partitions" are represented by submodules (``submod_0``, -``submod_1``), each of which is reconstructed with original model operations +``submod_1``), each of which is reconstructed with original model operations, weights and hierarchies. In addition, a "root-level" ``forward`` function is reconstructed to capture the data flow between those partitions. Such data flow will be replayed by the pipeline runtime later, in a distributed fashion. @@ -317,12 +318,29 @@ The ``Pipe`` object provides a method for retrieving the "model partitions": stage_mod : nn.Module = pipe.get_stage_module(stage_idx) -You can also create a distributed stage runtime on a device using ``Pipe``: +The returned ``stage_mod`` is a ``nn.Module``, with which you can create an +optimizer, save or load checkpoints, or apply other parallelisms. + +``Pipe`` also allows you to create a distributed stage runtime on a device given +a ``ProcessGroup``: .. code-block:: python stage = pipe.build_stage(stage_idx, device, group) +Alternatively, if you would like to build the stage runtime later after some +modification to the ``stage_mod``, you can use a functional version of the +``build_stage`` API. For example: + +.. code-block:: python + + from torch.distributed.pipelining import build_stage + from torch.nn.parallel import DistributedDataParallel + + dp_mod = DistributedDataParallel(stage_mod) + info = pipe.info() + stage = build_stage(dp_mod, stage_idx, info, device, group) + .. note:: The ``pipeline`` frontend uses a tracer (``torch.export``) to capture your model into a single graph. If your model is not full-graph'able, you can use From 665e568381eedd417dc4b7cf293ca8a7fe825ac4 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 10 Jun 2024 16:27:39 -0700 Subject: [PATCH 614/706] [inductor][inlining nn module] Skip batchnorm version check test for inlining (#128268) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128268 Approved by: https://github.com/zou3519 ghstack dependencies: #128295, #126578 --- test/inductor/test_torchinductor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 53167a83ecd8..1ae24869e4df 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -80,6 +80,7 @@ IS_X86, parametrize, serialTest, + skipIfNNModuleInlined, skipIfRocm, skipIfXpu, subtest, @@ -3972,6 +3973,7 @@ def forward(self, x): self.assertEqual(eager_delta, compile_delta) + @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/128198") def test_buffer_batch_norm(self): class MyModel(torch.nn.Module): def __init__(self): From ca45649eb5810af86508962bef3ff66f60cc457d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 10 Jun 2024 16:35:48 -0700 Subject: [PATCH 615/706] [easy][dynamo][inline work] Fix test with inlining inbuilt nn modules (#128254) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128254 Approved by: https://github.com/williamwen42 ghstack dependencies: #128295, #126578, #128268 --- test/dynamo/test_decorators.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 890edca40ccc..16853bbfb3ad 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -304,13 +304,12 @@ def f3(x): self.assertEqual(cnt.frame_count, 0) def test_torch_guards_stack_frame_register_inlining_disable(self): - y = torch.nn.Parameter(torch.tensor([0.25, 0.25])) x = torch.tensor([0.5, 0.5]) class encoder(torch.nn.Module): def __init__(self, y): super().__init__() - self.register_parameter("param", y) + self.a = y @torch._dynamo.disable def helper(self, x, y): @@ -318,9 +317,9 @@ def helper(self, x, y): def forward(self, a, *args): x = a + a - return self.helper(x, self.param) + return self.helper(x, self.a) - e = encoder(y) + e = encoder(2.0) seen_frames = [] import contextlib From 7afffdf48b596d6d8b7e71fb72ded6402d0dce41 Mon Sep 17 00:00:00 2001 From: zengxian Date: Tue, 11 Jun 2024 03:12:11 +0000 Subject: [PATCH 616/706] [CI] Comment hf_T5_generate, hf_GPT2 and timm_efficientnet in inductor cpu smoketest for performance unstable issue (#127588) Fixes #126993 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127588 Approved by: https://github.com/chuanqi129, https://github.com/jgong5, https://github.com/desertfire --- .../dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv b/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv index f2f8c1b26176..e26d3b97864f 100644 --- a/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv +++ b/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv @@ -4,12 +4,11 @@ phlippe_densenet,float32,static,default,1.3988316 basic_gnn_gcn,float32,dynamic,default,1.074576405 llama_v2_7b_16h,float32,dynamic,default,1.211740245 resnet50,float32,dynamic,default,1.65984261 -timm_efficientnet,float32,static,cpp,2.271561735 +#timm_efficientnet,float32,static,cpp,2.1938112 mobilenet_v3_large,float32,static,cpp,2.63375628 timm_resnest,float32,dynamic,cpp,1.67998548 pyhpc_turbulent_kinetic_energy,float32,dynamic,cpp,1.59968463 -#hf_GPT2,float32,dynamic,cpp, -hf_GPT2,float32,dynamic,cpp,1.379885175 +#hf_GPT2,float32,dynamic,cpp,1.292704418 resnext50_32x4d,amp,static,default,1.461687045 vgg16,amp,static,default,1.267194285 hf_Longformer,amp,dynamic,default,0.997006035 @@ -17,6 +16,6 @@ hf_Bert_large,amp,dynamic,default,0.99391146 llama,amp,static,default,1.32950568 timm_regnet,amp,static,cpp,1.157188305 lennard_jones,amp,static,cpp,2.240104485 -hf_T5_generate,amp,dynamic,cpp,1.447656135 +#hf_T5_generate,amp,dynamic,cpp,1.29339502 timm_vovnet,amp,dynamic,cpp,1.07856471 mobilenet_v2,amp,dynamic,cpp,2.27774577 From 16e67be7f1a49e4cfd354f584b60ce1cdd9e3ca0 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 10 Jun 2024 12:08:23 -0700 Subject: [PATCH 617/706] Also preserve unbacked SymInts when partitioning as backward inputs (#128338) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/128338 Approved by: https://github.com/IvanKobzarev --- .../test_torchinductor_dynamic_shapes.py | 42 +++++++++++++++++++ torch/fx/experimental/symbolic_shapes.py | 15 +++---- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 2f9506a9d561..5608adc94e2f 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -8,6 +8,7 @@ import sys import unittest from functools import partial +from typing import List import torch import torch.library @@ -369,6 +370,47 @@ def f(x): arg = torch.tensor(5, device=device) self.assertEqual(f(arg), cf(arg)) + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + @torch._inductor.config.patch(implicit_fallbacks=True) + def test_unbacked_save_for_backwards(self, device) -> None: + @torch.library.custom_op("_test::_cat", mutates_args=()) + def _cat(t: torch.Tensor, ds: List[int]) -> torch.Tensor: + return t * t.new_ones([sum(ds)]) + + @torch.library.register_fake("_test::_cat") + def _cat_fake(t: torch.Tensor, ds: List[int]) -> torch.Tensor: + [torch._check_is_size(d) for d in ds] + return t.new_empty([sum(ds)]) + + def _cat_setup_context(ctx, inputs, output): + pass + + def _cat_backward(ctx, grad): + return grad.sum(), None + + torch.library.register_autograd( + "_test::_cat", + _cat_backward, + setup_context=_cat_setup_context, + ) + + def fn(t, sizes): + r = torch.ops._test._cat(t, sizes.tolist()) + return r * t + + t = torch.randn((), requires_grad=True, device=device) + sizes = torch.tensor([4, 8], dtype=torch.int64, device="cpu") + out = fn(t, sizes) + out.sum().backward() + expect = t.grad + t.grad = None + torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)( + t, sizes + ).sum().backward() + self.assertEqual(t.grad, expect) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_unbacked_reduction(self, device): expect_fail = device == "cpu" and not IS_ARM64 diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 3852dc44e7ea..800b92b2e07b 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -453,20 +453,21 @@ def free_unbacked_symbols(x): # setup! def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]: if ( - node.op == "placeholder" and "val" in node.meta and isinstance(node.meta["val"], torch.SymInt) and - isinstance(node.meta["val"].node.expr, sympy.Symbol) + isinstance(node.meta["val"].node.expr, sympy.Symbol) and + (node.op == "placeholder" or free_unbacked_symbols(node.meta["val"].node.expr)) ): return node.meta["val"].node.expr return None def find_symbol_binding_fx_nodes(graph): - return { - node.meta["val"].node.expr: node - for node in graph.nodes - if is_symbol_binding_fx_node(node) - } + r = {} + # NB: Prefer first occurrence of symbol + for node in graph.nodes: + if is_symbol_binding_fx_node(node) and node.meta["val"].node.expr not in r: + r[node.meta["val"].node.expr] = node + return r # Analogous to ConvertIntSource From cba195c8edd6c7149036ef0767772d11fff5390e Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Fri, 7 Jun 2024 01:25:11 +0000 Subject: [PATCH 618/706] Support aten operations with out tensor (#124926) This PR intends to support the aten operations with the `out` tensor. Currently, the AOT compile always does **NOT** keep input tensor mutations. According to the comments, this is because it has not encountered such a use case. > For now there's no use case involving keeping input mutations in the graph (which we can only do in the inference case anyway). We can add this later if we need to. However, for aten operations, it is popular that the `out` tensor is an input parameter and needs to be mutated. This PR intends to support it by adding a `keep_inference_input_mutations` flag to `aot_inductor.keep_inference_input_mutations`. This flag can provide flexibility to the callee in deciding whether the AOT compile needs to keep input tensor mutations in the graph. Take `clamp` as an example as follows. ```python out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(-2.0) inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0) min_tensor = inp_tensor - 0.05 max_tensor = inp_tensor + 0.05 torch.clamp(input=inp_tensor, min=min_tensor, max=max_tensor, out=out_tensor) ``` W/O this PR ```python def forward(self): arg0_1: "f32[128]"; arg1_1: "f32[128]"; arg2_1: "f32[128]"; arg3_1: "f32[128]"; arg0_1, arg1_1, arg2_1, arg3_1, = fx_pytree.tree_flatten_spec([], self._in_spec) clamp_min: "f32[128]" = torch.ops.aten.clamp_min.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None clamp_max: "f32[128]" = torch.ops.aten.clamp_max.Tensor(clamp_min, arg2_1); clamp_min = arg2_1 = None return (clamp_max, clamp_max) ``` W/ this PR ```python def forward(self): arg0_1: "f32[128]"; arg1_1: "f32[128]"; arg2_1: "f32[128]"; arg3_1: "f32[128]"; arg0_1, arg1_1, arg2_1, arg3_1, = fx_pytree.tree_flatten_spec([], self._in_spec) clamp_min: "f32[128]" = torch.ops.aten.clamp_min.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None clamp_max: "f32[128]" = torch.ops.aten.clamp_max.Tensor(clamp_min, arg2_1); clamp_min = arg2_1 = None copy_: "f32[128]" = torch.ops.aten.copy_.default(arg3_1, clamp_max); arg3_1 = clamp_max = None return (copy_,) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/124926 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/angelayi --- .../aot_inductor_torchbench_inference.csv | 2 +- .../aot_inductor_torchbench_inference.csv | 2 +- test/inductor/test_torchinductor.py | 131 +++++++++++------- torch/_inductor/compile_fx.py | 18 ++- torch/export/_unlift.py | 10 +- 5 files changed, 106 insertions(+), 57 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv index 40382a4f277c..65c905837c2a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv @@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0 -pyhpc_isoneutral_mixing,fail_to_run,0 +pyhpc_isoneutral_mixing,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv index 40382a4f277c..65c905837c2a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv @@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0 -pyhpc_isoneutral_mixing,fail_to_run,0 +pyhpc_isoneutral_mixing,pass,0 diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 1ae24869e4df..cc111cf83898 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -233,6 +233,23 @@ def run_with_backward(): return run_and_get_code(run_with_backward) +def register_ops_with_aoti_compile(ns, op_set, dispatch_key, torch_compile_op_lib_impl): + for _op_name in op_set: + qualified_op_name = f"{ns}::{_op_name}" + _, overload_names = torch._C._jit_get_operation(qualified_op_name) + for overload_name in overload_names: + try: + reg_op_name = qualified_op_name + schema = torch._C._get_schema(qualified_op_name, overload_name) + if schema.overload_name: + reg_op_name = f"{qualified_op_name}.{schema.overload_name}" + torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 + reg_op_name, dispatch_key + ) + except Exception as e: + continue + + class TestCase(InductorTestCase): @classmethod def setUpClass(cls): @@ -751,6 +768,58 @@ def fn(a, b): ), ) + @skipCUDAIf(not SM80OrLater, "Requires sm80") + def test_eager_aoti_support_out(self): + ns = "aten" + op_name = "clamp" + dispatch_key = "CPU" + device = "cpu" + if self.device.lower() == "cuda": + dispatch_key = "CUDA" + device = "cuda" + + inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0) + min_tensor = inp_tensor - 0.05 + max_tensor = inp_tensor + 0.05 + with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: + ref_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_( + -1 + ) + ref_tensor = torch.clamp( + max=max_tensor, min=min_tensor, input=inp_tensor, out=ref_out_tensor + ) + + ref_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_( + -1 + ) + ref_tensor1 = torch.clamp( + max=max_tensor, out=ref_out_tensor1, min=min_tensor, input=inp_tensor + ) + + register_ops_with_aoti_compile( + ns, [op_name], dispatch_key, torch_compile_op_lib_impl + ) + + res_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_( + -1 + ) + res_tensor = torch.clamp( + max=max_tensor, min=min_tensor, input=inp_tensor, out=res_out_tensor + ) + + self.assertEqual(ref_tensor, res_tensor) + self.assertEqual(ref_out_tensor, res_out_tensor) + + res_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_( + -1 + ) + res_tensor1 = torch.clamp( + max=max_tensor, out=res_out_tensor1, min=min_tensor, input=inp_tensor + ) + + self.assertEqual(ref_tensor1, res_tensor1) + self.assertEqual(ref_out_tensor1, res_out_tensor1) + @skipCUDAIf(not SM80OrLater, "Requires sm80") def test_eager_aoti_cache_hit(self): ns = "aten" @@ -779,24 +848,13 @@ def test_eager_aoti_cache_hit(self): with mock.patch( "torch._inductor.utils.aoti_compile_with_persistent_cache", None ): - qualified_op_name = f"{ns}::{op_name}" - _, overload_names = torch._C._jit_get_operation(qualified_op_name) - with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: # Get ref result from eager ref_value = getattr(torch.ops.aten, op_name)(input_tensor) - for overload_name in overload_names: - try: - reg_op_name = qualified_op_name - schema = torch._C._get_schema(qualified_op_name, overload_name) - if schema.overload_name: - reg_op_name = f"{qualified_op_name}.{schema.overload_name}" - torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 - reg_op_name, dispatch_key - ) - except Exception as e: - continue + register_ops_with_aoti_compile( + ns, [op_name], dispatch_key, torch_compile_op_lib_impl + ) # Invoke the pre-compiled kernel and get result. res_value = getattr(torch.ops.aten, op_name)(input_tensor) @@ -804,7 +862,7 @@ def test_eager_aoti_cache_hit(self): self.assertEqual(ref_value, res_value) @skipCUDAIf(not SM80OrLater, "Requires sm80") - def test_aoti_compile_with_persistent_cache(self): + def test_eager_aoti_with_persistent_cache(self): def fn(a): return torch.abs(a) @@ -906,19 +964,9 @@ def test_eager_aoti_with_scalar(self): for scalar_value in scalar_values: ref_values.append(torch.add(a, b, alpha=scalar_value)) - qualified_op_name = f"{namespace_name}::{op_name}" - _, overload_names = torch._C._jit_get_operation(qualified_op_name) - for overload_name in overload_names: - try: - reg_op_name = qualified_op_name - schema = torch._C._get_schema(reg_op_name, overload_name) - if schema.overload_name: - reg_op_name = f"{reg_op_name}.{schema.overload_name}" - torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 - reg_op_name, dispatch_key - ) - except Exception as e: - continue + register_ops_with_aoti_compile( + namespace_name, [op_name], dispatch_key, torch_compile_op_lib_impl + ) res_values = [] for scalar_value in scalar_values: @@ -928,8 +976,7 @@ def test_eager_aoti_with_scalar(self): self.assertEqual(ref_values, res_values) @skipCUDAIf(not SM80OrLater, "Requires sm80") - def test_torch_compile_override_registration(self): - dynamic = False + def test_eager_aoti_override_registration(self): namespace_name = "aten" dispatch_key = "CPU" device = torch.device("cpu") @@ -951,24 +998,10 @@ def fn(x, op_name=""): ref = opt_fn(x) ref_array.append(ref) - def register_ops(op_set, dispatch_key, torch_compile_op_lib_impl): - for _op_name in op_set: - qualified_op_name = f"{namespace_name}::{_op_name}" - _, overload_names = torch._C._jit_get_operation(qualified_op_name) - for overload_name in overload_names: - try: - reg_op_name = qualified_op_name - schema = torch._C._get_schema(qualified_op_name, overload_name) - if schema.overload_name: - reg_op_name = f"{qualified_op_name}.{schema.overload_name}" - torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 - reg_op_name, dispatch_key - ) - except Exception as e: - continue - with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: - register_ops(unary_op_set, dispatch_key, torch_compile_op_lib_impl) + register_ops_with_aoti_compile( + namespace_name, unary_op_set, dispatch_key, torch_compile_op_lib_impl + ) res_array = [] for unary_op_name in unary_op_set: @@ -985,7 +1018,9 @@ def register_ops(op_set, dispatch_key, torch_compile_op_lib_impl): ref_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor) with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: - register_ops(["clamp"], dispatch_key, torch_compile_op_lib_impl) + register_ops_with_aoti_compile( + namespace_name, ["clamp"], dispatch_key, torch_compile_op_lib_impl + ) res_with_min = torch.ops.aten.clamp(a, min_tensor) res_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor) self.assertEqual(ref_with_min, res_with_min) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index d0069c5cc219..dc3b1b811a6b 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -201,11 +201,19 @@ def _unlift_graph(mod, gm, graph_signature): outputs = list(gm.graph.nodes)[-1].args[0] mutated_outputs = [] - for out in outputs: - if out.name in graph_signature.buffers_to_mutate: - mutated_outputs.append(graph_signature.buffers_to_mutate[out.name]) - else: - mutated_outputs.append(None) + buffer_mutations = graph_signature.buffers_to_mutate + user_input_mutations = graph_signature.user_inputs_to_mutate + output_tokens = graph_signature.output_tokens + for idx, out in enumerate(outputs): + value = None + + if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): + if out.name in buffer_mutations: + value = buffer_mutations[out.name] + elif out.name in user_input_mutations: + value = user_input_mutations[out.name] + + mutated_outputs.append(value) unlifted_gm = _unlift( gm, diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 97df0562caa7..5a8f144b04e0 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -86,6 +86,7 @@ def _insert_copy_for_mutations( assert len(outputs) == len(mutated_outputs) user_output_nodes = [] + return_nodes_to_copy = {} for return_node, mutated_node_name in zip(outputs, mutated_outputs): if mutated_node_name is None: user_output_nodes.append(return_node) @@ -101,13 +102,18 @@ def _insert_copy_for_mutations( ) with gm.graph.inserting_before(output_node): - _ = gm.graph.call_function( + copy_node = gm.graph.call_function( torch.ops.aten.copy_.default, (mutated_node, return_node) ) + return_nodes_to_copy[return_node] = copy_node + output_args = [ + return_nodes_to_copy[node] if node in return_nodes_to_copy else node + for node in user_output_nodes + ] with gm.graph.inserting_before(output_node): # Only return user outputs - new_output = gm.graph.output(tuple(user_output_nodes)) + new_output = gm.graph.output(tuple(output_args)) output_node.replace_all_uses_with(new_output) gm.graph.erase_node(output_node) From fe39c07826bae984f4de84d12dc048253d0a7d53 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 10 Jun 2024 16:50:00 -0700 Subject: [PATCH 619/706] [pipelining][doc] Remove duplicated words (#128368) "for execution" is used in both step titles Pull Request resolved: https://github.com/pytorch/pytorch/pull/128368 Approved by: https://github.com/wconstab ghstack dependencies: #128361 --- docs/source/distributed.pipelining.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 36e90d3aa0f7..e1d66d223b2b 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -62,8 +62,8 @@ Overall, the ``pipelining`` package provides the following features: application on the Llama model. -Step 1: build ``PipelineStage`` for execution -********************************************* +Step 1: build ``PipelineStage`` +******************************* Before we can use a ``PipelineSchedule``, we need to create ``PipelineStage`` objects that wrap the part of the model running in that stage. The From fa88f390a04bc59aefad555b1fc631a33b42f2cb Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 11 Jun 2024 04:53:37 +0000 Subject: [PATCH 620/706] Revert "[inductor] enable fx graph cache on torchbench (#128239)" This reverts commit 734e8f6ad7e7f0fa0341fb658f1f986225173f5f. Reverted https://github.com/pytorch/pytorch/pull/128239 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to surface a bunch of inductor failures in trunk https://hud.pytorch.org/pytorch/pytorch/commit/734e8f6ad7e7f0fa0341fb658f1f986225173f5f ([comment](https://github.com/pytorch/pytorch/pull/128239#issuecomment-2159789242)) --- benchmarks/dynamo/torchbench.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index d7877c5a3fac..fd11e984bbdc 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -25,10 +25,6 @@ # We are primarily interested in tf32 datatype torch.backends.cuda.matmul.allow_tf32 = True -# Enable FX graph caching -if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: - torch._inductor.config.fx_graph_cache = True - def _reassign_parameters(model): # torch_geometric models register parameter as tensors due to From 5b5d269d341b61d0d6dbae86a4f3bca24f630e72 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Mon, 10 Jun 2024 15:23:28 -0700 Subject: [PATCH 621/706] Speed up fx graph iteration by implementing it in C++ (#128288) Before this change ``` python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py iterating over 100000000 FX nodes took 19.5s (5132266 nodes/s) ``` After this change ``` python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py iterating over 100000000 FX nodes took 3.4s (29114001 nodes/s) ``` 5.7x improvement Differential Revision: [D58343997](https://our.internmc.facebook.com/intern/diff/D58343997) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128288 Approved by: https://github.com/jansel, https://github.com/albanD --- build_variables.bzl | 1 + torch/_C/__init__.pyi.in | 11 ++ torch/csrc/Module.cpp | 3 + torch/csrc/fx/node.cpp | 257 +++++++++++++++++++++++++++++++++++++++ torch/csrc/fx/node.h | 6 + torch/fx/graph.py | 17 +-- torch/fx/node.py | 23 +++- 7 files changed, 300 insertions(+), 18 deletions(-) create mode 100644 torch/csrc/fx/node.cpp create mode 100644 torch/csrc/fx/node.h diff --git a/build_variables.bzl b/build_variables.bzl index 20822ba95cf2..323588c15b4c 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -827,6 +827,7 @@ libtorch_python_core_sources = [ "torch/csrc/dynamo/guards.cpp", "torch/csrc/dynamo/init.cpp", "torch/csrc/functorch/init.cpp", + "torch/csrc/fx/node.cpp", "torch/csrc/mps/Module.cpp", "torch/csrc/mtia/Module.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 6f719d7db20a..30a4fb6c36c6 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2333,3 +2333,14 @@ def _save_pickle(obj: Any) -> bytes: ... # Defined in torch/csrc/jit/runtime/static/init.cpp def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ... def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ... + +# Defined in torch/csrc/fx/node.cpp +class _NodeBase: + _erased: _bool + _prev: "_NodeBase" + _next: "_NodeBase" + +class _NodeIter(Iterator): + def __init__(self, root: _NodeBase, reversed: _bool) -> None: ... + def __iter__(self) -> Iterator[_NodeBase]: ... + def __next__(self) -> _NodeBase: ... diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 7e509fce7af2..57b28d676484 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -67,6 +67,7 @@ #include #include #include +#include #include #include #include @@ -1602,6 +1603,8 @@ PyObject* initModule() { THPDevice_init(module); THPStream_init(module); THPEvent_init(module); + NodeBase_init(module); + NodeIter_init(module); ASSERT_TRUE(THPVariable_initModule(module)); ASSERT_TRUE(THPFunction_initModule(module)); ASSERT_TRUE(THPEngine_initModule(module)); diff --git a/torch/csrc/fx/node.cpp b/torch/csrc/fx/node.cpp new file mode 100644 index 000000000000..dc96737abdab --- /dev/null +++ b/torch/csrc/fx/node.cpp @@ -0,0 +1,257 @@ +#include + +#include +#include + +//////////////////////////////// +// NodeBase +/////////////////////////////// + +struct NodeBase { + PyObject_HEAD bool _erased; + NodeBase* _prev; + NodeBase* _next; +}; + +static PyObject* NodeBase_new( + PyTypeObject* type, + PyObject* args, + PyObject* kwds) { + PyObject* self = type->tp_alloc(type, 0); + if (!self) + return nullptr; + return self; +} + +static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) { + self->_erased = false; + Py_INCREF(self); + self->_prev = self; + Py_INCREF(self); + self->_next = self; + return 0; +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) +static struct PyMemberDef NodeBase_members[] = { + {"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr}, + {"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr}, + {"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr}, + {nullptr} /* Sentinel */ +}; + +static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) { + Py_VISIT(self->_prev); + Py_VISIT(self->_next); + return 0; +} + +static int NodeBase_clear(NodeBase* self) { + Py_CLEAR(self->_prev); + Py_CLEAR(self->_next); + return 0; +} + +static void NodeBase_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + (void)NodeBase_clear((NodeBase*)self); + Py_TYPE(self)->tp_free(self); +} + +static PyTypeObject NodeBaseType = { + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */ + sizeof(NodeBase), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)NodeBase_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | + Py_TPFLAGS_HAVE_GC, /* tp_flags */ + nullptr, /* tp_doc */ + (traverseproc)NodeBase_traverse, /* tp_traverse */ + (inquiry)NodeBase_clear, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + NodeBase_members, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)NodeBase_init_fn, /* tp_init */ + nullptr, /* tp_alloc */ + NodeBase_new, /* tp_new */ +}; + +bool NodeBase_init(PyObject* module) { + if (PyModule_AddType(module, &NodeBaseType) < 0) { + return false; + } + return true; +} + +//////////////////////////////// +// NodeIter +//////////////////////////////// + +struct NodeIter { + PyObject_HEAD bool _reversed; + NodeBase* _root; + NodeBase* _cur; +}; + +static PyObject* NodeIter_new( + PyTypeObject* type, + PyObject* args, + PyObject* kwds) { + PyObject* self = type->tp_alloc(type, 0); + if (!self) + return nullptr; + return self; +} + +static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) { + NodeBase* root = nullptr; + bool reversed = false; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + constexpr const char* keywords[] = {"root", "reversed", nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, + kwargs, + "Ob|", + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(keywords), + &root, + &reversed)) { + return -1; + } + self->_reversed = reversed; + Py_INCREF(root); + self->_root = root; + Py_INCREF(root); + self->_cur = root; + return 0; +} + +template +PyObject* NodeIter_iternext_helper(NodeIter* self) { + // It should be possible to relax the ref counting here + // but in practice, we do not have that many _erased Nodes, + // so probably not worth it. + if constexpr (reversed) { + NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev); + Py_CLEAR(self->_cur); + self->_cur = prev; + } else { + NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next); + Py_CLEAR(self->_cur); + self->_cur = next; + } + while (self->_cur != self->_root) { + if (!self->_cur->_erased) { + Py_INCREF(self->_cur); + return (PyObject*)self->_cur; + } + if constexpr (reversed) { + NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev); + Py_CLEAR(self->_cur); + self->_cur = prev; + } else { + NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next); + Py_CLEAR(self->_cur); + self->_cur = next; + } + } + PyErr_SetNone(PyExc_StopIteration); + return nullptr; +} + +PyObject* NodeIter_iternext(PyObject* _self) { + NodeIter* self = (NodeIter*)_self; + if (self->_reversed) { + return NodeIter_iternext_helper(self); + } else { + return NodeIter_iternext_helper(self); + } +} + +static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) { + Py_VISIT(self->_root); + Py_VISIT(self->_cur); + return 0; +} + +static int NodeIter_clear(NodeIter* self) { + Py_CLEAR(self->_root); + Py_CLEAR(self->_cur); + return 0; +} + +static void NodeIter_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + (void)NodeIter_clear((NodeIter*)self); + Py_TYPE(self)->tp_free(self); +} + +static PyTypeObject NodeIterType = { + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */ + sizeof(NodeIter), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)NodeIter_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */ + nullptr, /* tp_doc */ + (traverseproc)NodeIter_traverse, /* tp_traverse */ + (inquiry)NodeIter_clear, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + NodeIter_iternext, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)NodeIter_init_fn, /* tp_init */ + nullptr, /* tp_alloc */ + NodeIter_new, /* tp_new */ +}; + +bool NodeIter_init(PyObject* module) { + if (PyModule_AddType(module, &NodeIterType) < 0) { + return false; + } + return true; +} diff --git a/torch/csrc/fx/node.h b/torch/csrc/fx/node.h new file mode 100644 index 000000000000..2ea74e839f25 --- /dev/null +++ b/torch/csrc/fx/node.h @@ -0,0 +1,6 @@ +#pragma once + +#include + +bool NodeBase_init(PyObject* module); +bool NodeIter_init(PyObject* module); diff --git a/torch/fx/graph.py b/torch/fx/graph.py index dea8265f134d..9e034278ccb1 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -4,6 +4,7 @@ import torch.utils._pytree as pytree from . import _pytree as fx_pytree from ._compatibility import compatibility +from torch._C import _NodeIter import os import contextlib @@ -271,20 +272,8 @@ def __len__(self): return self.graph._len def __iter__(self): - root = self.graph._root - if self.direction == "_next": - cur = root._next - while cur is not root: - if not cur._erased: - yield cur - cur = cur._next - else: - assert self.direction == "_prev" - cur = root._prev - while cur is not root: - if not cur._erased: - yield cur - cur = cur._prev + assert self.direction == "_prev" or self.direction == "_next" + yield from _NodeIter(self.graph._root, self.direction == "_prev") def __reversed__(self): return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') diff --git a/torch/fx/node.py b/torch/fx/node.py index 8b4768aa497a..2e400158b551 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -11,6 +11,7 @@ import warnings from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair from .._ops import ops as _ops +from torch._C import _NodeBase if TYPE_CHECKING: from .graph import Graph @@ -139,7 +140,7 @@ def _format_arg(arg, max_list_len=float('inf')) -> str: return str(arg) @compatibility(is_backward_compatible=True) -class Node: +class Node(_NodeBase): """ ``Node`` is the data structure that represents individual operations within a ``Graph``. For the most part, Nodes represent callsites to various entities, @@ -197,6 +198,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', annotation of values in the generated code or for other types of analyses. """ + super().__init__() self.graph = graph self.name = name # unique name of value being created assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] @@ -235,9 +237,6 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # does not produce a value, it's more of a notation. Thus, this value # describes the type of args[0] in the ``return`` node. self.type : Optional[Any] = return_type - self._prev = self - self._next = self - self._erased = False self._sort_key: Any = () # If set, use this fn to print this node @@ -247,6 +246,22 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # transformations. This metadata is preserved across node copies self.meta : Dict[str, Any] = {} + def __getstate__(self): + state = self.__dict__.copy() + state["_erased"] = self._erased + state["_prev"] = self._prev + state["_next"] = self._next + return state + + def __setstate__(self, state): + _erased = state.pop("_erased") + _prev = state.pop("_prev") + _next = state.pop("_next") + self.__dict__.update(state) + self._erased = _erased + self._prev = _prev + self._next = _next + @property def next(self) -> 'Node': """ From 24e7f290993c842dd3680eade1011fd33cb8d108 Mon Sep 17 00:00:00 2001 From: Lourenco Matos Date: Tue, 11 Jun 2024 06:39:02 +0000 Subject: [PATCH 622/706] Lowering for avg_pool_3d_backward (Fixes:#127101) (#127722) We implemented a lowering for the avg_pool3d_backward operation and created tests for it. We ran some benchmarks and achieved the following results: ``` [-------------- avgpool_3d_backwards --------------] | Decomposed | Eager 16 threads: ---------------------------------------- (3, 5, 400, 200, 200) | 6061 | 11160 (3, 5, 300, 200, 200) | 4547 | 8372 (3, 5, 200, 200, 200) | 3032 | 5585 (3, 5, 300, 300, 300) | 10100 | 18840 (3, 5, 100, 100, 100) | 381 | 703 (3, 5, 100, 300, 200) | 2270 | 4190 (8, 8, 128, 128, 128) | 3397 | 6253 (2, 3, 150, 150, 150) | 520 | 947 (1, 3, 128, 128, 128) | 161 | 299 (8, 16, 64, 64, 64) | 851 | 1569 (1, 1, 50, 50, 50) | 17 | 11 (3, 5, 20, 40, 40) | 17 | 30 (3, 5, 10, 20, 20) | 17 | 11 (1, 1, 10, 10, 10) | 16 | 11 (3, 5, 5, 10, 10) | 17 | 11 (3, 5, 2, 5, 5) | 17 | 11 ``` These were run on an RTX 3050, so we were not able to allocate larger tensors due to memory limitations. We believe it would be beneficial to benchmark this on more recent hardware, just to check if the performance holds up with larger sizes. Furthermore, we also refactored code from adaptive_avg_pool2d and adaptive_max_pool2d, to reduce code duplication. We diffed the kernels and they are identical. Fixes #127101 Co-authored-by: Martim Mendes Pull Request resolved: https://github.com/pytorch/pytorch/pull/127722 Approved by: https://github.com/jansel --- test/inductor/test_torchinductor.py | 89 ++++ ...st_torchinductor_codegen_dynamic_shapes.py | 1 + torch/_inductor/lowering.py | 424 ++++++++++++++---- 3 files changed, 416 insertions(+), 98 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index cc111cf83898..a12d68c18c74 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -7851,6 +7851,95 @@ def fn(a, b): ) assertGeneratedKernelCountEqual(self, 0) + def test_avg_pool3d_backward(self): + def fn(a, b): + return aten.avg_pool3d_backward( + a, + b, + [2, 2, 2], + [2, 2, 2], + [0, 0, 0], + True, + False, + None, + ) + + self.common( + fn, + [ + torch.randn([2, 4, 7, 7, 7]), + torch.randn([2, 4, 14, 14, 14]), + ], + ) + + def test_avg_pool3d_backward2(self): + def fn(a, b): + return aten.avg_pool3d_backward( + a, + b, + [3, 3, 3], + [1, 1, 1], + [1, 1, 1], + True, + False, + None, + ) + + self.common( + fn, + [ + torch.randn([1, 1, 20, 20, 15]), + torch.randn([1, 1, 20, 20, 15]), + ], + ) + + def test_avg_pool3d_backward3(self): + def fn(a, b): + return aten.avg_pool3d_backward( + a, + b, + [1, 1, 1], + [2, 2, 2], + [0, 0, 0], + False, + False, + None, + ) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + [ + torch.randn([1, 2016, 11, 11, 11]), + torch.randn([1, 2016, 21, 21, 21]), + ], + ) + assertGeneratedKernelCountEqual(self, 1) + + def test_avg_pool3d_backward4(self): + def fn(a, b): + return aten.avg_pool3d_backward( + a, + b, + [13, 13, 13], + [1, 1, 1], + [0, 0, 0], + True, + False, + None, + ) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + [ + torch.randn([1, 16, 12, 12, 12]), + torch.randn([1, 16, 24, 24, 24]), + ], + check_lowp=False, + ) + assertGeneratedKernelCountEqual(self, 0) + @config.patch(search_autotune_cache=False) def test_mm_views(self): def fn(a, b): diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 58deed4460d8..1f641389bfd0 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -146,6 +146,7 @@ def run(*ex, **kwargs): "test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")), "test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda")), "test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_avg_pool3d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")), "test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda")), "test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda")), "test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda")), diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index f9f1bca3d920..bff5fb4073af 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2155,7 +2155,6 @@ def is_aligned(x): # 4) Backwards (try py_impl'ing them) when fwd is written as a decomp -make_fallback(aten.avg_pool3d_backward) make_fallback(aten.max_pool3d_with_indices_backward) make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) make_fallback(aten._adaptive_avg_pool3d_backward) @@ -4034,11 +4033,32 @@ def load(prefix, increments, start_indices, end_indices): return load -def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns): - h_start_index_fn, w_start_index_fn = start_index_fns - h_end_index_fn, w_end_index_fn = end_index_fns +def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out): + h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) - def fn_sum(idx, loader): + w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + return h_start_index, h_end_index, w_start_index, w_end_index + + +def _adaptive_pooling_fn( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): *prefix, bh, bw = idx h_start_index = h_start_index_fn(bh) @@ -4047,7 +4067,7 @@ def fn_sum(idx, loader): w_start_index = w_start_index_fn(bw) w_end_index = w_end_index_fn(bw) - total = None + result = None for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): val = loader( prefix, @@ -4055,13 +4075,66 @@ def fn_sum(idx, loader): [h_start_index, w_start_index], [h_end_index, w_end_index], ) - if total is None: - total = val + if result is None: + result = val else: - total = ops.add(val, total) - return total + result = pooling_fn(val, result) + return result + + return fn + + +def _adaptive_pooling_fn_with_idx( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + index = ops.index_expr( + (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 + ) + + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) - return fn_sum + if maxval is None: + maxval = val + else: + maxval = pooling_fn(val, maxval) + + return maxindex + + return fn fallback_adaptive_avg_pool2d = fallback_handler( @@ -4099,27 +4172,24 @@ def _adaptive_avg_pool2d(x, output_size): new_size = list(batch) + [h_out, w_out] dtype = x.get_dtype() + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + def start_index(index, out_dim, inp_dim): return FloorDiv((index * inp_dim), out_dim) def end_index(index, out_dim, inp_dim): return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) - h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) - h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) - - w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) - w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) - - window_size = h_kernel_max * w_kernel_max - if window_size > 25: - # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. - return fallback_adaptive_avg_pool2d(x, output_size) - - fn_sum = _adaptive_pooling_idx_sum( - [h_kernel_max, w_kernel_max], - [h_start_index, w_start_index], - [h_end_index, w_end_index], + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.add, ) ones_loader = pad_adaptive_loader(ones_like(x)) @@ -4139,60 +4209,6 @@ def fn(idx): return rv -def _adaptive_pooling_idx_max(kernel_maxes, in_sizes, out_sizes, return_index, loader): - # NOTE: There is some duplication between this and addaptive_avg_pool2d and max_pool2d - # Look into refactoring/deduplication after #116418 is merged. - h_in, w_in = in_sizes - h_out, w_out = out_sizes - - def start_index(index, out_dim, inp_dim): - return FloorDiv((index * inp_dim), out_dim) - - def end_index(index, out_dim, inp_dim): - return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) - - h_start_index_fn = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) - h_end_index_fn = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) - w_start_index_fn = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) - w_end_index_fn = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) - - def fn_max(idx): - *prefix, bh, bw = idx - - h_start_index = h_start_index_fn(bh) - h_end_index = h_end_index_fn(bh) - - w_start_index = w_start_index_fn(bw) - w_end_index = w_end_index_fn(bw) - maxval = None - maxindex = None - for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): - val = loader( - prefix, - [ih, iw], - [h_start_index, w_start_index], - [h_end_index, w_end_index], - ) - index = ops.index_expr( - (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 - ) - if return_index: - if maxindex is None: - maxindex = index - else: - maxindex = ops.where(ops.gt(val, maxval), index, maxindex) - if maxval is None: - maxval = val - else: - maxval = ops.maximum(val, maxval) - if return_index: - return maxindex - else: - return maxval - - return fn_max - - fallback_adaptive_max_pool2d = fallback_handler( aten.adaptive_max_pool2d.default, add_to_fallback_set=False ) @@ -4245,32 +4261,46 @@ def adaptive_max_pool2d(x, output_size): # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. return fallback_adaptive_max_pool2d(x, output_size) - inner_func_max_val = _adaptive_pooling_idx_max( + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + inner_func_max_val = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, kernel_maxes=[h_kernel_max, w_kernel_max], in_sizes=[h_in, w_in], out_sizes=[h_out, w_out], - return_index=False, - loader=pad_adaptive_loader(x, float("-inf")), + pooling_fn=ops.maximum, ) - inner_func_max_idx = _adaptive_pooling_idx_max( + inner_func_max_idx = _adaptive_pooling_fn_with_idx( + start_index=start_index, + end_index=end_index, kernel_maxes=[h_kernel_max, w_kernel_max], in_sizes=[h_in, w_in], out_sizes=[h_out, w_out], - return_index=True, - loader=pad_adaptive_loader(x, float("-inf")), + pooling_fn=ops.maximum, ) + def inner_fn_max_val(idx): + return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf"))) + + def inner_fn_max_idx(idx): + return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf"))) + rv = Pointwise.create( device=x.get_device(), dtype=dtype, - inner_fn=inner_func_max_val, + inner_fn=inner_fn_max_val, ranges=new_size, ) ri = Pointwise.create( device=x.get_device(), dtype=torch.int64, - inner_fn=inner_func_max_idx, + inner_fn=inner_fn_max_idx, ranges=new_size, ) return rv, ri @@ -4400,16 +4430,13 @@ def start_index(index, out_dim, inp_dim): def end_index(index, out_dim, inp_dim): return start_index((index + 1), out_dim, inp_dim) - h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h) - h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h) - - w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w) - w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w) - - fn_sum = _adaptive_pooling_idx_sum( - [h_kernel_max, w_kernel_max], - [h_start_index, w_start_index], - [h_end_index, w_end_index], + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[inp_h, inp_w], + out_sizes=[out_h, out_w], + pooling_fn=ops.add, ) def fn(idx): @@ -4761,6 +4788,207 @@ def fn(idx): return rv +fallback_avg_pool3d_backward = fallback_handler( + aten.avg_pool3d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None) +def avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 3 + assert len(stride) == 3 + assert len(padding) == 3 + assert len(x.get_size()) in (4, 5) + + grad_output.realize_hint() + + *batch, depth, height, width = x.get_size() + + d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode) + h_out, ceil_mode_h = pooling_size( + height, 1, kernel_size, stride, padding, ceil_mode + ) + w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode) + + grad_loader = grad_output.make_loader() + had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w + + *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + d_window_size, h_window_size, w_window_size = ( + max( + max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1) + for d in range(kernel_size[i] * 2) + ) + for i in range(3) + ) + + window_size = d_window_size * h_window_size * w_window_size + if window_size > 125: + # Kernel size too big. Results in hard-to-optimize Triton code. + return fallback_avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(pd, ph, pw): + stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride) + pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding) + kernel_d, kernel_h, kernel_w = ( + ops.constant(k, torch.int32) for k in kernel_size + ) + + dstart, hstart, wstart = ( + ops.sub(ops.mul(p, s), pad) + for p, s, pad in zip( + [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w] + ) + ) + dend, hend, wend = ( + ops.minimum( + ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad) + ) + for start, k, dim, pad in zip( + [dstart, hstart, wstart], + [kernel_d, kernel_h, kernel_w], + [depth, height, width], + [pad_d, pad_h, pad_w], + ) + ) + dstart, hstart, wstart = ( + ops.maximum(start, ops.constant(0, torch.int32)) + for start in [dstart, hstart, wstart] + ) + dend, hend, wend = ( + ops.minimum(end, ops.index_expr(dim, torch.int32)) + for end, dim in zip([dend, hend, wend], [depth, height, width]) + ) + divide_factor = ops.mul( + ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart) + ) + return divide_factor + + def fn(idx): + *prefix, d, h, w = idx + d, h, w = (v + pad for v, pad in zip([d, h, w], padding)) + + pdstart, phstart, pwstart = ( + ops.index_expr(FloorDiv(v - k + s, s), torch.int32) + for v, k, s in zip([d, h, w], kernel_size, stride) + ) + + pdend, phend, pwend = ( + ops.index_expr(FloorDiv(v, s) + 1, torch.int32) + for v, s in zip([d, h, w], stride) + ) + + pdstart, phstart, pwstart = ( + ops.maximum(pstart, ops.constant(0, torch.int32)) + for pstart in [pdstart, phstart, pwstart] + ) + pdend, phend, pwend = ( + ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32)) + for pend, pooled_dim in zip( + [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width] + ) + ) + + gradient = None + # Iterate over the 3D region to accumulate gradients + for pd_ in range(d_window_size): + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + pd, ph, pw = ( + ops.add(pstart, ops.constant(p_, torch.int32)) + for pstart, p_ in zip( + [pdstart, phstart, pwstart], [pd_, ph_, pw_] + ) + ) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] * kernel_size[2] + else: + scale = compute_pool_size_without_padding(pd, ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + pd, ops.sub(pdend, ops.constant(1, torch.int32)) + ), + pooled_depth, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where( + mask, part, ops.constant(0.0, torch.float32) + ) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + def _validate_reduction_axis(x, axis): size = x.get_size() if isinstance(axis, int): From a32157c67c179ba13d20da0a635a0d4c6d179deb Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 10 Jun 2024 19:21:00 -0700 Subject: [PATCH 623/706] Mark params static if inlining modules and freezing (#128355) Today inlining builtin nn modules is not compatible with parameter freezing. Freezing parameters and then constant folding them through the graph relies on the assumption that they will not be inputs and will be static across calls to the same graph. When inlining builtin nn modules this assumption is broken and we reuse the same graph for different instances of the same nn module. There are three options 1) abandon constant folding, 2) create a dispatcher layer (like cudagraphs) which will dispatch to the correct constant-folded graph for each distinct set of parameters or 3) recompile This PR implements 3 by introducing guards on the parameter pointers. This was due to freezing being relatively rare and performance sensistive. 2 Had many more unknowns and 1 is not a viable option due to the drop in performance. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128355 Approved by: https://github.com/anijain2305 --- torch/_dynamo/variables/builder.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f31f5c97eb62..2f10fba1b370 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1128,6 +1128,19 @@ def wrap_module(self, value: torch.nn.Module): if mutation_guard.is_dynamic_nn_module(value, self.tx.export): # created dynamically, don't specialize on it self.install_guards(GuardBuilder.TYPE_MATCH) + if ( + torch._dynamo.config.inline_inbuilt_nn_modules + and torch._inductor.config.freezing + and not torch.is_grad_enabled() + ): + from ..decorators import mark_static_address + + for p in value.parameters(): + mark_static_address(p) + + for b in value.buffers(): + mark_static_address(b) + result = UnspecializedNNModuleVariable(value, source=self.source) if not SideEffects.cls_supports_mutation_side_effects(type(value)): # don't allow STORE_ATTR mutation with custom __setattr__ From 402b289f3b8e1aa07bdc6ce4922e073477f9827c Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 10 Jun 2024 19:21:01 -0700 Subject: [PATCH 624/706] Properly register parameter for binary folding test (#128356) This PR properly registers the tensor used in the module compute as a parameter. This bug was hidden previously because all tensors on the nn modules would be considered constant by dynamo, with inlining NN modules, this is no longer the case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128356 Approved by: https://github.com/anijain2305 ghstack dependencies: #128355 --- test/inductor/test_binary_folding.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_binary_folding.py b/test/inductor/test_binary_folding.py index a8e6392892f7..1a25e81ebf27 100644 --- a/test/inductor/test_binary_folding.py +++ b/test/inductor/test_binary_folding.py @@ -56,7 +56,7 @@ def __init__(self, in_channels, out_channels, device, **kwargs): self.use_scalar = scalar tensor_size = [1 for _ in range(self.conv.weight.ndim)] tensor_size[1] = self.conv.weight.size(0) - self.tensor = ( + self.tensor = torch.nn.Parameter( add_tensor if add_tensor is not None else torch.rand(tensor_size).to(device) @@ -136,7 +136,11 @@ def my_inner_compile(gm, example_inputs, *args, **kwargs): nn.Conv2d, pytorch_op, False, - add_tensor=torch.rand(32, 1, 32).to(self.device), + add_tensor=torch.rand( + 32, + 1, + 32, + ).to(self.device), expect_success=False, ) @@ -156,7 +160,7 @@ def my_inner_compile(gm, example_inputs, *args, **kwargs): nn.Conv2d, pytorch_op, False, - add_tensor=torch.tensor([2]).to(torch.int).to(self.device), + add_tensor=torch.tensor([2]).to(torch.float64).to(self.device), expect_success=False, ) From f2d7f235a684c593f5a1ff2ca0b47b47274bfe85 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 10 Jun 2024 16:38:15 -0700 Subject: [PATCH 625/706] [dynamo][yolov3] Track UnspecializedNNModuleVariable for mutation (#128269) Fixes https://github.com/pytorch/pytorch/issues/101168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128269 Approved by: https://github.com/jansel ghstack dependencies: #128295, #126578, #128268, #128254 --- .../aot_eager_torchbench_inference.csv | 2 +- .../aot_eager_torchbench_training.csv | 2 +- .../aot_inductor_torchbench_inference.csv | 2 +- .../cpu_inductor_torchbench_freezing_inference.csv | 2 +- .../cpu_inductor_torchbench_inference.csv | 2 +- .../cu124/aot_inductor_torchbench_inference.csv | 2 +- .../dynamic_aot_eager_torchbench_inference.csv | 2 +- .../dynamic_aot_eager_torchbench_training.csv | 2 +- .../dynamic_cpu_inductor_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_training.csv | 2 +- .../dynamo_eager_torchbench_inference.csv | 2 +- .../dynamo_eager_torchbench_training.csv | 2 +- .../inductor_torchbench_inference.csv | 2 +- .../ci_expected_accuracy/inductor_torchbench_training.csv | 2 +- torch/_dynamo/output_graph.py | 8 +++++++- 16 files changed, 22 insertions(+), 16 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 20fb340690ac..68331f317995 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -378,4 +378,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 5131c2e9ade4..20a5e024ece5 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -286,4 +286,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,fail_accuracy,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv index 65c905837c2a..1624d6dc7973 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv @@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,fail_to_run,0 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv index 3942e3a2f343..3af215541c1d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv @@ -338,4 +338,4 @@ vision_maskrcnn,pass,28 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index fcd87f4d2454..a497fb45d7d4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -338,4 +338,4 @@ vision_maskrcnn,pass,28 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv index 65c905837c2a..1624d6dc7973 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv @@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,fail_to_run,0 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index bcdf06917b64..fd84df653db1 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -374,4 +374,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index 1e1a4be4149e..c010e129c19b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -282,4 +282,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,fail_accuracy,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index ce271939b18c..5ffc870a8dec 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -298,4 +298,4 @@ vision_maskrcnn,pass,28 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 3f60be5afd97..f0417110484e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -374,4 +374,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index ee58808c0bb0..82c4c1da2317 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -282,4 +282,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 20fb340690ac..68331f317995 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -378,4 +378,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index cfc524426644..30808bc6bcd4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -286,4 +286,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index 108bc6543aa9..b4700da57b25 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -378,4 +378,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index cfc524426644..30808bc6bcd4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -286,4 +286,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ee43db8524a7..946bc52d7182 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -752,7 +752,13 @@ def register_attr_or_module( **options, ): if is_dynamic_nn_module(target, self.root_tx.export): - return variables.UnspecializedNNModuleVariable(target, **options) + result = variables.UnspecializedNNModuleVariable(target, **options) + if not SideEffects.cls_supports_mutation_side_effects(type(target)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.root_tx.output.side_effects.track_object_existing( + target, result + ) options = dict(options) assert "source" in options From a206dcc79e048b169962e41d1b14f5fef946dd03 Mon Sep 17 00:00:00 2001 From: Colin L Reliability Rice Date: Tue, 11 Jun 2024 07:46:12 +0000 Subject: [PATCH 626/706] fb_memcache: Move to fbcode from thirdparty (#128174) Summary: The fb_memcache injections location and path is changing. Test Plan: Existing tests should pass. Reviewed By: bertmaher, oulgen Differential Revision: D57973772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128174 Approved by: https://github.com/oulgen --- test/inductor/test_codecache.py | 2 +- test/inductor/test_max_autotune.py | 2 +- torch/_inductor/codecache.py | 2 +- torch/_inductor/compile_fx.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 10 +++++++--- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 21d70d90d290..3ef39adeed3d 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -195,7 +195,7 @@ def put(self, filename, data): num_put += 1 cache_module = ( - "triton.runtime.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend" + "triton.fb.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend" if config.is_fbcode() else "torch._inductor.remote_cache.RedisRemoteCacheBackend" ) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index bd74ea58ad59..176f0dda606d 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -267,7 +267,7 @@ def put(self, filename, data): num_put += 1 cache_module = ( - "triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend" + "triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend" if config.is_fbcode() else "torch._inductor.remote_cache.RedisRemoteCacheBackend" ) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index ae8453660813..d151e3673474 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1021,7 +1021,7 @@ def load( cache_id = "fx-graph-v1" try: if config.is_fbcode(): - from triton.runtime.fb_memcache import ( + from triton.fb.fb_memcache import ( FbMemcacheRemoteFxGraphCacheBackend, ) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index dc3b1b811a6b..d49ed38902cb 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -400,7 +400,7 @@ def should_use_remote_fx_graph_cache(): return False try: - from triton.runtime.fb_memcache import MEMCACHE_VERSION + from triton.fb.fb_memcache import MEMCACHE_VERSION except ModuleNotFoundError: return False diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5e05368e0a11..5396ccf3e70d 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1031,7 +1031,7 @@ def should_use_remote_autotune_cache(inductor_meta): if inductor_meta.get("is_hip"): return False - from triton.runtime.fb_memcache import MEMCACHE_VERSION + from triton.fb.fb_memcache import MEMCACHE_VERSION return MEMCACHE_VERSION >= torch._utils_internal.justknobs_getval_int( "pytorch/remote_cache:autotune_memcache_version" @@ -1075,8 +1075,12 @@ def cached_autotune( try: if inductor_meta.get("is_fbcode"): - remote_cache = triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend( - key + import triton.fb.fb_memcache + + remote_cache = ( + triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend( + key + ) ) else: from torch._inductor.remote_cache import RedisRemoteCacheBackend From 207c2248a881d261ce42566cde7ca25b134fd382 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Tue, 11 Jun 2024 00:16:15 +0100 Subject: [PATCH 627/706] [inductor] Fix lowering full with SymBool value (#128213) Fixes #128161, fixes #128095 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128213 Approved by: https://github.com/lezcano --- test/inductor/test_torchinductor.py | 8 ++++++++ ...st_torchinductor_codegen_dynamic_shapes.py | 1 + torch/_inductor/lowering.py | 16 +++++++-------- torch/_inductor/sizevars.py | 11 +++++++++- torch/_prims_common/__init__.py | 20 +++++++++++-------- torch/utils/_sympy/value_ranges.py | 7 ++++++- 6 files changed, 45 insertions(+), 18 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index a12d68c18c74..b33e01aebbdc 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5537,6 +5537,14 @@ def fn(a): for dtype in all_types(): self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),)) + def test_full_boolean(self): + def fn(n): + x = torch.full((1,), n >= 1024, device=self.device) + return x, x + 1 + + self.common(fn, (1024,)) + self.common(fn, (1023,)) + def test_index1(self): def fn(a, b, c): return aten.index(a, [b, c]) diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 1f641389bfd0..bd036810d4c1 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -121,6 +121,7 @@ def run(*ex, **kwargs): "test_conv2d_channels_last_dynamic_shapes": TestFailure(("cpu",)), "test_conv3d_channels_last_dynamic_shapes": TestFailure(("cpu",)), "test_expand_dynamic_shapes": TestFailure(("cpu",)), + "test_full_boolean_dynamic_shapes": TestFailure(("cpu",)), "test_glu_dynamic_shapes": TestFailure(("cpu",)), "test_isinf2_dynamic_shapes": TestFailure(("cpu",)), "test_linspace1_dynamic_shapes": TestFailure(("cpu",)), diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index bff5fb4073af..e3457a27aa94 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -177,7 +177,7 @@ def is_boolean_type(x): def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): def construct_input(inp): - if isinstance(inp, (Number, sympy.Expr)): + if isinstance(inp, (Number, sympy.Basic)): return inp else: assert hasattr(inp, "get_dtype") @@ -216,7 +216,7 @@ def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool): promoting_args = [ a for a in args - if isinstance(a, (Number, sympy.Expr)) + if isinstance(a, (Number, sympy.Basic)) or getattr(a, "dtype", None) is not None ] dtype = get_promoted_dtype( @@ -368,15 +368,15 @@ def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=No if override_return_dtype is None and type_promotion_kind is None: type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - if not any(isinstance(x, (sympy.Expr, int, float)) for x in inputs): + if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs): return inputs - if all(isinstance(x, (int, float, sympy.Expr)) for x in inputs): + if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs): dtype = override_return_dtype or get_promoted_dtype( *inputs, type_promotion_kind=type_promotion_kind ) def const_func(x): - if isinstance(x, sympy.Expr): + if isinstance(x, sympy.Basic): return ir.IndexingConstant(x, dtype, decode_device(None)) else: return ir.Constant(x, dtype, decode_device(None)) @@ -391,7 +391,7 @@ def const_func(x): ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) ) ) - elif isinstance(x, sympy.Expr): + elif isinstance(x, sympy.Basic): out.append( ExpandView.create( IndexingConstant(x, ex.get_dtype(), ex.get_device()), @@ -2470,7 +2470,7 @@ def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): ranges: List[sympy.Expr] = [] - if isinstance(data, sympy.Expr): + if isinstance(data, sympy.Basic): def inner_fn(index): return ops.index_expr(data, dtype) @@ -2596,7 +2596,7 @@ def _full(fill_value, device, dtype, size): def inner_fn(index): return ops.constant(value, dtype) - elif isinstance(value, sympy.Expr): + elif isinstance(value, sympy.Basic): def inner_fn(index): return ops.index_expr(value, dtype) diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 8ea55b024c59..f48c0884d3ad 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -193,7 +193,16 @@ def _simplify_loops_impl( """ sizes = list(map(self.simplify, sizes)) - strides = [self.stride_vars(x, index_vars) for x in index_formulas] + strides = [ + # index_formulas may contain boolean expressions (e.g. s0 < 10), + # for which "strides" don't make sense so we ignore them here. + # NOTE: These expressions may still block merging dims in the sound + # substitution test performed in can_merge_dims. + self.stride_vars(x, index_vars) + if isinstance(x, sympy.Expr) + else [0] * len(index_vars) + for x in index_formulas + ] assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) for i in range(len(sizes)): diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 11b97403f308..c05b0ebf10e7 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -1046,17 +1046,17 @@ def type_to_dtype(typ: type) -> torch.dtype: assert isinstance(typ, type) - if typ is bool: + if typ in (bool, torch.SymBool): return torch.bool - if typ in [int, torch.SymInt]: + if typ in (int, torch.SymInt): return torch.long - if typ in [float, torch.SymFloat]: + if typ in (float, torch.SymFloat): return torch.get_default_dtype() # TODO: sym_complex_float? if typ is complex: return corresponding_complex_dtype(torch.get_default_dtype()) - raise ValueError("Invalid type!") + raise ValueError(f"Invalid type {typ}!") def get_dtype(x: Union[torch.Tensor, NumberType]): @@ -1363,8 +1363,12 @@ def number_type( return type(x) -def expr_type(x: sympy.Expr) -> Type: - if x.is_integer: # type: ignore[attr-defined] +def expr_type(x: sympy.Basic) -> Type: + import sympy + + if x.kind is sympy.core.kind.BooleanKind: + return bool + elif x.is_integer: # type: ignore[attr-defined] return int else: # NB: Not strictly correct, but we don't support SymPy complex or bool. @@ -1471,13 +1475,13 @@ def elementwise_dtypes( import sympy for x in args: - if not isinstance(x, (Number, TensorLike, sympy.Expr)): + if not isinstance(x, (Number, TensorLike, sympy.Basic)): msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!" raise ValueError(msg) if isinstance(x, Number): highest_type = get_higher_type(highest_type, number_type(x)) - elif isinstance(x, sympy.Expr): + elif isinstance(x, sympy.Basic): highest_type = get_higher_type(highest_type, expr_type(x)) else: # x is a TensorLike diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index d16da832459a..c1ed0b02946d 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -89,7 +89,10 @@ def sympy_generic_le(lower, upper): return lower <= upper else: # only negative condition is True > False - assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean) + assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), ( + lower, + upper, + ) return not (lower and not upper) @@ -945,6 +948,8 @@ def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): if dtype == torch.bool: if x.is_singleton(): return ValueRanges.wrap(x.lower != 0) + elif x.is_bool: + return x elif 0 not in x: return ValueRanges.wrap(sympy.true) else: From 648625b230e8e6e7478fb219ff4f0aa6a45070f5 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 11 Jun 2024 11:48:57 +0800 Subject: [PATCH 628/706] Make TraceUtils.h to be device-agnostic (#126969) Some features of third-party devices depend on TraceUtils.h, so some of the CUDA code was removed and split into NCCLUtils files. In addition, some common functions still remain in TraceUtils.h since I'm not sure if other devices will use them later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126969 Approved by: https://github.com/c-p-i-o --- torch/csrc/distributed/c10d/NCCLUtils.cpp | 50 +++ torch/csrc/distributed/c10d/NCCLUtils.hpp | 436 ++++++++++++++++++- torch/csrc/distributed/c10d/TraceUtils.h | 497 ---------------------- 3 files changed, 485 insertions(+), 498 deletions(-) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index bc820fc1c8d5..db268371ea0f 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -238,6 +238,56 @@ std::string getNcclErrorDetailStr( return interpret + err; } +void DebugInfoWriter::write(const std::string& ncclTrace) { + // Open a file for writing. The ios::binary flag is used to write data as + // binary. + std::ofstream file(filename_, std::ios::binary); + + // Check if the file was opened successfully. + if (!file.is_open()) { + LOG(ERROR) << "Error opening file for writing NCCLPG debug info: " + << filename_; + return; + } + + file.write(ncclTrace.data(), ncclTrace.size()); + LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_; +} + +DebugInfoWriter& DebugInfoWriter::getWriter(int rank) { + if (writer_ == nullptr) { + std::string fileNamePrefix = getCvarString( + {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); + // Using std::unique_ptr here to auto-delete the writer object + // when the pointer itself is destroyed. + std::unique_ptr writerPtr( + new DebugInfoWriter(fileNamePrefix, rank)); + DebugInfoWriter::registerWriter(std::move(writerPtr)); + } + return *writer_; +} + +void DebugInfoWriter::registerWriter(std::unique_ptr writer) { + TORCH_CHECK_WITH( + DistBackendError, + hasWriterRegistered_.load() == false, + "debugInfoWriter already registered"); + hasWriterRegistered_.store(true); + writer_ = std::move(writer); +} + +std::unique_ptr DebugInfoWriter::writer_ = nullptr; +std::atomic DebugInfoWriter::hasWriterRegistered_(false); + +float getDurationFromEvent( + at::cuda::CUDAEvent& ncclStartEvent, + at::cuda::CUDAEvent& ncclEndEvent) { + TORCH_CHECK( + ncclEndEvent.query(), + "getDuration can only be called after work is succeeded.") + return ncclStartEvent.elapsed_time(ncclEndEvent); +} + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 9ce25b55dc13..4aa4b15b2917 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -10,9 +10,11 @@ #include #include +#include #include #include #include +#include #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ (NCCL_MINOR >= 14) @@ -172,6 +174,39 @@ namespace c10d { +static c10::IValue entries_key = "entries"; +static c10::IValue nccl_comm_key = "nccl_comm_state"; +static c10::IValue version_key = "version"; +// Update whenever changing contents or formatting of the dump +// (minor when adding fields, major when changing existing fields) +static c10::IValue version_val = "2.2"; +static c10::IValue pg_config_key = "pg_config"; +static c10::IValue record_id_key = "record_id"; +static c10::IValue pg_id_key = "pg_id"; +static c10::IValue pg_name_key = "process_group"; +static c10::IValue collective_seq_id_key = "collective_seq_id"; +static c10::IValue p2p_seq_id_key = "p2p_seq_id"; +static c10::IValue is_p2p_key = "is_p2p"; +static c10::IValue op_id_key = "op_id"; +static c10::IValue profiling_name_key = "profiling_name"; +static c10::IValue input_sizes_key = "input_sizes"; +static c10::IValue input_dtypes_key = "input_dtypes"; +static c10::IValue output_sizes_key = "output_sizes"; +static c10::IValue output_dtypes_key = "output_dtypes"; +static c10::IValue time_created_key = "time_created_ns"; +static c10::IValue duration_key = "duration_ms"; +static c10::IValue timeout_key = "timeout_ms"; + +static c10::IValue frames_key = "frames"; +static c10::IValue state_key = "state"; +static c10::IValue line_key = "line"; +static c10::IValue name_key = "name"; +static c10::IValue filename_key = "filename"; +static c10::IValue retired_key = "retired"; +static c10::IValue time_discovered_started_key = "time_discovered_started_ns"; +static c10::IValue time_discovered_completed_key = + "time_discovered_completed_ns"; + TORCH_API size_t hashTensors(const std::vector& tensors); TORCH_API std::string getNcclVersion(); TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error); @@ -195,7 +230,7 @@ TORCH_API std::string getNcclErrorDetailStr( // auto-registered). class TORCH_API DebugInfoWriter { public: - virtual ~DebugInfoWriter(); + virtual ~DebugInfoWriter() = default; virtual void write(const std::string& ncclTrace); static DebugInfoWriter& getWriter(int rank); static void registerWriter(std::unique_ptr writer); @@ -518,6 +553,405 @@ struct ncclRedOpRAII { bool premul_sum_ = false; }; +/* Helper used by work::getDuration() and nccl flight recorder */ +float getDurationFromEvent( + at::cuda::CUDAEvent& ncclStartEvent, + at::cuda::CUDAEvent& ncclEndEvent); + +struct NCCLTraceBuffer { + static NCCLTraceBuffer* get() { + // intentionally leak on exit + // because this will hold python state that may get destructed + static NCCLTraceBuffer* instance = new NCCLTraceBuffer(); + return instance; + } + NCCLTraceBuffer() { + max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0); + capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false); + enabled_ = max_entries_ > 0; + } + using Event = at::cuda::CUDAEvent; + struct Entry { + size_t id_; // incremented id in the trace buffer + // used to figure out where in the circular entries + // buffer this entry will be located to + // update state information + size_t pg_id_; + std::tuple pg_name_; // + + // collective_seq_id and p2p_seq_id refer to actual kernel launches (e.g. 1 + // per coalesced group). + // collective_seq_id only increments for true collective operations (over + // all ranks in the group). p2p_seq_id only increments over non-collective + // operations in the group. op_id refers to logical operations (e.g. one per + // op inside coalesced group) + size_t collective_seq_id_; + size_t p2p_seq_id_; + size_t op_id_; + std::string profiling_name_; + + std::shared_ptr traceback_; + // we borrow pointers to start_ and end_ so we can query the state + // on reporting. However, once the event is completed, the call + // to `complete` will clear these. + Event *start_, *end_; + + // timestamp when the entry was created, likely close to the time the work + // was 'enqueued'- not necessarily started + c10::time_t time_created_; + + // configured timeout for this entry + c10::time_t timeout_ms_; + + // Is this a P2P event? + bool isP2P_; + + std::optional duration_; + + // timestamp when our CPU threads discovered that the kernel started. + // will always be _after_ it actually started, and can be very late + // if the watchdog thread got stuck on CUDA APIs. + std::optional time_discovered_started_; + + // timestamp when our CPU threads discovered that the kernel completed. + // will always be _after_ it actually complated, and can be the same time + // as the discovery of the start if the watchdog thread is stuck on CUDA + // APIs + std::optional time_discovered_completed_; + + // size information for input/output tensors + c10::SmallVector input_dims_; + std::vector input_dtypes_; + c10::SmallVector output_dims_; + std::vector output_dtypes_; + c10::SmallVector sizes_; // flattened from inputs, outputs + bool retired_ = false; // is this work entry no longer in the workMetaList_? + // a retired but not completed event has timed out + }; + + bool enabled_ = false; + bool capture_cpp_stack_ = false; + std::mutex mutex_; + std::vector entries_; + size_t max_entries_ = 0; + size_t next_ = 0; + size_t id_ = 0; + std::map, std::vector> + pg_name_to_ranks_ = {}; + + std::optional record( + size_t pg_id, + const std::tuple& pg_name, + size_t collective_seq_id, + size_t p2p_seq_id, + size_t op_id, + std::string profiling_name, + const std::vector& inputs, + const std::vector& outputs, + Event* start, + Event* end, + std::chrono::milliseconds timeout_ms, + bool isP2P) { + if (!enabled_) { + return c10::nullopt; + } + auto traceback = + torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); + std::lock_guard guard(mutex_); + + auto te = Entry{ + id_, + pg_id, + pg_name, + collective_seq_id, + p2p_seq_id, + op_id, + std::move(profiling_name), + std::move(traceback), + std::move(start), + std::move(end), + c10::getTime(), + timeout_ms.count(), + isP2P}; + + for (const auto& input : inputs) { + c10::IntArrayRef sizes = input.sizes(); + te.input_dtypes_.push_back(input.dtype().toScalarType()); + te.input_dims_.push_back(sizes.size()); + te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); + } + + for (const auto& output : outputs) { + c10::IntArrayRef sizes = output.sizes(); + te.output_dtypes_.push_back(output.dtype().toScalarType()); + te.output_dims_.push_back(sizes.size()); + te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); + } + + if (entries_.size() < max_entries_) { + entries_.emplace_back(std::move(te)); + } else { + entries_[next_++] = std::move(te); + if (next_ == max_entries_) { + next_ = 0; + } + } + return id_++; + } + + void record_pg_ranks( + const std::tuple& pg_name, + std::vector ranks) { + if (!enabled_) { + return; + } + std::lock_guard guard(mutex_); + pg_name_to_ranks_[pg_name] = ranks; + } + + void update_state(Entry& r) { + if (r.start_ != nullptr) { + bool started = r.start_->query(); + if (started && !r.time_discovered_started_) { + r.time_discovered_started_ = c10::getTime(); + } + } + if (r.end_ != nullptr) { + bool completed = r.end_->query(); + if (completed && !r.time_discovered_completed_) { + r.time_discovered_completed_ = c10::getTime(); + } + } + } + + std::vector dump_entries() { + std::lock_guard guard(mutex_); + std::vector result; + result.reserve(entries_.size()); + result.insert(result.end(), entries_.begin() + next_, entries_.end()); + result.insert(result.end(), entries_.begin(), entries_.begin() + next_); + // query any remaining events + for (auto& r : result) { + update_state(r); + r.start_ = r.end_ = nullptr; + } + return result; + } + + /* + Mark an Event as completed and free its events. + This is called by the watchdog thread, and is asynchronous from the + perspective of the main thread. + compute_duration defaults to true since retire_id is only called in the + watchdog thread, which is currently a place we call cuda APIs which may hang, + but care should be taken to avoid computing duration in any function that must + never hang. (timing must also be enabled for compute_duration - see + TORCH_NCCL_ENABLE_TIMING). + */ + void retire_id(std::optional id, bool compute_duration = true) { + if (!enabled_ || !id) { + return; + } + + bool can_compute_duration = false; + Event* startEvent = nullptr; + Event* endEvent = nullptr; + std::optional duration = c10::nullopt; + + std::unique_lock guard(mutex_); + + Entry* entry = &entries_.at(*id % max_entries_); + if (entry->id_ == *id) { + update_state(*entry); + + if (compute_duration) { + can_compute_duration = entry->time_discovered_completed_.has_value() && + entry->start_ && entry->end_; + startEvent = entry->start_; + endEvent = entry->end_; + } + } + + if (can_compute_duration) { + // Compute duration without without holding the lock, because + // cudaEventDuration() can hang, and we need to acquire the lock before we + // can dump(), which we never want to block. + guard.unlock(); + duration = getDurationFromEvent(*startEvent, *endEvent); + guard.lock(); + + // Refresh the entry pointer, see if the entry has been overwritten + entry = &entries_.at(*id % max_entries_); + if (entry->id_ != *id) { + LOG(INFO) + << "retire_id abandoned for id " << *id + << ", event was overwritten while waiting to compute duration."; + return; + } + if (duration.has_value()) { + entry->duration_ = duration.value(); + } + } + + entry->retired_ = true; + entry->start_ = entry->end_ = nullptr; + } + + const c10::List getCollectiveTrace( + bool includeStacktraces, + bool onlyActive) { + auto entries = new_list(); + auto result = dump_entries(); + std::vector tracebacks; + torch::SymbolizedTracebacks stracebacks; + std::vector all_frames; + if (includeStacktraces) { + for (auto& e : result) { + tracebacks.push_back(e.traceback_.get()); + } + stracebacks = torch::symbolize(tracebacks); + for (const auto& f : stracebacks.all_frames) { + auto d = new_dict(); + d.insert(name_key, f.funcname); + d.insert(filename_key, f.filename); + d.insert(line_key, int64_t(f.lineno)); + all_frames.emplace_back(std::move(d)); + } + } + for (auto i : c10::irange(result.size())) { + auto dict = new_dict(); + auto& e = result.at(i); + // Skip completed events + if (onlyActive && e.time_discovered_completed_.has_value()) { + continue; + } + + if (includeStacktraces) { + auto& tb = stracebacks.tracebacks.at(i); + auto frames = new_list(); + for (int64_t frame : tb) { + frames.push_back(all_frames.at(frame)); + } + dict.insert(frames_key, frames); + } + + dict.insert(record_id_key, int64_t(e.id_)); + dict.insert(pg_id_key, int64_t(e.pg_id_)); + dict.insert(pg_name_key, e.pg_name_); + dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_)); + dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_)); + dict.insert(op_id_key, int64_t(e.op_id_)); + dict.insert(profiling_name_key, e.profiling_name_); + dict.insert(time_created_key, int64_t(e.time_created_)); + if (e.duration_) { + dict.insert(duration_key, *e.duration_); + } + + auto it = e.sizes_.begin(); + auto read_sizes = [&](const c10::SmallVector& dims) { + auto sizes = new_list(); + for (auto dim : dims) { + auto arg_sizes = new_list(); + for (auto i : c10::irange(dim)) { + (void)i; + arg_sizes.push_back(*it++); + } + sizes.push_back(arg_sizes); + } + return sizes; + }; + + dict.insert(input_sizes_key, read_sizes(e.input_dims_)); + std::vector input_dtypes_strs; + input_dtypes_strs.reserve(e.input_dtypes_.size()); + for (const auto& input_dtype : e.input_dtypes_) { + input_dtypes_strs.push_back(c10::toString(input_dtype)); + } + dict.insert(input_dtypes_key, input_dtypes_strs); + dict.insert(output_sizes_key, read_sizes(e.output_dims_)); + std::vector output_dtypes_strs; + output_dtypes_strs.reserve(e.output_dtypes_.size()); + for (const auto& output_dtype : e.output_dtypes_) { + output_dtypes_strs.push_back(c10::toString(output_dtype)); + } + dict.insert(output_dtypes_key, output_dtypes_strs); + if (e.time_discovered_completed_.has_value()) { + dict.insert(state_key, "completed"); + } else if (e.time_discovered_started_.has_value()) { + dict.insert(state_key, "started"); + } else { + dict.insert(state_key, "scheduled"); + } + + dict.insert( + time_discovered_started_key, + e.time_discovered_started_.has_value() + ? int64_t(*e.time_discovered_started_) + : c10::IValue()); + dict.insert( + time_discovered_completed_key, + e.time_discovered_completed_.has_value() + ? int64_t(*e.time_discovered_completed_) + : c10::IValue()); + dict.insert(retired_key, e.retired_); + dict.insert(timeout_key, e.timeout_ms_); + dict.insert(is_p2p_key, e.isP2P_); + + entries.push_back(dict); + } + return entries; + } + + // dump pg_entries + const c10::Dict getPgConfig() { + auto pg_config = new_dict(); + for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { + auto pg_info = new_dict(); + pg_info.insert("name", std::get<0>(pg_name)); + pg_info.insert("desc", std::get<1>(pg_name)); + pg_info.insert("ranks", ranks_str(ranks)); + pg_config.insert(std::get<0>(pg_name), pg_info); + } + return pg_config; + } + + // dump all collectives + ncclDumpMap + std::string dump( + const std::optional>>& ncclDumpMap, + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + auto result = new_dict(); + // common values + result.insert(version_key, version_val); + result.insert(pg_config_key, getPgConfig()); + + // collective trace + if (includeCollectives) { + result.insert( + entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); + } + + // convert ncclDumpMap into a dictionary + auto per_comm_dict = new_dict(); + if (ncclDumpMap.has_value()) { + for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) { + auto inner_dict = new_dict(); + for (const auto& [key, value] : ncclDump) { + inner_dict.insert(key, value); + } + per_comm_dict.insert(ncclId, inner_dict); + } + } + if (per_comm_dict.size() > 0) { + result.insert(nccl_comm_key, per_comm_dict); + } + return pickle_str(result); + } +}; + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index de623d77fe9e..9c469dbd5bc6 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -10,11 +10,6 @@ #include #include -#ifdef USE_C10D_NCCL -#include -#include -#endif - #include #include #include @@ -24,41 +19,6 @@ namespace c10d { -static c10::IValue entries_key = "entries"; -static c10::IValue nccl_comm_key = "nccl_comm_state"; -static c10::IValue version_key = "version"; -// Update whenever changing contents or formatting of the dump -// (minor when adding fields, major when changing existing fields) -static c10::IValue version_val = "2.2"; -static c10::IValue pg_config_key = "pg_config"; -static c10::IValue record_id_key = "record_id"; -static c10::IValue pg_id_key = "pg_id"; -static c10::IValue pg_name_key = "process_group"; -static c10::IValue collective_seq_id_key = "collective_seq_id"; -static c10::IValue p2p_seq_id_key = "p2p_seq_id"; -static c10::IValue is_p2p_key = "is_p2p"; -static c10::IValue op_id_key = "op_id"; -static c10::IValue profiling_name_key = "profiling_name"; -static c10::IValue input_sizes_key = "input_sizes"; -static c10::IValue input_dtypes_key = "input_dtypes"; -static c10::IValue output_sizes_key = "output_sizes"; -static c10::IValue output_dtypes_key = "output_dtypes"; -static c10::IValue time_created_key = "time_created_ns"; -static c10::IValue duration_key = "duration_ms"; -static c10::IValue timeout_key = "timeout_ms"; - -static c10::IValue frames_key = "frames"; -static c10::IValue state_key = "state"; -static c10::IValue line_key = "line"; -static c10::IValue name_key = "name"; -static c10::IValue filename_key = "filename"; -static c10::IValue retired_key = "retired"; -static c10::IValue time_discovered_started_key = "time_discovered_started_ns"; -static c10::IValue time_discovered_completed_key = - "time_discovered_completed_ns"; - -/* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */ - inline std::string getTraceStartKey(const std::string& pgName, int rank) { return pgName + "_" + std::to_string(rank) + "_trace_start"; } @@ -303,66 +263,6 @@ inline std::string retrieveDesyncReport( return report; } -/* Trace Utils Related to Flight Recorder */ - -/* Note: this is only used by PGNCCL (could be generalized in an ideal world but - * wasn't done that way, so isn't expected to be fully general at the moment) */ - -#ifdef USE_C10D_NCCL - -/* Helper used by work::getDuration() and nccl flight recorder */ -float getDurationFromEvent( - at::cuda::CUDAEvent& ncclStartEvent, - at::cuda::CUDAEvent& ncclEndEvent) { - TORCH_CHECK( - ncclEndEvent.query(), - "getDuration can only be called after work is succeeded.") - return ncclStartEvent.elapsed_time(ncclEndEvent); -} - -DebugInfoWriter::~DebugInfoWriter() = default; - -void DebugInfoWriter::write(const std::string& ncclTrace) { - // Open a file for writing. The ios::binary flag is used to write data as - // binary. - std::ofstream file(filename_, std::ios::binary); - - // Check if the file was opened successfully. - if (!file.is_open()) { - LOG(ERROR) << "Error opening file for writing NCCLPG debug info: " - << filename_; - return; - } - - file.write(ncclTrace.data(), ncclTrace.size()); - LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_; -} - -DebugInfoWriter& DebugInfoWriter::getWriter(int rank) { - if (writer_ == nullptr) { - std::string fileNamePrefix = getCvarString( - {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); - // Using std::unique_ptr here to auto-delete the writer object - // when the pointer itself is destroyed. - std::unique_ptr writerPtr( - new DebugInfoWriter(fileNamePrefix, rank)); - DebugInfoWriter::registerWriter(std::move(writerPtr)); - } - return *writer_; -} - -void DebugInfoWriter::registerWriter(std::unique_ptr writer) { - TORCH_CHECK_WITH( - DistBackendError, - hasWriterRegistered_.load() == false, - "debugInfoWriter already registered"); - hasWriterRegistered_.store(true); - writer_ = std::move(writer); -} - -std::unique_ptr DebugInfoWriter::writer_ = nullptr; -std::atomic DebugInfoWriter::hasWriterRegistered_(false); - inline std::string pickle_str(const c10::IValue& v) { std::vector result; { @@ -421,401 +321,4 @@ inline std::string ranks_str(const std::vector& ranks) { return c10::str("[", str, "]"); } -struct NCCLTraceBuffer { - static NCCLTraceBuffer* get() { - // intentionally leak on exit - // because this will hold python state that may get destructed - static NCCLTraceBuffer* instance = new NCCLTraceBuffer(); - return instance; - } - NCCLTraceBuffer() { - max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0); - capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false); - enabled_ = max_entries_ > 0; - } - using Event = at::cuda::CUDAEvent; - struct Entry { - size_t id_; // incremented id in the trace buffer - // used to figure out where in the circular entries - // buffer this entry will be located to - // update state information - size_t pg_id_; - std::tuple pg_name_; // - - // collective_seq_id and p2p_seq_id refer to actual kernel launches (e.g. 1 - // per coalesced group). - // collective_seq_id only increments for true collective operations (over - // all ranks in the group). p2p_seq_id only increments over non-collective - // operations in the group. op_id refers to logical operations (e.g. one per - // op inside coalesced group) - size_t collective_seq_id_; - size_t p2p_seq_id_; - size_t op_id_; - std::string profiling_name_; - - std::shared_ptr traceback_; - // we borrow pointers to start_ and end_ so we can query the state - // on reporting. However, once the event is completed, the call - // to `complete` will clear these. - Event *start_, *end_; - - // timestamp when the entry was created, likely close to the time the work - // was 'enqueued'- not necessarily started - c10::time_t time_created_; - - // configured timeout for this entry - c10::time_t timeout_ms_; - - // Is this a P2P event? - bool isP2P_; - - std::optional duration_; - - // timestamp when our CPU threads discovered that the kernel started. - // will always be _after_ it actually started, and can be very late - // if the watchdog thread got stuck on CUDA APIs. - std::optional time_discovered_started_; - - // timestamp when our CPU threads discovered that the kernel completed. - // will always be _after_ it actually complated, and can be the same time - // as the discovery of the start if the watchdog thread is stuck on CUDA - // APIs - std::optional time_discovered_completed_; - - // size information for input/output tensors - c10::SmallVector input_dims_; - std::vector input_dtypes_; - c10::SmallVector output_dims_; - std::vector output_dtypes_; - c10::SmallVector sizes_; // flattened from inputs, outputs - bool retired_ = false; // is this work entry no longer in the workMetaList_? - // a retired but not completed event has timed out - }; - - bool enabled_ = false; - bool capture_cpp_stack_ = false; - std::mutex mutex_; - std::vector entries_; - size_t max_entries_ = 0; - size_t next_ = 0; - size_t id_ = 0; - std::map, std::vector> - pg_name_to_ranks_ = {}; - - std::optional record( - size_t pg_id, - const std::tuple& pg_name, - size_t collective_seq_id, - size_t p2p_seq_id, - size_t op_id, - std::string profiling_name, - const std::vector& inputs, - const std::vector& outputs, - Event* start, - Event* end, - std::chrono::milliseconds timeout_ms, - bool isP2P) { - if (!enabled_) { - return c10::nullopt; - } - auto traceback = - torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); - std::lock_guard guard(mutex_); - - auto te = Entry{ - id_, - pg_id, - pg_name, - collective_seq_id, - p2p_seq_id, - op_id, - std::move(profiling_name), - std::move(traceback), - std::move(start), - std::move(end), - c10::getTime(), - timeout_ms.count(), - isP2P}; - - for (const auto& input : inputs) { - c10::IntArrayRef sizes = input.sizes(); - te.input_dtypes_.push_back(input.dtype().toScalarType()); - te.input_dims_.push_back(sizes.size()); - te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); - } - - for (const auto& output : outputs) { - c10::IntArrayRef sizes = output.sizes(); - te.output_dtypes_.push_back(output.dtype().toScalarType()); - te.output_dims_.push_back(sizes.size()); - te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); - } - - if (entries_.size() < max_entries_) { - entries_.emplace_back(std::move(te)); - } else { - entries_[next_++] = std::move(te); - if (next_ == max_entries_) { - next_ = 0; - } - } - return id_++; - } - - void record_pg_ranks( - const std::tuple& pg_name, - std::vector ranks) { - if (!enabled_) { - return; - } - std::lock_guard guard(mutex_); - pg_name_to_ranks_[pg_name] = ranks; - } - - void update_state(Entry& r) { - if (r.start_ != nullptr) { - bool started = r.start_->query(); - if (started && !r.time_discovered_started_) { - r.time_discovered_started_ = c10::getTime(); - } - } - if (r.end_ != nullptr) { - bool completed = r.end_->query(); - if (completed && !r.time_discovered_completed_) { - r.time_discovered_completed_ = c10::getTime(); - } - } - } - - std::vector dump_entries() { - std::lock_guard guard(mutex_); - std::vector result; - result.reserve(entries_.size()); - result.insert(result.end(), entries_.begin() + next_, entries_.end()); - result.insert(result.end(), entries_.begin(), entries_.begin() + next_); - // query any remaining events - for (auto& r : result) { - update_state(r); - r.start_ = r.end_ = nullptr; - } - return result; - } - - /* - Mark an Event as completed and free its events. - - This is called by the watchdog thread, and is asynchronous from the - perspective of the main thread. - - compute_duration defaults to true since retire_id is only called in the - watchdog thread, which is currently a place we call cuda APIs which may hang, - but care should be taken to avoid computing duration in any function that must - never hang. (timing must also be enabled for compute_duration - see - TORCH_NCCL_ENABLE_TIMING). - */ - void retire_id(std::optional id, bool compute_duration = true) { - if (!enabled_ || !id) { - return; - } - - bool can_compute_duration = false; - Event* startEvent = nullptr; - Event* endEvent = nullptr; - std::optional duration = c10::nullopt; - - std::unique_lock guard(mutex_); - - Entry* entry = &entries_.at(*id % max_entries_); - if (entry->id_ == *id) { - update_state(*entry); - - if (compute_duration) { - can_compute_duration = entry->time_discovered_completed_.has_value() && - entry->start_ && entry->end_; - startEvent = entry->start_; - endEvent = entry->end_; - } - } - - if (can_compute_duration) { - // Compute duration without without holding the lock, because - // cudaEventDuration() can hang, and we need to acquire the lock before we - // can dump(), which we never want to block. - guard.unlock(); - duration = getDurationFromEvent(*startEvent, *endEvent); - guard.lock(); - - // Refresh the entry pointer, see if the entry has been overwritten - entry = &entries_.at(*id % max_entries_); - if (entry->id_ != *id) { - LOG(INFO) - << "retire_id abandoned for id " << *id - << ", event was overwritten while waiting to compute duration."; - return; - } - if (duration.has_value()) { - entry->duration_ = duration.value(); - } - } - - entry->retired_ = true; - entry->start_ = entry->end_ = nullptr; - } - - const c10::List getCollectiveTrace( - bool includeStacktraces, - bool onlyActive) { - auto entries = new_list(); - auto result = dump_entries(); - std::vector tracebacks; - torch::SymbolizedTracebacks stracebacks; - std::vector all_frames; - if (includeStacktraces) { - for (auto& e : result) { - tracebacks.push_back(e.traceback_.get()); - } - stracebacks = torch::symbolize(tracebacks); - for (const auto& f : stracebacks.all_frames) { - auto d = new_dict(); - d.insert(name_key, f.funcname); - d.insert(filename_key, f.filename); - d.insert(line_key, int64_t(f.lineno)); - all_frames.emplace_back(std::move(d)); - } - } - for (auto i : c10::irange(result.size())) { - auto dict = new_dict(); - auto& e = result.at(i); - // Skip completed events - if (onlyActive && e.time_discovered_completed_.has_value()) { - continue; - } - - if (includeStacktraces) { - auto& tb = stracebacks.tracebacks.at(i); - auto frames = new_list(); - for (int64_t frame : tb) { - frames.push_back(all_frames.at(frame)); - } - dict.insert(frames_key, frames); - } - - dict.insert(record_id_key, int64_t(e.id_)); - dict.insert(pg_id_key, int64_t(e.pg_id_)); - dict.insert(pg_name_key, e.pg_name_); - dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_)); - dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_)); - dict.insert(op_id_key, int64_t(e.op_id_)); - dict.insert(profiling_name_key, e.profiling_name_); - dict.insert(time_created_key, int64_t(e.time_created_)); - if (e.duration_) { - dict.insert(duration_key, *e.duration_); - } - - auto it = e.sizes_.begin(); - auto read_sizes = [&](const c10::SmallVector& dims) { - auto sizes = new_list(); - for (auto dim : dims) { - auto arg_sizes = new_list(); - for (auto i : c10::irange(dim)) { - (void)i; - arg_sizes.push_back(*it++); - } - sizes.push_back(arg_sizes); - } - return sizes; - }; - - dict.insert(input_sizes_key, read_sizes(e.input_dims_)); - std::vector input_dtypes_strs; - input_dtypes_strs.reserve(e.input_dtypes_.size()); - for (const auto& input_dtype : e.input_dtypes_) { - input_dtypes_strs.push_back(c10::toString(input_dtype)); - } - dict.insert(input_dtypes_key, input_dtypes_strs); - dict.insert(output_sizes_key, read_sizes(e.output_dims_)); - std::vector output_dtypes_strs; - output_dtypes_strs.reserve(e.output_dtypes_.size()); - for (const auto& output_dtype : e.output_dtypes_) { - output_dtypes_strs.push_back(c10::toString(output_dtype)); - } - dict.insert(output_dtypes_key, output_dtypes_strs); - if (e.time_discovered_completed_.has_value()) { - dict.insert(state_key, "completed"); - } else if (e.time_discovered_started_.has_value()) { - dict.insert(state_key, "started"); - } else { - dict.insert(state_key, "scheduled"); - } - - dict.insert( - time_discovered_started_key, - e.time_discovered_started_.has_value() - ? int64_t(*e.time_discovered_started_) - : c10::IValue()); - dict.insert( - time_discovered_completed_key, - e.time_discovered_completed_.has_value() - ? int64_t(*e.time_discovered_completed_) - : c10::IValue()); - dict.insert(retired_key, e.retired_); - dict.insert(timeout_key, e.timeout_ms_); - dict.insert(is_p2p_key, e.isP2P_); - - entries.push_back(dict); - } - return entries; - } - - // dump pg_entries - const c10::Dict getPgConfig() { - auto pg_config = new_dict(); - for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { - auto pg_info = new_dict(); - pg_info.insert("name", std::get<0>(pg_name)); - pg_info.insert("desc", std::get<1>(pg_name)); - pg_info.insert("ranks", ranks_str(ranks)); - pg_config.insert(std::get<0>(pg_name), pg_info); - } - return pg_config; - } - - // dump all collectives + ncclDumpMap - std::string dump( - const std::optional>>& ncclDumpMap, - bool includeCollectives, - bool includeStackTraces, - bool onlyActive) { - auto result = new_dict(); - // common values - result.insert(version_key, version_val); - result.insert(pg_config_key, getPgConfig()); - - // collective trace - if (includeCollectives) { - result.insert( - entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); - } - - // convert ncclDumpMap into a dictionary - auto per_comm_dict = new_dict(); - if (ncclDumpMap.has_value()) { - for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) { - auto inner_dict = new_dict(); - for (const auto& [key, value] : ncclDump) { - inner_dict.insert(key, value); - } - per_comm_dict.insert(ncclId, inner_dict); - } - } - if (per_comm_dict.size() > 0) { - result.insert(nccl_comm_key, per_comm_dict); - } - return pickle_str(result); - } -}; - -#endif } // namespace c10d From fc77fdca6f7b1272e2a512eac022a334d38b26a8 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Mon, 10 Jun 2024 14:50:55 -0700 Subject: [PATCH 629/706] [guard_size_oblivious] Add gso ExpandUtils:_sym_to (#128224) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128224 Approved by: https://github.com/ezyang --- aten/src/ATen/ExpandUtils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 03cfca36e722..66973031c431 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -462,7 +462,7 @@ inline Tensor _sum_to( reduce_dims.push_back(i); } for (int64_t i = leading_dims; i < static_cast(sizes.size()); ++i) { - if (shape[i - leading_dims] == 1 && + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) && TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) { reduce_dims.push_back(i); } From 55901fb3da53e6c33834696d2a101c21b87145fa Mon Sep 17 00:00:00 2001 From: kareem mohiddeen shaik Date: Tue, 11 Jun 2024 14:04:52 +0000 Subject: [PATCH 630/706] [fx] Preserve Fx graph node order in partitioner across runs (#115621) Fixes #ISSUE_NUMBER partitioner generates different graph in recompilation on each run Co-authored-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/115621 Approved by: https://github.com/ezyang --- test/fx/test_partitioner_order.py | 53 ++++++++++++++++++++++++++++ torch/fx/passes/infra/partitioner.py | 13 +++---- 2 files changed, 60 insertions(+), 6 deletions(-) create mode 100644 test/fx/test_partitioner_order.py diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py new file mode 100644 index 000000000000..ff6418238f8e --- /dev/null +++ b/test/fx/test_partitioner_order.py @@ -0,0 +1,53 @@ +# Owner(s): ["module: fx"] + +import unittest + +from typing import Mapping + +import torch +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport +from torch.testing._internal.common_utils import TestCase + + +class DummyDevOperatorSupport(OperatorSupport): + def is_node_supported( + self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + return True + + +class DummyPartitioner(CapabilityBasedPartitioner): + def __init__(self, graph_module: torch.fx.GraphModule): + super().__init__( + graph_module, + DummyDevOperatorSupport(), + allows_single_node_partition=True, + ) + + +class AddModule(torch.nn.Module): + def forward(self, x): + y = torch.add(x, x) + z = torch.add(y, x) + return z + + +class TestPartitionerOrder(TestCase): + # partitoner test to check graph node order + def test_partitioner_order(self): + m = AddModule() + traced_m = torch.fx.symbolic_trace(m) + partions = DummyPartitioner(traced_m).propose_partitions() + partion_nodes = [list(partition.nodes) for partition in partions] + node_order = [n.name for n in partion_nodes[0]] + for _ in range(10): + traced_m = torch.fx.symbolic_trace(m) + new_partion = DummyPartitioner(traced_m).propose_partitions() + new_partion_nodes = [list(partition.nodes) for partition in new_partion] + new_node_order = [n.name for n in new_partion_nodes[0]] + self.assertTrue(node_order == new_node_order) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 095be545eb54..58e4e9dd09e8 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -18,16 +18,16 @@ class Partition: def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None): self.id = id - self.nodes: Set[Node] = set(nodes) if nodes is not None else set() + self.nodes = {node: None for node in nodes} if nodes is not None else dict() def __repr__(self) -> str: return str(self.nodes) def add_node(self, node: Node): - self.nodes.add(node) + self.nodes.update({node: None}) def remove_node(self, node: Node): - self.nodes.remove(node) + del self.nodes[node] def size(self): return len(self.nodes) @@ -321,12 +321,13 @@ def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: remove_node: Set[Node] = set() for node in partition.nodes: if is_non_compute_node(node) and \ - (is_transparent_input_node(node, partition.nodes, remove_node) or - is_transparent_output_node(node, partition.nodes, remove_node)): + (is_transparent_input_node(node, set(partition.nodes), remove_node) or + is_transparent_output_node(node, set(partition.nodes), remove_node)): remove_node.add(node) if len(remove_node) != 0: - partition.nodes = partition.nodes - remove_node + for node in remove_node: + partition.nodes.pop(node, None) def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule: partitions = self.propose_partitions() From 9a38cae299e5ffd8143182bec878c28f96cfd72a Mon Sep 17 00:00:00 2001 From: Huamin Li Date: Tue, 11 Jun 2024 15:01:25 +0000 Subject: [PATCH 631/706] [AOTI] Switch to use shim v2 (#127674) Differential Revision: D56709309 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127674 Approved by: https://github.com/desertfire --- torch/_inductor/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 2ea60000d265..10374f577edf 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -43,7 +43,7 @@ def is_fbcode(): ) c_shim_version = os.environ.get( - "TORCHINDUCTOR_C_SHIM_VERSION", "1" if is_fbcode() else "2" + "TORCHINDUCTOR_C_SHIM_VERSION", "1" if (is_fbcode() and torch.version.hip) else "2" ) # dead code elimination From 053930e194211173c9e029d71914fd5974b60a8f Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 10 Jun 2024 17:25:13 -0700 Subject: [PATCH 632/706] [MPS][BE] Remove code duplication (#128373) Use `scalarToMetalTypeString` instead of `getMetalType` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128373 Approved by: https://github.com/Skylion007 --- .../ATen/native/mps/operations/UnaryKernel.mm | 29 ++----------------- 1 file changed, 2 insertions(+), 27 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/UnaryKernel.mm b/aten/src/ATen/native/mps/operations/UnaryKernel.mm index 540fc6a26cd8..5c894efb89fd 100644 --- a/aten/src/ATen/native/mps/operations/UnaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/UnaryKernel.mm @@ -13,32 +13,6 @@ #include namespace at::native { -static const std::string& getMetalType(const c10::ScalarType& t) { - // Mapping from c10::ScalarType to integral type that can be used for unary ops - static std::unordered_map scalar_to_metal_type = { - {c10::ScalarType::Half, "half"}, - {c10::ScalarType::Float, "float"}, - {c10::ScalarType::Long, "long"}, - {c10::ScalarType::Int, "int"}, - {c10::ScalarType::Short, "short"}, - {c10::ScalarType::Bool, "bool"}, - {c10::ScalarType::Char, "int8_t"}, - {c10::ScalarType::Byte, "uint8_t"}, - }; - - auto it = scalar_to_metal_type.find(t); - TORCH_CHECK(it != scalar_to_metal_type.end(), "Unsupported type ", t); - return it->second; -} - -static const std::string& getMetalType(const c10::Scalar& s) { - return getMetalType(s.type()); -} - -static const std::string& getMetalType(const Tensor& t) { - return getMetalType(t.scalar_type()); -} - static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2); TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) { @@ -57,7 +31,8 @@ } using namespace mps; @autoreleasepool { - auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel", {getMetalType(outputTensor), getMetalType(self)}); + auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel", + {scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)}); if (!self.is_contiguous()) { inputTensor = inputTensor.contiguous(); From c13e03c87428b986972a48d8fc78dbffc2579f63 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Mon, 10 Jun 2024 21:39:43 -0700 Subject: [PATCH 633/706] Flip default value for mypy disallow_untyped_defs [10+2/11] (#128374) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128374 Approved by: https://github.com/Skylion007 --- torch/_C/return_types.pyi.in | 1 + torch/distributed/_tensor/examples/display_sharding_example.py | 1 + torch/nn/functional.pyi.in | 1 + torch/utils/_sympy/numbers.py | 1 + torch/utils/data/datapipes/datapipe.py | 1 + torch/utils/data/datapipes/datapipe.pyi.in | 1 + 6 files changed, 6 insertions(+) diff --git a/torch/_C/return_types.pyi.in b/torch/_C/return_types.pyi.in index 458a076d7bfe..ce37323f7b33 100644 --- a/torch/_C/return_types.pyi.in +++ b/torch/_C/return_types.pyi.in @@ -1,4 +1,5 @@ # ${generated_comment} +# mypy: allow-untyped-defs from typing import ( Any, diff --git a/torch/distributed/_tensor/examples/display_sharding_example.py b/torch/distributed/_tensor/examples/display_sharding_example.py index 0e32ed074534..1ce3962b9545 100644 --- a/torch/distributed/_tensor/examples/display_sharding_example.py +++ b/torch/distributed/_tensor/examples/display_sharding_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os from typing import Any, Dict diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 5bb847a0a727..9dec24809e24 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import ( Any, Callable, diff --git a/torch/utils/_sympy/numbers.py b/torch/utils/_sympy/numbers.py index 89dac14fddf3..6a93255df852 100644 --- a/torch/utils/_sympy/numbers.py +++ b/torch/utils/_sympy/numbers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import mpmath.libmp as mlib # type: ignore[import-untyped] import sympy from sympy import Expr diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index 1c99fe79e406..8add81987837 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import pickle from typing import Dict, Callable, Optional, TypeVar, Generic, Iterator diff --git a/torch/utils/data/datapipes/datapipe.pyi.in b/torch/utils/data/datapipes/datapipe.pyi.in index 6b3cbe34b46a..4d03665d5d66 100644 --- a/torch/utils/data/datapipes/datapipe.pyi.in +++ b/torch/utils/data/datapipes/datapipe.pyi.in @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection # The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt # Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other From f8c45996d517d16845782de0af9e7530d6c4bb4d Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 10 Jun 2024 17:48:21 -0700 Subject: [PATCH 634/706] [MPS] Make erfinv compilable for bfloat16 (#128375) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128375 Approved by: https://github.com/Skylion007 ghstack dependencies: #128373 --- aten/src/ATen/native/mps/UnaryConstants.h | 25 +++++++++-------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/mps/UnaryConstants.h b/aten/src/ATen/native/mps/UnaryConstants.h index 4adf1d0e333e..b1a92f688d12 100644 --- a/aten/src/ATen/native/mps/UnaryConstants.h +++ b/aten/src/ATen/native/mps/UnaryConstants.h @@ -18,26 +18,21 @@ kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]], /* coefficients in rational expansion */ float y_abs = abs(y); - if(y_abs > 1.0f){{ - output[index] = NAN; + if (y_abs >= 1.0f) {{ + output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y)); return; }} - if(y_abs == 1.0f){{ - output[index] = copysign(INFINITY, y); - return; - }} - if(y_abs <= 0.7f) {{ + if (y_abs <= 0.7f) {{ z = y * y; - num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); - dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + 1.0f); + num = ((a[3] * z + a[2]) * z + a[1])*z + a[0]; + dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f; x = y * num / dem; - }} - else{{ + }} else {{ z = sqrt(-1.0f*log((1.0-y_abs)/2.0)); - num = ((c[3]*z + c[2])*z + c[1]) * z + c[0]; - dem = (d[1]*z + d[0])*z + 1.0f; + num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0]; + dem = (d[1] * z + d[0]) * z + 1.0f; x = copysign(num, y) / dem; }} - output[index] = x; -}})METAL"; \ No newline at end of file + output[index] = {0}(x); +}})METAL"; From 29081059b6534377698db4a1086d745c22f2f6eb Mon Sep 17 00:00:00 2001 From: David Berard Date: Mon, 10 Jun 2024 15:52:01 -0700 Subject: [PATCH 635/706] [Static Runtime] Fix & run gen_static_runtime_ops (#128299) gen_static_runtime_ops hasn't been updated in a while. In preparation for https://github.com/pytorch/pytorch/pull/127675 in which I need to re-run the codegen step for cumprod, I want to land these changes beforehand in case there are any other issues that arise. I added a number of ops to the blocklist: ``` + "_nested_tensor_storage_offsets", + "_nested_get_values", # no CPU backend + "_nested_get_values_copy", # no CPU backend + "_nested_view_from_jagged", # testing needs to be patched + "_nested_view_from_jagged_copy", # testing needs to be patched + "_nested_view_from_buffer", # testing needs to be patched + "_nested_view_from_buffer_copy", # testing needs to be patched + "_int_mm", # testing needs to be patched + "_to_sparse_csc", # testing needs to be patched + "_to_sparse_csr", # testing needs to be patched + "segment_reduce", # testing needs to be patched ``` Most of these are added just because testing doesn't work right now. Additionally, a few `fft` ops seem to have been removed from native_functions.yaml; I'm guessing it's unlikely FFT would have been used in many real models though. Differential Revision: [D58329403](https://our.internmc.facebook.com/intern/diff/D58329403/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128299 Approved by: https://github.com/YuqingJ --- .../static_runtime/test_generated_ops.cc | 282 ++++-------------- .../csrc/jit/runtime/static/generated_ops.cpp | 198 +++++------- torchgen/static_runtime/config.py | 2 +- torchgen/static_runtime/generator.py | 11 + 4 files changed, 148 insertions(+), 345 deletions(-) diff --git a/benchmarks/static_runtime/test_generated_ops.cc b/benchmarks/static_runtime/test_generated_ops.cc index 415bf464fbd1..bdf0585404ed 100644 --- a/benchmarks/static_runtime/test_generated_ops.cc +++ b/benchmarks/static_runtime/test_generated_ops.cc @@ -272,6 +272,38 @@ TEST(StaticRuntime, autogen_addr) { /*check_resize=*/true); } +TEST(StaticRuntime, autogen__test_functorch_fallback) { + const std::string script = R"IR( + graph(%self: Tensor, %other: Tensor): + %bias: None = prim::Constant() + %ret = aten::_test_functorch_fallback(%self, %other) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + auto self0 = at::rand({6, 6, 6}); + auto other0 = at::rand({6, 6, 6}); + std::vector args{self0, other0}; + testStaticRuntime( + script, + args, + {}, + /*use_allclose=*/false, + /*use_equalnan=*/false, + /*check_resize=*/true); + + auto self1 = at::rand({22, 22, 22}); + auto other1 = at::rand({22, 22, 22}); + std::vector args2{self1, other1}; + testStaticRuntime( + script, + args, + args2, + /*use_allclose=*/false, + /*use_equalnan=*/false, + /*check_resize=*/true); +} + TEST(StaticRuntime, autogen_argmax) { const std::string script = R"IR( graph(%self: Tensor, %dim: int?, %keepdim: bool): @@ -4440,6 +4472,40 @@ TEST(StaticRuntime, autogen_masked_select) { /*check_resize=*/true); } +TEST(StaticRuntime, autogen_nonzero_static) { + const std::string script = R"IR( + graph(%self: Tensor, %size: int, %fill_value: int): + %bias: None = prim::Constant() + %ret = aten::nonzero_static(%self, %size, %fill_value) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + auto self0 = at::rand({6, 6, 6}); + auto size0 = 1; + auto fill_value0 = 1; + std::vector args{self0, size0, fill_value0}; + testStaticRuntime( + script, + args, + {}, + /*use_allclose=*/false, + /*use_equalnan=*/false, + /*check_resize=*/true); + + auto self1 = at::rand({22, 22, 22}); + auto size1 = 1; + auto fill_value1 = 1; + std::vector args2{self1, size1, fill_value1}; + testStaticRuntime( + script, + args, + args2, + /*use_allclose=*/false, + /*use_equalnan=*/false, + /*check_resize=*/true); +} + TEST(StaticRuntime, autogen_gather) { const std::string script = R"IR( graph(%self: Tensor, %dim: int, %index: Tensor, %sparse_grad: bool): @@ -7106,222 +7172,6 @@ TEST(StaticRuntime, autogen_special_multigammaln) { /*check_resize=*/true); } -TEST(StaticRuntime, autogen_fft_fft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_fft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_fft_ifft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_ifft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_fft_rfft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_rfft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_fft_irfft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_irfft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_fft_hfft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_hfft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_fft_ihfft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_ihfft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - TEST(StaticRuntime, autogen_linalg_cross) { const std::string script = R"IR( graph(%self: Tensor, %other: Tensor, %dim: int): diff --git a/torch/csrc/jit/runtime/static/generated_ops.cpp b/torch/csrc/jit/runtime/static/generated_ops.cpp index af61ee72a00e..4597e1298cd6 100644 --- a/torch/csrc/jit/runtime/static/generated_ops.cpp +++ b/torch/csrc/jit/runtime/static/generated_ops.cpp @@ -36,7 +36,8 @@ #include #include -namespace torch::jit { +namespace torch { +namespace jit { REGISTER_OPERATOR_FUNCTOR( aten::absolute, @@ -190,6 +191,29 @@ REGISTER_OPERATOR_FUNCTOR(aten::addr, aten_addr, [](Node* n) -> SROperator { return nullptr; }); +REGISTER_OPERATOR_FUNCTOR( + aten::_test_functorch_fallback, + aten__test_functorch_fallback, + [](Node* n) -> SROperator { + if (n->matches(torch::schema( + "aten::_test_functorch_fallback(Tensor self, Tensor other) -> Tensor"))) { + return [](ProcessedNode* p_node) { + const auto& self = p_node->Input(0).toTensor(); + const auto& other = p_node->Input(1).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = + at::native::_test_functorch_fallback(self, other); + return; + } + auto& out = p_node->Output(0).toTensor(); + fastResizeToZero(out); + at::native::_test_functorch_fallback_out(self, other, out); + }; + } + LogAndDumpSchema(n); + return nullptr; + }); + REGISTER_OPERATOR_FUNCTOR(aten::argmax, aten_argmax, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"))) { @@ -2430,6 +2454,25 @@ REGISTER_OPERATOR_FUNCTOR(aten::addbmm, aten_addbmm, [](Node* n) -> SROperator { return nullptr; }); +REGISTER_OPERATOR_FUNCTOR(aten::diag, aten_diag, [](Node* n) -> SROperator { + if (n->matches( + torch::schema("aten::diag(Tensor self, int diagonal=0) -> Tensor"))) { + return [](ProcessedNode* p_node) { + const auto& self = p_node->Input(0).toTensor(); + const auto diagonal = p_node->Input(1).toInt(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::diag(self, diagonal); + return; + } + auto& out = p_node->Output(0).toTensor(); + fastResizeToZero(out); + at::native::diag_out(self, diagonal, out); + }; + } + LogAndDumpSchema(n); + return nullptr; +}); + REGISTER_OPERATOR_FUNCTOR(aten::cross, aten_cross, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor"))) { @@ -2684,6 +2727,30 @@ REGISTER_OPERATOR_FUNCTOR( return nullptr; }); +REGISTER_OPERATOR_FUNCTOR( + aten::nonzero_static, + aten_nonzero_static, + [](Node* n) -> SROperator { + if (n->matches(torch::schema( + "aten::nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor"))) { + return [](ProcessedNode* p_node) { + const auto& self = p_node->Input(0).toTensor(); + const auto size = p_node->Input(1).toInt(); + const auto fill_value = p_node->Input(2).toInt(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = + at::native::nonzero_static_cpu(self, size, fill_value); + return; + } + auto& out = p_node->Output(0).toTensor(); + fastResizeToZero(out); + at::native::nonzero_static_out_cpu(self, size, fill_value, out); + }; + } + LogAndDumpSchema(n); + return nullptr; + }); + REGISTER_OPERATOR_FUNCTOR(aten::gather, aten_gather, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor"))) { @@ -4463,132 +4530,6 @@ REGISTER_OPERATOR_FUNCTOR( return nullptr; }); -REGISTER_OPERATOR_FUNCTOR(aten::fft_fft, aten_fft_fft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_fft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_fft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR(aten::fft_ifft, aten_fft_ifft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_ifft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_ifft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR(aten::fft_rfft, aten_fft_rfft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_rfft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_rfft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR(aten::fft_irfft, aten_fft_irfft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_irfft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_irfft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR(aten::fft_hfft, aten_fft_hfft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_hfft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_hfft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR(aten::fft_ihfft, aten_fft_ihfft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_ihfft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_ihfft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - REGISTER_OPERATOR_FUNCTOR( aten::linalg_cross, aten_linalg_cross, @@ -5281,4 +5222,5 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( return nullptr; }); -} // namespace torch::jit +} // namespace jit +} // namespace torch diff --git a/torchgen/static_runtime/config.py b/torchgen/static_runtime/config.py index 407165147e35..da6e2a21c2a3 100644 --- a/torchgen/static_runtime/config.py +++ b/torchgen/static_runtime/config.py @@ -383,6 +383,6 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N return if op_name in ("diagonal", "linalg_diagonal"): arg_map["offset"] = "0" - arg_map["dim0"] = "1" arg_map["dim1"] = "2" + arg_map["dim2"] = "1" return diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py index e709450b48d3..b068af7728aa 100644 --- a/torchgen/static_runtime/generator.py +++ b/torchgen/static_runtime/generator.py @@ -222,6 +222,17 @@ def has_alias( "special_spherical_bessel_j0", "_foobar", "_nested_tensor_strides", + "_nested_tensor_storage_offsets", + "_nested_get_values", # no CPU backend + "_nested_get_values_copy", # no CPU backend + "_nested_view_from_jagged", # testing needs to be patched + "_nested_view_from_jagged_copy", # testing needs to be patched + "_nested_view_from_buffer", # testing needs to be patched + "_nested_view_from_buffer_copy", # testing needs to be patched + "_int_mm", # testing needs to be patched + "_to_sparse_csc", # testing needs to be patched + "_to_sparse_csr", # testing needs to be patched + "segment_reduce", # testing needs to be patched ) ) From a838e909644a3b6d811619be9d60825c16f9006e Mon Sep 17 00:00:00 2001 From: ankurneog Date: Tue, 11 Jun 2024 16:35:17 +0000 Subject: [PATCH 636/706] Add Intel Gaudi device/HPU to auto load in instantiate_device_type_tests (#126970) ### Motivation Intel Gaudi accelerator (device name hpu) is seen to have good pass rate with the pytorch framework UTs , however being an out-of-tree device, we face challenges in adapting the device to natively run the existing pytorch UTs under pytorch/test. The UTs however is a good indicator of the device stack health and as such we run them regularly with adaptations. Although we can add Gaudi/HPU device to generate the device specific tests using the TORCH_TEST_DEVICES environment variable, we miss out on lot of features such as executing for specific dtypes, skipping and overriding opInfo. With significant changes introduced every Pytorch release maintaining these adaptations become difficult and time consuming. Hence with this PR we introduce Gaudi device in common_device_type framework, so that the tests are instantiated for Gaudi when the library is loaded. The eventual goal is to introduce Gaudi out-of-tree support as equivalent to in-tree devices ### Changes Add HPUTestBase of type DeviceTypeTestBase specifying appropriate attributes for Gaudi/HPU. Include code to check if intel Gaudi Software library is loaded and if so, add the device to the list of devices considered for instantiation of device type tests ### Additional Context please refer the following RFC : https://github.com/pytorch/rfcs/pull/63/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/126970 Approved by: https://github.com/albanD --- torch/testing/_internal/common_device_type.py | 29 ++++++++++++++++++- torch/testing/_internal/common_utils.py | 10 +++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 07caa0ac3eee..2e2a379a501e 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -15,7 +15,7 @@ import torch from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \ skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \ - IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, \ + IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, TEST_HPU, \ _TestParametrizer, compose_parametrize_fns, dtype_name, \ TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo, \ get_tracked_input, clear_tracked_input, PRINT_REPRO_ON_FAILURE, \ @@ -590,6 +590,18 @@ def setUpClass(cls): def _should_stop_test_suite(self): return False +class HPUTestBase(DeviceTypeTestBase): + device_type = 'hpu' + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def setUpClass(cls): + cls.primary_device = 'hpu:0' + class PrivateUse1TestBase(DeviceTypeTestBase): primary_device: ClassVar[str] device_mod = None @@ -701,6 +713,8 @@ def get_desired_device_type_test_bases(except_for=None, only_for=None, include_l test_bases.append(MPSTestBase) if only_for == 'xpu' and TEST_XPU and XPUTestBase not in test_bases: test_bases.append(XPUTestBase) + if TEST_HPU and HPUTestBase not in test_bases: + test_bases.append(HPUTestBase) # Filter out the device types based on user inputs desired_device_type_test_bases = filter_desired_device_types(test_bases, except_for, only_for) if include_lazy: @@ -1060,6 +1074,10 @@ class skipMPSIf(skipIf): def __init__(self, dep, reason): super().__init__(dep, reason, device_type='mps') +class skipHPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type='hpu') + # Skips a test on XLA if the condition is true. class skipXLAIf(skipIf): @@ -1343,6 +1361,9 @@ def onlyMPS(fn): def onlyXPU(fn): return onlyOn('xpu')(fn) +def onlyHPU(fn): + return onlyOn('hpu')(fn) + def onlyPRIVATEUSE1(fn): device_type = torch._C._get_privateuse1_backend_name() device_mod = getattr(torch, device_type, None) @@ -1401,6 +1422,9 @@ def expectedFailureMeta(fn): def expectedFailureXLA(fn): return expectedFailure('xla')(fn) +def expectedFailureHPU(fn): + return expectedFailure('hpu')(fn) + # Skips a test on CPU if LAPACK is not available. def skipCPUIfNoLapack(fn): return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn) @@ -1578,6 +1602,9 @@ def skipXLA(fn): def skipMPS(fn): return skipMPSIf(True, "test doesn't work on MPS backend")(fn) +def skipHPU(fn): + return skipHPUIf(True, "test doesn't work on HPU backend")(fn) + def skipPRIVATEUSE1(fn): return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 5d72a444fdda..2097e25bdaa8 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1236,6 +1236,7 @@ def TemporaryDirectoryName(suffix=None): TEST_MKL = torch.backends.mkl.is_available() TEST_MPS = torch.backends.mps.is_available() TEST_XPU = torch.xpu.is_available() +TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False TEST_CUDA = torch.cuda.is_available() custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) custom_device_is_available = hasattr(custom_device_mod, "is_available") and custom_device_mod.is_available() @@ -1622,6 +1623,15 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper +def skipIfHpu(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_HPU: + raise unittest.SkipTest("test doesn't currently work with HPU") + else: + fn(*args, **kwargs) + return wrapper + # Skips a test on CUDA if ROCm is available and its version is lower than requested. def skipIfRocmVersionLessThan(version=None): def dec_fn(fn): From 4345d98663d31f23492cafc0062f515a47d96a78 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Tue, 11 Jun 2024 16:43:15 +0000 Subject: [PATCH 637/706] [dynamo] Fix for #127696 (#128358) Test Plan: `buck2 test @//mode/dev-nosan //executorch/exir/backend/...` https://www.internalfb.com/intern/testinfra/testrun/12666373989243932 Differential Revision: D58384518 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128358 Approved by: https://github.com/ydwu4 --- torch/_dynamo/variables/higher_order_ops.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 00932f984f38..59f8c26ce62d 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -12,7 +12,7 @@ import torch.fx import torch.nn import torch.onnx.operators -from torch._dynamo.utils import deepcopy_to_fake_tensor, get_fake_value, get_real_value +from torch._dynamo.utils import get_fake_value from torch._dynamo.variables import ConstantVariable from torch._dynamo.variables.base import VariableTracker from torch._dynamo.variables.builtin import BuiltinVariable @@ -1149,17 +1149,15 @@ def call_function( p_args = tuple(arg.as_proxy() for arg in args[1:]) real_sub_args = pytree.tree_map_only( - torch.fx.Proxy, lambda a: get_real_value(a.node, tx.output), p_args + torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args ) - example_res = lowered_module.original_module.module()(*real_sub_args) + example_value = lowered_module.original_module.module()(*real_sub_args) # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]: # executorch modules promise not to alias inputs and outputs. # Thus, output FakeTensors will correctly not alias input FakeTensors. - _assert_tensors_nonaliasing(real_sub_args, example_res) - - example_value = deepcopy_to_fake_tensor(example_res, tx.fake_mode) + _assert_tensors_nonaliasing(real_sub_args, example_value) p_args = (lowered_node,) + p_args From 491c4a5dcbad4a5cea1735cd42072766827ab5b0 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 11 Jun 2024 17:45:20 +0000 Subject: [PATCH 638/706] Revert "Make sure #126704 is BC for torch.save-ed `nn.Module` (#128344)" This reverts commit 841d87177a900c2bbd59b6589165189141c4e8bb. Reverted https://github.com/pytorch/pytorch/pull/128344 on behalf of https://github.com/clee2000 due to broke internal typecheck D58394110 (which probably means the code wouldn't work either but I guess it didn't run on the diff). Probably an easy fix? ([comment](https://github.com/pytorch/pytorch/pull/126704#issuecomment-2161299193)) --- torch/nn/modules/module.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 2c2c2865687c..942d5b8a8f95 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1955,12 +1955,7 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): for name, module in self._modules.items(): if module is not None: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) - for value in self._state_dict_hooks.values(): - # For BC reasons - if isinstance(value, tuple): - hook, from_private = value - else: - hook, from_private = value, True + for (hook, from_private) in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) if from_private and hook_result is not None: destination = hook_result From 1d233b8f500f3fafb96e944953579a33a3c1c24e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 11 Jun 2024 17:45:20 +0000 Subject: [PATCH 639/706] Revert "Make nn.Module state_dict load_state_dict pre-hook and state_dict post hook public (#126704)" This reverts commit c38b3381a12a0ec033dd417827c530c4474b8165. Reverted https://github.com/pytorch/pytorch/pull/126704 on behalf of https://github.com/clee2000 due to broke internal typecheck D58394110 (which probably means the code wouldn't work either but I guess it didn't run on the diff). Probably an easy fix? ([comment](https://github.com/pytorch/pytorch/pull/126704#issuecomment-2161299193)) --- test/nn/test_load_state_dict.py | 2 +- test/nn/test_module_hooks.py | 42 ++--------- .../fx/_lower_to_native_backend.py | 2 +- .../_checkpoint/checkpoint_wrapper.py | 4 +- torch/nn/modules/module.py | 71 ++++++------------- 5 files changed, 32 insertions(+), 89 deletions(-) diff --git a/test/nn/test_load_state_dict.py b/test/nn/test_load_state_dict.py index 3ad7e9c3a639..1bb9f7e82572 100644 --- a/test/nn/test_load_state_dict.py +++ b/test/nn/test_load_state_dict.py @@ -201,7 +201,7 @@ def hook_fn( module_state_dict = module.state_dict() self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys())) - model[0][0].register_load_state_dict_pre_hook(hook_fn) + model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True) model.load_state_dict(model.state_dict(), strict=True) # fails swapping as LSTM installs weak references on the parameters diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index b2fddcdf0cbd..dc4bead78242 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -588,28 +588,21 @@ def hook_with_module( hook_called += 1 hook_called = 0 - # Test private API since this sets with_module=False which diverges from public API m_load._register_load_state_dict_pre_hook(hook_without_module) m_load.load_state_dict(m_state_dict) self.assertEqual(1, hook_called) - hook_called = 0 - m_load.register_load_state_dict_pre_hook(hook_with_module) - m_load.load_state_dict(m_state_dict) - self.assertEqual(2, hook_called) - - # Test private API with with_module=True hook_called = 0 m_load._register_load_state_dict_pre_hook(hook_with_module, True) m_load.load_state_dict(m_state_dict) - self.assertEqual(3, hook_called) + self.assertEqual(2, hook_called) def test_no_extra_ref_to_module(self): try: gc.disable() m = nn.Linear(10, 10) - m.register_load_state_dict_pre_hook(_hook_to_pickle) + m._register_load_state_dict_pre_hook(_hook_to_pickle, True) weak_m = weakref.ref(m) del m @@ -619,7 +612,7 @@ def test_no_extra_ref_to_module(self): def test_pickled_hook(self): m = nn.Linear(10, 10) - m.register_load_state_dict_pre_hook(_hook_to_pickle) + m._register_load_state_dict_pre_hook(_hook_to_pickle, True) pickle.loads(pickle.dumps(m)) @swap([True, False]) @@ -685,13 +678,14 @@ def __init__(self, mod): mod = m hook_called = 0 - # Test private API since this sets with_module=False which diverges from public API mod._register_load_state_dict_pre_hook(mod.my_pre_load_hook) m.load_state_dict(state_dict) self.assertEqual(1, hook_called) hook_called = 0 - mod.register_load_state_dict_pre_hook(mod.my_pre_load_hook_with_module) + mod._register_load_state_dict_pre_hook( + mod.my_pre_load_hook_with_module, True + ) m.load_state_dict(state_dict) self.assertEqual(2, hook_called) @@ -865,30 +859,6 @@ def my_state_dict_pre_hook(*args, **kwargs): _ = m.state_dict() self.assertTrue(called) - def test_register_state_dict_post_hook(self): - def state_dict_post_hook(module, state_dict, prefix, local_metadata): - for name, param in module.named_parameters(recurse=False): - state_dict[prefix + name] = torch.nn.Parameter( - state_dict[prefix + name] - ) - - def register_linear_hook(module): - if isinstance(module, nn.Linear): - module.register_state_dict_post_hook(state_dict_post_hook) - - m = nn.Transformer( - d_model=4, nhead=2, num_encoder_layers=2, num_decoder_layers=2 - ) - m.apply(register_linear_hook) - - sd = m.state_dict() - - for k, v in m.state_dict().items(): - if "linear" in k or "out_proj" in k: - self.assertTrue(isinstance(v, torch.nn.Parameter)) - else: - self.assertFalse(isinstance(v, torch.nn.Parameter)) - class TestModuleGlobalHooks(TestCase): def tearDown(self): diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index f36904c3f587..92620a169383 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -442,7 +442,7 @@ def load_arg(a): quantized_model = GraphModule(quantized_model, folded_graph) quantized_model._register_state_dict_hook(_save_packed_weight) - quantized_model.register_load_state_dict_pre_hook(_load_packed_weight) + quantized_model._register_load_state_dict_pre_hook(_load_packed_weight, with_module=True) return quantized_model def _get_module(node: Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]: diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index a39082e7ea49..86ab1de003db 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -34,8 +34,8 @@ def __init__(self, mod): self._register_state_dict_hook(self._post_state_dict_hook) # load_state_dict pre-hook to allow loading back into # checkpoint-wrapped module. - self.register_load_state_dict_pre_hook( - self._pre_load_state_dict_hook + self._register_load_state_dict_pre_hook( + self._pre_load_state_dict_hook, with_module=True ) def forward(self, *args, **kwargs): diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 942d5b8a8f95..f803d3f02a17 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -418,9 +418,7 @@ def forward(self, x): # As JIT does not support Set[int], this dict is used as a set, where all # hooks represented in this dict accept kwargs. _forward_pre_hooks_with_kwargs: Dict[int, bool] - # The bool indicates whether the hook comes from the private method - # or the public method. - _state_dict_hooks: Dict[int, Tuple[Callable, bool]] + _state_dict_hooks: Dict[int, Callable] _load_state_dict_pre_hooks: Dict[int, Callable] _state_dict_pre_hooks: Dict[int, Callable] _load_state_dict_post_hooks: Dict[int, Callable] @@ -1801,40 +1799,24 @@ def __delattr__(self, name): super().__delattr__(name) def _register_state_dict_hook(self, hook): - r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. + r"""Register a state-dict hook. - It should have the following signature:: - hook(module, state_dict, prefix, local_metadata) -> None or state_dict - - The registered hooks can modify the ``state_dict`` inplace or return a new one. - If a new ``state_dict`` is returned, it will only be respected if it is the root - module that :meth:`~nn.Module.state_dict` is called from. - """ - handle = hooks.RemovableHandle(self._state_dict_hooks) - # True indicates that the hook was registered via the private method - self._state_dict_hooks[handle.id] = (hook, True) - return handle - - def register_state_dict_post_hook(self, hook): - r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. - - It should have the following signature:: - hook(module, state_dict, prefix, local_metadata) -> None - - The registered hooks can modify the ``state_dict`` inplace. + These hooks will be called with arguments: `self`, `state_dict`, + `prefix`, `local_metadata`, after the `state_dict` of `self` is set. + Note that only parameters and buffers of `self` or its children are + guaranteed to exist in `state_dict`. The hooks may modify `state_dict` + inplace or return a new one. """ handle = hooks.RemovableHandle(self._state_dict_hooks) - # False indicates that the hook was registered via the public method - self._state_dict_hooks[handle.id] = (hook, False) + self._state_dict_hooks[handle.id] = hook return handle def register_state_dict_pre_hook(self, hook): r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. - It should have the following signature:: - hook(module, prefix, keep_vars) -> None - - The registered hooks can be used to perform pre-processing before the ``state_dict`` + These hooks will be called with arguments: ``self``, ``prefix``, + and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered + hooks can be used to perform pre-processing before the ``state_dict`` call is made. """ handle = hooks.RemovableHandle(self._state_dict_pre_hooks) @@ -1955,19 +1937,22 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): for name, module in self._modules.items(): if module is not None: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) - for (hook, from_private) in self._state_dict_hooks.values(): + for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) - if from_private and hook_result is not None: + if hook_result is not None: destination = hook_result return destination def _register_load_state_dict_pre_hook(self, hook, with_module=False): - r"""See :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` for details. + r"""Register a pre-hook for the :meth:`~torch.nn.Module.load_state_dict` method. + + These hooks will be called with arguments: `state_dict`, `prefix`, + `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, + `error_msgs`, before loading `state_dict` into `self`. These arguments + are exactly the same as those of `_load_from_state_dict`. - A subtle difference is that if ``with_module`` is set to ``False``, then the - hook will not take the ``module`` as the first argument whereas - :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` always takes the - ``module`` as the first argument. + If ``with_module`` is ``True``, then the first argument to the hook is + an instance of the module. Arguments: hook (Callable): Callable hook that will be invoked before @@ -1979,20 +1964,8 @@ def _register_load_state_dict_pre_hook(self, hook, with_module=False): self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(hook, self if with_module else None) return handle - def register_load_state_dict_pre_hook(self, hook): - r"""Register a pre-hook to be run before module's :meth:`~nn.Module.load_state_dict` is called. - - It should have the following signature:: - hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950 - - Arguments: - hook (Callable): Callable hook that will be invoked before - loading the state dict. - """ - return self._register_load_state_dict_pre_hook(hook, with_module=True) - def register_load_state_dict_post_hook(self, hook): - r"""Register a post-hook to be run after module's :meth:`~nn.Module.load_state_dict` is called. + r"""Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None From 8a09940a543d4c2fd23a5c78edbf1ac24d481b45 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 10 Jun 2024 16:47:51 -0700 Subject: [PATCH 640/706] [inductor] fix compile time regression by caching get_gpu_type (#128363) We observed signficant compile time regression in torchtitan when turning on 2D parallel + torch.compile recently. So I decided to get a deeper understanding why. It turns out this is affecting **all the trainings** that have functional collectives captured in the graph, not only 2D parallel (2D parallel was just the job that happen to have collectives captured in the TP region). The root cause is because when doing inductor lowering, we are calling the comm analysis pass to get a estimated collective time for each collective node in the graph, for each call to check the collective node, we are calling `get_gpu_type()`, which under the hood calls a `torch.utils.collect_env.run` to get the GPU info. However, this call is super expensive! The reason is that this call effectively spawns a new process and call `nvidia-smi` to get the GPU info, so the cost is **linear** to the number of collective nodes in the graph. see https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py#L75 The fix is to add a lru cache to the function, so that we only call this once and reuse the cached results afterwards torchtitan benchmark shows: * before this fix: 2D parallel + fp8 compile time: 6min + * after this fix: 2D parallel + fp8 compile time: 2min 48s (more than 100% improvement) There're more room to improve the compile time, but this PR is trying to fix the biggest regression I found so far. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128363 Approved by: https://github.com/yf225 --- torch/_inductor/comm_analysis.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 334ccf5b7e18..71e8740a5fd7 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -1,3 +1,4 @@ +import functools import math from enum import IntEnum @@ -22,6 +23,7 @@ class NVIDIA_GPU_TYPE(IntEnum): HOPPER = 2 +@functools.lru_cache def get_gpu_type() -> NVIDIA_GPU_TYPE: gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or "" if "V100" in gpu_info: From cac7a22b92478d897488688010e562b7bd36b97f Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 11 Jun 2024 18:09:25 +0000 Subject: [PATCH 641/706] [cuDNN][Quantization] Don't print when plan finalization fails in cuDNN quantization backend (#128177) Similar in spirit to #125790, hopefully addresses failures seen for cuDNN 9.1 upgrade: #https://github.com/pytorch/pytorch/pull/128166 CC @nWEIdia @atalman Pull Request resolved: https://github.com/pytorch/pytorch/pull/128177 Approved by: https://github.com/nWEIdia, https://github.com/Skylion007 --- aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp | 2 +- aten/src/ATen/native/quantized/cudnn/Conv.cpp | 2 +- aten/src/ATen/native/quantized/cudnn/Linear.cpp | 2 +- test/quantization/core/test_quantized_op.py | 1 - 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp index 07ccc19c4828..9e9e675e7a3c 100644 --- a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp @@ -242,7 +242,7 @@ Tensor add(Tensor qa, Tensor qb, double output_scale, int64_t output_zero_point) run(plan_desc); execution_plan_cache[key] = plan_desc; return quantized_output.view(orig_sizes); - } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;} + } catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {} } TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Add Cudnn"); diff --git a/aten/src/ATen/native/quantized/cudnn/Conv.cpp b/aten/src/ATen/native/quantized/cudnn/Conv.cpp index 606d769fe6eb..8823038da48b 100644 --- a/aten/src/ATen/native/quantized/cudnn/Conv.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Conv.cpp @@ -252,7 +252,7 @@ void PackedConvWeightCudnn::apply_impl_helper(const at::Tensor& qua run(plan); execution_plan_cache.emplace(key, plan); return; - } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;} + } catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {} } TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Conv2D Cudnn"); diff --git a/aten/src/ATen/native/quantized/cudnn/Linear.cpp b/aten/src/ATen/native/quantized/cudnn/Linear.cpp index d3219592e25b..54eb08443c48 100644 --- a/aten/src/ATen/native/quantized/cudnn/Linear.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Linear.cpp @@ -286,7 +286,7 @@ void PackedLinearWeightCudnn::apply_impl_helper(const at::Tensor& quantized_outp run(plan); execution_plan_cache.emplace(key, plan); return; - } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;} + } catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {} } TORCH_CHECK(false, "Unable to find an engine to execute this computation Quantized Linear Cudnn"); diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 5b86693e11c1..6671b6634e00 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4052,7 +4052,6 @@ def test_qlinear_with_input_q_dq_qweight_dq_output_fp32( use_channelwise=st.sampled_from([False])) # channelwise currently not supported for qlinear cudnn @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skipIf(TEST_CUDNN and torch.backends.cudnn.version() == 90100, "expected failure on cuDNN 9.1.0") @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") @unittest.skipIf(TEST_ROCM, "not supported on rocm.") # TODO: check with yang regarding CUDNN flags From 205410cb44efafadbd3af4a9238b30d2a03a10b8 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Tue, 11 Jun 2024 18:13:01 +0000 Subject: [PATCH 642/706] add xpu to torch.tensors (#127280) As support for Intel GPU has been upstreamed, this PR is to add the XPU-related contents to torch.tensors doc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127280 Approved by: https://github.com/svekars --- docs/source/tensors.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 7bfa8704f5e5..3f9a96ac7da6 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -779,4 +779,5 @@ Tensor class reference Tensor.where Tensor.xlogy Tensor.xlogy_ + Tensor.xpu Tensor.zero_ From 984b1a8c354a0e10f9994a0803f969ab0f6131a3 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 6 Jun 2024 18:52:15 -0700 Subject: [PATCH 643/706] Fix 'get_attr' call in dynamo 'run_node' (#127696) Fixes #124858 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127696 Approved by: https://github.com/jansel ghstack dependencies: #127695 --- test/dynamo/test_decorators.py | 22 ++++++++++++++++++++++ torch/_dynamo/utils.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 16853bbfb3ad..487d8bc52312 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -464,6 +464,28 @@ def fn(a, b, c): self.assertEqual(cnt.frame_count, 1) + def test_assume_constant_result_on_user_defined_fn(self): + @torch._dynamo.assume_constant_result + def const_fn(n, s): + return torch.full([n], s) + + def fn(B): + B = const_fn(B.size(0), 13) + X = B * 2 + return X.tolist() + + B_list = [8] * 32 + + B = torch.tensor(B_list, dtype=torch.int32) + torch._dynamo.decorators.mark_static(B, 0) + + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + self.assertEqual( + fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B) + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 59500f0338a1..41131be10554 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1906,7 +1906,7 @@ def make_error_message(e): assert nnmodule is not None return nnmodule(*args, **kwargs) elif op == "get_attr": - return tracer.get_submodule(node.target) + return tracer.output_graph.get_submodule(node.target) elif op == "placeholder": assert "example_value" in node.meta return node.meta["example_value"] From 61f922c2cab1fddc9ff8dd0c7612990a2d9ba2dc Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 6 Jun 2024 18:52:15 -0700 Subject: [PATCH 644/706] Fix 'get_real_value' on placeholder nodes (#127698) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127698 Approved by: https://github.com/jansel ghstack dependencies: #127695, #127696 --- test/dynamo/test_decorators.py | 16 ++++++++++++++++ torch/_dynamo/utils.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 487d8bc52312..c13fcd31dab7 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -486,6 +486,22 @@ def fn(B): fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B) ) + def test_assume_constant_result_on_computation_with_graph_input(self): + @torch._dynamo.assume_constant_result + def check(y): + return y[0].item() == 1 + + def fn(x, y): + if check(y): + return x + 2 + else: + return x + 1 + + y = torch.tensor([1]) + x = torch.tensor(1) + + self.assertEqual(fn(x, y), torch.compile(fn)(x, y)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 41131be10554..fe2f096ec488 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1941,6 +1941,9 @@ def get_real_value(node, tracer): lambda n: get_real_value(n, tracer), ) + if op == "placeholder" and "grapharg" in node.meta: + return node.meta["grapharg"].example + if op == "call_module": nn_module = tracer.output_graph.nn_modules[node.target] if not is_lazy_module(nn_module): From 3e091237974719f00879a1f4621d3ce6549d2f01 Mon Sep 17 00:00:00 2001 From: yuqingj Date: Mon, 10 Jun 2024 20:48:01 -0700 Subject: [PATCH 645/706] Enable UFMT on test_nestedtensor.py (#128359) split it into two PRs since it is more than 2k lines of change Pull Request resolved: https://github.com/pytorch/pytorch/pull/128359 Approved by: https://github.com/davidberard98 --- .lintrunner.toml | 1 - test/test_nestedtensor.py | 2059 +++++++++++++++++++++++++------------ 2 files changed, 1402 insertions(+), 658 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 7d9cfa39916e..5ccab63f487e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1099,7 +1099,6 @@ exclude_patterns = [ 'test/test_namedtuple_return_api.py', 'test/test_native_functions.py', 'test/test_native_mha.py', - 'test/test_nestedtensor.py', 'test/test_nn.py', 'test/test_out_dtype_op.py', 'test/test_overrides.py', diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index ca50c93dd260..5524658b0123 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -2,20 +2,31 @@ import io import itertools +import math import sys -from typing import Optional, Tuple import unittest from functools import partial -import math +from typing import Optional, Tuple import numpy as np + import torch import torch._dynamo import torch._dynamo.testing import torch.nn import torch.nn.functional as F + +from torch.nested._internal.nested_tensor import ( + buffer_from_jagged, + jagged_from_list, + nested_view_from_values_offsets, + NestedTensor, + ViewNestedFromBuffer, +) from torch.testing._internal.common_cuda import ( - SM70OrLater, SM80OrLater, PLATFORM_SUPPORTS_FUSED_ATTENTION, + PLATFORM_SUPPORTS_FUSED_ATTENTION, + SM70OrLater, + SM80OrLater, ) from torch.testing._internal.common_device_type import ( dtypes, @@ -23,10 +34,10 @@ instantiate_device_type_tests, onlyCPU, onlyCUDA, + PYTORCH_CUDA_MEMCHECK, skipCUDAIf, skipCUDAIfRocm, skipMeta, - PYTORCH_CUDA_MEMCHECK, ) from torch.testing._internal.common_dtype import floating_types_and_half from torch.testing._internal.common_utils import ( @@ -36,23 +47,15 @@ instantiate_parametrized_tests, IS_FBCODE, IS_WINDOWS, + markDynamoStrictTest, parametrize, run_tests, skipIfSlowGradcheckEnv, skipIfTorchDynamo, - markDynamoStrictTest, - xfailIfTorchDynamo, subtest, TEST_WITH_ROCM, TestCase, -) - -from torch.nested._internal.nested_tensor import ( - buffer_from_jagged, - jagged_from_list, - NestedTensor, - nested_view_from_values_offsets, - ViewNestedFromBuffer, + xfailIfTorchDynamo, ) # Tests are ported from pytorch/nestedtensor. @@ -63,6 +66,7 @@ def _iter_constructors(): # yield as_nested_tensor yield torch.nested.nested_tensor + # Helper function to generate a pair of random nested tensors # one is contiguous, the other is not, but they appear to have same entries # an output nested tensor consists of @@ -84,6 +88,7 @@ def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16 nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2) return nt_contiguous, nt_noncontiguous + # Helper functions to pad a noncontiguous nested tensor # can be replaced once to_padded_tensor supports noncontiguous memory @@ -110,10 +115,19 @@ def noncontiguous_to_padded_tensor(input, shape=None): view.copy_(tensor) return result + # Helper function to generate a random nested tensor -def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided, require_non_empty=True): +def random_nt( + device, + dtype, + num_tensors, + max_dims, + min_dims=None, + layout=torch.strided, + require_non_empty=True, +): if min_dims is None: min_dims = tuple([0] * len(max_dims)) @@ -122,9 +136,9 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch. assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim" assert min_dim >= 0, "random_nt: min_dim must be non-negative" if require_non_empty: - assert not (min_dim == 0 and max_dim == 1), ( - "random_nt: zero cannot be the only possible value if require_non_empty is True" - ) + assert not ( + min_dim == 0 and max_dim == 1 + ), "random_nt: zero cannot be the only possible value if require_non_empty is True" if require_non_empty: # Select a random idx that will be required to be non-empty @@ -137,7 +151,9 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch. new_min_dim = min_dim if require_non_empty and i == non_zero_idx and min_dim == 0: new_min_dim = 1 - tensor_dims.append(torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item()) + tensor_dims.append( + torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item() + ) t1 = torch.randn(tensor_dims, device=device, dtype=dtype) ts1.append(t1) @@ -147,14 +163,23 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch. # Alternate approach to generating a random NT. # dims should be something like [5, None, 10], with None indicating that a # random ragged structure should be used -def random_nt_from_dims(dims, device=None, dtype=None, layout=torch.strided, requires_grad=False): +def random_nt_from_dims( + dims, device=None, dtype=None, layout=torch.strided, requires_grad=False +): sizes = [ - [d if d is not None else torch.randint(2, 10, size=(1,)).item() for d in dims[1:]] + [ + d if d is not None else torch.randint(2, 10, size=(1,)).item() + for d in dims[1:] + ] for d in range(dims[0]) ] - return torch.nested.nested_tensor([ - torch.randn(*size) for size in sizes - ], device=device, dtype=dtype, layout=layout, requires_grad=requires_grad) + return torch.nested.nested_tensor( + [torch.randn(*size) for size in sizes], + device=device, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + ) # Creates an NT matching another NT's number of components and @@ -176,9 +201,9 @@ def random_nt_from_similar(other, dims=None): ret_size.append(d) ret_sizes.append(ret_size) - return torch.nested.nested_tensor([ - torch.randn(*size) for size in ret_sizes - ], device=other.device) + return torch.nested.nested_tensor( + [torch.randn(*size) for size in ret_sizes], device=other.device + ) # makes naming nice for tests that parametrize over layout. @@ -236,8 +261,7 @@ def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size): nested_tensor_list = nested_tensor.unbind() for id in range(batch_size): self.assertEqual( - nested_tensor_list[id], - nested_tensor_ref_list[id].type(torch.int64) + nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64) ) @parametrize("batch_size", [2, 4]) @@ -259,8 +283,7 @@ def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size): nested_tensor_list = nested_tensor.unbind() for id in range(batch_size): self.assertEqual( - nested_tensor_list[id], - nested_tensor_ref_list[id].type(torch.int64) + nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64) ) @parametrize("batch_size", [2, 4]) @@ -284,11 +307,9 @@ def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size): nested_tensor_list = nested_tensor.unbind() for id in range(batch_size): self.assertEqual( - nested_tensor_list[id], - nested_tensor_ref_list[id].type(torch.float) + nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.float) ) - @torch.inference_mode() def _test_unbind_case(self, a, b): nt = torch.nested.nested_tensor([a, b]) @@ -308,25 +329,29 @@ def _test_unbind_case(self, a, b): @torch.inference_mode() def test_unbind_0(self): self._test_unbind_case( - torch.tensor([1, 2]), torch.tensor([7, 8]), + torch.tensor([1, 2]), + torch.tensor([7, 8]), ) @torch.inference_mode() def test_unbind_1(self): self._test_unbind_case( - torch.tensor([1]), torch.tensor([7]), + torch.tensor([1]), + torch.tensor([7]), ) @torch.inference_mode() def test_unbind_3(self): self._test_unbind_case( - torch.tensor([1.0]), torch.tensor([]), + torch.tensor([1.0]), + torch.tensor([]), ) @torch.inference_mode() def test_unbind_4(self): self._test_unbind_case( - torch.tensor([]), torch.tensor([]), + torch.tensor([]), + torch.tensor([]), ) @torch.inference_mode() @@ -345,7 +370,9 @@ def _test_fn(unbind_fn): @torch.inference_mode() def test_nested_tensor(self): - self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0]))) + self.assertRaises( + TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0])) + ) self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0)) @torch.inference_mode() @@ -434,18 +461,22 @@ def test_size_dim(self): a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)]) self.assertEqual(a.size(0), 2) - a = torch.nested.nested_tensor([torch.rand(1, 2), - torch.rand(1, 8)]) + a = torch.nested.nested_tensor([torch.rand(1, 2), torch.rand(1, 8)]) self.assertEqual(a.size(0), 2) self.assertEqual(a.size(1), 1) self.assertRaisesRegex( - RuntimeError, "Given dimension 2 is irregular and does not have a size", lambda: a.size(2)) + RuntimeError, + "Given dimension 2 is irregular and does not have a size", + lambda: a.size(2), + ) - a = torch.nested.nested_tensor([torch.rand(3, 4), - torch.rand(5, 4)]) + a = torch.nested.nested_tensor([torch.rand(3, 4), torch.rand(5, 4)]) self.assertEqual(a.size(0), 2) self.assertRaisesRegex( - RuntimeError, "Given dimension 1 is irregular and does not have a size", lambda: a.size(1)) + RuntimeError, + "Given dimension 1 is irregular and does not have a size", + lambda: a.size(1), + ) self.assertEqual(a.size(2), 4) @unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.") @@ -478,8 +509,12 @@ def test_is_contiguous(self): self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous()) # Test querying by memory_format - self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)) - self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)) + self.assertTrue( + nt_contiguous.is_contiguous(memory_format=torch.contiguous_format) + ) + self.assertTrue( + not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format) + ) @torch.inference_mode() def test_repr_string(self): @@ -499,7 +534,6 @@ def test_repr_string(self): self.assertEqual(repr(a), expected) def test_to_padded_tensor_on_empty_tensor(self): - nt = torch.nested.nested_tensor([]) empty = torch.nested.to_padded_tensor(nt, 4) self.assertEqual(empty, torch.tensor([])) @@ -512,7 +546,7 @@ def test_nested_namespace(self): def test_to(self): ntensors = 4 - nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) def test_copy_behavior(t, non_blocking=False): self.assertIs(t, t.to(t, non_blocking=non_blocking)) @@ -520,113 +554,141 @@ def test_copy_behavior(t, non_blocking=False): self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking)) self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True)) self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True)) - self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)) + self.assertIsNot( + t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True) + ) devices = [t.device] - if t.device.type == 'cuda': + if t.device.type == "cuda": if t.device.index == -1: - devices.append(f'cuda:{torch.cuda.current_device()}') + devices.append(f"cuda:{torch.cuda.current_device()}") elif t.device.index == torch.cuda.current_device(): - devices.append('cuda') + devices.append("cuda") for device in devices: self.assertIs(t, t.to(device, non_blocking=non_blocking)) self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking)) self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True)) - self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)) + self.assertIsNot( + t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True) + ) test_copy_behavior(nt) - self.assertEqual(nt.device, nt.to('cpu').device) - self.assertEqual(nt.device, nt.to('cpu', dtype=torch.float32).device) - self.assertIs(torch.float32, nt.to('cpu', dtype=torch.float32).dtype) + self.assertEqual(nt.device, nt.to("cpu").device) + self.assertEqual(nt.device, nt.to("cpu", dtype=torch.float32).device) + self.assertIs(torch.float32, nt.to("cpu", dtype=torch.float32).dtype) self.assertEqual(nt.device, nt.to(torch.float32).device) self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype) def test_data_ptr(getter): - self.assertEqual(getter(nt), getter(nt.to('cpu'))) - self.assertEqual(getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False))) - self.assertEqual(getter(nt), getter(nt.to('cpu', copy=False))) - self.assertNotEqual(getter(nt), getter(nt.to('cpu', copy=True))) + self.assertEqual(getter(nt), getter(nt.to("cpu"))) + self.assertEqual( + getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False)) + ) + self.assertEqual(getter(nt), getter(nt.to("cpu", copy=False))) + self.assertNotEqual(getter(nt), getter(nt.to("cpu", copy=True))) test_data_ptr(lambda nt: nt.data_ptr()) if torch.cuda.is_available(): for non_blocking in [True, False]: - for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: + for cuda in [ + "cuda", + "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1", + ]: nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4)) test_copy_behavior(nt2, non_blocking) - self.assertEqual(nt2.device, nt2.to(cuda, non_blocking=non_blocking).device) - self.assertEqual(nt.device, nt2.to('cpu', non_blocking=non_blocking).device) - self.assertEqual(nt2.device, nt.to(cuda, non_blocking=non_blocking).device) - self.assertIs(torch.int32, nt2.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype) - self.assertEqual(nt.device, nt2.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device) + self.assertEqual( + nt2.device, nt2.to(cuda, non_blocking=non_blocking).device + ) + self.assertEqual( + nt.device, nt2.to("cpu", non_blocking=non_blocking).device + ) + self.assertEqual( + nt2.device, nt.to(cuda, non_blocking=non_blocking).device + ) + self.assertIs( + torch.int32, + nt2.to( + "cpu", dtype=torch.int32, non_blocking=non_blocking + ).dtype, + ) + self.assertEqual( + nt.device, + nt2.to( + "cpu", dtype=torch.int32, non_blocking=non_blocking + ).device, + ) self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype) self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device) def test_copy_(self): ntensors = 4 - nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) nt_copy = torch.empty_like(nt) nt_copy.copy_(nt) - for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): + for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): self.assertEqual(nt_ub, nt_copy_ub) nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])]) self.assertRaisesRegex( RuntimeError, "copy_ only supports tensors that are the same size for Nested implementations", - lambda: nt_error.copy_(nt) + lambda: nt_error.copy_(nt), ) if torch.cuda.is_available(): - nt = random_nt(torch.device('cuda'), torch.float32, ntensors, (4, 4)) - nt_copy = torch.empty_like(nt, device=torch.device('cpu')) + nt = random_nt(torch.device("cuda"), torch.float32, ntensors, (4, 4)) + nt_copy = torch.empty_like(nt, device=torch.device("cpu")) nt_copy.copy_(nt, non_blocking=True) torch.cuda.current_stream(torch.cuda.current_device()).synchronize() - for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): + for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): self.assertEqual(nt_ub, nt_copy_ub) - nt_copy = torch.empty_like(nt, device=torch.device('cpu')) + nt_copy = torch.empty_like(nt, device=torch.device("cpu")) nt_copy.copy_(nt, non_blocking=False) - for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): + for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): self.assertEqual(nt_ub, nt_copy_ub) def test_fill_(self): ntensors = 4 - nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) - nt.fill_(10.) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) + nt.fill_(10.0) for nt_ub in nt.unbind(): t = torch.empty_like(nt_ub) - t.fill_(10.) + t.fill_(10.0) self.assertEqual(nt_ub, t) - fill_tensor = torch.tensor([11.]) + fill_tensor = torch.tensor([11.0]) self.assertRaisesRegex( RuntimeError, "fill_ only supports 0-dimension value tensor", - lambda: nt.fill_(fill_tensor) + lambda: nt.fill_(fill_tensor), ) nt.fill_(fill_tensor[0]) for nt_ub in nt.unbind(): t = torch.empty_like(nt_ub) - t.fill_(11.) + t.fill_(11.0) self.assertEqual(nt_ub, t) def test_zero_(self): ntensors = 4 - nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) nt.zero_() for nt_ub in nt.unbind(): t = torch.empty_like(nt_ub) - t.fill_(0.) + t.fill_(0.0) self.assertEqual(nt_ub, t) - @parametrize("func", [torch.ones_like, torch.zeros_like, torch.randn_like], - name_fn=lambda f: f.__name__) + @parametrize( + "func", + [torch.ones_like, torch.zeros_like, torch.randn_like], + name_fn=lambda f: f.__name__, + ) def test_like_functions(self, func): ntensors = 4 - nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) torch.manual_seed(1) nt_like = func(nt) @@ -642,7 +704,8 @@ def test_cat(self): y = random_nt_from_dims([3, 4, None]) output = torch.cat([x, y], dim=0) for out_component, xy_component in zip( - output.unbind(), itertools.chain(x.unbind(), y.unbind())): + output.unbind(), itertools.chain(x.unbind(), y.unbind()) + ): self.assertEqual(out_component, xy_component) # dim=-1 success case @@ -652,29 +715,40 @@ def test_cat(self): y = random_nt_from_similar(x, dims=[-1, -1, 8]) # should be shape (B, *, D + D') when supported output = torch.cat([x, y], dim=-1) - for out_component, x_component, y_component in zip(output.unbind(), x.unbind(), y.unbind()): - self.assertEqual(out_component, torch.cat([x_component, y_component], dim=-1)) + for out_component, x_component, y_component in zip( + output.unbind(), x.unbind(), y.unbind() + ): + self.assertEqual( + out_component, torch.cat([x_component, y_component], dim=-1) + ) # dim between 0 and -1 success case x = random_nt_from_dims([5, None, 2, 3]) # same structure as x but dim=2 differs y = random_nt_from_similar(x, dims=[-1, -1, 4, -1]) output = torch.cat([x, y], dim=2) - for out_component, x_component, y_component in zip(output.unbind(), x.unbind(), y.unbind()): - self.assertEqual(out_component, torch.cat([x_component, y_component], dim=1)) + for out_component, x_component, y_component in zip( + output.unbind(), x.unbind(), y.unbind() + ): + self.assertEqual( + out_component, torch.cat([x_component, y_component], dim=1) + ) # error case: mixed NT / dense inputs x = random_nt_from_dims([5, None, 2]) y = torch.randn(5, 3, 2) with self.assertRaisesRegex( - RuntimeError, "expected each tensor in given list to be nested"): + RuntimeError, "expected each tensor in given list to be nested" + ): torch.cat([x, y], dim=-1) # error case: NTs with different dims x = random_nt_from_dims([5, None, 2]) y = random_nt_from_dims([5, None, 2, 3]) with self.assertRaisesRegex( - RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"): + RuntimeError, + "expected all nested tensors to have matching ragged structures outside of the concatenated dim", + ): torch.cat([x, y], dim=-1) # error case: non-contiguous NT @@ -682,43 +756,56 @@ def test_cat(self): # transpose to put ragged dim next to batch dim x, y = x.transpose(-2, -1), y.transpose(-2, -1) with self.assertRaisesRegex( - RuntimeError, "only contiguous nested tensors are supported"): + RuntimeError, "only contiguous nested tensors are supported" + ): torch.cat([x, y], dim=-1) # error case: multiple ragged dims in inputs x = random_nt_from_dims([5, None, None, 2]) y = random_nt_from_similar(x) with self.assertRaisesRegex( - RuntimeError, "only nested tensors with a single ragged dim next to the batch dim are supported"): + RuntimeError, + "only nested tensors with a single ragged dim next to the batch dim are supported", + ): torch.cat([x, y], dim=-1) # error case: ragged dim not next to batch dim x = random_nt_from_dims([5, 2, None]) y = random_nt_from_similar(x) with self.assertRaisesRegex( - RuntimeError, "only nested tensors with a single ragged dim next to the batch dim are supported"): + RuntimeError, + "only nested tensors with a single ragged dim next to the batch dim are supported", + ): torch.cat([x, y], dim=1) # error case: NTs with different batch sizes x = random_nt_from_dims([5, None, 2]) y = random_nt_from_dims([3, None, 2]) with self.assertRaisesRegex( - RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"): + RuntimeError, + "expected all nested tensors to have matching ragged structures outside of the concatenated dim", + ): torch.cat([x, y], dim=-1) # error case: NTs with different ragged structures - x = torch.nested.nested_tensor([ - torch.randn(2, 6), - torch.randn(4, 6), - torch.randn(5, 6), - ]) - y = torch.nested.nested_tensor([ - torch.randn(5, 6), - torch.randn(4, 6), - torch.randn(2, 6), - ]) + x = torch.nested.nested_tensor( + [ + torch.randn(2, 6), + torch.randn(4, 6), + torch.randn(5, 6), + ] + ) + y = torch.nested.nested_tensor( + [ + torch.randn(5, 6), + torch.randn(4, 6), + torch.randn(2, 6), + ] + ) with self.assertRaisesRegex( - RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"): + RuntimeError, + "expected all nested tensors to have matching ragged structures outside of the concatenated dim", + ): torch.cat([x, y], dim=-1) @@ -730,13 +817,20 @@ def random_nt_pair(self, device, dtype, num_tensors, max_dims): ts1 = [] ts2 = [] for _ in range(num_tensors): - tensor_dims = tuple([torch.randint(low=0, high=max_dim, size=(1,)).item() for max_dim in max_dims]) + tensor_dims = tuple( + [ + torch.randint(low=0, high=max_dim, size=(1,)).item() + for max_dim in max_dims + ] + ) t1 = torch.randn(tensor_dims, device=device, dtype=dtype) t2 = torch.randn(tensor_dims, device=device, dtype=dtype) ts1.append(t1) ts2.append(t2) - return (torch.nested.nested_tensor(ts1, device=device, dtype=dtype), - torch.nested.nested_tensor(ts2, device=device, dtype=dtype)) + return ( + torch.nested.nested_tensor(ts1, device=device, dtype=dtype), + torch.nested.nested_tensor(ts2, device=device, dtype=dtype), + ) @dtypes(*floating_types_and_half()) def test_detach(self, device, dtype): @@ -768,7 +862,9 @@ def test_detach(self, device, dtype): @dtypes(torch.float, torch.float16, torch.double) def test_unbind_noncontiguous(self, device, dtype): - nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) ub_contiguous = nt_contiguous.unbind() ub_noncontiguous = nt_noncontiguous.unbind() self.assertEqual(len(ub_contiguous), len(ub_noncontiguous)) @@ -787,7 +883,7 @@ def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype): nt_to = torch._nested_from_padded_and_nested_example(padded, nt) - for (t1, t2) in zip(nt.unbind(), nt_to.unbind()): + for t1, t2 in zip(nt.unbind(), nt_to.unbind()): self.assertEqual(t1, t2) self.assertEqual(nt.device, nt_to.device) @@ -804,7 +900,7 @@ def _test(size): nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) nt_result = layer_norm(nt) - for (nt_subresult, t) in zip(nt_result.unbind(), ts): + for nt_subresult, t in zip(nt_result.unbind(), ts): t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) self.assertEqual(nt_subresult, t_result) @@ -816,28 +912,36 @@ def _test(size): nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) nt_result = layer_norm(nt) - for (nt_subresult, t) in zip(nt_result.unbind(), ts): + for nt_subresult, t in zip(nt_result.unbind(), ts): t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) self.assertEqual(nt_subresult, t_result) if size <= 128: # Test with multidimensional tensors after irregular dim # (run only with smaller dimensions to ensure fast execution) - t0 = torch.randn(4, size, size, 4, device=device, dtype=dtype, requires_grad=False) - t1 = torch.randn(10, size, size, 4, device=device, dtype=dtype, requires_grad=False) - t2 = torch.randn(7, size, size, 4, device=device, dtype=dtype, requires_grad=False) + t0 = torch.randn( + 4, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + t1 = torch.randn( + 10, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + t2 = torch.randn( + 7, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) ts = [t0, t1, t2, t0, t2] nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) - layer_norm = torch.nn.LayerNorm((size, size, 4), device=device, dtype=dtype) + layer_norm = torch.nn.LayerNorm( + (size, size, 4), device=device, dtype=dtype + ) nt_result = layer_norm(nt) - for (nt_subresult, t) in zip(nt_result.unbind(), ts): + for nt_subresult, t in zip(nt_result.unbind(), ts): t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) self.assertEqual(nt_subresult, t_result) # Test where the normalizing dimensions are not all layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype) nt_result = layer_norm(nt) - for (nt_subresult, t) in zip(nt_result.unbind(), ts): + for nt_subresult, t in zip(nt_result.unbind(), ts): t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) self.assertEqual(nt_subresult, t_result) @@ -850,9 +954,15 @@ def _test(size): @torch.inference_mode() def test_layer_norm_breaking(self, device, dtype): size = 128 - t0 = torch.randn(4, size, size, 4, device=device, dtype=dtype, requires_grad=False) - t1 = torch.randn(10, size, size, 4, device=device, dtype=dtype, requires_grad=False) - t2 = torch.randn(7, size, size, 4, device=device, dtype=dtype, requires_grad=False) + t0 = torch.randn( + 4, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + t1 = torch.randn( + 10, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + t2 = torch.randn( + 7, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) ts = [t0, t1, t2, t0, t2] nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype) @@ -871,7 +981,7 @@ def test_layer_norm_breaking(self, device, dtype): @decorateIf( xfailIfTorchDynamo, # only fails in python 3.11. TODO: Ensure this is fixed once views work! - lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11) + lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11), ) @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) def test_embedding(self, device, layout): @@ -879,14 +989,15 @@ def test_embedding(self, device, layout): torch.randint(100, (L,), device=device, dtype=torch.int64) for L in torch.randint(5, 50, (8,)) ] - x = torch.nested.nested_tensor(inputs, device=device, dtype=torch.int64, layout=layout) + x = torch.nested.nested_tensor( + inputs, device=device, dtype=torch.int64, layout=layout + ) emb = torch.nn.Embedding(100, 8, device=device) y = emb(x) ys = y.unbind() for i, inp in enumerate(inputs): self.assertEqual(emb(inp), ys[i]) - @skipMeta @torch.inference_mode() @dtypes(*floating_types_and_half()) @@ -894,11 +1005,12 @@ def test_masked_fill(self, device, dtype): # nested tensor * nested tensor (nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4)) mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()]) - ref = torch.nested.nested_tensor([t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())]) + ref = torch.nested.nested_tensor( + [t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())] + ) out = nt.masked_fill(mask, 0) self.assertEqual(ref, out) - @dtypes(torch.float, torch.float16) def test_to_padded_tensor_simple(self, device, dtype): t = torch.randn(4, 4, 4, device=device, dtype=dtype) @@ -926,8 +1038,12 @@ def test_to_padded_tensor_output_size(self, device, dtype): ts[0] = ts[0][:-1] nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) for padding_value in (0, 1): - padded = torch.nested.to_padded_tensor(nt, padding_value, output_size=output_size) - correct_output = torch.ones(output_size, device=device, dtype=dtype) * padding_value + padded = torch.nested.to_padded_tensor( + nt, padding_value, output_size=output_size + ) + correct_output = ( + torch.ones(output_size, device=device, dtype=dtype) * padding_value + ) correct_output[:4:, :4, :4] = t.clone() if padding_value == 0: correct_output[0][3] = torch.zeros_like(correct_output[0][3]) @@ -951,7 +1067,7 @@ def test_to_padded_tensor_dim2(self, device, dtype): for t in ts: next_output = torch.ones_like(ts[2]) * pad correct_output.append(next_output) - next_output[:t.size(0)].copy_(t) + next_output[: t.size(0)].copy_(t) correct_output = torch.stack(correct_output) padded = torch.nested.to_padded_tensor(nt, pad) self.assertEqual(padded, correct_output) @@ -969,7 +1085,7 @@ def test_to_padded_tensor_dim3(self, device, dtype): for t in ts: next_output = torch.ones_like(ts[2]) * pad correct_output.append(next_output) - next_output[:t.size(0), :t.size(1)].copy_(t) + next_output[: t.size(0), : t.size(1)].copy_(t) correct_output = torch.stack(correct_output) padded = torch.nested.to_padded_tensor(nt, pad) self.assertEqual(padded, correct_output) @@ -987,7 +1103,7 @@ def test_to_padded_tensor_dim4(self, device, dtype): for t in ts: next_output = torch.ones_like(ts[2]) * pad correct_output.append(next_output) - next_output[:t.size(0), :t.size(1), :t.size(2)].copy_(t) + next_output[: t.size(0), : t.size(1), : t.size(2)].copy_(t) correct_output = torch.stack(correct_output) padded = torch.nested.to_padded_tensor(nt, pad) self.assertEqual(padded, correct_output) @@ -999,22 +1115,25 @@ def test_to_padded_tensor_dim4(self, device, dtype): @dtypes(torch.float, torch.float16, torch.double) @torch.inference_mode() def test_to_padded_tensor_noncontiguous(self, device, dtype): - nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) # test noncontiguous_to_padded_tensor functionality self.assertEqual( torch.nested.to_padded_tensor(nt_contiguous, 0.0), - noncontiguous_to_padded_tensor(nt_noncontiguous)) + noncontiguous_to_padded_tensor(nt_noncontiguous), + ) # test to_padded_tensor error message self.assertRaisesRegex( RuntimeError, r"for now to_padded_tensor only supports contiguous nested tensor", - lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0) + lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0), ) @skipMeta def test_device_checks(self, device): nt = torch.nested.nested_tensor([], device=device) - is_cuda = 'cuda' in str(device) + is_cuda = "cuda" in str(device) self.assertEqual(nt.is_cuda, is_cuda) @dtypes(torch.float, torch.float16, torch.double) @@ -1062,26 +1181,35 @@ def test_nested_tensor_indexing(self, device, dtype): self.assertEqual(nt[-1], x1) grad_x0 = torch.randn((2, 5), device=device, dtype=dtype) nt[0].backward(grad_x0) - expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)]) + expected_grad = torch.nested.nested_tensor( + [grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)] + ) self.assertEqual(nt.grad, expected_grad) - @parametrize("func", [subtest(torch.nn.functional.relu, name='relu'), - subtest(torch.nn.functional.relu_, name='relu_'), - subtest(torch.nn.functional.gelu, name='gelu'), - subtest(torch._C._nn.gelu_, name='gelu_'), - subtest(torch.tanh, name='tanh'), - subtest(torch.tanh_, name='tanh_'), - subtest(torch.neg, name='neg'), - subtest(torch.nn.functional.silu, name='silu'), - subtest(partial(torch.nn.functional.silu, inplace=True), name='silu_'), - subtest(torch.abs, name="abs"), - subtest(torch.abs_, name="abs_"), - subtest(torch.sgn, name="sgn"), - subtest(torch.logical_not, name='logical_not'), - subtest(torch.sin, name='sin'), - subtest(torch.cos, name='cos')]) + @parametrize( + "func", + [ + subtest(torch.nn.functional.relu, name="relu"), + subtest(torch.nn.functional.relu_, name="relu_"), + subtest(torch.nn.functional.gelu, name="gelu"), + subtest(torch._C._nn.gelu_, name="gelu_"), + subtest(torch.tanh, name="tanh"), + subtest(torch.tanh_, name="tanh_"), + subtest(torch.neg, name="neg"), + subtest(torch.nn.functional.silu, name="silu"), + subtest(partial(torch.nn.functional.silu, inplace=True), name="silu_"), + subtest(torch.abs, name="abs"), + subtest(torch.abs_, name="abs_"), + subtest(torch.sgn, name="sgn"), + subtest(torch.logical_not, name="logical_not"), + subtest(torch.sin, name="sin"), + subtest(torch.cos, name="cos"), + ], + ) def test_activations(self, device, func): - nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32) + nt, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device=device, dtype=torch.float32 + ) nested_result = func(nt) self.assertTrue(nested_result.is_nested) for t, t_res in zip(nt.unbind(), nested_result.unbind()): @@ -1089,13 +1217,14 @@ def test_activations(self, device, func): self.assertRaisesRegex( RuntimeError, "NestedTensor must be contiguous to get buffer.", - lambda: func(nt_noncontiguous)) + lambda: func(nt_noncontiguous), + ) - @parametrize("func", [subtest(torch.ge, name='ge'), - subtest(torch.eq, name='eq')]) + @parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")]) def test_binary_ops_with_scalar(self, device, func): nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( - (2, 3, 6, 7), device=device, dtype=torch.float32) + (2, 3, 6, 7), device=device, dtype=torch.float32 + ) scalar = 0.0 # should work regardless of contiguity @@ -1131,30 +1260,42 @@ def test_nested_tensor_chunk(self, device, dtype): # Failure chunking on ragged dimensions self.assertRaisesRegex( - RuntimeError, "Chunk for nested tensors is currently only supported for the last dimension.", - lambda: torch.chunk(nt, 5, dim=1)) + RuntimeError, + "Chunk for nested tensors is currently only supported for the last dimension.", + lambda: torch.chunk(nt, 5, dim=1), + ) self.assertRaisesRegex( - RuntimeError, "Chunk for nested tensors is currently only supported for the last dimension.", - lambda: torch.chunk(nt, 5, dim=0)) + RuntimeError, + "Chunk for nested tensors is currently only supported for the last dimension.", + lambda: torch.chunk(nt, 5, dim=0), + ) # Failure on non-contiguous nt _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) self.assertRaisesRegex( - RuntimeError, "chunk expects `self` to be contiguous.", lambda: torch.chunk(nt_noncontiguous, 5, dim=-1)) + RuntimeError, + "chunk expects `self` to be contiguous.", + lambda: torch.chunk(nt_noncontiguous, 5, dim=-1), + ) # Failure when calling non divisible n_chunks self.assertRaisesRegex( - RuntimeError, "Chunk for nested tensors is only supported for " + RuntimeError, + "Chunk for nested tensors is only supported for " "nested tensors with trailing dimension divisible by chunks.", - lambda: torch.chunk(nt, 5, dim=-1)) + lambda: torch.chunk(nt, 5, dim=-1), + ) # Failure when calling backward on a chunk a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True) b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True) nt_grad = torch.nested.as_nested_tensor([a, b]) chunked = torch.chunk(nt_grad, 2, dim=-1) - self.assertRaisesRegex(RuntimeError, "derivative for aten::chunk is not implemented", - lambda: chunked[0].backward(chunked[0].clone())) + self.assertRaisesRegex( + RuntimeError, + "derivative for aten::chunk is not implemented", + lambda: chunked[0].backward(chunked[0].clone()), + ) @dtypes(*floating_types_and_half()) def test_nested_tensor_split_with_sizes(self, device, dtype): @@ -1171,42 +1312,56 @@ def test_nested_tensor_split_with_sizes(self, device, dtype): nt_splits = nt.split_with_sizes(split_sizes, dim=-1) for i, nt_split in enumerate(nt_splits): - self.assertEqual(nt_split, torch.nested.nested_tensor( - [a_splits[i], b_splits[i], c_splits[i]])) - dense_strides = torch.stack([ - torch.tensor(a_splits[i].stride()), - torch.tensor(b_splits[i].stride()), - torch.tensor(c_splits[i].stride()) - ]) + self.assertEqual( + nt_split, + torch.nested.nested_tensor([a_splits[i], b_splits[i], c_splits[i]]), + ) + dense_strides = torch.stack( + [ + torch.tensor(a_splits[i].stride()), + torch.tensor(b_splits[i].stride()), + torch.tensor(c_splits[i].stride()), + ] + ) self.assertEqual(nt_split._nested_tensor_strides(), dense_strides) self.assertFalse(nt_split.is_contiguous()) # Failure calling on ragged dimensions self.assertRaisesRegex( - RuntimeError, "split_with_sizes for nested tensors is currently only supported for the last dimension.", - lambda: torch.split_with_sizes(nt, split_sizes, dim=1)) + RuntimeError, + "split_with_sizes for nested tensors is currently only supported for the last dimension.", + lambda: torch.split_with_sizes(nt, split_sizes, dim=1), + ) # Failure calling on non-last dimension self.assertRaisesRegex( - RuntimeError, "split_with_sizes for nested tensors is currently only supported for the last dimension.", - lambda: torch.split_with_sizes(nt, split_sizes, dim=0)) + RuntimeError, + "split_with_sizes for nested tensors is currently only supported for the last dimension.", + lambda: torch.split_with_sizes(nt, split_sizes, dim=0), + ) # Failure on non-contiguous nt _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) self.assertRaisesRegex( - RuntimeError, "split_with_sizes expects `self` to be contiguous.", - lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1)) + RuntimeError, + "split_with_sizes expects `self` to be contiguous.", + lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1), + ) # Failure when calling with split_sizes that don't cover the full dim size bad_split_sizes = [4, 6, 9] # don't add up to 20 self.assertRaisesRegex( - RuntimeError, "split_with_sizes expects split_sizes to sum exactly to 20", - lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1)) + RuntimeError, + "split_with_sizes expects split_sizes to sum exactly to 20", + lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1), + ) @dtypes(torch.float, torch.float16, torch.double) @torch.inference_mode() def test_nested_tensor_indexing_noncontiguous(self, device, dtype): - nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0)) n = nt_contiguous.size(0) for i in range(n): @@ -1226,7 +1381,9 @@ def test_nested_tensor_add(self, device, dtype, transpose): nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) else: (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) - ref = torch.nested.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) out = nt1 + nt2 self.assertEqual(ref, out) @@ -1244,7 +1401,9 @@ def test_nested_tensor_sub(self, device, dtype, transpose): nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) else: (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) - ref = torch.nested.nested_tensor([t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) out = nt1 - nt2 self.assertEqual(ref, out) @@ -1255,9 +1414,11 @@ def test_nested_tensor_sub(self, device, dtype, transpose): def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim): def _test_add_mul(nt, t): ref_add = torch.nested.nested_tensor( - [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]) + [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())] + ) ref_mul = torch.nested.nested_tensor( - [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]) + [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())] + ) self.assertEqual(nt.add(t), ref_add) self.assertEqual(nt.mul(t), ref_mul) @@ -1282,7 +1443,9 @@ def _test_add_mul(nt, t): def test_nested_tensor_mul(self, device, dtype): # nested tensor * nested tensor (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) - ref = torch.nested.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) out = nt1 * nt2 self.assertEqual(ref, out) # nested tensor * scalar @@ -1302,12 +1465,12 @@ def test_nested_tensor_mul(self, device, dtype): self.assertRaisesRegex( RuntimeError, "Expected both self and other to be nested, but got a nested self and non-nested other", - lambda: nt1.mul(vector) + lambda: nt1.mul(vector), ) self.assertRaisesRegex( RuntimeError, "Expected both self and other to be nested, but got a non-nested self and nested other", - lambda: vector.mul(nt1) + lambda: vector.mul(nt1), ) @dtypes(torch.float, torch.float16) @@ -1323,31 +1486,43 @@ def test_nested_tensor_div(self, device, dtype): out = nt.transpose(1, 2) / 4.0 self.assertEqual(ref_transposed, out) - ref = torch.nested.nested_tensor([t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())] + ) out = nt / nt2 self.assertEqual(ref, out) out = nt.transpose(1, 2) / nt2.transpose(1, 2) self.assertEqual(ref.transpose(1, 2), out) - nt_transpose_copy = torch.nested.nested_tensor([t.transpose(0, 1) for t in nt.unbind()]) + nt_transpose_copy = torch.nested.nested_tensor( + [t.transpose(0, 1) for t in nt.unbind()] + ) self.assertRaisesRegex( - RuntimeError, "div requires strides to match when given NestedTensors", - lambda: nt_transpose_copy.transpose(1, 2) / nt2) + RuntimeError, + "div requires strides to match when given NestedTensors", + lambda: nt_transpose_copy.transpose(1, 2) / nt2, + ) - nt = torch.nested.nested_tensor([torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype) + nt = torch.nested.nested_tensor( + [torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype + ) nt_chunks = nt.chunk(2, -1) self.assertRaisesRegex( - RuntimeError, "div requires offsets to match when given NestedTensors", - lambda: nt_chunks[0] / nt_chunks[1]) + RuntimeError, + "div requires offsets to match when given NestedTensors", + lambda: nt_chunks[0] / nt_chunks[1], + ) @dtypes(torch.float, torch.float16) @skipMeta @torch.inference_mode() def test_nested_tensor_add_in_place(self, device, dtype): (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) - ref = torch.nested.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) nt1 += nt2 self.assertEqual(ref, nt1) @@ -1357,7 +1532,9 @@ def test_nested_tensor_add_in_place(self, device, dtype): def test_nested_tensor_mul_in_place(self, device, dtype): # nested tensor * nested tensor (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) - ref = torch.nested.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) nt1 *= nt2 self.assertEqual(ref, nt1) # nested tensor * scalar @@ -1373,19 +1550,19 @@ def test_nested_tensor_mul_in_place(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]", - lambda: scalar.mul_(nt1) + lambda: scalar.mul_(nt1), ) # error case: numel == 1 but dim > 0 vector = torch.tensor([number]).to(dtype).to(device) self.assertRaisesRegex( RuntimeError, "Expected both self and other to be nested, but got a nested self and non-nested other", - lambda: nt1.mul_(vector) + lambda: nt1.mul_(vector), ) self.assertRaisesRegex( RuntimeError, "Expected both self and other to be nested, but got a non-nested self and nested other", - lambda: vector.mul_(nt1) + lambda: vector.mul_(nt1), ) @onlyCPU @@ -1421,14 +1598,26 @@ def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True): test_sum(device, dtype, ntensors, max_sizes, len(max_sizes)) # Test error inputs - with self.assertRaisesRegex(RuntimeError, "NestedTensor can only be reduced across the last"): - torch.nested.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(0, keepdim=True) + with self.assertRaisesRegex( + RuntimeError, "NestedTensor can only be reduced across the last" + ): + torch.nested.nested_tensor( + [torch.tensor([3, 4, 5]), torch.tensor([1, 2])] + ).sum(0, keepdim=True) - with self.assertRaisesRegex(RuntimeError, "NestedTensor only allows reduction of a single"): - torch.nested.nested_tensor([torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]).sum([0, 1], keepdim=True) + with self.assertRaisesRegex( + RuntimeError, "NestedTensor only allows reduction of a single" + ): + torch.nested.nested_tensor( + [torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])] + ).sum([0, 1], keepdim=True) - with self.assertRaisesRegex(RuntimeError, "NestedTensor always requires keepdim=True for now."): - torch.nested.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(-1) + with self.assertRaisesRegex( + RuntimeError, "NestedTensor always requires keepdim=True for now." + ): + torch.nested.nested_tensor( + [torch.tensor([3, 4, 5]), torch.tensor([1, 2])] + ).sum(-1) @dtypes(torch.float, torch.float16) def test_contiguous(self, device, dtype): @@ -1438,8 +1627,12 @@ def test_contiguous(self, device, dtype): # whose numels is now less than the size of the buffer. Clone was # previously creating a new NT with a buffer that was the same size as the # original. - nt_contiguous = torch.nested.nested_tensor([torch.randn(2, 20, device=device, dtype=dtype), - torch.randn(4, 20, device=device, dtype=dtype)]) + nt_contiguous = torch.nested.nested_tensor( + [ + torch.randn(2, 20, device=device, dtype=dtype), + torch.randn(4, 20, device=device, dtype=dtype), + ] + ) # Split up the last dimension which has a consistent size of 20 into 5 chunks chunks = nt_contiguous.chunk(5, dim=-1) @@ -1551,12 +1744,12 @@ def test_softmax(self, device, dtype): self.assertRaisesRegex( RuntimeError, "Cannot apply softmax across nested dimension 0", - lambda: torch.nn.functional.softmax(nt, 0) + lambda: torch.nn.functional.softmax(nt, 0), ) self.assertRaisesRegex( RuntimeError, "Cannot apply softmax across nested dimension 0", - lambda: torch.nn.functional.softmax(nt, -3) + lambda: torch.nn.functional.softmax(nt, -3), ) # error case: dimension out of range self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3)) @@ -1583,91 +1776,95 @@ def test_softmax(self, device, dtype): @dtypes(torch.float, torch.double) @torch.inference_mode() def test_softmax_noncontiguous(self, device, dtype): - nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) self.assertEqual( torch.nn.functional.softmax(nt_contiguous, -1), - torch.nn.functional.softmax(nt_noncontiguous, -1)) + torch.nn.functional.softmax(nt_noncontiguous, -1), + ) def _test_bmm(self, device, dtype): # error case: not 3D tensors nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) - nt2 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + nt1 = torch.nested.nested_tensor( + [torch.randn(2), torch.randn(3)], device=device, dtype=dtype + ) + nt2 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt0.bmm(nt0) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0) ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt0.bmm(nt1) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1) ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt0.bmm(nt2) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2) ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt1.bmm(nt0) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0) ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt1.bmm(nt1) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1) ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt1.bmm(nt2) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2) ) self.assertRaisesRegex( - RuntimeError, - "batch2 must be a 3D tensor", - lambda: nt2.bmm(nt0) + RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0) ) self.assertRaisesRegex( - RuntimeError, - "batch2 must be a 3D tensor", - lambda: nt2.bmm(nt1) + RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1) ) # error case: incompatible batch size - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), - torch.randn((4, 5)), - torch.randn((4, 7))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))], + device=device, + dtype=dtype, + ) self.assertRaisesRegex( RuntimeError, "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.", - lambda: nt0.bmm(nt1) + lambda: nt0.bmm(nt1), ) self.assertRaisesRegex( RuntimeError, "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.", - lambda: nt1.bmm(nt0) + lambda: nt1.bmm(nt0), ) # error case: underlying matrices cannot be multiplied - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) self.assertRaisesRegex( RuntimeError, r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)", - lambda: nt0.bmm(nt0) + lambda: nt0.bmm(nt0), ) # normal nested tensor - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype + ) actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) - expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0)) + expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm( + torch.nested.to_padded_tensor(nt1, 0.0) + ) if dtype == torch.float16: self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) else: self.assertEqual(actual, expect) # nested tensor bmm normal tensor - nt0 = torch.nested.nested_tensor([torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype + ) nt1 = torch.rand(2, 7, 5, dtype=dtype, device=device) actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1) @@ -1686,10 +1883,11 @@ def _test_bmm(self, device, dtype): else: self.assertEqual(actual, expect) - # normal tensor bmm nested tensor nt0 = torch.rand(2, 5, 7, dtype=dtype, device=device) - nt1 = torch.nested.nested_tensor([torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype) + nt1 = torch.nested.nested_tensor( + [torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype + ) actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) expect = nt0.bmm(torch.nested.to_padded_tensor(nt1, 0.0)) if dtype == torch.float16: @@ -1698,10 +1896,16 @@ def _test_bmm(self, device, dtype): self.assertEqual(actual, expect) # test tensorcore path - nt0 = torch.nested.nested_tensor([torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype + ) actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) - expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0)) + expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm( + torch.nested.to_padded_tensor(nt1, 0.0) + ) if dtype == torch.float16: self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) else: @@ -1721,11 +1925,16 @@ def test_bmm_cpu(self, device, dtype): # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' @dtypes(torch.float, torch.double) def test_bmm_noncontiguous(self, device, dtype): - nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) - nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype) + nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair( + (2, 3), device, dtype + ) + nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair( + (6, 7), device, dtype + ) self.assertEqual( nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), - nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous)) + nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous), + ) @dtypes(torch.float, torch.double) def test_matmul_with_bmm_path(self, device, dtype): @@ -1758,142 +1967,176 @@ def unbind_rebind_matmul(nt1, nt2): seq_len = np.random.randint(2, 5) t3s.append(torch.randn(seq_len, n_heads, head_dim)) t4s.append(torch.randn(seq_len, n_heads, head_dim)) - nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(1, 2) - nt4 = torch.nested.nested_tensor(t4s, device=device, dtype=dtype).transpose(1, 2).transpose(2, 3) + nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose( + 1, 2 + ) + nt4 = ( + torch.nested.nested_tensor(t4s, device=device, dtype=dtype) + .transpose(1, 2) + .transpose(2, 3) + ) self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4)) # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' @dtypes(torch.float, torch.double) def test_matmul(self, device, dtype): # error case: one is nested but the other is not - nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) + nt = torch.nested.nested_tensor( + [torch.randn(2), torch.randn(3)], device=device, dtype=dtype + ) t = torch.randn(4, device=device, dtype=dtype) self.assertRaisesRegex( RuntimeError, "Expected both to be nested, but got a nested self and non-nested other", - lambda: torch.matmul(nt, t) + lambda: torch.matmul(nt, t), ) self.assertRaisesRegex( RuntimeError, "Expected both to be nested, but got a non-nested self and nested other", - lambda: torch.matmul(t, nt) + lambda: torch.matmul(t, nt), ) # error case: not 3+D tensors nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) - nt2 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + nt1 = torch.nested.nested_tensor( + [torch.randn(2), torch.randn(3)], device=device, dtype=dtype + ) + nt2 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt0, nt0) + lambda: torch.matmul(nt0, nt0), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt0, nt1) + lambda: torch.matmul(nt0, nt1), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt0, nt2) + lambda: torch.matmul(nt0, nt2), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt1, nt0) + lambda: torch.matmul(nt1, nt0), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt1, nt1) + lambda: torch.matmul(nt1, nt1), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt1, nt2) + lambda: torch.matmul(nt1, nt2), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", - lambda: torch.matmul(nt2, nt0) + lambda: torch.matmul(nt2, nt0), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", - lambda: torch.matmul(nt2, nt1) + lambda: torch.matmul(nt2, nt1), ) # error case: incompatible batch size - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), - torch.randn((4, 5)), - torch.randn((4, 7))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))], + device=device, + dtype=dtype, + ) self.assertRaisesRegex( RuntimeError, r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", - lambda: torch.matmul(nt0, nt1) + lambda: torch.matmul(nt0, nt1), ) self.assertRaisesRegex( RuntimeError, r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", - lambda: torch.matmul(nt1, nt0) + lambda: torch.matmul(nt1, nt0), ) # error case: incompatible (wrong) batch sizes that shouldn't even broadcast? - nt0 = torch.nested.nested_tensor([torch.randn((2, 2, 4)), - torch.randn((2, 3, 4))], - device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((3, 4, 6)), - torch.randn((3, 4, 5))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 2, 4)), torch.randn((2, 3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((3, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype + ) self.assertRaisesRegex( RuntimeError, "matmul(): For nested tensors, batch dimensions must have the same sizes,", - lambda: torch.matmul(nt0, nt1) + lambda: torch.matmul(nt0, nt1), ) # error case: incompatible batch sizes that should technically broadcast - nt0 = torch.nested.nested_tensor([torch.randn((2, 2, 4)), - torch.randn((1, 3, 4))], - device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((1, 4, 6)), - torch.randn((3, 4, 5))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 2, 4)), torch.randn((1, 3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((1, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype + ) self.assertRaisesRegex( RuntimeError, "matmul(): For nested tensors, batch dimensions must have the same sizes,", - lambda: torch.matmul(nt0, nt1) + lambda: torch.matmul(nt0, nt1), ) # error case: underlying matrices cannot be multiplied - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) self.assertRaisesRegex( RuntimeError, "matmul(): Nested tensors cannot be matrix multiplied", - lambda: torch.matmul(nt0, nt0) + lambda: torch.matmul(nt0, nt0), ) # normal nested tensor: 3D - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype + ) actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) - expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0)) + expect = torch.matmul( + torch.nested.to_padded_tensor(nt0, 0.0), + torch.nested.to_padded_tensor(nt1, 0.0), + ) self.assertEqual(actual, expect) # normal nested tensor: 4D (with testing for batch_size=1) - nt0 = torch.nested.nested_tensor([torch.randn((1, 2, 4)), - torch.randn((8, 3, 7))], - device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((1, 4, 6)), - torch.randn((8, 7, 5))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((1, 2, 4)), torch.randn((8, 3, 7))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((1, 4, 6)), torch.randn((8, 7, 5))], device=device, dtype=dtype + ) actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) - expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0)) + expect = torch.matmul( + torch.nested.to_padded_tensor(nt0, 0.0), + torch.nested.to_padded_tensor(nt1, 0.0), + ) self.assertEqual(actual, expect) # normal nested tensor: 5D - nt0 = torch.nested.nested_tensor([torch.randn((8, 9, 2, 4)), - torch.randn((8, 9, 3, 7))], - device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((8, 9, 4, 6)), - torch.randn((8, 9, 7, 5))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((8, 9, 2, 4)), torch.randn((8, 9, 3, 7))], + device=device, + dtype=dtype, + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((8, 9, 4, 6)), torch.randn((8, 9, 7, 5))], + device=device, + dtype=dtype, + ) actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) - expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0)) + expect = torch.matmul( + torch.nested.to_padded_tensor(nt0, 0.0), + torch.nested.to_padded_tensor(nt1, 0.0), + ) self.assertEqual(actual, expect) # only supported on CUDA for now @@ -1912,11 +2155,16 @@ def test_matmul_nt_with_broadcasted_t(self, device, dtype): # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' @dtypes(torch.float, torch.double) def test_matmul_noncontiguous(self, device, dtype): - nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) - nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype) + nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair( + (2, 3), device, dtype + ) + nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair( + (6, 7), device, dtype + ) self.assertEqual( torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous), - torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous)) + torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous), + ) @dtypes(torch.float, torch.double) def test_linear(self, device, dtype): @@ -1931,29 +2179,39 @@ def test_linear(self, device, dtype): torch.functional.F.linear(nt, weight, bias) # invalid nested tensor dimension - msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2' - nt1 = torch.nested.nested_tensor([torch.randn(1, device=device, dtype=dtype), - torch.randn(2, device=device, dtype=dtype)]) + msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2" + nt1 = torch.nested.nested_tensor( + [ + torch.randn(1, device=device, dtype=dtype), + torch.randn(2, device=device, dtype=dtype), + ] + ) with self.assertRaisesRegex(RuntimeError, msg): torch.functional.F.linear(nt1, weight, bias) # invalid weight shape - msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3' + msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3" weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype) with self.assertRaisesRegex(RuntimeError, msg): torch.functional.F.linear(nt, weight1, bias) # inconsistent last dim of nested tensor msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:" - nt2 = torch.nested.nested_tensor([torch.randn(1, 2, device=device, dtype=dtype), - torch.randn(2, 3, device=device, dtype=dtype)]) + nt2 = torch.nested.nested_tensor( + [ + torch.randn(1, 2, device=device, dtype=dtype), + torch.randn(2, 3, device=device, dtype=dtype), + ] + ) with self.assertRaisesRegex(RuntimeError, msg): torch.functional.F.linear(nt2, weight, bias) # Mismatch of nested tensor last dim and weight dimension weight2 = torch.randn(2, 4, device=device, dtype=dtype) - msg = r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" \ + msg = ( + r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4" + ) with self.assertRaisesRegex(RuntimeError, msg): torch.functional.F.linear(nt, weight2, bias) @@ -1968,22 +2226,26 @@ def test_linear(self, device, dtype): # since linear does not support noncontiguous buffer yet @dtypes(torch.float, torch.double) def test_linear_noncontiguous(self, device, dtype): - nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) weight = torch.randn((8, 5), device=device, dtype=dtype) self.assertRaisesRegex( RuntimeError, r"for now linear only supports contiguous nested tensor", - lambda: torch.nn.functional.linear(nt_noncontiguous, weight) + lambda: torch.nn.functional.linear(nt_noncontiguous, weight), ) @dtypes(torch.float, torch.float16, torch.double) def test_to_padded_tensor_zero_numel_errors(self, device, dtype): ts = [torch.ones(1, 0), torch.ones(0, 0)] - nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype, layout=torch.strided) + nt = torch.nested.nested_tensor( + ts, device=device, dtype=dtype, layout=torch.strided + ) self.assertRaisesRegex( RuntimeError, r"at least one constituent tensor should have non-zero numel", - lambda: torch.nested.to_padded_tensor(nt, 0.0) + lambda: torch.nested.to_padded_tensor(nt, 0.0), ) @dtypes(torch.float, torch.float16, torch.double) @@ -1993,12 +2255,12 @@ def test_transpose(self, device, dtype): self.assertRaisesRegex( RuntimeError, "Nested tensor dimension 0 cannot be transposed", - lambda: nt.transpose(0, 1) + lambda: nt.transpose(0, 1), ) self.assertRaisesRegex( RuntimeError, "Nested tensor dimension 0 cannot be transposed", - lambda: nt.transpose(1, -3) + lambda: nt.transpose(1, -3), ) # error case: dimension out of range self.assertRaises(IndexError, lambda: nt.transpose(1, 3)) @@ -2019,13 +2281,13 @@ def test_squeeze_unsqueeze(self, device, dtype): self.assertRaisesRegex( RuntimeError, "For nested tensors, squeeze without the dim argument", - lambda: nt.squeeze() + lambda: nt.squeeze(), ) # error case: squeeze nested dimension self.assertRaisesRegex( RuntimeError, "For nested tensors, squeezing dimension 0", - lambda: nt.squeeze(0) + lambda: nt.squeeze(0), ) # error case: dimension out of range self.assertRaises(IndexError, lambda: nt.squeeze(3)) @@ -2035,7 +2297,7 @@ def test_squeeze_unsqueeze(self, device, dtype): self.assertRaisesRegex( RuntimeError, "For nested tensors, squeezing a nested tensor of singleton", - lambda: nt_singleton.squeeze(1) + lambda: nt_singleton.squeeze(1), ) # squeezing a dim which does not have size 1 should be a no-op @@ -2046,7 +2308,7 @@ def test_squeeze_unsqueeze(self, device, dtype): nt_sizes = nt._nested_tensor_size() nt_strides = nt._nested_tensor_strides() for i in range(-2, 4): - if (i == 0): + if i == 0: # cannot unsqueeze batch dim continue nt_unsqueezed = nt.unsqueeze(i) @@ -2054,9 +2316,12 @@ def test_squeeze_unsqueeze(self, device, dtype): wrapped_i = i + nt.dim() + 1 if i < 0 else i # col_index into nt size tensor is requires subtraction of 1 to ignore batch dim size_idx = wrapped_i - 1 - self.assertEqual(nt_unsqueezed._nested_tensor_size()[:, size_idx], torch.ones(2, dtype=torch.long)) + self.assertEqual( + nt_unsqueezed._nested_tensor_size()[:, size_idx], + torch.ones(2, dtype=torch.long), + ) unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx] - if (i == nt.ndim or i == -1): + if i == nt.ndim or i == -1: self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long)) else: stride_col_after = nt_strides[:, size_idx] @@ -2094,25 +2359,25 @@ def test_view(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"shape '\[\]' is invalid for a nested tensor", - lambda: nt.view(()) + lambda: nt.view(()), ) # error case: empty nested tensor nt_empty = torch.nested.nested_tensor([]) self.assertRaisesRegex( RuntimeError, "empty nested tensor cannot be reshaped", - lambda: nt_empty.view(-1) + lambda: nt_empty.view(-1), ) # error case: -1 for batch size self.assertRaisesRegex( RuntimeError, r"view: For now nested view cannot change or infer the implicit batch dimension", - lambda: nt.view(-1, 2, 3) + lambda: nt.view(-1, 2, 3), ) self.assertRaisesRegex( RuntimeError, r"shape '\[.*\]' is invalid for input of size [0-9]+", - lambda: nt.view(4, 2, 3) + lambda: nt.view(4, 2, 3), ) # normal case x0 = torch.randn((2, 20), device=device, dtype=dtype) @@ -2123,7 +2388,7 @@ def test_view(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"For now nested view cannot change or infer the implicit batch dimension", - lambda: nt.transpose(-1, -2).view(40, -1) + lambda: nt.transpose(-1, -2).view(40, -1), ) # inherit only the ragged dimension # (2, 20) -> (2, 5, 4) @@ -2139,13 +2404,15 @@ def test_view(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"only one dimension can be inferred", - lambda: nt1.view(2, -1, -1, 2, 2) + lambda: nt1.view(2, -1, -1, 2, 2), ) @dtypes(torch.float, torch.float16, torch.double) def test_view_inference_mode_interaction(self, device, dtype): # Construct in default mode and view while in inference mode - nt = torch.nested.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype) + nt = torch.nested.nested_tensor( + [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype + ) with torch.inference_mode(): ntT = nt.view(2, -1, 4, 5) ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) @@ -2154,7 +2421,9 @@ def test_view_inference_mode_interaction(self, device, dtype): self.assertEqual(ptT, ptT_from_ntT) # Construct and view while in inference mode with torch.inference_mode(): - nt = torch.nested.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype) + nt = torch.nested.nested_tensor( + [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype + ) ntT = nt.view(2, -1, 4, 5) ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) pt = torch.nested.to_padded_tensor(nt, 0.0) @@ -2168,25 +2437,25 @@ def test_reshape(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"shape '\[\]' is invalid for a nested tensor", - lambda: nt.reshape(()) + lambda: nt.reshape(()), ) # error case: empty nested tensor nt_empty = torch.nested.nested_tensor([]) self.assertRaisesRegex( RuntimeError, "empty nested tensor cannot be reshaped", - lambda: nt_empty.reshape(-1) + lambda: nt_empty.reshape(-1), ) # error case: -1 for batch size self.assertRaisesRegex( RuntimeError, r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", - lambda: nt.reshape(-1, 2, 3) + lambda: nt.reshape(-1, 2, 3), ) self.assertRaisesRegex( RuntimeError, r"shape '\[.*\]' is invalid for input of size [0-9]+", - lambda: nt.reshape(4, 2, 3) + lambda: nt.reshape(4, 2, 3), ) # normal case x0 = torch.randn((2, 20), device=device, dtype=dtype) @@ -2197,7 +2466,7 @@ def test_reshape(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", - lambda: nt.transpose(-1, -2).reshape(40, -1) + lambda: nt.transpose(-1, -2).reshape(40, -1), ) # inherit only the ragged dimension # (2, 20) -> (2, 5, 4) @@ -2213,7 +2482,7 @@ def test_reshape(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"only one dimension can be inferred", - lambda: nt1.reshape(2, -1, -1, 2, 2) + lambda: nt1.reshape(2, -1, -1, 2, 2), ) @dtypes(torch.float, torch.float16, torch.double) @@ -2232,35 +2501,50 @@ def test_narrow(self, device, dtype): # dim != 0 is not supported for dim in range(1, nt.dim()): - with self.assertRaisesRegex(RuntimeError, "only dim=0 supported for nested tensors"): + with self.assertRaisesRegex( + RuntimeError, "only dim=0 supported for nested tensors" + ): nt.narrow(dim=dim, start=0, length=1) # error case: non-contiguous NT _, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4)) - with self.assertRaisesRegex(RuntimeError, "only contiguous nested tensors supported"): + with self.assertRaisesRegex( + RuntimeError, "only contiguous nested tensors supported" + ): nt_noncont.narrow(dim=0, start=0, length=1) @parametrize("input_dim", [3, 4]) def test_scaled_dot_product_attention(self, device, input_dim): - def rand_tensor(*shape): return torch.randn(shape, device=device) E = 8 if input_dim == 3: # Shape: (N, L, E); ragged L - query = torch.nested.nested_tensor([rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)]) + query = torch.nested.nested_tensor( + [rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)] + ) # Shape: (N, S, E); ragged S - key = torch.nested.nested_tensor([rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]) - value = torch.nested.nested_tensor([rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]) + key = torch.nested.nested_tensor( + [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)] + ) + value = torch.nested.nested_tensor( + [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)] + ) elif input_dim == 4: # In the 4D case the L and S is ragged # Shape: (N, N', L, E); ragged N' and L - query = torch.nested.nested_tensor([rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)]) + query = torch.nested.nested_tensor( + [rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)] + ) # Shape: (N, N', S, E); ragged N' and S - key = torch.nested.nested_tensor([rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]) - value = torch.nested.nested_tensor([rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]) + key = torch.nested.nested_tensor( + [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)] + ) + value = torch.nested.nested_tensor( + [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)] + ) else: self.fail(f"Invalid input_dim {input_dim} encountered in SDP test") @@ -2268,31 +2552,43 @@ def rand_mask(size): return torch.randint(0, 2, size=size, dtype=torch.bool, device=device) # Shape: (N, L, S); ragged L and S matching above - attn_mask = torch.nested.nested_tensor([rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))]) + attn_mask = torch.nested.nested_tensor( + [rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))] + ) dropout_p = 0.0 # no dropout for reproducibility # Success case: no attn_mask set and is_causal=False. actual = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p) + query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p + ) expected_outputs = [] for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()): output = torch.nn.functional.scaled_dot_product_attention( - q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_mask=None, dropout_p=dropout_p) + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + attn_mask=None, + dropout_p=dropout_p, + ) expected_outputs.append(output.squeeze(0)) expected_output_nested = torch.nested.nested_tensor(expected_outputs) self.assertEqual(actual, expected_output_nested) # Error case: explicit attn_mask set. - with self.assertRaisesRegex(RuntimeError, "not supported when an explicit attn_mask is set"): + with self.assertRaisesRegex( + RuntimeError, "not supported when an explicit attn_mask is set" + ): torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=dropout_p) + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p + ) # Error case: is_causal=True. with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"): torch.nn.functional.scaled_dot_product_attention( - query, key, value, dropout_p=dropout_p, is_causal=True) + query, key, value, dropout_p=dropout_p, is_causal=True + ) @dtypes(torch.float, torch.float16, torch.double) def test_empty_like(self, device, dtype): @@ -2308,10 +2604,10 @@ def test_empty_like(self, device, dtype): if torch.cuda.is_available(): if device == "cpu": - nt_cuda = torch.empty_like(nt, device='cuda') + nt_cuda = torch.empty_like(nt, device="cuda") self.assertEqual(torch.device("cuda").type, nt_cuda.device.type) else: - nt_cpu = torch.empty_like(nt, device='cpu') + nt_cpu = torch.empty_like(nt, device="cpu") self.assertEqual(torch.device("cpu").type, nt_cpu.device.type) # Check changing dtype of empty_like nested tensor output @@ -2335,19 +2631,36 @@ def test_empty_like(self, device, dtype): assert nt_noncont.is_same_size(nt_empty_non_contig) # Test the contiguous memory format option - nt_empty_contig = torch.empty_like(nt_cont, memory_format=torch.contiguous_format) + nt_empty_contig = torch.empty_like( + nt_cont, memory_format=torch.contiguous_format + ) assert nt_cont.is_same_size(nt_empty_contig) assert nt_empty_contig.is_contiguous() - nt_empty_non_contig = torch.empty_like(nt_noncont, memory_format=torch.contiguous_format) + nt_empty_non_contig = torch.empty_like( + nt_noncont, memory_format=torch.contiguous_format + ) assert nt_noncont.is_same_size(nt_empty_non_contig) assert nt_empty_non_contig.is_contiguous() # Test other memory formats fail - self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last)) - self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last)) - self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d)) - self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d)) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last), + ) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last), + ) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d), + ) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d), + ) + @markDynamoStrictTest class TestNestedTensorAutograd(TestCase): @@ -2355,12 +2668,26 @@ class TestNestedTensorAutograd(TestCase): # includes the default parameters used for testing ops with gradcheck. However nested tensor # does not support the stack op therefore we turn it off for these tests def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False): - return torch.nested.nested_tensor([torch.randn(1, 2,), - torch.randn(7, 8)], requires_grad=requires_grad, device=tensor_device) + return torch.nested.nested_tensor( + [ + torch.randn( + 1, + 2, + ), + torch.randn(7, 8), + ], + requires_grad=requires_grad, + device=tensor_device, + ) def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False): - return torch.nested.as_nested_tensor([torch.randn(1, 2, requires_grad=requires_grad), - torch.randn(7, 8, requires_grad=requires_grad)], device=tensor_device) + return torch.nested.as_nested_tensor( + [ + torch.randn(1, 2, requires_grad=requires_grad), + torch.randn(7, 8, requires_grad=requires_grad), + ], + device=tensor_device, + ) def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False): data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device) @@ -2378,7 +2705,9 @@ def test_as_nested_tensor_propagates_gradients(self, device): a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) nt2 = torch.nested.as_nested_tensor([a, b]) - fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)], device=device) + fake_grad = torch.nested.nested_tensor( + [torch.ones_like(a), torch.zeros_like(b)], device=device + ) nt2.backward(fake_grad) self.assertEqual(a.grad, fake_grad[0]) self.assertEqual(b.grad, fake_grad[1]) @@ -2395,7 +2724,9 @@ def test_nested_tensor_generates_leaf(self, device): self.assertTrue(nt2.is_leaf) self.assertTrue(nt2.requires_grad) - fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)], device=device) + fake_grad = torch.nested.nested_tensor( + [torch.ones_like(a), torch.zeros_like(b)], device=device + ) nt2.backward(fake_grad) self.assertEqual(nt2.grad, fake_grad) self.assertEqual(a.grad, None) @@ -2445,8 +2776,16 @@ def test_backward_for_sub_op(self, device): self.assertEqual(nt_2.grad, -1 * grad_output) def test_backward_sub_strided(self, device): - a = torch.nested.nested_tensor([torch.randn(9, 2, 4), torch.randn(12, 2, 4)], requires_grad=True, device=device) - b = torch.nested.nested_tensor([torch.randn(9, 4, 2), torch.randn(12, 4, 2)], requires_grad=True, device=device) + a = torch.nested.nested_tensor( + [torch.randn(9, 2, 4), torch.randn(12, 2, 4)], + requires_grad=True, + device=device, + ) + b = torch.nested.nested_tensor( + [torch.randn(9, 4, 2), torch.randn(12, 4, 2)], + requires_grad=True, + device=device, + ) c = a - b.transpose(-1, -2) grad_output = c.clone() c.backward(grad_output) @@ -2454,8 +2793,16 @@ def test_backward_sub_strided(self, device): self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2)) def test_backward_add_strided(self, device): - a = torch.nested.nested_tensor([torch.randn(9, 2, 4), torch.randn(12, 2, 4)], requires_grad=True, device=device) - b = torch.nested.nested_tensor([torch.randn(9, 4, 2), torch.randn(12, 4, 2)], requires_grad=True, device=device) + a = torch.nested.nested_tensor( + [torch.randn(9, 2, 4), torch.randn(12, 2, 4)], + requires_grad=True, + device=device, + ) + b = torch.nested.nested_tensor( + [torch.randn(9, 4, 2), torch.randn(12, 4, 2)], + requires_grad=True, + device=device, + ) c = a + b.transpose(-1, -2) grad_output = c.clone() c.backward(grad_output) @@ -2465,13 +2812,20 @@ def test_backward_add_strided(self, device): # Test Factory Functions def test_nested_tensor_to_padded_tensor(self, device): for padding_val in [0, 1]: - nt = self._create_leaf_nested_tensor_from_list(tensor_device=device, requires_grad=True) + nt = self._create_leaf_nested_tensor_from_list( + tensor_device=device, requires_grad=True + ) out = torch.nested.to_padded_tensor(nt, padding_val) grad_output = torch.ones(out.shape, device=device) out.backward(grad_output) - self.assertEqual(nt.grad, torch.nested.nested_tensor([torch.ones(1, 2), torch.ones(7, 8)], device=device)) + self.assertEqual( + nt.grad, + torch.nested.nested_tensor( + [torch.ones(1, 2), torch.ones(7, 8)], device=device + ), + ) def test_nested_tensor_from_mask_and_to_padded(self, device): N, L, D = 2, 4, 4 @@ -2483,12 +2837,15 @@ def test_nested_tensor_from_mask_and_to_padded(self, device): mask[0, :] = 1 mask = mask.bool() - data = torch.randn(N, L, D, requires_grad=True, dtype=torch.float64, device=device) + data = torch.randn( + N, L, D, requires_grad=True, dtype=torch.float64, device=device + ) def grad_test_func(inpt): nt = torch._nested_tensor_from_mask(inpt, mask) # This implicitly tests to_padded_tensor grads return torch.nested.to_padded_tensor(nt, 0) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) def test_nested_tensor_from_padded(self, device): @@ -2498,7 +2855,9 @@ def test_nested_tensor_from_padded(self, device): padded_tensor.requires_grad_() def grad_test_func(tensor, nested_size): - nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=False) + nt = torch._nested_from_padded( + tensor, nested_size, fuse_transform_0213=False + ) # This implicitly tests to_padded_tensor grads return torch.nested.to_padded_tensor(nt, 0) @@ -2512,14 +2871,16 @@ def test_nested_tensor_from_padded_fused(self, device): padded_tensor.requires_grad_() def grad_test_func(tensor, nested_size): - nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=True) + nt = torch._nested_from_padded( + tensor, nested_size, fuse_transform_0213=True + ) # This implicitly tests to_padded_tensor grads return torch.nested.to_padded_tensor(nt, 0) + data = (padded_tensor, nested_size) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) def test_nested_tensor_from_list(self, device): - a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device) @@ -2528,20 +2889,29 @@ def grad_test_func(a, b, c): c = torch.nested.as_nested_tensor([a, b, c]) # This implictily tests to_padded_tensor grads return torch.nested.to_padded_tensor(c, 0) + data = (a, b, c) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) @decorateIf( xfailIfTorchDynamo, # only fails in python 3.11. TODO: Debug this! - lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11) + lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11), ) @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) def test_dropout_backward(self, layout): if layout == torch.jagged: - nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 5))], requires_grad=True, layout=layout) + nt = torch.nested.nested_tensor( + [torch.randn((2, 5)), torch.randn((3, 5))], + requires_grad=True, + layout=layout, + ) else: - nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))], requires_grad=True, layout=layout) + nt = torch.nested.nested_tensor( + [torch.randn((2, 5)), torch.randn((3, 4))], + requires_grad=True, + layout=layout, + ) p = 0.2 y = torch.nn.functional.dropout(nt, p) y.backward(nt.clone().detach()) @@ -2563,8 +2933,16 @@ def grad_test_func(a, b, c, d): assert torch.autograd.gradcheck(grad_test_func, inputs=data) def test_nested_tensor_bmm_backward(self, device): - nt0 = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True, device=device) - nt1 = torch.nested.nested_tensor([torch.randn((6, 4)), torch.randn((6, 5))], requires_grad=True, device=device) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 6)), torch.randn((3, 6))], + requires_grad=True, + device=device, + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((6, 4)), torch.randn((6, 5))], + requires_grad=True, + device=device, + ) with torch.no_grad(): pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) @@ -2593,8 +2971,16 @@ def grad_test_func(a, b, c, d): assert torch.autograd.gradcheck(grad_test_func, inputs=data) def test_nested_tensor_matmul_backward(self, device): - nt0 = torch.nested.nested_tensor([torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], requires_grad=True, device=device) - nt1 = torch.nested.nested_tensor([torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], requires_grad=True, device=device) + nt0 = torch.nested.nested_tensor( + [torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], + requires_grad=True, + device=device, + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], + requires_grad=True, + device=device, + ) with torch.no_grad(): pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) @@ -2620,7 +3006,11 @@ def grad_test_func(a, b): assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) def test_nested_tensor_transpose_backward(self, device): - nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))], requires_grad=True, device=device) + nt = torch.nested.nested_tensor( + [torch.randn((2, 5)), torch.randn((3, 4))], + requires_grad=True, + device=device, + ) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -2644,7 +3034,9 @@ def grad_test_func(a, b): assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) def test_nested_tensor_reshape_backward(self): - nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True) + nt = torch.nested.nested_tensor( + [torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True + ) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -2656,7 +3048,11 @@ def test_nested_tensor_reshape_backward(self): self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) def test_nested_tensor_squeeze_backward(self, device): - nt = torch.nested.nested_tensor([torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], requires_grad=True, device=device) + nt = torch.nested.nested_tensor( + [torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], + requires_grad=True, + device=device, + ) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -2668,8 +3064,12 @@ def test_nested_tensor_squeeze_backward(self, device): self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) def test_nested_tensor_squeeze_gradcheck(self, device): - a = torch.randn((2, 6, 1), dtype=torch.float64, requires_grad=True, device=device) - b = torch.randn((3, 6, 1), dtype=torch.float64, requires_grad=True, device=device) + a = torch.randn( + (2, 6, 1), dtype=torch.float64, requires_grad=True, device=device + ) + b = torch.randn( + (3, 6, 1), dtype=torch.float64, requires_grad=True, device=device + ) def grad_test_func(a, b): nt = torch.nested.as_nested_tensor([a, b]) @@ -2679,7 +3079,11 @@ def grad_test_func(a, b): assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) def test_nested_tensor_unsqueeze_backward(self, device): - nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True, device=device) + nt = torch.nested.nested_tensor( + [torch.randn((2, 6)), torch.randn((3, 6))], + requires_grad=True, + device=device, + ) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -2702,12 +3106,13 @@ def grad_test_func(a, b): assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) def test_nested_tensor_linear(self, device): - a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) - weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + weight = torch.randn( + 2, 2, requires_grad=True, dtype=torch.float64, device=device + ) bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c, weight, bias=None): @@ -2715,6 +3120,7 @@ def grad_test_func(a, b, c, weight, bias=None): # This implicitly tests to_padded_tensor grads d = torch.functional.F.linear(nt, weight, bias) return torch.nested.to_padded_tensor(d, 0) + data = (a, b, c, weight, bias) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) @@ -2727,7 +3133,9 @@ def test_nested_tensor_linear_plus_transpose(self, device): b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) - weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + weight = torch.randn( + 2, 2, requires_grad=True, dtype=torch.float64, device=device + ) bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c, weight, bias=None): @@ -2736,6 +3144,7 @@ def grad_test_func(a, b, c, weight, bias=None): d = torch.functional.F.linear(nt, weight, bias) d = d.transpose(-1, -2).contiguous() return torch.nested.to_padded_tensor(d, 0) + data = (a, b, c, weight, bias) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) @@ -2845,7 +3254,9 @@ def test_indexing_backward(self, device): self.assertEqual(nt[-1], x1) grad_x0 = torch.randn((2, 5), device=device) nt[0].backward(grad_x0) - expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device)]) + expected_grad = torch.nested.nested_tensor( + [grad_x0, torch.zeros((3, 4), device=device)] + ) self.assertEqual(nt.grad, expected_grad) def test_masked_fill_backward(self, device): @@ -2859,6 +3270,7 @@ def grad_test_func(a, b, c): out = nt.masked_fill(mask, 0) out = torch.nested.to_padded_tensor(out, 0) return out + data = (a, b, c) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) @@ -2918,9 +3330,13 @@ def grad_test_func(a, b, c): # NotImplementedError: Cannot access storage of UndefinedTensorImpl def test_layer_norm_backward_edge_case(self, device): size = 4 - a = torch.randn(1, 2, size, requires_grad=False, dtype=torch.float64, device=device) + a = torch.randn( + 1, 2, size, requires_grad=False, dtype=torch.float64, device=device + ) nt = torch.nested.nested_tensor([a]) - nt_layer_norm = torch.nn.LayerNorm(nt.size(-1), device=device, dtype=torch.float64) + nt_layer_norm = torch.nn.LayerNorm( + nt.size(-1), device=device, dtype=torch.float64 + ) out = nt_layer_norm(nt) out.backward(out.clone()) @@ -2941,13 +3357,21 @@ def grad_test_func(a, b): @skipIfSlowGradcheckEnv @parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2]) def test_layer_norm_backward(self, device, size): - a = torch.randn(1, 2, size, requires_grad=True, dtype=torch.float64, device=device) - b = torch.randn(2, 2, size, requires_grad=True, dtype=torch.float64, device=device) - c = torch.randn(3, 2, size, requires_grad=True, dtype=torch.float64, device=device) + a = torch.randn( + 1, 2, size, requires_grad=True, dtype=torch.float64, device=device + ) + b = torch.randn( + 2, 2, size, requires_grad=True, dtype=torch.float64, device=device + ) + c = torch.randn( + 3, 2, size, requires_grad=True, dtype=torch.float64, device=device + ) def grad_test_func(a, b, c): nt = torch.nested.as_nested_tensor([a, b, c]) - layer_norm = torch.nn.LayerNorm(nt.size(-1), device=device, dtype=torch.float64) + layer_norm = torch.nn.LayerNorm( + nt.size(-1), device=device, dtype=torch.float64 + ) nt_layer_norm = layer_norm(nt) return torch.nested.to_padded_tensor(nt_layer_norm, 0) @@ -2959,23 +3383,33 @@ def grad_test_func(a, b, c): # Could either mark slow or reduce size @parametrize("size", [128, 32, 4, 2]) def test_layer_norm_backward_5d(self, device, size): - a = torch.randn(4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device) - b = torch.randn(7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device) - c = torch.randn(10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device) + a = torch.randn( + 4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device + ) + b = torch.randn( + 7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device + ) + c = torch.randn( + 10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device + ) def grad_test_func(a, b, c): nt = torch.nested.as_nested_tensor([a, b, c]) - layer_norm = torch.nn.LayerNorm((size, size, nt.size(-1)), device=device, dtype=torch.float64) + layer_norm = torch.nn.LayerNorm( + (size, size, nt.size(-1)), device=device, dtype=torch.float64 + ) nt_layer_norm = layer_norm(nt) return torch.nested.to_padded_tensor(nt_layer_norm, 0) data = (a, b, c) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + # Found in torch/testing/_comparison.py default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} + def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: deviation = true_value - computed_value deviation = torch.abs(deviation / true_value) @@ -3008,6 +3442,7 @@ def get_tolerances( rtol = default_rtol[computed_value.dtype] return atol, rtol + # We can probably parametrizing existing tests instead of having a separate # test class as we begin to support more ops. Also maybe rewrite with OpInfos. @markDynamoStrictTest @@ -3018,16 +3453,25 @@ def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True): out = [] for s in nested_size[0]: out.append( - torch.randn(s, *Ds, requires_grad=requires_grad, device=device, dtype=torch.float64) + torch.randn( + s, + *Ds, + requires_grad=requires_grad, + device=device, + dtype=torch.float64, + ) ) return out - def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True): - - def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True): + def _get_example_tensor_lists( + self, include_list_of_lists=True, include_requires_grad=True + ): + def _make_tensor( + *shape, include_requires_grad=include_requires_grad, requires_grad=True + ): return torch.randn( *shape, - requires_grad=(requires_grad if include_requires_grad else False) + requires_grad=(requires_grad if include_requires_grad else False), ) # Purposefully introduce mixed requires_grad settings for the components @@ -3038,7 +3482,7 @@ def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_g _make_tensor(2, 5), _make_tensor(3, 5, requires_grad=False), _make_tensor(4, 5, requires_grad=False), - _make_tensor(6, 5) + _make_tensor(6, 5), ], # (B, *, D_0, D_1) with B=5 [ @@ -3066,7 +3510,8 @@ def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_g _make_tensor(2, 5, requires_grad=False).tolist(), _make_tensor(3, 5).tolist(), _make_tensor(4, 5).tolist(), - ]) + ] + ) return example_lists @@ -3088,11 +3533,14 @@ def test_tensor_attributes(self, device): ): op(nt) - with self.assertRaisesRegex(RuntimeError, - "directly calling torch.ops.aten.size"): + with self.assertRaisesRegex( + RuntimeError, "directly calling torch.ops.aten.size" + ): torch.ops.aten.size.default(nt) - nested_int = torch.nested._internal.nested_tensor.get_tensor_symint(_offsets, coeff=1) + nested_int = torch.nested._internal.nested_tensor.get_tensor_symint( + _offsets, coeff=1 + ) self.assertEqual(nt.size(), (3, nested_int, 3)) self.assertEqual(nt.shape, (3, nested_int, 3)) self.assertEqual(nt.dim(), 3) @@ -3102,7 +3550,9 @@ def test_linear(self, device): a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) - weight = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) + weight = torch.randn( + 4, 3, requires_grad=True, dtype=torch.float64, device=device + ) def grad_test_func(a, b, c, weight): nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) @@ -3125,19 +3575,30 @@ def grad_test_func(a, b, c): def test_unary_pointwise_transposed_inputs(self, device): a, b, c = ( - torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3) + torch.randn( + i + 2, 5, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) ) - nt = torch.nested.nested_tensor([a.detach(), b.detach(), c.detach()], layout=torch.jagged) + nt = torch.nested.nested_tensor( + [a.detach(), b.detach(), c.detach()], layout=torch.jagged + ) nt_t = nt.transpose(1, 2) self.assertFalse(nt_t.is_contiguous()) out = torch.nn.functional.silu(nt_t.sin().cos()) - self.assertEqual(out.is_contiguous(), torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous()) + self.assertEqual( + out.is_contiguous(), + torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous(), + ) self.assertEqual(nt_t.shape, out.shape) a, b, c = ( - torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3) + torch.randn( + i + 2, 5, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) ) def grad_test_func(a, b, c): @@ -3148,7 +3609,6 @@ def grad_test_func(a, b, c): gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) - def test_binary_pointwise(self, device): a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) @@ -3162,7 +3622,8 @@ def test_binary_pointwise(self, device): self.assertRaisesRegex( RuntimeError, "cannot call binary pointwise function .* with inputs of shapes", - lambda: nt1 * nt2) + lambda: nt1 * nt2, + ) # Correct usage: chain the calls using the same offsets tensor object def grad_test_func(a, b, c): @@ -3197,7 +3658,10 @@ def test_binary_pointwise_transposed(self, device): ) a, b, c = ( - torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3) + torch.randn( + i + 2, 5, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) ) # Correct usage: chain the calls using the same offsets tensor object @@ -3221,11 +3685,15 @@ def test_split(self, device): self.assertEqual(len(out), 2) self.assertEqual( out[0], - torch.nested.as_nested_tensor([a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged) + torch.nested.as_nested_tensor( + [a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged + ), ) self.assertEqual( out[1], - torch.nested.as_nested_tensor([a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged) + torch.nested.as_nested_tensor( + [a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged + ), ) with self.assertRaisesRegex( @@ -3244,11 +3712,15 @@ def test_split_with_sizes(self, device): self.assertEqual(len(out), 2) self.assertEqual( out[0], - torch.nested.as_nested_tensor([a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged) + torch.nested.as_nested_tensor( + [a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged + ), ) self.assertEqual( out[1], - torch.nested.as_nested_tensor([a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged) + torch.nested.as_nested_tensor( + [a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged + ), ) with self.assertRaisesRegex( RuntimeError, @@ -3259,7 +3731,8 @@ def test_split_with_sizes(self, device): def test_views_inherit_ragged_dim(self, device): # view nt = random_nt_from_dims( - [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged) + [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged + ) # inherit ragged dim via -1 view = nt.view(4, -1, 80) self.assertEqual(nt.shape[1], view.shape[1]) @@ -3269,20 +3742,25 @@ def test_views_inherit_ragged_dim(self, device): # expand nt = random_nt_from_dims( - [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged) + [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged + ) # inherit batch and ragged dims via -1 view = nt.expand(-1, -1, 5) self.assertEqual(nt.shape[:2], view.shape[:2]) def test_view_ragged_idx_not_one(self, device): - nt = random_nt_from_dims([2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged + ) view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1)) self.assertEqual((2, 20, nt.size(1)), (view_transposed.size())) self.assertEqual(view_transposed._base, nt._base) def test_unsafe_view(self, device): - nt = random_nt_from_dims([4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged + ) # basic view view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80)) self.assertEqual((4, nt.size(1), 80), tuple(view1.size())) @@ -3299,12 +3777,16 @@ def test_unsafe_view(self, device): @parametrize("requires_grad", [False, True]) def test_reshape_decomp(self, device, requires_grad): # contiguous NT should result in view. - nt = random_nt_from_dims( - [3, None, 10], - device=device, - dtype=torch.float32, - layout=torch.jagged, - ).detach().requires_grad_(requires_grad) + nt = ( + random_nt_from_dims( + [3, None, 10], + device=device, + dtype=torch.float32, + layout=torch.jagged, + ) + .detach() + .requires_grad_(requires_grad) + ) view = nt.reshape(-1, -1, 5, 2) self.assertEqual(view.shape[:2], nt.shape[:2]) self.assertTrue(view._is_view() and view._base is nt) @@ -3319,7 +3801,7 @@ def test_reshape_decomp(self, device, requires_grad): device=device, dtype=torch.float32, layout=torch.jagged, - requires_grad=requires_grad + requires_grad=requires_grad, ) nt_noncontig = nt.transpose(-1, -2) self.assertFalse(nt_noncontig.is_contiguous()) @@ -3333,12 +3815,14 @@ def test_reshape_decomp(self, device, requires_grad): def test_flatten_decomp(self, device): nt = random_nt_from_dims( - [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged) + [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged + ) flattened = nt.flatten(-2, -1) self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape) nt = random_nt_from_dims( - [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged) + [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged + ) flattened = nt.flatten(-3, -2) self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape) @@ -3346,7 +3830,9 @@ def test_chunk(self, device): # normal case D = 30 B = 8 - nt = random_nt_from_dims([B, None, D], device=device, dtype=torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [B, None, D], device=device, dtype=torch.float32, layout=torch.jagged + ) NUM_CHUNKS = 3 chunks = nt.chunk(NUM_CHUNKS, dim=-1) self.assertEqual(len(chunks), NUM_CHUNKS) @@ -3362,12 +3848,17 @@ def test_chunk(self, device): self.assertEqual(chunks[i].shape[0], chunk_size) else: self.assertEqual(chunks[i].shape[0], B - chunk_size * (NUM_CHUNKS - 1)) - offsets_expected = nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1] - nt._offsets[i * chunk_size] + offsets_expected = ( + nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1] + - nt._offsets[i * chunk_size] + ) self.assertEqual(chunks[i]._offsets[1:], offsets_expected) self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0)) # chunk on ragged dim not supported - with self.assertRaisesRegex(RuntimeError, "chunk.* not supported for NestedTensor on dim=1"): + with self.assertRaisesRegex( + RuntimeError, "chunk.* not supported for NestedTensor on dim=1" + ): nt.chunk(2, dim=1) def test_squeeze(self, device): @@ -3375,7 +3866,8 @@ def test_squeeze(self, device): D = 6 # squeeze middle dim nt = random_nt_from_dims( - [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged) + [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged + ) j0 = nt.shape[1] for dim_arg in [-2, 2]: @@ -3385,7 +3877,8 @@ def test_squeeze(self, device): # squeeze last dim nt = random_nt_from_dims( - [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged) + [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged + ) j1 = nt.shape[1] for dim_arg in [-1, 2]: @@ -3395,17 +3888,21 @@ def test_squeeze(self, device): # squeeze on batch dim not supported with self.assertRaisesRegex( - RuntimeError, "squeeze.* not supported for NestedTensor on dim=0"): + RuntimeError, "squeeze.* not supported for NestedTensor on dim=0" + ): nt.squeeze(0) # squeeze on ragged dim not supported with self.assertRaisesRegex( - RuntimeError, "squeeze.* not supported for NestedTensor on dim=1"): + RuntimeError, "squeeze.* not supported for NestedTensor on dim=1" + ): nt.squeeze(1) def test_binary_pointwise_broadcasting(self, device): # (B, j0, 3, 4) - ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device, requires_grad=True) + ts = self._get_list_for_jagged_tensor( + ((2, 3, 4), 3, 4), device, requires_grad=True + ) # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?) # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?) # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?) @@ -3425,12 +3922,18 @@ def grad_test_func(t, *ts): return out.values() for t_size in t_sizes: - t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64) + t = torch.rand( + t_size, requires_grad=True, device=device, dtype=torch.float64 + ) gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False) def test_threshold_backward(self, device): - ts1 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False) - ts2 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False) + ts1 = self._get_list_for_jagged_tensor( + ((2, 3, 4), 16), device=device, requires_grad=False + ) + ts2 = self._get_list_for_jagged_tensor( + ((2, 3, 4), 16), device=device, requires_grad=False + ) nt1, offsets = jagged_from_list(ts1, None) nt2, offsets = jagged_from_list(ts2, offsets) @@ -3442,11 +3945,12 @@ def test_threshold_backward(self, device): self.assertEqual(res_dense, res_nt.values()) - @parametrize("keepdim", [False, True]) def test_sum_int_DimList(self, device, keepdim): # (B, j0, 3, 4) - ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device=device, requires_grad=True) + ts = self._get_list_for_jagged_tensor( + ((2, 3, 4), 3, 4), device=device, requires_grad=True + ) # Check shape correctness reduce_dims = ( @@ -3462,8 +3966,9 @@ def test_sum_int_DimList(self, device, keepdim): for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims: if (0 in rd) ^ (1 in rd): with self.assertRaisesRegex( - RuntimeError, - "applying over the ragged dimension, but not the batch dimension"): + RuntimeError, + "applying over the ragged dimension, but not the batch dimension", + ): nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) out = torch.sum(nt, dim=rd, keepdim=keepdim) continue @@ -3494,18 +3999,17 @@ def test_sum_int_DimList(self, device, keepdim): self.assertNotIsInstance(out, NestedTensor) self.assertTrue(torch.allclose(out, out_ref)) - - @dtypes(torch.float, torch.double, torch.half) @parametrize("requires_grad", [False, True]) @parametrize("weights_only", [False, True]) def test_serialization(self, device, dtype, requires_grad, weights_only): - def compare_metadata(nt1, nt2): self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size()) self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides()) - self.assertEqual(nt1._nested_tensor_storage_offsets(), - nt2._nested_tensor_storage_offsets()) + self.assertEqual( + nt1._nested_tensor_storage_offsets(), + nt2._nested_tensor_storage_offsets(), + ) nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) for a in [nt_contiguous, nt_noncontiguous]: @@ -3520,7 +4024,9 @@ def compare_metadata(nt1, nt2): self.assertEqual(b, nt_contiguous) self.assertEqual(b, nt_noncontiguous) - @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") + @unittest.skipIf( + PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property" + ) @onlyCUDA def test_pin_memory(self, device): nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) @@ -3535,7 +4041,9 @@ def test_pin_memory(self, device): self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) @torch.compiler.disable - def _validate_nt(self, nt, device, dtype, layout, requires_grad, dim, batch_size, base=None): + def _validate_nt( + self, nt, device, dtype, layout, requires_grad, dim, batch_size, base=None + ): # Validate a bunch of properties after NT construction. device = torch.device(device) self.assertEqual(nt.dim(), dim) @@ -3557,20 +4065,30 @@ def _validate_nt(self, nt, device, dtype, layout, requires_grad, dim, batch_size @parametrize("requires_grad", [False, True]) @parametrize("components_require_grad", [False, True]) def test_jagged_layout_construction_nested_tensor( - self, device, dtype, requires_grad, components_require_grad): + self, device, dtype, requires_grad, components_require_grad + ): for tensor_list in self._get_example_tensor_lists( - include_list_of_lists=True, include_requires_grad=components_require_grad): + include_list_of_lists=True, include_requires_grad=components_require_grad + ): nt = torch.nested.nested_tensor( tensor_list, device=device, dtype=dtype, layout=torch.jagged, - requires_grad=requires_grad) + requires_grad=requires_grad, + ) expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1 expected_batch_size = len(tensor_list) self._validate_nt( - nt, device, dtype, torch.jagged, requires_grad, expected_dim, expected_batch_size) + nt, + device, + dtype, + torch.jagged, + requires_grad, + expected_dim, + expected_batch_size, + ) # Make sure grads -don't- flow back into original tensors for nested_tensor() if requires_grad: @@ -3582,15 +4100,15 @@ def test_jagged_layout_construction_nested_tensor( @dtypes(torch.float, torch.double, torch.half) @parametrize("components_require_grad", [False, True]) def test_jagged_layout_construction_as_nested_tensor( - self, device, dtype, components_require_grad): + self, device, dtype, components_require_grad + ): # NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list for tensor_list in self._get_example_tensor_lists( - include_list_of_lists=False, include_requires_grad=components_require_grad): + include_list_of_lists=False, include_requires_grad=components_require_grad + ): nt = torch.nested.as_nested_tensor( - tensor_list, - device=device, - dtype=dtype, - layout=torch.jagged) + tensor_list, device=device, dtype=dtype, layout=torch.jagged + ) # nt.requires_grad=True should be set if at least one component requires grad expected_dim = tensor_list[0].dim() + 1 @@ -3602,7 +4120,8 @@ def test_jagged_layout_construction_as_nested_tensor( torch.jagged, components_require_grad, expected_dim, - expected_batch_size) + expected_batch_size, + ) # Make sure grads flow back into original tensors for as_nested_tensor() if components_require_grad: @@ -3614,15 +4133,15 @@ def test_jagged_layout_construction_as_nested_tensor( self.assertTrue(t.grad is None) @xfailIfTorchDynamo - @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") + @unittest.skipIf( + PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property" + ) @onlyCUDA def test_jagged_layout_construction_with_pinned_memory(self, device): for tensor_list in self._get_example_tensor_lists(): nt = torch.nested.nested_tensor( - tensor_list, - layout=torch.jagged, - device="cpu", - pin_memory=True) + tensor_list, layout=torch.jagged, device="cpu", pin_memory=True + ) expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1 expected_batch_size = len(tensor_list) @@ -3633,20 +4152,26 @@ def test_jagged_layout_construction_with_pinned_memory(self, device): layout=torch.jagged, requires_grad=False, dim=expected_dim, - batch_size=expected_batch_size) + batch_size=expected_batch_size, + ) self.assertTrue(nt.is_pinned()) @dtypes(torch.float, torch.double, torch.half) @parametrize("requires_grad", [False, True]) @parametrize("values_is_view", [False, True]) - def test_jagged_view_from_values_offsets(self, device, dtype, requires_grad, values_is_view): + def test_jagged_view_from_values_offsets( + self, device, dtype, requires_grad, values_is_view + ): if values_is_view: # make values a view of base base = torch.randn( - 2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad) + 2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad + ) values = base.flatten(0, -2) else: - values = torch.randn(10, 5, device=device, dtype=dtype, requires_grad=requires_grad) + values = torch.randn( + 10, 5, device=device, dtype=dtype, requires_grad=requires_grad + ) offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64) nt = nested_view_from_values_offsets(values, offsets) @@ -3655,9 +4180,15 @@ def test_jagged_view_from_values_offsets(self, device, dtype, requires_grad, val expected_batch_size = offsets.shape[0] - 1 expected_base = base if values_is_view else values self._validate_nt( - nt, device, dtype, torch.jagged, requires_grad, expected_dim, expected_batch_size, + nt, + device, + dtype, + torch.jagged, + requires_grad, + expected_dim, + expected_batch_size, # ensure NT is a proper view - base=expected_base + base=expected_base, ) if requires_grad: @@ -3687,7 +4218,9 @@ def test_nested_tensor_from_jagged(self, device, dtype): # construct from (values, offsets, lengths) lengths = torch.tensor([2, 1, 1, 2], device=device) - nt = torch.nested.nested_tensor_from_jagged(values, offsets=offsets, lengths=lengths) + nt = torch.nested.nested_tensor_from_jagged( + values, offsets=offsets, lengths=lengths + ) self.assertTrue(isinstance(nt, NestedTensor)) self.assertTrue(nt._is_view() and nt._base is values) self.assertEqual(nt.dim(), 3) @@ -3709,32 +4242,44 @@ def test_nested_tensor_from_jagged(self, device, dtype): # for now, if only lengths is specified, convert to offsets to integrate best with the # existing kernels expected_offsets = torch.tensor([0, 2, 5, 9, 14], device=device) - expected_nt = torch.nested.nested_tensor_from_jagged(values, offsets=expected_offsets) + expected_nt = torch.nested.nested_tensor_from_jagged( + values, offsets=expected_offsets + ) for n1, n2 in zip(nt.unbind(), expected_nt.unbind()): self.assertEqual(n1, n2) # error case: no offsets or lengths - with self.assertRaisesRegex(RuntimeError, "At least one of offsets or lengths is required"): + with self.assertRaisesRegex( + RuntimeError, "At least one of offsets or lengths is required" + ): torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None) @dtypes(torch.float, torch.double, torch.half) @parametrize("dim", range(5)) - @parametrize("layout", [torch.strided, torch.jagged], - name_fn=lambda l: f"layout_{str(l).split('.')[1]}") + @parametrize( + "layout", + [torch.strided, torch.jagged], + name_fn=lambda l: f"layout_{str(l).split('.')[1]}", + ) @parametrize("requires_grad", [False, True]) @parametrize("contiguous", [False, True]) def test_as_nested_tensor_from_tensor( - self, device, dtype, dim, layout, requires_grad, contiguous): + self, device, dtype, dim, layout, requires_grad, contiguous + ): if dim == 0: - t = torch.tensor(3., requires_grad=requires_grad) + t = torch.tensor(3.0, requires_grad=requires_grad) else: t = torch.randn(*(3 for _ in range(dim)), requires_grad=requires_grad) assert t.dim() == dim if dim < 2: # 0-1 dim tensors can't be converted to NTs - with self.assertRaisesRegex(RuntimeError, "Expected tensor argument to have dim"): - nt = torch.nested.as_nested_tensor(t, device=device, dtype=dtype, layout=layout) + with self.assertRaisesRegex( + RuntimeError, "Expected tensor argument to have dim" + ): + nt = torch.nested.as_nested_tensor( + t, device=device, dtype=dtype, layout=layout + ) return orig_t = t @@ -3745,7 +4290,8 @@ def test_as_nested_tensor_from_tensor( expected_dim = t.dim() expected_batch_size = t.size(0) self._validate_nt( - nt, device, dtype, layout, requires_grad, expected_dim, expected_batch_size) + nt, device, dtype, layout, requires_grad, expected_dim, expected_batch_size + ) if torch.device(device) == t.device and dtype == t.dtype and contiguous: # should be the non-copying (view) case @@ -3753,18 +4299,24 @@ def test_as_nested_tensor_from_tensor( # should be equivalent to construction from unbound tensor list nt_from_unbind = torch.nested.as_nested_tensor( - list(t.unbind(0)), device=device, dtype=dtype, layout=layout) + list(t.unbind(0)), device=device, dtype=dtype, layout=layout + ) self.assertEqual(nt, nt_from_unbind) # ensure call on a NT with the same properties returns the NT directly - nt2 = torch.nested.as_nested_tensor(nt, device=device, dtype=dtype, layout=layout) + nt2 = torch.nested.as_nested_tensor( + nt, device=device, dtype=dtype, layout=layout + ) self.assertTrue(nt is nt2) # we don't support conversion between layouts this way atm other_layout = torch.strided if layout == torch.jagged else torch.jagged with self.assertRaisesRegex( - RuntimeError, "Converting between nested tensor layouts is not supported"): - torch.nested.as_nested_tensor(nt, device=device, dtype=dtype, layout=other_layout) + RuntimeError, "Converting between nested tensor layouts is not supported" + ): + torch.nested.as_nested_tensor( + nt, device=device, dtype=dtype, layout=other_layout + ) if requires_grad: # make sure gradients flow back into inputs @@ -3778,10 +4330,8 @@ def test_device_dtype_transfer_updates_offsets(self, device, dtype): orig_device = torch.device("cpu") orig_dtype = torch.float32 nt = torch.nested.nested_tensor( - tensor_list, - layout=torch.jagged, - device=orig_device, - dtype=orig_dtype) + tensor_list, layout=torch.jagged, device=orig_device, dtype=orig_dtype + ) self.assertEqual(torch.int64, nt.offsets().dtype) nt = nt.to(device=device).to(dtype=dtype) @@ -3793,9 +4343,8 @@ def test_device_dtype_transfer_updates_offsets(self, device, dtype): def test_unbind(self, device): for tensor_list in self._get_example_tensor_lists(): nt = torch.nested.nested_tensor( - tensor_list, - layout=torch.jagged, - device=device) # ragged_idx = 1 + tensor_list, layout=torch.jagged, device=device + ) # ragged_idx = 1 out = nt.unbind() self.assertEqual(len(out), len(tensor_list)) for i, t in enumerate(out): @@ -3805,35 +4354,38 @@ def test_unbind(self, device): def test_unbind_transpose(self, device, ragged_idx): for tensor_list in self._get_example_tensor_lists(): nt = torch.nested.nested_tensor( - tensor_list, - layout=torch.jagged, - device=device) + tensor_list, layout=torch.jagged, device=device + ) if ragged_idx < nt.dim(): nt = nt.transpose(1, ragged_idx) # set ragged_idx out = nt.unbind() self.assertEqual(len(out), len(tensor_list)) for i, t in enumerate(out): - self.assertEqual(t.transpose(0, ragged_idx - 1), tensor_list[i]) # transpose back each element of result + self.assertEqual( + t.transpose(0, ragged_idx - 1), tensor_list[i] + ) # transpose back each element of result def test_unbind_transpose_ragged_idx_last_dim(self, device): for tensor_list in self._get_example_tensor_lists(): nt = torch.nested.nested_tensor( - tensor_list, - layout=torch.jagged, - device=device).transpose(1, -1) # set ragged_idx = last dimension + tensor_list, layout=torch.jagged, device=device + ).transpose( + 1, -1 + ) # set ragged_idx = last dimension out = nt.unbind() self.assertEqual(len(out), len(tensor_list)) for i, t in enumerate(out): - self.assertEqual(t.transpose(0, -1), tensor_list[i]) # transpose back each element of result + self.assertEqual( + t.transpose(0, -1), tensor_list[i] + ) # transpose back each element of result def test_unbind_lengths(self, device): values = torch.randn(16, 128, device=device) offsets = torch.tensor([0, 8, 12, 13, 16], device=device) lengths = torch.tensor([6, 2, 1, 2], device=device) nt = torch.nested.nested_tensor_from_jagged( - values, - offsets=offsets, - lengths=lengths) # 3D nested tensor + values, offsets=offsets, lengths=lengths + ) # 3D nested tensor tensor_list = [] for i in range(offsets.shape[0] - 1): @@ -3850,10 +4402,8 @@ def test_unbind_lengths_ragged_idx_1(self, device): lengths = torch.tensor([6, 2, 1, 2], device=device) ragged_idx = 1 nt = torch.nested._internal.nested_tensor.NestedTensor( - values, - offsets=offsets, - lengths=lengths, - _ragged_idx=ragged_idx) # 4D nested tensor + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor tensor_list = [] for i in range(offsets.shape[0] - 1): @@ -3871,28 +4421,23 @@ def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device): lengths = torch.tensor([6, 2, 1, 2], device=device) ragged_idx = 2 nt = torch.nested._internal.nested_tensor.NestedTensor( - values, - offsets=offsets, - lengths=lengths, - _ragged_idx=ragged_idx) # 4D nested tensor + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor self.assertRaisesRegex( RuntimeError, r"unbind\(\): nested tensor offsets and lengths.*", - lambda: nt.unbind() + lambda: nt.unbind(), ) - def test_unbind_lengths_ragged_idx_2(self, device): values = torch.randn(16, 8, 128, device=device) offsets = torch.tensor([0, 2, 4, 8], device=device) lengths = torch.tensor([2, 1, 3], device=device) ragged_idx = 2 nt = torch.nested._internal.nested_tensor.NestedTensor( - values, - offsets=offsets, - lengths=lengths, - _ragged_idx=ragged_idx) # 4D nested tensor + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor tensor_list = [] for i in range(offsets.shape[0] - 1): @@ -3910,10 +4455,8 @@ def test_unbind_lengths_ragged_idx_3(self, device): lengths = torch.tensor([50, 28], device=device) ragged_idx = 3 nt = torch.nested._internal.nested_tensor.NestedTensor( - values, - offsets=offsets, - lengths=lengths, - _ragged_idx=ragged_idx) # 4D nested tensor + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor tensor_list = [] for i in range(offsets.shape[0] - 1): @@ -3925,17 +4468,17 @@ def test_unbind_lengths_ragged_idx_3(self, device): for i, t in enumerate(out): self.assertEqual(t, tensor_list[i]) - @skipIfTorchDynamo("TorchDynamo raises an error for ragged_idx == 0 earlier than Torch") + @skipIfTorchDynamo( + "TorchDynamo raises an error for ragged_idx == 0 earlier than Torch" + ) def test_unbind_lengths_ragged_idx_0(self, device): values = torch.randn(16, 8, 128, device=device) offsets = torch.tensor([0, 100, 128], device=device) lengths = torch.tensor([50, 28], device=device) ragged_idx = 0 nt = torch.nested._internal.nested_tensor.NestedTensor( - values, - offsets=offsets, - lengths=lengths, - _ragged_idx=ragged_idx) # 4D nested tensor + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor tensor_list = [] for i in range(offsets.shape[0] - 1): @@ -3944,7 +4487,7 @@ def test_unbind_lengths_ragged_idx_0(self, device): self.assertRaisesRegex( RuntimeError, r"unbind\(\): nested tensor.*out of bounds", - lambda: nt.unbind() + lambda: nt.unbind(), ) @xfailIfTorchDynamo @@ -3975,15 +4518,12 @@ def test_narrow(self, device): lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) buffer = ( torch.arange(0, 10, device=device, dtype=torch.int64) - .unsqueeze(0).expand(5, -1).clone().detach() - ) - nt = torch.nested.narrow( - buffer, - 1, - starts, - lengths, - layout=torch.jagged + .unsqueeze(0) + .expand(5, -1) + .clone() + .detach() ) + nt = torch.nested.narrow(buffer, 1, starts, lengths, layout=torch.jagged) self.assertTrue(nt._is_view() and nt._base is buffer) @@ -3993,8 +4533,10 @@ def test_narrow(self, device): # self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i]) for i in range(starts.shape[0]): self.assertEqual( - torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), - nt.values()[nt.offsets()[i]:(nt.offsets()[i] + nt.lengths()[i])] + torch.arange( + starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64 + ), + nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])], ) def test_is_contiguous(self, device): @@ -4005,23 +4547,20 @@ def test_is_contiguous(self, device): starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64) lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) - narrow_base = torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone() + narrow_base = ( + torch.arange(0, 10, device=device, dtype=torch.int64) + .unsqueeze(0) + .expand(5, -1) + .clone() + ) nt_noncontiguous = torch.nested.narrow( - narrow_base, - 1, - starts_nc, - lengths_nc, - layout=torch.jagged + narrow_base, 1, starts_nc, lengths_nc, layout=torch.jagged ) starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64) lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64) nt_contiguous_narrow = torch.nested.narrow( - narrow_base, - 1, - starts_c, - lengths_c, - layout=torch.jagged + narrow_base, 1, starts_c, lengths_c, layout=torch.jagged ) # Test contiguous case @@ -4032,23 +4571,36 @@ def test_is_contiguous(self, device): assert nt_contiguous_narrow.is_contiguous() # Test querying by memory_format - self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)) - self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)) - self.assertTrue(nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format)) + self.assertTrue( + nt_contiguous.is_contiguous(memory_format=torch.contiguous_format) + ) + self.assertTrue( + not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format) + ) + self.assertTrue( + nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format) + ) def test_layout_under_torch_dispatch_mode(self): - from torch.testing._internal.logging_tensor import capture_logs_with_logging_tensor_mode + from torch.testing._internal.logging_tensor import ( + capture_logs_with_logging_tensor_mode, + ) - nt = random_nt_from_dims([2, None, 3], torch.device('cpu'), torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged + ) with capture_logs_with_logging_tensor_mode(): self.assertEqual(nt.layout, torch.jagged) @skipIfTorchDynamo("Not a suitable test for TorchDynamo") - @parametrize("func", [torch.empty_like, torch.randn_like], - name_fn=lambda f: f.__name__) + @parametrize( + "func", [torch.empty_like, torch.randn_like], name_fn=lambda f: f.__name__ + ) def test_like_shape(self, func): - nt = random_nt_from_dims([2, None, 3], torch.device('cpu'), torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged + ) nt_like = func(nt) for nt_ub in nt_like.unbind(): @@ -4056,10 +4608,13 @@ def test_like_shape(self, func): self.assertEqual(nt_ub.shape, t_like.shape) @skipIfTorchDynamo("Not a suitable test for TorchDynamo") - @parametrize("func", [torch.ones_like, torch.zeros_like], - name_fn=lambda f: f.__name__) + @parametrize( + "func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__ + ) def test_like_value(self, func): - nt = random_nt_from_dims([2, None, 3], torch.device('cpu'), torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged + ) nt_like = func(nt) for nt_ub in nt_like.unbind(): @@ -4095,8 +4650,13 @@ def check_nt_equality(x, y): def test_to_copy(self, device): nt = torch.nested.nested_tensor( - [torch.randn(i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device) - for i in range(3)], layout=torch.jagged + [ + torch.randn( + i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) + ], + layout=torch.jagged, ) nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16) @@ -4123,16 +4683,20 @@ def test_profiler_sequence_nr(self): fwd_seq_nrs = [] for evt in prof.events(): - if "linear" in evt.name.lower() and "backward" not in evt.name.lower() and evt.sequence_nr != -1: + if ( + "linear" in evt.name.lower() + and "backward" not in evt.name.lower() + and evt.sequence_nr != -1 + ): fwd_seq_nrs.append(evt.sequence_nr) bwd_seq_nrs = [] for evt in prof.events(): if ( - "linear" in evt.name.lower() and - "backward" in evt.name.lower() and - "evaluate_function" not in evt.name.lower() and - evt.sequence_nr != -1 + "linear" in evt.name.lower() + and "backward" in evt.name.lower() + and "evaluate_function" not in evt.name.lower() + and evt.sequence_nr != -1 ): bwd_seq_nrs.append(evt.sequence_nr) @@ -4147,7 +4711,12 @@ def test_profiler_sequence_nr(self): def test_is_same_size(self, device): def get_3_tensors(): - return [torch.randn(i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device) for i in range(3)] + return [ + torch.randn( + i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) + ] nt1, offsets1 = jagged_from_list(get_3_tensors(), None) nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1) @@ -4209,7 +4778,6 @@ def check_results(ref_fn, res_fn, args): res_fn(values, like_values), ) - def fn(values, same_size): return values + same_size @@ -4237,8 +4805,12 @@ def fn(values, same_size): TEST_WITH_ROCM, "ROCm doesn't support flash attention or mem_efficient attention for NT", ) - @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if - SM80OrLater else [torch.float16, torch.float32]) + @parametrize( + "dtype", + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32], + ) def test_sdpa(self, device, dtype): batch_size = 1 emb_dims = 128 @@ -4248,27 +4820,63 @@ def test_sdpa(self, device, dtype): sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) - query = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) - key = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) - value = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) + query = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + key = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + value = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) # Simplest case: 1 sentence, no batching x_d1 = sen1.unsqueeze(0) x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged) # See note below for why we detach here. - q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + q_d1 = ( + query(x_d1) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) q_d1_t = q_d1.transpose(1, 2) - k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + k_d1 = ( + key(x_d1) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) k_d1_t = k_d1.transpose(1, 2) - v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + v_d1 = ( + value(x_d1) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) v_d1_t = v_d1.transpose(1, 2) - q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + q_nt = ( + query(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) q_nt_t = q_nt.transpose(1, 2) - k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + k_nt = ( + key(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) k_nt_t = k_nt.transpose(1, 2) - v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + v_nt = ( + value(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) v_nt_t = v_nt.transpose(1, 2) # High Precision Math Reference @@ -4278,11 +4886,15 @@ def test_sdpa(self, device, dtype): q_d1_f32_t = q_d1_f32.transpose(1, 2) k_d1_f32_t = k_d1_f32.transpose(1, 2) v_d1_f32_t = v_d1_f32.transpose(1, 2) - out_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_f32_t, k_d1_f32_t, v_d1_f32_t)[0] + out_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1_f32_t, k_d1_f32_t, v_d1_f32_t + )[0] grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32)) # Low Precision Math Reference - out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_t, k_d1_t, v_d1_t)[0] + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1_t, k_d1_t, v_d1_t + )[0] grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1)) # Compute tolerances @@ -4293,10 +4905,19 @@ def test_sdpa(self, device, dtype): grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol] grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol] - attn_d1 = torch.nn.functional.scaled_dot_product_attention(q_d1_t, k_d1_t, v_d1_t).transpose(1, 2) - attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt_t, k_nt_t, v_nt_t).transpose(1, 2) + attn_d1 = torch.nn.functional.scaled_dot_product_attention( + q_d1_t, k_d1_t, v_d1_t + ).transpose(1, 2) + attn_nt = torch.nn.functional.scaled_dot_product_attention( + q_nt_t, k_nt_t, v_nt_t + ).transpose(1, 2) - self.assertEqual(attn_d1, attn_nt.unbind()[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual( + attn_d1, + attn_nt.unbind()[0].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) # Simple case: 2 sentences, no extra params x_d2 = sen2.unsqueeze(0) @@ -4305,46 +4926,106 @@ def test_sdpa(self, device, dtype): # NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before # it is transposed. This is because today we cannot backward through view or unbind a # transposed tensor. - q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + q_d2 = ( + query(x_d2) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) q_d2_t = q_d2.transpose(1, 2) - k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + k_d2 = ( + key(x_d2) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) k_d2_t = k_d2.transpose(1, 2) - v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + v_d2 = ( + value(x_d2) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) v_d2_t = v_d2.transpose(1, 2) - q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + q_nt = ( + query(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) q_nt_t = q_nt.transpose(1, 2) - k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + k_nt = ( + key(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) k_nt_t = k_nt.transpose(1, 2) - v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + v_nt = ( + value(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) v_nt_t = v_nt.transpose(1, 2) - attn_d2 = torch.nn.functional.scaled_dot_product_attention(q_d2_t, k_d2_t, v_d2_t).transpose(1, 2) + attn_d2 = torch.nn.functional.scaled_dot_product_attention( + q_d2_t, k_d2_t, v_d2_t + ).transpose(1, 2) d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1)) d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2)) def check_forward_backward(): - attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt_t, k_nt_t, v_nt_t).transpose(1, 2) + attn_nt = torch.nn.functional.scaled_dot_product_attention( + q_nt_t, k_nt_t, v_nt_t + ).transpose(1, 2) attn_nts = attn_nt.unbind() - self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) - self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual( + attn_d1, + attn_nts[0].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) + self.assertEqual( + attn_d2, + attn_nts[1].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) nt_grads = torch.autograd.grad(attn_nt.values().sum(), (q_nt, k_nt, v_nt)) - for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols): + for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip( + nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols + ): unbound_nt_grads = nt_grad.unbind() - self.assertEqual(d1_grad, unbound_nt_grads[0].unsqueeze(0), atol=grad_atol, rtol=grad_rtol) - self.assertEqual(d2_grad, unbound_nt_grads[1].unsqueeze(0), atol=grad_atol, rtol=grad_rtol) + self.assertEqual( + d1_grad, + unbound_nt_grads[0].unsqueeze(0), + atol=grad_atol, + rtol=grad_rtol, + ) + self.assertEqual( + d2_grad, + unbound_nt_grads[1].unsqueeze(0), + atol=grad_atol, + rtol=grad_rtol, + ) # Default check_forward_backward() # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices) - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=True): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=True, enable_math=True + ): check_forward_backward() # Test math fallback - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): # Math fallback doesn't work with bfloat16 on CUDA because # "group_gemm_dispatch" not implemented for 'BFloat16' if not (str(device).startswith("cuda") and dtype == torch.bfloat16): @@ -4356,8 +5037,13 @@ def check_forward_backward(): # Guarding with sqrt() doesn't work on ROCm? @skipCUDAIfRocm @onlyCUDA - @dtypes(*([torch.float16, torch.bfloat16, torch.float32] if SM80OrLater - else [torch.float16, torch.float32])) + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) + ) def test_sdpa_compile(self, device, dtype): batch_size = 1 emb_dims = 1024 @@ -4367,9 +5053,15 @@ def test_sdpa_compile(self, device, dtype): sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) - query = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) - key = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) - value = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) + query = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + key = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + value = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) # Simplest case: 1 sentence, no batching x_d1 = sen1.unsqueeze(0) @@ -4383,28 +5075,61 @@ def test_sdpa_compile(self, device, dtype): k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) - q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2) - k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2) - v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2) + q_nt = ( + query(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .transpose(1, 2) + ) + k_nt = ( + key(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .transpose(1, 2) + ) + v_nt = ( + value(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .transpose(1, 2) + ) # High Precision Math Reference q_d1_f32 = q_d1.to(torch.float32) k_d1_f32 = k_d1.to(torch.float32) v_d1_f32 = v_d1.to(torch.float32) - out_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_f32, k_d1_f32, v_d1_f32)[0] + out_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1_f32, k_d1_f32, v_d1_f32 + )[0] # Low Precision Math Reference - out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1, k_d1, v_d1)[0] + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1, k_d1, v_d1 + )[0] output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) - attn_d1 = torch.nn.functional.scaled_dot_product_attention(q_d1, k_d1, v_d1).transpose(1, 2) - attn_d2 = torch.nn.functional.scaled_dot_product_attention(q_d2, k_d2, v_d2).transpose(1, 2) + attn_d1 = torch.nn.functional.scaled_dot_product_attention( + q_d1, k_d1, v_d1 + ).transpose(1, 2) + attn_d2 = torch.nn.functional.scaled_dot_product_attention( + q_d2, k_d2, v_d2 + ).transpose(1, 2) compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention) attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2) attn_nts = attn_nt.unbind() - self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) - self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual( + attn_d1, + attn_nts[0].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) + self.assertEqual( + attn_d2, + attn_nts[1].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) @dtypes(torch.float32, torch.double, torch.half) def test_sdpa_with_constant_sequence_length(self, device, dtype): @@ -4414,14 +5139,17 @@ def test_sdpa_with_constant_sequence_length(self, device, dtype): # S: (constant) sequence length # D: embedding size query = random_nt_from_dims( - [4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged) + [4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged + ) key = random_nt_from_similar(query) value = random_nt_from_similar(query) output = F.scaled_dot_product_attention(query, key, value) self.assertTrue(isinstance(output, NestedTensor)) # should be equivalent to just running the buffers through - output_dense = F.scaled_dot_product_attention(query._values, key._values, value._values) + output_dense = F.scaled_dot_product_attention( + query._values, key._values, value._values + ) self.assertEqual(output._values, output_dense) # Doesn't work until we have real views @@ -4429,20 +5157,28 @@ def test_sdpa_with_constant_sequence_length(self, device, dtype): @onlyCUDA @unittest.skipIf( not PLATFORM_SUPPORTS_FUSED_ATTENTION, - "Platform doesn't support flash or mem-efficient attention" + "Platform doesn't support flash or mem-efficient attention", + ) + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) ) - @dtypes(*([torch.float16, torch.bfloat16, torch.float32] if SM80OrLater - else [torch.float16, torch.float32])) def test_sdpa_with_packed_in_proj(self, device, dtype): # shape (B, *, D) input_packed = random_nt_from_dims( - [5, None, 10], device=device, dtype=dtype, layout=torch.jagged) + [5, None, 10], device=device, dtype=dtype, layout=torch.jagged + ) # Do input projection. num_heads = 2 # should be multiple of 4 for efficient kernels (e.g. flash / mem-efficient) head_dim = 8 - qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to(device=device, dtype=dtype) + qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to( + device=device, dtype=dtype + ) def in_proj(input_packed, qkv_linear=qkv_linear): qkv_post_proj = qkv_linear(input_packed) @@ -4458,18 +5194,22 @@ def in_proj(input_packed, qkv_linear=qkv_linear): # compare to individually running unbound components through for in_component, out_component in zip( - input_packed.unbind(), - output.transpose(-2, -3).unbind() + input_packed.unbind(), output.transpose(-2, -3).unbind() ): q, k, v = in_proj(in_component) out = F.scaled_dot_product_attention(q, k, v).transpose(-2, -3) # Low Precision Math Reference - out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( - q, k, v)[0].transpose(-2, -3) - output_ref_atol, output_ref_rtol = get_tolerances(out, out_lp_ref, fudge_factor=2) + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q, k, v)[ + 0 + ].transpose(-2, -3) + output_ref_atol, output_ref_rtol = get_tolerances( + out, out_lp_ref, fudge_factor=2 + ) - self.assertEqual(out, out_component, atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual( + out, out_component, atol=output_ref_atol, rtol=output_ref_rtol + ) @skipIfTorchDynamo("SDPA test compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @@ -4477,8 +5217,13 @@ def in_proj(input_packed, qkv_linear=qkv_linear): # mha_varlen_fwd not supported on ROCm @skipCUDAIfRocm @onlyCUDA - @dtypes(*([torch.float16, torch.bfloat16, torch.float32] if SM80OrLater - else [torch.float16, torch.float32])) + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) + ) def test_sdpa_backwards(self, device, dtype): values = torch.randn(9, 3, 256, requires_grad=True, device=device, dtype=dtype) offsets = torch.tensor([0, 1, 3, 5, 9], device=device, dtype=torch.int64) @@ -4520,7 +5265,6 @@ def __init__(self): self.linear = torch.nn.Linear(d2, d3, device=device) def forward(self, query, value, offsets): - value = self.linear(value) key = convert_jagged_to_nested_tensor(value, offsets, max_length_1) value = convert_jagged_to_nested_tensor(value, offsets, max_length_2) @@ -4595,10 +5339,11 @@ def test_fbgemm_jagged_to_padded_dense_kernels(self, device, dtype): # should be equivalent to the original values self.assertEqual(values, output_jagged) + instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) instantiate_device_type_tests(TestNestedTensorAutograd, globals()) instantiate_device_type_tests(TestNestedTensorSubclass, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() From 45dccfddcd8fce804f50075484421ade27f1f021 Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 11 Jun 2024 19:22:18 +0000 Subject: [PATCH 646/706] [cuDNN][SDPA] Support different key, value dimension in cuDNN SDPA (#128350) CC @vedaanta-nvidia @drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/128350 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/cudnn/MHA.cpp | 58 +++++++++++++------ aten/src/ATen/native/cudnn/MHA.h | 6 +- .../native/transformers/cuda/attention.cu | 7 ++- .../transformers/cuda/attention_backward.cu | 8 +-- test/test_transformers.py | 29 +++++++++- 5 files changed, 80 insertions(+), 28 deletions(-) diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 4f992098aea8..ab19b5d68a90 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -13,7 +13,8 @@ void run_cudnn_SDP_fprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool isTraining, bool is_causal, @@ -34,7 +35,8 @@ void run_cudnn_SDP_bprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool is_causal, float dropout_probability, @@ -128,7 +130,8 @@ struct MHAParams { int64_t h; int64_t s_q; int64_t s_kv; - int64_t d; + int64_t d_qk; + int64_t d_v; double dropout_probability; bool is_causal; bool return_softmaxstats; @@ -140,7 +143,8 @@ void setMHAParams( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, const Tensor& q, const Tensor& k, const Tensor& v, @@ -155,7 +159,8 @@ void setMHAParams( } params.b = b; params.h = h; - params.d = d; + params.d_qk = d_qk; + params.d_v = d_v; params.s_q = s_q; params.s_kv = s_kv; params.dropout_probability = dropout_probability; @@ -193,7 +198,8 @@ struct MHACacheKeyWrapper : ParamsWrapper { int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, const Tensor& q, const Tensor& k, const Tensor& v, @@ -206,7 +212,8 @@ struct MHACacheKeyWrapper : ParamsWrapper { h, s_q, s_kv, - d, + d_qk, + d_v, q, k, v, @@ -249,7 +256,8 @@ auto build_graph_and_tensors( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool return_softmaxstats, bool is_causal, @@ -383,7 +391,8 @@ auto build_graph_and_tensors_backward( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool is_causal, float dropout_probability, @@ -514,7 +523,8 @@ void run_cudnn_SDP_fprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool return_softmaxstats, bool is_causal, @@ -528,7 +538,7 @@ void run_cudnn_SDP_fprop( Tensor& dropoutoffset) { cudnnHandle_t handle = getCudnnHandle(); o = at::empty_strided( - {b, h, s_q, d}, {s_q * h * d, d, h * d, 1}, q.options()); + {b, h, s_q, d_v}, {s_q * h * d_v, d_v, h * d_v, 1}, q.options()); if (return_softmaxstats) { // TODO(eqy): verify that this is correct softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat)); @@ -539,7 +549,8 @@ void run_cudnn_SDP_fprop( h, s_q, s_kv, - d, + d_qk, + d_v, q, k, v, @@ -556,7 +567,8 @@ void run_cudnn_SDP_fprop( h, s_q, s_kv, - d, + d_qk, + d_v, scaling_factor, return_softmaxstats, is_causal, @@ -599,7 +611,8 @@ void run_cudnn_SDP_bprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool is_causal, float dropout_probability, @@ -623,7 +636,18 @@ void run_cudnn_SDP_bprop( } cudnnHandle_t handle = getCudnnHandle(); auto key = MHACacheKeyWrapper( - b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true); + b, + h, + s_q, + s_kv, + d_qk, + d_v, + q, + k, + v, + dropout_probability, + is_causal, + true); auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key); graph_and_tensors_backward graph_and_tensors_backward_values; if (graph_and_tensors_backward_ptr) { @@ -634,7 +658,8 @@ void run_cudnn_SDP_bprop( h, s_q, s_kv, - d, + d_qk, + d_v, scaling_factor, is_causal, dropout_probability, @@ -684,5 +709,4 @@ void run_cudnn_SDP_bprop( } // namespace native } // namespace at - #endif diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 0406cf783dc5..8b9315a5a3d8 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -9,7 +9,8 @@ void run_cudnn_SDP_fprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_k, + int64_t d_v, float scaling_factor, bool isTraining, bool is_causal, @@ -27,7 +28,8 @@ void run_cudnn_SDP_bprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_k, + int64_t d_v, float scaling_factor, bool is_causal, float dropout_probability, diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 655efeec5b42..3e307b29512f 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -764,8 +764,8 @@ std::tuple _scaled_dot_product_cudnn_attention_c const int64_t batch_size = query.size(0); const int64_t num_heads = query.size(1); const int64_t max_seqlen_batch_q = query.size(2); - const int64_t head_dim = query.size(3); - + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); const int64_t max_seqlen_batch_k = key.size(2); const int64_t max_seqlen_batch_v = value.size(2); TORCH_CHECK( @@ -806,7 +806,8 @@ std::tuple _scaled_dot_product_cudnn_attention_c num_heads/*int64_t h*/, max_seqlen_batch_q/*int64_t s_q*/, max_seqlen_batch_k/*int64_t s_kv*/, - head_dim/*int64_t d*/, + head_dim_qk/*int64_t d_qk*/, + head_dim_v/*int64_t d_v*/, softmax_scale/*float scaling_factor*/, compute_logsumexp/* bool */, is_causal/* bool */, diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index bc0ce3d25c03..14d389bf8653 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -194,12 +194,11 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ const int64_t batch_size = query.size(0); const int64_t num_heads = query.size(1); - const int64_t head_dim = query.size(3); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); const int64_t max_seqlen_batch_q = query.size(1); const int64_t max_seqlen_batch_k = key.size(1); - const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); - auto dq = at::empty_like(query); auto dk = at::empty_like(key); auto dv = at::empty_like(value); @@ -207,7 +206,8 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ num_heads /*int64_t h*/, max_seqlen_batch_q /*int64_t s_q*/, max_seqlen_batch_k /*int64_t s_kv*/, - head_dim /*int64_t d*/, + head_dim_qk /*int64_t d_qk*/, + head_dim_v /*int64_t d_v*/, softmax_scale /*float scaling_factor*/, is_causal /*bool is_causal*/, dropout_p /*float dropout_probability*/, diff --git a/test/test_transformers.py b/test/test_transformers.py index fdf64f11aed6..eea3b3fab8d9 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1980,7 +1980,7 @@ def ref(x): class TestSDPACudaOnly(NNTestCase): """ Used to test CUDA only functionality of scaled_dot_product_attention Quarks: - There is some trickiness with this function. It's runtime behavior + There is some trickiness with this function. Its runtime behavior is dependent on the CUDA architecture you are testing it on. See `PLATFORM_SUPPORTS_FUSED_ATTENTION` at the top of the file. Summary: @@ -2147,9 +2147,34 @@ def convert_flash_attn_S_to_softmax( S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) return S_converted[:, :, :seqlen_q, :seqlen_k] + @skipIfRocm # No cuDNN Attention + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") + def test_cudnn_attention_different_dk_dv(self, device): + dtype = torch.bfloat16 + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) + batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64 + seq_len = 640 + q_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k) + k_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k) + v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v) + query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): + actual = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) + with sdpa_kernel(backends=[SDPBackend.MATH]): + math_ref = torch.nn.functional.scaled_dot_product_attention( + query.contiguous().to(torch.float32), + key.contiguous().to(torch.float32), + value.contiguous().to(torch.float32), + attn_mask=None, dropout_p=0.0, is_causal=False) + + self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) + + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) - def test_mem_efficient_attetntion_mask_variants(self, device, mask_dim: List[int]): + def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]): dtype = torch.float16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) batch, num_heads, head_dim = 8, 8, 64 From adb699189b9d2de7cfbd71e59c70d916483b23dd Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 11 Jun 2024 19:41:41 +0000 Subject: [PATCH 647/706] Revert "[RELAND][dynamo][nn-modules] Trace through nn.Module dunder methods for UnspecializedNNModule (#126578)" This reverts commit b2d602306a9eb19e30328cbaee941c874f8148a9. Reverted https://github.com/pytorch/pytorch/pull/126578 on behalf of https://github.com/clee2000 due to failed internal test D58394084. Author has forward fix but includes external changes so reverting is a bit easier to coordinate ([comment](https://github.com/pytorch/pytorch/pull/126578#issuecomment-2161481839)) --- test/distributed/test_dynamo_distributed.py | 10 ++- test/dynamo/test_higher_order_ops.py | 16 ++--- ...=> FakeTensorTest.test_embedding_bag_meta} | 0 ...ansformsCPU.test_compile_vmap_hessian_cpu} | 0 ...> TestEmbeddingNN.test_embedding_max_norm} | 0 ...stEmbeddingNN.test_embedding_sparse_basic} | 0 ...ddingNN.test_embedding_sparse_empty_tensor | 0 ...ngNN.test_embeddingbag_include_last_offset | 0 ....test_profiler_pattern_matcher_json_report | 0 .../TestJitGeneratedModule.test_nn_Bilinear | 0 .../TestJitGeneratedModule.test_nn_Embedding | 0 ...dModule.test_nn_EmbeddingBag_discontiguous | 0 ...itGeneratedModule.test_nn_EmbeddingBag_max | 0 ...odule.test_nn_EmbeddingBag_max_padding_idx | 0 ...tGeneratedModule.test_nn_EmbeddingBag_mean | 0 ...dule.test_nn_EmbeddingBag_mean_padding_idx | 0 ...eneratedModule.test_nn_EmbeddingBag_sparse | 0 ...itGeneratedModule.test_nn_EmbeddingBag_sum | 0 ...odule.test_nn_EmbeddingBag_sum_padding_idx | 0 ...atedModule.test_nn_Embedding_discontiguous | 0 ...itGeneratedModule.test_nn_Embedding_sparse | 0 .../TestJitGeneratedModule.test_nn_Linear | 0 ...eneratedModule.test_nn_Linear_no_batch_dim | 0 ...GeneratedModule.test_nn_PReLU_no_batch_dim | 0 .../TestNN.test_ParameterDict | 0 .../TestNN.test_Sequential_iadd | 0 .../TestNN.test_bilinear_broadcasting | 0 ...st_layer_norm_grads_with_create_graph_flag | 0 ..._linear_autograd_device_cpu_bias_weightCOO | 0 ..._linear_autograd_device_cpu_bias_weightCSC | 0 ..._linear_autograd_device_cpu_bias_weightCSR | 0 .../TestNN.test_linear_broadcasting | 0 .../TestNN.test_module_apply_inplace_op | 0 ...metrized_tensor_parametrization_swap_False | 0 ...weight_norm_parametrization_swap_False_cpu | 0 ..._weight_norm_parametrization_swap_True_cpu | 0 ...sorDeviceTypeCPU.test_embedding_jagged_cpu | 0 .../TestPruningNN.test_identity_pruning | 0 .../TestPruningNN.test_random_pruning_0perc | 0 test/profiler/test_profiler.py | 1 - torch/_dynamo/create_parameter_op.py | 20 ------ torch/_dynamo/mutation_guard.py | 3 - torch/_dynamo/side_effects.py | 32 ++++------ torch/_dynamo/symbolic_convert.py | 11 +--- torch/_dynamo/utils.py | 4 +- torch/_dynamo/variables/dicts.py | 6 +- torch/_dynamo/variables/misc.py | 26 +++----- torch/_dynamo/variables/nn_module.py | 40 ++++-------- torch/_dynamo/variables/torch.py | 9 +-- torch/_dynamo/variables/user_defined.py | 63 +++++++------------ 50 files changed, 72 insertions(+), 169 deletions(-) rename test/dynamo_expected_failures/{TestNN.test_overwrite_module_params_on_conversion => FakeTensorTest.test_embedding_bag_meta} (100%) rename test/dynamo_expected_failures/{TestNNParametrization.test_new_spectral_norm_forward_swap_True => TestCompileTransformsCPU.test_compile_vmap_hessian_cpu} (100%) rename test/dynamo_expected_failures/{TestNNParametrization.test_new_spectral_norm_swap_True => TestEmbeddingNN.test_embedding_max_norm} (100%) rename test/dynamo_expected_failures/{TestPruningNN.test_pruning_id_consistency => TestEmbeddingNN.test_embedding_sparse_basic} (100%) create mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor create mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset create mode 100644 test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim create mode 100644 test/dynamo_expected_failures/TestNN.test_ParameterDict create mode 100644 test/dynamo_expected_failures/TestNN.test_Sequential_iadd create mode 100644 test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting create mode 100644 test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_broadcasting create mode 100644 test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False create mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu create mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu create mode 100644 test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu create mode 100644 test/dynamo_expected_failures/TestPruningNN.test_identity_pruning create mode 100644 test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index db44f1ce915d..b31a2f717537 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1084,14 +1084,12 @@ def _(ctx): # far from an exhaustive check of all the expected guards, just check a couple of them. FileCheck().check("""local "L['self']" TYPE_MATCH""").check( """local "L['self']" ID_MATCH""" + ).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check( + f"""{expected_guard_source} "L['self'].net" ID_MATCH""" ).check( - f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" + f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH""" ).check( - f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH""" - ).check( - f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" - ).check( - f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH""" + f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH""" ).run( GUARDS_FILE.getvalue() ) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index c934cf55e8f5..30dff83e12dd 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -5187,10 +5187,10 @@ def wrapper_fn(x): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"): - l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_ + def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): + l_self_tensor_constant0 = L_self_tensor_constant0 - alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None + alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default) @@ -5209,16 +5209,16 @@ def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): - l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_ - l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_ + def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): + getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_ + getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_ l_flat_tangents_1_ = L_flat_tangents_1_ - _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None + _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None - mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None + mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None return (mul_tensor,) """, ) diff --git a/test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion b/test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta similarity index 100% rename from test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion rename to test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True b/test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu similarity index 100% rename from test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True rename to test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm similarity index 100% rename from test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True rename to test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm diff --git a/test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic similarity index 100% rename from test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency rename to test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset b/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report b/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_ParameterDict b/test/dynamo_expected_failures/TestNN.test_ParameterDict new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_Sequential_iadd b/test/dynamo_expected_failures/TestNN.test_Sequential_iadd new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting b/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag b/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_broadcasting b/test/dynamo_expected_failures/TestNN.test_linear_broadcasting new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op b/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False b/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu b/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning b/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc b/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 81d158635c0e..ca0481922e4f 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -2411,7 +2411,6 @@ def test_profiler_matmul_dim_fp16_pattern(self): num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) - @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_profiler_pattern_matcher_json_report(self): x = torch.ones((100, 100)) model = nn.Sequential( diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py index d30e4a37f003..f6cd12de2021 100644 --- a/torch/_dynamo/create_parameter_op.py +++ b/torch/_dynamo/create_parameter_op.py @@ -1,7 +1,4 @@ # mypy: allow-untyped-defs -import threading -from contextlib import contextmanager - import torch doc = """ @@ -40,20 +37,3 @@ def new_parameter_placeholder(size, dtype, device, requires_grad): # Allocating a zero tensor would causes assert failures in autograd. result.untyped_storage().resize_(0) return result - - -_TLS = threading.local() - - -@contextmanager -def do_not_convert_to_tracable_parameter(): - old_flag = getattr(_TLS, "convert_tracable_parameter", True) - _TLS.convert_tracable_parameter = False - try: - yield False - finally: - _TLS.convert_tracable_parameter = old_flag - - -def can_convert_to_tracable_parameter(): - return getattr(_TLS, "convert_tracable_parameter", True) diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py index 9077ecd3d57f..22e2b9999e03 100644 --- a/torch/_dynamo/mutation_guard.py +++ b/torch/_dynamo/mutation_guard.py @@ -11,9 +11,6 @@ from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks -unpatched_nn_module_init = torch.nn.Module.__init__ - - class MutationTracker: db = ExactWeakKeyDictionary() diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 94797251c866..229282f709cb 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -347,7 +347,13 @@ def codegen_save_tempvars(self, cg: PyCodegen): elif isinstance(var.mutable_local, AttributeMutationNew): if isinstance(var, variables.AutogradFunctionContextVariable): unimplemented("AutogradFunctionContextVariable escaped") - cg.load_import_from(utils.__name__, "object_new") + if "__call_nn_module_init" in self.store_attr_mutations.get( + var.mutable_local, {} + ): + assert isinstance(var, variables.UnspecializedNNModuleVariable) + cg.load_import_from(utils.__name__, "nn_module_new") + else: + cg.load_import_from(utils.__name__, "object_new") cg(var.mutable_local.cls_source) cg.extend_output(create_call_function(1, True)) cg.add_cache(var) @@ -474,25 +480,9 @@ def codegen_update_mutated(self, cg: PyCodegen): ] ) elif self.is_attribute_mutation(var): - # Applying mutations involves two steps: 1) Push all - # reconstructed objects onto the stack. 2) Call STORE_ATTR to - # apply the mutations. - # - # Dynamo must ensure that mutations are applied in the same - # order as in the original program. Therefore, two reverse - # operations occur below. - # - # The first reverse operation concerns `suffixes`. We apply - # suffixes in reverse order due to the way Python handles the - # stack. In Step 1, we push all reconstructed objects onto the - # stack, but the item at the top of the stack refers to the last - # attribute in the mutation order. If not fixed, this will apply - # the mutations of attributes in the reverse order. To account - # for this reversal, we iterate through the mutable attributes - # in reverse order. - for name, value in reversed( - self.store_attr_mutations.get(var.mutable_local, {}).items() - ): + for name, value in self.store_attr_mutations.get( + var.mutable_local, {} + ).items(): if isinstance(var, variables.NewGlobalVariable): cg.tx.output.update_co_names(name) cg(value) @@ -500,6 +490,8 @@ def codegen_update_mutated(self, cg: PyCodegen): suffixes.append( [create_instruction("STORE_GLOBAL", argval=name)] ) + elif name == "__call_nn_module_init": + pass # handled in codegen_save_tempvars elif isinstance(value, variables.DeletedVariable): if isinstance( var.mutable_local, AttributeMutationExisting diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 678a0497c8a2..41ceaa615916 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -416,15 +416,10 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): self.push(value) self.jump(inst) elif isinstance(value, UserDefinedObjectVariable): - try: - x = value.var_getattr(self, "__bool__") - except exc.ObservedException: - # if __bool__ is missing, trying __len__ to infer a truth value. + x = value.var_getattr(self, "__bool__") + # if __bool__ is missing, trying __len__ to infer a truth value. + if isinstance(x, GetAttrVariable): x = value.var_getattr(self, "__len__") - else: - if isinstance(x, GetAttrVariable): - # if __bool__ is missing, trying __len__ to infer a truth value. - x = value.var_getattr(self, "__len__") # __bool__ or __len__ is function if isinstance(x, UserMethodVariable): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index fe2f096ec488..6da8b514f16b 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2019,12 +2019,12 @@ def object_has_getattribute(value: Any): return False -def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): +def get_custom_getattr(value: Any): try: getattr_fn = inspect.getattr_static(type(value), "__getattr__") except AttributeError: getattr_fn = None - if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__: + if getattr_fn is torch.nn.Module.__getattr__: # ignore this case of getattr getattr_fn = None return getattr_fn diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 8391563c8e76..0724a80621f7 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -174,11 +174,7 @@ def python_type(self): def __contains__(self, vt): assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker - return ( - is_hashable(vt) - and Hashable(vt) in self.items - and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) - ) + return is_hashable(vt) and Hashable(vt) in self.items def reconstruct(self, codegen): # instructions to load collections.OrderedDict if necessary diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 9ef36eb7f29f..cc0fb7096701 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -14,10 +14,8 @@ import torch.utils._pytree as pytree from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction -from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import unimplemented from ..guards import GuardBuilder, install_guard -from ..mutation_guard import unpatched_nn_module_init from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource from ..utils import ( check_unspec_or_constant_args, @@ -123,6 +121,7 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) + if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: @@ -134,10 +133,12 @@ def call_method( and isinstance(objvar.mutable_local, AttributeMutationNew) and not (args or kwargs) ): - with do_not_convert_to_tracable_parameter(): - return variables.UserFunctionVariable( - unpatched_nn_module_init, source=source - ).call_function(tx, [self.objvar] + args, kwargs) + tx.output.side_effects.store_attr( + objvar, + "__call_nn_module_init", + variables.ConstantVariable.create(True), + ) + return variables.ConstantVariable.create(None) else: unimplemented("super() nn.Module.__init__") elif isinstance(inner_fn, types.FunctionType): @@ -174,19 +175,6 @@ def call_method( self.objvar, UserDefinedObjectVariable ): return self.objvar.method_setattr_standard(tx, *args, **kwargs) - elif inner_fn is object.__delattr__: - attr = args[0] - try: - attr = attr.as_python_constant() - except NotImplementedError: - unimplemented(f"non-const delattr attr: {attr}") - if not tx.output.side_effects.is_attribute_mutation(self.objvar): - unimplemented(f"delattr({self.objvar}, {attr}, ...)") - - tx.output.side_effects.store_attr( - self.objvar, attr, variables.DeletedVariable() - ) - return variables.ConstantVariable(None) unimplemented(f"non-function or method super: {inner_fn}") diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index d3f7052a9445..37c0bc17697a 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -215,7 +215,7 @@ def _custom_getattr_fallback(self, base, tx, name, options): if object_has_getattribute(base): unimplemented("torch.nn.Module with a custom __getattribute__ defined") - getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True) + getattr_fn = get_custom_getattr(base) if getattr_fn is None: return None @@ -665,6 +665,7 @@ def gen_source(source, name): if isinstance(args[0], SliceVariable): # Build a TupleVariable of NNModules result = [] + submods = [] # Turn the slice into the list of integers keys = list(range(len(module)))[args[0].as_python_constant()] @@ -678,8 +679,9 @@ def gen_source(source, name): source=src, ) ) + submods.append(submod) - new_module = module[args[0].as_python_constant()] + new_module = torch.nn.Sequential(*submods) new_module_variable = tx.output.register_attr_or_module( new_module, f"{self}.__getitem__(slice)", @@ -693,10 +695,8 @@ def gen_source(source, name): if isinstance(args[0], SymNodeVariable): key = args[0].evaluate_expr(tx.output) - elif args[0].is_python_constant(): - key = args[0].as_python_constant() else: - unimplemented(f"getitem on NNModuleVariable with key {args[0]}") + key = args[0].as_python_constant() submod = module[key] return tx.output.register_attr_or_module( @@ -790,7 +790,7 @@ def set_nn_module_stack_source(self, source): @functools.lru_cache(None) def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler - supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} + supported = {torch.nn.Module.__setattr__} return { id(x.__code__) for x in torch.nn.Module.__dict__.values() @@ -798,6 +798,8 @@ def _nn_module_method_ids(): } def unpack_var_sequence(self, tx): + from .builder import VariableBuilder + try: fn = inspect.getattr_static(self.value_type, "__iter__") except AttributeError as e: @@ -808,16 +810,11 @@ def unpack_var_sequence(self, tx): torch.nn.ParameterList.__iter__, torch.nn.Sequential.__iter__, ): - # The program can mutate the nn module object but the saved `value` - # will not reflect the mutations. So, trace through the `__iter__` - # function to reflect any tracked mutations. - return tx.inline_user_function_return( - variables.UserFunctionVariable(fn), - [ - self, - ], - {}, - ).unpack_var_sequence(tx) + assert self.source + return [ + VariableBuilder(tx, source=GetItemSource(self.source, idx))(item) + for idx, item in enumerate(self.value) + ] return super().unpack_var_sequence(tx) @@ -946,17 +943,6 @@ def call_method( # Handle submodules self.is_state_mutated = True - if method is torch.nn.Module.__setattr__ and isinstance( - args[1], variables.DeletedVariable - ): - # Trace through __delattr__ to track mutations on the module - # members like `_modules``. - return tx.inline_user_function_return( - variables.UserFunctionVariable(torch.nn.Module.__delattr__), - [self, args[0]], - kwargs, - ) - return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 934e9a316a4b..4d7b96b6a320 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -18,11 +18,7 @@ from ..._guards import TracingContext from .. import config, polyfill, variables from ..codegen import PyCodegen -from ..create_parameter_op import ( - can_convert_to_tracable_parameter, - new_parameter_placeholder, - tracable_create_parameter, -) +from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter from ..device_interface import get_registered_device_interfaces from ..exc import unimplemented from ..guards import GuardBuilder, install_guard @@ -875,9 +871,6 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) - if not can_convert_to_tracable_parameter(): - unimplemented("Workaround for issues with nn_parameter construction") - try: shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) dtype = data.var_getattr(tx, "dtype").as_python_constant() diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 6c79d9cfcbef..7c7673a103fd 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -34,8 +34,7 @@ from torch._guards import TracingContext from .. import variables -from ..create_parameter_op import do_not_convert_to_tracable_parameter -from ..exc import ObservedException, unimplemented +from ..exc import unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource from ..utils import ( @@ -58,7 +57,10 @@ def is_standard_setattr(val): - return val in (object.__setattr__,) + return val in ( + object.__setattr__, + torch.nn.Module.__setattr__, + ) class UserDefinedVariable(VariableTracker): @@ -376,7 +378,17 @@ def call_function( else UserDefinedObjectVariable, {}, ) - with do_not_convert_to_tracable_parameter(): + if ( + inspect.getattr_static(self.value, "__init__", None) + is torch.nn.Module.__init__ + ): + tx.output.side_effects.store_attr( + var, + "__call_nn_module_init", + variables.ConstantVariable.create(True), + ) + return var + else: var.call_method(tx, "__init__", args, kwargs) return var elif variables.CustomizedDictVariable.is_matching_cls(self.value): @@ -626,10 +638,6 @@ def call_method( else AttrSource(AttrSource(self.source, "__class__"), name) ) # TODO(jansel): add a guard to check for monkey patching? - from ..mutation_guard import unpatched_nn_module_init - - if method is torch.nn.Module.__init__: - method = unpatched_nn_module_init return UserMethodVariable(method, self, source=source).call_function( tx, args, kwargs ) @@ -791,7 +799,7 @@ def _check_for_getattr(self): def _getattr_static(self, name): if ( - isinstance(self.value, PyTreeSpec) + isinstance(self.value, (torch.nn.Module, PyTreeSpec)) or "__slots__" in self.value.__class__.__dict__ or type(self.value) == threading.local ): @@ -804,6 +812,7 @@ def _getattr_static(self, name): return cls_var except AttributeError: pass # __slots__ + # this might call torch.nn.Module.__getattr__ subobj = getattr(self.value, name) else: subobj = inspect.getattr_static(self.value, name) @@ -1009,35 +1018,14 @@ def call_hasattr(self, tx, name: str) -> "VariableTracker": install_guard( AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) ) - if self._check_for_getattribute(): - unimplemented("hasattr with custom __getattribute__") + if self._check_for_getattribute() or self._check_for_getattr(): + unimplemented("hasattr with custom __getattr__") try: self._getattr_static(name) return variables.ConstantVariable.create(True) except AttributeError: - # Now check in __getattr__ function - getattr_fn = self._check_for_getattr() - if isinstance(getattr_fn, types.FunctionType): - # Dynamo is going to trace the __getattr__ function with - # args=name. Set the source accordingly. - new_source = None - if self.source: - new_source = AttrSource(self.source, "__getattr__") - try: - result = variables.UserMethodVariable( - getattr_fn, self, source=new_source - ).call_function(tx, [variables.ConstantVariable.create(name)], {}) - - return variables.ConstantVariable.create( - not isinstance(result, variables.DeletedVariable) - ) - except ObservedException: - return variables.ConstantVariable.create(False) - elif getattr_fn is None: - return variables.ConstantVariable.create(False) - else: - unimplemented("UserDefined with non-function __getattr__") + return variables.ConstantVariable.create(False) def odict_getitem(self, tx, key): from .builder import VariableBuilder @@ -1104,12 +1092,6 @@ def var_getattr(self, tx, name): return super().var_getattr(tx, name) -class RemovableHandleClass: - # Dummy class to pass to python_type of RemovableHandleVariable - # Useful for isinstance check on hooks - pass - - class RemovableHandleVariable(VariableTracker): REMOVED = -1 @@ -1140,6 +1122,3 @@ def reconstruct(self, codegen): return # unreachable due to codegen.add_cache() when the hook is installed super().reconstruct(codegen) - - def python_type(self): - return RemovableHandleClass From 70a1e8571802c22c0f09279b77876e6e85c81325 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 10 Jun 2024 20:19:35 -0700 Subject: [PATCH 648/706] [Traceable FSDP2] Use custom ops for AllGather copy-in / copy-out and ReduceScatter copy-in (#127856) Making these operations into custom ops helps Inductor identify these ops and enforce the FSDP communication op ordering. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127856 Approved by: https://github.com/awgu --- .../_composable/fsdp/_fsdp_collectives.py | 111 ++++++++++++++++-- 1 file changed, 102 insertions(+), 9 deletions(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index ac5084813ee1..1423cfd600fc 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -26,6 +26,98 @@ class AllGatherResult(NamedTuple): all_gather_input_split_sizes: List[int] +lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 + +lib.define( + """ + all_gather_copy_in( + Tensor[] all_gather_inputs, + SymInt[] inp_split_sizes, + SymInt all_gather_input_numel, + SymInt world_size, + SymInt rank, + ScalarType dtype, + Device device + ) -> (Tensor, Tensor) + """ +) + + +@torch.library.impl(lib, "all_gather_copy_in", "Meta") +def all_gather_copy_in_meta( + all_gather_inputs: List[torch.Tensor], + inp_split_sizes: List[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor]: + all_gather_output = torch.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device="meta" + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + return all_gather_input, all_gather_output + + +@torch.library.impl(lib, "all_gather_copy_in", "CUDA") +def all_gather_copy_in_cuda( + all_gather_inputs: List[torch.Tensor], + inp_split_sizes: List[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor]: + all_gather_output = torch.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device=device + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) + with torch.no_grad(): + torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) + return all_gather_input, all_gather_output + + +lib.define( + "split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()" +) + + +@torch.library.impl(lib, "split_with_sizes_copy", "Meta") +@torch.library.impl(lib, "split_with_sizes_copy", "CUDA") +def split_with_sizes_copy( + all_gather_output: torch.Tensor, + all_gather_input_split_sizes: List[int], + dim: int, + out: List[torch.Tensor], +) -> None: + torch.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=dim, out=out + ) + + +lib.define( + "chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()" +) + + +@torch.library.impl(lib, "chunk_cat", "Meta") +@torch.library.impl(lib, "chunk_cat", "CUDA") +def chunk_cat( + tensors: List[torch.Tensor], + dim: int, + num_chunks: int, + out: torch.Tensor, +) -> None: + torch._chunk_cat(tensors, dim, num_chunks, out=out) + + @torch.no_grad() def foreach_all_gather( fsdp_params: List[FSDPParam], @@ -53,14 +145,15 @@ def foreach_all_gather( all_gather_inputs = [t for ts in param_all_gather_inputs for t in ts] inp_split_sizes = [t.numel() for t in all_gather_inputs] all_gather_input_numel = sum(inp_split_sizes) - all_gather_output = torch.empty( - (all_gather_input_numel * world_size,), dtype=dtype, device=device - ) - all_gather_input = all_gather_output.narrow( - 0, all_gather_input_numel * rank, all_gather_input_numel + all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( + all_gather_inputs, + inp_split_sizes, + all_gather_input_numel, + world_size, + rank, + dtype, + device, ) - foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) - torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) del param_all_gather_inputs all_gather_stream.wait_stream(all_gather_copy_in_stream) with torch.cuda.stream(all_gather_stream): @@ -124,7 +217,7 @@ def foreach_all_gather_copy_out( out = [t.view(world_size, -1).view(torch.uint8) for t in gen] else: out = [t.view(world_size, -1) for t in gen] - torch.split_with_sizes_copy( + torch.ops.fsdp.split_with_sizes_copy( all_gather_output, all_gather_input_split_sizes, dim=1, out=out ) @@ -259,7 +352,7 @@ def foreach_reduce_scatter_copy_in( world_size: int, ) -> None: reduce_scatter_input = reduce_scatter_input.view(world_size, -1) - torch._chunk_cat( + torch.ops.fsdp.chunk_cat( unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input ) From 8c1247cffb7117da3d4db3a203c727983194c767 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 3 Jun 2024 21:50:42 -0700 Subject: [PATCH 649/706] [BE] Fixed CPU autocast warning (#127774) This PR fixes ``` /data/users/andgu/pytorch/torch/utils/checkpoint.py:1398: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127774 Approved by: https://github.com/soulitzer, https://github.com/Skylion007, https://github.com/tianyu-l --- torch/utils/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 38b747d8cd69..5cbfd1543cf4 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -1396,7 +1396,7 @@ def recompute_fn(*inputs): device_autocast_ctx = torch.amp.autocast( device_type=device, **device_autocast_kwargs ) if torch.amp.is_autocast_available(device) else contextlib.nullcontext() - with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] + with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] fn(*args, **kwargs) new_frame = _CheckpointFrame( From a55d0d9718c11eb2897423c78eff18b168dd0a06 Mon Sep 17 00:00:00 2001 From: rzou Date: Fri, 7 Jun 2024 07:27:26 -0700 Subject: [PATCH 650/706] Fix side effect pruning (#128028) Summary: The previous side effect pruning algorithm would keep many dead cell variables alive. For example, in https://github.com/pytorch/pytorch/issues/125078, the compiled function has one return but there were three in the Dynamo graph due to two dead cell variables not being pruned away. This PR adds a corrected algorithm. "new cell variables" are alive if they can be reached from one of the following: 1. any of the tx.symbolic_locals or tx.stack (that is, if they are involved in a return from the function or intermediate variable during a graph break). Example: an alive NestedUserFunctionVariable 2. "mutations to pre-existing objects". Example: appending a NestedUserFunctionVariable to a global list The new algorithm reflects this, but please let me know if there are more cases to handle. Test Plan: - existing tests (afaict, test/dynamo/test_python_autograd is the best SideEffects test case we have) - see in test/dynamo/test_higher_order_ops that the expecttests changed -- the functorch dynamo graphs no longer return dead cellvars. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128028 Approved by: https://github.com/jansel --- test/dynamo/test_higher_order_ops.py | 40 ++++++++++--------------- torch/_dynamo/side_effects.py | 44 ++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 30dff83e12dd..7a746a9b1d08 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -3,7 +3,6 @@ import functools import pprint import re -import sys import unittest import warnings @@ -2860,7 +2859,7 @@ def forward(self, L_x_: "f32[4, 3]"): _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1) - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); _add_batch_dim_1 = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim_1 = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None @@ -2896,7 +2895,7 @@ def forward(self, L_x_: "f32[4, 3]"): jac_out_in: "f32[4, 3, 4, 3, 12]" = split_2[0]; split_2 = None unflatten: "f32[4, 3, 4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None - return (unflatten, diff_primals, o) + return (unflatten,) """, ) @@ -2964,8 +2963,8 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() - _wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3) - child_4 = torch._C._functorch._wrap_for_grad(child_3, 3) + _wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None + child_4 = torch._C._functorch._wrap_for_grad(child_3, 3); child_3 = None set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) @@ -3002,7 +3001,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1) - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); _add_batch_dim_1 = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = child_4 = _add_batch_dim_1 = None child_5 = _autograd_grad[0]; _autograd_grad = None child_6 = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None @@ -3041,17 +3040,10 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): unflatten: "f32[4, 3, 3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None""", ) - # Python 3.10 and 3.11 produces slightly different graphs - if sys.version_info[:2] > (3, 10): - self.assertExpectedInline( - actual.split("\n")[-2], - """ return (unflatten, child_2, _wrap_for_grad_1, child_3, child_4, o)""", - ) - else: - self.assertExpectedInline( - actual.split("\n")[-2], - """ return (unflatten, child_3, child_2, _wrap_for_grad_1, child_4, o)""", - ) + self.assertExpectedInline( + actual.split("\n")[-2], + """ return (unflatten,)""", + ) @unittest.expectedFailure def test_hessian_disable_capture(self): @@ -3160,7 +3152,7 @@ def forward(self, L_x_: "f32[4, 3]"): _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None @@ -3172,7 +3164,7 @@ def forward(self, L_x_: "f32[4, 3]"): split_1: "f32[12, 4, 3]" = split[0]; split = None output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None - return (output_input, diff_primals, o) + return (output_input,) """, ) @@ -3243,7 +3235,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None @@ -3255,7 +3247,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): split_1: "f32[12, 3, 4]" = split[0]; split = None output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None - return (output_input, diff_primals, o) + return (output_input,) """, ) @@ -3328,7 +3320,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None @@ -3340,7 +3332,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): split_1: "f32[12, 3, 4]" = split[0]; split = None output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None - return (output_input, aux_1, diff_primals, o) + return (output_input, aux_1) """, ) @@ -3776,7 +3768,7 @@ def forward(self, L_x_: "f32[3, 3, 3]"): _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() - return (grad_input_1, y) + return (y, grad_input_1) """, ) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 229282f709cb..9c3bf7d28711 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -295,14 +295,25 @@ def track_tensor_variables_from_runahead_side_effects(self, other): def prune_dead_object_new(self, tx): live_new_objects = set() - skip_obj = None + + # use this to avoid cycles in mutable_local (though I'm not sure if that + # can actually happen). + visited: Any = set({}) def visit(var: VariableTracker): - if ( - isinstance(var.mutable_local, AttributeMutationNew) - and var.mutable_local is not skip_obj - ): + if isinstance(var.mutable_local, AttributeMutationNew): + if var in visited: + return + visited.add(var) + # Object may have been mutated, store this mutation. live_new_objects.add(var.mutable_local) + # It's possible that we have mutated the value of this variable + # to be another one. The new value is in store_attr_mutations. + # Also recurse through the new value to detect alive AttributeMutationNew. + if var.mutable_local in self.store_attr_mutations: + VariableTracker.visit( + visit, self.store_attr_mutations[var.mutable_local] + ) def is_live(var: Union[MutableLocalBase, VariableTracker]): if isinstance(var, AttributeMutationNew): @@ -311,13 +322,22 @@ def is_live(var: Union[MutableLocalBase, VariableTracker]): return is_live(var.mutable_local) return True - VariableTracker.visit(visit, (tx.stack, tx.symbolic_locals)) - for var in self.id_to_variable.values(): - if not isinstance(var.mutable_local, AttributeMutationNew): - VariableTracker.visit(visit, var) - - for skip_obj, setattrs in self.store_attr_mutations.items(): - VariableTracker.visit(visit, setattrs) + pre_existing_vars = [ + var + for var in self.id_to_variable.values() + if not isinstance(var.mutable_local, AttributeMutationNew) + ] + + # The only live side effects come from returns (tx.stack), any intermediates + # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables. + # Recursively visit Variables and see if any of them have been mutated. + VariableTracker.visit(visit, (tx.stack, tx.symbolic_locals, pre_existing_vars)) + + # NB: cell variable handling.is tricky. + # cell variables must stay alive if any NestedUserFunctionVariable + # are live. "visit"-ing the NestedUserFunctionVariable visits + # the .closures field, from which we will see if we need to keep + # any mutations to cell variables alive. self.id_to_variable = { k: v for k, v in self.id_to_variable.items() if is_live(v) From 5fcb5f0c8b1fc19951f45f191b87684ca01f8782 Mon Sep 17 00:00:00 2001 From: Andrew Hoblitzell Date: Tue, 11 Jun 2024 21:56:31 +0000 Subject: [PATCH 651/706] init reshape_from_tensor_shape comment (#128171) Fixes #127897 ### Description Add docstring to torch/onnx/symbolic_opset9.py:sigmoid function ### Checklist - [x] The issue that is being fixed is referred in the description - [x] Only one issue is addressed in this pull request - [x] Labels from the issue that this PR is fixing are added to this pull request - [x] No unnecessary issues are included into this pull request Pull Request resolved: https://github.com/pytorch/pytorch/pull/128171 Approved by: https://github.com/titaiwangms --- torch/onnx/operators.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torch/onnx/operators.py b/torch/onnx/operators.py index 1e7532e8451d..88ac6779f91c 100644 --- a/torch/onnx/operators.py +++ b/torch/onnx/operators.py @@ -32,4 +32,16 @@ def shape_as_tensor(x): def reshape_from_tensor_shape(x, shape): + """Reshape a tensor to the given shape. + + This function is used to make dynamic size operations traceable when exporting models via ONNX. + This function is kept for backward-compatibility. It is implemented directly in ATen. + + Parameters: + x (Tensor): the tensor to be reshaped. + shape (Tensor): the target shape. + + Returns: + Tensor: the reshaped tensor. + """ return torch._reshape_from_tensor(x, shape) From 1dd2431f863b459b5dfe219c6102640081b95c37 Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Tue, 11 Jun 2024 10:23:47 -0700 Subject: [PATCH 652/706] [Test] Add test for only_active flag (#128191) Summary: Add a unit test for the only_active flag to _dump_nccl_trace API call. With this flag, we only expect active records to be returned. Test Plan: Unit test. Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/128191 Approved by: https://github.com/d4l3k --- test/distributed/test_c10d_nccl.py | 32 ++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index feaa649e5851..50ec40291cd7 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3662,7 +3662,8 @@ def test_long(self): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("timing_enabled", [True, False]) - def test_trace_while_active(self, timing_enabled): + @parametrize("only_active", [True, False]) + def test_trace_while_active(self, timing_enabled, only_active): if self.rank == self.MAIN_PROCESS_RANK: for c in self.children_pipes: self.assertEqual(c.recv(), "next") @@ -3683,17 +3684,26 @@ def test_trace_while_active(self, timing_enabled): if self.rank != 0: pg.allreduce(a).wait() e.synchronize() - t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + t = pickle.loads( + torch._C._distributed_c10d._dump_nccl_trace(onlyActive=only_active) + ) t = t["entries"] - self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") - if self.rank == 0: - self.assertEqual(t[-1]["collective_seq_id"], 1) - self.assertEqual(t[-1]["state"], "completed") - else: - self.assertEqual(t[-1]["collective_seq_id"], 2) - self.assertEqual( - t[-1]["state"], self.started_or_scheduled(timing_enabled) - ) + if only_active: + if self.rank == 0: + self.assertEqual(len(t), 0) + else: + self.assertEqual(len(t), 1) + if not only_active: + if self.rank == 0: + self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") + self.assertEqual(t[-1]["collective_seq_id"], 1) + self.assertEqual(t[-1]["state"], "completed") + else: + self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") + self.assertEqual(t[-1]["collective_seq_id"], 2) + self.assertEqual( + t[-1]["state"], self.started_or_scheduled(timing_enabled) + ) self.parent.send("next") self.assertEqual("next", self.parent.recv()) From eb567b1f40233667b982f81e3a75deec0fdfd9ca Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Tue, 11 Jun 2024 10:23:48 -0700 Subject: [PATCH 653/706] Pass params to dump_nccl_trace_pickle (#128307) Summary: Pass parameters from request to dump_nccl_trace_pickle handler. The supported parameters + value are all lowercase. includecollectives={true, false} includestacktraces={true, false} onlyactive={true, false} Example post is: /handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true Test Plan: unit tests Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/128307 Approved by: https://github.com/d4l3k ghstack dependencies: #128191 --- .../distributed/elastic/test_control_plane.py | 42 +++++++++++++++ torch/csrc/distributed/c10d/NCCLUtils.cpp | 51 +++++++++++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 10 ---- .../c10d/control_plane/Handlers.hpp | 3 ++ .../c10d/control_plane/WorkerServer.cpp | 4 ++ 5 files changed, 100 insertions(+), 10 deletions(-) diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index 775b062451b1..971099e32f6d 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -80,6 +80,48 @@ def test_dump_nccl_trace_pickle(self) -> None: resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") self.assertEqual(resp.status, 200) out = pickle.loads(resp.data) + self.assertIsInstance(out, dict) + self.assertIn("version", out) + + @requires_cuda + def test_dump_nccl_trace_pickle_with_params(self) -> None: + with local_worker_server() as pool: + # bad key - not lower case + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true" + ) + self.assertEqual(resp.status, 400) + # unknown key + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?unknownkey=true" + ) + self.assertEqual(resp.status, 400) + # bad value - not a bool + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool" + ) + self.assertEqual(resp.status, 400) + # bad value - value not lowercase + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includecollectives=True" + ) + self.assertEqual(resp.status, 400) + # good key and value + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includecollectives=true" + ) + self.assertEqual(resp.status, 200) + # good key and value + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includestacktraces=true" + ) + self.assertEqual(resp.status, 200) + # multiple good keys and values + resp = pool.request( + "POST", + "/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true", + ) + self.assertEqual(resp.status, 200) def test_tcp(self) -> None: import requests diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index db268371ea0f..01edfdaf9292 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -1,9 +1,12 @@ #include +#include +#include #include #include #ifdef USE_C10D_NCCL +#include #include #include @@ -288,6 +291,54 @@ float getDurationFromEvent( return ncclStartEvent.elapsed_time(ncclEndEvent); } +control_plane::RegisterHandler dumpHandler{ + "dump_nccl_trace_pickle", + [](const control_plane::Request& req, control_plane::Response& res) { + const auto params = req.params(); + size_t validParamCount = 0; + + // valid params + const std::string includeCollectivesStr = "includecollectives"; + const std::string includeStackTracesStr = "includestacktraces"; + const std::string onlyActiveStr = "onlyactive"; + + std::unordered_map expectedParams = { + {includeCollectivesStr, true}, + {includeStackTracesStr, true}, + {onlyActiveStr, false}}; + + for (const auto& [paramName, paramValue] : params) { + auto it = expectedParams.find(paramName); + if (it != expectedParams.end()) { + validParamCount++; + if (paramValue == "true") { + it->second = true; + } else if (paramValue == "false") { + it->second = false; + } else { + res.setStatus(400); + res.setContent( + "Invalid value for " + paramName + + " valid values are true or false", + "text/plain"); + return; + } + } + } + if (validParamCount < params.size()) { + res.setStatus(400); + res.setContent( + "Invalid parameters - unexpected param passed in", "text/plain"); + return; + } + res.setContent( + dump_nccl_trace( + expectedParams[includeCollectivesStr], + expectedParams[includeStackTracesStr], + expectedParams[onlyActiveStr]), + "application/octet-stream"); + }}; + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index bb9198f22200..158522063ab7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -28,7 +28,6 @@ #include #include #include -#include #include #include @@ -379,15 +378,6 @@ std::string dump_nccl_trace( } #endif -// TODO(c-p-i-o): add a JSON endpoint. -control_plane::RegisterHandler dumpHandler{ - "dump_nccl_trace_pickle", - [](const control_plane::Request& req, control_plane::Response& res) { - // TODO: c-p-i-o: params from the request need to go to dump_nccl_trace. - res.setContent( - dump_nccl_trace(true, true, false), "application/octet-stream"); - }}; - std::optional)>>& get_cpp_trace_dumper() { static std::optional< diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp index 0c1063054931..fef4776713e2 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -14,6 +15,8 @@ class TORCH_API Request { public: virtual ~Request() = default; + virtual const std::multimap& params() const = 0; + virtual const std::string& body() = 0; }; diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index e4b649d888dd..b99b9210eb54 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -22,6 +22,10 @@ class RequestImpl : public Request { return req_.body; } + const std::multimap& params() const override { + return req_.params; + } + private: const httplib::Request& req_; }; From b79d056e76ac2644e134c016053ec15c119b53f8 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 11 Jun 2024 22:32:08 +0000 Subject: [PATCH 654/706] [export] FIx unflattener for preserving modules containing unused inputs (#128260) Currently unflattener fails if the module its preserving the module signature for contains unused inputs/outputs. This also fixes unflattener issues in D57829276. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128260 Approved by: https://github.com/pianpwk --- test/export/test_unflatten.py | 25 ++++++++++++++++++++++++ torch/export/unflatten.py | 36 ++++++++++++++++++++++++++++------- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 3940cde45234..383287db421a 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -312,6 +312,31 @@ def forward(self, x): export_module.module(), unflattened, (torch.randn((2, 3)),) ) + @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") + def test_unflatten_preserve_with_unused_input(self): + class M1(torch.nn.Module): + def forward(self, x, a, b): + return x + a, b + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.m1 = M1() + + def forward(self, x, y): + a, b = torch.topk(y, 2) + return self.m1(x, a, b)[0] + + ep = torch.export.export( + M(), + (torch.randn(2), torch.randn(5)), + preserve_module_call_signature=("m1",), + strict=False, + ) + ep.graph.eliminate_dead_code() + unflattened = unflatten(ep) + self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5))) + def test_unflatten_wrong_input(self): class Mod(torch.nn.Module): def __init__(self): diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 4de95dad2c8d..11075058a0e9 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -731,14 +731,20 @@ def __init__( ) if isinstance(arg, ConstantArgument): continue - flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) - self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node + + if arg.name in self.seen_nodes: + flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) + self.node_to_placeholder[ + self.seen_nodes[arg.name] + ] = flat_arg_node with self.parent.graph.inserting_before(self.parent_call_module): input_nodes: List[Optional[torch.fx.Node]] = [] for input in signature.inputs: if isinstance(input, ConstantArgument) and input.value is None: input_nodes.append(None) + elif input.name not in self.seen_nodes: + input_nodes.append(None) else: assert isinstance(input, (TensorArgument, SymIntArgument)) input_nodes.append( @@ -801,18 +807,32 @@ def finalize_outputs(self): if signature is not None and self.parent is not None: for output in signature.outputs: if isinstance(output, (TensorArgument, SymIntArgument)): - orig_outputs.append(self.seen_nodes[output.name]) + if output.name in self.seen_nodes: + orig_outputs.append(self.seen_nodes[output.name]) + else: + orig_outputs.append(None) else: raise RuntimeError( f"Unsupported data type for output node: {output}" ) + def get_actual_output_node(output): + if output is None: + return None + + seen_node = self.seen_nodes[output.name] + if seen_node in self.node_map: + return self.node_map[seen_node] + elif seen_node in self.node_to_placeholder: + return self.node_to_placeholder[seen_node] + else: + raise RuntimeError( + f"Could not find output node {output}. Graph: {self.graph}" + ) + tree_out_node = _generate_unflatten( self.module, - tuple( - self.node_map[self.seen_nodes[output.name]] - for output in orig_outputs - ), + tuple(get_actual_output_node(output) for output in orig_outputs), signature.out_spec, ) parent_out: Optional[torch.fx.Node] = _generate_flatten( @@ -852,6 +872,8 @@ def finalize_outputs(self): self.parent.node_map[orig_outputs[0]] = parent_out else: for i, orig_output in enumerate(orig_outputs): + if orig_output is None: + continue # Use Proxy to record getitem access. proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] proxy_out.meta["val"] = orig_output.meta.get("val") From 447173198b9ff16a908c402b0a077da7489bbb81 Mon Sep 17 00:00:00 2001 From: Andrea Frittoli Date: Tue, 11 Jun 2024 22:42:11 +0000 Subject: [PATCH 655/706] =?UTF-8?q?Add=20docstring=20for=20the=20torch.fx.?= =?UTF-8?q?operator=5Fschemas.create=5Ftype=5Fhint=20func=E2=80=A6=20(#128?= =?UTF-8?q?139)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes: #127916 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128139 Approved by: https://github.com/SherlockNoMad --- torch/fx/operator_schemas.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index becd1ffcd6f4..04be7d139da4 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -184,6 +184,17 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): @compatibility(is_backward_compatible=False) def create_type_hint(x): + """ + Produces a type hint for the given argument. + + The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`. + + If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass + of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned. + If no such object is found, it defaults to `List[Any]`. + + If `x` is neither a `list` nor a `tuple`, it returns `x`. + """ try: if isinstance(x, (list, tuple)): # todo(chilli): Figure out the right way for mypy to handle this From 94fea82d6646c3f55f80d6a4e84a4104e18387f4 Mon Sep 17 00:00:00 2001 From: Andrew Hoblitzell Date: Tue, 11 Jun 2024 22:42:33 +0000 Subject: [PATCH 656/706] init sub comment (#128082) Fixes #127905 ### Description Add docstring to torch/onnx/symbolic_opset9.py:sigmoid function ### Checklist - [x] The issue that is being fixed is referred in the description - [x] Only one issue is addressed in this pull request - [x] Labels from the issue that this PR is fixing are added to this pull request - [x] No unnecessary issues are included into this pull request Pull Request resolved: https://github.com/pytorch/pytorch/pull/128082 Approved by: https://github.com/titaiwangms --- torch/onnx/symbolic_opset9.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 21ce701d26fe..b4c937ed3f66 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -371,6 +371,21 @@ def add(g: jit_utils.GraphContext, self, other, alpha=None): @_onnx_symbolic("aten::sub") @_beartype.beartype def sub(g: jit_utils.GraphContext, self, other, alpha=None): + """ + Consumes sub function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (Optional[Tensor]): A scaling factor to apply to the second operand. + If `alpha` is not provided, it defaults to 1. + + Returns: + ONNX operator + """ if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: other = g.op("Mul", other, alpha) return g.op("Sub", self, other) From c9c1fed06549c7a0e7eb69f3aa8220c9e24f7629 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 11 Jun 2024 23:34:03 +0000 Subject: [PATCH 657/706] Revert "Flip default value for mypy disallow_untyped_defs [10+2/11] (#128374)" This reverts commit c13e03c87428b986972a48d8fc78dbffc2579f63. Reverted https://github.com/pytorch/pytorch/pull/128374 on behalf of https://github.com/clee2000 due to sorry I need to revert this in order to revert something else, to remerge, just rebase and fix the merge conflict ([comment](https://github.com/pytorch/pytorch/pull/128374#issuecomment-2161772864)) --- torch/_C/return_types.pyi.in | 1 - torch/distributed/_tensor/examples/display_sharding_example.py | 1 - torch/nn/functional.pyi.in | 1 - torch/utils/_sympy/numbers.py | 1 - torch/utils/data/datapipes/datapipe.py | 1 - torch/utils/data/datapipes/datapipe.pyi.in | 1 - 6 files changed, 6 deletions(-) diff --git a/torch/_C/return_types.pyi.in b/torch/_C/return_types.pyi.in index ce37323f7b33..458a076d7bfe 100644 --- a/torch/_C/return_types.pyi.in +++ b/torch/_C/return_types.pyi.in @@ -1,5 +1,4 @@ # ${generated_comment} -# mypy: allow-untyped-defs from typing import ( Any, diff --git a/torch/distributed/_tensor/examples/display_sharding_example.py b/torch/distributed/_tensor/examples/display_sharding_example.py index 1ce3962b9545..0e32ed074534 100644 --- a/torch/distributed/_tensor/examples/display_sharding_example.py +++ b/torch/distributed/_tensor/examples/display_sharding_example.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import os from typing import Any, Dict diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 9dec24809e24..5bb847a0a727 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs from typing import ( Any, Callable, diff --git a/torch/utils/_sympy/numbers.py b/torch/utils/_sympy/numbers.py index 6a93255df852..89dac14fddf3 100644 --- a/torch/utils/_sympy/numbers.py +++ b/torch/utils/_sympy/numbers.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import mpmath.libmp as mlib # type: ignore[import-untyped] import sympy from sympy import Expr diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index 8add81987837..1c99fe79e406 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import functools import pickle from typing import Dict, Callable, Optional, TypeVar, Generic, Iterator diff --git a/torch/utils/data/datapipes/datapipe.pyi.in b/torch/utils/data/datapipes/datapipe.pyi.in index 4d03665d5d66..6b3cbe34b46a 100644 --- a/torch/utils/data/datapipes/datapipe.pyi.in +++ b/torch/utils/data/datapipes/datapipe.pyi.in @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs # This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection # The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt # Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other From 5d8c7f39d46699d8f8e92512309ea3499a29c08a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 11 Jun 2024 23:36:08 +0000 Subject: [PATCH 658/706] Revert "Introduce int_oo (#127693)" This reverts commit 9cab5987bdeb66df8efbc581b3469bfe300e168c. Reverted https://github.com/pytorch/pytorch/pull/127693 on behalf of https://github.com/clee2000 due to sorry executorch CI is a bit weird regarding pins, I'll make a chat with mergen with the choices of what to do and how it'll affect executorch CI, reverting for now to prevent more divergences in the meantime ([comment](https://github.com/pytorch/pytorch/pull/127693#issuecomment-2161775400)) --- test/dynamo/test_exc.py | 9 +- test/dynamo/test_export.py | 1 + test/dynamo/test_misc.py | 12 +- test/export/test_export.py | 15 +- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 4 + test/test_dynamic_shapes.py | 11 - test/test_proxy_tensor.py | 4 +- test/test_sympy_utils.py | 70 ---- torch/_decomp/decompositions.py | 9 +- ...runtime_assertions_for_constraints_pass.py | 5 +- torch/_export/serde/serialize.py | 11 +- torch/_inductor/graph.py | 14 +- torch/export/dynamic_shapes.py | 30 +- torch/fx/experimental/symbolic_shapes.py | 83 ++-- torch/fx/passes/runtime_assert.py | 5 +- torch/utils/_sympy/functions.py | 124 ++---- torch/utils/_sympy/interp.py | 27 +- torch/utils/_sympy/numbers.py | 394 ------------------ torch/utils/_sympy/value_ranges.py | 63 +-- 19 files changed, 145 insertions(+), 746 deletions(-) delete mode 100644 torch/utils/_sympy/numbers.py diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index b7b17ed4a1dd..953e8ecd0a35 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -253,6 +253,7 @@ def fn(x, shape): ==> (>= 0 s1) ==> (>= 0 s2) ==> (>= 0 s3) + ==> (>= 9223372036854775806 s0) Failed Source Expressions: ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", @@ -286,14 +287,14 @@ def fn(x, shape): Model: ==> L['shape'][0]: 1 ==> L['shape'][1]: 1 - ==> L['shape'][2]: 0 + ==> L['shape'][2]: 2 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 ==> s0: 3 ==> s1: 1 ==> s2: 1 - ==> s3: 0 + ==> s3: 2 Assertions: ==> (== 0 L['x'].storage_offset()) @@ -317,6 +318,10 @@ def fn(x, shape): ==> (== L['shape'][2] s3) ==> (== L['x'].size()[0] s0) ==> (> s0 0) + ==> (>= 9223372036854775806 s0) + ==> (>= 9223372036854775807 s1) + ==> (>= 9223372036854775807 s2) + ==> (>= 9223372036854775807 s3) Failed Source Expressions: ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 776e8ef85cbf..dbf983faabb7 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -3473,6 +3473,7 @@ def forward(self, pred, x): ] false_guard_code = [ "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)", + "-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])", ] test_symbool_guards( f, diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 68c38089c53a..02f7c68aa1a9 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9309,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]} + > Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9343,7 +9343,7 @@ def test_shape_env_equal_unbacked(self): > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]} + > Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]} > Right: {} """, ) @@ -9420,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self): > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: VR[3, 3], s1: VR[2, int_oo]} - > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} + > Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]} + > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} """, ) self._replay_and_check(main) @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self): > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]} - > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} + > Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]} + > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} """, ) self._replay_and_check(main) diff --git a/test/export/test_export.py b/test/export/test_export.py index c3458ff8003a..19acbbca39f1 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -201,19 +201,6 @@ def forward(self, x): dynamic_shapes={"x": {0: dim_x}}, ) - def test_export_slice_maxsize(self): - class Slice(torch.nn.Module): - def forward(self, *args): - return torch.ops.aten.slice.Tensor(*args) - - inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807) - dynamic_shapes = (({0: Dim("dim")}, None, None, None),) - torch.export.export( - Slice(), - inp, - dynamic_shapes=dynamic_shapes, - ) - def test_export_constraints_error(self): class ConflictingConstraints(torch.nn.Module): def forward(self, x): @@ -5196,7 +5183,7 @@ def forward(self, x): } export(f, (inputs,), dynamic_shapes=dynamic_shapes) - def test_disable_forced_specializations_ok(self): + def test_disable_forced_specializations(self): # check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags # both behave correctly, avoiding forced specializations and deferring to runtime. # case 1: modulo guards diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index f8154a149b41..0f0e01bc0dc2 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -633,6 +633,10 @@ def forward(self, x): func, (torch.randn(3, 4),) ) + @pytorch_test_common.xfail_if_model_type_is_exportedprogram( + error_message="Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}.", + reason="https://github.com/pytorch/pytorch/issues/112622", + ) def test_operator_with_scalar_output(self): class Foo(torch.nn.Module): def forward(self, x, y): diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 156b23742900..60ce1fb764ec 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -385,17 +385,6 @@ def test_size_expressions(self): self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) - def test_floordiv_static(self): - shape_env = ShapeEnv() - s0 = create_symint(shape_env, 8) - # This was extracted from - # python test/inductor/test_cuda_cpp_wrapper.py -k - # DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper - bool(s0 % 2 == 0) - bool(s0 % (s0 // 2) == 0) - bool(2 * (s0 // 2) == s0) - self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2)) - def test_numel(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 3985eea7d5b9..04483ffba0fc 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1201,9 +1201,7 @@ def f(src_tokens): batch_size = 4 src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size)) gm = make_fx(f, tracing_mode="symbolic")(src_tokens) - # Guards to rule out batch_size == sys.maxsize (wobbling between 2 and - # 1 ok) - self.assertEqual(len(gm.shape_env.guards), 1) + self.assertEqual(len(gm.shape_env.guards), 0) @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') def test_cpu_scalar_cuda(self): diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 569c06291331..8b16b2c620fd 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: pt2"] import itertools -import math import sys import sympy @@ -20,7 +19,6 @@ from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis from torch.utils._sympy.interp import sympy_interp from torch.utils._sympy.singleton_int import SingletonInt -from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity from sympy.core.relational import is_ge, is_le, is_gt, is_lt import functools import torch.fx as fx @@ -124,74 +122,6 @@ def generate_range(vals): yield ValueRanges(a1, a2) -class TestNumbers(TestCase): - def test_int_infinity(self): - self.assertIsInstance(int_oo, IntInfinity) - self.assertIsInstance(-int_oo, NegativeIntInfinity) - self.assertTrue(int_oo.is_integer) - # is tests here are for singleton-ness, don't use it for comparisons - # against numbers - self.assertIs(int_oo + int_oo, int_oo) - self.assertIs(int_oo + 1, int_oo) - self.assertIs(int_oo - 1, int_oo) - self.assertIs(-int_oo - 1, -int_oo) - self.assertIs(-int_oo + 1, -int_oo) - self.assertIs(-int_oo + (-int_oo), -int_oo) - self.assertIs(-int_oo - int_oo, -int_oo) - self.assertIs(1 + int_oo, int_oo) - self.assertIs(1 - int_oo, -int_oo) - self.assertIs(int_oo * int_oo, int_oo) - self.assertIs(2 * int_oo, int_oo) - self.assertIs(int_oo * 2, int_oo) - self.assertIs(-1 * int_oo, -int_oo) - self.assertIs(-int_oo * int_oo, -int_oo) - self.assertIs(2 * -int_oo, -int_oo) - self.assertIs(-int_oo * 2, -int_oo) - self.assertIs(-1 * -int_oo, int_oo) - self.assertIs(int_oo / 2, sympy.oo) - self.assertIs(-(-int_oo), int_oo) # noqa: B002 - self.assertIs(abs(int_oo), int_oo) - self.assertIs(abs(-int_oo), int_oo) - self.assertIs(int_oo ** 2, int_oo) - self.assertIs((-int_oo) ** 2, int_oo) - self.assertIs((-int_oo) ** 3, -int_oo) - self.assertEqual(int_oo ** -1, 0) - self.assertEqual((-int_oo) ** -1, 0) - self.assertIs(int_oo ** int_oo, int_oo) - self.assertTrue(int_oo == int_oo) - self.assertFalse(int_oo != int_oo) - self.assertTrue(-int_oo == -int_oo) - self.assertFalse(int_oo == 2) - self.assertTrue(int_oo != 2) - self.assertFalse(int_oo == sys.maxsize) - self.assertTrue(int_oo >= sys.maxsize) - self.assertTrue(int_oo >= 2) - self.assertTrue(int_oo >= -int_oo) - - def test_relation(self): - self.assertIs(sympy.Add(2, int_oo), int_oo) - self.assertFalse(-int_oo > 2) - - def test_lt_self(self): - self.assertFalse(int_oo < int_oo) - self.assertIs(min(-int_oo, -4), -int_oo) - self.assertIs(min(-int_oo, -int_oo), -int_oo) - - def test_float_cast(self): - self.assertEqual(float(int_oo), math.inf) - self.assertEqual(float(-int_oo), -math.inf) - - def test_mixed_oo_int_oo(self): - # Arbitrary choice - self.assertTrue(int_oo < sympy.oo) - self.assertFalse(int_oo > sympy.oo) - self.assertTrue(sympy.oo > int_oo) - self.assertFalse(sympy.oo < int_oo) - self.assertIs(max(int_oo, sympy.oo), sympy.oo) - self.assertTrue(-int_oo > -sympy.oo) - self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo) - - class TestValueRanges(TestCase): @parametrize("fn", UNARY_OPS) @parametrize("dtype", ("int", "float")) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 7ebc69462fa1..7c9d342ea0f0 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -734,11 +734,6 @@ def slice_forward( end: Optional[int] = None, step: int = 1, ): - from torch.fx.experimental.symbolic_shapes import ( - guard_size_oblivious, - statically_known_true, - ) - ndim = self.dim() if ndim == 0: raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") @@ -765,9 +760,7 @@ def slice_forward( if end_val < start_val: end_val = start_val - elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious( - end_val > sizes[dim] - ): + elif end_val > sizes[dim]: end_val = sizes[dim] storage_offset = self.storage_offset() + start_val * strides[dim] diff --git a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py index e3bfb0f3de55..44f0ea270212 100644 --- a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +++ b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -10,7 +10,6 @@ import torch import torch.fx from torch.utils._sympy.value_ranges import ValueRanges -from torch.utils._sympy.numbers import int_oo from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.fx.passes.infra.pass_base import PassBase, PassResult @@ -24,9 +23,9 @@ class InputDim(NamedTuple): def _convert_to_int(val): # Convert simple sympy Integers into concrete int - if val in (sympy.oo, int_oo): + if val == sympy.oo: return math.inf - if val in (-sympy.oo, -int_oo): + if val == -sympy.oo: return -math.inf if isinstance(val, sympy.Integer): return int(val) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index ff729ddb3c5c..f8fdc1011b52 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -42,7 +42,6 @@ from torch.utils import _pytree as pytree from torch.utils._pytree import treespec_dumps, treespec_loads from torch.utils._sympy.value_ranges import ValueRanges -from torch.utils._sympy.numbers import int_oo from .schema import ( # type: ignore[attr-defined] Argument, @@ -322,9 +321,9 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...] def _sympy_int_to_int(val: sympy.Expr, adjust: str): # Convert simple sympy Integers into concrete int - if val in (sympy.oo, int_oo): + if val == sympy.oo: return math.inf - if val in (-sympy.oo, -int_oo): + if val == -sympy.oo: return -math.inf if isinstance(val, sympy.Integer): return int(val) @@ -347,9 +346,9 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str): def _int_to_sympy_int(val) -> sympy.Expr: # Convert concrete int into simple sympy Integers if val == math.inf: - return int_oo + return sympy.oo if val == -math.inf: - return -int_oo + return -sympy.oo return sympy.Integer(val) @@ -1827,7 +1826,7 @@ def deserialize( self.symbol_name_to_range = {} if symbol_name_to_range: for k, vr in symbol_name_to_range.items(): - lower = vr.lower + lower = int(vr.lower) if vr.upper >= 2: # max is >= 2, not sym bool range lower = max(2, lower) self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 19e81b236ad9..f2bdf22e2d96 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -42,7 +42,6 @@ SymTypes, ) from torch.utils._mode_utils import no_dispatch -from torch.utils._sympy.numbers import int_oo from . import config, ir from .codegen.common import ( @@ -1428,21 +1427,18 @@ def format_buffers(): vr = shape_env.var_to_range[i0] if not shape_env._default_unspecified_value_range().issubset(vr): - def is_convertible(s): - if s in (int_oo, -int_oo): - return False + def convert(s): try: - int(s) - return True + return int(s) except TypeError: - return False + return None - if is_convertible(vr.lower): + if (lower := convert(vr.lower)) is not None: self.register_buffer( ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"), set_name=True, ) - if is_convertible(vr.upper): + if (upper := convert(vr.upper)) is not None: self.register_buffer( ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"), set_name=True, diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 8572e069f536..a5ce066faa47 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import builtins import dataclasses import inspect import sys @@ -40,11 +41,9 @@ class _Dim(type): @staticmethod def readable(name, min_, max_): - from torch.utils._sympy.numbers import int_oo - if min_ == 2: min_ = None - if max_ == int_oo: + if max_ == sys.maxsize - 1: max_ = None if min_ is None and max_ is None: return f"Dim('{name}')" @@ -141,11 +140,6 @@ def min(self): # TODO(avik): use sympy value range analysis instead? from sympy import Integer - from torch.utils._sympy.numbers import int_oo - - if self.root.min is -int_oo: # type: ignore[attr-defined] - return -int_oo # fn not needed cuz increasing - _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] root = self.root # type: ignore[attr-defined] assert _min_symint >= 0, ( @@ -161,11 +155,6 @@ def max(self): # TODO(avik): use sympy value range analysis instead? from sympy import Integer - from torch.utils._sympy.numbers import int_oo - - if self.root.max is int_oo: # type: ignore[attr-defined] - return int_oo # fn not needed cuz increasing - _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] root = self.root # type: ignore[attr-defined] assert _max_symint <= sys.maxsize - 1, ( @@ -201,10 +190,8 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): Returns: A type that can be used in dynamic shape specifications for tensors. """ - from torch.utils._sympy.numbers import int_oo - _min = 0 if min is None else min - _max = int_oo if max is None else max + _max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1) assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" dim = _Dim(name, (int,), {"min": _min, "max": _max}) dim.__module__ = getattr( @@ -282,11 +269,10 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): def _clone_with_range(self, lower=0, upper=None): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint - from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.value_ranges import ValueRanges if upper is None: - upper = int_oo + upper = sys.maxsize - 1 constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), @@ -517,14 +503,15 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint - from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.value_ranges import ValueRanges return _create_constraint( weakref.ref(t), id(t), index, - StrictMinMaxConstraint(vr=ValueRanges(lower=0, upper=int_oo), warn_only=False), + StrictMinMaxConstraint( + vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False + ), debug_name=debug_name, ) @@ -738,7 +725,6 @@ def to_constraint(dim, tensor, i): import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint - from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import ValueRanges @@ -813,7 +799,7 @@ def root_value(): constraint = dynamic_dim(tensor, i, debug_name=dim.__name__) if dim.min != 0: constraint = constraint >= dim.min - if dim.max != int_oo: + if dim.max != sys.maxsize - 1: constraint = constraint <= dim.max return constraint diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 800b92b2e07b..bf21ef7ffb2c 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -65,7 +65,6 @@ FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt ) from torch.utils._sympy.solve import try_solve -from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._traceback import format_frame, CapturedTraceback @@ -873,9 +872,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): for N=1. """ if min is None: - min = -int_oo + min = -sys.maxsize - 1 if max is None: - max = int_oo + max = sys.maxsize - 1 if max < min: raise ValueError( @@ -1384,7 +1383,6 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 'PythonMod': operator.mod, 'FloorDiv': operator.floordiv, 'TrueDiv': operator.truediv, - 'PowByNatural': operator.pow, 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 'floor': math.floor, 'ceiling': math.ceil, @@ -1997,7 +1995,7 @@ def _check_same_range(c, dim): (dim.min < 2 and c.get("min", 2) == 2) or dim.min == c.get("min", 2) ) # let pass if analysis min = 2 and specified min = 0/1 - and dim.max == c.get("max", int_oo) + and dim.max == c.get("max", sys.maxsize - 1) ) # 1) newly introduced roots @@ -2020,7 +2018,7 @@ def _check_same_range(c, dim): modulus, remainder = sympy.polys.polytools.div(c["eq"], root) c_min = c.get("min", 2) min_ = math.ceil((c_min - remainder) / modulus) - c_max = c.get("max", int_oo) + c_max = c.get("max", sys.maxsize - 1) max_ = math.floor((c_max - remainder) / modulus) # create result & dim results[str(root)] = {"min": min_, "max": max_} @@ -2768,7 +2766,7 @@ def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None, if min is None: min = 0 if max is None: - max = int_oo + max = sys.maxsize - 1 if max < min: raise ValueError( @@ -4097,7 +4095,7 @@ def issue_guard(guard: ShapeGuard) -> None: assert sources bounds = [] - if r.lower not in (-sympy.oo, -int_oo): + if r.lower != -sympy.oo: if any(is_dim(source) for source in sources): self.dim_constraints.add(sympy.Ge(symbol, r.lower)) # Only print lower bound in simplified mode if it is not the @@ -4105,7 +4103,14 @@ def issue_guard(guard: ShapeGuard) -> None: if not _simplified or r.lower != self._default_value_range().lower: bounds.append(str(r.lower)) bounds.append(source_ref(sources[0])) - if r.upper not in (sympy.oo, int_oo): + # NB: This looks like an off-by-one error but it's not: the + # upper bound may be sys.maxsize - 1 because we intentionally + # exclude sys.maxsize from our bounds to deal with direct + # == INT_MAX guards, but it's still dumb to actually test it. + # Note that you can be off by a pretty large constant and it + # won't matter because sizes in practice will be no where near + # the 64-bit limit. + if r.upper != sympy.oo and r.upper < sys.maxsize - 1: if any(is_dim(source) for source in sources): self.dim_constraints.add(sympy.Le(symbol, r.upper)) # nontrivial upper bound is always interesting @@ -4117,8 +4122,9 @@ def issue_guard(guard: ShapeGuard) -> None: constraints = symbol_to_constraints[symbol] for c in constraints: if isinstance(c, StrictMinMaxConstraint): - # TODO: With int_oo, I think this condition is a noop - # now + # NB: By default, we have a restrictive range + # 2 <= s0 <= sys.maxsize - 1. But export users generally + # expect to be able to specify nice ranges like [0, oo] if not (c.vr & self._default_value_range()).issubset(r): source = sources[0] @@ -4191,9 +4197,9 @@ def issue_guard(guard: ShapeGuard) -> None: # Reason: '_maybe_evaluate_static' may eliminate guards based on the # refined value ranges. for sym, vr in self.var_to_range.items(): - if vr.lower not in (-sympy.oo, -int_oo): + if vr.lower != -sympy.oo: self._add_target_expr(sympy.Le(vr.lower, sym)) - if vr.upper not in (sympy.oo, int_oo): + if vr.upper != sympy.oo: self._add_target_expr(sympy.Le(sym, vr.upper)) # Before validating, populate the input of the validator with the @@ -4325,14 +4331,9 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} if size_oblivious: # Clamp values of size-like variables - # NB: discarding the old upper bound in intentional, per - # https://github.com/pytorch/pytorch/pull/123675 for x in self.size_like & var_to_range.keys(): if var_to_range[x] is not None: - # NB: do NOT set upper to 2 ** 48, we're using this solely - # to determine if we can do size-like replacement, the - # upper bound is irrelevant here - var_to_range[x] = ValueRanges(2, int_oo) + var_to_range[x] = ValueRanges(2, sys.maxsize - 1) assert var_to_range[x].is_int return bound_sympy(expr, var_to_range) @@ -4450,25 +4451,18 @@ def _maybe_evaluate_static( vr = self._default_unspecified_value_range() if size_oblivious and k in self.size_like: lower = max(2, vr.lower) - # Clamping size-oblivious to some quantity below sys.maxsize - # helps us determine that f(u0) != sys.maxsize, which is a - # test that is looking for sys.maxsize as a sentinel, but you - # don't really want to worry about it for unbacked SymInts. - # This is similar to the flavor where size oblivious omits - # 0/1, it changes semantics but in a benign way. - upper = min(2 ** 48, vr.upper) # This is a bit dodgy: what this means is that there was a # size-like unbacked symbol whose upper bound < 2. This # causes... problems. - if lower <= upper: - vr = ValueRanges(lower, upper) + if lower <= vr.upper: + vr = ValueRanges(lower, vr.upper) else: lower = vr.lower # Don't do anything if we don't have a nontrivial lower bound # Also don't do anything if we asked only to simplify unbacked # SymInt if ( - lower is -int_oo or + lower < (-sys.maxsize - 1) // 2 or (unbacked_only and k in self.var_to_val) or not vr.is_int ): @@ -4724,6 +4718,21 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No if a in self.var_to_range: src_bound = self.var_to_range[a] + # If you have x in [2, maxint], then 2*x in [4, 2*maxint]. + # But we don't really care that the max bound says we can + # go beyond the maximum integer size, because we aren't + # using bigints anyway. Arguably, ValueRanges should know + # to do this truncation automaticaly (to avoid doing + # bigint compute in range analysis), but right now it doesn't + # so we need to get rid of some unnecessary precision. + int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1) + + def issubset(x, y): + if x.is_int and y.is_int: + return (x & int_range).issubset(y & int_range) + else: + return x.issubset(y) + # First, refine the value range of a based on the computed value range # of tgt. This is always OK to do, even if we decide not to do the # substitution in the end. This might be a no-op, if a already has @@ -4736,7 +4745,7 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No # - the source bound non-trivially improves over what we get out of # the existing bounds. # - the replacement is univariate and we can invert the tgt expression - if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1: + if not issubset(tgt_bound, src_bound) and len(tgt.free_symbols) == 1: b = next(iter(tgt.free_symbols)) # Try to invert the equality r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) @@ -4751,7 +4760,7 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)) self._update_var_to_range(b, b_bound) tgt_bound = self.bound_sympy(tgt) - assert tgt_bound.issubset(src_bound) + assert issubset(tgt_bound, src_bound) # TODO: Should we propagate size-like-ness? # @@ -4789,13 +4798,13 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No # - If the variable is unbacked, only substitute if the substitution # would preserve the bounds also under size-like-ness conditions. - if not tgt_bound.issubset(src_bound): + if not issubset(tgt_bound, src_bound): self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound) return elif a in self.size_like: tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) src_bound_so = self.bound_sympy(a, size_oblivious=True) - if not tgt_bound_so.issubset(src_bound_so): + if not issubset(tgt_bound_so, src_bound_so): self.log.debug("skipped set_replacement %s = %s (%s) " "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) return @@ -4880,7 +4889,6 @@ def _smart_symbol_sort(x): has_only_ephemeral_sources = ( x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x]) ) - # NB: size_hint is int, not sympy.Expr, do not use int_oo here size = self.size_hint(x, allow_none=True) or sys.maxsize name = x.name # 1 puts ephemeral sourced symbols first when sorting in reverse @@ -4977,12 +4985,15 @@ def trivial_solve(lhs, rhs): return # See: Note - On 0/1 specialization + # NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT + # as a sentinel sometimes. Your sizevar isn't going to be + # anywhere near the max 64-bit integer anyway. def _default_value_range(self) -> ValueRanges: lower = 2 if self.specialize_zero_one else 0 - return ValueRanges(lower, int_oo) + return ValueRanges(lower, sys.maxsize - 1) def _default_unspecified_value_range(self) -> ValueRanges: - return ValueRanges(-int_oo, int_oo) + return ValueRanges(-sys.maxsize - 1, sys.maxsize) @_lru_cache def _simplify_floor_div(self, expr): diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index d1d206eff63e..66b8fbe29d9f 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -65,7 +65,7 @@ def insert_deferred_runtime_asserts( ): assert len(node.args) == 1 nodes_that_already_have_sym_constraint_range.add( - (node.args[0], node.kwargs.get("min"), node.kwargs.get("max")) + (node.args[0], node.kwargs["min"], node.kwargs["max"]) ) if ( node.op == "call_function" @@ -86,7 +86,6 @@ def insert_deferred_runtime_asserts( InnerTensorKey, ) from torch.utils._sympy.interp import sympy_interp - from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.reference import PythonReferenceAnalysis # TODO: Request simplification on runtime asserts before emitting them @@ -368,8 +367,6 @@ def go(node, keypath): # (refinement should not be necessary once runtime # asserts cause refinement, but that's NYI) def convert(s): - if s in (int_oo, -int_oo): - return None try: return int(s) except TypeError: diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 3c845f58117b..25aa07cd5a5c 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -7,8 +7,6 @@ import sympy from sympy import S -from .numbers import int_oo - __all__ = [ "FloorDiv", "ModularIndexing", @@ -105,15 +103,6 @@ def eval(cls, base, divisor): # makes it difficult to check the types. if divisor.is_zero: raise ZeroDivisionError("division by zero") - if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in ( - int_oo, - -int_oo, - sympy.oo, - -sympy.oo, - ): - return sympy.nan - if base is sympy.nan or divisor is sympy.nan: - return sympy.nan if base.is_zero: return sympy.S.Zero @@ -121,23 +110,6 @@ def eval(cls, base, divisor): return base if base.is_integer and divisor == -1: return sympy.Mul(base, -1) - if ( - isinstance(base, sympy.Number) - and isinstance(divisor, sympy.Number) - and ( - base in (int_oo, -int_oo, sympy.oo, -sympy.oo) - or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo) - ) - ): - r = float(base) / float(divisor) - if r == math.inf: - return int_oo - elif r == -math.inf: - return -int_oo - elif math.isnan(r): - return sympy.nan - else: - return sympy.Integer(math.floor(r)) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): return sympy.Integer(int(base) // int(divisor)) if isinstance(base, FloorDiv): @@ -383,10 +355,10 @@ class CeilToInt(sympy.Function): @classmethod def eval(cls, number): # assert number.is_integer is not True, number - if number in (sympy.oo, int_oo): - return int_oo - if number in (-sympy.oo, -int_oo): - return -int_oo + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) if isinstance(number, sympy.Number): return sympy.Integer(math.ceil(float(number))) @@ -397,10 +369,10 @@ class FloorToInt(sympy.Function): @classmethod def eval(cls, number): # assert number.is_integer is not True, number - if number in (sympy.oo, int_oo): - return int_oo - if number in (-sympy.oo, int_oo): - return -int_oo + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) if isinstance(number, sympy.Number): return sympy.Integer(math.floor(float(number))) @@ -449,7 +421,6 @@ def safe_pow(base, exp): return sign * _safe_pow(base, exp) -# Prevent people from overflowing pow def _safe_pow(base, exponent): if exponent < 0: raise ValueError("Exponent must be non-negative.") @@ -458,20 +429,17 @@ def _safe_pow(base, exponent): return 1 half_exp = safe_pow(base, exponent // 2) - if half_exp is int_oo: - return int_oo - - # TODO: microoptimization is to avoid overflowing into arbitrary precision - # and detect overflow prior to doing operations + if half_exp > sys.maxsize - 1: + return sys.maxsize - 1 result = half_exp * half_exp - if result > sys.maxsize: - return int_oo + if result > sys.maxsize - 1: + return sys.maxsize - 1 if exponent % 2 == 1: result *= base - if result > sys.maxsize: - return int_oo + if result > sys.maxsize - 1: + return sys.maxsize - 1 return result @@ -481,20 +449,14 @@ class PowByNatural(sympy.Function): @classmethod def eval(cls, base, exp): - if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer): - r = safe_pow(base, exp) - if r in (-int_oo, int_oo): - return r - return sympy.Integer(r) + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Integer(safe_pow(base, exp)) if isinstance(exp, sympy.Integer): - # Rely on regular sympy Pow for this (note that iterated - # multiplication turns into a Pow anyway, you can't escape!!) - return sympy.Pow(base, exp) - if exp in (int_oo, sympy.oo): - if base.is_nonnegative: - return int_oo - elif base.is_negative: - return sympy.zoo # this is apparently what (-2)**sympy.oo does + # Translate power into iterated multiplication + r = sympy.Integer(1) + for _ in range(int(exp)): + r *= base + return r # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp # is a natural number if we do @@ -507,11 +469,6 @@ class FloatPow(sympy.Function): @classmethod def eval(cls, base, exp): - # NB: These test sympy.Number, not sympy.Float, because: - # - Sometimes we may have sympy.oo or int_oo, and that's not a Float - # (but coerces to math.Inf) - # - Sometimes Float(0.0) will unpredictably decay to Integer(0), - # but we should still accept it in floatey contexts if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): return sympy.Float(float(base) ** float(exp)) # NB: do not do any nontrivial reasoning @@ -555,18 +512,7 @@ def eval(cls, base, divisor): if divisor.is_zero: raise ZeroDivisionError("division by zero") - if ( - isinstance(base, sympy.Number) - and isinstance(divisor, sympy.Number) - and ( - base in (int_oo, -int_oo, sympy.oo, -sympy.oo) - or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo) - ) - ): - # Don't have to worry about precision here, you're getting zero or - # inf from the division - return sympy.Float(float(base) / float(divisor)) - if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): return sympy.Float(int(base) / int(divisor)) @@ -652,10 +598,10 @@ class TruncToInt(sympy.Function): @classmethod def eval(cls, number): # assert number.is_integer is not True, number - if number in (sympy.oo, int_oo): - return int_oo - if number in (-sympy.oo, -int_oo): - return -int_oo + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) if isinstance(number, sympy.Number): return sympy.Integer(math.trunc(float(number))) @@ -668,11 +614,7 @@ class RoundToInt(sympy.Function): def eval(cls, number): # assert number.is_integer is not True, number - if number is sympy.oo: - return int_oo - if number is -sympy.oo: - return -int_oo - if isinstance(number, sympy.Number): + if isinstance(number, sympy.Float): return sympy.Integer(round(float(number), 0)) @@ -699,7 +641,7 @@ class RoundDecimal(sympy.Function): def eval(cls, number, ndigits): # assert number.is_integer is not True, number - if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): + if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): return sympy.Float(round(float(number), int(ndigits))) @@ -714,10 +656,6 @@ def eval(cls, number): if isinstance(number, sympy.Integer): return sympy.Float(int(number)) - if number is int_oo: - return sympy.oo - if number is -int_oo: - return -sympy.oo class Identity(sympy.Function): @@ -763,11 +701,7 @@ def eval(cls, a): # weird objects but ask silly questions, get silly answers except OverflowError: return getattr(sympy, name)(a) - elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]: - if a is int_oo: - a = sympy.oo - if a is -int_oo: - a = -sympy.oo + elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo]: return getattr(sympy, name)(a) return None diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 3bcb369bcebc..1bb60da4f234 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -9,7 +9,6 @@ """ import functools -import logging from typing import Any, Dict, Union import sympy @@ -39,9 +38,6 @@ ) -log = logging.getLogger(__name__) - - # TODO: Dedupe this with SYMPY_INTERP @@ -163,18 +159,11 @@ def sympy_interp( else: handler_name = handlers()[expr.func] handler = getattr(analysis, handler_name) - try: - if handler_name in ASSOCIATIVE_OPS: - assert len(args) > 1 - acc = handler(args[0], args[1]) - for i in range(2, len(args)): - acc = handler(acc, args[i]) - log.debug("%s(%s) -> %s", handler_name, args, acc) - return acc - else: - r = handler(*args) - log.debug("%s(%s) -> %s", handler_name, args, r) - return r - except Exception: - log.warning("failed while executing %s(%s)", handler_name, args) - raise + if handler_name in ASSOCIATIVE_OPS: + assert len(args) > 1 + acc = handler(args[0], args[1]) + for i in range(2, len(args)): + acc = handler(acc, args[i]) + return acc + else: + return handler(*args) diff --git a/torch/utils/_sympy/numbers.py b/torch/utils/_sympy/numbers.py deleted file mode 100644 index 89dac14fddf3..000000000000 --- a/torch/utils/_sympy/numbers.py +++ /dev/null @@ -1,394 +0,0 @@ -import mpmath.libmp as mlib # type: ignore[import-untyped] -import sympy -from sympy import Expr -from sympy.core.decorators import _sympifyit -from sympy.core.expr import AtomicExpr -from sympy.core.numbers import Number -from sympy.core.parameters import global_parameters -from sympy.core.singleton import S, Singleton - - -class IntInfinity(Number, metaclass=Singleton): - r"""Positive integer infinite quantity. - - Integer infinity is a value in an extended integers which - is greater than all other integers. We distinguish it from - sympy's existing notion of infinity in that it reports that - it is_integer. - - Infinity is a singleton, and can be accessed by ``S.IntInfinity``, - or can be imported as ``int_oo``. - """ - - # NB: We can't actually mark this as infinite, as integer and infinite are - # inconsistent assumptions in sympy. We also report that we are complex, - # different from sympy.oo - - is_integer = True - is_commutative = True - is_number = True - is_extended_real = True - is_comparable = True - is_extended_positive = True - is_prime = False - - # Ensure we get dispatched to before plain numbers - _op_priority = 100.0 - - __slots__ = () - - def __new__(cls): - return AtomicExpr.__new__(cls) - - def _sympystr(self, printer): - return "int_oo" - - def _eval_subs(self, old, new): - if self == old: - return new - - # We could do these, not sure about it - """ - def _eval_evalf(self, prec=None): - return Float('inf') - - def evalf(self, prec=None, **options): - return self._eval_evalf(prec) - """ - - @_sympifyit("other", NotImplemented) - def __add__(self, other): - if isinstance(other, Number) and global_parameters.evaluate: - if other is S.NegativeInfinity: - return S.NegativeInfinity - if other in (S.NegativeIntInfinity, S.NaN): - return S.NaN - return self - return Number.__add__(self, other) - - __radd__ = __add__ - - @_sympifyit("other", NotImplemented) - def __sub__(self, other): - if isinstance(other, Number) and global_parameters.evaluate: - if other is S.Infinity: - return S.NegativeInfinity - if other in (S.IntInfinity, S.NaN): - return S.NaN - return self - return Number.__sub__(self, other) - - @_sympifyit("other", NotImplemented) - def __rsub__(self, other): - return (-self).__add__(other) - - @_sympifyit("other", NotImplemented) - def __mul__(self, other): - if isinstance(other, Number) and global_parameters.evaluate: - if other.is_zero or other is S.NaN: - return S.NaN - if other.is_extended_positive: - return self - return S.NegativeIntInfinity - return Number.__mul__(self, other) - - __rmul__ = __mul__ - - @_sympifyit("other", NotImplemented) - def __truediv__(self, other): - if isinstance(other, Number) and global_parameters.evaluate: - if other in ( - S.Infinity, - S.IntInfinity, - S.NegativeInfinity, - S.NegativeIntInfinity, - S.NaN, - ): - return S.NaN - if other.is_extended_nonnegative: - return S.Infinity # truediv produces float - return S.NegativeInfinity # truediv produces float - return Number.__truediv__(self, other) - - def __abs__(self): - return S.IntInfinity - - def __neg__(self): - return S.NegativeIntInfinity - - def _eval_power(self, expt): - if expt.is_extended_positive: - return S.IntInfinity - if expt.is_extended_negative: - return S.Zero - if expt is S.NaN: - return S.NaN - if expt is S.ComplexInfinity: - return S.NaN - if expt.is_extended_real is False and expt.is_number: - from sympy.functions.elementary.complexes import re - - expt_real = re(expt) - if expt_real.is_positive: - return S.ComplexInfinity - if expt_real.is_negative: - return S.Zero - if expt_real.is_zero: - return S.NaN - - return self ** expt.evalf() - - def _as_mpf_val(self, prec): - return mlib.finf - - def __hash__(self): - return super().__hash__() - - def __eq__(self, other): - return other is S.IntInfinity - - def __ne__(self, other): - return other is not S.IntInfinity - - def __gt__(self, other): - if other is S.Infinity: - return sympy.false # sympy.oo > int_oo - elif other is S.IntInfinity: - return sympy.false # consistency with sympy.oo - else: - return sympy.true - - def __ge__(self, other): - if other is S.Infinity: - return sympy.false # sympy.oo > int_oo - elif other is S.IntInfinity: - return sympy.true # consistency with sympy.oo - else: - return sympy.true - - def __lt__(self, other): - if other is S.Infinity: - return sympy.true # sympy.oo > int_oo - elif other is S.IntInfinity: - return sympy.false # consistency with sympy.oo - else: - return sympy.false - - def __le__(self, other): - if other is S.Infinity: - return sympy.true # sympy.oo > int_oo - elif other is S.IntInfinity: - return sympy.true # consistency with sympy.oo - else: - return sympy.false - - @_sympifyit("other", NotImplemented) - def __mod__(self, other): - if not isinstance(other, Expr): - return NotImplemented - return S.NaN - - __rmod__ = __mod__ - - def floor(self): - return self - - def ceiling(self): - return self - - -int_oo = S.IntInfinity - - -class NegativeIntInfinity(Number, metaclass=Singleton): - """Negative integer infinite quantity. - - NegativeInfinity is a singleton, and can be accessed - by ``S.NegativeInfinity``. - - See Also - ======== - - IntInfinity - """ - - # Ensure we get dispatched to before plain numbers - _op_priority = 100.0 - - is_integer = True - is_extended_real = True - is_commutative = True - is_comparable = True - is_extended_negative = True - is_number = True - is_prime = False - - __slots__ = () - - def __new__(cls): - return AtomicExpr.__new__(cls) - - def _eval_subs(self, old, new): - if self == old: - return new - - def _sympystr(self, printer): - return "-int_oo" - - """ - def _eval_evalf(self, prec=None): - return Float('-inf') - - def evalf(self, prec=None, **options): - return self._eval_evalf(prec) - """ - - @_sympifyit("other", NotImplemented) - def __add__(self, other): - if isinstance(other, Number) and global_parameters.evaluate: - if other is S.Infinity: - return S.Infinity - if other in (S.IntInfinity, S.NaN): - return S.NaN - return self - return Number.__add__(self, other) - - __radd__ = __add__ - - @_sympifyit("other", NotImplemented) - def __sub__(self, other): - if isinstance(other, Number) and global_parameters.evaluate: - if other is S.NegativeInfinity: - return S.Infinity - if other in (S.NegativeIntInfinity, S.NaN): - return S.NaN - return self - return Number.__sub__(self, other) - - @_sympifyit("other", NotImplemented) - def __rsub__(self, other): - return (-self).__add__(other) - - @_sympifyit("other", NotImplemented) - def __mul__(self, other): - if isinstance(other, Number) and global_parameters.evaluate: - if other.is_zero or other is S.NaN: - return S.NaN - if other.is_extended_positive: - return self - return S.IntInfinity - return Number.__mul__(self, other) - - __rmul__ = __mul__ - - @_sympifyit("other", NotImplemented) - def __truediv__(self, other): - if isinstance(other, Number) and global_parameters.evaluate: - if other in ( - S.Infinity, - S.IntInfinity, - S.NegativeInfinity, - S.NegativeIntInfinity, - S.NaN, - ): - return S.NaN - if other.is_extended_nonnegative: - return self - return S.Infinity # truediv returns float - return Number.__truediv__(self, other) - - def __abs__(self): - return S.IntInfinity - - def __neg__(self): - return S.IntInfinity - - def _eval_power(self, expt): - if expt.is_number: - if expt in ( - S.NaN, - S.Infinity, - S.NegativeInfinity, - S.IntInfinity, - S.NegativeIntInfinity, - ): - return S.NaN - - if isinstance(expt, sympy.Integer) and expt.is_extended_positive: - if expt.is_odd: - return S.NegativeIntInfinity - else: - return S.IntInfinity - - inf_part = S.IntInfinity**expt - s_part = S.NegativeOne**expt - if inf_part == 0 and s_part.is_finite: - return inf_part - if ( - inf_part is S.ComplexInfinity - and s_part.is_finite - and not s_part.is_zero - ): - return S.ComplexInfinity - return s_part * inf_part - - def _as_mpf_val(self, prec): - return mlib.fninf - - def __hash__(self): - return super().__hash__() - - def __eq__(self, other): - return other is S.NegativeIntInfinity - - def __ne__(self, other): - return other is not S.NegativeIntInfinity - - def __gt__(self, other): - if other is S.NegativeInfinity: - return sympy.true # -sympy.oo < -int_oo - elif other is S.NegativeIntInfinity: - return sympy.false # consistency with sympy.oo - else: - return sympy.false - - def __ge__(self, other): - if other is S.NegativeInfinity: - return sympy.true # -sympy.oo < -int_oo - elif other is S.NegativeIntInfinity: - return sympy.true # consistency with sympy.oo - else: - return sympy.false - - def __lt__(self, other): - if other is S.NegativeInfinity: - return sympy.false # -sympy.oo < -int_oo - elif other is S.NegativeIntInfinity: - return sympy.false # consistency with sympy.oo - else: - return sympy.true - - def __le__(self, other): - if other is S.NegativeInfinity: - return sympy.false # -sympy.oo < -int_oo - elif other is S.NegativeIntInfinity: - return sympy.true # consistency with sympy.oo - else: - return sympy.true - - @_sympifyit("other", NotImplemented) - def __mod__(self, other): - if not isinstance(other, Expr): - return NotImplemented - return S.NaN - - __rmod__ = __mod__ - - def floor(self): - return self - - def ceiling(self): - return self - - def as_powers_dict(self): - return {S.NegativeOne: 1, S.IntInfinity: 1} diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index c1ed0b02946d..e1ef17f3d340 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -6,6 +6,7 @@ import logging import math import operator +import sys from typing import ( Callable, Dict, @@ -23,7 +24,6 @@ from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom import torch -from torch._logging import LazyString from torch._prims_common import dtype_to_type from .functions import ( @@ -43,7 +43,6 @@ TruncToInt, ) from .interp import sympy_interp -from .numbers import int_oo, IntInfinity, NegativeIntInfinity log = logging.getLogger(__name__) @@ -172,10 +171,7 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: self, "is_int", not self.is_bool - and ( - isinstance(lower, (sympy.Integer, NegativeIntInfinity)) - or isinstance(upper, (sympy.Integer, IntInfinity)) - ), + and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)), ) """ # This assert is just impossible right now, too many sympy bugs @@ -272,14 +268,11 @@ def __or__(self: AllVR, other: AllVR) -> AllVR: def is_singleton(self) -> bool: return self.lower == self.upper + # TODO: this doesn't work with bools but arguably it should @staticmethod def unknown() -> ValueRanges[sympy.Expr]: return ValueRanges(-sympy.oo, sympy.oo) - @staticmethod - def unknown_int() -> ValueRanges[sympy.Expr]: - return ValueRanges(-int_oo, int_oo) - @staticmethod def unknown_bool() -> ValueRanges[SympyBoolean]: return ValueRanges(sympy.false, sympy.true) @@ -411,7 +404,7 @@ def constant(value, dtype): elif dtype.is_floating_point: return ValueRanges.unknown() else: - return ValueRanges(-int_oo, int_oo) + return ValueRanges(-sys.maxsize - 1, sys.maxsize) if is_python: type_ = dtype_to_type(dtype) @@ -434,10 +427,6 @@ def constant(value, dtype): def to_dtype(a, dtype, src_dtype=None): if dtype == torch.float64: return ValueRanges.increasing_map(a, ToFloat) - elif dtype == torch.bool: - return ValueRanges.unknown_bool() - elif not dtype.is_floating_point: - return ValueRanges.unknown_int() return ValueRanges.unknown() @staticmethod @@ -533,7 +522,9 @@ def safe_mul(a, b): def int_truediv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) - if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)): + if 0 in b or ( + (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + ): return ValueRanges.unknown() else: return ValueRanges.coordinatewise_monotone_map( @@ -557,17 +548,14 @@ def truediv(a, b): def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) - if 0 in b: + if 0 in b or ( + # TODO: make this more precise + (-sympy.oo in a or sympy.oo in a) + or (-sympy.oo in b or sympy.oo in b) + ): return ValueRanges.unknown() - products = [] - for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]): - r = FloorDiv(x, y) - if r is sympy.nan: - products.append((sympy.sign(x) * sympy.sign(y)) * int_oo) - else: - products.append(r) - - return ValueRanges(min(products), max(products)) + else: + return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv) @classmethod def mod(cls, x, y): @@ -583,10 +571,10 @@ def c_mod(a, b): def c_div(a, b): x = a / b - return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x + return sympy.Integer(x) if x.is_finite else x if 0 in y: - return ValueRanges.unknown_int() + return ValueRanges.unknown() elif y.is_singleton(): y_val = abs(y.lower) # If it wraps, we need to take the whole interval @@ -616,7 +604,7 @@ def modular_indexing(cls, a, b, c): @classmethod def is_non_overlapping_and_dense_indicator(cls, *args): - return ValueRanges.unknown_int() + return ValueRanges.unknown() # TODO: type here is wrong @classmethod def pow_by_natural(cls, a, b): @@ -630,7 +618,7 @@ def pow_by_natural(cls, a, b): # to replacements, so don't assert it, but DO clamp it to prevent # degenerate problems return ValueRanges.coordinatewise_increasing_map( - a, b & ValueRanges(0, int_oo), PowByNatural + a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural ) elif b.is_singleton(): if b.lower % 2 == 0: @@ -960,8 +948,6 @@ def cast(x, dtype): if dtype.is_floating_point: return sympy.Float(x) else: - if x in (int_oo, -int_oo): - return x try: return sympy.Integer(x) except TypeError: @@ -1009,18 +995,7 @@ def __getattr__(self, name): def bound_sympy( expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: - log.debug( - "bound_sympy(%s)%s", - expr, - LazyString( - lambda: "\n" - + "\n".join( - f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols - ) - if ranges - else "" - ), - ) + log.debug("bound_sympy(%s, %s)", expr, ranges) if isinstance(expr, sympy.Number): return ValueRanges.wrap(expr) From 786c24a4cd84a085173e72b59e9d0c356f923249 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Tue, 11 Jun 2024 09:48:36 -0700 Subject: [PATCH 659/706] [inductor] Always realize sigmoid for CPU (#128339) Summary: Currently the cpu backend prefers to always realize exp because it's a heavy op on CPU. For the same reason, we need to realize sigmoid as well. This solves a problem in llama2 inference where exp was repeated in an inner loop for many times. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128339 Approved by: https://github.com/eellison, https://github.com/helloguo, https://github.com/jansel, https://github.com/jgong5, https://github.com/peterbell10 --- test/inductor/test_cpu_repro.py | 14 ++++++++++++++ torch/_inductor/ir.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index b2ab30832e06..79ace98c1b96 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3761,6 +3761,20 @@ def fn(arg0_1): exactly=True, ).run(code) + def test_repeated_exp(self): + def fn(x): + y = x.sigmoid() + return y + 1, y.sum(-1) + + x = torch.randn(1000, 1000) + opt_fn = torch.compile(fn) + _, code = run_and_get_cpp_code(opt_fn, x) + FileCheck().check_count( + ".exp()", + 1, + exactly=True, + ).run(code) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 7b2cf76e7943..9255ee94fe83 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7547,7 +7547,7 @@ def should_realize_on_cpu(loops: Union[Pointwise, Reduction]): """ The heuristic for realizing reused result of heavy ops on cpu """ - heavy_ops = ["exp"] # a list of heavy ops + heavy_ops = ["exp", "sigmoid"] # a list of heavy ops fn_str = loops.inner_fn_str() return any((op + "(") in fn_str for op in heavy_ops) From 6af4c6acad5a352f2f974c73aba3a03535b20a0b Mon Sep 17 00:00:00 2001 From: Kurman Karabukaev Date: Wed, 12 Jun 2024 01:03:40 +0000 Subject: [PATCH 660/706] Migrate test to internal base class, fixes (#128367) Summary: ## Remove etc deps converted tests to non-etcd based rdzv handler so that tests don't have dependency on etcd server ## Adopt pytorch test convetions - test starts with `test_TESTS.py` - Test base class is torch.testing._internal.common_utils.TestCase - include __main__ handler ## reduce test timing (used to take > 300 seconds): 3.05s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_init_method_env_with_torchelastic 2.59s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_init_method_tcp_with_torchelastic 2.33s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic_worker_raise_exception 2.33s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_run_path 2.30s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_nproc_launch_auto_configurations 2.24s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_is_torchelastic_launched_with_logs_spec_defined 2.24s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_is_torchelastic_launched 2.17s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic_multiple_agents 2.12s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic 2.08s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_nproc_gpu_launch_configurations 1.32s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_standalone 1.05s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_nproc_launch_number_configurations 1.05s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_with_env_vars 1.05s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_python 1.05s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_python_caffe2_bc 1.04s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_bash 1.03s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_default_nproc 0.04s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_logs_logs_spec_entrypoint_must_be_defined 0.01s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic_agent_raise_exception 0.01s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_shutdown Test Plan: pytest --durations=0 test/distributed/launcher/run_test.py Differential Revision: D58388182 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128367 Approved by: https://github.com/d4l3k --- .../launcher/{run_test.py => test_run.py} | 91 +++++-------------- 1 file changed, 21 insertions(+), 70 deletions(-) rename test/distributed/launcher/{run_test.py => test_run.py} (89%) diff --git a/test/distributed/launcher/run_test.py b/test/distributed/launcher/test_run.py similarity index 89% rename from test/distributed/launcher/run_test.py rename to test/distributed/launcher/test_run.py index c816042e3e46..ba58aec43871 100644 --- a/test/distributed/launcher/run_test.py +++ b/test/distributed/launcher/test_run.py @@ -13,7 +13,6 @@ import subprocess import sys import tempfile -import unittest import uuid from contextlib import closing from unittest import mock @@ -23,12 +22,13 @@ from torch.distributed.elastic.agent.server.api import RunResult, WorkerState from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs from torch.distributed.elastic.multiprocessing.errors import ChildFailedError -from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer from torch.distributed.elastic.utils import get_socket_with_port from torch.distributed.elastic.utils.distributed import get_free_port from torch.testing._internal.common_utils import ( + run_tests, skip_but_pass_in_sandcastle_if, TEST_WITH_DEV_DBG_ASAN, + TestCase, ) @@ -63,19 +63,7 @@ class MockException(Exception): pass -class ElasticLaunchTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - # start a standalone, single process etcd server to use for all tests - cls._etcd_server = EtcdServer() - cls._etcd_server.start() - cls._etcd_endpoint = cls._etcd_server.get_endpoint() - - @classmethod - def tearDownClass(cls): - # stop the standalone etcd server - cls._etcd_server.stop() - +class ElasticLaunchTest(TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() @@ -103,8 +91,6 @@ def _test_launch_user_script_python(self): args = [ f"--nnodes={nnodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -156,8 +142,6 @@ def test_launch_user_script_bash(self): args = [ f"--nnodes={nnodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -187,8 +171,6 @@ def test_launch_user_script_default_nproc(self): world_size = 1 args = [ f"--nnodes={nnodes}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -220,8 +202,6 @@ def test_launch_with_env_vars(self): os.environ["PET_NNODES"] = str(nnodes) os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node) - os.environ["PET_RDZV_BACKEND"] = "etcd" - os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint os.environ["PET_RDZV_ID"] = run_id os.environ["PET_MONITOR_INTERVAL"] = "1" os.environ["PET_START_METHOD"] = "spawn" @@ -250,8 +230,6 @@ def _test_nproc_launch_configuration(self, nproc_type, expected_number): args = [ f"--nnodes={nnodes}", f"--nproc-per-node={nproc_type}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -272,7 +250,8 @@ def _test_nproc_launch_configuration(self, nproc_type, expected_number): @skip_but_pass_in_sandcastle_if( TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) - def test_nproc_launch_auto_configurations(self): + @patch("torch.cuda.is_available", return_value=False) + def test_nproc_launch_auto_configurations(self, _mock1): self._test_nproc_launch_configuration("auto", os.cpu_count()) @skip_but_pass_in_sandcastle_if( @@ -310,8 +289,9 @@ def test_launch_elastic(self): args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{get_free_port()}", + "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -343,8 +323,9 @@ def test_launch_elastic_worker_raise_exception(self, record_mock): args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{get_free_port()}", + "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'", f"--rdzv-id={run_id}", "--monitor-interval=1", "--max-restarts=0", @@ -376,8 +357,9 @@ def test_launch_elastic_agent_raise_exception(self, record_mock, mock_agent_run) args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{get_free_port()}", + "--rdzv_conf=timeout=5", f"--rdzv-id={run_id}", "--monitor-interval=1", "--max-restarts=0", @@ -452,8 +434,9 @@ def test_launch_elastic_multiple_agents(self): args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{get_free_port()}", + "--rdzv_conf=timeout=5", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -608,21 +591,6 @@ def test_is_not_torchelastic_launched(self): is_torchelastic_launched = fp.readline() self.assertEqual("False", is_torchelastic_launched) - def test_init_method_tcp(self): - port = get_free_port() - with patch.object( - sys, - "argv", - [ - path("bin/test_script_init_method.py"), - f"--init-method=tcp://localhost:{port}", - "--rank=0", - "--world-size=1", - ], - ): - runpy.run_path(sys.argv[0], run_name="__main__") - # nothing to validate, just make sure it runs - @skip_but_pass_in_sandcastle_if( TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) @@ -642,27 +610,6 @@ def test_init_method_tcp_with_torchelastic(self): ) # nothing to validate, just make sure it runs - def test_init_method_env(self): - port = get_free_port() - with patch.dict( - os.environ, - { - "RANK": "0", - "WORLD_SIZE": "1", - "MASTER_ADDR": "localhost", - "MASTER_PORT": str(port), - }, - ), patch.object( - sys, - "argv", - [ - path("bin/test_script_init_method.py"), - "--init-method=env://", - ], - ): - runpy.run_path(sys.argv[0], run_name="__main__") - # nothing to validate, just make sure it runs - @skip_but_pass_in_sandcastle_if( TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) @@ -681,3 +628,7 @@ def test_init_method_env_with_torchelastic(self): ] ) # nothing to validate, just make sure it runs + + +if __name__ == "__main__": + run_tests() From fb013ecb241c4feb858aa60350d2c03083051dbe Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 12 Jun 2024 01:07:14 +0000 Subject: [PATCH 661/706] Remove unused private List::ptr_to_first_element (#128405) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128405 Approved by: https://github.com/ezyang --- aten/src/ATen/core/List.h | 2 -- aten/src/ATen/core/List_inl.h | 7 ------- 2 files changed, 9 deletions(-) diff --git a/aten/src/ATen/core/List.h b/aten/src/ATen/core/List.h index 53560b9666ae..7f65551fbe70 100644 --- a/aten/src/ATen/core/List.h +++ b/aten/src/ATen/core/List.h @@ -478,8 +478,6 @@ namespace impl { // (maybe except for some internal prim ops). using GenericList = List; -const IValue* ptr_to_first_element(const GenericList& list); - } } diff --git a/aten/src/ATen/core/List_inl.h b/aten/src/ATen/core/List_inl.h index 64760b5f782b..0d223122599c 100644 --- a/aten/src/ATen/core/List_inl.h +++ b/aten/src/ATen/core/List_inl.h @@ -350,11 +350,4 @@ void List::unsafeSetElementType(TypePtr t) { impl_->elementType = std::move(t); } -namespace impl { - -inline const IValue* ptr_to_first_element(const GenericList& list) { - return &list.impl_->list[0]; -} - -} } From 219da29dfd8fd39b783b0a25aef693e25bbe6c8a Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 12 Jun 2024 01:10:33 +0000 Subject: [PATCH 662/706] [7/N] Remove unused functions (#128407) Follows #128309 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128407 Approved by: https://github.com/ezyang --- aten/src/ATen/native/SpectralOps.cpp | 9 --------- torch/csrc/jit/mobile/nnc/aot_compiler.cpp | 4 +++- torch/csrc/jit/passes/batch_mm.cpp | 12 ------------ 3 files changed, 3 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 5f9ff1b83822..db8acf193199 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -1195,15 +1195,6 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho #undef REPR } -static Tensor istft(const Tensor& self, const int64_t n_fft, const optional hop_lengthOpt, - const optional win_lengthOpt, const Tensor& window, - const bool center, const bool normalized, const optional onesidedOpt, - const optional lengthOpt) { - return at::native::istft( - self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized, - onesidedOpt, lengthOpt, /*return_complex=*/false); -} - void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { const auto input_sizes = input.sizes(); const auto input_strides = input.strides(); diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp index 1f7ba264048f..98638ff62e26 100644 --- a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp +++ b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp @@ -383,6 +383,8 @@ static std::vector> generateExampleInputs( return example_inputs; } +// TODO(mvz): temporarily disable NNC backend in mobile builds. +/* static c10::IValue preprocess( const torch::jit::Module& mod, const c10::Dict& compile_spec, @@ -440,8 +442,8 @@ static c10::IValue preprocess( } return cu.serialize(); } +*/ -// TODO(mvz): temporarily disable NNC backend in mobile builds. // static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess); } // namespace nnc diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index 052ba45ceb40..7fac68aec4d7 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -464,18 +464,6 @@ static void BatchMMSide(Block* block, AliasDb& alias_db) { } } -static bool hasMutableOperators(Block* block) { - for (auto n : block->nodes()) { - if (n->kind().is_aten() && n->schema().is_mutable()) - return true; - for (auto b : n->blocks()) { - if (hasMutableOperators(b)) - return true; - } - } - return false; -} - static bool hasMMOperators(std::shared_ptr& graph) { DepthFirstGraphNodeIterator it(graph); Node* n = nullptr; From 9538bf4e7c568885a11570e1bd781bfddcbc7405 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 12 Jun 2024 01:18:20 +0000 Subject: [PATCH 663/706] [2/N] Remove inclusion of c10/util/string_utils.h (#128372) Follows #128300. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128372 Approved by: https://github.com/aaronenyeshi --- aten/src/ATen/native/TensorShape.cpp | 1 - .../ATen/native/mps/operations/Activation.mm | 32 ++++++------ aten/src/ATen/native/mps/operations/Blas.mm | 4 +- .../ATen/native/mps/operations/ConstantOps.mm | 2 +- .../ATen/native/mps/operations/Convolution.mm | 50 +++++++++---------- .../native/mps/operations/Distributions.mm | 4 +- aten/src/ATen/native/mps/operations/Linear.mm | 2 +- .../native/mps/operations/LinearAlgebra.mm | 12 ++--- .../src/ATen/native/mps/operations/LossOps.mm | 16 +++--- .../native/mps/operations/RangeFactories.mm | 8 +-- .../ATen/native/mps/operations/ReduceOps.mm | 18 +++---- aten/src/ATen/native/mps/operations/Shape.mm | 10 ++-- aten/src/ATen/native/mps/operations/Sort.mm | 4 +- .../native/mps/operations/TensorCompare.mm | 4 +- aten/src/ATen/native/mps/operations/Unique.mm | 4 +- .../ATen/native/mps/operations/UpSample.mm | 2 +- aten/src/ATen/native/mps/operations/View.mm | 2 +- torch/csrc/autograd/profiler_python.cpp | 2 - 18 files changed, 87 insertions(+), 90 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index bf1b6e5fa262..adcddead041b 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -210,7 +210,6 @@ #include #endif -#include #include #include #include diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index da11401c948d..741789c7eac9 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -143,7 +143,7 @@ Tensor relu_mps(const Tensor& self) { Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); @autoreleasepool { - string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to()); + string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + std::to_string(negative_slope.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); @@ -193,8 +193,8 @@ Tensor relu_mps(const Tensor& self) { Tensor output_ = at::empty_like(self, self.suggest_memory_format()); @autoreleasepool { - string key = - "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to()); + string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + + std::to_string(negative_slope.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); @@ -242,7 +242,7 @@ Tensor relu_mps(const Tensor& self) { MPSStream* stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + to_string(dim); + string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + std::to_string(dim); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); @@ -285,7 +285,7 @@ Tensor relu_mps(const Tensor& self) { MPSStream* stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + to_string(dim); + string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + std::to_string(dim); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output)); MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output)); @@ -539,8 +539,8 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + to_string(threshold.to()) + ":" + - to_string(value.to()); + string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(threshold.to()) + + ":" + std::to_string(value.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); @@ -587,7 +587,7 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c @autoreleasepool { string key = - "threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + to_string(threshold.to()); + "threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + std::to_string(threshold.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); @@ -826,8 +826,8 @@ static void elu_variants_out_mps(const Tensor& self, MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = func_name + ":" + getTensorsStringKey({self}) + ":" + to_string(alpha.to()) + ":" + - to_string(scale.to()) + ":" + to_string(input_scale.to()); + string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to()) + ":" + + std::to_string(scale.to()) + ":" + std::to_string(input_scale.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); @@ -916,8 +916,8 @@ static void elu_variants_out_mps(const Tensor& self, @autoreleasepool { string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" + - to_string(alpha.to()) + ":" + to_string(scale.to()) + ":" + - to_string(input_scale.to()) + ":" + to_string(is_result); + std::to_string(alpha.to()) + ":" + std::to_string(scale.to()) + ":" + + std::to_string(input_scale.to()) + ":" + std::to_string(is_result); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); @@ -1010,7 +1010,7 @@ static void elu_variants_out_mps(const Tensor& self, MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + to_string(dim); + string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(dim); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); NSArray* outputTensorsArray = [mpsGraph splitTensor:inputTensor @@ -1052,7 +1052,7 @@ static void elu_variants_out_mps(const Tensor& self, MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + to_string(dim); + string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(dim); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); MPSGraphTensor* gradOutputTensor = @@ -1855,8 +1855,8 @@ Tensor hardtanh_backward_mps(const Tensor& grad_output, const Tensor& self, cons MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + to_string(min.to()) + - ":" + to_string(max.to()); + string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + + std::to_string(min.to()) + ":" + std::to_string(max.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); diff --git a/aten/src/ATen/native/mps/operations/Blas.mm b/aten/src/ATen/native/mps/operations/Blas.mm index 1714a8e7e2f8..25cc732c1e62 100644 --- a/aten/src/ATen/native/mps/operations/Blas.mm +++ b/aten/src/ATen/native/mps/operations/Blas.mm @@ -136,8 +136,8 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) { Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1); @autoreleasepool { - string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) + - ":" + to_string(alpha_.toDouble()); + string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + + std::to_string(beta_.toDouble()) + ":" + std::to_string(alpha_.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* matMulVecTensor = mpsGraphRankedPlaceHolder(mpsGraph, matMulVec); MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm index 2e7d0881bb60..353978547186 100644 --- a/aten/src/ATen/native/mps/operations/ConstantOps.mm +++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm @@ -33,7 +33,7 @@ }; @autoreleasepool { - string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble()); + string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type())); diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index fbf5a67262be..08ad620a2028 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -193,24 +193,24 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, string bias_shape_key; if (bias_defined) { - bias_shape_key = to_string(bias_shape[0]); + bias_shape_key = std::to_string(bias_shape[0]); } else { bias_shape_key = "nobias"; } string key; if (is3DConv) { - key = "mps_3d_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(stride[2]) + - ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) + ":" + - to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" + to_string(groups) + - ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" + to_string(bias_defined) + ":" + - bias_shape_key; + key = "mps_3d_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + + std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + + std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + + mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key; } else { - key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) + - ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + - to_string(groups) + ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" + - to_string(bias_defined) + ":" + bias_shape_key; + key = "mps_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + + std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + + mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key; } MPSShape* inputShape = mps::getMPSShape(input_t, memory_format); @@ -388,16 +388,16 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size, NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; string key; if (is3DConv) { - key = "mps_3d_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + ":" + - to_string(stride[2]) + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) + - ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" + - to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, weight_t}) + ":" + - string([ns_shape_key UTF8String]); + key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + ":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + + std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + + std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + + getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]); } else { - key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + - to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + - to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key + + key = "mps_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + + std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]); } auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { @@ -547,15 +547,15 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size, NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; string key; if (is3DConv) { - key = "mps_3d_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + - to_string(stride[2]) + ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + - to_string(dilation[2]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + - to_string(padding[2]) + ":" + to_string(groups) + ":" + mem_format_key + + key = "mps_3d_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + + std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + + std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]); } else { - key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + - to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + - to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key + + key = "mps_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + + std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]); } auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index 7ed06c8bf437..303a7bda99f7 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -63,7 +63,7 @@ @autoreleasepool { string key = op_name + getTensorsStringKey({self, mean_opt.value_or(Tensor()), std_opt.value_or(Tensor())}) + ":" + - to_string(val1) + ":" + to_string(val2); + std::to_string(val1) + ":" + std::to_string(val2); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]); @@ -469,7 +469,7 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional(key, [&](auto mpsGraph, auto newCachedGraph) { MPSShape* prob_shape = getMPSShape(self_v); newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]); diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index 6686c2bed06e..fc8253e341f2 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -236,7 +236,7 @@ static Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& g MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" + + string key = "mps_linear_backward_weights:" + std::to_string(bias_defined) + ":" + getTensorsStringKey({input_reshaped, weight, grad_output_reshaped}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index e0db2c1e8b9b..25405cf4d395 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -229,8 +229,8 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) @autoreleasepool { string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl"); - key += getTensorsStringKey({batch1, batch2, input}) + ":" + to_string(beta.toDouble()) + ":" + - to_string(alpha.toDouble()); + key += getTensorsStringKey({batch1, batch2, input}) + ":" + std::to_string(beta.toDouble()) + ":" + + std::to_string(alpha.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input); @@ -331,8 +331,8 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) }; @autoreleasepool { - string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(beta.toDouble()) + - ":" + to_string(alpha.toDouble()); + string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + + std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* selfTensor = nil; MPSGraphTensor* otherTensor = nil; @@ -615,8 +615,8 @@ Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, cons }; @autoreleasepool { - string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + to_string(beta.toDouble()) + - ":" + to_string(alpha.toDouble()); + string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + + std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape); MPSGraphTensor* t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape); diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index 3e58d2ca8a4b..65540c770db4 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -69,7 +69,7 @@ static string reductionToString(int64_t reduction) { }; @autoreleasepool { - string key = op_name + reductionToString(reduction) + ":" + to_string(grad_input.sizes()[1]) + + string key = op_name + reductionToString(reduction) + ":" + std::to_string(grad_input.sizes()[1]) + getTensorsStringKey({input, target, grad_output}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); @@ -327,8 +327,8 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg, } @autoreleasepool { string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) + - to_string(numClasses) + ":" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" + - to_string(isTargetCasted) + ":" + reductionToString(reduction); + std::to_string(numClasses) + ":" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) + + ":" + std::to_string(isTargetCasted) + ":" + reductionToString(reduction); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); @@ -463,9 +463,9 @@ static void nllnd_loss_forward_impl(Tensor& output, NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; // TODO: Make the key - string key = "nllnd_loss_forward_impl:" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" + - reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" + - getMPSTypeString(target) + ":" + to_string(isTargetCasted) + ":" + getMPSTypeString(weight); + string key = "nllnd_loss_forward_impl:" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) + + ":" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" + + getMPSTypeString(target) + ":" + std::to_string(isTargetCasted) + ":" + getMPSTypeString(weight); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape); MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), target_shape); @@ -598,7 +598,7 @@ static void smooth_l1_loss_impl(const Tensor& input, NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + - to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target); + std::to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { // smooth_l1_loss_mps: // ln = 0.5 * ( xn - yn ) ^ 2 / beta, if |xn - yn| < beta @@ -734,7 +734,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output, @autoreleasepool { string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" + - reductionToString(reduction) + ":" + to_string(beta); + reductionToString(reduction) + ":" + std::to_string(beta); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); diff --git a/aten/src/ATen/native/mps/operations/RangeFactories.mm b/aten/src/ATen/native/mps/operations/RangeFactories.mm index 102c54c251db..e558cb1d0d15 100644 --- a/aten/src/ATen/native/mps/operations/RangeFactories.mm +++ b/aten/src/ATen/native/mps/operations/RangeFactories.mm @@ -106,7 +106,7 @@ auto stream = getCurrentMPSStream(); auto mpsDataType = getMPSDataType(result); @autoreleasepool { - string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size); + string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size); auto cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { @@ -173,7 +173,7 @@ auto stream = getCurrentMPSStream(); auto mpsDataType = getMPSDataType(result); @autoreleasepool { - string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size); + string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size); auto cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { @@ -221,8 +221,8 @@ bool start_less_end = (start.to() <= end.to()); @autoreleasepool { - string key = - "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end); + string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + std::to_string(steps) + + std::to_string(start_less_end); auto cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 416c83f0d3b3..b5ebd959932d 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -359,8 +359,8 @@ static void impl_func_norm_mps(const Tensor& input_tensor, NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t}); - string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + - keepdim_info + ":" + toString(in_dtype) + ":" + to_string(castInputData); + string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + std::to_string(p) + ":" + + keepdim_info + ":" + toString(in_dtype) + ":" + std::to_string(castInputData); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor); @@ -572,7 +572,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps"; NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased "; - string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0"; + string use_dim_info = (use_dim) ? "use_dim=1:" + std::to_string(dim_value.size()) : "use_dim=0"; string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; string key = op_key + ":" + getTensorsStringKey(input_t) + ":" + use_dim_info + ":" + keepdim_info + ":" + string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value); @@ -700,7 +700,7 @@ static void min_max_out_mps(const Tensor& input_t, auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + to_string(dim_); + string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + std::to_string(dim_); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* outputTensor = nil; @@ -860,7 +860,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t, @autoreleasepool { NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","]; string key = - func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]); + func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputScalarType = input_t.scalar_type(); MPSGraphTensor* inputTensor = @@ -1217,7 +1217,7 @@ Tensor std_mps(const Tensor& input_t, @autoreleasepool { MPSShape* input_t_shape = getMPSShape(input_t); - string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + + string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" + getMPSTypeString(input_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSDataType input_type = getMPSDataType(input_t); @@ -1313,7 +1313,7 @@ Tensor std_mps(const Tensor& input_t, @autoreleasepool { MPSShape* input_t_shape = getMPSShape(input_t); - string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + + string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" + getMPSTypeString(input_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSDataType input_type = getMPSDataType(input_t); @@ -1531,8 +1531,8 @@ static void median_out_mps(const Tensor& input_t, auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = - func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t); + string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + + getTensorsStringKey(indices_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* castInputTensor = diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm index 135041be1f41..c32553094855 100644 --- a/aten/src/ATen/native/mps/operations/Shape.mm +++ b/aten/src/ATen/native/mps/operations/Shape.mm @@ -108,8 +108,8 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in // Input as placeholders MPSShape* input_shape = getMPSShape(self); NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + to_string(k) + - ":dim" + to_string(dim_) + ":largest" + to_string(largest); + string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + std::to_string(k) + + ":dim" + std::to_string(dim_) + ":largest" + std::to_string(largest); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); @@ -320,12 +320,12 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in }; @autoreleasepool { - string key = - "cat_out_mps:" + to_string(dimension) + ":" + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); + string key = "cat_out_mps:" + std::to_string(dimension) + ":" + + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); if (!all_same_dtype) { key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride); } else { - key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + to_string(inputs.size()); + key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size()); } for (auto idx : skipped_tensor_indices) { key += "," + std::to_string(idx); diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm index e3ee85cfe230..5b94240846da 100644 --- a/aten/src/ATen/native/mps/operations/Sort.mm +++ b/aten/src/ATen/native/mps/operations/Sort.mm @@ -60,8 +60,8 @@ // Input as placeholders MPSShape* input_shape = getMPSShape(self); NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + to_string(dim) + - ":descending" + to_string(descending); + string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + + std::to_string(dim) + ":descending" + std::to_string(descending); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index 4da5c302214d..6f8bfff53b8c 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -240,8 +240,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t, @autoreleasepool { // the optional min/max refs could affect how we build the cached graph - string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") + - (has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t}); + string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") + + (has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { if (has_min) newCachedGraph->minTensor = [mpsGraph diff --git a/aten/src/ATen/native/mps/operations/Unique.mm b/aten/src/ATen/native/mps/operations/Unique.mm index fc30c2d0b797..a9948183b04c 100644 --- a/aten/src/ATen/native/mps/operations/Unique.mm +++ b/aten/src/ATen/native/mps/operations/Unique.mm @@ -36,8 +36,8 @@ const bool consecutive, c10::optional dimOpt) { return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" + - (dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + "]:[" + - to_string(return_counts) + "]:[" + to_string(consecutive) + "]"; + (dimOpt.has_value() ? std::to_string(dimOpt.value()) : "None") + "]:[" + std::to_string(return_inverse) + "]:[" + + std::to_string(return_counts) + "]:[" + std::to_string(consecutive) + "]"; } // dim arg not supported when non consecutive, ie sorted diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm index f4973f600015..fca71ed346c5 100644 --- a/aten/src/ATen/native/mps/operations/UpSample.mm +++ b/aten/src/ATen/native/mps/operations/UpSample.mm @@ -99,7 +99,7 @@ static void upsample_out_template(const Tensor& input, @autoreleasepool { string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") + - getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" + + getTensorsStringKey({input}) + ":[" + std::to_string(scale_h) + "," + std::to_string(scale_w) + "]:[" + (is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]"; auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index b583a19ef5e6..ae530ad12bde 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -42,7 +42,7 @@ } return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" + - getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]"; + getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + std::to_string(storage_offset) + "]"; } // initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index 799188be9a68..5fcc7b86a2fa 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -20,7 +19,6 @@ #include #include #include -#include #include #include #include From bb2a9955297fe064b99308a64a1ac43ab1a212c8 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 12 Jun 2024 01:34:29 +0000 Subject: [PATCH 664/706] Back out "[Dynamo] Treat integers stored on nn.Modules as dynamic (#126466)" (#128432) Summary: Original commit changeset: c7d2e6b13922 Original Phabricator Diff: D57618942 Differential Revision: D58383241 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128432 Approved by: https://github.com/ezyang, https://github.com/Yuzhen11 --- test/dynamo/test_modules.py | 57 ------------------- .../test_dynamo_with_onnxruntime_backend.py | 6 +- torch/_dynamo/variables/builder.py | 4 ++ 3 files changed, 7 insertions(+), 60 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index c38dc7c7b892..dbfef8af4386 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -22,7 +22,6 @@ from torch._dynamo.eval_frame import unsupported from torch._dynamo.mutation_guard import GenerationTracker from torch._dynamo.testing import expectedFailureDynamic, same -from torch._dynamo.utils import ifdynstaticdefault from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import Parameter, UninitializedParameter @@ -1108,37 +1107,6 @@ def forward(self, x): return self.m(x) -class ModuleWithIntAttr(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(4, 4) - self.step = 10 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + 1 - self.step += 1 - return self.layer(x) + self.step - - -class UnspecInlinableModule(torch.nn.Module): - torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule - - def forward(self, x): - return torch.sin(x) - - -class UnspecModuleWithIntAttr(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer = UnspecInlinableModule() - self.step = 10 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + 1 - self.step += 1 - return self.layer(x) + self.step - - def make_test(fn, expected_ops=None): def test_fn(self): return torch._dynamo.testing.standard_test( @@ -1392,31 +1360,6 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) self.assertTrue(torch._dynamo.testing.same(out1, out_post)) - def test_nn_module_unspec_int_attr(self): - for module_class in [ModuleWithIntAttr, UnspecModuleWithIntAttr]: - mod = module_class() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch.compile(backend=cnt)(copy.deepcopy(mod)) - x = torch.randn(3, 4) - - # Compiling self.step as static. - ref1 = mod(x) - res1 = opt_mod(x) - self.assertTrue(torch.allclose(ref1, res1)) - self.assertEqual(cnt.frame_count, 1) - - # Compiling self.step as dynamic. - ref2 = mod(x) - res2 = opt_mod(x) - self.assertTrue(torch.allclose(ref2, res2)) - self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) - - # No re-compilation! - ref3 = mod(x) - res3 = opt_mod(x) - self.assertTrue(torch.allclose(ref3, res3)) - self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) - # RuntimeError: SymIntArrayRef expected to contain only concrete integers @expectedFailureDynamic def test_lazy_module1(self): diff --git a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py index 951e7cfd7c54..0c7a141d6a7a 100644 --- a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py +++ b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py @@ -471,7 +471,7 @@ def generate_example_inputs(batch: int, seq: int, hidden_size: int): if test_local_backend: assert local_ort is not None - number_of_captured_graphs = 3 if test_backward else 2 + number_of_captured_graphs = 2 if test_backward else 1 execution_count = len(example_args_collection) * number_of_captured_graphs self._assert_counting_information( local_ort, @@ -564,7 +564,7 @@ def generate_example_inputs(batch: int, seq: int, hidden_size: int): if test_local_backend: assert local_ort is not None - number_of_captured_graphs = 3 if test_backward else 2 + number_of_captured_graphs = 2 if test_backward else 1 execution_count = len(example_args_collection) * number_of_captured_graphs self._assert_counting_information( local_ort, @@ -649,7 +649,7 @@ def generate_example_inputs(batch: int, seq: int): if test_local_backend: assert local_ort is not None - number_of_captured_graphs = 3 if test_backward else 2 + number_of_captured_graphs = 2 if test_backward else 1 execution_count = len(example_args_collection) * number_of_captured_graphs self._assert_counting_information( local_ort, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 2f10fba1b370..03b29f9e04c8 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1195,6 +1195,10 @@ def wrap_literal(self, value): value in self._common_constants() # Assume integers from global variables want to be specialized or not self.source.guard_source().is_local() + # Assume that integers that came from NN modules want to be + # specialized (as we don't expect users to be changing the + # NN modules on the fly) + or self.source.guard_source().is_nn_module() or is_from_defaults(self.source) or is_cell_contents(self.source) ): From 3d55d84ec2271b58e731f180788c3e564fc69cc8 Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Wed, 12 Jun 2024 01:52:09 +0000 Subject: [PATCH 665/706] [Fix] Check tensor dtype before using torch.allclose in _trace log (#128438) #### Issue `torch.allclose` errors out during logging due to different dtypes. #### Test * `pytest test/test_jit.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128438 Approved by: https://github.com/angelayi --- torch/jit/_trace.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 5bdd71f94381..7db856024287 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -656,6 +656,8 @@ def analyze_ts_result_with_export_result(export, trace): return False if isinstance(orig, torch.Tensor): + if orig.dtype != loaded.dtype: + return False if not torch.allclose(orig, loaded): return False else: From 7f6daf289b62cf459a70c6bd4be13a21e086d211 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 12 Jun 2024 01:55:53 +0000 Subject: [PATCH 666/706] [inductor] parallel compile: set LD_LIBRARY_PATH for sub-processes in internal (#128376) Test Plan: `TORCHINDUCTOR_WORKER_START=subprocess TORCHINDUCTOR_COMPILE_THREADS=16 buck run mode/opt scripts/slarsen/torch_compile:run` Differential Revision: D58371264 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128376 Approved by: https://github.com/eellison --- torch/_inductor/compile_worker/subproc_pool.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 03bfe6c3f203..5aba18707b41 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -13,6 +13,7 @@ from concurrent.futures import Future, ProcessPoolExecutor from typing import Any, Callable, Dict +from torch._inductor import config from torch._inductor.compile_worker.watchdog import _async_compile_initializer log = logging.getLogger(__name__) @@ -59,6 +60,19 @@ def _recv_msg(read_pipe): return job_id, data +def _get_ld_library_path(): + path = os.environ.get("LD_LIBRARY_PATH", "") + if config.is_fbcode(): + from libfb.py.parutil import get_runtime_path + + runtime_path = get_runtime_path() + if runtime_path: + lib_path = os.path.join(runtime_path, "runtime", "lib") + path = os.pathsep.join([lib_path, path]) if path else lib_path + + return path + + class SubprocPool: """ Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in @@ -85,6 +99,8 @@ def __init__(self, nprocs: int): # torch._inductor.codecache since the warming process is what # creates the SubprocPool in the first place. "TORCH_WARM_POOL": "0", + # Some internal usages need a modified LD_LIBRARY_PATH. + "LD_LIBRARY_PATH": _get_ld_library_path(), }, ) self.write_pipe: Pipe = typing.cast(Pipe, self.process.stdin) From 85eeb90d2c4b4bd5a75450f2ff3f25796f73f5a7 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 11 Jun 2024 10:32:10 -0700 Subject: [PATCH 667/706] [dynamo] Fix graph breaks related to HF ModelOutput (#127780) Fixes https://github.com/pytorch/pytorch/issues/126028 and https://github.com/pytorch/pytorch/issues/126027. Changes: - Support building `CustomizedDictVariable` in` VariableBuilder` (but only for HF `ModelOutput` subclasses) - Remove `DataClassVariable` since it's not really being used anywhere (`CustomizedDictVariable` can be used instead) - Support side effects for `CustomizedDictVariable` - Allow `NO_HASATTR` leaf guard on `DictSubclassGuardManager` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127780 Approved by: https://github.com/jansel, https://github.com/anijain2305 --- .../aot_eager_huggingface_training.csv | 28 +-- .../aot_eager_torchbench_inference.csv | 2 +- .../aot_eager_torchbench_training.csv | 2 +- .../dynamic_inductor_huggingface_training.csv | 28 +-- .../cu124/inductor_huggingface_training.csv | 28 +-- ...dynamic_aot_eager_huggingface_training.csv | 28 +-- ...dynamic_aot_eager_torchbench_inference.csv | 2 +- .../dynamic_aot_eager_torchbench_training.csv | 2 +- .../dynamic_inductor_huggingface_training.csv | 28 +-- .../dynamic_inductor_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_training.csv | 2 +- .../dynamo_eager_huggingface_training.csv | 28 +-- .../dynamo_eager_torchbench_inference.csv | 2 +- .../dynamo_eager_torchbench_training.csv | 2 +- .../inductor_huggingface_training.csv | 28 +-- .../inductor_torchbench_inference.csv | 2 +- .../inductor_torchbench_training.csv | 2 +- test/dynamo/test_model_output.py | 62 ++++++ torch/_dynamo/side_effects.py | 39 ++++ torch/_dynamo/variables/__init__.py | 2 - torch/_dynamo/variables/builder.py | 12 +- torch/_dynamo/variables/builtin.py | 1 - torch/_dynamo/variables/dicts.py | 204 +++++------------- torch/_dynamo/variables/misc.py | 6 + torch/_dynamo/variables/user_defined.py | 3 - torch/csrc/dynamo/guards.cpp | 14 +- 26 files changed, 292 insertions(+), 267 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 68331f317995..9863aa7da6a2 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,46 +hf_BigBird,pass,43 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 20a5e024ece5..82048af8775a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass, 52 +hf_BigBird,pass,49 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index fd84df653db1..3aecea06b530 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,46 +hf_BigBird,pass,43 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index c010e129c19b..c87a07a8c294 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,52 +hf_BigBird,pass,49 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index f0417110484e..c167ea680d2c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_accuracy,46 +hf_BigBird,fail_accuracy,43 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index 82c4c1da2317..c25fa9471337 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,52 +hf_BigBird,pass,49 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 68331f317995..9863aa7da6a2 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,46 +hf_BigBird,pass,43 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index 30808bc6bcd4..4055eda462c5 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,52 +hf_BigBird,pass,49 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index b4700da57b25..74549205d747 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_accuracy,46 +hf_BigBird,fail_accuracy,43 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 30808bc6bcd4..4055eda462c5 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,52 +hf_BigBird,pass,49 diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index b2c1581d7e86..e6a6fc6dab58 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -101,6 +101,15 @@ def fn(obj: BaseModelOutput): self._common(fn, 2) + @maybe_skip + def test_mo_getattr_missing(self): + def fn(obj: BaseModelOutput): + if getattr(obj, "asdf", None) is not None: + obj.asdf += 1 + return obj.attentions + 1 + + self._common(fn, 1) + @maybe_skip def test_mo_getitem(self): def fn(obj: BaseModelOutput): @@ -166,6 +175,59 @@ def fn(obj): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 2) + @maybe_skip + def test_mo_init2(self): + # this ModelOutput subclass runs a different __post_init__ codepath + @dataclasses.dataclass + class MyDataClass(ModelOutput): + x: torch.FloatTensor = None + + def fn(x): + obj = MyDataClass(x=x) + return obj + + inp = torch.randn(3, 3) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + self.assertEqual(fn(inp).x, opt_fn(inp).x) + + @maybe_skip + def test_mo_init_with_disable(self): + # Can result in "non-function or method super: " + # graph breaks (although it may not be the first) + # Minimal repro for https://github.com/pytorch/pytorch/issues/126028 + @dataclasses.dataclass + class MyDataClass(ModelOutput): + x: torch.FloatTensor = None + + @torch._dynamo.disable(recursive=False) + def fn(x): + return MyDataClass(x=x) + + inp = torch.randn(3, 3) + opt_fn = torch._dynamo.optimize("eager")(fn) + self.assertEqual(fn(inp).x, opt_fn(inp).x) + + @maybe_skip + def test_mo_newkey(self): + obj = BaseModelOutput() + + def fn(obj): + return obj["wwww"] + 1 + + inp = torch.randn(3, 3) + obj["wwww"] = inp + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + self.assertEqual(fn(obj), opt_fn(obj)) + + @maybe_skip + def test_mo_from_outside(self): + def fn(obj): + return obj.attentions + 1 + + obj = BaseModelOutput(attentions=torch.randn(3, 3)) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + self.assertEqual(fn(obj), opt_fn(obj)) + @maybe_skip def test_HF_bert_model_output(self): class BertPooler(torch.nn.Module): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 9c3bf7d28711..2dedf71a66d9 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -7,6 +7,7 @@ from . import utils, variables from .bytecode_transformation import ( + bytecode_from_template, create_call_function, create_call_method, create_instruction, @@ -59,6 +60,11 @@ def __init__(self, source: Optional[Source], cls_source: Optional[Source]): self.cls_source = cls_source +def _manual_update_dict(dict_from, dict_to): + for k, v in dict_from.items(): + dict_to[k] = v + + class SideEffects: """ Track side effects (list mutation, setattr, etc) that need to be @@ -480,6 +486,39 @@ def codegen_update_mutated(self, cg: PyCodegen): ] ) suffixes.append([create_instruction("STORE_SUBSCR")]) + elif isinstance(var, variables.CustomizedDictVariable): + # need to update the dict manually since update method may be invalid + varname_map = {} + for name in _manual_update_dict.__code__.co_varnames: + varname_map[name] = cg.tx.output.new_var() + + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.extend_output( + [create_instruction("STORE_FAST", argval=varname_map["dict_to"])] + ) + + cg(var, allow_cache=False) + cg.extend_output( + [create_instruction("STORE_FAST", argval=varname_map["dict_from"])] + ) + + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.extend_output([create_load_method("clear")]) + + # unfortunately can't just use DICT_MERGE due to possible custom behaviors + dict_update_insts = bytecode_from_template( + _manual_update_dict, varname_map=varname_map + ) + + suffixes.append( + [ + *create_call_method(0), # clear + create_instruction("POP_TOP"), + *dict_update_insts, + create_instruction("POP_TOP"), + ] + ) + elif isinstance(var, variables.ConstDictVariable): cg.tx.output.update_co_names("clear") cg.tx.output.update_co_names("update") diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 25bda3769eb4..9ffdd64fbc96 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -23,7 +23,6 @@ from .dicts import ( ConstDictVariable, CustomizedDictVariable, - DataClassVariable, DefaultDictVariable, SetVariable, ) @@ -113,7 +112,6 @@ "CountIteratorVariable", "CustomizedDictVariable", "CycleIteratorVariable", - "DataClassVariable", "DefaultDictVariable", "DeletedVariable", "DeterministicAlgorithmsVariable", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 03b29f9e04c8..478fd3eb4010 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -111,7 +111,7 @@ ) from .dicts import ( ConstDictVariable, - DataClassVariable, + CustomizedDictVariable, DefaultDictVariable, HFPretrainedConfigVariable, PythonSysModulesVariable, @@ -493,6 +493,11 @@ class Autotuner: elif value is sys.modules: self.install_guards(GuardBuilder.FUNCTION_MATCH) return PythonSysModulesVariable(source=self.source) + elif CustomizedDictVariable.is_matching_cls_hf(type(value)): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = CustomizedDictVariable.wrap(self, value) + result.source = self.source + return self.tx.output.side_effects.track_object_existing(value, result) elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): if not value and self.get_source().is_nn_module(): # It is faster to guard on 'false' property than to guard @@ -711,9 +716,6 @@ def build_key_value(i, k, v): ) elif np and isinstance(value, np.number): return self.wrap_unspecialized_primitive(value) - elif DataClassVariable.is_matching_object(value): - self.install_guards(GuardBuilder.TYPE_MATCH) - return DataClassVariable.wrap(self, value) elif HFPretrainedConfigVariable.is_matching_object(value): self.install_guards(GuardBuilder.TYPE_MATCH) return HFPretrainedConfigVariable(value) @@ -1701,7 +1703,7 @@ def wrap_unspecialized_primitive(self, value): def _dataclasses_fields_lambda(obj): if isinstance(obj, UserDefinedObjectVariable): value = obj.value - elif isinstance(obj, DataClassVariable): + elif isinstance(obj, CustomizedDictVariable): value = obj.user_cls else: unimplemented(f"Dataclass fields handling fails for type {obj}") diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 2586c8deab94..71744e95277f 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1633,7 +1633,6 @@ def call_setattr( if isinstance( obj, ( - variables.DataClassVariable, variables.CustomizedDictVariable, variables.PlacementVariable, variables.UserDefinedObjectVariable, diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 0724a80621f7..ea599af95cdc 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -545,6 +545,8 @@ def python_type(self): def _is_matching_transformers_cls(cls) -> bool: mod = sys.modules.get("transformers.file_utils") + if mod is None: + mod = sys.modules.get("transformers.utils.generic") return mod is not None and issubclass(cls, mod.ModelOutput) @@ -555,12 +557,20 @@ def _is_matching_diffusers_cls(cls) -> bool: def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker": """Shared method between DataClassVariable and CustomizedDictVariable where items are attrs""" + if tx.output.side_effects.is_attribute_mutation(self): + try: + result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) + return variables.ConstantVariable.create( + not isinstance(result, variables.DeletedVariable) + ) + except KeyError: + pass if name in self.items or hasattr(self.user_cls, name): return ConstantVariable(True) elif istype(self.mutable_local, MutableLocal) and self.source is None: # Something created locally can't have any extra fields on it return ConstantVariable(False) - elif self.mutable_local is None and self.source: + elif self.source: # Maybe add a guard try: example = tx.output.root_tx.get_example_value(self.source) @@ -577,152 +587,27 @@ def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker": class DataClassVariable(ConstDictVariable): """ - This is a bit of a hack to deal with - transformers.file_utils.ModelOutput() from huggingface. + This class doesn't appear to be used anywhere. + It used to be used to deal with transformers.file_utils.ModelOutput + from huggingface. - ModelOutput causes trouble because it a a mix of a dataclass and a - OrderedDict and it calls super() methods implemented in C. + Keeping since we wish to support dataclasses in general in the future """ - # ModelOutput() excludes None, though generic datclasses don't - include_none = False - - @staticmethod - @functools.lru_cache(None) - def _patch_once(): - try: - from transformers.file_utils import ModelOutput - - for obj in ModelOutput.__dict__.values(): - if callable(obj): - skip_code(obj.__code__) - except ImportError: - pass + pass - try: - from diffusers.utils import BaseOutput - - for obj in BaseOutput.__dict__.values(): - if callable(obj): - skip_code(obj.__code__) - except ImportError: - pass +class CustomizedDictVariable(ConstDictVariable): @staticmethod - def is_matching_cls(cls): + def is_matching_cls_hf(cls): return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) - @classmethod - def is_matching_object(cls, obj): - return cls.is_matching_cls(type(obj)) - - @classmethod - def create(cls, user_cls, args, kwargs, options): - DataClassVariable._patch_once() - - skip_code(user_cls.__init__.__code__) - keys = [f.name for f in dataclasses.fields(user_cls)] - bound = inspect.signature(user_cls).bind(*args, **kwargs) - bound.apply_defaults() - assert set(bound.arguments.keys()) == set(keys) - items = {} - for key in keys: - val = bound.arguments[key] - key = ConstantVariable.create(key) - if isinstance(val, VariableTracker): - items[key] = val - else: - if cls.include_none: - assert variables.ConstantVariable.is_literal(val) - items[key] = variables.ConstantVariable.create(val) - else: - assert val is None, f"unexpected {val}" - - if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable): - unimplemented("DataClassVariable iterator constructor") - # TODO(jansel): implement unpacking logic in ModelOutput.__post_init__ - - return cls(items, user_cls, **options) - - @classmethod - def wrap(cls, builder, obj): - user_cls = type(obj) - keys = [f.name for f in dataclasses.fields(user_cls)] - - excluded = [] - items = {} - for key in keys: - # __init__ function of a dataclass might not have yet defined the key - if hasattr(obj, key): - val = getattr(obj, key) - var = builder.__class__( - tx=builder.tx, source=AttrSource(builder.source, key) - )(val) - if val is not None or cls.include_none: - key = ConstantVariable.create(key) - items[key] = var - else: - excluded.append(var) - return cls(items, user_cls) - - def __init__(self, items, user_cls, **options): - super().__init__(items, user_cls, **options) - assert self.is_matching_cls(user_cls) - - def as_proxy(self): - raise NotImplementedError - - def reconstruct(self, codegen): - codegen.extend_output([codegen._create_load_const(self.user_cls)]) - # All the keys are just wrapped strings - d = self.keys_as_python_constant() - codegen.foreach(d.values()) - keys = tuple(d.keys()) - codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True)) - - def call_method( - self, - tx, - name, - args: "List[VariableTracker]", - kwargs: "Dict[str, VariableTracker]", - ) -> "VariableTracker": - if name == "__getitem__": - assert not kwargs and len(args) == 1 - val = args[0] - if val.python_type() == str: - return self.getitem_const(val) - else: - return self.call_method(tx, "to_tuple", [], {}).call_method( - tx, "__getitem__", args, kwargs - ) - elif name == "to_tuple": - assert not (args or kwargs) - return variables.TupleVariable(list(self.items.values())) - elif name == "__setattr__": - name = "__setitem__" - return super().call_method(tx, name, args, kwargs) - - def var_getattr(self, tx, name: str) -> "VariableTracker": - name_vt = ConstantVariable.create(name) - if name_vt in self: - return self.call_method(tx, "__getitem__", [name_vt], {}) - elif not self.include_none: - defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)} - if name in defaults: - assert variables.ConstantVariable.is_literal(defaults[name]) - return variables.ConstantVariable.create(defaults[name]) - super().var_getattr(tx, name) - - call_hasattr = _call_hasattr_customobj - - -class CustomizedDictVariable(ConstDictVariable): @staticmethod def is_matching_cls(cls): # True if using default OrderedDict.__init__ and did not implement __post_init__ if ( issubclass(cls, collections.OrderedDict) + and cls is not collections.OrderedDict and cls.__init__ is collections.OrderedDict.__init__ and not hasattr(cls, "__post_init__") ): @@ -730,7 +615,7 @@ def is_matching_cls(cls): # hack for HF usecase: # assume dataclass annotation for ModelOutput subclass # assume self.create is AA to ModelOutput.__post_init__ - return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) + return CustomizedDictVariable.is_matching_cls_hf(cls) @classmethod def is_matching_object(cls, obj): @@ -764,9 +649,7 @@ def make_var(x): ) bound_args = {} - if _is_matching_transformers_cls(user_cls) or _is_matching_diffusers_cls( - user_cls - ): + if cls.is_matching_cls_hf(user_cls): # Skip none for k, v in bound.arguments.items(): if isinstance(v, ConstantVariable) and v.value is None or v is None: @@ -792,7 +675,27 @@ def make_var(x): # called from builder.py @classmethod def wrap(cls, builder, obj): - raise NotImplementedError + user_cls = type(obj) + + if not cls.is_matching_cls_hf(user_cls): + unimplemented("custom non-hf dict subclass wrap unimplemented") + + items = builder.__class__(tx=builder.tx, source=builder.source)( + collections.OrderedDict(obj) + ).items + + keys = [f.name for f in dataclasses.fields(user_cls)] + for key in keys: + # __init__ function of a dataclass might not have yet defined the key + if hasattr(obj, key): + val = getattr(obj, key) + var = builder.__class__( + tx=builder.tx, source=AttrSource(builder.source, key) + )(val) + if val is not None: + key = ConstantVariable.create(key) + items[key] = var + return cls(items, user_cls) def __init__(self, items, user_cls, **options): super().__init__(items, user_cls, **options) @@ -804,9 +707,7 @@ def as_proxy(self): # 'RETURN_VALUE triggered compile' # called from torch/_dynamo/codegen.py def reconstruct(self, codegen): - is_hf_model_output = _is_matching_transformers_cls( - self.user_cls - ) or _is_matching_diffusers_cls(self.user_cls) + is_hf_model_output = self.is_matching_cls_hf(self.user_cls) # If the user class is a ModelOutput, then wrap the instance creation in # torch._dynamo.disable(). Even though we mark the __post_init__ as skip @@ -848,21 +749,34 @@ def call_method( ): # for python dict method without overridden return super().call_method(tx, name, args, kwargs) - elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"): + elif name in ( + "__getitem__", + "to_tuple", + "__setitem__", + "__setattr__", + "__post_init__", + ): # for user overridden method return tx.inline_user_function_return( variables.UserFunctionVariable(fn, source=source), [self] + list(args), kwargs, ) + elif fn is getattr(collections.OrderedDict, name, None): + return super().call_method(tx, name, args, kwargs) - unimplemented("custom dict: call_method unimplemented name=%s", name) + unimplemented(f"custom dict: call_method unimplemented name={name}") def var_getattr(self, tx, name: str) -> "VariableTracker": name_vt = ConstantVariable.create(name) if name_vt in self: return self.call_method(tx, "__getitem__", [name_vt], {}) - super().var_getattr(tx, name) + if dataclasses.is_dataclass(self.user_cls): + defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)} + if name in defaults: + assert variables.ConstantVariable.is_literal(defaults[name]) + return variables.ConstantVariable.create(defaults[name]) + return super().var_getattr(tx, name) call_hasattr = _call_hasattr_customobj diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index cc0fb7096701..83270e8fe6c4 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -171,6 +171,12 @@ def call_method( return super(variables.CustomizedDictVariable, self.objvar).call_method( tx, "__setitem__", args, kwargs ) + elif inner_fn is collections.OrderedDict.__getitem__ and isinstance( + self.objvar, variables.CustomizedDictVariable + ): + return super(variables.CustomizedDictVariable, self.objvar).call_method( + tx, "__getitem__", args, kwargs + ) elif is_standard_setattr(inner_fn) and isinstance( self.objvar, UserDefinedObjectVariable ): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 7c7673a103fd..2f0bf7530304 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -396,9 +396,6 @@ def call_function( return variables.CustomizedDictVariable.create( self.value, args, kwargs, options ) - elif variables.DataClassVariable.is_matching_cls(self.value): - options = {"mutable_local": MutableLocal()} - return variables.DataClassVariable.create(self.value, args, kwargs, options) elif ( variables.RestrictedListSubclassVariable.is_matching_cls(self.value) and self.source diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index b7fde50a9f1a..d2eb41f51115 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -3247,13 +3247,13 @@ void install_tensor_aliasing_guard( void install_no_tensor_aliasing_guard( const py::list& guard_managers, - py::list tensor_names, + const py::list& tensor_names, py::object verbose_code_parts) { // Adds a guard that checks none of tensors alias. This is a an example of // relational guard. There is one guard object that is shared between multiple // guard managers. std::shared_ptr guard = std::make_shared( - std::move(tensor_names), std::move(verbose_code_parts)); + tensor_names, std::move(verbose_code_parts)); // Register the resetter on the toor guard mananger, so that it can reset // the newly added relational guard when the guard eval fails. @@ -4006,7 +4006,15 @@ PyObject* torch_c_dynamo_guards_init() { DictSubclassGuardManager, DictGuardManager, std::unique_ptr>( - py_m, "DictSubclassGuardManager"); // NOLINT + py_m, "DictSubclassGuardManager") // NOLINT + .def( + "add_no_hasattr_guard", + [](DictSubclassGuardManager& self, + py::object attr_name, + py::object verbose_code_parts) -> void { + self.add_permitted_leaf_guard(std::make_shared( + std::move(attr_name), std::move(verbose_code_parts))); + }); py_m.def("install_tensor_aliasing_guard", install_tensor_aliasing_guard); py_m.def( From 3ddec713b81c671cdeec5d59b7b8f554ca684fc0 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 12 Jun 2024 02:20:15 +0000 Subject: [PATCH 668/706] Revert "[cuDNN][Quantization] Don't print when plan finalization fails in cuDNN quantization backend (#128177)" This reverts commit cac7a22b92478d897488688010e562b7bd36b97f. Reverted https://github.com/pytorch/pytorch/pull/128177 on behalf of https://github.com/clee2000 due to broke test/test_quantization.py::TestQuantizedLinear::test_qlinear_cudnn on sm86 tests https://hud.pytorch.org/pytorch/pytorch/commit/cac7a22b92478d897488688010e562b7bd36b97f https://github.com/pytorch/pytorch/actions/runs/9470648757/job/26100448913. Probably a landrace, test ran on the PR and succeed ([comment](https://github.com/pytorch/pytorch/pull/128177#issuecomment-2161977110)) --- aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp | 2 +- aten/src/ATen/native/quantized/cudnn/Conv.cpp | 2 +- aten/src/ATen/native/quantized/cudnn/Linear.cpp | 2 +- test/quantization/core/test_quantized_op.py | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp index 9e9e675e7a3c..07ccc19c4828 100644 --- a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp @@ -242,7 +242,7 @@ Tensor add(Tensor qa, Tensor qb, double output_scale, int64_t output_zero_point) run(plan_desc); execution_plan_cache[key] = plan_desc; return quantized_output.view(orig_sizes); - } catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {} + } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;} } TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Add Cudnn"); diff --git a/aten/src/ATen/native/quantized/cudnn/Conv.cpp b/aten/src/ATen/native/quantized/cudnn/Conv.cpp index 8823038da48b..606d769fe6eb 100644 --- a/aten/src/ATen/native/quantized/cudnn/Conv.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Conv.cpp @@ -252,7 +252,7 @@ void PackedConvWeightCudnn::apply_impl_helper(const at::Tensor& qua run(plan); execution_plan_cache.emplace(key, plan); return; - } catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {} + } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;} } TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Conv2D Cudnn"); diff --git a/aten/src/ATen/native/quantized/cudnn/Linear.cpp b/aten/src/ATen/native/quantized/cudnn/Linear.cpp index 54eb08443c48..d3219592e25b 100644 --- a/aten/src/ATen/native/quantized/cudnn/Linear.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Linear.cpp @@ -286,7 +286,7 @@ void PackedLinearWeightCudnn::apply_impl_helper(const at::Tensor& quantized_outp run(plan); execution_plan_cache.emplace(key, plan); return; - } catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {} + } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;} } TORCH_CHECK(false, "Unable to find an engine to execute this computation Quantized Linear Cudnn"); diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 6671b6634e00..5b86693e11c1 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4052,6 +4052,7 @@ def test_qlinear_with_input_q_dq_qweight_dq_output_fp32( use_channelwise=st.sampled_from([False])) # channelwise currently not supported for qlinear cudnn @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") + @unittest.skipIf(TEST_CUDNN and torch.backends.cudnn.version() == 90100, "expected failure on cuDNN 9.1.0") @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") @unittest.skipIf(TEST_ROCM, "not supported on rocm.") # TODO: check with yang regarding CUDNN flags From 7c2058338a99a4867683747b0f04b971ddc8749c Mon Sep 17 00:00:00 2001 From: Tuan Trieu Date: Wed, 12 Jun 2024 02:50:37 +0000 Subject: [PATCH 669/706] Improve convert fp32 to fp16 fx pass (#127829) Summary: Improve the convert fp32 to fp16 fx pass to use to_dtype node and const folding instead of inplace conversion. Test Plan: ``` buck2 test @//mode/{opt,inplace} //glow/fb/fx/fba/tests:test_fba_pass_manager_builder ``` Differential Revision: D57803843 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127829 Approved by: https://github.com/Skylion007 --- torch/fx/experimental/const_fold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index cb94ed3930ed..dca495b7f691 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -259,7 +259,7 @@ def mod_partition(node: torch.fx.Node): # worry about whether this is one or more tensors because the original graph # correctly uses getitem to extract individual tensors if there are multiple folded. fx_const_folded_attrs_name = get_unique_attr_name_in_module( - split, "_FX_CONST_FOLDED_ATTRS" + mod_traced, "_FX_CONST_FOLDED_ATTRS" ) setattr( split, From 86b5df3e71e6b786347ee5fa69daa054849bea2e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 12 Jun 2024 03:06:30 +0000 Subject: [PATCH 670/706] Documenting the torch.fx.annotate.annotate function (#128337) Fixes #127903 This PR adds docstring to the `torch.fx.annotate.annotate` function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128337 Approved by: https://github.com/malfet --- torch/fx/annotate.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torch/fx/annotate.py b/torch/fx/annotate.py index ab5c6d0acd61..d1b5b5f2d376 100644 --- a/torch/fx/annotate.py +++ b/torch/fx/annotate.py @@ -4,8 +4,18 @@ @compatibility(is_backward_compatible=False) def annotate(val, type): - # val could be either a regular value (not tracing) - # or fx.Proxy (tracing) + """ + Annotates a Proxy object with a given type. + + This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object + Args: + val (object): An object to be annotated if its type is torch.fx.Proxy. + type (object): A type to be assigned to a given proxy object as val. + Returns: + The given val. + Raises: + RuntimeError: If a val already has a type in its node. + """ if isinstance(val, Proxy): if val.node.type: raise RuntimeError(f"Tried to annotate a value that already had a type on it!" From 8cf302dce4c639df6af81124c1475dd9e8d67533 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 12 Jun 2024 03:25:52 +0000 Subject: [PATCH 671/706] [5/N] Change static functions in headers to inline (#128406) Follows #128286 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128406 Approved by: https://github.com/ezyang --- aten/src/ATen/native/GridSamplerUtils.h | 12 ++++-------- aten/src/ATen/native/LossMulti.h | 7 ++----- aten/src/ATen/native/MaxPooling.h | 2 +- aten/src/ATen/native/UpSample.h | 16 ++++++++-------- aten/src/ATen/native/vol2col.h | 4 ++-- 5 files changed, 17 insertions(+), 24 deletions(-) diff --git a/aten/src/ATen/native/GridSamplerUtils.h b/aten/src/ATen/native/GridSamplerUtils.h index eea21ddf5e37..f783043c7961 100644 --- a/aten/src/ATen/native/GridSamplerUtils.h +++ b/aten/src/ATen/native/GridSamplerUtils.h @@ -18,10 +18,8 @@ enum class GridSamplerPadding {Zeros, Border, Reflection}; using detail::GridSamplerInterpolation; using detail::GridSamplerPadding; -namespace { - // See NOTE [ grid_sampler Native Functions ]. -void check_grid_sampler_common( +inline void check_grid_sampler_common( const TensorBase& input, const TensorBase& grid ) { @@ -60,7 +58,7 @@ void check_grid_sampler_common( } // See NOTE [ grid_sampler Native Functions ]. -void check_grid_sampler_2d( +inline void check_grid_sampler_2d( const TensorBase& input, const TensorBase& grid ) { @@ -72,7 +70,7 @@ void check_grid_sampler_2d( } // See NOTE [ grid_sampler Native Functions ]. -void check_grid_sampler_3d( +inline void check_grid_sampler_3d( const TensorBase& input, const TensorBase& grid, int64_t interpolation_mode @@ -91,7 +89,7 @@ void check_grid_sampler_3d( // See NOTE [ grid_sampler Native Functions ]. // cudnn does not support inputs larger than 1024. -bool cond_cudnn_grid_sampler( +inline bool cond_cudnn_grid_sampler( const TensorBase& input, const TensorBase& grid ) { @@ -104,6 +102,4 @@ bool cond_cudnn_grid_sampler( input.sym_size(1) <= 1024); } -} // anonymous namespace - } // namespace at::native diff --git a/aten/src/ATen/native/LossMulti.h b/aten/src/ATen/native/LossMulti.h index 27697815ad59..8877b05a54cc 100644 --- a/aten/src/ATen/native/LossMulti.h +++ b/aten/src/ATen/native/LossMulti.h @@ -5,8 +5,7 @@ #include namespace at::native { -namespace { - static C10_UNUSED void multilabel_margin_loss_shape_check( + inline void multilabel_margin_loss_shape_check( int64_t& nframe, int64_t& dim, const int64_t& ndims, @@ -35,7 +34,7 @@ namespace { } } - static C10_UNUSED void multi_margin_loss_shape_check( + inline void multi_margin_loss_shape_check( int64_t& nframe, int64_t& dim, const int64_t& ndims, @@ -67,6 +66,4 @@ namespace { } } - -} // anonymous namespace } // namespace at::native diff --git a/aten/src/ATen/native/MaxPooling.h b/aten/src/ATen/native/MaxPooling.h index 3c6760ca6886..7044b6ee3dc2 100644 --- a/aten/src/ATen/native/MaxPooling.h +++ b/aten/src/ATen/native/MaxPooling.h @@ -7,7 +7,7 @@ namespace at::native { -static void check_max_pool1d( +inline void check_max_pool1d( const Tensor& self, IntArrayRef kernel_size, IntArrayRef stride, diff --git a/aten/src/ATen/native/UpSample.h b/aten/src/ATen/native/UpSample.h index 275d2028f764..9542d9953ed6 100644 --- a/aten/src/ATen/native/UpSample.h +++ b/aten/src/ATen/native/UpSample.h @@ -103,7 +103,7 @@ DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel); DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel); DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel); -static C10_UNUSED std::array upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) { +inline C10_UNUSED std::array upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 1, "It is expected output_size equals to 1, but got size ", @@ -131,7 +131,7 @@ static C10_UNUSED std::array upsample_1d_common_check(IntArrayRef in return {nbatch, channels, output_width}; } -static C10_UNUSED std::array upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) { +inline C10_UNUSED std::array upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 2, "It is expected output_size equals to 2, but got size ", @@ -167,7 +167,7 @@ static C10_UNUSED std::array upsample_2d_common_check(IntArrayRef in return {nbatch, channels, output_height, output_width}; } -static C10_UNUSED +inline C10_UNUSED std::array upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 3, @@ -365,7 +365,7 @@ inline int64_t nearest_exact_idx( typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, std::optional); template -static scalar_t upsample_get_value_bounded( +scalar_t upsample_get_value_bounded( scalar_t* data, int64_t width, int64_t height, @@ -377,7 +377,7 @@ static scalar_t upsample_get_value_bounded( } template -static void upsample_increment_value_bounded( +void upsample_increment_value_bounded( scalar_t* data, int64_t width, int64_t height, @@ -392,17 +392,17 @@ static void upsample_increment_value_bounded( // Based on // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm template -inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) { +scalar_t cubic_convolution1(scalar_t x, scalar_t A) { return ((A + 2) * x - (A + 3)) * x * x + 1; } template -inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) { +scalar_t cubic_convolution2(scalar_t x, scalar_t A) { return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; } template -inline void get_cubic_upsample_coefficients( +void get_cubic_upsample_coefficients( scalar_t coeffs[4], scalar_t t) { scalar_t A = -0.75; diff --git a/aten/src/ATen/native/vol2col.h b/aten/src/ATen/native/vol2col.h index ccbfc69ce3c6..fa5c46b8c52e 100644 --- a/aten/src/ATen/native/vol2col.h +++ b/aten/src/ATen/native/vol2col.h @@ -5,7 +5,7 @@ namespace at::native { template -static void vol2col( +void vol2col( const T* data_vol, const int64_t channels, const int64_t depth, @@ -56,7 +56,7 @@ static void vol2col( } template -static void col2vol( +void col2vol( const T* data_col, const int64_t channels, const int64_t depth, From 02e7519ac3cd4c4b043c9a0f672464d3797c0622 Mon Sep 17 00:00:00 2001 From: loganthomas Date: Wed, 12 Jun 2024 03:57:45 +0000 Subject: [PATCH 672/706] DOC: strip inaccurate either float32 or float64 statement from set_default_type (#128192) Fixes #126647 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128192 Approved by: https://github.com/malfet --- torch/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/__init__.py b/torch/__init__.py index 1c4d5e45b305..a5d22d2ebf72 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -976,7 +976,6 @@ def set_default_dtype(d): Args: d (:class:`torch.dtype`): the floating point dtype to make the default. - Either torch.float32 or torch.float64. Example: >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?") From c0b87afcade3fc93d73bcbe4b61d732565019340 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 11 Jun 2024 15:31:57 -0700 Subject: [PATCH 673/706] [RELAND2][dynamo][nn-modules] Trace through nn.Module dunder methods for UnspecializedNNModule (#126578) Tracing through `__init__` is important because it initializes (calls STORE_ATTR) on members. By doing that, we kick in the mutation tracking for these objects. So, things like mutating `_modules` etc is tracked automatically. Fixes https://github.com/pytorch/pytorch/issues/111837 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126578 Approved by: https://github.com/jansel --- test/distributed/test_dynamo_distributed.py | 10 +-- test/dynamo/test_higher_order_ops.py | 16 ++--- ...ddingNN.test_embedding_sparse_empty_tensor | 0 ...ngNN.test_embeddingbag_include_last_offset | 0 ....test_profiler_pattern_matcher_json_report | 0 .../TestJitGeneratedModule.test_nn_Bilinear | 0 .../TestJitGeneratedModule.test_nn_Embedding | 0 ...dModule.test_nn_EmbeddingBag_discontiguous | 0 ...itGeneratedModule.test_nn_EmbeddingBag_max | 0 ...odule.test_nn_EmbeddingBag_max_padding_idx | 0 ...tGeneratedModule.test_nn_EmbeddingBag_mean | 0 ...dule.test_nn_EmbeddingBag_mean_padding_idx | 0 ...eneratedModule.test_nn_EmbeddingBag_sparse | 0 ...itGeneratedModule.test_nn_EmbeddingBag_sum | 0 ...odule.test_nn_EmbeddingBag_sum_padding_idx | 0 ...atedModule.test_nn_Embedding_discontiguous | 0 ...itGeneratedModule.test_nn_Embedding_sparse | 0 .../TestJitGeneratedModule.test_nn_Linear | 0 ...eneratedModule.test_nn_Linear_no_batch_dim | 0 ...GeneratedModule.test_nn_PReLU_no_batch_dim | 0 .../TestNN.test_ParameterDict | 0 .../TestNN.test_Sequential_iadd | 0 .../TestNN.test_bilinear_broadcasting | 0 ...st_layer_norm_grads_with_create_graph_flag | 0 ..._linear_autograd_device_cpu_bias_weightCOO | 0 ..._linear_autograd_device_cpu_bias_weightCSC | 0 ..._linear_autograd_device_cpu_bias_weightCSR | 0 .../TestNN.test_linear_broadcasting | 0 .../TestNN.test_module_apply_inplace_op | 0 ...est_overwrite_module_params_on_conversion} | 0 ...metrized_tensor_parametrization_swap_False | 0 ....test_new_spectral_norm_forward_swap_True} | 0 ...rization.test_new_spectral_norm_swap_True} | 0 ...weight_norm_parametrization_swap_False_cpu | 0 ..._weight_norm_parametrization_swap_True_cpu | 0 ...sorDeviceTypeCPU.test_embedding_jagged_cpu | 0 .../TestPruningNN.test_identity_pruning | 0 ...TestPruningNN.test_pruning_id_consistency} | 0 .../TestPruningNN.test_random_pruning_0perc | 0 test/profiler/test_profiler.py | 1 + torch/_dynamo/create_parameter_op.py | 20 ++++++ torch/_dynamo/mutation_guard.py | 3 + torch/_dynamo/side_effects.py | 32 ++++++---- torch/_dynamo/symbolic_convert.py | 19 ++++-- torch/_dynamo/utils.py | 4 +- torch/_dynamo/variables/dicts.py | 6 +- torch/_dynamo/variables/misc.py | 26 +++++--- torch/_dynamo/variables/nn_module.py | 40 ++++++++---- torch/_dynamo/variables/torch.py | 9 ++- torch/_dynamo/variables/user_defined.py | 63 ++++++++++++------- 50 files changed, 176 insertions(+), 73 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor delete mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset delete mode 100644 test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim delete mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim delete mode 100644 test/dynamo_expected_failures/TestNN.test_ParameterDict delete mode 100644 test/dynamo_expected_failures/TestNN.test_Sequential_iadd delete mode 100644 test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting delete mode 100644 test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR delete mode 100644 test/dynamo_expected_failures/TestNN.test_linear_broadcasting delete mode 100644 test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op rename test/dynamo_expected_failures/{FakeTensorTest.test_embedding_bag_meta => TestNN.test_overwrite_module_params_on_conversion} (100%) delete mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False rename test/dynamo_expected_failures/{TestCompileTransformsCPU.test_compile_vmap_hessian_cpu => TestNNParametrization.test_new_spectral_norm_forward_swap_True} (100%) rename test/dynamo_expected_failures/{TestEmbeddingNN.test_embedding_max_norm => TestNNParametrization.test_new_spectral_norm_swap_True} (100%) delete mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu delete mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu delete mode 100644 test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu delete mode 100644 test/dynamo_expected_failures/TestPruningNN.test_identity_pruning rename test/dynamo_expected_failures/{TestEmbeddingNN.test_embedding_sparse_basic => TestPruningNN.test_pruning_id_consistency} (100%) delete mode 100644 test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index b31a2f717537..db44f1ce915d 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1084,12 +1084,14 @@ def _(ctx): # far from an exhaustive check of all the expected guards, just check a couple of them. FileCheck().check("""local "L['self']" TYPE_MATCH""").check( """local "L['self']" ID_MATCH""" - ).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check( - f"""{expected_guard_source} "L['self'].net" ID_MATCH""" ).check( - f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH""" + f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" ).check( - f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH""" + f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH""" + ).check( + f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" + ).check( + f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH""" ).run( GUARDS_FILE.getvalue() ) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 7a746a9b1d08..410317d33a14 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -5179,10 +5179,10 @@ def wrapper_fn(x): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): - l_self_tensor_constant0 = L_self_tensor_constant0 + def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"): + l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_ - alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None + alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default) @@ -5201,16 +5201,16 @@ def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): - getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_ - getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_ + def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): + l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_ + l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_ l_flat_tangents_1_ = L_flat_tangents_1_ - _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None + _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None - mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None + mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None return (mul_tensor,) """, ) diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset b/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report b/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_ParameterDict b/test/dynamo_expected_failures/TestNN.test_ParameterDict deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_Sequential_iadd b/test/dynamo_expected_failures/TestNN.test_Sequential_iadd deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting b/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag b/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_broadcasting b/test/dynamo_expected_failures/TestNN.test_linear_broadcasting deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op b/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta b/test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion similarity index 100% rename from test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta rename to test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False b/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu b/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True similarity index 100% rename from test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu rename to test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm b/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True similarity index 100% rename from test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm rename to test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu b/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning b/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic b/test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency similarity index 100% rename from test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic rename to test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency diff --git a/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc b/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index ca0481922e4f..81d158635c0e 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -2411,6 +2411,7 @@ def test_profiler_matmul_dim_fp16_pattern(self): num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) + @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_profiler_pattern_matcher_json_report(self): x = torch.ones((100, 100)) model = nn.Sequential( diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py index f6cd12de2021..d30e4a37f003 100644 --- a/torch/_dynamo/create_parameter_op.py +++ b/torch/_dynamo/create_parameter_op.py @@ -1,4 +1,7 @@ # mypy: allow-untyped-defs +import threading +from contextlib import contextmanager + import torch doc = """ @@ -37,3 +40,20 @@ def new_parameter_placeholder(size, dtype, device, requires_grad): # Allocating a zero tensor would causes assert failures in autograd. result.untyped_storage().resize_(0) return result + + +_TLS = threading.local() + + +@contextmanager +def do_not_convert_to_tracable_parameter(): + old_flag = getattr(_TLS, "convert_tracable_parameter", True) + _TLS.convert_tracable_parameter = False + try: + yield False + finally: + _TLS.convert_tracable_parameter = old_flag + + +def can_convert_to_tracable_parameter(): + return getattr(_TLS, "convert_tracable_parameter", True) diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py index 22e2b9999e03..9077ecd3d57f 100644 --- a/torch/_dynamo/mutation_guard.py +++ b/torch/_dynamo/mutation_guard.py @@ -11,6 +11,9 @@ from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks +unpatched_nn_module_init = torch.nn.Module.__init__ + + class MutationTracker: db = ExactWeakKeyDictionary() diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 2dedf71a66d9..c3d23728093a 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -373,13 +373,7 @@ def codegen_save_tempvars(self, cg: PyCodegen): elif isinstance(var.mutable_local, AttributeMutationNew): if isinstance(var, variables.AutogradFunctionContextVariable): unimplemented("AutogradFunctionContextVariable escaped") - if "__call_nn_module_init" in self.store_attr_mutations.get( - var.mutable_local, {} - ): - assert isinstance(var, variables.UnspecializedNNModuleVariable) - cg.load_import_from(utils.__name__, "nn_module_new") - else: - cg.load_import_from(utils.__name__, "object_new") + cg.load_import_from(utils.__name__, "object_new") cg(var.mutable_local.cls_source) cg.extend_output(create_call_function(1, True)) cg.add_cache(var) @@ -539,9 +533,25 @@ def codegen_update_mutated(self, cg: PyCodegen): ] ) elif self.is_attribute_mutation(var): - for name, value in self.store_attr_mutations.get( - var.mutable_local, {} - ).items(): + # Applying mutations involves two steps: 1) Push all + # reconstructed objects onto the stack. 2) Call STORE_ATTR to + # apply the mutations. + # + # Dynamo must ensure that mutations are applied in the same + # order as in the original program. Therefore, two reverse + # operations occur below. + # + # The first reverse operation concerns `suffixes`. We apply + # suffixes in reverse order due to the way Python handles the + # stack. In Step 1, we push all reconstructed objects onto the + # stack, but the item at the top of the stack refers to the last + # attribute in the mutation order. If not fixed, this will apply + # the mutations of attributes in the reverse order. To account + # for this reversal, we iterate through the mutable attributes + # in reverse order. + for name, value in reversed( + self.store_attr_mutations.get(var.mutable_local, {}).items() + ): if isinstance(var, variables.NewGlobalVariable): cg.tx.output.update_co_names(name) cg(value) @@ -549,8 +559,6 @@ def codegen_update_mutated(self, cg: PyCodegen): suffixes.append( [create_instruction("STORE_GLOBAL", argval=name)] ) - elif name == "__call_nn_module_init": - pass # handled in codegen_save_tempvars elif isinstance(value, variables.DeletedVariable): if isinstance( var.mutable_local, AttributeMutationExisting diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 41ceaa615916..7e129a05a090 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -102,7 +102,7 @@ PythonModuleVariable, UnknownVariable, ) -from .variables.nn_module import NNModuleVariable +from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable from .variables.user_defined import ( RemovableHandleVariable, @@ -415,11 +415,22 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if push: self.push(value) self.jump(inst) + elif isinstance(value, UnspecializedNNModuleVariable): + mod = value.value + if truth_fn(mod): + if push: + self.push(value) + self.jump(inst) elif isinstance(value, UserDefinedObjectVariable): - x = value.var_getattr(self, "__bool__") - # if __bool__ is missing, trying __len__ to infer a truth value. - if isinstance(x, GetAttrVariable): + try: + x = value.var_getattr(self, "__bool__") + except exc.ObservedException: + # if __bool__ is missing, trying __len__ to infer a truth value. x = value.var_getattr(self, "__len__") + else: + if isinstance(x, GetAttrVariable): + # if __bool__ is missing, trying __len__ to infer a truth value. + x = value.var_getattr(self, "__len__") # __bool__ or __len__ is function if isinstance(x, UserMethodVariable): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 6da8b514f16b..fe2f096ec488 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2019,12 +2019,12 @@ def object_has_getattribute(value: Any): return False -def get_custom_getattr(value: Any): +def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): try: getattr_fn = inspect.getattr_static(type(value), "__getattr__") except AttributeError: getattr_fn = None - if getattr_fn is torch.nn.Module.__getattr__: + if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__: # ignore this case of getattr getattr_fn = None return getattr_fn diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index ea599af95cdc..50ea3f96379c 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -174,7 +174,11 @@ def python_type(self): def __contains__(self, vt): assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker - return is_hashable(vt) and Hashable(vt) in self.items + return ( + is_hashable(vt) + and Hashable(vt) in self.items + and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) + ) def reconstruct(self, codegen): # instructions to load collections.OrderedDict if necessary diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 83270e8fe6c4..179bb9a52bf9 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -14,8 +14,10 @@ import torch.utils._pytree as pytree from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction +from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import unimplemented from ..guards import GuardBuilder, install_guard +from ..mutation_guard import unpatched_nn_module_init from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource from ..utils import ( check_unspec_or_constant_args, @@ -121,7 +123,6 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) - if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: @@ -133,12 +134,10 @@ def call_method( and isinstance(objvar.mutable_local, AttributeMutationNew) and not (args or kwargs) ): - tx.output.side_effects.store_attr( - objvar, - "__call_nn_module_init", - variables.ConstantVariable.create(True), - ) - return variables.ConstantVariable.create(None) + with do_not_convert_to_tracable_parameter(): + return variables.UserFunctionVariable( + unpatched_nn_module_init, source=source + ).call_function(tx, [self.objvar] + args, kwargs) else: unimplemented("super() nn.Module.__init__") elif isinstance(inner_fn, types.FunctionType): @@ -181,6 +180,19 @@ def call_method( self.objvar, UserDefinedObjectVariable ): return self.objvar.method_setattr_standard(tx, *args, **kwargs) + elif inner_fn is object.__delattr__: + attr = args[0] + try: + attr = attr.as_python_constant() + except NotImplementedError: + unimplemented(f"non-const delattr attr: {attr}") + if not tx.output.side_effects.is_attribute_mutation(self.objvar): + unimplemented(f"delattr({self.objvar}, {attr}, ...)") + + tx.output.side_effects.store_attr( + self.objvar, attr, variables.DeletedVariable() + ) + return variables.ConstantVariable(None) unimplemented(f"non-function or method super: {inner_fn}") diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 37c0bc17697a..d3f7052a9445 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -215,7 +215,7 @@ def _custom_getattr_fallback(self, base, tx, name, options): if object_has_getattribute(base): unimplemented("torch.nn.Module with a custom __getattribute__ defined") - getattr_fn = get_custom_getattr(base) + getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True) if getattr_fn is None: return None @@ -665,7 +665,6 @@ def gen_source(source, name): if isinstance(args[0], SliceVariable): # Build a TupleVariable of NNModules result = [] - submods = [] # Turn the slice into the list of integers keys = list(range(len(module)))[args[0].as_python_constant()] @@ -679,9 +678,8 @@ def gen_source(source, name): source=src, ) ) - submods.append(submod) - new_module = torch.nn.Sequential(*submods) + new_module = module[args[0].as_python_constant()] new_module_variable = tx.output.register_attr_or_module( new_module, f"{self}.__getitem__(slice)", @@ -695,8 +693,10 @@ def gen_source(source, name): if isinstance(args[0], SymNodeVariable): key = args[0].evaluate_expr(tx.output) - else: + elif args[0].is_python_constant(): key = args[0].as_python_constant() + else: + unimplemented(f"getitem on NNModuleVariable with key {args[0]}") submod = module[key] return tx.output.register_attr_or_module( @@ -790,7 +790,7 @@ def set_nn_module_stack_source(self, source): @functools.lru_cache(None) def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler - supported = {torch.nn.Module.__setattr__} + supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} return { id(x.__code__) for x in torch.nn.Module.__dict__.values() @@ -798,8 +798,6 @@ def _nn_module_method_ids(): } def unpack_var_sequence(self, tx): - from .builder import VariableBuilder - try: fn = inspect.getattr_static(self.value_type, "__iter__") except AttributeError as e: @@ -810,11 +808,16 @@ def unpack_var_sequence(self, tx): torch.nn.ParameterList.__iter__, torch.nn.Sequential.__iter__, ): - assert self.source - return [ - VariableBuilder(tx, source=GetItemSource(self.source, idx))(item) - for idx, item in enumerate(self.value) - ] + # The program can mutate the nn module object but the saved `value` + # will not reflect the mutations. So, trace through the `__iter__` + # function to reflect any tracked mutations. + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn), + [ + self, + ], + {}, + ).unpack_var_sequence(tx) return super().unpack_var_sequence(tx) @@ -943,6 +946,17 @@ def call_method( # Handle submodules self.is_state_mutated = True + if method is torch.nn.Module.__setattr__ and isinstance( + args[1], variables.DeletedVariable + ): + # Trace through __delattr__ to track mutations on the module + # members like `_modules``. + return tx.inline_user_function_return( + variables.UserFunctionVariable(torch.nn.Module.__delattr__), + [self, args[0]], + kwargs, + ) + return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 4d7b96b6a320..934e9a316a4b 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -18,7 +18,11 @@ from ..._guards import TracingContext from .. import config, polyfill, variables from ..codegen import PyCodegen -from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter +from ..create_parameter_op import ( + can_convert_to_tracable_parameter, + new_parameter_placeholder, + tracable_create_parameter, +) from ..device_interface import get_registered_device_interfaces from ..exc import unimplemented from ..guards import GuardBuilder, install_guard @@ -871,6 +875,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) + if not can_convert_to_tracable_parameter(): + unimplemented("Workaround for issues with nn_parameter construction") + try: shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) dtype = data.var_getattr(tx, "dtype").as_python_constant() diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 2f0bf7530304..6c6d3182b660 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -34,7 +34,8 @@ from torch._guards import TracingContext from .. import variables -from ..exc import unimplemented +from ..create_parameter_op import do_not_convert_to_tracable_parameter +from ..exc import ObservedException, unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource from ..utils import ( @@ -57,10 +58,7 @@ def is_standard_setattr(val): - return val in ( - object.__setattr__, - torch.nn.Module.__setattr__, - ) + return val in (object.__setattr__,) class UserDefinedVariable(VariableTracker): @@ -378,17 +376,7 @@ def call_function( else UserDefinedObjectVariable, {}, ) - if ( - inspect.getattr_static(self.value, "__init__", None) - is torch.nn.Module.__init__ - ): - tx.output.side_effects.store_attr( - var, - "__call_nn_module_init", - variables.ConstantVariable.create(True), - ) - return var - else: + with do_not_convert_to_tracable_parameter(): var.call_method(tx, "__init__", args, kwargs) return var elif variables.CustomizedDictVariable.is_matching_cls(self.value): @@ -635,6 +623,10 @@ def call_method( else AttrSource(AttrSource(self.source, "__class__"), name) ) # TODO(jansel): add a guard to check for monkey patching? + from ..mutation_guard import unpatched_nn_module_init + + if method is torch.nn.Module.__init__: + method = unpatched_nn_module_init return UserMethodVariable(method, self, source=source).call_function( tx, args, kwargs ) @@ -796,7 +788,7 @@ def _check_for_getattr(self): def _getattr_static(self, name): if ( - isinstance(self.value, (torch.nn.Module, PyTreeSpec)) + isinstance(self.value, PyTreeSpec) or "__slots__" in self.value.__class__.__dict__ or type(self.value) == threading.local ): @@ -809,7 +801,6 @@ def _getattr_static(self, name): return cls_var except AttributeError: pass # __slots__ - # this might call torch.nn.Module.__getattr__ subobj = getattr(self.value, name) else: subobj = inspect.getattr_static(self.value, name) @@ -1015,14 +1006,35 @@ def call_hasattr(self, tx, name: str) -> "VariableTracker": install_guard( AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) ) - if self._check_for_getattribute() or self._check_for_getattr(): - unimplemented("hasattr with custom __getattr__") + if self._check_for_getattribute(): + unimplemented("hasattr with custom __getattribute__") try: self._getattr_static(name) return variables.ConstantVariable.create(True) except AttributeError: - return variables.ConstantVariable.create(False) + # Now check in __getattr__ function + getattr_fn = self._check_for_getattr() + if isinstance(getattr_fn, types.FunctionType): + # Dynamo is going to trace the __getattr__ function with + # args=name. Set the source accordingly. + new_source = None + if self.source: + new_source = AttrSource(self.source, "__getattr__") + try: + result = variables.UserMethodVariable( + getattr_fn, self, source=new_source + ).call_function(tx, [variables.ConstantVariable.create(name)], {}) + + return variables.ConstantVariable.create( + not isinstance(result, variables.DeletedVariable) + ) + except ObservedException: + return variables.ConstantVariable.create(False) + elif getattr_fn is None: + return variables.ConstantVariable.create(False) + else: + unimplemented("UserDefined with non-function __getattr__") def odict_getitem(self, tx, key): from .builder import VariableBuilder @@ -1089,6 +1101,12 @@ def var_getattr(self, tx, name): return super().var_getattr(tx, name) +class RemovableHandleClass: + # Dummy class to pass to python_type of RemovableHandleVariable + # Useful for isinstance check on hooks + pass + + class RemovableHandleVariable(VariableTracker): REMOVED = -1 @@ -1119,3 +1137,6 @@ def reconstruct(self, codegen): return # unreachable due to codegen.add_cache() when the hook is installed super().reconstruct(codegen) + + def python_type(self): + return RemovableHandleClass From 77a0ca66e4eb6919ed14a9491fa7579d06a29f3c Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Wed, 12 Jun 2024 04:13:33 +0000 Subject: [PATCH 674/706] Add threadfence to 2-stage reduction for correct writes visibility (#128455) Final block accumulating 2-stage reduction result has to complete acquire pattern to make sure the writes of all other blocks are visible to it, see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=atom#release-and-acquire-patterns Pull Request resolved: https://github.com/pytorch/pytorch/pull/128455 Approved by: https://github.com/eqy, https://github.com/ezyang --- aten/src/ATen/native/cuda/Reduce.cuh | 1 + aten/src/ATen/native/cuda/reduction_template.cuh | 1 + 2 files changed, 2 insertions(+) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 1f67ee3ea63e..85bde8b5990f 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -807,6 +807,7 @@ struct ReduceOp { bool is_last_block_done = mark_block_finished(); if (is_last_block_done) { + __threadfence(); // complete the acquire pattern after atomic value = ident; if (config.should_block_x_reduce()) { index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; diff --git a/aten/src/ATen/native/cuda/reduction_template.cuh b/aten/src/ATen/native/cuda/reduction_template.cuh index a38edb538256..6d1e861493d4 100644 --- a/aten/src/ATen/native/cuda/reduction_template.cuh +++ b/aten/src/ATen/native/cuda/reduction_template.cuh @@ -595,6 +595,7 @@ struct ReduceJitOp { bool is_last_block_done = mark_block_finished(); if (is_last_block_done) { + __threadfence(); //complete acquire pattern value = ident; if (config.should_block_x_reduce()) { uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; From 089f9a116ac8b2c14d6351b52614b529caba126b Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 11 Jun 2024 22:14:04 +0000 Subject: [PATCH 675/706] [tp] refactor and fix PrepareModuleInput for DTensor inputs (#128431) as titled, this PR refactors the PrepareModuleInput style to have common method prepare_input_arg, allow both args/kwargs to reuse this logic This also fixes https://github.com/pytorch/pytorch/issues/128365 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128431 Approved by: https://github.com/awgu --- .../tensor/parallel/test_tp_style.py | 12 +++++ torch/distributed/tensor/parallel/style.py | 52 ++++++++----------- 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/test/distributed/tensor/parallel/test_tp_style.py b/test/distributed/tensor/parallel/test_tp_style.py index e2a9a01da85b..776bdc9b50b4 100644 --- a/test/distributed/tensor/parallel/test_tp_style.py +++ b/test/distributed/tensor/parallel/test_tp_style.py @@ -317,6 +317,18 @@ def forward(self, *, x, y=2, z=None): self.assertEqual(comm_mode.get_total_counts(), 2) self.assertEqual(output.shape, (1 * self.world_size, 8)) + # test the case where x is a DTensor + x_dt = DTensor.from_local( + torch.randn(1, 8, device=self.device_type), mesh, [Shard(0)] + ) + with comm_mode: + output = test_kwonly_mod( + x=x_dt, z=torch.ones(1, 8, device=self.device_type) + ) + + self.assertEqual(comm_mode.get_total_counts(), 2) + self.assertEqual(output.shape, (1 * self.world_size, 8)) + @with_comms def test_prepare_module_output(self): mesh = init_device_mesh(self.device_type, (self.world_size,)) diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index f532b97e97d0..00d85bf5d499 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod -from typing import Optional, Union, Tuple, Dict +from typing import Optional, Union, Tuple, Dict, Any from functools import partial import torch @@ -400,6 +400,23 @@ def __init__( assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \ "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" + def _prepare_input_arg(self, input: Any, mesh: DeviceMesh, input_layout: Placement, desired_layout: Placement): + if input_layout is not None: + if isinstance(input, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert inp.placements[0] == input_layout + dt_inp = input + else: + assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!" + dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False) + + if desired_layout is not None and input_layout != desired_layout: + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) + + return dt_inp.to_local() if self.use_local_output else dt_inp + else: + return input + def _prepare_input_fn(self, inputs, device_mesh): if self.input_layouts is None: return inputs @@ -409,21 +426,8 @@ def _prepare_input_fn(self, inputs, device_mesh): if len(inputs) != len(self.input_layouts): raise ValueError("module inputs and input_layouts should have same length!") - assert self.desired_input_layouts is not None, "desired module inputs should not be None!" for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts): - if input_layout is not None: - if isinstance(inp, DTensor): - # TODO: re-enable the check once we fix the compile path - # assert inp.placements[0] == input_layout - dt_inp = inp - else: - dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False) - - if desired_layout is not None and input_layout != desired_layout: - dt_inp = dt_inp.redistribute(placements=(desired_layout,)) - prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp) - else: - prepared_inputs.append(inp) + prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)) return tuple(prepared_inputs) def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): @@ -431,20 +435,10 @@ def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): prepared_kwarg_inputs = {} for kwarg_key in kwarg_inputs.keys(): kwarg_val = kwarg_inputs[kwarg_key] - input_layout = None - if kwarg_key in self.input_kwarg_layouts: - input_layout = self.input_kwarg_layouts[kwarg_key] - assert isinstance(kwarg_val, torch.Tensor), f"input of key {kwarg_key} to the module should be a Tensor!" - kwarg_val = DTensor.from_local(kwarg_val, device_mesh, (input_layout,), run_check=False) - - if kwarg_key in self.desired_input_kwarg_layouts: - desired_layout = self.desired_input_kwarg_layouts[kwarg_key] - if desired_layout != input_layout: - kwarg_val = kwarg_val.redistribute(placements=(desired_layout,)) - - prepared_kwarg_inputs[kwarg_key] = kwarg_val.to_local() if self.use_local_output else kwarg_val - else: - prepared_kwarg_inputs[kwarg_key] = kwarg_val + input_layout = self.input_kwarg_layouts.get(kwarg_key) + desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) + + prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout) return (prepared_arg_inputs, prepared_kwarg_inputs) From 62311257adb902d6a4ea98809c88895af1dbbf2b Mon Sep 17 00:00:00 2001 From: diwei sun Date: Wed, 12 Jun 2024 05:33:54 +0000 Subject: [PATCH 676/706] Add 1 test case for Convtranspose1D in op microbenchmark (#127216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Operator Convtransposd1d suffers performance regression with specific shape, #120982. Then we'd like to have this shape included into op level benchmark in this PR. I reproduced the regression that convtranspos1d with shape [2016, 1026, 1024, 256, 1, 224]. Here is the summary: Hardware info: Intel SPR8480-56cores per socket with frequency=2.1G. Performance comparison between torch 1.13 vs. torch 2.2 Benchmarking **PyTorch1.13**: ConvTranspose1d Mode: Eager Name: ConvTranspose1d_IC2016_OC1026_kernel1024_stride256_N1_L224_cpu Input: IC: 2016, OC: 1026, kernel: 1024, stride: 256, N: 1, L: 224, device: cpu Forward Execution Time (s) : **0.96s** Benchmarking **PyTorch2.2:** ConvTranspose1d Mode: Eager Name: ConvTranspose1d_IC2016_OC1026_kernel1024_stride256_N1_L224_cpu Input: IC: 2016, OC: 1026, kernel: 1024, stride: 256, N: 1, L: 224, device: cpu Forward Execution Time (s) : **7.988s** Also benchmarking for 7 rounds to check the variance.   | Round1 | Round2 | Round3 | Round4 | Round5 | Round6 | Round7 | Normalized Variance -- | -- | -- | -- | -- | -- | -- | -- | -- Pytorch1.13 | 0.971 | 0.972 | 0.969 | 0.970 | 0.972 | 0.970 | 0.971 | 0.0002% Pytorch 2.2 | 8.064 | 8.053 | 8.027 | 7.927 | 7.971 | 7.929 | 7.902 | 0.0059% Ratio v2.2 vs. v1.13(Lower is better) | 8.31 | 8.28 | 8.29 | 8.18 | 8.20 | 8.18 | 8.14 |   Reproduce script: numctl -N 0 python -m pt.conv_test Pull Request resolved: https://github.com/pytorch/pytorch/pull/127216 Approved by: https://github.com/chuanqi129, https://github.com/jgong5, https://github.com/atalman --- benchmarks/operator_benchmark/pt/configs.py | 11 +++++++++++ benchmarks/operator_benchmark/pt/conv_test.py | 4 +++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/benchmarks/operator_benchmark/pt/configs.py b/benchmarks/operator_benchmark/pt/configs.py index 3add77db1a87..ccfa62cc364e 100644 --- a/benchmarks/operator_benchmark/pt/configs.py +++ b/benchmarks/operator_benchmark/pt/configs.py @@ -34,6 +34,17 @@ def remove_cuda(config_list): tags=["long"], ) +convtranspose_1d_configs_short = op_bench.config_list( + attr_names=["IC", "OC", "kernel", "stride", "N", "L"], + attrs=[ + [2016, 1026, 1024, 256, 1, 224], + ], + cross_product_configs={ + "device": ["cpu", "cuda"], + }, + tags=["short"], +) + # Configs for Conv2d and ConvTranspose1d conv_2d_configs_short = op_bench.config_list( attr_names=[ diff --git a/benchmarks/operator_benchmark/pt/conv_test.py b/benchmarks/operator_benchmark/pt/conv_test.py index e01473a04f5b..ad315d8a0bb8 100644 --- a/benchmarks/operator_benchmark/pt/conv_test.py +++ b/benchmarks/operator_benchmark/pt/conv_test.py @@ -37,7 +37,9 @@ def forward(self, input): configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark ) op_bench.generate_pt_test( - configs.conv_1d_configs_short + configs.conv_1d_configs_long, + configs.convtranspose_1d_configs_short + + configs.conv_1d_configs_short + + configs.conv_1d_configs_long, ConvTranspose1dBenchmark, ) From dcc0093dba163df47e67ba676294541503ea84bd Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 10 Jun 2024 19:16:54 +0000 Subject: [PATCH 677/706] [BE][Easy] export explicitly imported public submodules (#127703) Add top-level submodules `torch.{storage,serialization,functional,amp,overrides,types}` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127703 Approved by: https://github.com/ezyang --- torch/__init__.py | 84 ++++++++++++++++--------------------- torch/optim/__init__.py | 34 ++++++++------- torch/utils/data/sampler.py | 3 +- 3 files changed, 57 insertions(+), 64 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index a5d22d2ebf72..fbe0e59c4017 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1,6 +1,4 @@ -# mypy: allow-untyped-defs - -r""" +""" The torch package contains data structures for multi-dimensional tensors and defines mathematical operations over these tensors. Additionally, it provides many utilities for efficient serialization of @@ -10,6 +8,8 @@ on an NVIDIA GPU with compute capability >= 3.0. """ +# mypy: allow-untyped-defs + import math import os import sys @@ -289,10 +289,6 @@ def load_shared_libraries(library_path): _load_global_deps() from torch._C import * # noqa: F403 -# Appease the type checker; ordinarily this binding is inserted by the -# torch._C module initialization code in C -if TYPE_CHECKING: - from . import _C as _C # noqa: TCH004 class SymInt: """ @@ -614,10 +610,9 @@ def sym_not(a): a (SymBool or bool): Object to negate """ import sympy - from .overrides import has_torch_function_unary, handle_torch_function - if has_torch_function_unary(a): - return handle_torch_function(sym_not, (a,), a) + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(sym_not, (a,), a) if hasattr(a, '__sym_not__'): return a.__sym_not__() if isinstance(a, sympy.Basic): @@ -630,10 +625,8 @@ def sym_float(a): Args: a (SymInt, SymFloat, or object): Object to cast """ - from .overrides import has_torch_function_unary, handle_torch_function - - if has_torch_function_unary(a): - return handle_torch_function(sym_float, (a,), a) + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(sym_float, (a,), a) if isinstance(a, SymFloat): return a elif hasattr(a, '__sym_float__'): @@ -647,10 +640,8 @@ def sym_int(a): Args: a (SymInt, SymFloat, or object): Object to cast """ - from .overrides import has_torch_function_unary, handle_torch_function - - if has_torch_function_unary(a): - return handle_torch_function(sym_int, (a,), a) + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(sym_int, (a,), a) if isinstance(a, SymInt): return a elif isinstance(a, SymFloat): @@ -664,10 +655,8 @@ def sym_max(a, b): promotes to float if any argument is float (unlike builtins.max, which will faithfully preserve the type of the input argument). """ - from .overrides import has_torch_function, handle_torch_function - - if has_torch_function((a, b)): - return handle_torch_function(sym_max, (a, b), a, b) + if overrides.has_torch_function((a, b)): + return overrides.handle_torch_function(sym_max, (a, b), a, b) if isinstance(a, (SymInt, SymFloat)): return a.__sym_max__(b) elif isinstance(b, (SymInt, SymFloat)): @@ -683,11 +672,9 @@ def sym_max(a, b): return builtins.max(a, b) def sym_min(a, b): - """ SymInt-aware utility for min().""" - from .overrides import has_torch_function, handle_torch_function - - if has_torch_function((a, b)): - return handle_torch_function(sym_min, (a, b), a, b) + """SymInt-aware utility for min().""" + if overrides.has_torch_function((a, b)): + return overrides.handle_torch_function(sym_min, (a, b), a, b) if isinstance(a, (SymInt, SymFloat)): return a.__sym_min__(b) elif isinstance(b, (SymInt, SymFloat)): @@ -702,10 +689,8 @@ def sym_min(a, b): # Drop in replacement for math.sqrt, math.sin, math.cos etc def _get_sym_math_fn(name): def fn(a): - from .overrides import has_torch_function_unary, handle_torch_function - - if has_torch_function_unary(a): - return handle_torch_function(fn, (a,), a) + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(fn, (a,), a) if hasattr(a, f"__sym_{name}__"): return getattr(a, f"__sym_{name}__")() return getattr(math, name)(a) @@ -727,10 +712,8 @@ def fn(a): def sym_ite(b, t, f): - from .overrides import has_torch_function, handle_torch_function - - if has_torch_function((b, t, f)): - return handle_torch_function(sym_ite, (b, t, f), b, t, f) + if overrides.has_torch_function((b, t, f)): + return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f) assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f) if isinstance(b, SymBool): return b.__sym_ite__(t, f) @@ -760,16 +743,20 @@ def sym_ite(b, t, f): ''').strip()) from None raise # If __file__ is not None the cause is unknown, so just re-raise. +# The torch._C submodule is already loaded via `from torch._C import *` above +# Make an explicit reference to the _C submodule to appease linters +from torch import _C as _C + __name, __obj = '', None for __name in dir(_C): if __name[0] != '_' and not __name.endswith('Base'): __all__.append(__name) __obj = getattr(_C, __name) if callable(__obj) or inspect.isclass(__obj): - if __obj.__module__ != __name__: + if __obj.__module__ != __name__: # "torch" # TODO: fix their module from C++ side if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']: - __obj.__module__ = __name__ + __obj.__module__ = __name__ # "torch" elif __name == 'TensorBase': # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch. delattr(sys.modules[__name__], __name) @@ -1478,6 +1465,7 @@ def _check_tensor_all(cond, message=None): # noqa: F811 ################################################################################ from ._tensor import Tensor +from torch import storage as storage from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal # NOTE: New Storage classes should never be added. When adding a new @@ -1665,7 +1653,9 @@ def _dtype(self): _tensor_classes: Set[Type] = set() # If you edit these imports, please update torch/__init__.py.in as well +from torch import random as random from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed +from torch import serialization as serialization from .serialization import save, load from ._tensor_str import set_printoptions @@ -1682,6 +1672,7 @@ def _manager_path(): raise RuntimeError("Unable to find torch_shm_manager at " + path) return path.encode('utf-8') +from torch import amp as amp from torch.amp import autocast, GradScaler # Initializing the extension shadows the built-in python float / int classes; @@ -1717,7 +1708,7 @@ def _manager_path(): if __name.startswith('__') or __name in PRIVATE_OPS: continue __obj = getattr(_C._VariableFunctions, __name) - __obj.__module__ = __name__ + __obj.__module__ = __name__ # "torch" # Hide some APIs that should not be public if __name == "segment_reduce": # TODO: Once the undocumented FC window is passed, remove the line bellow @@ -1751,6 +1742,7 @@ def _manager_path(): ################################################################################ # needs to be after the above ATen bindings so we can overwrite from Python side +from torch import functional as functional from .functional import * # noqa: F403 @@ -1769,10 +1761,8 @@ def _manager_path(): def _assert(condition, message): r"""A wrapper around Python's assert which is symbolically traceable. """ - from .overrides import has_torch_function, handle_torch_function - - if type(condition) is not torch.Tensor and has_torch_function((condition,)): - return handle_torch_function(_assert, (condition,), condition, message) + if type(condition) is not torch.Tensor and overrides.has_torch_function((condition,)): + return overrides.handle_torch_function(_assert, (condition,), condition, message) assert condition, message ################################################################################ @@ -1801,7 +1791,6 @@ def _assert(condition, message): from torch import nn as nn from torch.signal import windows as windows from torch import optim as optim -import torch.optim._multi_tensor from torch import multiprocessing as multiprocessing from torch import sparse as sparse from torch import special as special @@ -1809,7 +1798,6 @@ def _assert(condition, message): from torch import jit as jit from torch import linalg as linalg from torch import hub as hub -from torch import random as random from torch import distributions as distributions from torch import testing as testing from torch import backends as backends @@ -1817,6 +1805,8 @@ def _assert(condition, message): from torch import __config__ as __config__ from torch import __future__ as __future__ from torch import profiler as profiler +from torch import overrides as overrides +from torch import types as types # Quantized, sparse, AO, etc. should be last to get imported, as nothing # is expected to depend on them. @@ -1827,7 +1817,7 @@ def _assert(condition, message): import torch.nn.qat import torch.nn.intrinsic -_C._init_names(list(torch._storage_classes)) +_C._init_names(list(_storage_classes)) # attach docstrings to torch and tensor functions from . import _torch_docs, _tensor_docs, _storage_docs, _size_docs @@ -1854,7 +1844,7 @@ def compiled_with_cxx11_abi() -> builtins.bool: # If you are seeing this, it means that this call site was not checked if # the memory format could be preserved, and it was switched to old default # behaviour of contiguous -legacy_contiguous_format = contiguous_format +legacy_contiguous_format = contiguous_format # defined by _C._initExtension() # Register fork handler to initialize OpenMP in child processes (see gh-28389) from torch.multiprocessing._atfork import register_after_fork @@ -1876,7 +1866,7 @@ def compiled_with_cxx11_abi() -> builtins.bool: # Import experimental masked operations support. See # [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more # information. -from . import masked +from torch import masked as masked # Import removed ops with error message about removal from ._linalg_utils import ( # type: ignore[misc] diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index 58d9c948416b..341d07b1a2e8 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -6,21 +6,22 @@ future. """ -from . import lr_scheduler, swa_utils -from .adadelta import Adadelta -from .adagrad import Adagrad -from .adam import Adam -from .adamax import Adamax -from .adamw import AdamW -from .asgd import ASGD -from .lbfgs import LBFGS -from .nadam import NAdam -from .optimizer import Optimizer -from .radam import RAdam -from .rmsprop import RMSprop -from .rprop import Rprop -from .sgd import SGD -from .sparse_adam import SparseAdam +from torch.optim import lr_scheduler, swa_utils +from torch.optim.adadelta import Adadelta +from torch.optim.adagrad import Adagrad +from torch.optim.adam import Adam +from torch.optim.adamax import Adamax +from torch.optim.adamw import AdamW +from torch.optim.asgd import ASGD +from torch.optim.lbfgs import LBFGS +from torch.optim.nadam import NAdam +from torch.optim.optimizer import Optimizer +from torch.optim.radam import RAdam +from torch.optim.rmsprop import RMSprop +from torch.optim.rprop import Rprop +from torch.optim.sgd import SGD +from torch.optim.sparse_adam import SparseAdam + del adadelta # type: ignore[name-defined] # noqa: F821 del adagrad # type: ignore[name-defined] # noqa: F821 @@ -36,3 +37,6 @@ del optimizer # type: ignore[name-defined] # noqa: F821 del nadam # type: ignore[name-defined] # noqa: F821 del lbfgs # type: ignore[name-defined] # noqa: F821 + + +import torch.optim._multi_tensor diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 476d8dfadd41..c6ad6933fb49 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import torch -from torch import Tensor from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union @@ -213,7 +212,7 @@ class WeightedRandomSampler(Sampler[int]): [0, 1, 4, 3, 2] """ - weights: Tensor + weights: torch.Tensor num_samples: int replacement: bool From a42169999822b8911cd834bf18e39fb5168e6d40 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 12 Jun 2024 06:25:53 +0000 Subject: [PATCH 678/706] Revert "[tp] refactor and fix PrepareModuleInput for DTensor inputs (#128431)" This reverts commit 089f9a116ac8b2c14d6351b52614b529caba126b. Reverted https://github.com/pytorch/pytorch/pull/128431 on behalf of https://github.com/DanilBaibak due to Sorry for the revert. Your changes broke the linter. Here you can find more details - https://hud.pytorch.org/pytorch/pytorch/commit/089f9a116ac8b2c14d6351b52614b529caba126b ([comment](https://github.com/pytorch/pytorch/pull/128431#issuecomment-2162197858)) --- .../tensor/parallel/test_tp_style.py | 12 ----- torch/distributed/tensor/parallel/style.py | 52 +++++++++++-------- 2 files changed, 29 insertions(+), 35 deletions(-) diff --git a/test/distributed/tensor/parallel/test_tp_style.py b/test/distributed/tensor/parallel/test_tp_style.py index 776bdc9b50b4..e2a9a01da85b 100644 --- a/test/distributed/tensor/parallel/test_tp_style.py +++ b/test/distributed/tensor/parallel/test_tp_style.py @@ -317,18 +317,6 @@ def forward(self, *, x, y=2, z=None): self.assertEqual(comm_mode.get_total_counts(), 2) self.assertEqual(output.shape, (1 * self.world_size, 8)) - # test the case where x is a DTensor - x_dt = DTensor.from_local( - torch.randn(1, 8, device=self.device_type), mesh, [Shard(0)] - ) - with comm_mode: - output = test_kwonly_mod( - x=x_dt, z=torch.ones(1, 8, device=self.device_type) - ) - - self.assertEqual(comm_mode.get_total_counts(), 2) - self.assertEqual(output.shape, (1 * self.world_size, 8)) - @with_comms def test_prepare_module_output(self): mesh = init_device_mesh(self.device_type, (self.world_size,)) diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 00d85bf5d499..f532b97e97d0 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod -from typing import Optional, Union, Tuple, Dict, Any +from typing import Optional, Union, Tuple, Dict from functools import partial import torch @@ -400,23 +400,6 @@ def __init__( assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \ "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" - def _prepare_input_arg(self, input: Any, mesh: DeviceMesh, input_layout: Placement, desired_layout: Placement): - if input_layout is not None: - if isinstance(input, DTensor): - # TODO: re-enable the check once we fix the compile path - # assert inp.placements[0] == input_layout - dt_inp = input - else: - assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!" - dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False) - - if desired_layout is not None and input_layout != desired_layout: - dt_inp = dt_inp.redistribute(placements=(desired_layout,)) - - return dt_inp.to_local() if self.use_local_output else dt_inp - else: - return input - def _prepare_input_fn(self, inputs, device_mesh): if self.input_layouts is None: return inputs @@ -426,8 +409,21 @@ def _prepare_input_fn(self, inputs, device_mesh): if len(inputs) != len(self.input_layouts): raise ValueError("module inputs and input_layouts should have same length!") + assert self.desired_input_layouts is not None, "desired module inputs should not be None!" for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts): - prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)) + if input_layout is not None: + if isinstance(inp, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert inp.placements[0] == input_layout + dt_inp = inp + else: + dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False) + + if desired_layout is not None and input_layout != desired_layout: + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) + prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp) + else: + prepared_inputs.append(inp) return tuple(prepared_inputs) def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): @@ -435,10 +431,20 @@ def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): prepared_kwarg_inputs = {} for kwarg_key in kwarg_inputs.keys(): kwarg_val = kwarg_inputs[kwarg_key] - input_layout = self.input_kwarg_layouts.get(kwarg_key) - desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) - - prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout) + input_layout = None + if kwarg_key in self.input_kwarg_layouts: + input_layout = self.input_kwarg_layouts[kwarg_key] + assert isinstance(kwarg_val, torch.Tensor), f"input of key {kwarg_key} to the module should be a Tensor!" + kwarg_val = DTensor.from_local(kwarg_val, device_mesh, (input_layout,), run_check=False) + + if kwarg_key in self.desired_input_kwarg_layouts: + desired_layout = self.desired_input_kwarg_layouts[kwarg_key] + if desired_layout != input_layout: + kwarg_val = kwarg_val.redistribute(placements=(desired_layout,)) + + prepared_kwarg_inputs[kwarg_key] = kwarg_val.to_local() if self.use_local_output else kwarg_val + else: + prepared_kwarg_inputs[kwarg_key] = kwarg_val return (prepared_arg_inputs, prepared_kwarg_inputs) From 8b3daf1768fdf8c1c54771fd738552c4fa984b12 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Tue, 11 Jun 2024 07:44:10 -0700 Subject: [PATCH 679/706] Add FloatTrueDiv and ToFloat to SYMPY_INTERP (#128418) Summary: I admit I'm not 100% sure what I'm doing here. I'm hitting a bug in the FX graph cache when we try to evaluate a guards expression. We're creating guards that look like this: ``` Ne(CeilToInt(FloatTrueDiv(ToFloat(8*L['t0']) - 4.0, 8.0))*CeilToInt(FloatTrueDiv(ToFloat(8*L['t1']) - 4.0, 8.0)), CeilToInt(FloatTrueDiv(ToFloat(8*L['t1']) - 4.0, 8.0))) and ... ``` It looks like we have a facility to define these operators in the SYMPY_INTERP map and we're just missing FloatTrueDiv and ToFloat. What's surprsing to me is that we're only hitting this problem with the FX graph enabled. We can create such guards, but we've never actually evaluated any? Test Plan: `TORCHINDUCTOR_FX_GRAPH_CACHE=1 python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --only detectron2_fcos_r_50_fpn` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128418 Approved by: https://github.com/ezyang --- torch/fx/experimental/symbolic_shapes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index bf21ef7ffb2c..e1170fd49f8c 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1393,6 +1393,8 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 'RoundDecimal': builtins.round, 'TruncToInt': math.trunc, 'IntTrueDiv': operator.truediv, + 'FloatTrueDiv': operator.truediv, + 'ToFloat': builtins.float, } From 0b331fd5d75ae2c3eed43293f016247bf37ddab5 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Wed, 12 Jun 2024 07:47:12 +0000 Subject: [PATCH 680/706] [CUDA] Abate `SoftMax.cu` compiler warning spam (#128468) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Avoids excessively spammy warnings such as ``` pytorch/aten/src/ATen/native/cuda/SoftMax.cu(844): warning #191-D: type qualifier is meaningless on cast type [&] { const auto& the_type = input.scalar_type(); constexpr const char* at_dispatch_name = "host_softmax"; at::ScalarType _st = ::detail::scalar_type(the_type); ; switch (_st) { case at::ScalarType::Double: { do { if constexpr (!at::should_include_kernel_dtype( at_dispatch_name, at::ScalarType::Double)) { do { ::c10::detail::deprecated_AT_ERROR(); if (!(false)) { ::c10::detail::torchCheckFail( __func__, "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", static_cast(844), (::c10::detail::torchCheckMsgImpl( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", ::c10::str("dtype '", toString(at::ScalarType::Double), "' not selected for kernel tag ", at_dispatch_name)))); }; } while (false); } } while (0); using scalar_t __attribute__((__unused__)) = c10::impl::ScalarTypeToCPPTypeT; return [&] { using accscalar_t = acc_type; if (!half_to_float) { auto output_ptr = output.mutable_data_ptr(); auto input_ptr = input.const_data_ptr(); if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { int64_t remaining = outer_size; int64_t chunk_size = (1L << 30L) / dim_size; while(remaining > 0) { dispatch_softmax_forward( output_ptr, input_ptr, dim_size, dim_size, std::min(remaining, chunk_size), nullptr ); input_ptr += chunk_size * dim_size; output_ptr += chunk_size * dim_size; remaining -= chunk_size; } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); dim3 block = SoftMaxForward_getBlockSize(dim_size); size_t smem_reduction_sz = block.x / 32 * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); bool can_use_smem = dim_size < max_elements_per_smem; can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); if (can_use_smem) { size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; cunn_SoftMaxForwardSmem <<>>(output_ptr, input_ptr, dim_size); } else { cunn_SoftMaxForward <<>>(output_ptr, input_ptr, dim_size); } do { const cudaError_t __err = cudaGetLastError(); c10::cuda::c10_cuda_check_implementation( static_cast(__err), "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", __func__, static_cast(880), true); } while (0); } } else { auto output_ptr = output.mutable_data_ptr(); auto input_ptr = input.const_data_ptr(); if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { int64_t remaining = outer_size; int64_t chunk_size = (1<<30) / dim_size; while(remaining > 0) { dispatch_softmax_forward( output_ptr, input_ptr, dim_size, dim_size, std::min(remaining, chunk_size), nullptr ); input_ptr += chunk_size * dim_size; output_ptr += chunk_size * dim_size; remaining -= chunk_size; } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); dim3 block = SoftMaxForward_getBlockSize(dim_size); size_t smem_reduction_sz = block.x / 32 * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); bool can_use_smem = dim_size < max_elements_per_smem; can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); if (can_use_smem) { size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; cunn_SoftMaxForwardSmem <<>>(output_ptr, input_ptr, dim_size); } else { cunn_SoftMaxForward <<>>(output_ptr, input_ptr, dim_size); } do { const cudaError_t __err = cudaGetLastError(); c10::cuda::c10_cuda_check_implementation( static_cast(__err), "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", __func__, static_cast(916), true); } while (0); } } }(); } case at::ScalarType::Float: { do { if constexpr (!at::should_include_kernel_dtype( at_dispatch_name, at::ScalarType::Float)) { do { ::c10::detail::deprecated_AT_ERROR(); if (!(false)) { ::c10::detail::torchCheckFail( __func__, "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", static_cast(844), (::c10::detail::torchCheckMsgImpl( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", ::c10::str("dtype '", toString(at::ScalarType::Float), "' not selected for kernel tag ", at_dispatch_name)))); }; } while (false); } } while (0); using scalar_t __attribute__((__unused__)) = c10::impl::ScalarTypeToCPPTypeT; return [&] { using accscalar_t = acc_type; if (!half_to_float) { auto output_ptr = output.mutable_data_ptr(); auto input_ptr = input.const_data_ptr(); if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { int64_t remaining = outer_size; int64_t chunk_size = (1L << 30L) / dim_size; while(remaining > 0) { dispatch_softmax_forward( output_ptr, input_ptr, dim_size, dim_size, std::min(remaining, chunk_size), nullptr ); input_ptr += chunk_size * dim_size; output_ptr += chunk_size * dim_size; remaining -= chunk_size; } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); dim3 block = SoftMaxForward_getBlockSize(dim_size); size_t smem_reduction_sz = block.x / 32 * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); bool can_use_smem = dim_size < max_elements_per_smem; can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); if (can_use_smem) { size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; cunn_SoftMaxForwardSmem <<>>(output_ptr, input_ptr, dim_size); } else { cunn_SoftMaxForward <<>>(output_ptr, input_ptr, dim_size); } do { const cudaError_t __err = cudaGetLastError(); c10::cuda::c10_cuda_check_implementation( static_cast(__err), "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", __func__, static_cast(880), true); } while (0); } } else { auto output_ptr = output.mutable_data_ptr(); auto input_ptr = input.const_data_ptr(); if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { int64_t remaining = outer_size; int64_t chunk_size = (1<<30) / dim_size; while(remaining > 0) { dispatch_softmax_forward( output_ptr, input_ptr, dim_size, dim_size, std::min(remaining, chunk_size), nullptr ); input_ptr += chunk_size * dim_size; output_ptr += chunk_size * dim_size; remaining -= chunk_size; } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); dim3 block = SoftMaxForward_getBlockSize(dim_size); size_t smem_reduction_sz = block.x / 32 * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); bool can_use_smem = dim_size < max_elements_per_smem; can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); if (can_use_smem) { size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; cunn_SoftMaxForwardSmem <<>>(output_ptr, input_ptr, dim_size); } else { cunn_SoftMaxForward <<>>(output_ptr, input_ptr, dim_size); } do { const cudaError_t __err = cudaGetLastError(); c10::cuda::c10_cuda_check_implementation( static_cast(__err), "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", __func__, static_cast(916), true); } while (0); } } }(); } case at::ScalarType::Half: { do { if constexpr (!at::should_include_kernel_dtype( at_dispatch_name, at::ScalarType::Half)) { do { ::c10::detail::deprecated_AT_ERROR(); if (!(false)) { ::c10::detail::torchCheckFail( __func__, "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", static_cast(844), (::c10::detail::torchCheckMsgImpl( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", ::c10::str("dtype '", toString(at::ScalarType::Half), "' not selected for kernel tag ", at_dispatch_name)))); }; } while (false); } } while (0); using scalar_t __attribute__((__unused__)) = c10::impl::ScalarTypeToCPPTypeT; return [&] { using accscalar_t = acc_type; if (!half_to_float) { auto output_ptr = output.mutable_data_ptr(); auto input_ptr = input.const_data_ptr(); if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { int64_t remaining = outer_size; int64_t chunk_size = (1L << 30L) / dim_size; while(remaining > 0) { dispatch_softmax_forward( output_ptr, input_ptr, dim_size, dim_size, std::min(remaining, chunk_size), nullptr ); input_ptr += chunk_size * dim_size; output_ptr += chunk_size * dim_size; remaining -= chunk_size; } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); dim3 block = SoftMaxForward_getBlockSize(dim_size); size_t smem_reduction_sz = block.x / 32 * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); bool can_use_smem = dim_size < max_elements_per_smem; can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); if (can_use_smem) { size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; cunn_SoftMaxForwardSmem <<>>(output_ptr, input_ptr, dim_size); } else { cunn_SoftMaxForward <<>>(output_ptr, input_ptr, dim_size); } do { const cudaError_t __err = cudaGetLastError(); c10::cuda::c10_cuda_check_implementation( static_cast(__err), "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", __func__, static_cast(880), true); } while (0); } } else { auto output_ptr = output.mutable_data_ptr(); auto input_ptr = input.const_data_ptr(); if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { int64_t remaining = outer_size; int64_t chunk_size = (1<<30) / dim_size; while(remaining > 0) { dispatch_softmax_forward( output_ptr, input_ptr, dim_size, dim_size, std::min(remaining, chunk_size), nullptr ); input_ptr += chunk_size * dim_size; output_ptr += chunk_size * dim_size; remaining -= chunk_size; } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); dim3 block = SoftMaxForward_getBlockSize(dim_size); size_t smem_reduction_sz = block.x / 32 * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); bool can_use_smem = dim_size < max_elements_per_smem; can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); if (can_use_smem) { size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; cunn_SoftMaxForwardSmem <<>>(output_ptr, input_ptr, dim_size); } else { cunn_SoftMaxForward <<>>(output_ptr, input_ptr, dim_size); } do { const cudaError_t __err = cudaGetLastError(); c10::cuda::c10_cuda_check_implementation( static_cast(__err), "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", __func__, static_cast(916), true); } while (0); } } }(); } case at::ScalarType::BFloat16: { do { if constexpr (!at::should_include_kernel_dtype( at_dispatch_name, at::ScalarType::BFloat16)) { do { ::c10::detail::deprecated_AT_ERROR(); if (!(false)) { ::c10::detail::torchCheckFail( __func__, "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", static_cast(844), (::c10::detail::torchCheckMsgImpl( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", ::c10::str("dtype '", toString(at::ScalarType::BFloat16), "' not selected for kernel tag ", at_dispatch_name)))); }; } while (false); } } while (0); using scalar_t __attribute__((__unused__)) = c10::impl::ScalarTypeToCPPTypeT; return [&] { using accscalar_t = acc_type; if (!half_to_float) { auto output_ptr = output.mutable_data_ptr(); auto input_ptr = input.const_data_ptr(); if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { int64_t remaining = outer_size; int64_t chunk_size = (1L << 30L) / dim_size; while(remaining > 0) { dispatch_softmax_forward( output_ptr, input_ptr, dim_size, dim_size, std::min(remaining, chunk_size), nullptr ); input_ptr += chunk_size * dim_size; output_ptr += chunk_size * dim_size; remaining -= chunk_size; } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); dim3 block = SoftMaxForward_getBlockSize(dim_size); size_t smem_reduction_sz = block.x / 32 * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); bool can_use_smem = dim_size < max_elements_per_smem; can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); if (can_use_smem) { size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; cunn_SoftMaxForwardSmem <<>>(output_ptr, input_ptr, dim_size); } else { cunn_SoftMaxForward <<>>(output_ptr, input_ptr, dim_size); } do { const cudaError_t __err = cudaGetLastError(); c10::cuda::c10_cuda_check_implementation( static_cast(__err), "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", __func__, static_cast(880), true); } while (0); } } else { auto output_ptr = output.mutable_data_ptr(); auto input_ptr = input.const_data_ptr(); if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { int64_t remaining = outer_size; int64_t chunk_size = (1<<30) / dim_size; while(remaining > 0) { dispatch_softmax_forward( output_ptr, input_ptr, dim_size, dim_size, std::min(remaining, chunk_size), nullptr ); input_ptr += chunk_size * dim_size; output_ptr += chunk_size * dim_size; remaining -= chunk_size; } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); dim3 block = SoftMaxForward_getBlockSize(dim_size); size_t smem_reduction_sz = block.x / 32 * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); bool can_use_smem = dim_size < max_elements_per_smem; can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); if (can_use_smem) { size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; cunn_SoftMaxForwardSmem <<>>(output_ptr, input_ptr, dim_size); } else { cunn_SoftMaxForward <<>>(output_ptr, input_ptr, dim_size); } do { const cudaError_t __err = cudaGetLastError(); c10::cuda::c10_cuda_check_implementation( static_cast(__err), "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", __func__, static_cast(916), true); } while (0); } } }(); } default: do { ::c10::detail::deprecated_AT_ERROR(); if (!(false)) { ::c10::detail::torchCheckFail( __func__, "/workspace/pytorch/aten/src/ATen/native/cuda/SoftMax.cu", static_cast(844), (::c10::detail::torchCheckMsgImpl( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", ::c10::str('"', at_dispatch_name, "\" not implemented for '", toString(_st), "'")))); }; } while (false); } }() ``` and ``` SoftMax.cu:844: warning: comparison of integer expressions of different signedness: ‘int64_t’ {aka ‘long int’} and ‘long unsigned int’ [-Wsign-compare] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128468 Approved by: https://github.com/valentinandrei --- aten/src/ATen/native/cuda/SoftMax.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 4aca753a510b..7616b7bdcc01 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -863,8 +863,8 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); - bool can_use_smem = dim_size < max_elements_per_smem; - can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); + bool can_use_smem = (size_t) dim_size < max_elements_per_smem; + can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); @@ -899,8 +899,8 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); - bool can_use_smem = dim_size < max_elements_per_smem; - can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); + bool can_use_smem = (size_t) dim_size < max_elements_per_smem; + can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); From 04037f3d22f00c55adf6f2b46b910928c4d44f24 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 10 Jun 2024 19:16:55 +0000 Subject: [PATCH 681/706] [BE] sort imports in `torch/__init__.py` (#127708) ---- - Sort import via `usort` - Change relative import `from . import xxx` to absolute import `from torch import xxx` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127708 Approved by: https://github.com/ezyang ghstack dependencies: #127703 --- torch/__init__.py | 358 +++++++++++++++++++++++++--------------------- 1 file changed, 192 insertions(+), 166 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index fbe0e59c4017..156f35939d02 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -10,17 +10,20 @@ # mypy: allow-untyped-defs +import builtins +import ctypes +import glob +import importlib +import importlib.util +import inspect import math import os -import sys import platform +import sys import textwrap -import ctypes -import inspect import threading -import pdb -import importlib -import importlib.util +from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union + # multipy/deploy is setting this import before importing torch, this is the most # reliable way we have to detect if we're running within deploy. @@ -28,19 +31,25 @@ def _running_with_deploy(): return sys.modules.get("torch._meta_registrations", None) is object -from ._utils import _import_dotted_name, classproperty -from ._utils import _functionalize_sync as _sync -from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \ - USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS + +from torch._utils import ( + _functionalize_sync as _sync, + _import_dotted_name, + classproperty, +) +from torch._utils_internal import ( + get_file_path, + prepare_multiprocessing_environment, + USE_GLOBAL_DEPS, + USE_RTLD_GLOBAL_WITH_LIBTORCH, +) # TODO(torch_deploy) figure out how to freeze version.py in fbcode build if _running_with_deploy(): __version__ = "torch-deploy-1.8" else: - from .torch_version import __version__ as __version__ + from torch.torch_version import __version__ as __version__ -from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union, List -import builtins __all__ = [ 'typename', 'is_tensor', 'is_storage', @@ -70,91 +79,97 @@ def _running_with_deploy(): ################################################################################ if sys.platform == 'win32': - import sysconfig - pfiles_path = os.getenv('ProgramFiles', 'C:\\Program Files') - py_dll_path = os.path.join(sys.exec_prefix, 'Library', 'bin') - th_dll_path = os.path.join(os.path.dirname(__file__), 'lib') - usebase_path = os.path.join(sysconfig.get_config_var("userbase"), 'Library', 'bin') - - # When users create a virtualenv that inherits the base environment, - # we will need to add the corresponding library directory into - # DLL search directories. Otherwise, it will rely on `PATH` which - # is dependent on user settings. - if sys.exec_prefix != sys.base_exec_prefix: - base_py_dll_path = os.path.join(sys.base_exec_prefix, 'Library', 'bin') - else: - base_py_dll_path = '' - dll_paths = list(filter(os.path.exists, [th_dll_path, py_dll_path, base_py_dll_path, usebase_path])) + def _load_dll_libraries(): + import sysconfig - if all(not os.path.exists(os.path.join(p, 'nvToolsExt64_1.dll')) for p in dll_paths): - nvtoolsext_dll_path = os.path.join( - os.getenv('NVTOOLSEXT_PATH', os.path.join(pfiles_path, 'NVIDIA Corporation', 'NvToolsExt')), 'bin', 'x64') - else: - nvtoolsext_dll_path = '' - - from .version import cuda as cuda_version - import glob - if cuda_version and all(not glob.glob(os.path.join(p, 'cudart64*.dll')) for p in dll_paths): - cuda_version_1 = cuda_version.replace('.', '_') - cuda_path_var = 'CUDA_PATH_V' + cuda_version_1 - default_path = os.path.join(pfiles_path, 'NVIDIA GPU Computing Toolkit', 'CUDA', 'v' + cuda_version) - cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), 'bin') - else: - cuda_path = '' + from torch.version import cuda as cuda_version - dll_paths.extend(filter(os.path.exists, [nvtoolsext_dll_path, cuda_path])) + pfiles_path = os.getenv('ProgramFiles', r'C:\Program Files') + py_dll_path = os.path.join(sys.exec_prefix, 'Library', 'bin') + th_dll_path = os.path.join(os.path.dirname(__file__), 'lib') + usebase_path = os.path.join(sysconfig.get_config_var("userbase"), 'Library', 'bin') - kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) - with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') - prev_error_mode = kernel32.SetErrorMode(0x0001) + # When users create a virtualenv that inherits the base environment, + # we will need to add the corresponding library directory into + # DLL search directories. Otherwise, it will rely on `PATH` which + # is dependent on user settings. + if sys.exec_prefix != sys.base_exec_prefix: + base_py_dll_path = os.path.join(sys.base_exec_prefix, 'Library', 'bin') + else: + base_py_dll_path = '' - kernel32.LoadLibraryW.restype = ctypes.c_void_p - if with_load_library_flags: - kernel32.LoadLibraryExW.restype = ctypes.c_void_p + dll_paths = [p for p in (th_dll_path, py_dll_path, base_py_dll_path, usebase_path) if os.path.exists(p)] - for dll_path in dll_paths: - os.add_dll_directory(dll_path) + if not builtins.any(os.path.exists(os.path.join(p, 'nvToolsExt64_1.dll')) for p in dll_paths): + nvtoolsext_dll_path = os.path.join( + os.getenv('NVTOOLSEXT_PATH', os.path.join(pfiles_path, 'NVIDIA Corporation', 'NvToolsExt')), 'bin', 'x64') + else: + nvtoolsext_dll_path = '' - try: - ctypes.CDLL('vcruntime140.dll') - ctypes.CDLL('msvcp140.dll') - ctypes.CDLL('vcruntime140_1.dll') - except OSError: - print('''Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. - It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe''') - - dlls = glob.glob(os.path.join(th_dll_path, '*.dll')) - path_patched = False - for dll in dlls: - is_loaded = False + if cuda_version and builtins.all(not glob.glob(os.path.join(p, 'cudart64*.dll')) for p in dll_paths): + cuda_version_1 = cuda_version.replace('.', '_') + cuda_path_var = 'CUDA_PATH_V' + cuda_version_1 + default_path = os.path.join(pfiles_path, 'NVIDIA GPU Computing Toolkit', 'CUDA', 'v' + cuda_version) + cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), 'bin') + else: + cuda_path = '' + + dll_paths.extend(p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p)) + + kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) + with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') + prev_error_mode = kernel32.SetErrorMode(0x0001) + + kernel32.LoadLibraryW.restype = ctypes.c_void_p if with_load_library_flags: - res = kernel32.LoadLibraryExW(dll, None, 0x00001100) - last_error = ctypes.get_last_error() - if res is None and last_error != 126: - err = ctypes.WinError(last_error) - err.strerror += f' Error loading "{dll}" or one of its dependencies.' - raise err - elif res is not None: - is_loaded = True - if not is_loaded: - if not path_patched: - os.environ['PATH'] = ';'.join(dll_paths + [os.environ['PATH']]) - path_patched = True - res = kernel32.LoadLibraryW(dll) - if res is None: - err = ctypes.WinError(ctypes.get_last_error()) - err.strerror += f' Error loading "{dll}" or one of its dependencies.' - raise err - - kernel32.SetErrorMode(prev_error_mode) + kernel32.LoadLibraryExW.restype = ctypes.c_void_p + + for dll_path in dll_paths: + os.add_dll_directory(dll_path) + + try: + ctypes.CDLL('vcruntime140.dll') + ctypes.CDLL('msvcp140.dll') + ctypes.CDLL('vcruntime140_1.dll') + except OSError: + print('''Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. + It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe''') + + dlls = glob.glob(os.path.join(th_dll_path, '*.dll')) + path_patched = False + for dll in dlls: + is_loaded = False + if with_load_library_flags: + res = kernel32.LoadLibraryExW(dll, None, 0x00001100) + last_error = ctypes.get_last_error() + if res is None and last_error != 126: + err = ctypes.WinError(last_error) + err.strerror += f' Error loading "{dll}" or one of its dependencies.' + raise err + elif res is not None: + is_loaded = True + if not is_loaded: + if not path_patched: + os.environ['PATH'] = ';'.join(dll_paths + [os.environ['PATH']]) + path_patched = True + res = kernel32.LoadLibraryW(dll) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error loading "{dll}" or one of its dependencies.' + raise err + + kernel32.SetErrorMode(prev_error_mode) + + _load_dll_libraries() + del _load_dll_libraries def _preload_cuda_deps(lib_folder, lib_name): """Preloads cuda deps if they could not be found otherwise.""" # Should only be called on Linux if default path resolution have failed assert platform.system() == 'Linux', 'Should only be called on Linux' - import glob + lib_path = None for path in sys.path: nvidia_path = os.path.join(path, 'nvidia') @@ -1456,7 +1471,7 @@ def _check_tensor_all(cond, message=None): # noqa: F811 # For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and # NumPy consistency (https://numpy.org/devdocs/reference/constants.html) -from math import e, nan , inf , pi +from math import e, inf, nan, pi newaxis: None = None __all__.extend(['e', 'pi', 'nan', 'inf', 'newaxis']) @@ -1464,9 +1479,17 @@ def _check_tensor_all(cond, message=None): # noqa: F811 # Define Storage and Tensor classes ################################################################################ -from ._tensor import Tensor -from torch import storage as storage -from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal +from torch._tensor import Tensor # usort: skip + +# needs to be after torch.Tensor is defined to avoid circular dependencies +from torch import storage as storage # usort: skip +from torch.storage import ( + _LegacyStorage, + _StorageBase, + _warn_typed_storage_removal, + TypedStorage, + UntypedStorage, +) # NOTE: New Storage classes should never be added. When adding a new # dtype, use torch.storage.TypedStorage directly. @@ -1653,16 +1676,22 @@ def _dtype(self): _tensor_classes: Set[Type] = set() # If you edit these imports, please update torch/__init__.py.in as well -from torch import random as random -from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed -from torch import serialization as serialization -from .serialization import save, load -from ._tensor_str import set_printoptions +from torch import amp as amp, random as random, serialization as serialization +from torch._tensor_str import set_printoptions +from torch.amp import autocast, GradScaler +from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state +from torch.serialization import load, save + +# Initializing the extension shadows the built-in python float / int classes; +# store them for later use by SymInt / SymFloat. +py_float = float +py_int = int ################################################################################ # Initialize extension ################################################################################ +# Shared memory manager needs to know the exact location of manager executable def _manager_path(): if _running_with_deploy() or platform.system() == 'Windows': return b"" @@ -1672,16 +1701,8 @@ def _manager_path(): raise RuntimeError("Unable to find torch_shm_manager at " + path) return path.encode('utf-8') -from torch import amp as amp -from torch.amp import autocast, GradScaler - -# Initializing the extension shadows the built-in python float / int classes; -# store them for later use by SymInt / SymFloat. -py_float = float -py_int = int - -# Shared memory manager needs to know the exact location of manager executable _C._initExtension(_manager_path()) + del _manager_path # Appease the type checker: it can't deal with direct setting of globals(). @@ -1734,17 +1755,16 @@ def _manager_path(): # Import TorchDynamo's lazy APIs to avoid circular dependenices ################################################################################ -# needs to be before from .functional import * to avoid circular dependencies -from ._compile import _disable_dynamo +# needs to be before from torch.functional import * to avoid circular dependencies +from torch._compile import _disable_dynamo # usort: skip ################################################################################ # Import interface functions defined in Python ################################################################################ # needs to be after the above ATen bindings so we can overwrite from Python side -from torch import functional as functional -from .functional import * # noqa: F403 - +from torch import functional as functional # usort: skip +from torch.functional import * # usort: skip # noqa: F403 ################################################################################ # Remove unnecessary members @@ -1772,55 +1792,61 @@ def _assert(condition, message): # Use the redundant form so that type checkers know that these are a part of # the public API. The "regular" import lines are there solely for the runtime # side effect of adding to the imported module's members for other users. -from torch import cuda as cuda -from torch import cpu as cpu -from torch import mps as mps -from torch import xpu as xpu -from torch import mtia as mtia -from torch import autograd as autograd -from torch.autograd import ( - no_grad as no_grad, + +# needs to be before import torch.nn as nn to avoid circular dependencies +from torch.autograd import ( # usort: skip enable_grad as enable_grad, - set_grad_enabled as set_grad_enabled, inference_mode as inference_mode, + no_grad as no_grad, + set_grad_enabled as set_grad_enabled, ) -from torch import fft as fft -from torch import futures as futures -from torch import _awaits as _awaits -from torch import nested as nested -from torch import nn as nn -from torch.signal import windows as windows -from torch import optim as optim -from torch import multiprocessing as multiprocessing -from torch import sparse as sparse -from torch import special as special + import torch.utils.backcompat -from torch import jit as jit -from torch import linalg as linalg -from torch import hub as hub -from torch import distributions as distributions -from torch import testing as testing -from torch import backends as backends import torch.utils.data -from torch import __config__ as __config__ -from torch import __future__ as __future__ -from torch import profiler as profiler -from torch import overrides as overrides -from torch import types as types +from torch import ( + __config__ as __config__, + __future__ as __future__, + _awaits as _awaits, + autograd as autograd, + backends as backends, + cpu as cpu, + cuda as cuda, + distributions as distributions, + fft as fft, + futures as futures, + hub as hub, + jit as jit, + linalg as linalg, + mps as mps, + mtia as mtia, + multiprocessing as multiprocessing, + nested as nested, + nn as nn, + optim as optim, + overrides as overrides, + profiler as profiler, + sparse as sparse, + special as special, + testing as testing, + types as types, + xpu as xpu, +) +from torch.signal import windows as windows # Quantized, sparse, AO, etc. should be last to get imported, as nothing # is expected to depend on them. -from torch import ao as ao +from torch import ao as ao # usort: skip + # nn.quant* depends on ao -- so should be after those. +import torch.nn.intrinsic +import torch.nn.qat import torch.nn.quantizable import torch.nn.quantized -import torch.nn.qat -import torch.nn.intrinsic _C._init_names(list(_storage_classes)) # attach docstrings to torch and tensor functions -from . import _torch_docs, _tensor_docs, _storage_docs, _size_docs +from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs del _torch_docs, _tensor_docs, _storage_docs, _size_docs @@ -1829,17 +1855,18 @@ def compiled_with_cxx11_abi() -> builtins.bool: return _C._GLIBCXX_USE_CXX11_ABI -# Import the ops "namespace" -from torch._ops import ops -from torch._classes import classes import torch._library -# quantization depends on torch.fx +# Import the ops "namespace" +from torch._classes import classes as classes +from torch._ops import ops as ops # usort: skip + +# quantization depends on torch.fx and torch.ops # Import quantization -from torch import quantization as quantization +from torch import quantization as quantization # usort: skip # Import the quasi random sampler -from torch import quasirandom as quasirandom +from torch import quasirandom as quasirandom # usort: skip # If you are seeing this, it means that this call site was not checked if # the memory format could be preserved, and it was switched to old default @@ -1853,15 +1880,13 @@ def compiled_with_cxx11_abi() -> builtins.bool: # Import tools that require fully imported torch (for applying # torch.jit.script as a decorator, for instance): -from ._lobpcg import lobpcg as lobpcg +from torch._lobpcg import lobpcg as lobpcg # These were previously defined in native_functions.yaml and appeared on the # `torch` namespace, but we moved them to c10 dispatch to facilitate custom # class usage. We add these lines here to preserve backward compatibility. -quantized_lstm = torch.ops.aten.quantized_lstm -quantized_gru = torch.ops.aten.quantized_gru - -from torch.utils.dlpack import from_dlpack, to_dlpack +quantized_lstm = ops.aten.quantized_lstm +quantized_gru = ops.aten.quantized_gru # Import experimental masked operations support. See # [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more @@ -1869,13 +1894,16 @@ def compiled_with_cxx11_abi() -> builtins.bool: from torch import masked as masked # Import removed ops with error message about removal -from ._linalg_utils import ( # type: ignore[misc] - matrix_rank, +from torch._linalg_utils import ( # type: ignore[misc] + _symeig as symeig, eig, - solve, lstsq, + matrix_rank, + solve, ) -from ._linalg_utils import _symeig as symeig # type: ignore[misc] + +from torch.utils.dlpack import from_dlpack, to_dlpack + class _TorchCompileInductorWrapper: compiler_name = "inductor" @@ -2106,7 +2134,8 @@ def fn(model: Callable): from torch import export as export -from torch._higher_order_ops import cond +from torch._higher_order_ops import cond as cond + def _register_device_module(device_type, module): r"""Register an external runtime module of the specific :attr:`device_type` @@ -2126,10 +2155,10 @@ def _register_device_module(device_type, module): sys.modules[torch_module_name] = module # expose return_types -from . import return_types -from . import library +from torch import library as library, return_types as return_types + if not TYPE_CHECKING: - from . import _meta_registrations + from torch import _meta_registrations # Enable CUDA Sanitizer if 'TORCH_CUDA_SANITIZER' in os.environ: @@ -2141,7 +2170,7 @@ def _register_device_module(device_type, module): import torch.fx.experimental.sym_node from torch import func as func -from torch.func import vmap +from torch.func import vmap as vmap # Register MPS specific decomps @@ -2176,9 +2205,7 @@ def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): # Import the following modules during type checking to enable code intelligence features, # such as auto-completion in tools like pylance, even when these modules are not explicitly # imported in user code. - from torch import _dynamo as _dynamo - from torch import _inductor as _inductor - from torch import onnx as onnx + from torch import _dynamo as _dynamo, _inductor as _inductor, onnx as onnx else: _lazy_modules = { @@ -2199,7 +2226,6 @@ def __getattr__(name): # Lazy modules if name in _lazy_modules: - import importlib return importlib.import_module(f".{name}", __name__) raise AttributeError(f"module '{__name__}' has no attribute '{name}'") @@ -2248,5 +2274,5 @@ def _constrain_as_size(symbol, min: Optional[builtins.int] = None, max: Optional torch.sym_constrain_range_for_size(symbol, min=min, max=max) -from . import _logging +from torch import _logging _logging._init_logs() From 1602c7d0c861a4382746ccb18c76d8703a636f4e Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 11 Jun 2024 15:31:57 -0700 Subject: [PATCH 682/706] [dynamo] Enable some inlining inbuilt nn module tests (#128440) Co-authored-by: Laith Sakka Pull Request resolved: https://github.com/pytorch/pytorch/pull/128440 Approved by: https://github.com/williamwen42, https://github.com/jansel ghstack dependencies: #126578 --- test/dynamo/test_inline_inbuilt_nn_modules.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 test/dynamo/test_inline_inbuilt_nn_modules.py diff --git a/test/dynamo/test_inline_inbuilt_nn_modules.py b/test/dynamo/test_inline_inbuilt_nn_modules.py new file mode 100644 index 000000000000..f7ba32bc15f3 --- /dev/null +++ b/test/dynamo/test_inline_inbuilt_nn_modules.py @@ -0,0 +1,62 @@ +# Owner(s): ["module: dynamo"] + +from torch._dynamo import config +from torch._dynamo.testing import make_test_cls_with_patches + +try: + from . import ( + test_aot_autograd, + test_functions, + test_higher_order_ops, + test_misc, + test_modules, + # test_repros, + ) +except ImportError: + import test_aot_autograd + import test_functions + import test_higher_order_ops + import test_misc + import test_modules + + +test_classes = {} + + +def make_inline_inbuilt_nn_modules_cls(cls): + suffix = "_inline_inbuilt_nn_modules" + + cls_prefix = "InlineInbuiltNNModules" + + test_class = make_test_cls_with_patches( + cls, + cls_prefix, + suffix, + (config, "inline_inbuilt_nn_modules", True), + xfail_prop="_expected_failure_inline_inbuilt_nn_modules", + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + return test_class + + +tests = [ + test_misc.MiscTests, + test_functions.FunctionTests, + test_modules.NNModuleTests, + test_higher_order_ops.HigherOrderOpTests, + test_higher_order_ops.FuncTorchHigherOrderOpTests, + test_aot_autograd.AotAutogradFallbackTests, + # test_repros.ReproTests, +] +for test in tests: + make_inline_inbuilt_nn_modules_cls(test) +del test + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() From ebb00a92bd940a7010410068e1130f12b5e0a7a6 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 11 Jun 2024 21:23:28 -0700 Subject: [PATCH 683/706] [dynamo] Skip freezing expect failure for inlining inbuilt nn modules (#128470) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128470 Approved by: https://github.com/mlazos ghstack dependencies: #126578, #128440 --- test/inductor/test_inductor_freezing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index 7d1688b366c4..4b6c04403002 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -338,6 +338,9 @@ def foo(mod, inp): ).run(code[0]) self.assertEqual(out_eager, out) + # With inlining of inbuilt nn modules, Dynamo traces the innards of inbuilt + # module and does not modify the eager module. + @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) def test_error_on_eager(self): mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device) From 1edcb31d34ef012d828bb9f39a8aef6020f580b2 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Tue, 11 Jun 2024 17:36:39 -0700 Subject: [PATCH 684/706] [RELAND][inductor][cpp] bf16/fp16 gemm template computed with fp32 (#128472) reland for https://github.com/pytorch/pytorch/pull/126068 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128472 Approved by: https://github.com/desertfire --- test/inductor/test_cpu_select_algorithm.py | 15 ++- torch/_inductor/codecache.py | 2 +- torch/_inductor/codegen/cpp.py | 3 +- torch/_inductor/codegen/cpp_gemm_template.py | 69 ++++++++--- torch/_inductor/codegen/cpp_micro_gemm.py | 90 ++++++++++---- .../_inductor/codegen/cpp_template_kernel.py | 113 ++++++++++-------- torch/_inductor/codegen/cpp_utils.py | 62 +++++++++- torch/_inductor/ir.py | 8 +- torch/_inductor/mkldnn_lowerings.py | 102 ++++++++++++++-- torch/_inductor/utils.py | 10 +- 10 files changed, 361 insertions(+), 113 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index aabd5bd08b15..0147ca8e24cd 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -77,11 +77,11 @@ class TestSelectAlgorithm(TestCase): @torch.no_grad @unittest.skipIf(not TEST_MKL, "Test requires MKL") @parametrize("batch_size", (1, 2, 1000)) - @parametrize("in_features", (1, 2, 1000)) - @parametrize("out_features", (1, 32, 1024)) + @parametrize("in_features", (1, 1000)) + @parametrize("out_features", (1, 1024)) @parametrize("bias", (True, False)) @parametrize("input_3d", (True, False)) - @dtypes(torch.float) + @dtypes(torch.float, torch.bfloat16, torch.half) def test_linear_static_shapes( self, batch_size, in_features, out_features, bias, input_3d, dtype ): @@ -97,7 +97,14 @@ def forward(self, x): mod = M(bias=bias).to(dtype=dtype).eval() B = (2, batch_size) if input_3d else (batch_size,) v = torch.randn(*B, in_features).to(dtype=dtype) - self.common(mod, (v,)) + # For bfloat16 and half, we have to relax the tolerance + # due to the difference associave orders in different + # kernel implementations + atol, rtol = 1e-4, 1e-4 + if dtype == torch.half or dtype == torch.bfloat16: + atol, rtol = 1e-2, 1e-2 + with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)): + self.common(mod, (v,), atol=atol, rtol=rtol) if ( counters["inductor"]["decompose_mm"] > 0 or counters["inductor"]["decompose_addmm"] > 0 diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index d151e3673474..574511d004a4 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1394,7 +1394,7 @@ class VecAVX2(VecISA): _bit_width = 256 _macro = ["CPU_CAPABILITY_AVX2"] _arch_flags = ( - "-mavx2 -mfma" if not _IS_WINDOWS else "/arch:AVX2" + "-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2" ) # TODO: use cflags _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 749a6e6d4cab..3370001aa429 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2850,9 +2850,8 @@ def store_reduction(self, name, index, value): return self.simd_vec def __exit__(self, exc_type, exc_val, exc_tb): - assert self._orig_wrapper_code is not None # Restore the wrapper_code - V.graph.wrapper_code = self._orig_wrapper_code + V.graph.wrapper_code = self._orig_wrapper_code # type: ignore[assignment] self.exit_stack.__exit__(exc_type, exc_val, exc_tb) def __enter__(self): diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index ce45ada78eba..cc8fcb699691 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -148,6 +148,7 @@ def __init__( beta=1, alpha=1, ): + assert layout.dtype in [torch.float, torch.bfloat16, torch.half] super().__init__("packed_gemm", input_nodes, layout) self.beta = beta self.alpha = alpha @@ -213,7 +214,13 @@ def cache_blocking(self) -> GemmBlocking: @staticmethod def add_choices( - choices, layout, input_nodes, beta=1, alpha=1, trans_w=False, input_indices=None + choices, + layout, + input_nodes, + beta=1, + alpha=1, + trans_w=False, + input_indices=None, ): if input_indices is None: input_indices = list(range(len(input_nodes))) @@ -233,28 +240,58 @@ def reorder_and_filter(inputs, layout_or_out): w_idx = input_indices[2] return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out - def transpose_weight(inputs, layout_or_out): + def maybe_to_dense(inputs, layout_or_out): + new_inputs = list(inputs) + if isinstance(inputs[1], torch.Tensor): + W = inputs[1] + new_inputs[1] = W.to_dense() if W.is_mkldnn else W + return new_inputs, layout_or_out + + def normalize_shapes(inputs, layout_or_out): if not trans_w: return inputs, layout_or_out new_inputs = list(inputs) + X = inputs[0] W = inputs[1] + B = inputs[2] if len(inputs) > 2 else None if isinstance(W, ir.IRNode): - if not isinstance(W, ir.TensorBox): - W = ir.TensorBox(W) - new_inputs[1] = L.permute(W, [1, 0]) - return new_inputs, layout_or_out + if trans_w: + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + W = L.permute(W, [1, 0]) else: - assert isinstance(W, torch.Tensor) - new_inputs[1] = W.transpose(0, 1) + if trans_w: + assert isinstance(W, torch.Tensor) + W = W.transpose(0, 1) + if B is not None: + if isinstance(B, ir.IRNode): + if not isinstance(B, ir.TensorBox): + B = ir.TensorBox(B) + B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) + else: + assert isinstance(B, torch.Tensor) + B = B.expand(X.shape[0], B.shape[-1]) + new_inputs[1] = W + if B is not None: + new_inputs[2] = B return new_inputs, layout_or_out # TODO(jgong5): decide proper number of threads per problem size num_threads = parallel_num_threads() - new_inputs, _ = transpose_weight(*reorder_and_filter(input_nodes, layout)) + new_inputs, _ = normalize_shapes( + *maybe_to_dense(*reorder_and_filter(input_nodes, layout)) + ) m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) micro_gemm = create_micro_gemm( - "micro_gemm", m, n, k, layout.dtype, alpha=alpha, num_threads=num_threads + "micro_gemm", + m, + n, + k, + input_dtype=layout.dtype, + output_dtype=torch.float, + alpha=alpha, + num_threads=num_threads, ) assert micro_gemm is not None _, block_n, _ = micro_gemm.register_blocking @@ -301,7 +338,9 @@ def pack_weight(inputs, layout_or_out): return new_inputs, layout_or_out def preprocessor(inputs, layout): - return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout))) + return pack_weight( + *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) + ) def postprocessor(output): if isinstance(output, ir.TensorBox): @@ -316,7 +355,7 @@ def postprocessor(output): W = V.graph.constants[W_node.get_name()] new_input_nodes[1] = W new_input_nodes, _ = pack_weight( - *transpose_weight(new_input_nodes, layout) + *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) ) W_packed = new_input_nodes[1] W_packed_constant = V.graph.add_tensor_constant(W_packed) @@ -359,8 +398,7 @@ def render( # type: ignore[override] template_buffer = Y Y_is_transposed = False - # TODO(jgong5): support local accumulation - use_local_acc = False + use_local_acc = self.layout.dtype != torch.float if epilogue_nodes: Y = cast(ir.Buffer, epilogue_nodes[-1]) assert Y.get_name() in V.kernel.inplace_update_buffers @@ -374,7 +412,8 @@ def render( # type: ignore[override] self.m, self.n, self.k, - self.layout.dtype, + input_dtype=self.layout.dtype, + output_dtype=torch.float, alpha=self.alpha, num_threads=self.num_threads, ) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 65b270285f47..47d6e87e5a70 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -60,7 +60,11 @@ def __init__( def get_common_options(self): return { + "torch": torch, "kernel_name": self.name, + "input_dtype": self.input_dtype, + "output_dtype": self.output_dtype, + "compute_dtype": self.compute_dtype, "input_t": DTYPE_TO_CPP[self.input_dtype], "output_t": DTYPE_TO_CPP[self.output_dtype], "compute_t": DTYPE_TO_CPP[self.compute_dtype], @@ -137,6 +141,29 @@ def inner(cls): return inner +def generate_gemm_config( + vec_isa_cls, + register_blockings, + input_dtype=torch.float, + output_dtype=None, + compute_dtype=None, +): + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = output_dtype + return [ + CppMicroGemmConfig( + input_dtype, + output_dtype, + compute_dtype, + vec_isa_cls, + GemmBlocking(*blocking), + ) + for blocking in register_blockings + ] + + class CppMicroGemmRef(CppMicroGemm): """ A reference implementation of the CppMicroGemm class with naive C++ code. @@ -171,28 +198,41 @@ def codegen_define(self, kernel: CppTemplateKernel) -> str: @register_micro_gemm( - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 48, 1) + *generate_gemm_config( + VecAVX512, [(8, 48, 1), (8, 32, 1), (16, 16, 1)], input_dtype=torch.float ), - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 32, 1) + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, ), - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(16, 16, 1) + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.half, + output_dtype=torch.float, ), - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 24, 1) + *generate_gemm_config( + VecAVX2, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.float ), - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 16, 1) + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, ), - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(8, 8, 1) + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.half, + output_dtype=torch.float, ), ) class CppMicroGemmFP32Vec(CppMicroGemm): """ - This class generates the code for fp32 micro gemm using vec instructions. + This class generates the code for micro gemm using fp32 vec instructions for compute. + It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output. """ TEMPLATE_ENTRY = r""" @@ -240,22 +280,23 @@ class CppMicroGemmFP32Vec(CppMicroGemm): TEMPLATE_KERNEL = r""" template inline void {{kernel_name}}_kernel( - const float* __restrict__ A, - const float* __restrict__ B, - float* __restrict__ C, + const {{input_t}}* __restrict__ A, + const {{input_t}}* __restrict__ B, + {{output_t}}* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc ) { - using Vectorized = at::vec::Vectorized; + using Vectorized = at::vec::Vectorized<{{compute_t}}>; + using VectorizedIn = at::vec::Vectorized<{{input_t}}>; constexpr auto VLEN = Vectorized::size(); constexpr auto ROWS = BLOCK_M; constexpr auto COLS = BLOCK_N / VLEN; Vectorized va; - at::vec::VectorizedN vb; - at::vec::VectorizedN vc; + at::vec::VectorizedN<{{compute_t}}, COLS> vb; + at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; auto loadc = [&](auto i) { if constexpr (accum) { @@ -274,14 +315,19 @@ class CppMicroGemmFP32Vec(CppMicroGemm): if constexpr (col == 0) { {%- if alpha != 1 %} - va = Vectorized(A[row * lda + k] * {{alpha}}); + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}}); {%- else %} - va = Vectorized(A[row * lda + k]); + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k])); {%- endif %} } if constexpr (row == 0) { + {%- if input_dtype == torch.bfloat16 or input_dtype == torch.float16 %} + auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); + vb[col] = at::vec::convert<{{compute_t}}>(b); + {%- else %} vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); + {%- endif %} } constexpr int idx = row * COLS + col; @@ -350,7 +396,7 @@ def create_from_config(cls, config: CppMicroGemmConfig): if output_dtype is None: output_dtype = input_dtype if compute_dtype is None: - compute_dtype = input_dtype + compute_dtype = output_dtype if num_threads < 0: num_threads = parallel_num_threads() vec_isa = pick_vec_isa() diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 34065e412f84..04bc8f1ec3d9 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -14,7 +14,7 @@ from ..virtualized import V from .common import Kernel, OpOverrides from .cpp import CppKernelProxy, KernelGroup -from .cpp_utils import cexpr_index, DTYPE_TO_CPP +from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferScope def parse_expr_with_index_symbols(expr): @@ -111,7 +111,13 @@ def index(self, node: ir.Buffer, indices: List[Any]) -> str: indexer = node.layout.as_fixed().make_indexer() index = indexer(parse_expr_with_index_symbols(indices)) index = self.rename_indexing(index) - return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" + outer_name = node.get_name() + inner_name = ( + outer_name + if outer_name in self.local_buffers + else self.args.input(node.get_name()) + ) + return f"{inner_name}[{cexpr_index(index)}]" def slice_nd(self, node, ranges: List[Tuple[Any, Any]]) -> ir.ReinterpretView: """ @@ -170,6 +176,50 @@ def define_buffer(self, name, sizes: List[Any], dtype=torch.float) -> str: numel = f"{cexpr_index(buf.get_numel())}" return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();" + def store_pointwise_nodes( + self, + dst: ir.Buffer, + nodes: List[ir.IRNode], + offsets: Optional[List[sympy.Expr]] = None, + reindexer: Optional[Callable[[List[Any]], List[Any]]] = None, + ) -> str: + var_sizes = (tuple(dst.get_size()), ()) + var_ranges = {sympy.Symbol(f"z{i}"): sz for i, sz in enumerate(var_sizes[0])} + if not offsets: + offsets = [sympy.Integer(0)] * len(var_sizes[0]) + assert len(offsets) == len(var_sizes[0]) + output_index = dst.get_layout().make_indexer()(var_ranges.keys()) + kernel_group = KernelGroup() + kernel_group.args = self.args + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + for i, node in enumerate(nodes): + output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name() + node = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(node, ir.Pointwise), node + + def fn(*args): + assert len(args) == 2 + assert len(args[0]) == len(var_sizes[0]) + assert len(args[1]) == 0 + new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] + if reindexer is not None: + new_args = reindexer(new_args) + V.ops.store( + output_name, + output_index, + node.make_loader()(new_args).value, + ) + + body = ir.LoopBody(fn, (list(var_ranges.keys()), ()), var_ranges) + bodies.append(body) + var_sizes_list.append(var_sizes) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + return kernel_group.loops_code.getvalue() + def store_output( self, dst: ir.Buffer, @@ -197,55 +247,20 @@ def store_output( needed on the indices to `epilogue_nodes` to match the indexing of `dst`. """ assert dst.get_size() == src.get_size() - if epilogue_nodes: - var_sizes = (tuple(dst.get_size()), ()) - var_ranges = { - sympy.Symbol(f"z{i}"): sz for i, sz in enumerate(var_sizes[0]) - } - - # epilogues are all pointwises, hence all indexed the same way as dst - output_index = dst.get_layout().make_indexer()(var_ranges.keys()) - - if not offsets: - offsets = [0] * len(var_sizes[0]) - assert len(offsets) == len(var_sizes[0]) + if offsets: offsets = parse_expr_with_index_symbols(offsets) - - kernel_group = KernelGroup() - kernel_group.args = self.args - cpp_kernel_proxy = CppKernelProxy(kernel_group) - bodies = [] - var_sizes_list = [] - for i, node in enumerate(epilogue_nodes): - assert isinstance(node, ir.ComputedBuffer) - output_name = ( - node.get_name() if i < len(epilogue_nodes) - 1 else dst.get_name() - ) - - def fn(*args): - assert len(args) == 2 - assert len(args[0]) == len(var_sizes[0]) - assert len(args[1]) == 0 - new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] - if reindexer is not None: - new_args = reindexer(new_args) - V.ops.store( - output_name, - output_index, - node.data.make_loader()(new_args).value, - ) - - body = ir.LoopBody(fn, (list(var_ranges.keys()), ()), var_ranges) - bodies.append(body) - var_sizes_list.append(var_sizes) - - cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - kernel_group.finalize_kernel(cpp_kernel_proxy, []) - return kernel_group.loops_code.getvalue() + if epilogue_nodes: + return self.store_pointwise_nodes(dst, epilogue_nodes, offsets, reindexer) else: - # TODO(jgong5): support local acc buffer to avoid assertion below - assert dst.get_name() == src.get_name() and dst.layout == src.layout - return "" + if dst.get_name() != src.get_name(): + # src is local + copy = L.copy(dst, src).data.data + with LocalBufferScope(self) as scope: + scope.add_local_buffer(src) + return self.store_pointwise_nodes(dst, [copy]) + else: + assert dst.layout == src.layout + return "" class CppTemplateCaller(ir.ChoiceCaller): diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index ef7566c8bcba..9534ff8e5d09 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -1,11 +1,16 @@ # mypy: allow-untyped-defs +import contextlib import math from collections import namedtuple +from typing import Dict +from unittest.mock import patch import torch +from .. import ir +from ..virtualized import V -from .common import ExprPrinter +from .common import ExprPrinter, Kernel DTYPE_TO_CPP = { torch.float32: "float", @@ -291,3 +296,58 @@ def value_to_cpp(value, cpp_type): return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" else: return f"static_cast<{cpp_type}>({repr(value)})" + + +class LocalBufferScope: + """ + This class creates a context that helps to generate code involving Inductor IR with + function local buffers. These buffers are constructed during the codegen process and + are used to store intermediate results such as local accumulators. We do not want to + add them to `V.graph` since they are not global and we do not want to add them as + function arguments either. So we patch the codegen processes under this scope to support + these buffers without exposure to the outside world. + """ + + def __init__(self, kernel: Kernel): + self.kernel = kernel + self.exit_stack = contextlib.ExitStack() + self.local_buffers: Dict[str, ir.Buffer] = {} + + def __enter__(self): + self.exit_stack.__enter__() + original_get_dtype = V.graph.get_dtype + + def get_dtype(name): + if name in self.local_buffers: + return self.local_buffers[name].get_dtype() + return original_get_dtype(name) + + self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) + + original_input = self.kernel.args.input + + def input(name): + if name in self.local_buffers: + return name + return original_input(name) + + self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input)) + + original_output = self.kernel.args.output + + def output(name): + if name in self.local_buffers: + return name + return original_output(name) + + self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output)) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.local_buffers.clear() + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def add_local_buffer(self, buffer: ir.Buffer): + assert buffer.get_name() not in self.local_buffers + self.local_buffers[buffer.get_name()] = buffer diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 9255ee94fe83..1edafadd68f0 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6337,7 +6337,7 @@ def codegen(self, wrapper): ) @classmethod - def create(cls, x, w, b, attr, scalars, algorithm): + def create(cls, x, w, B, attr, scalars, algorithm): x = cls.require_contiguous(cls.realize_input(x)) w = cls.require_contiguous(cls.realize_input(w)) @@ -6345,9 +6345,9 @@ def create(cls, x, w, b, attr, scalars, algorithm): oc, ic = w.get_size() inputs = [x, w] constant_args = [attr, scalars if scalars else [-1], algorithm] - if b is not None: - b = cls.require_contiguous(cls.realize_input(b)) - inputs.append(b) + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) else: constant_args.insert(0, None) diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index f1d82dcf7d60..721c54385d33 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -14,14 +14,25 @@ permute, register_lowering, to_dtype, + view, +) +from .select_algorithm import ( + autotune_select_algorithm, + ChoiceCaller, + ExternKernelChoice, ) -from .select_algorithm import autotune_select_algorithm, ExternKernelChoice from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune from .virtualized import V def register_onednn_fusion_ops(): if torch._C._has_mkldnn: + aten_mkldnn_linear_unary = ExternKernelChoice( + torch.ops.mkldnn._linear_pointwise, + "mkldnn::_linear_pointwise", + has_out_variant=False, + kernel_creator=ir.LinearUnary.create, + ) cpu_needs_realized_inputs = [ torch.ops.mkldnn._convolution_pointwise, torch.ops.mkldnn._convolution_pointwise_, @@ -129,11 +140,77 @@ def convolution_binary_inplace( @register_lowering(torch.ops.mkldnn._linear_pointwise) def linear_unary( - x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm + x: TensorBox, + w: TensorBox, + b: TensorBox, + attr, + scalars, + algorithm, + layout=None, ): - return TensorBox.create( - ir.LinearUnary.create(x, w, b, attr, scalars, algorithm) + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + choices: List[ChoiceCaller] = [] + if len(choices) == 0 or use_aten_gemm_kernels(): + choices.append( + aten_mkldnn_linear_unary.bind( + (x, w), + layout, + B=None, + attr=attr, + scalars=scalars, + algorithm=algorithm, + ) + if b is None + else aten_mkldnn_linear_unary.bind( + (x, w, b), + layout, + attr=attr, + scalars=scalars, + algorithm=algorithm, + ) + ) + if use_max_autotune(): + transposed_w = permute(w, [1, 0]) + *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) + if b is not None: + b = ir.ExternKernel.realize_input(b) + # TODO(jgong5): support epilogue fusion + if ( + use_cpp_packed_gemm_template(layout, x, transposed_w) + and attr == "none" + ): + if b is None: + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, w], + trans_w=True, + ) + else: + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, w, b], + trans_w=True, + input_indices=[2, 0, 1], + ) + assert w.get_name() in V.graph.constants + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + } + result = autotune_select_algorithm( + "linear_unary", + choices, + [x, w] if b is None else [x, w, b], + layout, + input_gen_fns=input_gen_fns, ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr): @@ -434,15 +511,7 @@ def mkl_packed_linear( *, layout=None, ): - choices = ( - [ - aten_mkl_linear.bind( - (x, packed_w, orig_w), layout, B=None, batch_size=batch_size - ) - ] - if use_aten_gemm_kernels() - else [] - ) + choices: List[ChoiceCaller] = [] if use_max_autotune(): transposed_w = permute(orig_w, [1, 0]) *_, layout, x, transposed_w = mm_args( @@ -457,6 +526,13 @@ def mkl_packed_linear( input_indices=[0, 2], ) + if len(choices) == 0 or use_aten_gemm_kernels(): + choices.append( + aten_mkl_linear.bind( + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size + ) + ) + assert packed_w.get_name() in V.graph.constants assert orig_w.get_name() in V.graph.constants # packed_w is a mkldnn tensor which we can't generate directly diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index ea3826855f59..f59fa34fe9c0 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1062,7 +1062,7 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): if not config.cpp.weight_prepack: return False - layout_dtypes = [torch.float32] + layout_dtypes = [torch.float32, torch.bfloat16, torch.half] m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2) # TODO(jgong5): support dynamic shapes for n or k if has_free_symbols((n, k)): @@ -1070,7 +1070,13 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): if isinstance(mat2, ir.BaseView): mat2 = mat2.unwrap_view() micro_gemm = create_micro_gemm( - "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads() + "micro_gemm", + m, + n, + k, + input_dtype=layout.dtype, + output_dtype=torch.float, + num_threads=parallel_num_threads(), ) # TODO(jgong5): support n % n_block_size != 0 return ( From 2386045e4f023466bbfbddddd83c89e3248a2a58 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Tue, 11 Jun 2024 12:54:06 +0000 Subject: [PATCH 685/706] Add OpInfo entry for alias_copy (#127232) (#128142) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128142 Approved by: https://github.com/lezcano --- .../ATen/functorch/BatchRulesDecompositions.cpp | 1 + test/distributed/_tensor/test_dtensor_ops.py | 1 + .../HasDecompTest.test_has_decomposition.expect | 2 -- test/functorch/test_vmap_registrations.py | 1 + test/onnx/test_fx_op_consistency.py | 4 ++++ test/test_mps.py | 1 + tools/autograd/gen_variable_type.py | 1 + torch/_decomp/__init__.py | 1 + torch/_inductor/exc.py | 2 +- torch/_refs/__init__.py | 4 ++++ .../_internal/common_methods_invocations.py | 15 +++++++++++++++ 11 files changed, 30 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 3e064d6c39dc..a0007aa18a00 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -324,6 +324,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(type_as); OP_DECOMPOSE(linalg_diagonal); OP_DECOMPOSE(diagonal_copy); + OP_DECOMPOSE(alias_copy); m.impl("pad", native::pad_symint); m.impl("_pad_circular", native::_pad_circular_symint); OP_DECOMPOSE(swapdims_); diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 83f0bb875167..07f8bfedc615 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -102,6 +102,7 @@ def wrapped(fn): xfail("addr"), xfail("all"), xfail("allclose"), + xfail("alias_copy"), xfail("amax"), xfail("amin"), xfail("aminmax"), diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index ad9cf07d7550..eeee3685e1fb 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -647,8 +647,6 @@ aten::adaptive_max_pool3d_backward.grad_input aten::addbmm aten::addbmm.out aten::addr_ -aten::alias_copy -aten::alias_copy.out aten::allclose aten::angle aten::angle.out diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index 967152945af5..737927a60f80 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -25,6 +25,7 @@ } xfail_functorch_batched_decomposition = { + "aten::alias_copy", "aten::diagonal_copy", "aten::is_same_size", "aten::unfold_copy", diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index e72c4206d578..6d675d446030 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -218,6 +218,10 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.COMPLEX_TYPES, reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") ), + xfail( + "alias_copy", + reason="OnnxExporterError: Failed to export model", + ), xfail( "allclose", reason=onnx_test_common.reason_dynamo_does_not_support("Allclose") diff --git a/test/test_mps.py b/test/test_mps.py index 93437fd5509d..00fc5c01c78d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -243,6 +243,7 @@ def mps_ops_modifier(ops): '__getitem__', 'abs', 'add', + 'alias_copy', 'argwhere', 'atleast_1d', 'atleast_2d', diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index b9651ea2da80..6abb13d244e9 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -305,6 +305,7 @@ "linalg_eig", "diagonal_copy", "diagonal_scatter", + "alias_copy", "select_backward", "diagonal_backward", "slice_backward", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index e0c7e5b6f49d..7674e5f466a8 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -261,6 +261,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.addcmul_, aten.addr, aten.affine_grid_generator, + aten.alias_copy, aten.all, aten.aminmax, aten.arange.default, diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index 27dcc6d8ef2d..8a172d8c29b1 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -46,7 +46,7 @@ def __init__(self, target, args, kwargs): There is a decomposition available for {target} in torch._decomp.get_decompositions(). Please add this operator to the - `decompositions` list in torch._inductor.decompositions + `decompositions` list in torch._inductor.decomposition """ ) ) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index db1f2a99d3d4..e0157368c62c 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -233,6 +233,7 @@ # View & Shape Ops # "alias", + "alias_copy", "atleast_1d", "atleast_2d", "atleast_3d", @@ -4462,6 +4463,9 @@ def alias(a: TensorLikeType) -> TensorLikeType: return prims.view_of(a) +alias_copy = _make_copy_from_view(alias) + + @register_decomposition(aten.transpose) def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index edacc3c4023e..5c32d1a11aff 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11588,6 +11588,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1): out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:] return np.reshape(input, out_shape) + +def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad)) + yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) + + # Operator database (sorted alphabetically) op_db: List[OpInfo] = [ UnaryUfuncInfo('abs', @@ -13091,6 +13097,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_diagonal_scatter), + OpInfo('alias_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_alias_copy, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), BinaryUfuncInfo('eq', ref=np.equal, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), @@ -23228,6 +23239,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # # View & Shape OpInfos # + PythonRefInfo( + "_refs.alias_copy", + torch_opinfo_name="alias_copy", + ), PythonRefInfo( "_refs.atleast_1d", torch_opinfo_name="atleast_1d", From 26433b86dea3b697f8d1132e2f452da9c674fdee Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 10 Jun 2024 19:16:55 +0000 Subject: [PATCH 686/706] [BE][Easy] sort `__all__` in `torch/__init__.py` (#127709) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127709 Approved by: https://github.com/ezyang ghstack dependencies: #127703, #127708 --- torch/__init__.py | 211 +++++++++++++++++++++++++++++----------------- 1 file changed, 134 insertions(+), 77 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index 156f35939d02..5936d950e1f7 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -52,28 +52,80 @@ def _running_with_deploy(): __all__ = [ - 'typename', 'is_tensor', 'is_storage', - 'set_default_tensor_type', 'set_default_device', 'get_default_device', - 'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed', - 'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul', - 'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode', - 'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage', - 'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage', - 'TypedStorage', 'UntypedStorage', - 'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor', - 'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor', - 'lobpcg', 'use_deterministic_algorithms', - 'are_deterministic_algorithms_enabled', - 'is_deterministic_algorithms_warn_only_enabled', - 'set_deterministic_debug_mode', 'get_deterministic_debug_mode', - 'set_float32_matmul_precision', 'get_float32_matmul_precision', - 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', - 'SymBool', 'sym_not', 'unravel_index', - 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap', - 'export', 'autocast', 'cond', 'GradScaler', - 'get_device_module', + "BoolStorage", + "BoolTensor", + "ByteStorage", + "ByteTensor", + "CharStorage", + "CharTensor", + "DoubleStorage", + "DoubleTensor", + "FloatStorage", + "FloatTensor", + "GradScaler", + "IntStorage", + "IntTensor", + "LongStorage", + "LongTensor", + "ShortStorage", + "ShortTensor", + "SymBool", + "SymFloat", + "SymInt", + "Tensor", + "TypedStorage", + "UntypedStorage", + "are_deterministic_algorithms_enabled", + "autocast", + "chunk", + "compile", + "cond", + "enable_grad", + "export", + "get_default_device", + "get_deterministic_debug_mode", + "get_device_module", + "get_float32_matmul_precision", + "get_rng_state", + "inference_mode", + "initial_seed", + "is_deterministic_algorithms_warn_only_enabled", + "is_storage", + "is_tensor", + "is_warn_always_enabled", + "load", + "lobpcg", + "manual_seed", + "matmul", + "no_grad", + "rand", + "randn", + "save", + "seed", + "set_default_device", + "set_default_tensor_type", + "set_deterministic_debug_mode", + "set_float32_matmul_precision", + "set_printoptions", + "set_rng_state", + "set_warn_always", + "split", + "stack", + "sym_float", + "sym_int", + "sym_ite", + "sym_max", + "sym_min", + "sym_not", + "typename", + "unravel_index", + "use_deterministic_algorithms", + "vmap", ] +# Please keep this list sorted +assert __all__ == sorted(__all__) + ################################################################################ # Load the extension module ################################################################################ @@ -133,8 +185,14 @@ def _load_dll_libraries(): ctypes.CDLL('msvcp140.dll') ctypes.CDLL('vcruntime140_1.dll') except OSError: - print('''Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. - It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe''') + print( + textwrap.dedent( + """ + Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. + It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe + """ + ).strip() + ) dlls = glob.glob(os.path.join(th_dll_path, '*.dll')) path_patched = False @@ -398,55 +456,55 @@ def __rpow__(self, other): return sym_float(self).__rpow__(sym_float(other)) def __eq__(self, other: object) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __lt__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __gt__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __le__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __ge__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __add__(self, other) -> "SymInt": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __mul__(self, other) -> "SymInt": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __pow_by_natural__(self, other) -> "SymInt": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __rpow_by_natural__(self, other) -> "SymInt": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __int_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __rint_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __int_floordiv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __rint_floordiv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_max__(self, other): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_min__(self, other): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_float__(self): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __neg__(self): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __repr__(self): return str(self.node) @@ -510,47 +568,47 @@ def __rpow__(self, other): # Magic methods installed by torch.fx.experimental.sym_node def __eq__(self, other: object) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __lt__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __gt__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __le__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __ge__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __float_pow__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __rfloat_pow__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __float_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __rfloat_truediv__(self, other) -> "SymFloat": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __trunc__(self): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_max__(self, other): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_min__(self, other): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_int__(self): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def is_integer(self): """Return True if the float is an integer.""" - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __repr__(self): return self.node.str() @@ -578,10 +636,10 @@ def __int__(self): # Magic methods installed by torch.fx.experimental.sym_node def __and__(self, other) -> "SymBool": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __or__(self, other) -> "SymBool": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") # We very carefully define __sym_not__, and not a number of other # plausible alternatives: @@ -601,13 +659,13 @@ def __or__(self, other) -> "SymBool": # so we reuse the conventional operators there for readability. # def __sym_not__(self) -> "SymBool": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_ite__(self, then_val, else_val): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __eq__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __repr__(self): return str(self.node) @@ -646,7 +704,7 @@ def sym_float(a): return a elif hasattr(a, '__sym_float__'): return a.__sym_float__() - return py_float(a) # type: ignore[operator] + return builtins.float(a) # type: ignore[operator] def sym_int(a): @@ -661,7 +719,7 @@ def sym_int(a): return a elif isinstance(a, SymFloat): return math.trunc(a) - return py_int(a) # type: ignore[operator] + return builtins.int(a) # type: ignore[operator] def sym_max(a, b): """ @@ -744,18 +802,22 @@ def sym_ite(b, t, f): # The __file__ check only works for Python 3.7 and above. if _C_for_compiled_check.__file__ is None: - raise ImportError(textwrap.dedent(''' - Failed to load PyTorch C extensions: - It appears that PyTorch has loaded the `torch/_C` folder - of the PyTorch repository rather than the C extensions which - are expected in the `torch._C` namespace. This can occur when - using the `install` workflow. e.g. - $ python setup.py install && python -c "import torch" - - This error can generally be solved using the `develop` workflow - $ python setup.py develop && python -c "import torch" # This should succeed - or by running Python from a different directory. - ''').strip()) from None + raise ImportError( + textwrap.dedent( + """ + Failed to load PyTorch C extensions: + It appears that PyTorch has loaded the `torch/_C` folder + of the PyTorch repository rather than the C extensions which + are expected in the `torch._C` namespace. This can occur when + using the `install` workflow. e.g. + $ python setup.py install && python -c "import torch" + + This error can generally be solved using the `develop` workflow + $ python setup.py develop && python -c "import torch" # This should succeed + or by running Python from a different directory. + """ + ).strip() + ) from None raise # If __file__ is not None the cause is unknown, so just re-raise. # The torch._C submodule is already loaded via `from torch._C import *` above @@ -1682,11 +1744,6 @@ def _dtype(self): from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state from torch.serialization import load, save -# Initializing the extension shadows the built-in python float / int classes; -# store them for later use by SymInt / SymFloat. -py_float = float -py_int = int - ################################################################################ # Initialize extension ################################################################################ From 46a35a1ed4fe0b13a2730043318ce77f1aa02542 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 10 Jun 2024 19:16:56 +0000 Subject: [PATCH 687/706] [BE] enable UFMT for `torch/__init__.py` (#127710) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127710 Approved by: https://github.com/ezyang ghstack dependencies: #127703, #127708, #127709 --- .lintrunner.toml | 1 - torch/__init__.py | 462 +++++++++++++++++++++++++++++++--------------- 2 files changed, 316 insertions(+), 147 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 5ccab63f487e..92a7fc0b1d8e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -999,7 +999,6 @@ command = [ ] exclude_patterns = [ 'tools/gen_vulkan_spv.py', - 'torch/__init__.py', # Skip this file to format because it's part of the public API # We don't care too much about files in this directory, don't enforce # formatting on them 'caffe2/**/*.py', diff --git a/torch/__init__.py b/torch/__init__.py index 5936d950e1f7..aa68247ed3a8 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -130,47 +130,67 @@ def _running_with_deploy(): # Load the extension module ################################################################################ -if sys.platform == 'win32': +if sys.platform == "win32": def _load_dll_libraries(): import sysconfig from torch.version import cuda as cuda_version - pfiles_path = os.getenv('ProgramFiles', r'C:\Program Files') - py_dll_path = os.path.join(sys.exec_prefix, 'Library', 'bin') - th_dll_path = os.path.join(os.path.dirname(__file__), 'lib') - usebase_path = os.path.join(sysconfig.get_config_var("userbase"), 'Library', 'bin') + pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files") + py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin") + th_dll_path = os.path.join(os.path.dirname(__file__), "lib") + usebase_path = os.path.join( + sysconfig.get_config_var("userbase"), "Library", "bin" + ) # When users create a virtualenv that inherits the base environment, # we will need to add the corresponding library directory into # DLL search directories. Otherwise, it will rely on `PATH` which # is dependent on user settings. if sys.exec_prefix != sys.base_exec_prefix: - base_py_dll_path = os.path.join(sys.base_exec_prefix, 'Library', 'bin') + base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin") else: - base_py_dll_path = '' + base_py_dll_path = "" - dll_paths = [p for p in (th_dll_path, py_dll_path, base_py_dll_path, usebase_path) if os.path.exists(p)] + dll_paths = [ + p + for p in (th_dll_path, py_dll_path, base_py_dll_path, usebase_path) + if os.path.exists(p) + ] - if not builtins.any(os.path.exists(os.path.join(p, 'nvToolsExt64_1.dll')) for p in dll_paths): + if not builtins.any( + os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths + ): nvtoolsext_dll_path = os.path.join( - os.getenv('NVTOOLSEXT_PATH', os.path.join(pfiles_path, 'NVIDIA Corporation', 'NvToolsExt')), 'bin', 'x64') + os.getenv( + "NVTOOLSEXT_PATH", + os.path.join(pfiles_path, "NVIDIA Corporation", "NvToolsExt"), + ), + "bin", + "x64", + ) else: - nvtoolsext_dll_path = '' - - if cuda_version and builtins.all(not glob.glob(os.path.join(p, 'cudart64*.dll')) for p in dll_paths): - cuda_version_1 = cuda_version.replace('.', '_') - cuda_path_var = 'CUDA_PATH_V' + cuda_version_1 - default_path = os.path.join(pfiles_path, 'NVIDIA GPU Computing Toolkit', 'CUDA', 'v' + cuda_version) - cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), 'bin') + nvtoolsext_dll_path = "" + + if cuda_version and builtins.all( + not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths + ): + cuda_version_1 = cuda_version.replace(".", "_") + cuda_path_var = "CUDA_PATH_V" + cuda_version_1 + default_path = os.path.join( + pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}" + ) + cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin") else: - cuda_path = '' + cuda_path = "" - dll_paths.extend(p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p)) + dll_paths.extend( + p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p) + ) - kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) - with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") prev_error_mode = kernel32.SetErrorMode(0x0001) kernel32.LoadLibraryW.restype = ctypes.c_void_p @@ -181,9 +201,9 @@ def _load_dll_libraries(): os.add_dll_directory(dll_path) try: - ctypes.CDLL('vcruntime140.dll') - ctypes.CDLL('msvcp140.dll') - ctypes.CDLL('vcruntime140_1.dll') + ctypes.CDLL("vcruntime140.dll") + ctypes.CDLL("msvcp140.dll") + ctypes.CDLL("vcruntime140_1.dll") except OSError: print( textwrap.dedent( @@ -194,7 +214,7 @@ def _load_dll_libraries(): ).strip() ) - dlls = glob.glob(os.path.join(th_dll_path, '*.dll')) + dlls = glob.glob(os.path.join(th_dll_path, "*.dll")) path_patched = False for dll in dlls: is_loaded = False @@ -203,18 +223,22 @@ def _load_dll_libraries(): last_error = ctypes.get_last_error() if res is None and last_error != 126: err = ctypes.WinError(last_error) - err.strerror += f' Error loading "{dll}" or one of its dependencies.' + err.strerror += ( + f' Error loading "{dll}" or one of its dependencies.' + ) raise err elif res is not None: is_loaded = True if not is_loaded: if not path_patched: - os.environ['PATH'] = ';'.join(dll_paths + [os.environ['PATH']]) + os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]]) path_patched = True res = kernel32.LoadLibraryW(dll) if res is None: err = ctypes.WinError(ctypes.get_last_error()) - err.strerror += f' Error loading "{dll}" or one of its dependencies.' + err.strerror += ( + f' Error loading "{dll}" or one of its dependencies.' + ) raise err kernel32.SetErrorMode(prev_error_mode) @@ -226,14 +250,16 @@ def _load_dll_libraries(): def _preload_cuda_deps(lib_folder, lib_name): """Preloads cuda deps if they could not be found otherwise.""" # Should only be called on Linux if default path resolution have failed - assert platform.system() == 'Linux', 'Should only be called on Linux' + assert platform.system() == "Linux", "Should only be called on Linux" lib_path = None for path in sys.path: - nvidia_path = os.path.join(path, 'nvidia') + nvidia_path = os.path.join(path, "nvidia") if not os.path.exists(nvidia_path): continue - candidate_lib_paths = glob.glob(os.path.join(nvidia_path, lib_folder, 'lib', lib_name)) + candidate_lib_paths = glob.glob( + os.path.join(nvidia_path, lib_folder, "lib", lib_name) + ) if candidate_lib_paths and not lib_path: lib_path = candidate_lib_paths[0] if lib_path: @@ -242,9 +268,9 @@ def _preload_cuda_deps(lib_folder, lib_name): raise ValueError(f"{lib_name} not found in the system path {sys.path}") ctypes.CDLL(lib_path) + # See Note [Global dependencies] def _load_global_deps() -> None: - LIBTORCH_PKG_NAME = "libtorchsplit" def find_package_path(package_name): @@ -261,16 +287,10 @@ def find_package_path(package_name): return None def load_shared_libraries(library_path): - lib_dir = os.path.join(library_path, 'lib') + lib_dir = os.path.join(library_path, "lib") if not os.path.exists(lib_dir): return - # Determine the file extension based on the platform - if platform.system() == 'Darwin': - lib_ext = '.dylib' - else: - lib_ext = '.so' - # Find all shared library files with the appropriate extension library_files = [f for f in os.listdir(lib_dir) if f.endswith(lib_ext)] if not library_files: @@ -283,37 +303,41 @@ def load_shared_libraries(library_path): except OSError as err: print(f"Failed to load {lib_path}: {err}") - if _running_with_deploy() or platform.system() == 'Windows': + if _running_with_deploy() or platform.system() == "Windows": return - lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so') + # Determine the file extension based on the platform + lib_ext = ".dylib" if platform.system() == "Darwin" else ".so" + lib_name = f"libtorch_global_deps{lib_ext}" here = os.path.abspath(__file__) - global_deps_lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) + global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name) split_build_lib_name = LIBTORCH_PKG_NAME library_path = find_package_path(split_build_lib_name) if library_path: - global_deps_lib_path = os.path.join(library_path, 'lib', lib_name) + global_deps_lib_path = os.path.join(library_path, "lib", lib_name) try: ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) except OSError as err: # Can only happen for wheel with cuda libs as PYPI deps # As PyTorch is not purelib, but nvidia-*-cu12 is cuda_libs: Dict[str, str] = { - 'cublas': 'libcublas.so.*[0-9]', - 'cudnn': 'libcudnn.so.*[0-9]', - 'cuda_nvrtc': 'libnvrtc.so.*[0-9]', - 'cuda_runtime': 'libcudart.so.*[0-9]', - 'cuda_cupti': 'libcupti.so.*[0-9]', - 'cufft': 'libcufft.so.*[0-9]', - 'curand': 'libcurand.so.*[0-9]', - 'cusolver': 'libcusolver.so.*[0-9]', - 'cusparse': 'libcusparse.so.*[0-9]', - 'nccl': 'libnccl.so.*[0-9]', - 'nvtx': 'libnvToolsExt.so.*[0-9]', + "cublas": "libcublas.so.*[0-9]", + "cudnn": "libcudnn.so.*[0-9]", + "cuda_nvrtc": "libnvrtc.so.*[0-9]", + "cuda_runtime": "libcudart.so.*[0-9]", + "cuda_cupti": "libcupti.so.*[0-9]", + "cufft": "libcufft.so.*[0-9]", + "curand": "libcurand.so.*[0-9]", + "cusolver": "libcusolver.so.*[0-9]", + "cusparse": "libcusparse.so.*[0-9]", + "nccl": "libnccl.so.*[0-9]", + "nvtx": "libnvToolsExt.so.*[0-9]", } - is_cuda_lib_err = [lib for lib in cuda_libs.values() if lib.split('.')[0] in err.args[0]] + is_cuda_lib_err = [ + lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] + ] if not is_cuda_lib_err: raise err for lib_folder, lib_name in cuda_libs.items(): @@ -324,8 +348,10 @@ def load_shared_libraries(library_path): # loading libtorch_global_deps first due its special logic load_shared_libraries(library_path) -if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \ - (_running_with_deploy() or platform.system() != 'Windows'): + +if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( + _running_with_deploy() or platform.system() != "Windows" +): # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a # few circumstances: # @@ -344,7 +370,9 @@ def load_shared_libraries(library_path): # old_flags = sys.getdlopenflags() sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY) + from torch._C import * # noqa: F403 + sys.setdlopenflags(old_flags) del old_flags @@ -516,6 +544,7 @@ def __hash__(self) -> builtins.int: # We could support constant SymInts as well, but not doing it for now raise TypeError("unhashable type: non-nested SymInt") + class SymFloat: """ Like an float (including magic methods), but redirects all operations on the @@ -613,6 +642,7 @@ def is_integer(self): def __repr__(self): return self.node.str() + class SymBool: """ Like an bool (including magic methods), but redirects all operations on the @@ -676,8 +706,9 @@ def __hash__(self): else: raise TypeError("unhashable type: SymBool") + def sym_not(a): - r""" SymInt-aware utility for logical negation. + r"""SymInt-aware utility for logical negation. Args: a (SymBool or bool): Object to negate @@ -686,14 +717,15 @@ def sym_not(a): if overrides.has_torch_function_unary(a): return overrides.handle_torch_function(sym_not, (a,), a) - if hasattr(a, '__sym_not__'): + if hasattr(a, "__sym_not__"): return a.__sym_not__() if isinstance(a, sympy.Basic): return ~a # type: ignore[operator] return not a + def sym_float(a): - r""" SymInt-aware utility for float casting. + r"""SymInt-aware utility for float casting. Args: a (SymInt, SymFloat, or object): Object to cast @@ -702,13 +734,13 @@ def sym_float(a): return overrides.handle_torch_function(sym_float, (a,), a) if isinstance(a, SymFloat): return a - elif hasattr(a, '__sym_float__'): + elif hasattr(a, "__sym_float__"): return a.__sym_float__() return builtins.float(a) # type: ignore[operator] def sym_int(a): - r""" SymInt-aware utility for int casting. + r"""SymInt-aware utility for int casting. Args: a (SymInt, SymFloat, or object): Object to cast @@ -721,6 +753,7 @@ def sym_int(a): return math.trunc(a) return builtins.int(a) # type: ignore[operator] + def sym_max(a, b): """ SymInt-aware utility for max which avoids branching on a < b. @@ -744,6 +777,7 @@ def sym_max(a, b): else: return builtins.max(a, b) + def sym_min(a, b): """SymInt-aware utility for min().""" if overrides.has_torch_function((a, b)): @@ -759,6 +793,7 @@ def sym_min(a, b): else: return builtins.min(a, b) + # Drop in replacement for math.sqrt, math.sin, math.cos etc def _get_sym_math_fn(name): def fn(a): @@ -770,8 +805,20 @@ def fn(a): return fn -__fn, __name, __sym_name = None, '', '' -for __name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan"): + +__fn, __name, __sym_name = None, "", "" +for __name in ( + "sqrt", + "cos", + "cosh", + "sin", + "sinh", + "tan", + "tanh", + "asin", + "acos", + "atan", +): __sym_name = f"_sym_{__name}" __fn = _get_sym_math_fn(__name) __fn.__qualname__ = __fn.__name__ = __sym_name @@ -792,6 +839,7 @@ def sym_ite(b, t, f): return b.__sym_ite__(t, f) return t if b else f + # Check to see if we can load C extensions, and if not provide some guidance # on what the problem might be. try: @@ -824,17 +872,21 @@ def sym_ite(b, t, f): # Make an explicit reference to the _C submodule to appease linters from torch import _C as _C -__name, __obj = '', None +__name, __obj = "", None for __name in dir(_C): - if __name[0] != '_' and not __name.endswith('Base'): + if __name[0] != "_" and not __name.endswith("Base"): __all__.append(__name) __obj = getattr(_C, __name) if callable(__obj) or inspect.isclass(__obj): if __obj.__module__ != __name__: # "torch" # TODO: fix their module from C++ side - if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']: + if __name not in { + "DisableTorchFunctionSubclass", + "DisableTorchFunction", + "Generator", + }: __obj.__module__ = __name__ # "torch" - elif __name == 'TensorBase': + elif __name == "TensorBase": # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch. delattr(sys.modules[__name__], __name) @@ -845,7 +897,7 @@ def sym_ite(b, t, f): # non-standard, and attributes of those submodules cannot be pickled since # pickle expect to be able to import them as "from _C.sub import attr" # which fails with "_C is not a package - __name, __candidate = '', None + __name, __candidate = "", None for __name in dir(_C): __candidate = getattr(_C, __name) if type(__candidate) is type(_C): @@ -864,15 +916,19 @@ def typename(o): if isinstance(o, torch.Tensor): return o.type() - module = '' - class_name = '' - if hasattr(o, '__module__') and o.__module__ != 'builtins' \ - and o.__module__ != '__builtin__' and o.__module__ is not None: - module = o.__module__ + '.' - - if hasattr(o, '__qualname__'): + module = "" + class_name = "" + if ( + hasattr(o, "__module__") + and o.__module__ != "builtins" + and o.__module__ != "__builtin__" + and o.__module__ is not None + ): + module = o.__module__ + "." + + if hasattr(o, "__qualname__"): class_name = o.__qualname__ - elif hasattr(o, '__name__'): + elif hasattr(o, "__name__"): class_name = o.__name__ else: class_name = o.__class__.__name__ @@ -983,6 +1039,7 @@ def set_default_device(device): device_context = None else: from torch.utils._device import DeviceContext + device_context = DeviceContext(device) device_context.__enter__() _GLOBAL_DEVICE_CONTEXT.device_context = device_context @@ -1071,8 +1128,13 @@ def set_default_dtype(d): """ _C._set_default_dtype(d) -def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.bool = False) -> None: - r""" Sets whether PyTorch operations must use "deterministic" + +def use_deterministic_algorithms( + mode: builtins.bool, + *, + warn_only: builtins.bool = False, +) -> None: + r"""Sets whether PyTorch operations must use "deterministic" algorithms. That is, algorithms which, given the same input, and when run on the same software and hardware, always produce the same output. When enabled, operations will use deterministic algorithms when available, @@ -1208,12 +1270,14 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo """ _C._set_deterministic_algorithms(mode, warn_only=warn_only) + def are_deterministic_algorithms_enabled() -> builtins.bool: r"""Returns True if the global deterministic flag is turned on. Refer to :func:`torch.use_deterministic_algorithms` documentation for more details. """ return _C._get_deterministic_algorithms() + def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool: r"""Returns True if the global deterministic flag is set to warn only. Refer to :func:`torch.use_deterministic_algorithms` documentation for more @@ -1221,6 +1285,7 @@ def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool: """ return _C._get_deterministic_algorithms_warn_only() + def set_deterministic_debug_mode(debug_mode: Union[builtins.int, str]) -> None: r"""Sets the debug mode for deterministic operations. @@ -1238,19 +1303,20 @@ def set_deterministic_debug_mode(debug_mode: Union[builtins.int, str]) -> None: # NOTE: builtins.int is used here because int in this scope resolves # to torch.int if not isinstance(debug_mode, (builtins.int, str)): - raise TypeError(f'debug_mode must be str or int, but got {type(debug_mode)}') + raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}") if isinstance(debug_mode, str): - if debug_mode == 'default': + if debug_mode == "default": debug_mode = 0 - elif debug_mode == 'warn': + elif debug_mode == "warn": debug_mode = 1 - elif debug_mode == 'error': + elif debug_mode == "error": debug_mode = 2 else: raise RuntimeError( - 'invalid value of debug_mode, expected one of `default`, ' - f'`warn`, `error`, but got {debug_mode}') + "invalid value of debug_mode, expected one of `default`, " + f"`warn`, `error`, but got {debug_mode}" + ) if debug_mode == 0: _C._set_deterministic_algorithms(False) @@ -1260,8 +1326,9 @@ def set_deterministic_debug_mode(debug_mode: Union[builtins.int, str]) -> None: _C._set_deterministic_algorithms(True) else: raise RuntimeError( - 'invalid value of debug_mode, expected 0, 1, or 2, ' - f'but got {debug_mode}') + "invalid value of debug_mode, expected 0, 1, or 2, " f"but got {debug_mode}" + ) + def get_deterministic_debug_mode() -> builtins.int: r"""Returns the current value of the debug mode for deterministic @@ -1277,12 +1344,14 @@ def get_deterministic_debug_mode() -> builtins.int: else: return 0 + def get_float32_matmul_precision() -> builtins.str: r"""Returns the current value of float32 matrix multiplication precision. Refer to :func:`torch.set_float32_matmul_precision` documentation for more details. """ return _C._get_float32_matmul_precision() + def set_float32_matmul_precision(precision: str) -> None: r"""Sets the internal precision of float32 matrix multiplications. @@ -1348,6 +1417,7 @@ def set_float32_matmul_precision(precision: str) -> None: """ _C._set_float32_matmul_precision(precision) + def set_warn_always(b: builtins.bool) -> None: r"""When this flag is False (default) then some PyTorch warnings may only appear once per process. This helps avoid excessive warning information. @@ -1360,12 +1430,14 @@ def set_warn_always(b: builtins.bool) -> None: """ _C._set_warnAlways(b) + def is_warn_always_enabled() -> builtins.bool: r"""Returns True if the global warn_always flag is turned on. Refer to :func:`torch.set_warn_always` documentation for more details. """ return _C._get_warnAlways() + ################################################################################ # Define error checking functions ################################################################################ @@ -1373,11 +1445,17 @@ def is_warn_always_enabled() -> builtins.bool: # These error checking functions must be kept consistent with their C++ # equivalents. Their C++ equivalents are mentioned where applicable. -def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): # noqa: F811 + +def _check_with( + error_type, + cond: Union[builtins.bool, SymBool], + message: Callable[[], str], +): # noqa: F811 if not isinstance(cond, (builtins.bool, torch.SymBool)): - raise TypeError(f'cond must be a bool, but got {type(cond)}') + raise TypeError(f"cond must be a bool, but got {type(cond)}") from torch.fx.experimental.symbolic_shapes import expect_true + if expect_true(cond): return @@ -1386,18 +1464,20 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab if message is None: message_evaluated = ( - 'Expected cond to be True, but got False. (Could this error ' - 'message be improved? If so, please report an enhancement request ' - 'to PyTorch.)') + "Expected cond to be True, but got False. (Could this error " + "message be improved? If so, please report an enhancement request " + "to PyTorch.)" + ) else: if not callable(message): - raise TypeError('message must be a callable') + raise TypeError("message must be a callable") message_evaluated = str(message()) raise error_type(message_evaluated) + def _check(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1415,6 +1495,7 @@ def _check(cond, message=None): # noqa: F811 """ _check_with(RuntimeError, cond, message) + def _check_is_size(i, message=None): """Checks that a given integer is a valid size (i.e., is non-negative). You should use this over _check(i >= 0) because we can use the semantic @@ -1428,8 +1509,10 @@ def _check_is_size(i, message=None): # This is responsible for the expect_true _check(i >= 0, message) from torch.fx.experimental.symbolic_shapes import _advise_is_size + _advise_is_size(i) + def _check_index(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1447,6 +1530,7 @@ def _check_index(cond, message=None): # noqa: F811 """ _check_with(IndexError, cond, message) + def _check_value(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1464,6 +1548,7 @@ def _check_value(cond, message=None): # noqa: F811 """ _check_with(ValueError, cond, message) + def _check_type(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1481,6 +1566,7 @@ def _check_type(cond, message=None): # noqa: F811 """ _check_with(TypeError, cond, message) + def _check_not_implemented(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1498,16 +1584,17 @@ def _check_not_implemented(cond, message=None): # noqa: F811 """ _check_with(NotImplementedError, cond, message) + def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811 if not torch.is_tensor(cond): - raise TypeError(f'cond must be a tensor, but got {type(cond)}') + raise TypeError(f"cond must be a tensor, but got {type(cond)}") if not cond.dtype == torch.bool: - raise TypeError( - f'cond tensor must have dtype torch.bool, but got {cond.dtype}') + raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}") _check_with(error_type, cond._is_all_true().item(), message) + # C++ equivalent: `TORCH_CHECK_TENSOR_ALL` def _check_tensor_all(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition @@ -1527,6 +1614,7 @@ def _check_tensor_all(cond, message=None): # noqa: F811 """ _check_tensor_all_with(RuntimeError, cond, message) + ################################################################################ # Define numeric constants ################################################################################ @@ -1534,8 +1622,10 @@ def _check_tensor_all(cond, message=None): # noqa: F811 # For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and # NumPy consistency (https://numpy.org/devdocs/reference/constants.html) from math import e, inf, nan, pi + newaxis: None = None -__all__.extend(['e', 'pi', 'nan', 'inf', 'newaxis']) + +__all__.extend(["e", "pi", "nan", "inf", "newaxis"]) ################################################################################ # Define Storage and Tensor classes @@ -1556,6 +1646,7 @@ def _check_tensor_all(cond, message=None): # noqa: F811 # NOTE: New Storage classes should never be added. When adding a new # dtype, use torch.storage.TypedStorage directly. + class ByteStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1566,6 +1657,7 @@ def dtype(self): def _dtype(self): return torch.uint8 + class DoubleStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1576,6 +1668,7 @@ def dtype(self): def _dtype(self): return torch.double + class FloatStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1586,6 +1679,7 @@ def dtype(self): def _dtype(self): return torch.float + class HalfStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1596,6 +1690,7 @@ def dtype(self): def _dtype(self): return torch.half + class LongStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1606,6 +1701,7 @@ def dtype(self): def _dtype(self): return torch.long + class IntStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1616,6 +1712,7 @@ def dtype(self): def _dtype(self): return torch.int + class ShortStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1626,6 +1723,7 @@ def dtype(self): def _dtype(self): return torch.short + class CharStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1636,6 +1734,7 @@ def dtype(self): def _dtype(self): return torch.int8 + class BoolStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1646,6 +1745,7 @@ def dtype(self): def _dtype(self): return torch.bool + class BFloat16Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1656,6 +1756,7 @@ def dtype(self): def _dtype(self): return torch.bfloat16 + class ComplexDoubleStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1666,6 +1767,7 @@ def dtype(self): def _dtype(self): return torch.cdouble + class ComplexFloatStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1676,6 +1778,7 @@ def dtype(self): def _dtype(self): return torch.cfloat + class QUInt8Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1686,6 +1789,7 @@ def dtype(self): def _dtype(self): return torch.quint8 + class QInt8Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1696,6 +1800,7 @@ def dtype(self): def _dtype(self): return torch.qint8 + class QInt32Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1706,6 +1811,7 @@ def dtype(self): def _dtype(self): return torch.qint32 + class QUInt4x2Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1716,6 +1822,7 @@ def dtype(self): def _dtype(self): return torch.quint4x2 + class QUInt2x4Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1726,12 +1833,27 @@ def dtype(self): def _dtype(self): return torch.quint2x4 + _storage_classes = { - UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage, - ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage, - QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage, - ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage, - TypedStorage + UntypedStorage, + DoubleStorage, + FloatStorage, + LongStorage, + IntStorage, + ShortStorage, + CharStorage, + ByteStorage, + HalfStorage, + BoolStorage, + QUInt8Storage, + QInt8Storage, + QInt32Storage, + BFloat16Storage, + ComplexFloatStorage, + ComplexDoubleStorage, + QUInt4x2Storage, + QUInt2x4Storage, + TypedStorage, } # The _tensor_classes set is initialized by the call to initialize_python_bindings. @@ -1748,15 +1870,17 @@ def _dtype(self): # Initialize extension ################################################################################ + # Shared memory manager needs to know the exact location of manager executable def _manager_path(): - if _running_with_deploy() or platform.system() == 'Windows': + if _running_with_deploy() or platform.system() == "Windows": return b"" - path = get_file_path('torch', 'bin', 'torch_shm_manager') - prepare_multiprocessing_environment(get_file_path('torch')) + path = get_file_path("torch", "bin", "torch_shm_manager") + prepare_multiprocessing_environment(get_file_path("torch")) if not os.path.exists(path): raise RuntimeError("Unable to find torch_shm_manager at " + path) - return path.encode('utf-8') + return path.encode("utf-8") + _C._initExtension(_manager_path()) @@ -1771,19 +1895,18 @@ def _manager_path(): # signatures already imported. For now these clashes are ignored; see # PR #43339 for details. from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403 + # Fixup segment_reduce visibility _segment_reduce = segment_reduce del segment_reduce # noqa: F821 # Ops not to be exposed in `torch` namespace, # mostly helper ops. -PRIVATE_OPS = ( - 'unique_dim', -) +PRIVATE_OPS = ("unique_dim",) -__name, __obj = '', None +__name, __obj = "", None for __name in dir(_C._VariableFunctions): - if __name.startswith('__') or __name in PRIVATE_OPS: + if __name.startswith("__") or __name in PRIVATE_OPS: continue __obj = getattr(_C._VariableFunctions, __name) __obj.__module__ = __name__ # "torch" @@ -1834,14 +1957,19 @@ def _manager_path(): # Define _assert ################################################################################ + # needs to be before the submodule imports to avoid circular dependencies def _assert(condition, message): - r"""A wrapper around Python's assert which is symbolically traceable. - """ - if type(condition) is not torch.Tensor and overrides.has_torch_function((condition,)): - return overrides.handle_torch_function(_assert, (condition,), condition, message) + r"""A wrapper around Python's assert which is symbolically traceable.""" + if type(condition) is not torch.Tensor and overrides.has_torch_function( + (condition,) + ): + return overrides.handle_torch_function( + _assert, (condition,), condition, message + ) assert condition, message + ################################################################################ # Import most common subpackages ################################################################################ @@ -1904,6 +2032,7 @@ def _assert(condition, message): # attach docstrings to torch and tensor functions from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs + del _torch_docs, _tensor_docs, _storage_docs, _size_docs @@ -1932,6 +2061,7 @@ def compiled_with_cxx11_abi() -> builtins.bool: # Register fork handler to initialize OpenMP in child processes (see gh-28389) from torch.multiprocessing._atfork import register_after_fork + register_after_fork(torch.get_num_threads) del register_after_fork @@ -1973,6 +2103,7 @@ def __init__(self, mode, options, dynamic): # Stash the compiler_fn to be used for backend match guard. from torch._inductor.compile_fx import compile_fx + self.compiler_fn = compile_fx if self.config.get("triton.cudagraphs", False): os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" @@ -1983,15 +2114,18 @@ def __init__(self, mode, options, dynamic): os.environ["TEARDOWN_CUPTI"] = "0" def __eq__(self, other): - return (isinstance(other, _TorchCompileInductorWrapper) and - self.config == other.config and - self.dynamic == other.dynamic) + return ( + isinstance(other, _TorchCompileInductorWrapper) + and self.config == other.config + and self.dynamic == other.dynamic + ) def apply_mode(self, mode: Optional[str]): if mode is None or mode == "default": pass - elif mode in ("reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"): + elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}: from torch._inductor import list_mode_options + self.apply_options(list_mode_options(mode, self.dynamic)) else: raise RuntimeError( @@ -2003,6 +2137,7 @@ def apply_options(self, options: Optional[Dict[str, Any]]): return from torch._inductor import config + current_config: Dict[str, Any] = config.shallow_copy_dict() for key, val in options.items(): @@ -2026,15 +2161,19 @@ def __call__(self, model_, inputs_): def get_compiler_config(self): from torch._inductor.compile_fx import get_patched_config_dict + return get_patched_config_dict(config_patches=self.config) def reset(self): from torch._inductor import config + if "triton.cudagraphs" in self.config or config.triton.cudagraphs: if self.config.get("triton.cudagraphs", True): from torch._inductor.cudagraph_trees import reset_cudagraph_trees + reset_cudagraph_trees() + class _TorchCompileWrapper: def __init__(self, backend, mode, options, dynamic): from torch._dynamo.backends.registry import lookup_backend @@ -2055,10 +2194,12 @@ def __init__(self, backend, mode, options, dynamic): self.kwargs["options"] = options def __eq__(self, other): - return (isinstance(other, _TorchCompileWrapper) and - self.compiler_fn == other.compiler_fn and - self.kwargs == other.kwargs and - self.dynamic == other.dynamic) + return ( + isinstance(other, _TorchCompileWrapper) + and self.compiler_fn == other.compiler_fn + and self.kwargs == other.kwargs + and self.dynamic == other.dynamic + ) def __call__(self, model_, inputs_): return self.compiler_fn(model_, inputs_, **self.kwargs) @@ -2068,13 +2209,16 @@ def reset(self): self.compiler_fn.reset() -def compile(model: Optional[Callable] = None, *, - fullgraph: builtins.bool = False, - dynamic: Optional[builtins.bool] = None, - backend: Union[str, Callable] = "inductor", - mode: Union[str, None] = None, - options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None, - disable: builtins.bool = False) -> Callable: +def compile( + model: Optional[Callable] = None, + *, + fullgraph: builtins.bool = False, + dynamic: Optional[builtins.bool] = None, + backend: Union[str, Callable] = "inductor", + mode: Union[str, None] = None, + options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None, + disable: builtins.bool = False, +) -> Callable: """ Optimizes given model/function using TorchDynamo and specified backend. If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile` @@ -2165,20 +2309,26 @@ def foo(x): # Decorator mode if model is None: + def fn(model: Callable): if model is None: raise RuntimeError("Model can't be None") - return compile(model, - fullgraph=fullgraph, - dynamic=dynamic, - backend=backend, - mode=mode, - options=options, - disable=disable) + return compile( + model, + fullgraph=fullgraph, + dynamic=dynamic, + backend=backend, + mode=mode, + options=options, + disable=disable, + ) + return fn if mode is not None and options is not None: - raise RuntimeError("Either mode or options can be specified, but both can't be specified at the same time.") + raise RuntimeError( + "Either mode or options can be specified, but both can't be specified at the same time." + ) if mode is None and options is None: mode = "default" if backend == "inductor": @@ -2186,7 +2336,12 @@ def fn(model: Callable): else: backend = _TorchCompileWrapper(backend, mode, options, dynamic) - return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model) + return torch._dynamo.optimize( + backend=backend, + nopython=fullgraph, + dynamic=dynamic, + disable=disable, + )(model) from torch import export as export @@ -2205,12 +2360,15 @@ def _register_device_module(device_type, module): device_type = torch.device(device_type).type m = sys.modules[__name__] if hasattr(m, device_type): - raise RuntimeError(f"The runtime module of '{device_type}' has already " - f"been registered with '{getattr(m, device_type)}'") + raise RuntimeError( + f"The runtime module of '{device_type}' has already " + f"been registered with '{getattr(m, device_type)}'" + ) setattr(m, device_type, module) - torch_module_name = '.'.join([__name__, device_type]) + torch_module_name = ".".join([__name__, device_type]) sys.modules[torch_module_name] = module + # expose return_types from torch import library as library, return_types as return_types @@ -2218,7 +2376,7 @@ def _register_device_module(device_type, module): from torch import _meta_registrations # Enable CUDA Sanitizer -if 'TORCH_CUDA_SANITIZER' in os.environ: +if "TORCH_CUDA_SANITIZER" in os.environ: import torch.cuda._sanitizer as csan csan.enable_cuda_sanitizer() @@ -2278,7 +2436,11 @@ def __getattr__(name): replacement = _deprecated_attrs.get(name) if replacement is not None: import warnings - warnings.warn(f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'", stacklevel=2) + + warnings.warn( + f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'", + stacklevel=2, + ) return replacement() # Lazy modules @@ -2287,6 +2449,7 @@ def __getattr__(name): raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + def get_device_module(device: Optional[Union[torch.device, str]] = None): """ Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). @@ -2300,7 +2463,9 @@ def get_device_module(device: Optional[Union[torch.device, str]] = None): # Using default accelerator type. If no accelerator is available, it automatically returns CPU device. device_module_name = torch._C._get_accelerator().type else: - raise RuntimeError(f"Invalid value of device '{device}', expect torch.device, str, or None") + raise RuntimeError( + f"Invalid value of device '{device}', expect torch.device, str, or None" + ) device_module = getattr(torch, device_module_name, None) if device_module is None: raise RuntimeError( @@ -2309,7 +2474,11 @@ def get_device_module(device: Optional[Union[torch.device, str]] = None): return device_module -def _constrain_as_size(symbol, min: Optional[builtins.int] = None, max: Optional[builtins.int] = None): +def _constrain_as_size( + symbol, + min: Optional[builtins.int] = None, + max: Optional[builtins.int] = None, +): """ This indicates that a given int is size-like, and can be used in any context where a size is expected. You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist() @@ -2332,4 +2501,5 @@ def _constrain_as_size(symbol, min: Optional[builtins.int] = None, max: Optional from torch import _logging + _logging._init_logs() From 2e065f2486e92d8b89b058f1d561bb801f226179 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 12 Jun 2024 10:49:16 +0000 Subject: [PATCH 688/706] [Quant][Inductor] Bug fix: mutation nodes not handled correctly for QLinearPointwiseBinaryPT2E (#127592) Fixes #127402 - Revert some changes to `ir.MutationOutput` and inductor/test_flex_attention.py - Add checks of mutation for QLinearPointwiseBinaryPT2E Pull Request resolved: https://github.com/pytorch/pytorch/pull/127592 Approved by: https://github.com/leslie-fang-intel, https://github.com/Chillee --- test/inductor/test_flex_attention.py | 10 +++--- test/inductor/test_mkldnn_pattern_matcher.py | 33 ++++++++++++++++++++ torch/_inductor/ir.py | 7 +---- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 4e8eecef0f41..c6f03052f37a 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -776,11 +776,13 @@ def f(q, k, v): metrics.reset() f(q, k, v) accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize - logsumexp_bytes = 1 * 8 * 1024 * torch.float32.itemsize num_accesses = 4 # q, k, v reads, one output. - self.assertEqual( - metrics.num_bytes_accessed, accessed_bytes * num_accesses + logsumexp_bytes - ) + # TODO: Get rid of this fudge factor + # We need this fudge factor for now, since + # 1. For some reason we materialize the output of the attention unnecessarily (it's related to the mutation somehow) + # 2. We also write the extraneous logsumexp + num_accesses += 2 + self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses) @supported_platform @skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571 diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 8932fcfc4afd..0490c3bcb9f3 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -233,6 +233,7 @@ def _test_code_common( rtol=1.3e-6, check_quantization=False, check_dynamic=None, + num_include_ops=None, ): with torch.no_grad(): clone_inputs = self._clone_inputs(inputs) @@ -245,6 +246,12 @@ def _test_code_common( ) for op in include_ops: self.assertIn(op, source_code) + if num_include_ops is not None: + assert len(include_ops) == len(num_include_ops) + for i in range(len(include_ops)): + self.assertEqual( + source_code.count(include_ops[i]), num_include_ops[i] + ) for op in exclude_ops: self.assertNotIn(op, source_code) if check_dynamic is not None: @@ -1808,6 +1815,32 @@ def matcher_check_fn(): matcher_check_fn=matcher_check_fn, is_qat=is_qat, ) + if torch._inductor.config.cpp_wrapper: + # For CPP wrapper + self._test_code_common( + mod, + (v,), + [ + "op_qlinear_pointwise.call", + "op_qlinear_pointwise_binary.call", + ], + [], + check_quantization=True, + num_include_ops=[2, 2], + ) + else: + # For python wrapper + self._test_code_common( + mod, + (v,), + [ + "torch.ops.onednn.qlinear_pointwise.default", + "torch.ops.onednn.qlinear_pointwise.binary", + ], + [], + check_quantization=True, + num_include_ops=[2, 2], + ) @skipIfNoDynamoSupport @skipIfNoONEDNN diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 1edafadd68f0..da0c1b120676 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4775,15 +4775,10 @@ def get_mutation_names(self): def __init__(self, layout, mutated_node, node_doing_mutating): # NB: Do not directly construct this - use `mark_node_as_mutating` - super().__init__(None, layout, [mutated_node], ()) + super().__init__(None, layout, [mutated_node, node_doing_mutating], ()) self.node_doing_mutating = node_doing_mutating self.name = V.graph.register_buffer(self) - def get_read_writes(self): - read_writes = super().get_read_writes() - read_writes.reads.add(dependencies.WeakDep(self.node_doing_mutating.get_name())) - return read_writes - def should_allocate(self): return False From abc3eec22d38079bee855fbcb75da62a9558284c Mon Sep 17 00:00:00 2001 From: James Wu Date: Tue, 11 Jun 2024 12:46:50 -0700 Subject: [PATCH 689/706] First version of AOTAutogradCache (#126791) This PR implements "V0" of AOTAutogradCache. Given an input to AOTAutograd, we calculate a cache key, then save an AOTAutogradCacheEntry. Each AOTAutogradCacheEntry has: - A CompiledForward and optionally a CompiledBackward - A bunch of metadata. CompiledForward and CompiledBackward each save the *key* to the FXGraphCache associated with the compiled object. FXGraphCache populates this key field as long as it's able to return a compiled graph given a set of inputs. We then load the same object from the FXGraphCache on an AOTAutogradCache hit. On cache miss: - Run AOTAutograd, up to AOTAutogradDispatch.post_compile. - Save an AOTAutogradCacheEntry to the cache after compiling the necessary portions and receiving a cache key from FXGraphCache. In this we *always* compile the backwards ahead of time. The PR above this one implements backward lazy caching, so that we only save to the cache after compiling the backward in a lazy backward scenario. - Return the resulting object On cache hit: - Run AOTAutogradCacheEntry.post_compile() on the cache key. - This attempts to load the forward and backward graphs from FXGraphCache - As long as we successfully load from FXGraphCache, it's a hit. We then rewrap the callable with post compile wrappers using our saved metadata. For now, we ignore the fakified out and debug wrappers. We only save to the cache if Fakified out is turned off. V0 Guards behavior: FXGraphCache serializes guards that are needed in the shape_env based on the symint inputs to the graph. The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly the same as the ones it passes to inductor, for both the forward and backward passes. (This does *not* mean that the tensor values passed in are the same: only that their symints are). That is, AOTAutograd and Inductor never create new guards based on symints with *different sources* than those passed to it by inductor. We don't currently store any AOTAutograd specific guards: my hypothesis is that FXGraphCache already stores these, as any guards generated by AOTAutograd should already be in the shape_env before calling into inductor, and we don't generate new guards post inductor. If this is needed, I'll add it in another diff. Testing: We'll start with some basic unit tests, but I'll be adding more and more complicated testing as the next step. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126791 Approved by: https://github.com/bdhirsh --- test/dynamo/test_aot_autograd_cache.py | 259 ++++++++++- .../_aot_autograd/autograd_cache.py | 409 +++++++++++++++++- .../jit_compile_runtime_wrappers.py | 41 +- .../_aot_autograd/runtime_wrappers.py | 6 +- torch/_functorch/_aot_autograd/schemas.py | 3 + torch/_functorch/aot_autograd.py | 29 +- torch/_functorch/config.py | 3 + torch/_inductor/codecache.py | 10 +- torch/_inductor/utils.py | 5 +- 9 files changed, 741 insertions(+), 24 deletions(-) diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index c57527d9f6cd..c34b30b7191d 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -1,19 +1,270 @@ # Owner(s): ["module: dynamo"] +import os +import unittest + import torch import torch._dynamo import torch._dynamo.test_case import torch._functorch._aot_autograd +from torch._dynamo.utils import counters from torch._functorch import config as functorch_config from torch._functorch._aot_autograd.autograd_cache import ( - autograd_cache_hash, + AOTAutogradCache, + autograd_cache_key, BypassAOTAutogradCache, ) from torch._functorch._aot_autograd.schemas import AOTConfig from torch._inductor import config as inductor_config - - +from torch.testing._internal.common_cuda import SM80OrLater +from torch.testing._internal.common_device_type import largeTensorTest +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU + + +@instantiate_parametrized_tests +class AOTAutogradCacheTests(torch._dynamo.test_case.TestCase): + def setUp(self): + """ + Reset all counters and caches before each unit test + """ + super().setUp() + counters.clear() + self._clear_all_caches() + + def _clear_all_caches(self): + """ + Clear every cache, including AOTAutogradCache and FXCache + """ + torch._inductor.codecache.FxGraphCache.clear() + AOTAutogradCache.clear() + self._clear_dynamo_and_codecache() + + def _clear_dynamo_and_codecache(self): + """ + Clear unrelated caches, like dynamo and PyCodeCache + """ + torch._dynamo.reset() + for m in torch._inductor.codecache.PyCodeCache.cache.values(): + os.remove(m.__file__) + torch._inductor.codecache.PyCodeCache.cache_clear() + + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + def test_basic(self): + """ + Verify the interactions between FXGraphCache and AOTAutogradCache. + """ + + def fn(x, y): + return (x * 2, y @ y) + + a = torch.rand(25) + b = torch.rand(5, 5) + + compiled_fn = torch.compile(fn, backend="inductor") + + # A first call should miss in the cache. + self.assertEqual(fn(a, b), compiled_fn(a, b)) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # A second call should hit. (First reset so in-memory guards + # don't prevent compilation). + self._clear_dynamo_and_codecache() + self.assertEqual(fn(a, b), compiled_fn(a, b)) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + def test_clear_fx_graph_cache(self): + """ + Verify the interactions between FXGraphCache and AOTAutogradCache. + """ + + def fn(x, y): + return (x * 2, y @ y) + + a = torch.rand(25) + b = torch.rand(5, 5) + + compiled_fn = torch.compile(fn, backend="inductor") + + # A first call should miss in the cache. + self.assertEqual(fn(a, b), compiled_fn(a, b)) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # Clear FX graph cache: second call should also be a miss + self._clear_dynamo_and_codecache() + torch._inductor.codecache.FxGraphCache.clear() + self.assertEqual(fn(a, b), compiled_fn(a, b)) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + # We save again into the cache + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) + + @inductor_config.patch("fx_graph_cache", False) + @functorch_config.patch({"enable_autograd_cache": True}) + def test_fx_graph_cache_off(self): + """ + Should not use cache if FXGraphCache is not enabled + """ + + def fn(x, y): + return (x * 2, y @ y) + + a = torch.rand(25) + b = torch.rand(5, 5) + + compiled_fn = torch.compile(fn, backend="inductor") + + # A first call should miss in the cache. + self.assertEqual(fn(a, b), compiled_fn(a, b)) + self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) + + # Clear FX graph cache: second call should also be a miss + self._clear_dynamo_and_codecache() + + self.assertEqual(fn(a, b), compiled_fn(a, b)) + self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) + + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + def test_autograd_function(self): + """ + Tests autograd cache hits + """ + + def fn(a, b): + return a.sin() + b + + a = torch.randn(25, requires_grad=True) + b = torch.randn(25, requires_grad=True) + a2 = a.detach().clone().requires_grad_(True) + b2 = b.detach().clone().requires_grad_(True) + + compiled_fn = torch.compile(fn, backend="inductor") + + # A first call should miss in the cache. + self.assertEqual(fn(a, b), compiled_fn(a2, b2)) + fn(a, b).sum().backward() + compiled_fn(a2, b2).sum().backward() + self.assertEqual(a.grad, a2.grad) + self.assertEqual(b.grad, b2.grad) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # Reset all tensors + a = torch.randn(25, requires_grad=True) + b = torch.randn(25, requires_grad=True) + a2 = a.detach().clone().requires_grad_(True) + b2 = b.detach().clone().requires_grad_(True) + + # A second call should hit. (First reset so in-memory guards + # don't prevent compilation). + self._clear_dynamo_and_codecache() + self.assertEqual(fn(a, b), compiled_fn(a2, b2)) + fn(a, b).sum().backward() + compiled_fn(a2, b2).sum().backward() + self.assertEqual(a.grad, a2.grad) + self.assertEqual(b.grad, b2.grad) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + @largeTensorTest("64GB", device=GPU_TYPE) + @parametrize("device", (GPU_TYPE,)) + @parametrize("dtype", (torch.float16, torch.bfloat16)) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + def test_autograd_inductor_guards(self, device, dtype): + """ + Tests that functions that would add inductor guards are cached properly + """ + if device == GPU_TYPE and not HAS_GPU: + raise unittest.SkipTest(f"requires {GPU_TYPE}") + if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + raise unittest.SkipTest("requires CUDA SM80 or later") + + def fn(x, y): + return (x + x, y + y) + + compiled_fn = torch.compile(fn, dynamic=True) + + # Iterate over different shapes, varying whether the total + # size is below or above int32. For each combination, we expect + # different guards around whether the symbolic sizes do or do + # not exceed int32. + shapes = ( + ((5, 6), (7, 8)), + ((5, 6), (47000, 47001)), + ((47000, 47001), (5, 6)), + ) + expected_hits = expected_misses = expected_saves = 0 + for a_shape, b_shape in shapes: + a = torch.rand(a_shape, device=device, dtype=dtype) + b = torch.rand(b_shape, device=device, dtype=dtype) + + # AVOID a dynamo reset here. We expect guards to have been + # added that will be violated with the new shape. We should + # see a recompilation (along with a cache miss). + res1 = compiled_fn(a, b) + # A first call should miss in the cache. + # NOTE: Currently, this cache miss is *not* due to guards, + # but instead because the AOTAutogradCache key calculation specializes on input shapes. + # Once we allow tensors with symints as part of the cache key calculation, it will + # instead cache miss because of guard failure. + expected_misses += 1 + expected_saves += 1 + self.assertEqual( + counters["aot_autograd"]["autograd_cache_miss"], expected_misses + ) + self.assertEqual( + counters["aot_autograd"]["autograd_cache_hit"], expected_hits + ) + self.assertEqual( + counters["aot_autograd"]["autograd_cache_saved"], expected_saves + ) + + # A second call should hit. (First reset so in-memory guards + # don't prevent compilation). + + # Now clear dynamo and we should see a cache hit + # This should populate guards to dynamo's cache, so that a subsequent run with a different + # shape will still trigger a second call to autograd_cache. + self._clear_dynamo_and_codecache() + res2 = compiled_fn(a, b) + expected_hits += 1 + self.assertEqual( + counters["aot_autograd"]["autograd_cache_miss"], expected_misses + ) + self.assertEqual( + counters["aot_autograd"]["autograd_cache_hit"], expected_hits + ) + self.assertEqual( + counters["aot_autograd"]["autograd_cache_saved"], expected_saves + ) + self.assertEqual(res1, res2) + + +@inductor_config.patch("fx_graph_cache", True) class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase): @property def device_type(self) -> str: @@ -57,7 +308,7 @@ def gen_cache_key(self, f, config, inputs=None): if inputs is None: inputs = [torch.ones(3)] _, fx_g, example_inputs = self._get_dynamo_output(f, *inputs) - return autograd_cache_hash(fx_g, example_inputs, config) + return autograd_cache_key(fx_g, example_inputs, config) def test_basic_hash_key(self): def fn(x): diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 8144a47f057a..9814ee9ec250 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -4,22 +4,56 @@ """ from __future__ import annotations +import contextlib +import copyreg + import functools +import io import logging import os -from typing import TYPE_CHECKING +import pickle +import shutil + +from dataclasses import dataclass + +from typing import Any, Callable, List, Optional, TYPE_CHECKING import torch +from torch._dynamo.utils import counters from torch._functorch import config +from torch._guards import detect_fake_mode + from torch._inductor.codecache import ( _ident, BypassFxGraphCache, + CompiledFxGraph, + FxGraphCache, FxGraphCachePickler, FxGraphHashDetails, get_code_hash, + write_atomic, +) + +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._subclasses.fake_tensor import ( + extract_tensor_metadata, + FakeTensor, + FakeTensorConverter, + in_kernel_invocation_manager, + TensorMetadata, ) -from .schemas import AOTConfig # noqa: F401 +from .runtime_wrappers import ( + AOTDispatchAutograd, + AOTDispatchSubclassWrapper, + CompilerWrapper, + FunctionalizedRngRuntimeWrapper, + post_compile, + RuntimeWrapper, + SubclassMeta, +) + +from .schemas import AOTConfig, ViewAndMutationMeta # noqa: F401 if TYPE_CHECKING: from torch.fx.node import Node @@ -31,6 +65,11 @@ class BypassAOTAutogradCache(Exception): pass +# Used to signify when FXGraphCache missed when AOTAutogradCache uses it +class FXGraphCacheMiss(BypassAOTAutogradCache): + pass + + def check_node_safe(node: Node): """ Checks that the node only uses supported operators. We are starting with very @@ -97,6 +136,15 @@ def check_cacheable(gm: torch.fx.GraphModule): raise BypassAOTAutogradCache( "Cannot cache a graph with compiled autograd enabled" ) + + if not torch._inductor.config.fx_graph_cache: + raise BypassAOTAutogradCache("FX graph cache is not enabled") + + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context and tracing_context.fakify_first_call: + raise BypassAOTAutogradCache( + "Won't cache a graph with fakify_first_call enabled" + ) for node in nodes: check_node_safe(node) @@ -113,7 +161,6 @@ def __init__( example_inputs, aot_config: AOTConfig, ): - check_cacheable(gm) # FxGraphHashDetails contains all the keys related to inductor. Also includes some system info self.aot_config = aot_config self.grad_enabled = torch.is_grad_enabled() @@ -122,7 +169,18 @@ def __init__( self.code_hash = get_autograd_code_hash() self.autograd_config = config.save_config() try: - super().__init__(gm, example_inputs, {}, []) + # We don't use FxGraphHashDetails to hash example_inputs because it expects + # example_inputs to always be FakeTensors, but at AOTAutograd's entry point, + # they're still regular. So instead we store their metadata here. + # TODO: this currently causes more cache misses than necessary + # with dynamic shapes, because this is before we add + # symints to tensor metadata. Improve this later. + self.example_input_metadata = [ + extract_tensor_metadata(t) + for t in example_inputs + if isinstance(t, torch.Tensor) + ] + super().__init__(gm, [], {}, []) except BypassFxGraphCache as e: # Sometimes inductor configs are unpickleable and can fail raise BypassAOTAutogradCache from e @@ -155,7 +213,7 @@ class AOTAutogradCachePickler(FxGraphCachePickler): dispatch_table[AOTConfig] = _reduce_aot_config -def autograd_cache_hash( +def autograd_cache_key( gm: torch.fx.GraphModule, example_inputs, config: AOTConfig, @@ -164,8 +222,347 @@ def autograd_cache_hash( """ Generate a unique hash of the FX graph for caching. """ + check_cacheable(gm) details = AOTAutogradCacheDetails(gm, example_inputs, config) # The prefix distinguishes among the other kinds of objects we cache key = "a" + AOTAutogradCachePickler.get_hash(details) - log.debug("FX graph cache hash details for key %s:\n%s", key, details.debug_str()) + log.debug( + "Autograd graph cache hash details for key %s:\n%s", key, details.debug_str() + ) return key + + +@dataclass +class FXGraphCacheLoadable: + fx_graph_cache_key: str + + def load(self, example_inputs) -> CompiledFxGraph: + # [Note: AOTAutogradCache and FXGraphCache Guard interactions] + # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. + # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. + # he invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly + # the same as the ones it passes to inductor, for both the forward and backward passes. + # (This does not mean that the tensor values passed in are the same: only that their symints are). + # That is, AOTAutograd and Inductor never create new guards based on symints with different sources + # than those passed to it by inductor. + result = FxGraphCache._lookup_graph( + self.fx_graph_cache_key, example_inputs, local=True, remote_cache=False + ) + if result is None: + log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_key) + raise FXGraphCacheMiss + result._boxed_call = True + return result + + +@dataclass +class CompiledForward(FXGraphCacheLoadable): + """ + Cacheable entry for a forward function + """ + + pass + + +@dataclass +class CompiledBackward(FXGraphCacheLoadable): + """ + Cacheable entry for a forward function + """ + + # Used by AOTDispatchAutograd.post_compile + backward_state_indices: List[int] + num_symints_saved_for_bw_: int + + +@dataclass +class AOTAutogradCacheEntry: + """A single entry into the cache.""" + + # Forward and Backward info + compiled_fw: CompiledForward + compiled_bw: Optional[CompiledBackward] + + # Runtime_metadata saved right before compilation + runtime_metadata: ViewAndMutationMeta + + # Wrappers that run after each aot_dispatch_* function + dispatch_wrappers: List[CompilerWrapper] + + # Used by AOTSubclassWrapper + maybe_subclass_meta: Optional[SubclassMeta] + num_fw_outs_saved_for_bw: Optional[int] + + # Used by RuntimeWrapepr + indices_of_inps_to_detach: List[int] + + # Turn cache entry into the original callable + def wrap_post_compile( + self, args: List[torch.Tensor], aot_config: AOTConfig + ) -> Callable: + """ + This function takes a cache entry and carefully reconstructs the original callable + that AOTAutograd returned the first time it was run. It does this by running the various + post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers. + + In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers. + In the autograd path, this consists of AOTAutogradDispatch.post_compile. + + The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd. + + Notably absent from the cached path are: + - DebugAssertWrapper + - FakifiedOutWrapper + + Which we'll handle separately later on, if necessary. + """ + compiled_fw_func = self.compiled_fw.load(args) + compiled_bw_func = None + if self.compiled_bw is not None: + compiled_bw_func = self.compiled_bw.load(args) + needs_autograd = True + else: + needs_autograd = False + + # Wrap the forward function in post compile wrappers + compiled_fw_func = AOTDispatchSubclassWrapper( + trace_joint=needs_autograd, + fw_only=None, + maybe_subclass_meta=self.maybe_subclass_meta, + num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + + # In autograd case, functionalizedRngWrapper should not modify outs + return_new_outs = not needs_autograd + compiled_fw_func = FunctionalizedRngRuntimeWrapper( + return_new_outs=return_new_outs + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + disable_amp = torch._C._is_any_autocast_enabled() + + if needs_autograd: + assert self.compiled_bw is not None + compiled_function = AOTDispatchAutograd.post_compile( + compiled_fw_func, + compiled_bw_func, + self.maybe_subclass_meta, + self.compiled_bw.num_symints_saved_for_bw_, + self.compiled_bw.backward_state_indices, + disable_amp, + self.indices_of_inps_to_detach, + None, # lazy_backward_info + aot_config, + fw_metadata=self.runtime_metadata, + ) + else: + compiled_function = RuntimeWrapper( + indices_of_inps_to_detach=self.indices_of_inps_to_detach, + trace_joint=False, + disable_amp=disable_amp, + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + + compiled_function, _ = post_compile( + self.dispatch_wrappers, + compiled_function, + aot_config, + runtime_metadata=self.runtime_metadata, + ) + + return compiled_function + + +def _fake_tensor_from_meta(metadata: TensorMetadata): + """ + Given a fake tensor metadata, reconstruct the fake tensor. + This should be used only on TensorMetadata that was serialized/unserialized by AOTAutogradCache. + """ + # Synthesize a new FakeTensor with the cached metadata. + # Based around FakeTensor._output_from_cache_entry + assert not metadata.is_sparse + fake_mode = detect_fake_mode() + empty = torch.empty_strided( + metadata.shape, + metadata.stride, + dtype=metadata.dtype, + layout=metadata.layout, + device="meta", + requires_grad=metadata.requires_grad, + ) + + if metadata.is_conj: + torch._C._set_conj(empty, True) + if metadata.is_neg: + torch._C._set_neg(empty, True) + + # TODO: can traced tangents ever have a storage offset or storage bytes? + maybe_suppress: Callable[[], Any] = contextlib.nullcontext + if fake_mode is not None and fake_mode.shape_env is not None: + maybe_suppress = fake_mode.shape_env.suppress_guards + + if metadata.storage_offset != 0: + storage = empty.untyped_storage() + with in_kernel_invocation_manager(fake_mode), maybe_suppress(): + empty.set_( + storage, metadata.storage_offset, metadata.shape, metadata.stride + ) + if metadata.storage_bytes == 0: + empty.untyped_storage().resize_(0) + + return FakeTensorConverter().from_meta_and_device(fake_mode, empty, metadata.device) + + +def _reduce_fake_tensor(t): + """ + Allows us to serialize and deserialize FakeTensors, which show up in various metadata in our cache entries + """ + metadata = extract_tensor_metadata(t) + if metadata.is_sparse: + raise BypassAOTAutogradCache( + "Sparse tensors in the FW metadata are not yet supported" + ) + return (_fake_tensor_from_meta, (metadata,)) + + +# TODO: We don't actually need to pickle FakeTensors in the cache. This is done for +# traced_tangents in this PR, but once we handle traced_tangents properly in the PR above, +# we can remove this. +class AOTAutogradCacheEntryPickler(pickle.Pickler): + dispatch_table = copyreg.dispatch_table.copy() + dispatch_table[FakeTensor] = _reduce_fake_tensor + + @staticmethod + def dumps(obj) -> bytes: + """ + Pickle an object using the FxGraphCachePickler. + """ + with io.BytesIO() as stream: + pickler = AOTAutogradCacheEntryPickler(stream) + pickler.dump(obj) + return stream.getvalue() + + +class AOTAutogradCacheEntryUnpickler(pickle.Unpickler): + dispatch_table = copyreg.dispatch_table.copy() + dispatch_table[FakeTensor] = _reduce_fake_tensor + + +class AOTAutogradCache: + """ + Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas + AOTAutogradCacheEntry handles the wrapping/unwrapping logic. + + Cache Inputs (AOTAutogradCacheDetails) + - AOTAutogradCache takes in the following inputs, which are analogous to inputs given + to AOTAutograd by dynamo: + - A fx graph module generated by dynamo + - A list of args, which consists of: + - Symint inputs to the graph, generated by dynamo + - The **real tensor** inputs, which inductor uses for cudagraphs + - Notably, the real tensor inputs don't have symints in their metadata. + AOTAutograd then retraces those real tensor arguments into FakeTensors later during execution. + - A set of global configurations that affect AOTAutograd or Inductor behavior. + + It then generates a cache key given these values. Notably, this means AOTAutogradCache currently + specializes on the sizes and strides of the real tensor inputs when dynamic shapes are turned on. + In a later PR, we'll likely generate the cache key based on the FakeTensors AOTAutograd generates + based on the real tensor inputs, which can contain symints. + + # Cache Outputs (AOTAutogradCacheEntry) + - AOTAutogradCache caches the following values: + - The compiled forward and backward functions from inductor, via keys to the FXGraphCache + - Metadata to reconstruct the AOTModule from the compiled inductor artifacts + - See AOTAutogradCacheEntry for more info + + [Note: Caching guards generated by AOTAutograd and Inductor] + AOTAutograd and inductor both can introduce new guards to the shape environment. FXGraphCache saves guards with each + compiled graph inductor generates. On a cache hit, AOTAutograd reloads the compiled forward and backward functions + from FXGraphCache, giving it new symint arguments from the input args. + FXGraphCache uses those symints and its saved guards to repopulate the ShapeEnv with guards. + **No new guards are generated into the shape env after inductor finishes compiling**, so the guards + saved by inductor are sufficient for correctness for both AOTAutograd and Inductor's caches. + """ + + @staticmethod + def clear(): + """Clear the cache""" + try: + shutil.rmtree(AOTAutogradCache._get_tmp_dir()) + except FileNotFoundError: + pass + + @staticmethod + def load( + dispatch_and_compile: Callable, + gm: torch.fx.GraphModule, + args, + aot_config: AOTConfig, + ) -> Callable: + """ + Load a result from the cache, and reconstruct a runtime wrapper around the object + """ + compiled_fn = None + cache_key = None + try: + cache_key = autograd_cache_key(gm, args, aot_config) + entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(cache_key) + if entry is not None: + compiled_fn = entry.wrap_post_compile(args, aot_config) + log.info("AOTAutograd cache hit for key %s", cache_key) + counters["aot_autograd"]["autograd_cache_hit"] += 1 + if compiled_fn is None: + log.info("AOTAutograd cache miss for key %s", cache_key) + counters["aot_autograd"]["autograd_cache_miss"] += 1 + # Count missing the FXGraphCache as a miss not a bypass + except FXGraphCacheMiss: + counters["aot_autograd"]["autograd_cache_miss"] += 1 + except BypassAOTAutogradCache: + cache_key = None + counters["aot_autograd"]["autograd_cache_bypass"] += 1 + if compiled_fn is None: + # Set the cache key so we can save a cache result later + aot_config.cache_key = cache_key + compiled_fn = dispatch_and_compile() + return compiled_fn + + @staticmethod + def _get_tmp_dir() -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cache_dir(), "aotautograd") + + @staticmethod + def _lookup(key: str) -> Optional[AOTAutogradCacheEntry]: + """Given a key generated by AOTAutogradCachePickler, look up its location in the cache.""" + subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key) + if not os.path.exists(subdir): + return None + path = os.path.join(subdir, "entry") + try: + with open(path, "rb") as f: + entry: AOTAutogradCacheEntry = AOTAutogradCacheEntryUnpickler(f).load() + return entry + except Exception as e: + log.warning("AOTAutograd cache unable to load compiled graph: %s", e) + return None + + @staticmethod + def save(key: str, entry: AOTAutogradCacheEntry): + """Save a single entry into the cache.""" + try: + content = AOTAutogradCacheEntryPickler.dumps(entry) + except Exception as e: + log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e) + raise e + subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + path = os.path.join(subdir, "entry") + log.info("Writing AOTAutograd cache entry to %s", path) + write_atomic(path, content) + counters["aot_autograd"]["autograd_cache_saved"] += 1 diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 5eb681889d8a..2c5c80f45032 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -24,6 +24,12 @@ from torch.fx.experimental.proxy_tensor import is_sym_node from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals from .. import config +from .autograd_cache import ( + AOTAutogradCache, + AOTAutogradCacheEntry, + CompiledBackward, + CompiledForward, +) from .dispatch_and_compile_graph import ( aot_dispatch_autograd_graph, aot_dispatch_base_graph, @@ -180,11 +186,25 @@ def aot_dispatch_base( compiled_fw = functionalized_rng_wrapper.post_compile( compiled_fw, aot_config, runtime_metadata=fw_metadata ) + if config.enable_autograd_cache and aot_config.cache_key: + if fw_key := getattr(compiled_fw, "_fx_graph_cache_key", None): + entry = AOTAutogradCacheEntry( + compiled_fw=CompiledForward(fw_key), + compiled_bw=None, + runtime_metadata=fw_metadata, + dispatch_wrappers=wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=None, + indices_of_inps_to_detach=[], + ) + AOTAutogradCache.save(aot_config.cache_key, entry) + compiled_fw = fakified_out_wrapper.post_compile( compiled_fw, aot_config, runtime_metadata=fw_metadata, ) + # Why do we need to pass in num_fw_outs_saved_for_bw? # See Note: [Partitioner handling for Subclasses, Part 2] compiled_fw_func = AOTDispatchSubclassWrapper( @@ -540,7 +560,9 @@ def aot_dispatch_autograd( placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride) compiled_bw_func = None - if num_symints_saved_for_bw > 0: + if num_symints_saved_for_bw > 0 or ( + config.enable_autograd_cache and aot_config.cache_key + ): context = torch._C._DisableAutocast if disable_amp else nullcontext with context(): try: @@ -585,6 +607,23 @@ def aot_dispatch_autograd( saved_context, saved_compile_context, ) + if config.enable_autograd_cache and aot_config.cache_key: + fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None) + bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None) + + if fw_key and bw_key: + entry = AOTAutogradCacheEntry( + CompiledForward(fw_key), + CompiledBackward( + bw_key, backward_state_indices, num_symints_saved_for_bw + ), + fw_metadata, + wrappers, + maybe_subclass_meta, + num_fw_outs_saved_for_bw, + _indices_of_inps_to_detach, + ) + AOTAutogradCache.save(aot_config.cache_key, entry) compiled_fn = AOTDispatchAutograd.post_compile( compiled_fw_func, diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 0afa24ce4ee8..a98bb5e0128b 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -1769,7 +1769,11 @@ def backward(ctx, *flat_args): def call_compiled_backward(): if ctx._is_compiled_autograd_tracing(): - assert lazy_backward_info is not None + if lazy_backward_info is None: + raise RuntimeError( + """This compiled backward function was saved by AOTAutogradCache, which does not support + compiled autograd. Please turn off AOTAutogradCache using `ENABLE_AOT_AUTOGRAD_CACHE=0` to continue.""" + ) bw_module = lazy_backward_info.bw_module # For compiled autograd, run raw FX graph so that it can be inlined into the larger graph symints = ctx._get_compiled_autograd_symints() diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index d5588a6e912c..c2db24d3544b 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -709,6 +709,9 @@ class AOTConfig: # this is always false outside of export. pre_dispatch: bool = False + # Key to use for AOTAutogradCache + cache_key: Optional[str] = None + def __post_init__(self): if self.pre_dispatch: assert self.is_export, "Can only have pre_dispatch IR for export." diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index c52a9cde0d55..97ef00858216 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -21,6 +21,11 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.utils._python_dispatch import is_traceable_wrapper_subclass from . import config +from ._aot_autograd.autograd_cache import ( # noqa: F401 + AOTAutogradCache, + autograd_cache_key, +) + from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 run_functionalized_fw_and_collect_metadata, ) @@ -880,8 +885,6 @@ def aot_module_simplified( params_flat = list(params_flat) params_len = len(params_flat) - functional_call = create_functional_call(mod, params_spec, params_len) - if bw_compiler is None: bw_compiler = fw_compiler if inference_compiler is None: @@ -947,14 +950,24 @@ def aot_module_simplified( aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source, is_export=False, no_tangents=False, + cache_key=None, ) - with compiled_autograd.disable(): - compiled_fn, _ = create_aot_dispatcher_function( - functional_call, - full_args, - aot_config, - ) + def dispatch_and_compile(): + functional_call = create_functional_call(mod, params_spec, params_len) + with compiled_autograd.disable(): + compiled_fn, _ = create_aot_dispatcher_function( + functional_call, + full_args, + aot_config, + ) + return compiled_fn + + # Autograd cache stuff + if config.enable_autograd_cache: + compiled_fn = AOTAutogradCache.load(dispatch_and_compile, mod, args, aot_config) + else: + compiled_fn = dispatch_and_compile() if isinstance(mod, torch._dynamo.utils.GmWrapper): # This function is called by the flatten_graph_inputs wrapper, which boxes diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 60bbf1f21c66..554907fd1be8 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -173,6 +173,9 @@ # Supported formats are defined here https://graphviz.org/docs/outputs/ torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg") +enable_autograd_cache = os.environ.get("ENABLE_AOT_AUTOGRAD_CACHE", "0") == "1" + + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 574511d004a4..208ad73cce5a 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -541,6 +541,9 @@ def get_str(obj) -> str: return str(extract_tensor_metadata_for_cache_key(obj)) elif isinstance(obj, bytes): return "" + elif type(obj) in cls.dispatch_table: + # Run the reducer on the object + return str(cls.dispatch_table[type(obj)](obj)[1]) else: return str(obj) @@ -785,8 +788,8 @@ def _get_shape_env() -> Optional[ShapeEnv]: def _lookup_graph( key: str, example_inputs: List[torch.Tensor], - local, - remote_cache, + local: bool, + remote_cache: Optional[Any], ) -> Optional[CompiledFxGraph]: """ Lookup a compiled graph in the cache by key. On a hit, return the @@ -1037,6 +1040,7 @@ def load( compiled_graph = FxGraphCache._lookup_graph( key, example_inputs, local, remote_cache ) + if compiled_graph is None: log.debug("fx graph cache miss for key %s", key) counters["inductor"]["fxgraph_cache_miss"] += 1 @@ -1054,6 +1058,7 @@ def load( else: log.debug("fx graph cache hit for key %s", key) counters["inductor"]["fxgraph_cache_hit"] += 1 + compiled_graph._fx_graph_cache_key = key except BypassFxGraphCache: counters["inductor"]["fxgraph_cache_bypass"] += 1 if not compiled_graph: @@ -1100,6 +1105,7 @@ class CompiledFxGraph: guards_expr: Optional[str] _boxed_call: Optional[bool] = None + _fx_graph_cache_key: Optional[str] = None def __init__( self, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f59fa34fe9c0..d25381a3927b 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -46,7 +46,6 @@ import sympy import torch -import torch._export import torch.utils._pytree as pytree from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.utils import detect_fake_mode @@ -1685,6 +1684,8 @@ def aoti_compile_with_persistent_cache( Compile the given function with persistent cache for AOTI eager mode. """ assert not dynamic, "Only support static shape for now" + from torch._export import aot_compile + type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool} supported_scalar_types = tuple(type_to_torch_dtype.keys()) flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) @@ -1707,7 +1708,7 @@ def aoti_compile_with_persistent_cache( {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, ): try: - kernel_lib_path = torch._export.aot_compile( + kernel_lib_path = aot_compile( f, args, kwargs, From 71f491554c182e78c59ac6ed3e35ae9efefc7b1e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 12 Jun 2024 13:59:29 +0000 Subject: [PATCH 690/706] Revert "First version of AOTAutogradCache (#126791)" This reverts commit abc3eec22d38079bee855fbcb75da62a9558284c. Reverted https://github.com/pytorch/pytorch/pull/126791 on behalf of https://github.com/DanilBaibak due to The changes broke a number of linux jobs ([comment](https://github.com/pytorch/pytorch/pull/126791#issuecomment-2163081643)) --- test/dynamo/test_aot_autograd_cache.py | 259 +---------- .../_aot_autograd/autograd_cache.py | 409 +----------------- .../jit_compile_runtime_wrappers.py | 41 +- .../_aot_autograd/runtime_wrappers.py | 6 +- torch/_functorch/_aot_autograd/schemas.py | 3 - torch/_functorch/aot_autograd.py | 29 +- torch/_functorch/config.py | 3 - torch/_inductor/codecache.py | 10 +- torch/_inductor/utils.py | 5 +- 9 files changed, 24 insertions(+), 741 deletions(-) diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index c34b30b7191d..c57527d9f6cd 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -1,270 +1,19 @@ # Owner(s): ["module: dynamo"] -import os -import unittest - import torch import torch._dynamo import torch._dynamo.test_case import torch._functorch._aot_autograd -from torch._dynamo.utils import counters from torch._functorch import config as functorch_config from torch._functorch._aot_autograd.autograd_cache import ( - AOTAutogradCache, - autograd_cache_key, + autograd_cache_hash, BypassAOTAutogradCache, ) from torch._functorch._aot_autograd.schemas import AOTConfig from torch._inductor import config as inductor_config -from torch.testing._internal.common_cuda import SM80OrLater -from torch.testing._internal.common_device_type import largeTensorTest -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, -) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU - - -@instantiate_parametrized_tests -class AOTAutogradCacheTests(torch._dynamo.test_case.TestCase): - def setUp(self): - """ - Reset all counters and caches before each unit test - """ - super().setUp() - counters.clear() - self._clear_all_caches() - - def _clear_all_caches(self): - """ - Clear every cache, including AOTAutogradCache and FXCache - """ - torch._inductor.codecache.FxGraphCache.clear() - AOTAutogradCache.clear() - self._clear_dynamo_and_codecache() - - def _clear_dynamo_and_codecache(self): - """ - Clear unrelated caches, like dynamo and PyCodeCache - """ - torch._dynamo.reset() - for m in torch._inductor.codecache.PyCodeCache.cache.values(): - os.remove(m.__file__) - torch._inductor.codecache.PyCodeCache.cache_clear() - - @inductor_config.patch("fx_graph_cache", True) - @functorch_config.patch({"enable_autograd_cache": True}) - def test_basic(self): - """ - Verify the interactions between FXGraphCache and AOTAutogradCache. - """ - - def fn(x, y): - return (x * 2, y @ y) - - a = torch.rand(25) - b = torch.rand(5, 5) - - compiled_fn = torch.compile(fn, backend="inductor") - - # A first call should miss in the cache. - self.assertEqual(fn(a, b), compiled_fn(a, b)) - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) - self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) - - # A second call should hit. (First reset so in-memory guards - # don't prevent compilation). - self._clear_dynamo_and_codecache() - self.assertEqual(fn(a, b), compiled_fn(a, b)) - - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) - - @inductor_config.patch("fx_graph_cache", True) - @functorch_config.patch({"enable_autograd_cache": True}) - def test_clear_fx_graph_cache(self): - """ - Verify the interactions between FXGraphCache and AOTAutogradCache. - """ - - def fn(x, y): - return (x * 2, y @ y) - - a = torch.rand(25) - b = torch.rand(5, 5) - - compiled_fn = torch.compile(fn, backend="inductor") - - # A first call should miss in the cache. - self.assertEqual(fn(a, b), compiled_fn(a, b)) - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) - self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) - - # Clear FX graph cache: second call should also be a miss - self._clear_dynamo_and_codecache() - torch._inductor.codecache.FxGraphCache.clear() - self.assertEqual(fn(a, b), compiled_fn(a, b)) - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) - # We save again into the cache - self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) - - @inductor_config.patch("fx_graph_cache", False) - @functorch_config.patch({"enable_autograd_cache": True}) - def test_fx_graph_cache_off(self): - """ - Should not use cache if FXGraphCache is not enabled - """ - - def fn(x, y): - return (x * 2, y @ y) - - a = torch.rand(25) - b = torch.rand(5, 5) - - compiled_fn = torch.compile(fn, backend="inductor") - - # A first call should miss in the cache. - self.assertEqual(fn(a, b), compiled_fn(a, b)) - self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) - self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) - - # Clear FX graph cache: second call should also be a miss - self._clear_dynamo_and_codecache() - - self.assertEqual(fn(a, b), compiled_fn(a, b)) - self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 2) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) - self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) - - @inductor_config.patch("fx_graph_cache", True) - @functorch_config.patch({"enable_autograd_cache": True}) - def test_autograd_function(self): - """ - Tests autograd cache hits - """ - - def fn(a, b): - return a.sin() + b - - a = torch.randn(25, requires_grad=True) - b = torch.randn(25, requires_grad=True) - a2 = a.detach().clone().requires_grad_(True) - b2 = b.detach().clone().requires_grad_(True) - - compiled_fn = torch.compile(fn, backend="inductor") - - # A first call should miss in the cache. - self.assertEqual(fn(a, b), compiled_fn(a2, b2)) - fn(a, b).sum().backward() - compiled_fn(a2, b2).sum().backward() - self.assertEqual(a.grad, a2.grad) - self.assertEqual(b.grad, b2.grad) - - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) - self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) - - # Reset all tensors - a = torch.randn(25, requires_grad=True) - b = torch.randn(25, requires_grad=True) - a2 = a.detach().clone().requires_grad_(True) - b2 = b.detach().clone().requires_grad_(True) - - # A second call should hit. (First reset so in-memory guards - # don't prevent compilation). - self._clear_dynamo_and_codecache() - self.assertEqual(fn(a, b), compiled_fn(a2, b2)) - fn(a, b).sum().backward() - compiled_fn(a2, b2).sum().backward() - self.assertEqual(a.grad, a2.grad) - self.assertEqual(b.grad, b2.grad) - - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) - - @largeTensorTest("64GB", device=GPU_TYPE) - @parametrize("device", (GPU_TYPE,)) - @parametrize("dtype", (torch.float16, torch.bfloat16)) - @inductor_config.patch("fx_graph_cache", True) - @functorch_config.patch({"enable_autograd_cache": True}) - def test_autograd_inductor_guards(self, device, dtype): - """ - Tests that functions that would add inductor guards are cached properly - """ - if device == GPU_TYPE and not HAS_GPU: - raise unittest.SkipTest(f"requires {GPU_TYPE}") - if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: - raise unittest.SkipTest("requires CUDA SM80 or later") - - def fn(x, y): - return (x + x, y + y) - - compiled_fn = torch.compile(fn, dynamic=True) - - # Iterate over different shapes, varying whether the total - # size is below or above int32. For each combination, we expect - # different guards around whether the symbolic sizes do or do - # not exceed int32. - shapes = ( - ((5, 6), (7, 8)), - ((5, 6), (47000, 47001)), - ((47000, 47001), (5, 6)), - ) - expected_hits = expected_misses = expected_saves = 0 - for a_shape, b_shape in shapes: - a = torch.rand(a_shape, device=device, dtype=dtype) - b = torch.rand(b_shape, device=device, dtype=dtype) - - # AVOID a dynamo reset here. We expect guards to have been - # added that will be violated with the new shape. We should - # see a recompilation (along with a cache miss). - res1 = compiled_fn(a, b) - # A first call should miss in the cache. - # NOTE: Currently, this cache miss is *not* due to guards, - # but instead because the AOTAutogradCache key calculation specializes on input shapes. - # Once we allow tensors with symints as part of the cache key calculation, it will - # instead cache miss because of guard failure. - expected_misses += 1 - expected_saves += 1 - self.assertEqual( - counters["aot_autograd"]["autograd_cache_miss"], expected_misses - ) - self.assertEqual( - counters["aot_autograd"]["autograd_cache_hit"], expected_hits - ) - self.assertEqual( - counters["aot_autograd"]["autograd_cache_saved"], expected_saves - ) - - # A second call should hit. (First reset so in-memory guards - # don't prevent compilation). - - # Now clear dynamo and we should see a cache hit - # This should populate guards to dynamo's cache, so that a subsequent run with a different - # shape will still trigger a second call to autograd_cache. - self._clear_dynamo_and_codecache() - res2 = compiled_fn(a, b) - expected_hits += 1 - self.assertEqual( - counters["aot_autograd"]["autograd_cache_miss"], expected_misses - ) - self.assertEqual( - counters["aot_autograd"]["autograd_cache_hit"], expected_hits - ) - self.assertEqual( - counters["aot_autograd"]["autograd_cache_saved"], expected_saves - ) - self.assertEqual(res1, res2) - - -@inductor_config.patch("fx_graph_cache", True) + + class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase): @property def device_type(self) -> str: @@ -308,7 +57,7 @@ def gen_cache_key(self, f, config, inputs=None): if inputs is None: inputs = [torch.ones(3)] _, fx_g, example_inputs = self._get_dynamo_output(f, *inputs) - return autograd_cache_key(fx_g, example_inputs, config) + return autograd_cache_hash(fx_g, example_inputs, config) def test_basic_hash_key(self): def fn(x): diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 9814ee9ec250..8144a47f057a 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -4,56 +4,22 @@ """ from __future__ import annotations -import contextlib -import copyreg - import functools -import io import logging import os -import pickle -import shutil - -from dataclasses import dataclass - -from typing import Any, Callable, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import torch -from torch._dynamo.utils import counters from torch._functorch import config -from torch._guards import detect_fake_mode - from torch._inductor.codecache import ( _ident, BypassFxGraphCache, - CompiledFxGraph, - FxGraphCache, FxGraphCachePickler, FxGraphHashDetails, get_code_hash, - write_atomic, -) - -from torch._inductor.runtime.runtime_utils import cache_dir -from torch._subclasses.fake_tensor import ( - extract_tensor_metadata, - FakeTensor, - FakeTensorConverter, - in_kernel_invocation_manager, - TensorMetadata, ) -from .runtime_wrappers import ( - AOTDispatchAutograd, - AOTDispatchSubclassWrapper, - CompilerWrapper, - FunctionalizedRngRuntimeWrapper, - post_compile, - RuntimeWrapper, - SubclassMeta, -) - -from .schemas import AOTConfig, ViewAndMutationMeta # noqa: F401 +from .schemas import AOTConfig # noqa: F401 if TYPE_CHECKING: from torch.fx.node import Node @@ -65,11 +31,6 @@ class BypassAOTAutogradCache(Exception): pass -# Used to signify when FXGraphCache missed when AOTAutogradCache uses it -class FXGraphCacheMiss(BypassAOTAutogradCache): - pass - - def check_node_safe(node: Node): """ Checks that the node only uses supported operators. We are starting with very @@ -136,15 +97,6 @@ def check_cacheable(gm: torch.fx.GraphModule): raise BypassAOTAutogradCache( "Cannot cache a graph with compiled autograd enabled" ) - - if not torch._inductor.config.fx_graph_cache: - raise BypassAOTAutogradCache("FX graph cache is not enabled") - - tracing_context = torch._guards.TracingContext.try_get() - if tracing_context and tracing_context.fakify_first_call: - raise BypassAOTAutogradCache( - "Won't cache a graph with fakify_first_call enabled" - ) for node in nodes: check_node_safe(node) @@ -161,6 +113,7 @@ def __init__( example_inputs, aot_config: AOTConfig, ): + check_cacheable(gm) # FxGraphHashDetails contains all the keys related to inductor. Also includes some system info self.aot_config = aot_config self.grad_enabled = torch.is_grad_enabled() @@ -169,18 +122,7 @@ def __init__( self.code_hash = get_autograd_code_hash() self.autograd_config = config.save_config() try: - # We don't use FxGraphHashDetails to hash example_inputs because it expects - # example_inputs to always be FakeTensors, but at AOTAutograd's entry point, - # they're still regular. So instead we store their metadata here. - # TODO: this currently causes more cache misses than necessary - # with dynamic shapes, because this is before we add - # symints to tensor metadata. Improve this later. - self.example_input_metadata = [ - extract_tensor_metadata(t) - for t in example_inputs - if isinstance(t, torch.Tensor) - ] - super().__init__(gm, [], {}, []) + super().__init__(gm, example_inputs, {}, []) except BypassFxGraphCache as e: # Sometimes inductor configs are unpickleable and can fail raise BypassAOTAutogradCache from e @@ -213,7 +155,7 @@ class AOTAutogradCachePickler(FxGraphCachePickler): dispatch_table[AOTConfig] = _reduce_aot_config -def autograd_cache_key( +def autograd_cache_hash( gm: torch.fx.GraphModule, example_inputs, config: AOTConfig, @@ -222,347 +164,8 @@ def autograd_cache_key( """ Generate a unique hash of the FX graph for caching. """ - check_cacheable(gm) details = AOTAutogradCacheDetails(gm, example_inputs, config) # The prefix distinguishes among the other kinds of objects we cache key = "a" + AOTAutogradCachePickler.get_hash(details) - log.debug( - "Autograd graph cache hash details for key %s:\n%s", key, details.debug_str() - ) + log.debug("FX graph cache hash details for key %s:\n%s", key, details.debug_str()) return key - - -@dataclass -class FXGraphCacheLoadable: - fx_graph_cache_key: str - - def load(self, example_inputs) -> CompiledFxGraph: - # [Note: AOTAutogradCache and FXGraphCache Guard interactions] - # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. - # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. - # he invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly - # the same as the ones it passes to inductor, for both the forward and backward passes. - # (This does not mean that the tensor values passed in are the same: only that their symints are). - # That is, AOTAutograd and Inductor never create new guards based on symints with different sources - # than those passed to it by inductor. - result = FxGraphCache._lookup_graph( - self.fx_graph_cache_key, example_inputs, local=True, remote_cache=False - ) - if result is None: - log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_key) - raise FXGraphCacheMiss - result._boxed_call = True - return result - - -@dataclass -class CompiledForward(FXGraphCacheLoadable): - """ - Cacheable entry for a forward function - """ - - pass - - -@dataclass -class CompiledBackward(FXGraphCacheLoadable): - """ - Cacheable entry for a forward function - """ - - # Used by AOTDispatchAutograd.post_compile - backward_state_indices: List[int] - num_symints_saved_for_bw_: int - - -@dataclass -class AOTAutogradCacheEntry: - """A single entry into the cache.""" - - # Forward and Backward info - compiled_fw: CompiledForward - compiled_bw: Optional[CompiledBackward] - - # Runtime_metadata saved right before compilation - runtime_metadata: ViewAndMutationMeta - - # Wrappers that run after each aot_dispatch_* function - dispatch_wrappers: List[CompilerWrapper] - - # Used by AOTSubclassWrapper - maybe_subclass_meta: Optional[SubclassMeta] - num_fw_outs_saved_for_bw: Optional[int] - - # Used by RuntimeWrapepr - indices_of_inps_to_detach: List[int] - - # Turn cache entry into the original callable - def wrap_post_compile( - self, args: List[torch.Tensor], aot_config: AOTConfig - ) -> Callable: - """ - This function takes a cache entry and carefully reconstructs the original callable - that AOTAutograd returned the first time it was run. It does this by running the various - post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers. - - In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers. - In the autograd path, this consists of AOTAutogradDispatch.post_compile. - - The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd. - - Notably absent from the cached path are: - - DebugAssertWrapper - - FakifiedOutWrapper - - Which we'll handle separately later on, if necessary. - """ - compiled_fw_func = self.compiled_fw.load(args) - compiled_bw_func = None - if self.compiled_bw is not None: - compiled_bw_func = self.compiled_bw.load(args) - needs_autograd = True - else: - needs_autograd = False - - # Wrap the forward function in post compile wrappers - compiled_fw_func = AOTDispatchSubclassWrapper( - trace_joint=needs_autograd, - fw_only=None, - maybe_subclass_meta=self.maybe_subclass_meta, - num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, - ).post_compile( - compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata - ) - - # In autograd case, functionalizedRngWrapper should not modify outs - return_new_outs = not needs_autograd - compiled_fw_func = FunctionalizedRngRuntimeWrapper( - return_new_outs=return_new_outs - ).post_compile( - compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata - ) - disable_amp = torch._C._is_any_autocast_enabled() - - if needs_autograd: - assert self.compiled_bw is not None - compiled_function = AOTDispatchAutograd.post_compile( - compiled_fw_func, - compiled_bw_func, - self.maybe_subclass_meta, - self.compiled_bw.num_symints_saved_for_bw_, - self.compiled_bw.backward_state_indices, - disable_amp, - self.indices_of_inps_to_detach, - None, # lazy_backward_info - aot_config, - fw_metadata=self.runtime_metadata, - ) - else: - compiled_function = RuntimeWrapper( - indices_of_inps_to_detach=self.indices_of_inps_to_detach, - trace_joint=False, - disable_amp=disable_amp, - ).post_compile( - compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata - ) - - compiled_function, _ = post_compile( - self.dispatch_wrappers, - compiled_function, - aot_config, - runtime_metadata=self.runtime_metadata, - ) - - return compiled_function - - -def _fake_tensor_from_meta(metadata: TensorMetadata): - """ - Given a fake tensor metadata, reconstruct the fake tensor. - This should be used only on TensorMetadata that was serialized/unserialized by AOTAutogradCache. - """ - # Synthesize a new FakeTensor with the cached metadata. - # Based around FakeTensor._output_from_cache_entry - assert not metadata.is_sparse - fake_mode = detect_fake_mode() - empty = torch.empty_strided( - metadata.shape, - metadata.stride, - dtype=metadata.dtype, - layout=metadata.layout, - device="meta", - requires_grad=metadata.requires_grad, - ) - - if metadata.is_conj: - torch._C._set_conj(empty, True) - if metadata.is_neg: - torch._C._set_neg(empty, True) - - # TODO: can traced tangents ever have a storage offset or storage bytes? - maybe_suppress: Callable[[], Any] = contextlib.nullcontext - if fake_mode is not None and fake_mode.shape_env is not None: - maybe_suppress = fake_mode.shape_env.suppress_guards - - if metadata.storage_offset != 0: - storage = empty.untyped_storage() - with in_kernel_invocation_manager(fake_mode), maybe_suppress(): - empty.set_( - storage, metadata.storage_offset, metadata.shape, metadata.stride - ) - if metadata.storage_bytes == 0: - empty.untyped_storage().resize_(0) - - return FakeTensorConverter().from_meta_and_device(fake_mode, empty, metadata.device) - - -def _reduce_fake_tensor(t): - """ - Allows us to serialize and deserialize FakeTensors, which show up in various metadata in our cache entries - """ - metadata = extract_tensor_metadata(t) - if metadata.is_sparse: - raise BypassAOTAutogradCache( - "Sparse tensors in the FW metadata are not yet supported" - ) - return (_fake_tensor_from_meta, (metadata,)) - - -# TODO: We don't actually need to pickle FakeTensors in the cache. This is done for -# traced_tangents in this PR, but once we handle traced_tangents properly in the PR above, -# we can remove this. -class AOTAutogradCacheEntryPickler(pickle.Pickler): - dispatch_table = copyreg.dispatch_table.copy() - dispatch_table[FakeTensor] = _reduce_fake_tensor - - @staticmethod - def dumps(obj) -> bytes: - """ - Pickle an object using the FxGraphCachePickler. - """ - with io.BytesIO() as stream: - pickler = AOTAutogradCacheEntryPickler(stream) - pickler.dump(obj) - return stream.getvalue() - - -class AOTAutogradCacheEntryUnpickler(pickle.Unpickler): - dispatch_table = copyreg.dispatch_table.copy() - dispatch_table[FakeTensor] = _reduce_fake_tensor - - -class AOTAutogradCache: - """ - Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas - AOTAutogradCacheEntry handles the wrapping/unwrapping logic. - - Cache Inputs (AOTAutogradCacheDetails) - - AOTAutogradCache takes in the following inputs, which are analogous to inputs given - to AOTAutograd by dynamo: - - A fx graph module generated by dynamo - - A list of args, which consists of: - - Symint inputs to the graph, generated by dynamo - - The **real tensor** inputs, which inductor uses for cudagraphs - - Notably, the real tensor inputs don't have symints in their metadata. - AOTAutograd then retraces those real tensor arguments into FakeTensors later during execution. - - A set of global configurations that affect AOTAutograd or Inductor behavior. - - It then generates a cache key given these values. Notably, this means AOTAutogradCache currently - specializes on the sizes and strides of the real tensor inputs when dynamic shapes are turned on. - In a later PR, we'll likely generate the cache key based on the FakeTensors AOTAutograd generates - based on the real tensor inputs, which can contain symints. - - # Cache Outputs (AOTAutogradCacheEntry) - - AOTAutogradCache caches the following values: - - The compiled forward and backward functions from inductor, via keys to the FXGraphCache - - Metadata to reconstruct the AOTModule from the compiled inductor artifacts - - See AOTAutogradCacheEntry for more info - - [Note: Caching guards generated by AOTAutograd and Inductor] - AOTAutograd and inductor both can introduce new guards to the shape environment. FXGraphCache saves guards with each - compiled graph inductor generates. On a cache hit, AOTAutograd reloads the compiled forward and backward functions - from FXGraphCache, giving it new symint arguments from the input args. - FXGraphCache uses those symints and its saved guards to repopulate the ShapeEnv with guards. - **No new guards are generated into the shape env after inductor finishes compiling**, so the guards - saved by inductor are sufficient for correctness for both AOTAutograd and Inductor's caches. - """ - - @staticmethod - def clear(): - """Clear the cache""" - try: - shutil.rmtree(AOTAutogradCache._get_tmp_dir()) - except FileNotFoundError: - pass - - @staticmethod - def load( - dispatch_and_compile: Callable, - gm: torch.fx.GraphModule, - args, - aot_config: AOTConfig, - ) -> Callable: - """ - Load a result from the cache, and reconstruct a runtime wrapper around the object - """ - compiled_fn = None - cache_key = None - try: - cache_key = autograd_cache_key(gm, args, aot_config) - entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(cache_key) - if entry is not None: - compiled_fn = entry.wrap_post_compile(args, aot_config) - log.info("AOTAutograd cache hit for key %s", cache_key) - counters["aot_autograd"]["autograd_cache_hit"] += 1 - if compiled_fn is None: - log.info("AOTAutograd cache miss for key %s", cache_key) - counters["aot_autograd"]["autograd_cache_miss"] += 1 - # Count missing the FXGraphCache as a miss not a bypass - except FXGraphCacheMiss: - counters["aot_autograd"]["autograd_cache_miss"] += 1 - except BypassAOTAutogradCache: - cache_key = None - counters["aot_autograd"]["autograd_cache_bypass"] += 1 - if compiled_fn is None: - # Set the cache key so we can save a cache result later - aot_config.cache_key = cache_key - compiled_fn = dispatch_and_compile() - return compiled_fn - - @staticmethod - def _get_tmp_dir() -> str: - """ - Get the toplevel temporary directory for storing compiled graphs. - """ - return os.path.join(cache_dir(), "aotautograd") - - @staticmethod - def _lookup(key: str) -> Optional[AOTAutogradCacheEntry]: - """Given a key generated by AOTAutogradCachePickler, look up its location in the cache.""" - subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key) - if not os.path.exists(subdir): - return None - path = os.path.join(subdir, "entry") - try: - with open(path, "rb") as f: - entry: AOTAutogradCacheEntry = AOTAutogradCacheEntryUnpickler(f).load() - return entry - except Exception as e: - log.warning("AOTAutograd cache unable to load compiled graph: %s", e) - return None - - @staticmethod - def save(key: str, entry: AOTAutogradCacheEntry): - """Save a single entry into the cache.""" - try: - content = AOTAutogradCacheEntryPickler.dumps(entry) - except Exception as e: - log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e) - raise e - subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key) - if not os.path.exists(subdir): - os.makedirs(subdir, exist_ok=True) - path = os.path.join(subdir, "entry") - log.info("Writing AOTAutograd cache entry to %s", path) - write_atomic(path, content) - counters["aot_autograd"]["autograd_cache_saved"] += 1 diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 2c5c80f45032..5eb681889d8a 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -24,12 +24,6 @@ from torch.fx.experimental.proxy_tensor import is_sym_node from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals from .. import config -from .autograd_cache import ( - AOTAutogradCache, - AOTAutogradCacheEntry, - CompiledBackward, - CompiledForward, -) from .dispatch_and_compile_graph import ( aot_dispatch_autograd_graph, aot_dispatch_base_graph, @@ -186,25 +180,11 @@ def aot_dispatch_base( compiled_fw = functionalized_rng_wrapper.post_compile( compiled_fw, aot_config, runtime_metadata=fw_metadata ) - if config.enable_autograd_cache and aot_config.cache_key: - if fw_key := getattr(compiled_fw, "_fx_graph_cache_key", None): - entry = AOTAutogradCacheEntry( - compiled_fw=CompiledForward(fw_key), - compiled_bw=None, - runtime_metadata=fw_metadata, - dispatch_wrappers=wrappers, - maybe_subclass_meta=maybe_subclass_meta, - num_fw_outs_saved_for_bw=None, - indices_of_inps_to_detach=[], - ) - AOTAutogradCache.save(aot_config.cache_key, entry) - compiled_fw = fakified_out_wrapper.post_compile( compiled_fw, aot_config, runtime_metadata=fw_metadata, ) - # Why do we need to pass in num_fw_outs_saved_for_bw? # See Note: [Partitioner handling for Subclasses, Part 2] compiled_fw_func = AOTDispatchSubclassWrapper( @@ -560,9 +540,7 @@ def aot_dispatch_autograd( placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride) compiled_bw_func = None - if num_symints_saved_for_bw > 0 or ( - config.enable_autograd_cache and aot_config.cache_key - ): + if num_symints_saved_for_bw > 0: context = torch._C._DisableAutocast if disable_amp else nullcontext with context(): try: @@ -607,23 +585,6 @@ def aot_dispatch_autograd( saved_context, saved_compile_context, ) - if config.enable_autograd_cache and aot_config.cache_key: - fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None) - bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None) - - if fw_key and bw_key: - entry = AOTAutogradCacheEntry( - CompiledForward(fw_key), - CompiledBackward( - bw_key, backward_state_indices, num_symints_saved_for_bw - ), - fw_metadata, - wrappers, - maybe_subclass_meta, - num_fw_outs_saved_for_bw, - _indices_of_inps_to_detach, - ) - AOTAutogradCache.save(aot_config.cache_key, entry) compiled_fn = AOTDispatchAutograd.post_compile( compiled_fw_func, diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index a98bb5e0128b..0afa24ce4ee8 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -1769,11 +1769,7 @@ def backward(ctx, *flat_args): def call_compiled_backward(): if ctx._is_compiled_autograd_tracing(): - if lazy_backward_info is None: - raise RuntimeError( - """This compiled backward function was saved by AOTAutogradCache, which does not support - compiled autograd. Please turn off AOTAutogradCache using `ENABLE_AOT_AUTOGRAD_CACHE=0` to continue.""" - ) + assert lazy_backward_info is not None bw_module = lazy_backward_info.bw_module # For compiled autograd, run raw FX graph so that it can be inlined into the larger graph symints = ctx._get_compiled_autograd_symints() diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index c2db24d3544b..d5588a6e912c 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -709,9 +709,6 @@ class AOTConfig: # this is always false outside of export. pre_dispatch: bool = False - # Key to use for AOTAutogradCache - cache_key: Optional[str] = None - def __post_init__(self): if self.pre_dispatch: assert self.is_export, "Can only have pre_dispatch IR for export." diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 97ef00858216..c52a9cde0d55 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -21,11 +21,6 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.utils._python_dispatch import is_traceable_wrapper_subclass from . import config -from ._aot_autograd.autograd_cache import ( # noqa: F401 - AOTAutogradCache, - autograd_cache_key, -) - from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 run_functionalized_fw_and_collect_metadata, ) @@ -885,6 +880,8 @@ def aot_module_simplified( params_flat = list(params_flat) params_len = len(params_flat) + functional_call = create_functional_call(mod, params_spec, params_len) + if bw_compiler is None: bw_compiler = fw_compiler if inference_compiler is None: @@ -950,24 +947,14 @@ def aot_module_simplified( aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source, is_export=False, no_tangents=False, - cache_key=None, ) - def dispatch_and_compile(): - functional_call = create_functional_call(mod, params_spec, params_len) - with compiled_autograd.disable(): - compiled_fn, _ = create_aot_dispatcher_function( - functional_call, - full_args, - aot_config, - ) - return compiled_fn - - # Autograd cache stuff - if config.enable_autograd_cache: - compiled_fn = AOTAutogradCache.load(dispatch_and_compile, mod, args, aot_config) - else: - compiled_fn = dispatch_and_compile() + with compiled_autograd.disable(): + compiled_fn, _ = create_aot_dispatcher_function( + functional_call, + full_args, + aot_config, + ) if isinstance(mod, torch._dynamo.utils.GmWrapper): # This function is called by the flatten_graph_inputs wrapper, which boxes diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 554907fd1be8..60bbf1f21c66 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -173,9 +173,6 @@ # Supported formats are defined here https://graphviz.org/docs/outputs/ torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg") -enable_autograd_cache = os.environ.get("ENABLE_AOT_AUTOGRAD_CACHE", "0") == "1" - - if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 208ad73cce5a..574511d004a4 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -541,9 +541,6 @@ def get_str(obj) -> str: return str(extract_tensor_metadata_for_cache_key(obj)) elif isinstance(obj, bytes): return "" - elif type(obj) in cls.dispatch_table: - # Run the reducer on the object - return str(cls.dispatch_table[type(obj)](obj)[1]) else: return str(obj) @@ -788,8 +785,8 @@ def _get_shape_env() -> Optional[ShapeEnv]: def _lookup_graph( key: str, example_inputs: List[torch.Tensor], - local: bool, - remote_cache: Optional[Any], + local, + remote_cache, ) -> Optional[CompiledFxGraph]: """ Lookup a compiled graph in the cache by key. On a hit, return the @@ -1040,7 +1037,6 @@ def load( compiled_graph = FxGraphCache._lookup_graph( key, example_inputs, local, remote_cache ) - if compiled_graph is None: log.debug("fx graph cache miss for key %s", key) counters["inductor"]["fxgraph_cache_miss"] += 1 @@ -1058,7 +1054,6 @@ def load( else: log.debug("fx graph cache hit for key %s", key) counters["inductor"]["fxgraph_cache_hit"] += 1 - compiled_graph._fx_graph_cache_key = key except BypassFxGraphCache: counters["inductor"]["fxgraph_cache_bypass"] += 1 if not compiled_graph: @@ -1105,7 +1100,6 @@ class CompiledFxGraph: guards_expr: Optional[str] _boxed_call: Optional[bool] = None - _fx_graph_cache_key: Optional[str] = None def __init__( self, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index d25381a3927b..f59fa34fe9c0 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -46,6 +46,7 @@ import sympy import torch +import torch._export import torch.utils._pytree as pytree from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.utils import detect_fake_mode @@ -1684,8 +1685,6 @@ def aoti_compile_with_persistent_cache( Compile the given function with persistent cache for AOTI eager mode. """ assert not dynamic, "Only support static shape for now" - from torch._export import aot_compile - type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool} supported_scalar_types = tuple(type_to_torch_dtype.keys()) flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) @@ -1708,7 +1707,7 @@ def aoti_compile_with_persistent_cache( {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, ): try: - kernel_lib_path = aot_compile( + kernel_lib_path = torch._export.aot_compile( f, args, kwargs, From 5ef70faaa76364a73cd7f9da2d3f8e23da218b02 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Wed, 12 Jun 2024 05:55:27 +0000 Subject: [PATCH 691/706] Revert "Make torch_geometric models compatible with export (#123403)" (#128377) This reverts commit d78991a7381adb3df5e9b63c365db4506643edce. This PR reverts https://github.com/pytorch/pytorch/pull/123403 to fix the performance regression as discussed in https://github.com/pytorch/pytorch/issues/127513#issuecomment-2158835653. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128377 Approved by: https://github.com/jgong5, https://github.com/angelayi, https://github.com/desertfire --- benchmarks/dynamo/common.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 39c3a3cda3e3..154651d4fbb7 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1184,12 +1184,14 @@ def load(cls, model, example_inputs, device): else: _register_dataclass_output_as_pytree(example_outputs) - gm = torch.export._trace._export( + # TODO(angelayi): change this to predispatch + # https://github.com/pytorch/pytorch/issues/127513 needs to be fixed before changing + # to predispatch to avoid performance regressions + gm = torch.export._trace._export_to_torch_ir( model, example_args, example_kwargs, - pre_dispatch=True, - ).module() + ) with torch.no_grad(): so_path = torch._inductor.aot_compile( gm, example_args, example_kwargs From 15ab636007223803b7d77c14b881ea337ce75a32 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 12 Jun 2024 14:55:56 +0000 Subject: [PATCH 692/706] Revert "Fix side effect pruning (#128028)" This reverts commit a55d0d9718c11eb2897423c78eff18b168dd0a06. Reverted https://github.com/pytorch/pytorch/pull/128028 on behalf of https://github.com/clee2000 due to broke test in internal D58443816. Test exists in external too though ([comment](https://github.com/pytorch/pytorch/pull/128028#issuecomment-2163249251)) --- test/dynamo/test_higher_order_ops.py | 40 +++++++++++++++---------- torch/_dynamo/side_effects.py | 44 ++++++++-------------------- 2 files changed, 36 insertions(+), 48 deletions(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 410317d33a14..c934cf55e8f5 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -3,6 +3,7 @@ import functools import pprint import re +import sys import unittest import warnings @@ -2859,7 +2860,7 @@ def forward(self, L_x_: "f32[4, 3]"): _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1) - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim_1 = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); _add_batch_dim_1 = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None @@ -2895,7 +2896,7 @@ def forward(self, L_x_: "f32[4, 3]"): jac_out_in: "f32[4, 3, 4, 3, 12]" = split_2[0]; split_2 = None unflatten: "f32[4, 3, 4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None - return (unflatten,) + return (unflatten, diff_primals, o) """, ) @@ -2963,8 +2964,8 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() - _wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None - child_4 = torch._C._functorch._wrap_for_grad(child_3, 3); child_3 = None + _wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3) + child_4 = torch._C._functorch._wrap_for_grad(child_3, 3) set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) @@ -3001,7 +3002,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1) - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = child_4 = _add_batch_dim_1 = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); _add_batch_dim_1 = None child_5 = _autograd_grad[0]; _autograd_grad = None child_6 = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None @@ -3040,10 +3041,17 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): unflatten: "f32[4, 3, 3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None""", ) - self.assertExpectedInline( - actual.split("\n")[-2], - """ return (unflatten,)""", - ) + # Python 3.10 and 3.11 produces slightly different graphs + if sys.version_info[:2] > (3, 10): + self.assertExpectedInline( + actual.split("\n")[-2], + """ return (unflatten, child_2, _wrap_for_grad_1, child_3, child_4, o)""", + ) + else: + self.assertExpectedInline( + actual.split("\n")[-2], + """ return (unflatten, child_3, child_2, _wrap_for_grad_1, child_4, o)""", + ) @unittest.expectedFailure def test_hessian_disable_capture(self): @@ -3152,7 +3160,7 @@ def forward(self, L_x_: "f32[4, 3]"): _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None @@ -3164,7 +3172,7 @@ def forward(self, L_x_: "f32[4, 3]"): split_1: "f32[12, 4, 3]" = split[0]; split = None output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None - return (output_input,) + return (output_input, diff_primals, o) """, ) @@ -3235,7 +3243,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None @@ -3247,7 +3255,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): split_1: "f32[12, 3, 4]" = split[0]; split = None output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None - return (output_input,) + return (output_input, diff_primals, o) """, ) @@ -3320,7 +3328,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None @@ -3332,7 +3340,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): split_1: "f32[12, 3, 4]" = split[0]; split = None output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None - return (output_input, aux_1) + return (output_input, aux_1, diff_primals, o) """, ) @@ -3768,7 +3776,7 @@ def forward(self, L_x_: "f32[3, 3, 3]"): _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() - return (y, grad_input_1) + return (grad_input_1, y) """, ) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index c3d23728093a..4072f7641f84 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -301,25 +301,14 @@ def track_tensor_variables_from_runahead_side_effects(self, other): def prune_dead_object_new(self, tx): live_new_objects = set() - - # use this to avoid cycles in mutable_local (though I'm not sure if that - # can actually happen). - visited: Any = set({}) + skip_obj = None def visit(var: VariableTracker): - if isinstance(var.mutable_local, AttributeMutationNew): - if var in visited: - return - visited.add(var) - # Object may have been mutated, store this mutation. + if ( + isinstance(var.mutable_local, AttributeMutationNew) + and var.mutable_local is not skip_obj + ): live_new_objects.add(var.mutable_local) - # It's possible that we have mutated the value of this variable - # to be another one. The new value is in store_attr_mutations. - # Also recurse through the new value to detect alive AttributeMutationNew. - if var.mutable_local in self.store_attr_mutations: - VariableTracker.visit( - visit, self.store_attr_mutations[var.mutable_local] - ) def is_live(var: Union[MutableLocalBase, VariableTracker]): if isinstance(var, AttributeMutationNew): @@ -328,22 +317,13 @@ def is_live(var: Union[MutableLocalBase, VariableTracker]): return is_live(var.mutable_local) return True - pre_existing_vars = [ - var - for var in self.id_to_variable.values() - if not isinstance(var.mutable_local, AttributeMutationNew) - ] - - # The only live side effects come from returns (tx.stack), any intermediates - # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables. - # Recursively visit Variables and see if any of them have been mutated. - VariableTracker.visit(visit, (tx.stack, tx.symbolic_locals, pre_existing_vars)) - - # NB: cell variable handling.is tricky. - # cell variables must stay alive if any NestedUserFunctionVariable - # are live. "visit"-ing the NestedUserFunctionVariable visits - # the .closures field, from which we will see if we need to keep - # any mutations to cell variables alive. + VariableTracker.visit(visit, (tx.stack, tx.symbolic_locals)) + for var in self.id_to_variable.values(): + if not isinstance(var.mutable_local, AttributeMutationNew): + VariableTracker.visit(visit, var) + + for skip_obj, setattrs in self.store_attr_mutations.items(): + VariableTracker.visit(visit, setattrs) self.id_to_variable = { k: v for k, v in self.id_to_variable.items() if is_live(v) From 3c971d2ef30bb07c460d40b9d983acc42b96cb74 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Wed, 12 Jun 2024 08:06:38 -0700 Subject: [PATCH 693/706] Flip default value for mypy disallow_untyped_defs [final] (#127836) Not requiring all functions to have types allows a lot of 'Any' types to slip in - which poison types and make mypy unable to properly typecheck the code. I want to flip the default so that new files are required to have fully typed defs and we can have a burndown list of files that fail to require full types. The preceding stack of PRs (cut up simply to limit the number of file changes per PR "reasonable") adds `# mypy: allow-untyped-defs` to any file which didn't immediately pass mypy with the flag flipped. Due to changing files and merge conflicts it will probably be necessary to have several passes through before landing this final PR which turns the option on. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127836 Approved by: https://github.com/oulgen, https://github.com/Skylion007 --- .github/scripts/lintrunner.sh | 1 + mypy.ini | 1 + test/test_utils.py | 8 +++++--- test/typing/pass/cuda_steam.py | 2 +- torch/_C/return_types.pyi.in | 1 + .../_tensor/examples/display_sharding_example.py | 4 +++- torch/distributed/fsdp/_flat_param.py | 2 +- torch/nn/functional.pyi.in | 1 + torch/nn/parallel/distributed.py | 2 +- torch/utils/data/datapipes/datapipe.pyi.in | 1 + 10 files changed, 16 insertions(+), 7 deletions(-) diff --git a/.github/scripts/lintrunner.sh b/.github/scripts/lintrunner.sh index 50a04ef487a0..ae3c203cf70f 100755 --- a/.github/scripts/lintrunner.sh +++ b/.github/scripts/lintrunner.sh @@ -29,6 +29,7 @@ python3 -m tools.pyi.gen_pyi \ --native-functions-path aten/src/ATen/native/native_functions.yaml \ --tags-path aten/src/ATen/native/tags.yaml \ --deprecated-functions-path "tools/autograd/deprecated.yaml" +python3 torch/utils/data/datapipes/gen_pyi.py RC=0 # Run lintrunner on all files diff --git a/mypy.ini b/mypy.ini index 7d51847da44b..c4fef0f5ba6f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,6 +11,7 @@ warn_redundant_casts = True show_error_codes = True show_column_numbers = True check_untyped_defs = True +disallow_untyped_defs = True follow_imports = normal local_partial_types = True enable_error_code = possibly-undefined diff --git a/test/test_utils.py b/test/test_utils.py index b0435e548311..df41b9b538be 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1178,9 +1178,11 @@ def test_device_mode_ops(self, device, dtype, op): kwargs.pop("device", None) with torch.device("meta"): r = func(sample.input, *sample.args, **kwargs) - self.assertTrue( - tree_all_only(torch.Tensor, lambda x: x.device.type == "meta", r) - ) + + def is_meta_device(x: torch.Tensor) -> bool: + return x.device.type == "meta" + + self.assertTrue(tree_all_only(torch.Tensor, is_meta_device, r)) instantiate_device_type_tests(TestDeviceUtils, globals()) diff --git a/test/typing/pass/cuda_steam.py b/test/typing/pass/cuda_steam.py index 0953effebbc2..bf9a40481b16 100644 --- a/test/typing/pass/cuda_steam.py +++ b/test/typing/pass/cuda_steam.py @@ -1,6 +1,6 @@ import torch -def foo(x: torch.Tensor): +def foo(x: torch.Tensor) -> None: stream = torch.cuda.current_stream() x.record_stream(stream) diff --git a/torch/_C/return_types.pyi.in b/torch/_C/return_types.pyi.in index 458a076d7bfe..fc1e2974bd4d 100644 --- a/torch/_C/return_types.pyi.in +++ b/torch/_C/return_types.pyi.in @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # ${generated_comment} from typing import ( diff --git a/torch/distributed/_tensor/examples/display_sharding_example.py b/torch/distributed/_tensor/examples/display_sharding_example.py index 0e32ed074534..4a0eb113e9c3 100644 --- a/torch/distributed/_tensor/examples/display_sharding_example.py +++ b/torch/distributed/_tensor/examples/display_sharding_example.py @@ -1,4 +1,4 @@ -import os +# mypy: allow-untyped-defs from typing import Any, Dict import torch @@ -168,6 +168,8 @@ def run_example(world_size, rank): if __name__ == "__main__": # this script is launched via torchrun which automatically manages ProcessGroup + import os + rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) assert world_size == 4 # our example uses 4 worker ranks diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index f3e918349af7..816b91433063 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -1141,7 +1141,7 @@ def shard_metadata( tuple(fqns_list), tuple(shapes_list), tuple(numels_list), - shard_param_offsets, + tuple(shard_param_offsets), ) @no_type_check diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 5bb847a0a727..9dec24809e24 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import ( Any, Callable, diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 34c593cd2c14..80ed52d9a0b6 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -804,7 +804,7 @@ def __init__( ) # Initialize gradient buffers and register all reduce hook - self._delay_grad_buffer = None + self._delay_grad_buffer: Optional[torch.Tensor] = None self._delay_grad_views: List[torch.Tensor] = [] self._delay_all_reduce_all_params = False if len(self._delay_all_reduce_params) != 0: diff --git a/torch/utils/data/datapipes/datapipe.pyi.in b/torch/utils/data/datapipes/datapipe.pyi.in index 6b3cbe34b46a..4d03665d5d66 100644 --- a/torch/utils/data/datapipes/datapipe.pyi.in +++ b/torch/utils/data/datapipes/datapipe.pyi.in @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection # The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt # Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other From b19c2319e46d4d5dc01ae1dd166070cd103b0e47 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 12 Jun 2024 15:53:39 +0000 Subject: [PATCH 694/706] [ROCm] TunableOp for gemm_and_bias (#128143) Thus far TunableOp was implemented for gemm, bgemm, and scaled_mm. gemm_and_bias was notably missing. This PR closes that gap. This PR also fixes a regression after #124362 disabled the numerical check by default. The env var to enable it no longer worked. CC @xw285cornell Pull Request resolved: https://github.com/pytorch/pytorch/pull/128143 Approved by: https://github.com/Skylion007 --- aten/src/ATen/cuda/tunable/GemmCommon.h | 76 +++++++++++- aten/src/ATen/cuda/tunable/GemmHipblaslt.h | 133 +++++++++++++++++---- aten/src/ATen/cuda/tunable/Tunable.cpp | 4 +- aten/src/ATen/cuda/tunable/TunableGemm.h | 68 ++++++++++- aten/src/ATen/native/cuda/Blas.cpp | 63 ++++++++-- 5 files changed, 306 insertions(+), 38 deletions(-) diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index a2c7c734a551..64a482bc2781 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -81,7 +81,8 @@ struct GemmParams : OpParams { } std::string Signature() const override { - return c10::str(transa, transb, "_", m, "_", n, "_", k); + static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); + return val; } size_t GetSize(bool duplicate_inputs) const { @@ -143,6 +144,73 @@ struct GemmParams : OpParams { bool duplicate_inputs_; }; +template +struct GemmAndBiasParams : OpParams { + std::string Signature() const override { + static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); + return val; + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = sizeof(T) * ldc * n; + if (duplicate_inputs) { + size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + } + return size; + } + + GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const { + GemmAndBiasParams* copy = new GemmAndBiasParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = ldc * n * sizeof(T); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmAndBiasParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; + } + + char transa; + char transb; + int64_t m; + int64_t n; + int64_t k; + at::opmath_type alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + T* c; + int64_t ldc; + const T* bias; + at::cuda::blas::GEMMAndBiasActivationEpilogue activation; +private: + bool duplicate_inputs_; +}; + template struct GemmStridedBatchedParams : OpParams { GemmStridedBatchedParams() { @@ -150,7 +218,8 @@ struct GemmStridedBatchedParams : OpParams { } std::string Signature() const override { - return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); + static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); + return val; } size_t GetSize(bool duplicate_inputs) const { @@ -223,7 +292,8 @@ struct ScaledGemmParams : OpParams { } std::string Signature() const override { - return c10::str(transa, transb, "_", m, "_", n, "_", k); + static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); + return val; } size_t GetSize(bool duplicate_inputs) const { diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index a9c420700275..ab1525bef652 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -25,35 +25,35 @@ namespace at::cuda::tunable { template -constexpr hipblasDatatype_t HipBlasDataTypeFor(); +constexpr hipblasDatatype_t HipDataTypeFor(); template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { - return HIPBLAS_R_32F; +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_32F; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { - return HIPBLAS_R_16F; +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_16F; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { - return HIPBLAS_R_16B; +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_16BF; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { - return HIPBLAS_R_64F; +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_64F; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { +constexpr hipblasDatatype_t HipDataTypeFor() { return HIP_R_8F_E4M3_FNUZ; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { +constexpr hipblasDatatype_t HipDataTypeFor() { return HIP_R_8F_E5M2_FNUZ; } @@ -62,6 +62,11 @@ int GetBatchFromParams(const GemmParams* params) { return 1; } +template +int GetBatchFromParams(const GemmAndBiasParams* params) { + return 1; +} + template int GetBatchFromParams(const GemmStridedBatchedParams* params) { return params->batch; @@ -77,6 +82,11 @@ int GetStrideAFromParams(const GemmParams* params) { return 1; } +template +int GetStrideAFromParams(const GemmAndBiasParams* params) { + return 1; +} + template int GetStrideAFromParams(const GemmStridedBatchedParams* params) { return params->stride_a; @@ -92,6 +102,11 @@ int GetStrideBFromParams(const GemmParams* params) { return 1; } +template +int GetStrideBFromParams(const GemmAndBiasParams* params) { + return 1; +} + template int GetStrideBFromParams(const GemmStridedBatchedParams* params) { return params->stride_b; @@ -107,6 +122,11 @@ int GetStrideCFromParams(const GemmParams* params) { return 1; } +template +int GetStrideCFromParams(const GemmAndBiasParams* params) { + return 1; +} + template int GetStrideCFromParams(const GemmStridedBatchedParams* params) { return params->stride_c; @@ -122,6 +142,11 @@ float GetAlphaFromParams(const GemmParams* params) { return params->alpha; } +template +float GetAlphaFromParams(const GemmAndBiasParams* params) { + return params->alpha; +} + template float GetAlphaFromParams(const GemmStridedBatchedParams* params) { return params->alpha; @@ -137,6 +162,11 @@ float GetBetaFromParams(const GemmParams* params) { return params->beta; } +template +float GetBetaFromParams(const GemmAndBiasParams* params) { + return 0.0; +} + template float GetBetaFromParams(const GemmStridedBatchedParams* params) { return params->beta; @@ -152,6 +182,11 @@ const void* GetAScalePointerFromParams(const GemmParams* params) { return nullptr; } +template +const void* GetAScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + template const void* GetAScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -167,6 +202,11 @@ const void* GetBScalePointerFromParams(const GemmParams* params) { return nullptr; } +template +const void* GetBScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + template const void* GetBScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -182,6 +222,11 @@ const void* GetDScalePointerFromParams(const GemmParams* params) { return nullptr; } +template +const void* GetDScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + template const void* GetDScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -197,6 +242,11 @@ const void* GetBiasPointerFromParams(const GemmParams* params) { return nullptr; } +template +const void* GetBiasPointerFromParams(const GemmAndBiasParams* params) { + return params->bias; +} + template const void* GetBiasPointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -212,6 +262,11 @@ hipDataType GetBiasTypeFromParams(const GemmParams* params) { return HIP_R_32F; } +template +hipDataType GetBiasTypeFromParams(const GemmAndBiasParams* params) { + return HipDataTypeFor(); +} + template hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams* params) { return HIP_R_32F; @@ -222,6 +277,26 @@ hipDataType GetBiasTypeFromParams(const ScaledGemmParams* params) { return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype); } +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams* params) { + return params->activation; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + static hipblasOperation_t _hipblasOpFromChar(char op) { switch (op) { case 'n': @@ -327,9 +402,9 @@ class HipblasltGemmOp : public Callable { TuningStatus Call(const ParamsT* params) override { hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); - auto a_datatype = HipBlasDataTypeFor(); - auto b_datatype = HipBlasDataTypeFor(); - auto in_out_datatype = HipBlasDataTypeFor(); + auto a_datatype = HipDataTypeFor(); + auto b_datatype = HipDataTypeFor(); + auto in_out_datatype = HipDataTypeFor(); auto opa = _hipblasOpFromChar(params->transa); auto opb = _hipblasOpFromChar(params->transb); @@ -385,13 +460,22 @@ class HipblasltGemmOp : public Callable { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); + } - const void* bias_ptr = GetBiasPointerFromParams(params); - auto bias_datatype = GetBiasTypeFromParams(params); - if (bias_ptr) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); + const void* bias_ptr = GetBiasPointerFromParams(params); + auto bias_datatype = GetBiasTypeFromParams(params); + if (bias_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); + auto activation = GetActivationFromParams(params); + if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS); + } + else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS); + } + else { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); } } @@ -460,9 +544,9 @@ template (); - auto b_datatype = HipBlasDataTypeFor(); - auto in_out_datatype = HipBlasDataTypeFor(); + auto a_datatype = HipDataTypeFor(); + auto b_datatype = HipDataTypeFor(); + auto in_out_datatype = HipDataTypeFor(); std::vector heuristic_result; hipblasLtHandle_t handle; @@ -505,6 +589,11 @@ auto GetHipBlasLtGemmTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); } +template +auto GetHipBlasLtGemmAndBiasTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + template auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index fc27fab77d79..d3d2333323e7 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -376,8 +376,8 @@ void TuningContext::EnableNumericsCheck(bool value) { bool TuningContext::IsNumericsCheckEnabled() const { static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); - if (env != nullptr && strcmp(env, "0") == 0) { - return false; + if (env != nullptr && strcmp(env, "1") == 0) { + return true; } return numerics_check_enable_; } diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index 53e6154120c9..6b02e26ade4d 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -48,6 +48,28 @@ class DefaultGemmOp : public Callable> { } }; +static bool _transposeBoolFromChar(char op) { + return op == 't' || op == 'T'; +} + +template +class DefaultGemmAndBiasOp : public Callable> { + public: + TuningStatus Call(const GemmAndBiasParams* params) override { + at::cuda::blas::gemm_and_bias( + _transposeBoolFromChar(params->transa), + _transposeBoolFromChar(params->transb), + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, + params->b, params->ldb, + params->bias, + params->c, params->ldc, + params->activation); + return OK; + } +}; + template class DefaultGemmStridedBatchedOp : public Callable> { public: @@ -265,7 +287,45 @@ class GemmTunableOp : public TunableOp, StreamTimer> { } std::string Signature() override { - return c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + static std::string val = c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return val; + } +}; + +template +class GemmAndBiasTunableOp : public TunableOp, StreamTimer> { + public: + GemmAndBiasTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + +#if defined(USE_ROCM) + bool rocm_validators = false; + + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + rocm_validators = true; + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + AddHipblasltValidator(); + } + + if (rocm_validators) { + AddRocmValidator(); + } +#endif + } + + std::string Signature() override { + static std::string val = c10::str("GemmAndBiasTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return val; } }; @@ -308,7 +368,8 @@ class GemmStridedBatchedTunableOp : public TunableOp } std::string Signature() override { - return c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + static std::string val = c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return val; } }; @@ -330,11 +391,12 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> } std::string Signature() override { - return c10::str("ScaledGemmTunableOp", + static std::string val = c10::str("ScaledGemmTunableOp", "_", TypeName(AT{}), "_", TypeName(BT{}), "_", TypeName(CT{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return val; } }; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 84c59a4fd0d7..f7997fe72712 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -175,12 +175,6 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa static bool getDisableAddmmCudaLt() { static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT"); #ifdef USE_ROCM - // if we enable tunable op, it'll take priority over just hipblaslt (heuristics) - // note the current tunable op is not the hipblaslt path (gemm_and_bias) - auto tuning_ctx = at::cuda::tunable::getTuningContext(); - if (tuning_ctx->IsTunableOpEnabled()) { - return true; - } // allow both CUDA and HIP env var names for ROCm builds // also, current default for ROCm builds is disable by default if (env_value == nullptr) { @@ -214,6 +208,49 @@ static bool isSupportedHipLtROCmArch(int index) { } #endif +template +static void launchTunableGemmAndBias(cublasCommonArgs &args, Tensor& result, const Tensor& self, bool is_rocm) { + bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); + bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); + at::cuda::tunable::GemmAndBiasParams params; + params.transa = args.transa; + params.transb = args.transb; + params.m = args.m; + params.n = args.n; + params.k = args.k; + params.a = args.mata->const_data_ptr(); + params.lda = args.lda; + params.b = args.matb->const_data_ptr(); + params.ldb = args.ldb; + if (is_rocm) { + params.bias = (&result != &self) ? self.const_data_ptr() : nullptr; + } + else { + params.bias = self.const_data_ptr(); + } + params.c = args.result->data_ptr(); + params.ldc = args.result_ld; + if (transa_ && transb_) { + static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; + gemm(¶ms); + } + else if (transa_ && !transb_) { + static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; + gemm(¶ms); + } + else if (!transa_ && transb_) { + static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; + gemm(¶ms); + } + else if (!transa_ && !transb_) { + static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; + gemm(¶ms); + } + else { + TORCH_CHECK(false, "unreachable"); + } +} + Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) { // Make sure to keep addmm_cuda below in sync with this code; it // preflights a check to try to avoid actually needing to call @@ -341,6 +378,11 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type, "addmm_cuda_lt", [&] { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + launchTunableGemmAndBias(args, result, self, true); + } + else { at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', @@ -359,7 +401,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma args.result_ld, activation_to_gemm_and_blas_arg(activation) ); - }); + }}); #else auto activation_epilogue = activation_to_gemm_and_blas_arg(activation); #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080)) @@ -377,6 +419,11 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type, "addmm_cuda_lt", [&] { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + launchTunableGemmAndBias(args, result, self, false); + } + else { at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', @@ -393,7 +440,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma args.result_ld, activation_epilogue ); - }); + }}); #endif } else { From 8df56afc200d1684d97add573b90dcc4978daceb Mon Sep 17 00:00:00 2001 From: Kulin Seth Date: Wed, 12 Jun 2024 16:03:57 +0000 Subject: [PATCH 695/706] Add support in Python API for the recommended max working set size. (#128289) Adds ways for users to request recommended max size for Metal on Mac. It plumbs through https://developer.apple.com/documentation/metal/mtldevice/2369280-recommendedmaxworkingsetsize?language=objc Can be used like ``` max_memory = torch.mps.recommended_max_memory() print ("Recommended Max Memory : ", (max_memory/(1024*1024*1024)), "GB") ``` Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/128289 Approved by: https://github.com/malfet --- aten/src/ATen/detail/MPSHooksInterface.h | 3 +++ aten/src/ATen/mps/MPSAllocator.h | 2 ++ aten/src/ATen/mps/MPSAllocator.mm | 3 +++ aten/src/ATen/mps/MPSAllocatorInterface.h | 1 + aten/src/ATen/mps/MPSHooks.h | 1 + aten/src/ATen/mps/MPSHooks.mm | 4 ++++ docs/source/mps.rst | 1 + test/test_mps.py | 5 +++++ torch/_C/__init__.pyi.in | 1 + torch/_dynamo/trace_rules.py | 1 + torch/csrc/mps/Module.cpp | 13 +++++++++++++ torch/mps/__init__.py | 11 +++++++++++ 12 files changed, 46 insertions(+) diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index dea35e671267..814b5aeb72d8 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -57,6 +57,9 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { virtual size_t getDriverAllocatedMemory() const { FAIL_MPSHOOKS_FUNC(__func__); } + virtual size_t getRecommendedMaxMemory() const { + FAIL_MPSHOOKS_FUNC(__func__); + } virtual void setMemoryFraction(double /*ratio*/) const { FAIL_MPSHOOKS_FUNC(__func__); } diff --git a/aten/src/ATen/mps/MPSAllocator.h b/aten/src/ATen/mps/MPSAllocator.h index bdf19e8d7362..1dc8c434f85b 100644 --- a/aten/src/ATen/mps/MPSAllocator.h +++ b/aten/src/ATen/mps/MPSAllocator.h @@ -308,6 +308,8 @@ class MPSHeapAllocatorImpl { // total GPU memory allocated in the process by Metal driver; including // implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl. size_t getDriverAllocatedMemory() const { return current_allocated_size(); } + // recommended Max memory for Metal + size_t getRecommendedMaxMemory() const { return max_device_size(); } // (see enum DebugVerbosity for description) uint32_t getDebugVerbosity() const { return m_debug_verbosity; } // returns the device that we allocate from diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm index 76280fb469e5..0c2a86948a4c 100644 --- a/aten/src/ATen/mps/MPSAllocator.mm +++ b/aten/src/ATen/mps/MPSAllocator.mm @@ -794,6 +794,9 @@ size_t getCurrentAllocatedMemory() const override { size_t getDriverAllocatedMemory() const override { return _getAllocImpl().getDriverAllocatedMemory(); } + size_t getRecommendedMaxMemory() const override { + return _getAllocImpl().getRecommendedMaxMemory(); + } ssize_t getLowWatermarkValue() const override { return _getAllocImpl().getLowWatermarkValue(); } diff --git a/aten/src/ATen/mps/MPSAllocatorInterface.h b/aten/src/ATen/mps/MPSAllocatorInterface.h index e30a02c3fb21..cce232fd6937 100644 --- a/aten/src/ATen/mps/MPSAllocatorInterface.h +++ b/aten/src/ATen/mps/MPSAllocatorInterface.h @@ -33,6 +33,7 @@ class IMPSAllocator : public c10::Allocator { virtual size_t getTotalAllocatedMemory() const = 0; virtual size_t getCurrentAllocatedMemory() const = 0; virtual size_t getDriverAllocatedMemory() const = 0; + virtual size_t getRecommendedMaxMemory() const = 0; virtual std::pair getSharedBufferPtr(const void* ptr) const = 0; virtual bool recordEvents(c10::ArrayRef buffers) const = 0; virtual bool waitForEvents(c10::ArrayRef buffers) const = 0; diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index 667430eaf811..dea8f25fa7fd 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -32,6 +32,7 @@ struct MPSHooks : public at::MPSHooksInterface { void emptyCache() const override; size_t getCurrentAllocatedMemory() const override; size_t getDriverAllocatedMemory() const override; + size_t getRecommendedMaxMemory() const override; void setMemoryFraction(double ratio) const override; // MPSProfiler interface diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index 387359592a74..285c0771c3c6 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -80,6 +80,10 @@ return at::mps::getIMPSAllocator()->getDriverAllocatedMemory(); } +size_t MPSHooks::getRecommendedMaxMemory() const { + return at::mps::getIMPSAllocator()->getRecommendedMaxMemory(); +} + void MPSHooks::setMemoryFraction(double ratio) const { at::mps::getIMPSAllocator()->setHighWatermarkRatio(ratio); } diff --git a/docs/source/mps.rst b/docs/source/mps.rst index bab0d3378ea8..86195242566f 100644 --- a/docs/source/mps.rst +++ b/docs/source/mps.rst @@ -17,6 +17,7 @@ torch.mps set_per_process_memory_fraction current_allocated_memory driver_allocated_memory + recommended_max_memory MPS Profiler ------------ diff --git a/test/test_mps.py b/test/test_mps.py index 00fc5c01c78d..c59a598facc4 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -7893,6 +7893,11 @@ def test_mps_allocator_module(self): self.assertTrue(current_alloc_after > current_alloc_before) self.assertTrue(driver_alloc_after > driver_alloc_before) + def test_mps_allocator_stats(self): + max_memory = torch.mps.recommended_max_memory() + print(f"Recommended Max Memory : {max_memory/ 1024 ** 3} GB") + self.assertTrue(max_memory > 0) + # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool, # press record, then run this python test, and press stop. Next expand # the os_signposts->PyTorchMPS and check if events or intervals are logged diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 30a4fb6c36c6..135ba3c27757 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1749,6 +1749,7 @@ def _mps_emptyCache() -> None: ... def _mps_setMemoryFraction(fraction: _float) -> None: ... def _mps_currentAllocatedMemory() -> _int: ... def _mps_driverAllocatedMemory() -> _int: ... +def _mps_recommendedMaxMemory() -> _int: ... def _mps_is_available() -> _bool: ... def _mps_is_on_macos_or_newer(major: _int, minor: _int) -> _bool: ... def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ... diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 4d3f5b11edb0..6078e3d4b76e 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -941,6 +941,7 @@ "torch._C._mps_currentAllocatedMemory", "torch._C._mps_deviceSynchronize", "torch._C._mps_driverAllocatedMemory", + "torch._C._mps_recommendedMaxMemory", "torch._C._mps_elapsedTimeOfEvents", "torch._C._mps_emptyCache", "torch._C._mps_get_default_generator", diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 2dcb215c574b..415fda6165dd 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -121,6 +121,15 @@ static PyObject* MPSModule_driverAllocatedMemory( END_HANDLE_TH_ERRORS } +static PyObject* MPSModule_recommendedMaxMemory( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packUInt64( + at::detail::getMPSHooks().getRecommendedMaxMemory()); + END_HANDLE_TH_ERRORS +} + static PyObject* MPSModule_profilerStartTrace( PyObject* _unused, PyObject* args) { @@ -244,6 +253,10 @@ static struct PyMethodDef _MPSModule_methods[] = { MPSModule_driverAllocatedMemory, METH_NOARGS, nullptr}, + {"_mps_recommendedMaxMemory", + MPSModule_recommendedMaxMemory, + METH_NOARGS, + nullptr}, {"_mps_profilerStartTrace", MPSModule_profilerStartTrace, METH_VARARGS, diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py index 0538ae50d1ad..5c61eaf91bd0 100644 --- a/torch/mps/__init__.py +++ b/torch/mps/__init__.py @@ -129,6 +129,16 @@ def driver_allocated_memory() -> int: return torch._C._mps_driverAllocatedMemory() +def recommended_max_memory() -> int: + r"""Returns recommended max Working set size for GPU memory in bytes. + + .. note:: + Recommended max working set size for Metal. + returned from device.recommendedMaxWorkingSetSize. + """ + return torch._C._mps_recommendedMaxMemory() + + from . import profiler from .event import Event @@ -145,4 +155,5 @@ def driver_allocated_memory() -> int: "driver_allocated_memory", "Event", "profiler", + "recommended_max_memory", ] From f2dcbe89d6613d211ff8a053b1e5f4b1f7f92952 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 12 Jun 2024 16:09:22 +0000 Subject: [PATCH 696/706] Revert "Prevent expansion of cat indexing to avoid int64 intermediate (#127815)" This reverts commit 793df7b7cb1473004837f5867f4c1c4b2b0f751d. Reverted https://github.com/pytorch/pytorch/pull/127815 on behalf of https://github.com/clee2000 due to the newly added test is failing internally D58444153. Test exists in opensource and passed in OSS CI, maybe env difference? ([comment](https://github.com/pytorch/pytorch/pull/127815#issuecomment-2163421968)) --- test/inductor/test_cuda_repro.py | 40 ------------------------------ torch/_inductor/bounds.py | 9 ------- torch/_inductor/codegen/common.py | 3 --- torch/_inductor/lowering.py | 12 ++------- torch/_inductor/utils.py | 10 ++------ torch/utils/_sympy/functions.py | 16 ------------ torch/utils/_sympy/interp.py | 2 -- torch/utils/_sympy/value_ranges.py | 4 --- 8 files changed, 4 insertions(+), 92 deletions(-) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 23243b7db5b5..8365d216f82c 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1238,46 +1238,6 @@ def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950 ) - def test_int64_index_intermediate(self): - def foo(inp): - view_23 = torch.ops.aten.view.default(inp, [-1, 8192, 8192]) - split_1 = torch.ops.aten.split.Tensor(view_23, 1024, 1) - view_23 = None - getitem_17 = split_1[0] - getitem_18 = split_1[1] - getitem_19 = split_1[2] - getitem_20 = split_1[3] - getitem_21 = split_1[4] - getitem_22 = split_1[5] - getitem_23 = split_1[6] - getitem_24 = split_1[7] - split_1 = None - cat_1 = torch.ops.aten.cat.default( - [ - getitem_17, - getitem_18, - getitem_19, - getitem_20, - getitem_21, - getitem_22, - getitem_23, - getitem_24, - ] - ) - getitem_17 = ( - getitem_18 - ) = ( - getitem_19 - ) = getitem_20 = getitem_21 = getitem_22 = getitem_23 = getitem_24 = None - return cat_1 - - for mark_dynamic in [False, True]: - inp = torch.rand((65536, 8192), dtype=torch.bfloat16, device="cuda") - if mark_dynamic: - torch._dynamo.mark_dynamic(inp, 0) - foo_c = torch.compile(foo) - torch.testing.assert_allclose(foo(inp), foo_c(inp)) - if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index b7bb37e5ee68..8c62ef2ba3c9 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -45,15 +45,6 @@ def upper_bound(v): # To access this variable call `get_bounds()` self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {} - def __repr__(self): - return ( - f"{self.__class__.__name__}(" - f"loop_body={self.loop_body},\n " - f"replacement_vals={self.replacement_vals}, \n" - f"unbounded_vars={self.unbounded_vars}, \n" - f"_bounds={self._bounds})" - ) - @cache_on_self def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: submodules = self.swap_submodules(self.loop_body.submodules) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 02aa3e7395f7..8ca6dc2b9153 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -393,9 +393,6 @@ def _print_FloatTrueDiv(self, expr): def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) - def _print_Identity(self, expr): - return self._print(expr.args[0]) - def _print_GreaterThan(self, expr): # GreaterThan: >= # StrictlyGreaterThan: > diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index e3457a27aa94..3b59620c7b89 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -35,13 +35,7 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.utils._sympy.functions import ( - CeilDiv, - FloorDiv, - Identity, - IntTrueDiv, - ModularIndexing, -) +from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 @@ -1022,9 +1016,7 @@ def inner_fn(idx): # if we're concatting [4], [2] # when we index the second tensor for 5 we want to index 5 - 4 - # Use Identity to prevent expansion of index * stride to keep expression - # in same int bitwidth as shape - idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0]) + idx_load[dim] -= inputs_ranges[i][0] masked_loads.append( ops.masked( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f59fa34fe9c0..d39713be81dd 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -54,13 +54,7 @@ from torch.autograd.profiler_util import EventList from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import ShapeProp -from torch.utils._sympy.functions import ( - CeilDiv, - CleanDiv, - FloorDiv, - Identity, - ModularIndexing, -) +from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import make_symbol, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges from . import config @@ -581,7 +575,7 @@ def sympy_str(expr: sympy.Expr) -> str: if isinstance(expr, sympy.Mul): return " * ".join(map(sympy_str, expr.args)) - if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)): + if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)): return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" return str(expr) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 25aa07cd5a5c..7b8387303336 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -22,7 +22,6 @@ "ToFloat", "FloatPow", "PowByNatural", - "Identity", ] @@ -658,21 +657,6 @@ def eval(cls, number): return sympy.Float(int(number)) -class Identity(sympy.Function): - """ - Prevents expansion and other optimizations - """ - - def __repr__(self): - return f"Identity({self.args[0]})" - - def _eval_is_real(self): - return self.args[0].is_real - - def _eval_is_integer(self): - return self.args[0].is_integer # type: ignore[attr-defined] - - def make_opaque_unary_fn(name): class OpaqueUnaryFn(sympy.Function): """ diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 1bb60da4f234..640b991cd104 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -22,7 +22,6 @@ FloatTrueDiv, FloorDiv, FloorToInt, - Identity, IntTrueDiv, IsNonOverlappingAndDenseIndicator, Mod, @@ -89,7 +88,6 @@ def handlers(): ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", - Identity: "identity", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", RoundDecimal: "round_decimal", } diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index e1ef17f3d340..97f47c4f28ac 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -462,10 +462,6 @@ def eq(a, b): def ne(cls, a, b): return cls.not_(cls.eq(a, b)) - @classmethod - def identity(cls, a): - return ValueRanges.wrap(a) - @classmethod def lt(cls, a, b): a = ValueRanges.wrap(a) From 9e39c62908de4831439397e30e9102f63c338682 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Wed, 12 Jun 2024 16:12:49 +0000 Subject: [PATCH 697/706] correct avx512_vnni isa name. (#128318) `x86` has two vnni isa currently: `avx2_vnni` and `avx512_vnni`. This PR correct the function name to `avx512_vnni`. Co-authored-by: Jiong Gong Pull Request resolved: https://github.com/pytorch/pytorch/pull/128318 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/desertfire --- aten/src/ATen/cpu/Utils.cpp | 2 +- aten/src/ATen/cpu/Utils.h | 2 +- torch/_C/_cpu.pyi | 2 +- torch/_dynamo/trace_rules.py | 4 ++-- torch/cpu/__init__.py | 3 ++- torch/csrc/cpu/Module.cpp | 2 +- 6 files changed, 8 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/cpu/Utils.cpp b/aten/src/ATen/cpu/Utils.cpp index 21b6f33877ed..fbf861dcabcf 100644 --- a/aten/src/ATen/cpu/Utils.cpp +++ b/aten/src/ATen/cpu/Utils.cpp @@ -20,7 +20,7 @@ bool is_cpu_support_avx512() { #endif } -bool is_cpu_support_vnni() { +bool is_cpu_support_avx512_vnni() { #if !defined(__s390x__) && !defined(__powerpc__) return cpuinfo_initialize() && cpuinfo_has_x86_avx512vnni(); #else diff --git a/aten/src/ATen/cpu/Utils.h b/aten/src/ATen/cpu/Utils.h index 805c7c64a21b..0ad6f8e893ca 100644 --- a/aten/src/ATen/cpu/Utils.h +++ b/aten/src/ATen/cpu/Utils.h @@ -8,6 +8,6 @@ TORCH_API bool is_cpu_support_avx2(); TORCH_API bool is_cpu_support_avx512(); // Detect if CPU support Vector Neural Network Instruction. -TORCH_API bool is_cpu_support_vnni(); +TORCH_API bool is_cpu_support_avx512_vnni(); } // namespace at::cpu diff --git a/torch/_C/_cpu.pyi b/torch/_C/_cpu.pyi index 641ba00312e0..37794bd7c10b 100644 --- a/torch/_C/_cpu.pyi +++ b/torch/_C/_cpu.pyi @@ -4,4 +4,4 @@ from torch.types import _bool def _is_cpu_support_avx2() -> _bool: ... def _is_cpu_support_avx512() -> _bool: ... -def _is_cpu_support_vnni() -> _bool: ... +def _is_cpu_support_avx512_vnni() -> _bool: ... diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 6078e3d4b76e..c6e2a848adec 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -408,7 +408,7 @@ "torch._C._conv_determine_backend_memory_format", "torch._C._cpu._is_cpu_support_avx2", "torch._C._cpu._is_cpu_support_avx512", - "torch._C._cpu._is_cpu_support_vnni", + "torch._C._cpu._is_cpu_support_avx512_vnni", "torch._C._crash_if_aten_asan", "torch._C._crash_if_csrc_asan", "torch._C._crash_if_csrc_ubsan", @@ -2422,7 +2422,7 @@ "torch.compiled_with_cxx11_abi", "torch.cpu._is_cpu_support_avx2", "torch.cpu._is_cpu_support_avx512", - "torch.cpu._is_cpu_support_vnni", + "torch.cpu._is_cpu_support_avx512_vnni", "torch.cpu.current_device", "torch.cpu.current_stream", "torch.cpu.device_count", diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index d2b8069048cc..d404ad4ba3b9 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -41,7 +41,8 @@ def _is_cpu_support_avx512() -> bool: def _is_cpu_support_vnni() -> bool: r"""Returns a bool indicating if CPU supports VNNI.""" - return torch._C._cpu._is_cpu_support_vnni() + # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later. + return torch._C._cpu._is_cpu_support_avx512_vnni() def is_available() -> bool: diff --git a/torch/csrc/cpu/Module.cpp b/torch/csrc/cpu/Module.cpp index b6c931eae0fe..3485f2a991cb 100644 --- a/torch/csrc/cpu/Module.cpp +++ b/torch/csrc/cpu/Module.cpp @@ -10,7 +10,7 @@ void initModule(PyObject* module) { auto cpu = m.def_submodule("_cpu", "cpu related pybind."); cpu.def("_is_cpu_support_avx2", at::cpu::is_cpu_support_avx2); cpu.def("_is_cpu_support_avx512", at::cpu::is_cpu_support_avx512); - cpu.def("_is_cpu_support_vnni", at::cpu::is_cpu_support_vnni); + cpu.def("_is_cpu_support_avx512_vnni", at::cpu::is_cpu_support_avx512_vnni); } } // namespace torch::cpu From c5172b8de851c0c534c415aac51c381950eb2905 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 12 Jun 2024 16:17:07 +0000 Subject: [PATCH 698/706] Revert "[AOTI] Switch to use shim v2 (#127674)" This reverts commit 9a38cae299e5ffd8143182bec878c28f96cfd72a. Reverted https://github.com/pytorch/pytorch/pull/127674 on behalf of https://github.com/clee2000 due to tests failed internally D56709309 ([comment](https://github.com/pytorch/pytorch/pull/127674#issuecomment-2163436728)) --- torch/_inductor/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 10374f577edf..2ea60000d265 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -43,7 +43,7 @@ def is_fbcode(): ) c_shim_version = os.environ.get( - "TORCHINDUCTOR_C_SHIM_VERSION", "1" if (is_fbcode() and torch.version.hip) else "2" + "TORCHINDUCTOR_C_SHIM_VERSION", "1" if is_fbcode() else "2" ) # dead code elimination From 81e4e12f02618777ba8784c3e91e33d48ab3dfca Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 12 Jun 2024 16:20:04 +0000 Subject: [PATCH 699/706] Revert "Support aten operations with out tensor (#124926)" This reverts commit cba195c8edd6c7149036ef0767772d11fff5390e. Reverted https://github.com/pytorch/pytorch/pull/124926 on behalf of https://github.com/clee2000 due to newly added test broke in internal D58444103. Test passed in OSS CI though ([comment](https://github.com/pytorch/pytorch/pull/124926#issuecomment-2163441547)) --- .../aot_inductor_torchbench_inference.csv | 2 +- .../aot_inductor_torchbench_inference.csv | 2 +- test/inductor/test_torchinductor.py | 131 +++++++----------- torch/_inductor/compile_fx.py | 18 +-- torch/export/_unlift.py | 10 +- 5 files changed, 57 insertions(+), 106 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv index 1624d6dc7973..e29c62dd5b71 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv @@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0 -pyhpc_isoneutral_mixing,pass,0 +pyhpc_isoneutral_mixing,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv index 1624d6dc7973..e29c62dd5b71 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv @@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0 -pyhpc_isoneutral_mixing,pass,0 +pyhpc_isoneutral_mixing,fail_to_run,0 diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index b33e01aebbdc..f9d736bcd413 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -233,23 +233,6 @@ def run_with_backward(): return run_and_get_code(run_with_backward) -def register_ops_with_aoti_compile(ns, op_set, dispatch_key, torch_compile_op_lib_impl): - for _op_name in op_set: - qualified_op_name = f"{ns}::{_op_name}" - _, overload_names = torch._C._jit_get_operation(qualified_op_name) - for overload_name in overload_names: - try: - reg_op_name = qualified_op_name - schema = torch._C._get_schema(qualified_op_name, overload_name) - if schema.overload_name: - reg_op_name = f"{qualified_op_name}.{schema.overload_name}" - torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 - reg_op_name, dispatch_key - ) - except Exception as e: - continue - - class TestCase(InductorTestCase): @classmethod def setUpClass(cls): @@ -768,58 +751,6 @@ def fn(a, b): ), ) - @skipCUDAIf(not SM80OrLater, "Requires sm80") - def test_eager_aoti_support_out(self): - ns = "aten" - op_name = "clamp" - dispatch_key = "CPU" - device = "cpu" - if self.device.lower() == "cuda": - dispatch_key = "CUDA" - device = "cuda" - - inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0) - min_tensor = inp_tensor - 0.05 - max_tensor = inp_tensor + 0.05 - with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: - ref_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_( - -1 - ) - ref_tensor = torch.clamp( - max=max_tensor, min=min_tensor, input=inp_tensor, out=ref_out_tensor - ) - - ref_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_( - -1 - ) - ref_tensor1 = torch.clamp( - max=max_tensor, out=ref_out_tensor1, min=min_tensor, input=inp_tensor - ) - - register_ops_with_aoti_compile( - ns, [op_name], dispatch_key, torch_compile_op_lib_impl - ) - - res_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_( - -1 - ) - res_tensor = torch.clamp( - max=max_tensor, min=min_tensor, input=inp_tensor, out=res_out_tensor - ) - - self.assertEqual(ref_tensor, res_tensor) - self.assertEqual(ref_out_tensor, res_out_tensor) - - res_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_( - -1 - ) - res_tensor1 = torch.clamp( - max=max_tensor, out=res_out_tensor1, min=min_tensor, input=inp_tensor - ) - - self.assertEqual(ref_tensor1, res_tensor1) - self.assertEqual(ref_out_tensor1, res_out_tensor1) - @skipCUDAIf(not SM80OrLater, "Requires sm80") def test_eager_aoti_cache_hit(self): ns = "aten" @@ -848,13 +779,24 @@ def test_eager_aoti_cache_hit(self): with mock.patch( "torch._inductor.utils.aoti_compile_with_persistent_cache", None ): + qualified_op_name = f"{ns}::{op_name}" + _, overload_names = torch._C._jit_get_operation(qualified_op_name) + with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: # Get ref result from eager ref_value = getattr(torch.ops.aten, op_name)(input_tensor) - register_ops_with_aoti_compile( - ns, [op_name], dispatch_key, torch_compile_op_lib_impl - ) + for overload_name in overload_names: + try: + reg_op_name = qualified_op_name + schema = torch._C._get_schema(qualified_op_name, overload_name) + if schema.overload_name: + reg_op_name = f"{qualified_op_name}.{schema.overload_name}" + torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 + reg_op_name, dispatch_key + ) + except Exception as e: + continue # Invoke the pre-compiled kernel and get result. res_value = getattr(torch.ops.aten, op_name)(input_tensor) @@ -862,7 +804,7 @@ def test_eager_aoti_cache_hit(self): self.assertEqual(ref_value, res_value) @skipCUDAIf(not SM80OrLater, "Requires sm80") - def test_eager_aoti_with_persistent_cache(self): + def test_aoti_compile_with_persistent_cache(self): def fn(a): return torch.abs(a) @@ -964,9 +906,19 @@ def test_eager_aoti_with_scalar(self): for scalar_value in scalar_values: ref_values.append(torch.add(a, b, alpha=scalar_value)) - register_ops_with_aoti_compile( - namespace_name, [op_name], dispatch_key, torch_compile_op_lib_impl - ) + qualified_op_name = f"{namespace_name}::{op_name}" + _, overload_names = torch._C._jit_get_operation(qualified_op_name) + for overload_name in overload_names: + try: + reg_op_name = qualified_op_name + schema = torch._C._get_schema(reg_op_name, overload_name) + if schema.overload_name: + reg_op_name = f"{reg_op_name}.{schema.overload_name}" + torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 + reg_op_name, dispatch_key + ) + except Exception as e: + continue res_values = [] for scalar_value in scalar_values: @@ -976,7 +928,8 @@ def test_eager_aoti_with_scalar(self): self.assertEqual(ref_values, res_values) @skipCUDAIf(not SM80OrLater, "Requires sm80") - def test_eager_aoti_override_registration(self): + def test_torch_compile_override_registration(self): + dynamic = False namespace_name = "aten" dispatch_key = "CPU" device = torch.device("cpu") @@ -998,10 +951,24 @@ def fn(x, op_name=""): ref = opt_fn(x) ref_array.append(ref) + def register_ops(op_set, dispatch_key, torch_compile_op_lib_impl): + for _op_name in op_set: + qualified_op_name = f"{namespace_name}::{_op_name}" + _, overload_names = torch._C._jit_get_operation(qualified_op_name) + for overload_name in overload_names: + try: + reg_op_name = qualified_op_name + schema = torch._C._get_schema(qualified_op_name, overload_name) + if schema.overload_name: + reg_op_name = f"{qualified_op_name}.{schema.overload_name}" + torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 + reg_op_name, dispatch_key + ) + except Exception as e: + continue + with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: - register_ops_with_aoti_compile( - namespace_name, unary_op_set, dispatch_key, torch_compile_op_lib_impl - ) + register_ops(unary_op_set, dispatch_key, torch_compile_op_lib_impl) res_array = [] for unary_op_name in unary_op_set: @@ -1018,9 +985,7 @@ def fn(x, op_name=""): ref_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor) with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: - register_ops_with_aoti_compile( - namespace_name, ["clamp"], dispatch_key, torch_compile_op_lib_impl - ) + register_ops(["clamp"], dispatch_key, torch_compile_op_lib_impl) res_with_min = torch.ops.aten.clamp(a, min_tensor) res_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor) self.assertEqual(ref_with_min, res_with_min) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index d49ed38902cb..f29e05a723ec 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -201,19 +201,11 @@ def _unlift_graph(mod, gm, graph_signature): outputs = list(gm.graph.nodes)[-1].args[0] mutated_outputs = [] - buffer_mutations = graph_signature.buffers_to_mutate - user_input_mutations = graph_signature.user_inputs_to_mutate - output_tokens = graph_signature.output_tokens - for idx, out in enumerate(outputs): - value = None - - if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): - if out.name in buffer_mutations: - value = buffer_mutations[out.name] - elif out.name in user_input_mutations: - value = user_input_mutations[out.name] - - mutated_outputs.append(value) + for out in outputs: + if out.name in graph_signature.buffers_to_mutate: + mutated_outputs.append(graph_signature.buffers_to_mutate[out.name]) + else: + mutated_outputs.append(None) unlifted_gm = _unlift( gm, diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 5a8f144b04e0..97df0562caa7 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -86,7 +86,6 @@ def _insert_copy_for_mutations( assert len(outputs) == len(mutated_outputs) user_output_nodes = [] - return_nodes_to_copy = {} for return_node, mutated_node_name in zip(outputs, mutated_outputs): if mutated_node_name is None: user_output_nodes.append(return_node) @@ -102,18 +101,13 @@ def _insert_copy_for_mutations( ) with gm.graph.inserting_before(output_node): - copy_node = gm.graph.call_function( + _ = gm.graph.call_function( torch.ops.aten.copy_.default, (mutated_node, return_node) ) - return_nodes_to_copy[return_node] = copy_node - output_args = [ - return_nodes_to_copy[node] if node in return_nodes_to_copy else node - for node in user_output_nodes - ] with gm.graph.inserting_before(output_node): # Only return user outputs - new_output = gm.graph.output(tuple(output_args)) + new_output = gm.graph.output(tuple(user_output_nodes)) output_node.replace_all_uses_with(new_output) gm.graph.erase_node(output_node) From f89574fa2376e53031a582d0a9ab8a49772ae46c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 12 Jun 2024 16:29:51 +0000 Subject: [PATCH 700/706] Revert "Pass params to dump_nccl_trace_pickle (#128307)" This reverts commit eb567b1f40233667b982f81e3a75deec0fdfd9ca. Reverted https://github.com/pytorch/pytorch/pull/128307 on behalf of https://github.com/clee2000 due to sorry need to revert this in order to revert 126969 ([comment](https://github.com/pytorch/pytorch/pull/128307#issuecomment-2163459399)) --- .../distributed/elastic/test_control_plane.py | 42 --------------- torch/csrc/distributed/c10d/NCCLUtils.cpp | 51 ------------------- .../distributed/c10d/ProcessGroupNCCL.cpp | 10 ++++ .../c10d/control_plane/Handlers.hpp | 3 -- .../c10d/control_plane/WorkerServer.cpp | 4 -- 5 files changed, 10 insertions(+), 100 deletions(-) diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index 971099e32f6d..775b062451b1 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -80,48 +80,6 @@ def test_dump_nccl_trace_pickle(self) -> None: resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") self.assertEqual(resp.status, 200) out = pickle.loads(resp.data) - self.assertIsInstance(out, dict) - self.assertIn("version", out) - - @requires_cuda - def test_dump_nccl_trace_pickle_with_params(self) -> None: - with local_worker_server() as pool: - # bad key - not lower case - resp = pool.request( - "POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true" - ) - self.assertEqual(resp.status, 400) - # unknown key - resp = pool.request( - "POST", "/handler/dump_nccl_trace_pickle?unknownkey=true" - ) - self.assertEqual(resp.status, 400) - # bad value - not a bool - resp = pool.request( - "POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool" - ) - self.assertEqual(resp.status, 400) - # bad value - value not lowercase - resp = pool.request( - "POST", "/handler/dump_nccl_trace_pickle?includecollectives=True" - ) - self.assertEqual(resp.status, 400) - # good key and value - resp = pool.request( - "POST", "/handler/dump_nccl_trace_pickle?includecollectives=true" - ) - self.assertEqual(resp.status, 200) - # good key and value - resp = pool.request( - "POST", "/handler/dump_nccl_trace_pickle?includestacktraces=true" - ) - self.assertEqual(resp.status, 200) - # multiple good keys and values - resp = pool.request( - "POST", - "/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true", - ) - self.assertEqual(resp.status, 200) def test_tcp(self) -> None: import requests diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 01edfdaf9292..db268371ea0f 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -1,12 +1,9 @@ #include -#include -#include #include #include #ifdef USE_C10D_NCCL -#include #include #include @@ -291,54 +288,6 @@ float getDurationFromEvent( return ncclStartEvent.elapsed_time(ncclEndEvent); } -control_plane::RegisterHandler dumpHandler{ - "dump_nccl_trace_pickle", - [](const control_plane::Request& req, control_plane::Response& res) { - const auto params = req.params(); - size_t validParamCount = 0; - - // valid params - const std::string includeCollectivesStr = "includecollectives"; - const std::string includeStackTracesStr = "includestacktraces"; - const std::string onlyActiveStr = "onlyactive"; - - std::unordered_map expectedParams = { - {includeCollectivesStr, true}, - {includeStackTracesStr, true}, - {onlyActiveStr, false}}; - - for (const auto& [paramName, paramValue] : params) { - auto it = expectedParams.find(paramName); - if (it != expectedParams.end()) { - validParamCount++; - if (paramValue == "true") { - it->second = true; - } else if (paramValue == "false") { - it->second = false; - } else { - res.setStatus(400); - res.setContent( - "Invalid value for " + paramName + - " valid values are true or false", - "text/plain"); - return; - } - } - } - if (validParamCount < params.size()) { - res.setStatus(400); - res.setContent( - "Invalid parameters - unexpected param passed in", "text/plain"); - return; - } - res.setContent( - dump_nccl_trace( - expectedParams[includeCollectivesStr], - expectedParams[includeStackTracesStr], - expectedParams[onlyActiveStr]), - "application/octet-stream"); - }}; - } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 158522063ab7..bb9198f22200 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -378,6 +379,15 @@ std::string dump_nccl_trace( } #endif +// TODO(c-p-i-o): add a JSON endpoint. +control_plane::RegisterHandler dumpHandler{ + "dump_nccl_trace_pickle", + [](const control_plane::Request& req, control_plane::Response& res) { + // TODO: c-p-i-o: params from the request need to go to dump_nccl_trace. + res.setContent( + dump_nccl_trace(true, true, false), "application/octet-stream"); + }}; + std::optional)>>& get_cpp_trace_dumper() { static std::optional< diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp index fef4776713e2..0c1063054931 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include @@ -15,8 +14,6 @@ class TORCH_API Request { public: virtual ~Request() = default; - virtual const std::multimap& params() const = 0; - virtual const std::string& body() = 0; }; diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index b99b9210eb54..e4b649d888dd 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -22,10 +22,6 @@ class RequestImpl : public Request { return req_.body; } - const std::multimap& params() const override { - return req_.params; - } - private: const httplib::Request& req_; }; From 5001f41b9047dfb8d97a63fd20d53a80b2706537 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 12 Jun 2024 16:32:57 +0000 Subject: [PATCH 701/706] Revert "Make TraceUtils.h to be device-agnostic (#126969)" This reverts commit 648625b230e8e6e7478fb219ff4f0aa6a45070f5. Reverted https://github.com/pytorch/pytorch/pull/126969 on behalf of https://github.com/clee2000 due to failing internal builds D58443769 ([comment](https://github.com/pytorch/pytorch/pull/126969#issuecomment-2163462600)) --- torch/csrc/distributed/c10d/NCCLUtils.cpp | 50 --- torch/csrc/distributed/c10d/NCCLUtils.hpp | 436 +------------------ torch/csrc/distributed/c10d/TraceUtils.h | 497 ++++++++++++++++++++++ 3 files changed, 498 insertions(+), 485 deletions(-) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index db268371ea0f..bc820fc1c8d5 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -238,56 +238,6 @@ std::string getNcclErrorDetailStr( return interpret + err; } -void DebugInfoWriter::write(const std::string& ncclTrace) { - // Open a file for writing. The ios::binary flag is used to write data as - // binary. - std::ofstream file(filename_, std::ios::binary); - - // Check if the file was opened successfully. - if (!file.is_open()) { - LOG(ERROR) << "Error opening file for writing NCCLPG debug info: " - << filename_; - return; - } - - file.write(ncclTrace.data(), ncclTrace.size()); - LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_; -} - -DebugInfoWriter& DebugInfoWriter::getWriter(int rank) { - if (writer_ == nullptr) { - std::string fileNamePrefix = getCvarString( - {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); - // Using std::unique_ptr here to auto-delete the writer object - // when the pointer itself is destroyed. - std::unique_ptr writerPtr( - new DebugInfoWriter(fileNamePrefix, rank)); - DebugInfoWriter::registerWriter(std::move(writerPtr)); - } - return *writer_; -} - -void DebugInfoWriter::registerWriter(std::unique_ptr writer) { - TORCH_CHECK_WITH( - DistBackendError, - hasWriterRegistered_.load() == false, - "debugInfoWriter already registered"); - hasWriterRegistered_.store(true); - writer_ = std::move(writer); -} - -std::unique_ptr DebugInfoWriter::writer_ = nullptr; -std::atomic DebugInfoWriter::hasWriterRegistered_(false); - -float getDurationFromEvent( - at::cuda::CUDAEvent& ncclStartEvent, - at::cuda::CUDAEvent& ncclEndEvent) { - TORCH_CHECK( - ncclEndEvent.query(), - "getDuration can only be called after work is succeeded.") - return ncclStartEvent.elapsed_time(ncclEndEvent); -} - } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 4aa4b15b2917..9ce25b55dc13 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -10,11 +10,9 @@ #include #include -#include #include #include #include -#include #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ (NCCL_MINOR >= 14) @@ -174,39 +172,6 @@ namespace c10d { -static c10::IValue entries_key = "entries"; -static c10::IValue nccl_comm_key = "nccl_comm_state"; -static c10::IValue version_key = "version"; -// Update whenever changing contents or formatting of the dump -// (minor when adding fields, major when changing existing fields) -static c10::IValue version_val = "2.2"; -static c10::IValue pg_config_key = "pg_config"; -static c10::IValue record_id_key = "record_id"; -static c10::IValue pg_id_key = "pg_id"; -static c10::IValue pg_name_key = "process_group"; -static c10::IValue collective_seq_id_key = "collective_seq_id"; -static c10::IValue p2p_seq_id_key = "p2p_seq_id"; -static c10::IValue is_p2p_key = "is_p2p"; -static c10::IValue op_id_key = "op_id"; -static c10::IValue profiling_name_key = "profiling_name"; -static c10::IValue input_sizes_key = "input_sizes"; -static c10::IValue input_dtypes_key = "input_dtypes"; -static c10::IValue output_sizes_key = "output_sizes"; -static c10::IValue output_dtypes_key = "output_dtypes"; -static c10::IValue time_created_key = "time_created_ns"; -static c10::IValue duration_key = "duration_ms"; -static c10::IValue timeout_key = "timeout_ms"; - -static c10::IValue frames_key = "frames"; -static c10::IValue state_key = "state"; -static c10::IValue line_key = "line"; -static c10::IValue name_key = "name"; -static c10::IValue filename_key = "filename"; -static c10::IValue retired_key = "retired"; -static c10::IValue time_discovered_started_key = "time_discovered_started_ns"; -static c10::IValue time_discovered_completed_key = - "time_discovered_completed_ns"; - TORCH_API size_t hashTensors(const std::vector& tensors); TORCH_API std::string getNcclVersion(); TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error); @@ -230,7 +195,7 @@ TORCH_API std::string getNcclErrorDetailStr( // auto-registered). class TORCH_API DebugInfoWriter { public: - virtual ~DebugInfoWriter() = default; + virtual ~DebugInfoWriter(); virtual void write(const std::string& ncclTrace); static DebugInfoWriter& getWriter(int rank); static void registerWriter(std::unique_ptr writer); @@ -553,405 +518,6 @@ struct ncclRedOpRAII { bool premul_sum_ = false; }; -/* Helper used by work::getDuration() and nccl flight recorder */ -float getDurationFromEvent( - at::cuda::CUDAEvent& ncclStartEvent, - at::cuda::CUDAEvent& ncclEndEvent); - -struct NCCLTraceBuffer { - static NCCLTraceBuffer* get() { - // intentionally leak on exit - // because this will hold python state that may get destructed - static NCCLTraceBuffer* instance = new NCCLTraceBuffer(); - return instance; - } - NCCLTraceBuffer() { - max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0); - capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false); - enabled_ = max_entries_ > 0; - } - using Event = at::cuda::CUDAEvent; - struct Entry { - size_t id_; // incremented id in the trace buffer - // used to figure out where in the circular entries - // buffer this entry will be located to - // update state information - size_t pg_id_; - std::tuple pg_name_; // - - // collective_seq_id and p2p_seq_id refer to actual kernel launches (e.g. 1 - // per coalesced group). - // collective_seq_id only increments for true collective operations (over - // all ranks in the group). p2p_seq_id only increments over non-collective - // operations in the group. op_id refers to logical operations (e.g. one per - // op inside coalesced group) - size_t collective_seq_id_; - size_t p2p_seq_id_; - size_t op_id_; - std::string profiling_name_; - - std::shared_ptr traceback_; - // we borrow pointers to start_ and end_ so we can query the state - // on reporting. However, once the event is completed, the call - // to `complete` will clear these. - Event *start_, *end_; - - // timestamp when the entry was created, likely close to the time the work - // was 'enqueued'- not necessarily started - c10::time_t time_created_; - - // configured timeout for this entry - c10::time_t timeout_ms_; - - // Is this a P2P event? - bool isP2P_; - - std::optional duration_; - - // timestamp when our CPU threads discovered that the kernel started. - // will always be _after_ it actually started, and can be very late - // if the watchdog thread got stuck on CUDA APIs. - std::optional time_discovered_started_; - - // timestamp when our CPU threads discovered that the kernel completed. - // will always be _after_ it actually complated, and can be the same time - // as the discovery of the start if the watchdog thread is stuck on CUDA - // APIs - std::optional time_discovered_completed_; - - // size information for input/output tensors - c10::SmallVector input_dims_; - std::vector input_dtypes_; - c10::SmallVector output_dims_; - std::vector output_dtypes_; - c10::SmallVector sizes_; // flattened from inputs, outputs - bool retired_ = false; // is this work entry no longer in the workMetaList_? - // a retired but not completed event has timed out - }; - - bool enabled_ = false; - bool capture_cpp_stack_ = false; - std::mutex mutex_; - std::vector entries_; - size_t max_entries_ = 0; - size_t next_ = 0; - size_t id_ = 0; - std::map, std::vector> - pg_name_to_ranks_ = {}; - - std::optional record( - size_t pg_id, - const std::tuple& pg_name, - size_t collective_seq_id, - size_t p2p_seq_id, - size_t op_id, - std::string profiling_name, - const std::vector& inputs, - const std::vector& outputs, - Event* start, - Event* end, - std::chrono::milliseconds timeout_ms, - bool isP2P) { - if (!enabled_) { - return c10::nullopt; - } - auto traceback = - torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); - std::lock_guard guard(mutex_); - - auto te = Entry{ - id_, - pg_id, - pg_name, - collective_seq_id, - p2p_seq_id, - op_id, - std::move(profiling_name), - std::move(traceback), - std::move(start), - std::move(end), - c10::getTime(), - timeout_ms.count(), - isP2P}; - - for (const auto& input : inputs) { - c10::IntArrayRef sizes = input.sizes(); - te.input_dtypes_.push_back(input.dtype().toScalarType()); - te.input_dims_.push_back(sizes.size()); - te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); - } - - for (const auto& output : outputs) { - c10::IntArrayRef sizes = output.sizes(); - te.output_dtypes_.push_back(output.dtype().toScalarType()); - te.output_dims_.push_back(sizes.size()); - te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); - } - - if (entries_.size() < max_entries_) { - entries_.emplace_back(std::move(te)); - } else { - entries_[next_++] = std::move(te); - if (next_ == max_entries_) { - next_ = 0; - } - } - return id_++; - } - - void record_pg_ranks( - const std::tuple& pg_name, - std::vector ranks) { - if (!enabled_) { - return; - } - std::lock_guard guard(mutex_); - pg_name_to_ranks_[pg_name] = ranks; - } - - void update_state(Entry& r) { - if (r.start_ != nullptr) { - bool started = r.start_->query(); - if (started && !r.time_discovered_started_) { - r.time_discovered_started_ = c10::getTime(); - } - } - if (r.end_ != nullptr) { - bool completed = r.end_->query(); - if (completed && !r.time_discovered_completed_) { - r.time_discovered_completed_ = c10::getTime(); - } - } - } - - std::vector dump_entries() { - std::lock_guard guard(mutex_); - std::vector result; - result.reserve(entries_.size()); - result.insert(result.end(), entries_.begin() + next_, entries_.end()); - result.insert(result.end(), entries_.begin(), entries_.begin() + next_); - // query any remaining events - for (auto& r : result) { - update_state(r); - r.start_ = r.end_ = nullptr; - } - return result; - } - - /* - Mark an Event as completed and free its events. - This is called by the watchdog thread, and is asynchronous from the - perspective of the main thread. - compute_duration defaults to true since retire_id is only called in the - watchdog thread, which is currently a place we call cuda APIs which may hang, - but care should be taken to avoid computing duration in any function that must - never hang. (timing must also be enabled for compute_duration - see - TORCH_NCCL_ENABLE_TIMING). - */ - void retire_id(std::optional id, bool compute_duration = true) { - if (!enabled_ || !id) { - return; - } - - bool can_compute_duration = false; - Event* startEvent = nullptr; - Event* endEvent = nullptr; - std::optional duration = c10::nullopt; - - std::unique_lock guard(mutex_); - - Entry* entry = &entries_.at(*id % max_entries_); - if (entry->id_ == *id) { - update_state(*entry); - - if (compute_duration) { - can_compute_duration = entry->time_discovered_completed_.has_value() && - entry->start_ && entry->end_; - startEvent = entry->start_; - endEvent = entry->end_; - } - } - - if (can_compute_duration) { - // Compute duration without without holding the lock, because - // cudaEventDuration() can hang, and we need to acquire the lock before we - // can dump(), which we never want to block. - guard.unlock(); - duration = getDurationFromEvent(*startEvent, *endEvent); - guard.lock(); - - // Refresh the entry pointer, see if the entry has been overwritten - entry = &entries_.at(*id % max_entries_); - if (entry->id_ != *id) { - LOG(INFO) - << "retire_id abandoned for id " << *id - << ", event was overwritten while waiting to compute duration."; - return; - } - if (duration.has_value()) { - entry->duration_ = duration.value(); - } - } - - entry->retired_ = true; - entry->start_ = entry->end_ = nullptr; - } - - const c10::List getCollectiveTrace( - bool includeStacktraces, - bool onlyActive) { - auto entries = new_list(); - auto result = dump_entries(); - std::vector tracebacks; - torch::SymbolizedTracebacks stracebacks; - std::vector all_frames; - if (includeStacktraces) { - for (auto& e : result) { - tracebacks.push_back(e.traceback_.get()); - } - stracebacks = torch::symbolize(tracebacks); - for (const auto& f : stracebacks.all_frames) { - auto d = new_dict(); - d.insert(name_key, f.funcname); - d.insert(filename_key, f.filename); - d.insert(line_key, int64_t(f.lineno)); - all_frames.emplace_back(std::move(d)); - } - } - for (auto i : c10::irange(result.size())) { - auto dict = new_dict(); - auto& e = result.at(i); - // Skip completed events - if (onlyActive && e.time_discovered_completed_.has_value()) { - continue; - } - - if (includeStacktraces) { - auto& tb = stracebacks.tracebacks.at(i); - auto frames = new_list(); - for (int64_t frame : tb) { - frames.push_back(all_frames.at(frame)); - } - dict.insert(frames_key, frames); - } - - dict.insert(record_id_key, int64_t(e.id_)); - dict.insert(pg_id_key, int64_t(e.pg_id_)); - dict.insert(pg_name_key, e.pg_name_); - dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_)); - dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_)); - dict.insert(op_id_key, int64_t(e.op_id_)); - dict.insert(profiling_name_key, e.profiling_name_); - dict.insert(time_created_key, int64_t(e.time_created_)); - if (e.duration_) { - dict.insert(duration_key, *e.duration_); - } - - auto it = e.sizes_.begin(); - auto read_sizes = [&](const c10::SmallVector& dims) { - auto sizes = new_list(); - for (auto dim : dims) { - auto arg_sizes = new_list(); - for (auto i : c10::irange(dim)) { - (void)i; - arg_sizes.push_back(*it++); - } - sizes.push_back(arg_sizes); - } - return sizes; - }; - - dict.insert(input_sizes_key, read_sizes(e.input_dims_)); - std::vector input_dtypes_strs; - input_dtypes_strs.reserve(e.input_dtypes_.size()); - for (const auto& input_dtype : e.input_dtypes_) { - input_dtypes_strs.push_back(c10::toString(input_dtype)); - } - dict.insert(input_dtypes_key, input_dtypes_strs); - dict.insert(output_sizes_key, read_sizes(e.output_dims_)); - std::vector output_dtypes_strs; - output_dtypes_strs.reserve(e.output_dtypes_.size()); - for (const auto& output_dtype : e.output_dtypes_) { - output_dtypes_strs.push_back(c10::toString(output_dtype)); - } - dict.insert(output_dtypes_key, output_dtypes_strs); - if (e.time_discovered_completed_.has_value()) { - dict.insert(state_key, "completed"); - } else if (e.time_discovered_started_.has_value()) { - dict.insert(state_key, "started"); - } else { - dict.insert(state_key, "scheduled"); - } - - dict.insert( - time_discovered_started_key, - e.time_discovered_started_.has_value() - ? int64_t(*e.time_discovered_started_) - : c10::IValue()); - dict.insert( - time_discovered_completed_key, - e.time_discovered_completed_.has_value() - ? int64_t(*e.time_discovered_completed_) - : c10::IValue()); - dict.insert(retired_key, e.retired_); - dict.insert(timeout_key, e.timeout_ms_); - dict.insert(is_p2p_key, e.isP2P_); - - entries.push_back(dict); - } - return entries; - } - - // dump pg_entries - const c10::Dict getPgConfig() { - auto pg_config = new_dict(); - for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { - auto pg_info = new_dict(); - pg_info.insert("name", std::get<0>(pg_name)); - pg_info.insert("desc", std::get<1>(pg_name)); - pg_info.insert("ranks", ranks_str(ranks)); - pg_config.insert(std::get<0>(pg_name), pg_info); - } - return pg_config; - } - - // dump all collectives + ncclDumpMap - std::string dump( - const std::optional>>& ncclDumpMap, - bool includeCollectives, - bool includeStackTraces, - bool onlyActive) { - auto result = new_dict(); - // common values - result.insert(version_key, version_val); - result.insert(pg_config_key, getPgConfig()); - - // collective trace - if (includeCollectives) { - result.insert( - entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); - } - - // convert ncclDumpMap into a dictionary - auto per_comm_dict = new_dict(); - if (ncclDumpMap.has_value()) { - for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) { - auto inner_dict = new_dict(); - for (const auto& [key, value] : ncclDump) { - inner_dict.insert(key, value); - } - per_comm_dict.insert(ncclId, inner_dict); - } - } - if (per_comm_dict.size() > 0) { - result.insert(nccl_comm_key, per_comm_dict); - } - return pickle_str(result); - } -}; - } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 9c469dbd5bc6..de623d77fe9e 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -10,6 +10,11 @@ #include #include +#ifdef USE_C10D_NCCL +#include +#include +#endif + #include #include #include @@ -19,6 +24,41 @@ namespace c10d { +static c10::IValue entries_key = "entries"; +static c10::IValue nccl_comm_key = "nccl_comm_state"; +static c10::IValue version_key = "version"; +// Update whenever changing contents or formatting of the dump +// (minor when adding fields, major when changing existing fields) +static c10::IValue version_val = "2.2"; +static c10::IValue pg_config_key = "pg_config"; +static c10::IValue record_id_key = "record_id"; +static c10::IValue pg_id_key = "pg_id"; +static c10::IValue pg_name_key = "process_group"; +static c10::IValue collective_seq_id_key = "collective_seq_id"; +static c10::IValue p2p_seq_id_key = "p2p_seq_id"; +static c10::IValue is_p2p_key = "is_p2p"; +static c10::IValue op_id_key = "op_id"; +static c10::IValue profiling_name_key = "profiling_name"; +static c10::IValue input_sizes_key = "input_sizes"; +static c10::IValue input_dtypes_key = "input_dtypes"; +static c10::IValue output_sizes_key = "output_sizes"; +static c10::IValue output_dtypes_key = "output_dtypes"; +static c10::IValue time_created_key = "time_created_ns"; +static c10::IValue duration_key = "duration_ms"; +static c10::IValue timeout_key = "timeout_ms"; + +static c10::IValue frames_key = "frames"; +static c10::IValue state_key = "state"; +static c10::IValue line_key = "line"; +static c10::IValue name_key = "name"; +static c10::IValue filename_key = "filename"; +static c10::IValue retired_key = "retired"; +static c10::IValue time_discovered_started_key = "time_discovered_started_ns"; +static c10::IValue time_discovered_completed_key = + "time_discovered_completed_ns"; + +/* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */ + inline std::string getTraceStartKey(const std::string& pgName, int rank) { return pgName + "_" + std::to_string(rank) + "_trace_start"; } @@ -263,6 +303,66 @@ inline std::string retrieveDesyncReport( return report; } +/* Trace Utils Related to Flight Recorder */ + +/* Note: this is only used by PGNCCL (could be generalized in an ideal world but + * wasn't done that way, so isn't expected to be fully general at the moment) */ + +#ifdef USE_C10D_NCCL + +/* Helper used by work::getDuration() and nccl flight recorder */ +float getDurationFromEvent( + at::cuda::CUDAEvent& ncclStartEvent, + at::cuda::CUDAEvent& ncclEndEvent) { + TORCH_CHECK( + ncclEndEvent.query(), + "getDuration can only be called after work is succeeded.") + return ncclStartEvent.elapsed_time(ncclEndEvent); +} + +DebugInfoWriter::~DebugInfoWriter() = default; + +void DebugInfoWriter::write(const std::string& ncclTrace) { + // Open a file for writing. The ios::binary flag is used to write data as + // binary. + std::ofstream file(filename_, std::ios::binary); + + // Check if the file was opened successfully. + if (!file.is_open()) { + LOG(ERROR) << "Error opening file for writing NCCLPG debug info: " + << filename_; + return; + } + + file.write(ncclTrace.data(), ncclTrace.size()); + LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_; +} + +DebugInfoWriter& DebugInfoWriter::getWriter(int rank) { + if (writer_ == nullptr) { + std::string fileNamePrefix = getCvarString( + {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); + // Using std::unique_ptr here to auto-delete the writer object + // when the pointer itself is destroyed. + std::unique_ptr writerPtr( + new DebugInfoWriter(fileNamePrefix, rank)); + DebugInfoWriter::registerWriter(std::move(writerPtr)); + } + return *writer_; +} + +void DebugInfoWriter::registerWriter(std::unique_ptr writer) { + TORCH_CHECK_WITH( + DistBackendError, + hasWriterRegistered_.load() == false, + "debugInfoWriter already registered"); + hasWriterRegistered_.store(true); + writer_ = std::move(writer); +} + +std::unique_ptr DebugInfoWriter::writer_ = nullptr; +std::atomic DebugInfoWriter::hasWriterRegistered_(false); + inline std::string pickle_str(const c10::IValue& v) { std::vector result; { @@ -321,4 +421,401 @@ inline std::string ranks_str(const std::vector& ranks) { return c10::str("[", str, "]"); } +struct NCCLTraceBuffer { + static NCCLTraceBuffer* get() { + // intentionally leak on exit + // because this will hold python state that may get destructed + static NCCLTraceBuffer* instance = new NCCLTraceBuffer(); + return instance; + } + NCCLTraceBuffer() { + max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0); + capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false); + enabled_ = max_entries_ > 0; + } + using Event = at::cuda::CUDAEvent; + struct Entry { + size_t id_; // incremented id in the trace buffer + // used to figure out where in the circular entries + // buffer this entry will be located to + // update state information + size_t pg_id_; + std::tuple pg_name_; // + + // collective_seq_id and p2p_seq_id refer to actual kernel launches (e.g. 1 + // per coalesced group). + // collective_seq_id only increments for true collective operations (over + // all ranks in the group). p2p_seq_id only increments over non-collective + // operations in the group. op_id refers to logical operations (e.g. one per + // op inside coalesced group) + size_t collective_seq_id_; + size_t p2p_seq_id_; + size_t op_id_; + std::string profiling_name_; + + std::shared_ptr traceback_; + // we borrow pointers to start_ and end_ so we can query the state + // on reporting. However, once the event is completed, the call + // to `complete` will clear these. + Event *start_, *end_; + + // timestamp when the entry was created, likely close to the time the work + // was 'enqueued'- not necessarily started + c10::time_t time_created_; + + // configured timeout for this entry + c10::time_t timeout_ms_; + + // Is this a P2P event? + bool isP2P_; + + std::optional duration_; + + // timestamp when our CPU threads discovered that the kernel started. + // will always be _after_ it actually started, and can be very late + // if the watchdog thread got stuck on CUDA APIs. + std::optional time_discovered_started_; + + // timestamp when our CPU threads discovered that the kernel completed. + // will always be _after_ it actually complated, and can be the same time + // as the discovery of the start if the watchdog thread is stuck on CUDA + // APIs + std::optional time_discovered_completed_; + + // size information for input/output tensors + c10::SmallVector input_dims_; + std::vector input_dtypes_; + c10::SmallVector output_dims_; + std::vector output_dtypes_; + c10::SmallVector sizes_; // flattened from inputs, outputs + bool retired_ = false; // is this work entry no longer in the workMetaList_? + // a retired but not completed event has timed out + }; + + bool enabled_ = false; + bool capture_cpp_stack_ = false; + std::mutex mutex_; + std::vector entries_; + size_t max_entries_ = 0; + size_t next_ = 0; + size_t id_ = 0; + std::map, std::vector> + pg_name_to_ranks_ = {}; + + std::optional record( + size_t pg_id, + const std::tuple& pg_name, + size_t collective_seq_id, + size_t p2p_seq_id, + size_t op_id, + std::string profiling_name, + const std::vector& inputs, + const std::vector& outputs, + Event* start, + Event* end, + std::chrono::milliseconds timeout_ms, + bool isP2P) { + if (!enabled_) { + return c10::nullopt; + } + auto traceback = + torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); + std::lock_guard guard(mutex_); + + auto te = Entry{ + id_, + pg_id, + pg_name, + collective_seq_id, + p2p_seq_id, + op_id, + std::move(profiling_name), + std::move(traceback), + std::move(start), + std::move(end), + c10::getTime(), + timeout_ms.count(), + isP2P}; + + for (const auto& input : inputs) { + c10::IntArrayRef sizes = input.sizes(); + te.input_dtypes_.push_back(input.dtype().toScalarType()); + te.input_dims_.push_back(sizes.size()); + te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); + } + + for (const auto& output : outputs) { + c10::IntArrayRef sizes = output.sizes(); + te.output_dtypes_.push_back(output.dtype().toScalarType()); + te.output_dims_.push_back(sizes.size()); + te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); + } + + if (entries_.size() < max_entries_) { + entries_.emplace_back(std::move(te)); + } else { + entries_[next_++] = std::move(te); + if (next_ == max_entries_) { + next_ = 0; + } + } + return id_++; + } + + void record_pg_ranks( + const std::tuple& pg_name, + std::vector ranks) { + if (!enabled_) { + return; + } + std::lock_guard guard(mutex_); + pg_name_to_ranks_[pg_name] = ranks; + } + + void update_state(Entry& r) { + if (r.start_ != nullptr) { + bool started = r.start_->query(); + if (started && !r.time_discovered_started_) { + r.time_discovered_started_ = c10::getTime(); + } + } + if (r.end_ != nullptr) { + bool completed = r.end_->query(); + if (completed && !r.time_discovered_completed_) { + r.time_discovered_completed_ = c10::getTime(); + } + } + } + + std::vector dump_entries() { + std::lock_guard guard(mutex_); + std::vector result; + result.reserve(entries_.size()); + result.insert(result.end(), entries_.begin() + next_, entries_.end()); + result.insert(result.end(), entries_.begin(), entries_.begin() + next_); + // query any remaining events + for (auto& r : result) { + update_state(r); + r.start_ = r.end_ = nullptr; + } + return result; + } + + /* + Mark an Event as completed and free its events. + + This is called by the watchdog thread, and is asynchronous from the + perspective of the main thread. + + compute_duration defaults to true since retire_id is only called in the + watchdog thread, which is currently a place we call cuda APIs which may hang, + but care should be taken to avoid computing duration in any function that must + never hang. (timing must also be enabled for compute_duration - see + TORCH_NCCL_ENABLE_TIMING). + */ + void retire_id(std::optional id, bool compute_duration = true) { + if (!enabled_ || !id) { + return; + } + + bool can_compute_duration = false; + Event* startEvent = nullptr; + Event* endEvent = nullptr; + std::optional duration = c10::nullopt; + + std::unique_lock guard(mutex_); + + Entry* entry = &entries_.at(*id % max_entries_); + if (entry->id_ == *id) { + update_state(*entry); + + if (compute_duration) { + can_compute_duration = entry->time_discovered_completed_.has_value() && + entry->start_ && entry->end_; + startEvent = entry->start_; + endEvent = entry->end_; + } + } + + if (can_compute_duration) { + // Compute duration without without holding the lock, because + // cudaEventDuration() can hang, and we need to acquire the lock before we + // can dump(), which we never want to block. + guard.unlock(); + duration = getDurationFromEvent(*startEvent, *endEvent); + guard.lock(); + + // Refresh the entry pointer, see if the entry has been overwritten + entry = &entries_.at(*id % max_entries_); + if (entry->id_ != *id) { + LOG(INFO) + << "retire_id abandoned for id " << *id + << ", event was overwritten while waiting to compute duration."; + return; + } + if (duration.has_value()) { + entry->duration_ = duration.value(); + } + } + + entry->retired_ = true; + entry->start_ = entry->end_ = nullptr; + } + + const c10::List getCollectiveTrace( + bool includeStacktraces, + bool onlyActive) { + auto entries = new_list(); + auto result = dump_entries(); + std::vector tracebacks; + torch::SymbolizedTracebacks stracebacks; + std::vector all_frames; + if (includeStacktraces) { + for (auto& e : result) { + tracebacks.push_back(e.traceback_.get()); + } + stracebacks = torch::symbolize(tracebacks); + for (const auto& f : stracebacks.all_frames) { + auto d = new_dict(); + d.insert(name_key, f.funcname); + d.insert(filename_key, f.filename); + d.insert(line_key, int64_t(f.lineno)); + all_frames.emplace_back(std::move(d)); + } + } + for (auto i : c10::irange(result.size())) { + auto dict = new_dict(); + auto& e = result.at(i); + // Skip completed events + if (onlyActive && e.time_discovered_completed_.has_value()) { + continue; + } + + if (includeStacktraces) { + auto& tb = stracebacks.tracebacks.at(i); + auto frames = new_list(); + for (int64_t frame : tb) { + frames.push_back(all_frames.at(frame)); + } + dict.insert(frames_key, frames); + } + + dict.insert(record_id_key, int64_t(e.id_)); + dict.insert(pg_id_key, int64_t(e.pg_id_)); + dict.insert(pg_name_key, e.pg_name_); + dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_)); + dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_)); + dict.insert(op_id_key, int64_t(e.op_id_)); + dict.insert(profiling_name_key, e.profiling_name_); + dict.insert(time_created_key, int64_t(e.time_created_)); + if (e.duration_) { + dict.insert(duration_key, *e.duration_); + } + + auto it = e.sizes_.begin(); + auto read_sizes = [&](const c10::SmallVector& dims) { + auto sizes = new_list(); + for (auto dim : dims) { + auto arg_sizes = new_list(); + for (auto i : c10::irange(dim)) { + (void)i; + arg_sizes.push_back(*it++); + } + sizes.push_back(arg_sizes); + } + return sizes; + }; + + dict.insert(input_sizes_key, read_sizes(e.input_dims_)); + std::vector input_dtypes_strs; + input_dtypes_strs.reserve(e.input_dtypes_.size()); + for (const auto& input_dtype : e.input_dtypes_) { + input_dtypes_strs.push_back(c10::toString(input_dtype)); + } + dict.insert(input_dtypes_key, input_dtypes_strs); + dict.insert(output_sizes_key, read_sizes(e.output_dims_)); + std::vector output_dtypes_strs; + output_dtypes_strs.reserve(e.output_dtypes_.size()); + for (const auto& output_dtype : e.output_dtypes_) { + output_dtypes_strs.push_back(c10::toString(output_dtype)); + } + dict.insert(output_dtypes_key, output_dtypes_strs); + if (e.time_discovered_completed_.has_value()) { + dict.insert(state_key, "completed"); + } else if (e.time_discovered_started_.has_value()) { + dict.insert(state_key, "started"); + } else { + dict.insert(state_key, "scheduled"); + } + + dict.insert( + time_discovered_started_key, + e.time_discovered_started_.has_value() + ? int64_t(*e.time_discovered_started_) + : c10::IValue()); + dict.insert( + time_discovered_completed_key, + e.time_discovered_completed_.has_value() + ? int64_t(*e.time_discovered_completed_) + : c10::IValue()); + dict.insert(retired_key, e.retired_); + dict.insert(timeout_key, e.timeout_ms_); + dict.insert(is_p2p_key, e.isP2P_); + + entries.push_back(dict); + } + return entries; + } + + // dump pg_entries + const c10::Dict getPgConfig() { + auto pg_config = new_dict(); + for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { + auto pg_info = new_dict(); + pg_info.insert("name", std::get<0>(pg_name)); + pg_info.insert("desc", std::get<1>(pg_name)); + pg_info.insert("ranks", ranks_str(ranks)); + pg_config.insert(std::get<0>(pg_name), pg_info); + } + return pg_config; + } + + // dump all collectives + ncclDumpMap + std::string dump( + const std::optional>>& ncclDumpMap, + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + auto result = new_dict(); + // common values + result.insert(version_key, version_val); + result.insert(pg_config_key, getPgConfig()); + + // collective trace + if (includeCollectives) { + result.insert( + entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); + } + + // convert ncclDumpMap into a dictionary + auto per_comm_dict = new_dict(); + if (ncclDumpMap.has_value()) { + for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) { + auto inner_dict = new_dict(); + for (const auto& [key, value] : ncclDump) { + inner_dict.insert(key, value); + } + per_comm_dict.insert(ncclId, inner_dict); + } + } + if (per_comm_dict.size() > 0) { + result.insert(nccl_comm_key, per_comm_dict); + } + return pickle_str(result); + } +}; + +#endif } // namespace c10d From 3d8dab99c8a9857aed422efcca958a24f8a69e0b Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 13 Jun 2024 13:35:10 +0800 Subject: [PATCH 702/706] correct the UT's doc Signed-off-by: yiliu30 --- test/quantization/pt2e/test_x86inductor_quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index a64c12a8f2a1..bd3eefc21480 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -2220,7 +2220,7 @@ def test_set_module_name_with_mixed_configs(self): def test_set_module_name_and_module_type_with_mixed_configs(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs. - Expect that all linear layers are quantized with static quantization except the last one. + Expect that that only the last linear(`sub`) is quantized using static quantization. """ class M(torch.nn.Module): From 73b2bc8468c9dd70f90beff674def40246410121 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 13 Jun 2024 13:56:58 +0800 Subject: [PATCH 703/706] fix lint Signed-off-by: yiliu30 --- .../pt2e/test_x86inductor_quantizer.py | 60 +++++++++++-------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index bd3eefc21480..f386412443b4 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -754,9 +754,11 @@ def test_conv2d_binary2(self): torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] self._test_quantizer( m, @@ -1349,9 +1351,11 @@ def test_linear_binary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] fq_m = self._test_quantizer( m, @@ -1404,9 +1408,11 @@ def test_linear_binary2(self): torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] fq_m = self._test_quantizer( m, @@ -1475,9 +1481,11 @@ def test_linear_binary_unary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] fq_m = self._test_quantizer( m, @@ -1697,9 +1705,11 @@ def test_qat_conv2d_binary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] @@ -1744,9 +1754,11 @@ def test_qat_conv2d_binary2(self): torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] self._test_quantizer( m, @@ -2401,12 +2413,12 @@ def test_attention_block(self): ) node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_tensor.default: 5 - if annotate_matmul - else 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7 - if annotate_matmul - else 3, + torch.ops.quantized_decomposed.quantize_per_tensor.default: ( + 5 if annotate_matmul else 1 + ), + torch.ops.quantized_decomposed.dequantize_per_tensor.default: ( + 7 if annotate_matmul else 3 + ), # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, From 688c0b02128c7949d25f8dcd4ff8350d5d93e259 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 13 Jun 2024 13:59:10 +0800 Subject: [PATCH 704/706] fix docstring Signed-off-by: yiliu30 --- test/quantization/pt2e/test_x86inductor_quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index f386412443b4..fb8182b21dda 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -2232,7 +2232,7 @@ def test_set_module_name_with_mixed_configs(self): def test_set_module_name_and_module_type_with_mixed_configs(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs. - Expect that that only the last linear(`sub`) is quantized using static quantization. + Expect that only the last linear(`sub`) is quantized using static quantization. """ class M(torch.nn.Module): From 621ff110f89ad1a084272f5973f86e8db1291fcd Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 13 Jun 2024 15:21:41 +0800 Subject: [PATCH 705/706] update the cur mode Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 136 ++++++++++-------- 1 file changed, 78 insertions(+), 58 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 178de27a864f..06f855fea20f 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -420,16 +420,17 @@ class _CurrentQuantizationMode: All possible current quantization modes are listed below: ---------------------------------------------------------------------------------------------------------- - | is_dynamic - is_qat |--------------------------------------------------------------------------------------------- + | dynamic_state + qat_state |--------------------------------------------------------------------------------------------- | None | True | False ---------------------------------------------------------------------------------------------------------- None | quantizer does not receive a non-None `quantization_config` | \ | \ False | quantizer will not do QAT | dynamic | static True | quantizer will do QAT | QAT + dynamic | QAT + static """ - is_qat: Optional[bool] - is_dynamic: Optional[bool] + + qat_state: Optional[bool] + dynamic_state: Optional[bool] class X86InductorQuantizer(Quantizer): @@ -468,20 +469,39 @@ def get_supported_operator_for_quantization_config( def _get_current_quantization_mode(self) -> _CurrentQuantizationMode: """Retrieves the current quantization mode based on all configurations.""" - is_qat = None - is_dynamic = None + qat_state = None + dynamic_state = None + # As we use `_need_skip_config` to skip all invalid configurations, + # we can safely assume that the all existing non-None configurations + # have the same quantization mode. for qconfig in ( list(self.module_name_qconfig.values()) + list(self.operator_type_qconfig.values()) + [self.global_config] ): if qconfig is not None: - is_qat = qconfig.is_qat + # Query the `is_qat` state + if qat_state is None: + qat_state = qconfig.is_qat + else: + assert qat_state == qconfig.is_qat, ( + f"All non-None quantization configs should have the same `is_qat`," + f"but got {qat_state} and {qconfig.is_qat}." + ) + # Query the `is_dynamic` state input_activation_spec = qconfig.input_activation if input_activation_spec is not None: - is_dynamic = input_activation_spec.is_dynamic - return _CurrentQuantizationMode(is_qat=is_qat, is_dynamic=is_dynamic) + if dynamic_state is None: + dynamic_state = input_activation_spec.is_dynamic + else: + assert dynamic_state == input_activation_spec.is_dynamic, ( + f"All non-None `input_activation_spec` should have the same `is_dynamic`," + f"but got {dynamic_state} and {input_activation_spec.is_dynamic}." + ) + return _CurrentQuantizationMode( + qat_state=qat_state, dynamic_state=dynamic_state + ) def _need_skip_config( self, quantization_config: Optional[QuantizationConfig] @@ -498,16 +518,16 @@ def _need_skip_config( need_skip = False current_mode = self._get_current_quantization_mode() if ( - current_mode.is_qat is not None - and current_mode.is_qat != quantization_config.is_qat + current_mode.qat_state is not None + and current_mode.qat_state != quantization_config.is_qat ): warnings.warn("Mixed QAT and Non-QAT quantization config is not supported.") need_skip = True - if current_mode.is_dynamic is not None: + if current_mode.dynamic_state is not None: input_activation_spec = quantization_config.input_activation if ( input_activation_spec is not None - and current_mode.is_dynamic != input_activation_spec.is_dynamic + and current_mode.dynamic_state != input_activation_spec.is_dynamic ): warnings.warn( "Mixed dynamic and static quantization config is not supported." @@ -844,19 +864,19 @@ def _annotate_qat_conv2d_bn_binary_unary( binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( quantization_config ) - binary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - _annotated=True, + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) ) - unary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) ) else: _annotate_nodes_not_quantize([binary_node, unary_node]) @@ -914,14 +934,14 @@ def _annotate_qat_conv2d_bn_binary( binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( quantization_config ) - binary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) ) else: _annotate_nodes_not_quantize(binary_node) @@ -971,13 +991,13 @@ def _annotate_qat_conv2d_bn_unary( self._annotate_conv_node_helper(conv_node, False, quantization_config) if quantization_config is not None: - unary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) ) else: _annotate_nodes_not_quantize(unary_node) @@ -1012,13 +1032,13 @@ def _annotate_qat_conv2d_bn( self._annotate_conv_node_helper(conv_node, False, quantization_config) if quantization_config is not None: - bn_output_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, + bn_output_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) ) else: _annotate_nodes_not_quantize(bn_output_node) @@ -1588,19 +1608,19 @@ def _annotate_linear_binary_unary( linear_node, False, quantization_config ) # We don't insert q-dq before the binary input node due to accuracy issues - binary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - input_qspec_map={}, - _annotated=True, - _is_output_of_quantized_pattern=(not has_unary), + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map={}, + _annotated=True, + _is_output_of_quantized_pattern=(not has_unary), + ) ) if unary_node is not None: - unary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - _annotated=True, - _is_output_of_quantized_pattern=True, + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) ) def validate(self, model: torch.fx.GraphModule) -> None: From 5378b56b2c334fd6c4dca0d6707ada2fdfef4a86 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 14 Jun 2024 09:21:58 +0800 Subject: [PATCH 706/706] fix litn Signed-off-by: yiliu30 --- .../quantizer/x86_inductor_quantizer.py | 90 +++++++++---------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 06f855fea20f..6eecabb6fee0 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -864,19 +864,19 @@ def _annotate_qat_conv2d_bn_binary_unary( binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( quantization_config ) - binary_node.meta[QUANT_ANNOTATION_KEY] = ( - _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - _annotated=True, - ) + binary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, ) - unary_node.meta[QUANT_ANNOTATION_KEY] = ( - _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + unary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, ) else: _annotate_nodes_not_quantize([binary_node, unary_node]) @@ -934,14 +934,14 @@ def _annotate_qat_conv2d_bn_binary( binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( quantization_config ) - binary_node.meta[QUANT_ANNOTATION_KEY] = ( - _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + binary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, ) else: _annotate_nodes_not_quantize(binary_node) @@ -991,13 +991,13 @@ def _annotate_qat_conv2d_bn_unary( self._annotate_conv_node_helper(conv_node, False, quantization_config) if quantization_config is not None: - unary_node.meta[QUANT_ANNOTATION_KEY] = ( - _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + unary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, ) else: _annotate_nodes_not_quantize(unary_node) @@ -1032,13 +1032,13 @@ def _annotate_qat_conv2d_bn( self._annotate_conv_node_helper(conv_node, False, quantization_config) if quantization_config is not None: - bn_output_node.meta[QUANT_ANNOTATION_KEY] = ( - _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + bn_output_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, ) else: _annotate_nodes_not_quantize(bn_output_node) @@ -1608,19 +1608,19 @@ def _annotate_linear_binary_unary( linear_node, False, quantization_config ) # We don't insert q-dq before the binary input node due to accuracy issues - binary_node.meta[QUANT_ANNOTATION_KEY] = ( - _X86InductorQuantizationAnnotation( - input_qspec_map={}, - _annotated=True, - _is_output_of_quantized_pattern=(not has_unary), - ) + binary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + input_qspec_map={}, + _annotated=True, + _is_output_of_quantized_pattern=(not has_unary), ) if unary_node is not None: - unary_node.meta[QUANT_ANNOTATION_KEY] = ( - _X86InductorQuantizationAnnotation( - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + unary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, ) def validate(self, model: torch.fx.GraphModule) -> None: